Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / fully_shard / hsdp_param.py: 75%

150 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP parameter""" 

16 

17from hyper_parallel.core.dtensor.device_mesh import DeviceMesh 

18from hyper_parallel.core.dtensor.dtensor import DTensor 

19from hyper_parallel.core.dtensor.placement_types import Replicate 

20from hyper_parallel.core.fully_shard.hsdp_utils import ( 

21 FullyShardParamMode, 

22 GroupInfo, 

23 get_rank_list_for_axes, 

24 get_split_rank_lists_for_axes, 

25) 

26from hyper_parallel.core.fully_shard.utils import DDPMeshInfo, FSDPMeshInfo 

27from hyper_parallel.platform import get_platform 

28 

29platform = get_platform() 

30_GROUP_INFO_CACHE = {} 

31 

32 

33def _build_group_info_from_rank_list(group_name: str, rank_list) -> GroupInfo: 

34 """Create group metadata from an explicit rank list.""" 

35 normalized_rank_list = tuple(sorted(int(rank) for rank in rank_list)) 

36 if len(normalized_rank_list) <= 1: 

37 return GroupInfo(f"{group_name}_invalid", None, 1) 

38 if normalized_rank_list in _GROUP_INFO_CACHE: 

39 cached_group = _GROUP_INFO_CACHE[normalized_rank_list] 

40 return GroupInfo(str(normalized_rank_list), cached_group, len(normalized_rank_list)) 

41 try: 

42 group = platform.create_group(list(normalized_rank_list)) 

43 except (RuntimeError, ValueError): # pragma: no cover - UT may run without dist init 

44 group = None 

45 _GROUP_INFO_CACHE[normalized_rank_list] = group 

46 return GroupInfo(str(normalized_rank_list), group, len(normalized_rank_list)) 

47 

48 

49def _build_group_info_from_process_group( 

50 group_name: str, 

51 process_group, 

52 rank_size: int, 

53 *, 

54 resolved_group_name: str | None = None, 

55) -> GroupInfo: 

56 """Create group metadata from an existing process group.""" 

57 if process_group is None or rank_size <= 1: 

58 return GroupInfo(f"{group_name}_invalid", None, 1) 

59 return GroupInfo(resolved_group_name or group_name, process_group, rank_size) 

60 

61 

62class HSDPParamV2: 

63 """ 

64 HSDP parameter. 

65 """ 

66 

67 def __init__( 

68 self, 

69 param, 

70 module_info, 

71 mesh_info, 

72 post_forward_mesh_info, 

73 shard_placement_fn, 

74 mp_policy, 

75 offload_policy, 

76 threshold, 

77 ): 

78 """ 

79 Initialize HSDPParamV2. 

80 

81 Args: 

82 param (nn.Parameter): The original parameter to shard. 

83 module_info (ParamModuleInfo): Ownership and shared-weight metadata for the parameter. 

84 mesh_info (FSDPMeshInfo): Mesh topology describing shard/replicate dimensions. 

85 post_forward_mesh_info: Mesh info used after forward (reserved for subclass use). 

86 shard_placement_fn (Callable, optional): Returns a Shard placement for the parameter, 

87 or None to use default (Shard(0)). 

88 mp_policy (MixedPrecisionPolicy, optional): Mixed precision dtype policy. 

89 offload_policy (OffloadPolicy, optional): CPU offload policy. 

90 threshold: Minimum parameter size to enable sharding (reserved for subclass use). 

