Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / fully_shard / hsdp_utils.py: 79%

160 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP optimizer shared level""" 

16from dataclasses import dataclass, field 

17from enum import auto, Enum 

18from typing import Any, List, Optional, Sequence 

19 

20import numpy as np 

21 

22from hyper_parallel.core.dtensor.device_mesh import DeviceMesh 

23from hyper_parallel.core.dtensor.dtensor import DTensor 

24from hyper_parallel.platform import get_platform 

25from hyper_parallel.platform.platform import PlatformType 

26 

27platform = get_platform() 

28 

29 

30class HSDPConfigV2: 

31 """HSDPConfigV2 inspect by torch fully_shard""" 

32 

33 def __init__(self, 

34 mesh, 

35 reshard_after_forward, 

36 shard_placement_fn, 

37 mp_policy, 

38 offload_policy, 

39 ignored_params=None, 

40 replicate_params=None, 

41 comm_fusion=False, 

42 comm_fusion_zero_copy=False, 

43 ): 

44 self.mesh = mesh 

45 self.reshard_after_forward = reshard_after_forward 

46 self.shard_placement_fn = shard_placement_fn 

47 self.mp_policy = mp_policy 

48 self.offload_policy = offload_policy 

49 self.ignored_params = ignored_params 

50 self.replicate_params = replicate_params 

51 self.reduce_dtype = self.mp_policy.reduce_dtype if self.mp_policy else None 

52 self.comm_fusion = comm_fusion 

53 self.comm_fusion_zero_copy = comm_fusion_zero_copy 

54 

55 

56class ShardedState(Enum): 

57 """ 

58 Parameter shard state 

59 """ 

60 SHARDED = auto() 

61 UNSHARDED = auto() 

62 

63 

64class FullyShardParamMode(Enum): 

65 """ 

66 Internal fully_shard execution modes derived from the original parameter layout. 

67 

68 LOCAL_PARAM: 

69 The parameter is a regular local tensor parameter and fully_shard owns the 

70 full data-parallel sharding behaviour. 

71 DTENSOR_COMPAT: 

72 The parameter already carries a DTensor layout and fully_shard is only used 

73 as the compatibility wrapper without adding an extra FSDP shard dimension. 

74 DTENSOR_UNIFIED: 

75 The parameter already carries a DTensor layout and fully_shard additionally 

76 contributes a data-parallel/FSDP mesh that must be unified with the 

77 existing distributed layout. 

78 """ 

79 

80 LOCAL_PARAM = auto() 

81 DTENSOR_COMPAT = auto() 

82 DTENSOR_UNIFIED = auto() 

83 

84 

85@dataclass 

86class GroupInfo: 

87 """Communication group metadata used by fully_shard.""" 

88 

89 group_name: str 

90 group: Any 

91 rank_size: int 

92 

93 

94class FSDPSchedulerState(Enum): 

95 """ 

96 Scheduler state: 

97 - PRE_FORWARD: 

98 already run hook before forward. 

99 - FORWARD: 

100 already run hook after forward. 

101 - PRE_BACKWARD: 

102 already run hook before backward. 

103 - BACKWARD: 

104 already run hook after backward. 

105 """ 

106 PRE_FORWARD = auto() 

107 FORWARD = auto() 

108 PRE_BACKWARD = auto() 

109 BACKWARD = auto() 

110 

111 

112@dataclass 

113class ParamModuleInfo: 

114 """ 

115 Tracks parameter ownership and supports shared weights in HSDP. 

116 

117 This dataclass maintains the mapping between a parameter and its module(s), 

118 enabling parameter swapping during sharding/unsharding transitions. Shared 

119 weights are parameters referenced by multiple modules (e.g., tied embeddings). 

120 

121 This class tracks all references to ensure proper parameter replacement during  

122 sharding/unsharding operations. 

123 

124 Attributes: 

125 module: The module that owns this parameter. 

126 param_name: Attribute name of the parameter in the module (e.g., "weight"). 

127 shared_modules: List of other modules sharing this same parameter object. 

128 shared_param_names: Corresponding parameter names in shared_modules (aligned by index). 

129 """ 

130 module: platform.Module 

131 param_name: str 

132 shared_modules: List[platform.Module] = field(default_factory=list) 

133 shared_param_names: List[str] = field(default_factory=list) 

134 

135 

136def _named_parameters_with_duplicates( 

137 module: platform.Module, **kwargs: Any 

138) -> list[tuple[str, platform.Parameter]]: 

139 """ 

