Coverage for hyper_parallel / core / fully_shard / hsdp_utils.py: 83%

71 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP optimizer shared level""" 

16from dataclasses import dataclass, field 

17from typing import Any, List, Optional, Sequence, Tuple 

18from enum import auto, Enum 

19from torch import nn 

20 

21class OptimizerLevel(Enum): 

22 """ 

23 Optimizer level: 

24 - SHARD_OPT: 

25 Splitting is performed on optimizer state. 

26 - SHARD_OPT_GRAD: 

27 Splitting is performed on optimizer state, and gradients. 

28 - SHARD_OPT_GRAD_PARAM: 

29 Splitting is performed on optimizer state, gradients and weights. 

30 """ 

31 SHARD_OPT = auto() 

32 SHARD_OPT_GRAD = auto() 

33 SHARD_OPT_GRAD_PARAM = auto() 

34 

35class GroupInfo: 

36 """ 

37 GroupInfo 

38 """ 

39 def __init__(self, group_name, group, rank_size): 

40 self.group_name = group_name 

41 self.group = group 

42 self.rank_size = rank_size 

43 

44 

45class HSDPConfigV2: 

46 """HSDPConfigV2 inspect by torch fully_shard""" 

47 

48 def __init__(self, mesh, reshard_after_forward, shard_placement_fn, mp_policy, offload_policy, ignored_param, 

49 reduce_dtype=None, comm_async=False, comm_fusion=False, bucket_size=-1): 

50 """ 

51 HSDP config init method 

52 Args: 

53 shard_size: optimizer weight sharded size. 

54 threshold: minimum weight size to shard. 

55 requires_acc_grad: requires gradient accumulation. 

56 grad_scale: use grad_scale to scale grad. 

57 shard_level: optimizer shard level. 

58 use_eager_hook: use eager hook or graph hook to implement hsdp. 

59 reduce_dtype: set gradient reduce dtype. 

60 comm_async: use async communication op for grad reduction. 

61 comm_fusion: use communication op fusion to reduce the number of communication op. 

62 bucket_size: the size of comm fusion buffer. 

63 """ 

64 self.mesh = mesh 

65 self.reshard_after_forward = reshard_after_forward 

66 self.shard_placement_fn = shard_placement_fn 

67 self.mp_policy = mp_policy 

68 self.offload_policy = offload_policy 

69 self.reduce_dtype = self.mp_policy.reduce_dtype if self.mp_policy else None 

70 # TODO: 下方属性待删除 

71 self.comm_async = False 

72 self.comm_fusion = False 

73 self.bucket_size = 9999 

74 self.grad_fusion = False 

75 

76class ShardedState(Enum): 

77 """ 

78 Parameter shard state 

79 """ 

80 SHARDED = auto() 

81 SHARDED_POST_FORWARD = auto() 

82 UNSHARDED = auto() 

83 

84class FSDPSchedulerState(Enum): 

85 """ 

86 Scheduler state: 

87 - PRE_FORWARD: 

88 already run hook before forward. 

89 - FORWARD: 

90 already run hook after forward. 

91 - PRE_BACKWARD: 

92 already run hook before backward. 

93 - PRE_BACKWARD: 

94 already run hook after backward. 

95 """ 

96 PRE_FORWARD = auto() 

97 FORWARD = auto() 

98 PRE_BACKWARD = auto() 

99 BACKWARD = auto() 

100 

101 

102@dataclass 

103class ParamModuleInfo: 

104 """ 

105 Tracks parameter ownership and supports shared weights in HSDP. 

106 

107 This dataclass maintains the mapping between a parameter and its module(s), 

108 enabling parameter swapping during sharding/unsharding transitions. Shared 

109 weights are parameters referenced by multiple modules (e.g., tied embeddings). 

110  

111 This class tracks all references to ensure proper parameter replacement during  

112 sharding/unsharding operations. 

113 

114 Attributes: 

115 module: The module that owns this parameter. 

116 param_name: Attribute name of the parameter in the module (e.g., "weight"). 

117 shared_modules: List of other modules sharing this same parameter object. 

118 shared_param_names: Corresponding parameter names in shared_modules (aligned by index). 

119 """ 

120 module: nn.Module 

121 param_name: str 

122 shared_modules: List[nn.Module] = field(default_factory=list) 

123 shared_param_names: List[str] = field(default_factory=list) 

124 

125 

126@dataclass 

127class ExtensionsData: 

128 """ 

129 Stores metadata for custom all-gather extensions. 

130 

131 This enables users to implement custom pre/post all-gather transforms 

132 by passing metadata between the two phases. The input sizes are saved 

133 to properly reshape the gathered outputs back to their original dimensions. 

134 

135 Attributes: 

136 all_gather_metadata: Custom metadata passed from pre to post all-gather. 

137 all_gather_input_sizes: Original tensor shapes before flattening for all-gather. 

138 """ 

139 all_gather_metadata: Optional[Any] = None 

140 all_gather_input_sizes: Sequence[Tuple[int, ...]] = () 

141 

142 def clear(self): 

143 """Reset all extension data to default values.""" 

144 self.all_gather_metadata = None 

145 self.all_gather_input_sizes = () 

146 

147 

148def _named_parameters_with_duplicates( 

149 module: nn.Module, **kwargs: Any 

150) -> list[tuple[str, nn.Parameter]]: 

151 """ 

152 This API is required as some modules overwrite `named_parameters()` but do not support 

153 `remove_duplicate`. 

154 """ 

155 if "remove_duplicate" in kwargs: 

156 raise AssertionError( 

157 "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument." 

158 ) 

159 kwargs["remove_duplicate"] = False 

160 try: 

161 ret = list(module.named_parameters(**kwargs)) 

162 except AssertionError: 

163 kwargs.pop("remove_duplicate") 

164 ret = list(module.named_parameters(**kwargs)) 

165 return ret 

166 

167def _get_param_module_infos( 

168 params: list[nn.Parameter], modules: tuple[nn.Module, ...] 

169) -> list['ParamModuleInfo']: 

170 """ 

171 Shared parameter: lin1.weight = lin2.weight 

172 Shared module: mlp.lin1 = mlp.lin2 

173 We do not remove duplicates when traversing both modules and parameters to 

174 find shared modules' parameters and shared parameters within a module. 

175 """ 

176 params_set = set(params) 

177 param_to_module_info: dict[nn.Parameter, ParamModuleInfo] = {} 

178 for module in modules: 

179 for _, submodule in module.named_modules(remove_duplicate=False): 

180 for param_name, param in _named_parameters_with_duplicates( 

181 submodule, recurse=False 

182 ): 

183 if param in params_set: 

184 if param not in param_to_module_info: 

185 param_to_module_info[param] = ParamModuleInfo( 

186 submodule, param_name 

187 ) 

188 else: 

189 param_to_module_info[param].shared_modules.append(submodule) 

190 param_to_module_info[param].shared_param_names.append( 

191 param_name 

192 ) 

193 if len(param_to_module_info) != len(params): 

194 raise AssertionError(f"Some parameters are not in the module tree of {modules}") 

195 return [param_to_module_info[param] for param in params]