91 """ 

92 raise NotImplementedError("HSDP param subclasses must implement __init__") 

93 

94 def _init_sharded_param(self, param, shard_placement_fn): 

95 """add and init sharded param""" 

96 raise NotImplementedError("HSDP param subclasses must implement _init_sharded_param") 

97 

98 def init_dtype_attrs(self, mp_policy): 

99 """Initialize dtype attributes from mixed precision policy.""" 

100 raise NotImplementedError("HSDP param subclasses must implement init_dtype_attrs") 

101 

102 def init_all_gather_outputs( 

103 self, all_gather_input_numels, all_gather_input_dtypes, world_size, device, force_recreate=False 

104 ): 

105 """Allocate or reuse output buffers for all-gather communication.""" 

106 raise NotImplementedError("HSDP param subclasses must implement init_all_gather_outputs") 

107 

108 def init_unsharded_param(self): 

109 """Reconstruct the full unsharded parameter from all-gather outputs.""" 

110 raise NotImplementedError("HSDP param subclasses must implement init_unsharded_param") 

111 

112 def to_sharded(self): 

113 """Transition parameter from unsharded back to sharded state and free unsharded storage.""" 

114 raise NotImplementedError("HSDP param subclasses must implement to_sharded") 

115 

116 def to_unsharded(self): 

117 """Transition parameter to unsharded state after all-gather completes.""" 

118 raise NotImplementedError("HSDP param subclasses must implement to_unsharded") 

119 

120 def to_sharded_dtensor(self, tensor): 

121 """Wrap a local sharded tensor as a DTensor with the correct mesh and placements.""" 

122 raise NotImplementedError("HSDP param subclasses must implement to_sharded_dtensor") 

123 

124 def to_accumulated_grad_if_needed(self): 

125 """Move unsharded grad to accumulated grad buffer if dtype conversion is required.""" 

126 raise NotImplementedError("HSDP param subclasses must implement to_accumulated_grad_if_needed") 

127 

128 def accumulate_unsharded_grad_if_needed(self): 

129 """Accumulate unsharded param grad into accumulated grad buffer if both exist.""" 

130 raise NotImplementedError("HSDP param subclasses must implement accumulate_unsharded_grad_if_needed") 

131 

132 def alloc_all_gather_outputs(self): 

133 """Resize all-gather output buffers to their full capacity for communication.""" 

134 raise NotImplementedError("HSDP param subclasses must implement alloc_all_gather_outputs") 

135 

136 def free_unsharded_param(self): 

137 """Release storage of all-gather outputs and inner tensors to free device memory.""" 

138 raise NotImplementedError("HSDP param subclasses must implement free_unsharded_param") 

139 

140 @property 

141 def all_gather_inputs(self): 

142 """Return the local sharded tensor(s) to use as input for all-gather communication.""" 

143 raise NotImplementedError("HSDP param subclasses must implement all_gather_inputs") 

144 

145 @property 

146 def unsharded_param(self): 

147 """Return the full unsharded parameter after all-gather.""" 

148 raise NotImplementedError("HSDP param subclasses must implement unsharded_param") 

149 

150 @property 

151 def unsharded_grad_data(self): 

152 """Return the unsharded_param.grad.""" 

153 raise NotImplementedError("HSDP param subclasses must implement unsharded_grad_data") 

154 

155 @property 

156 def unsharded_accumulated_grad_data(self): 

157 """Return the unsharded accumulated gradient buffer.""" 

158 raise NotImplementedError("HSDP param subclasses must implement unsharded_accumulated_grad_data") 

159 

160 @property 

161 def _sharded_local_tensor(self): 

162 """Return the underlying local tensor of the sharded DTensor parameter.""" 

163 raise NotImplementedError("HSDP param subclasses must implement _sharded_local_tensor") 

164 

165 def _get_unsharded_param_data(self, async_op=False): 

166 """Perform all-gather to obtain unsharded parameter data, returning (tensor, handle).""" 

167 raise NotImplementedError("HSDP param subclasses must implement _get_unsharded_param_data") 

168 

169 def unshard(self, async_op=False): 

170 """Trigger all-gather to unshard the parameter, optionally asynchronously.""" 

171 raise NotImplementedError("HSDP param subclasses must implement unshard") 

172 

173 def wait_for_unshard(self): 

174 """Wait for all-gather to complete and transition parameter to unsharded state.""" 

175 raise NotImplementedError("HSDP param subclasses must implement wait_for_unshard") 

176 

177 def shard(self): 

178 """Transition parameter from unsharded back to sharded state.""" 

179 raise NotImplementedError("HSDP param subclasses must implement shard") 

180 

181 def reduce_scatter_grad(self): 

182 """Perform reduce-scatter on the unsharded gradient to produce a sharded gradient.""" 

183 raise NotImplementedError("HSDP param subclasses must implement reduce_scatter_grad") 

184 

185 def all_reduce_grad(self): 

186 """Perform all-reduce on gradient across the replicate dimension (HSDP mode only).""" 

187 raise NotImplementedError("HSDP param subclasses must implement all_reduce_grad") 

188 

189 def _resolve_process_group_name(self, group_name: str, process_group) -> str: 

190 """Resolve the name recorded in GroupInfo for an existing process group.""" 

191 del process_group 

192 return group_name 

193 

194 def _get_base_spmd_placements(self) -> tuple: 

195 """Return placements before explicit data-parallel semantics are applied.""" 

196 if ( 

197 getattr(self, "param_mode", None) == FullyShardParamMode.DTENSOR_UNIFIED 

198 and getattr(self, "_orig_param_is_dtensor", False) 

199 ): 

200 self._spmd_mesh = DeviceMesh.concatenate([self.mesh_info.mesh, self._orig_dtensor_mesh]) 

201 dp_prefix_placements = tuple(Replicate() for _ in range(self.mesh_info.mesh.ndim)) 

202 return dp_prefix_placements + tuple(self._orig_dtensor_placements) 

203 

204 if ( 

205 getattr(self, "param_mode", None) == FullyShardParamMode.DTENSOR_COMPAT 

206 and getattr(self, "_orig_param_is_dtensor", False) 

207 ): 

208 self._spmd_mesh = self._orig_dtensor_mesh 

209 return tuple(self._orig_dtensor_placements) 

210 

211 self._spmd_mesh = self.mesh_info.mesh 

212 return tuple(Replicate() for _ in range(self._spmd_mesh.ndim)) 

213 

214 def _get_data_parallel_shard_placement(self, placements: list, shard_placement): 

215 """Return the placement to apply on the explicit fully_shard dimension.""" 

216 del placements 

217 return shard_placement 

218 

219 def _apply_data_parallel_placements(self, placements: list, shard_placement) -> tuple: 

220 """Apply explicit DDP/FSDP placements on top of the base SPMD layout.""" 

221 if len(placements) != self._spmd_mesh.ndim: 

222 raise AssertionError( 

223 f"Expected {self._spmd_mesh.ndim} unified placements, got {len(placements)}: {placements}" 

224 ) 

225 

226 spmd_replicate_mesh_dim = getattr(self, "_spmd_replicate_mesh_dim", None) 

227 if ( 

228 isinstance(self.mesh_info, DDPMeshInfo) 

229 and spmd_replicate_mesh_dim is not None 

230 and not getattr(self, "_orig_param_is_dtensor", False) 

231 ): 

232 placements[spmd_replicate_mesh_dim] = Replicate() 

233 

234 spmd_shard_mesh_dim = getattr(self, "_spmd_shard_mesh_dim", None) 

235 if ( 

236 getattr(self, "uses_param_shard", False) 

237 and isinstance(self.mesh_info, FSDPMeshInfo) 

238 and spmd_shard_mesh_dim is not None 

239 ): 

240 placements[spmd_shard_mesh_dim] = self._get_data_parallel_shard_placement( 

241 placements, shard_placement 

242 ) 

243 return tuple(placements) 

244 

245 def _init_group_infos(self) -> None: 

246 """Initialize sharded/unsharded communication groups from the current layout.""" 

247 if ( 

248 getattr(self, "uses_param_shard", False) 

249 and getattr(self, "is_sharded", False) 

250 and isinstance(self.mesh_info, FSDPMeshInfo) 

251 ): 

252 resolved_group_name = self._resolve_process_group_name( 

253 "fully_shard_sharded_group", 

254 self.mesh_info.shard_process_group, 

255 ) 

256 self.sharded_group_info = _build_group_info_from_process_group( 

257 "fully_shard_sharded_group", 

258 self.mesh_info.shard_process_group, 

259 self.mesh_info.shard_mesh_size, 

260 resolved_group_name=resolved_group_name, 

261 ) 

262 else: 

263 self.sharded_group_info = GroupInfo("fully_shard_sharded_group_invalid", None, 1) 

264 

265 self.unsharded_group_info = self._build_layout_driven_group_info() 

266 self.shard_size = self.sharded_group_info.rank_size 

267 self.dp_size = self.unsharded_group_info.rank_size 

268 self.rank_size = max(1, self.shard_size * self.dp_size) 

269 

270 def _build_layout_driven_group_info(self) -> GroupInfo: 

271 """Build the group that should all-reduce an unsharded gradient from the final layout.""" 

272 group_axes = [ 

273 axis 

274 for axis, placement in enumerate(self._spmd_placements) 

275 if placement.is_replicate() 

276 ] 

277 spmd_shard_mesh_dim = getattr(self, "_spmd_shard_mesh_dim", None) 

278 if getattr(self, "uses_param_shard", False) and spmd_shard_mesh_dim is not None: 

279 group_axes = [axis for axis in group_axes if axis != spmd_shard_mesh_dim] 

280 if not group_axes: 

281 return GroupInfo("fully_shard_unsharded_group_invalid", None, 1) 

282 

283 group_dim_names = getattr(self._spmd_mesh, "mesh_dim_names", None) 

284 if group_dim_names: 

285 try: 

286 mesh_axis_names = tuple(group_dim_names[axis] for axis in group_axes) 

287 if len(mesh_axis_names) == 1: 

288 axis_name = mesh_axis_names[0] 

289 process_group = self._spmd_mesh.get_group(axis_name) 

290 if process_group is not None: 

291 rank_size = self._spmd_mesh.mesh_shape[group_dim_names.index(axis_name)] 

292 resolved_group_name = self._resolve_process_group_name( 

293 "fully_shard_unsharded_group", 

294 process_group, 

295 ) 

296 return _build_group_info_from_process_group( 

297 "fully_shard_unsharded_group", 

298 process_group, 

299 rank_size, 

300 resolved_group_name=resolved_group_name, 

301 ) 

302 

303 split_rank_lists = get_split_rank_lists_for_axes(self._spmd_mesh, group_axes) 

304 process_group = platform.split_group(split_ranks=split_rank_lists) 

305 if process_group is not None: 

306 rank_size = 1 

307 for axis in group_axes: 

308 rank_size *= self._spmd_mesh.mesh_shape[axis] 

309 resolved_group_name = self._resolve_process_group_name( 

310 "fully_shard_unsharded_group", 

311 process_group, 

312 ) 

313 return _build_group_info_from_process_group( 

314 "fully_shard_unsharded_group", 

315 process_group, 

316 rank_size, 

317 resolved_group_name=resolved_group_name, 

318 ) 

319 except ( 

320 AssertionError, 

321 AttributeError, 

322 KeyError, 

323 RuntimeError, 

324 TypeError, 

325 ValueError, 

326 ): 

327 pass 

328 

329 rank_list = get_rank_list_for_axes(self._spmd_mesh, group_axes) 

330 return _build_group_info_from_rank_list("fully_shard_unsharded_group", rank_list) 

331 

332 def _normalize_unsharded_grad_to_local(self, grad, *, reduce_partial_dtensor: bool = True): 

333 """Normalize a pending gradient to the local tensor expected by fully_shard collectives.""" 

334 if not isinstance(grad, DTensor): 

335 return grad 

336 

337 if reduce_partial_dtensor and any(placement.is_partial() for placement in grad.placements): 

338 grad = grad.reduce_partial() 

339 

340 orig_dtensor_mesh = getattr(self, "_orig_dtensor_mesh", None) 

341 orig_dtensor_placements = getattr(self, "_orig_dtensor_placements", None) 

342 if ( 

343 orig_dtensor_mesh is not None 

344 and grad.device_mesh.to_hash() != orig_dtensor_mesh.to_hash() 

345 ) or ( 

346 orig_dtensor_placements is not None 

347 and tuple(grad.placements) != tuple(orig_dtensor_placements) 

348 ): 

349 grad = grad.redistribute(orig_dtensor_mesh, orig_dtensor_placements) 

350 return grad.to_local()