140 This API is required as some modules overwrite `named_parameters()` but do not support 

141 `remove_duplicate`. 

142 """ 

143 if "remove_duplicate" in kwargs: 

144 raise AssertionError( 

145 "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument." 

146 ) 

147 

148 def get_named_parameters(module, **kwargs): 

149 if platform.platform_type == PlatformType.PYTORCH: 

150 return module.named_parameters(**kwargs) 

151 return module.parameters_and_names(expand=False) 

152 kwargs["remove_duplicate"] = False 

153 try: 

154 ret = list(get_named_parameters(module, **kwargs)) 

155 except AssertionError: 

156 kwargs.pop("remove_duplicate") 

157 ret = list(get_named_parameters(module, **kwargs)) 

158 return ret 

159 

160 

161def _get_param_module_infos( 

162 params: list[platform.Parameter], modules: tuple[platform.Module, ...] 

163) -> list['ParamModuleInfo']: 

164 """ 

165 Shared parameter: lin1.weight = lin2.weight 

166 Shared module: mlp.lin1 = mlp.lin2 

167 We do not remove duplicates when traversing both modules and parameters to 

168 find shared modules' parameters and shared parameters within a module. 

169 """ 

170 params_set = set(params) 

171 param_to_module_info: dict[platform.Parameter, ParamModuleInfo] = {} 

172 

173 def get_named_modules(module): 

174 if platform.platform_type == PlatformType.PYTORCH: 

175 return module.named_modules(remove_duplicate=False) 

176 return module.cells_and_names() 

177 

178 for module in modules: 

179 for _, submodule in get_named_modules(module): 

180 for param_name, param in _named_parameters_with_duplicates( 

181 submodule, recurse=False 

182 ): 

183 if param in params_set: 

184 if param not in param_to_module_info: 

185 param_to_module_info[param] = ParamModuleInfo( 

186 submodule, param_name 

187 ) 

188 else: 

189 param_to_module_info[param].shared_modules.append(submodule) 

190 param_to_module_info[param].shared_param_names.append( 

191 param_name 

192 ) 

193 if len(param_to_module_info) != len(params): 

194 raise AssertionError(f"Some parameters are not in the module tree of {modules}") 

195 return [param_to_module_info[param] for param in params] 

196 

197 

198def get_managed_modules_parameters( 

199 modules: Sequence[platform.Module], 

200 ignored_params: Optional[Sequence[platform.Parameter]] = None, 

201) -> list[platform.Parameter]: 

202 """Collect deduplicated parameters from ``modules`` while skipping ignored params. 

203 

204 Parameters that were already initialized by an inner ``fully_shard`` instance 

205 are intentionally excluded so nested ``fully_shard(mesh=None)`` resolves mesh 

206 mode from the parameters that the current wrapper will actually manage. 

207 """ 

208 params: list[platform.Parameter] = [] 

209 ignored_params_set = set(ignored_params or ()) 

210 visited_params: set[platform.Parameter] = set() 

211 for mod in modules: 

212 for _, param in platform.parameters_dict(mod): 

213 if param in ignored_params_set or param in visited_params: 

214 continue 

215 if getattr(param, "_hsdp_param_initialized", False): 

216 continue 

217 visited_params.add(param) 

218 params.append(param) 

219 return params 

220 

221 

222def infer_fully_shard_param_mode( 

223 mesh: Optional[DeviceMesh], 

224 params: Optional[Sequence[Any]] = None, 

225) -> FullyShardParamMode: 

226 """ 

227 Infer the internal fully_shard execution mode from parameter layout and mesh. 

228 

229 The mode is intentionally phrased around whether parameters already carry a 

230 distributed layout instead of assuming the layout came from TP only. DTensor 

231 parameters may originate from TP, EP, or other distributed sharding paths. 

232 """ 

233 has_dtensor_param = any(is_dtensor_managed_param(param) for param in params or ()) 

234 if not has_dtensor_param: 

235 return FullyShardParamMode.LOCAL_PARAM 

236 if mesh is None: 

237 return FullyShardParamMode.DTENSOR_COMPAT 

238 return FullyShardParamMode.DTENSOR_UNIFIED 

239 

240 

241def unwrap_dtensor_param(param: Any) -> Optional[DTensor]: 

242 """Return the DTensor payload carried by ``param`` if one exists.""" 

243 if isinstance(param, DTensor): 

244 return param 

245 param_data = getattr(param, "data", None) 

246 if isinstance(param_data, DTensor): 

247 return param_data 

248 if all(hasattr(param, attr) for attr in ("_device_mesh", "_placements", "_local_tensor")): 

249 return param 

250 return None 

251 

252 

253def is_dtensor_managed_param(param: Any) -> bool: 

254 """Return whether a parameter already carries DTensor layout metadata.""" 

255 return unwrap_dtensor_param(param) is not None 

256 

257 

258def get_dtensor_managed_mesh(param: Any) -> Optional[DeviceMesh]: 

259 """Return the DTensor mesh carried by ``param`` if one exists.""" 

260 payload = unwrap_dtensor_param(param) 

261 if payload is None: 

262 return None 

263 return getattr(payload, "device_mesh", getattr(payload, "_device_mesh", None)) 

264 

265 

266def get_rank_list_for_axes( 

267 mesh: DeviceMesh, 

268 axes: Sequence[int], 

269 rank: Optional[int] = None, 

270) -> list[int]: 

271 """Return ranks that vary along ``axes`` and keep all other coordinates fixed.""" 

272 if rank is None: 

273 rank = mesh.rank 

274 if rank not in mesh.rank_list: 

275 raise ValueError(f"Rank {rank} not found in mesh rank list {mesh.rank_list}.") 

276 

277 normalized_axes = tuple(sorted(set(axes))) 

278 if len(normalized_axes) == 0: 

279 return [rank] 

280 

281 mesh_tensor = np.array(mesh.rank_list).reshape(mesh.mesh_shape) 

282 rank_index = mesh.rank_list.index(rank) 

283 coord = [0] * len(mesh.mesh_shape) 

284 temp = rank_index 

285 for i in range(len(mesh.mesh_shape) - 1, -1, -1): 

286 coord[i] = temp % mesh.mesh_shape[i] 

287 temp //= mesh.mesh_shape[i] 

288 mesh_slice = [] 

289 for axis, axis_coord in enumerate(coord): 

290 mesh_slice.append(slice(None) if axis in normalized_axes else axis_coord) 

291 selected = mesh_tensor[tuple(mesh_slice)] 

292 return [int(item) for item in np.array(selected).reshape(-1).tolist()] 

293 

294 

295def get_split_rank_lists_for_axes( 

296 mesh: DeviceMesh, 

297 axes: Sequence[int], 

298) -> list[list[int]]: 

299 """Return all rank lists induced by varying ``axes`` and fixing the complementary axes.""" 

300 normalized_axes = tuple(sorted(set(axes))) 

301 if len(normalized_axes) == 0: 

302 return [[int(rank) for rank in mesh.rank_list]] 

303 

304 mesh_tensor = np.array(mesh.rank_list).reshape(mesh.mesh_shape) 

305 complementary_axes = tuple( 

306 axis for axis in range(len(mesh.mesh_shape)) if axis not in normalized_axes 

307 ) 

308 if len(complementary_axes) == 0: 

309 return [[int(item) for item in np.array(mesh_tensor).reshape(-1).tolist()]] 

310 

311 complementary_shape = tuple(mesh.mesh_shape[axis] for axis in complementary_axes) 

312 split_rank_lists: list[list[int]] = [] 

313 for complementary_coord in np.ndindex(*complementary_shape): 

314 mesh_slice = [] 

315 coord_idx = 0 

316 for axis in range(len(mesh.mesh_shape)): 

317 if axis in normalized_axes: 

318 mesh_slice.append(slice(None)) 

319 else: 

320 mesh_slice.append(complementary_coord[coord_idx]) 

321 coord_idx += 1 

322 selected = mesh_tensor[tuple(mesh_slice)] 

323 split_rank_lists.append([int(item) for item in np.array(selected).reshape(-1).tolist()]) 

324 return split_rank_lists 

325 

326def get_hsdp_state(module): 

327 """Return the HSDPState for a fully_shard-managed module, or None.""" 

328 from hyper_parallel.core.fully_shard.api import HSDPModule # pylint: disable=C0415 

329 if isinstance(module, HSDPModule): 

330 scheduler = getattr(module, "hsdp_scheduler", None) 

331 if scheduler is not None: 

332 return scheduler.hsdp_state 

333 return None