Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / fully_shard / param.py: 42%

414 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP parameter""" 

16from typing import List, Callable, Optional, cast, Tuple 

17import itertools 

18import mindspore as ms 

19from mindspore import nn 

20from mindspore.common.api import _no_grad 

21from mindspore import ops, Parameter 

22import mindspore.mint.distributed as dist 

23from mindspore.ops.function.comm_func import CommHandle 

24from hyper_parallel.core.fully_shard.utils import ( 

25 MixedPrecisionPolicy, 

26 CPUOffloadPolicy, 

27 OffloadPolicy, 

28 FSDPMeshInfo, 

29 HSDPMeshInfo, 

30) 

31from hyper_parallel.core.dtensor.dtensor import DTensor 

32from hyper_parallel.core.dtensor.layout import Layout 

33from hyper_parallel.core.fully_shard.hsdp_param import HSDPParamV2 

34from hyper_parallel.core.fully_shard.hsdp_utils import ( 

35 ShardedState, 

36 FullyShardParamMode, 

37 unwrap_dtensor_param, 

38) 

39from hyper_parallel.core.dtensor.placement_types import Shard, StridedShard 

40from hyper_parallel.core.fully_shard.hsdp_utils import ParamModuleInfo 

41from hyper_parallel.platform.mindspore.fully_shard._version_utils import copy_without_bumping_version 

42from hyper_parallel.platform.mindspore.utils import normalize_runtime_device 

43from hyper_parallel.platform.mindspore.fully_shard.pack_utils import ( 

44 build_rs_plan, 

45 pack_for_reduce_scatter, 

46 unpack_from_all_gather, 

47) 

48 

49 

50def _pack_for_reduce_scatter(local_tensor: ms.Tensor, shard_dim: int, world_size: int) -> ms.Tensor: 

51 """Pack one local gradient into the row-major reduce-scatter layout. 

52 

53 MindSpore currently aligns with the torch non-comm-fusion V1 path: 

54 

55 - shard on dim 0: identity flatten 

56 - shard on non-dim0: chunk on shard dim, then concatenate on dim 0 

57 """ 

58 if world_size <= 1 or shard_dim == 0: 

59 return local_tensor 

60 chunks = ms.mint.chunk(local_tensor, world_size, dim=shard_dim) 

61 return ms.mint.cat(chunks, dim=0).contiguous() 

62 

63 

64def _to_dtype_if_needed( 

65 tensor: ms.Tensor, dtype: Optional[ms.Type] 

66) -> ms.Tensor: 

67 """Cast tensor to the given dtype if it differs from current dtype.""" 

68 if dtype is not None and tensor.dtype != dtype: 

69 return tensor.to(dtype) 

70 return tensor 

71 

72 

73def make_contiguous_strides_for(shape, row_major=True): 

74 """ 

75 Compute strides for a contiguous tensor of the given shape. 

76 

77 Args: 

78 shape (tuple of int): The shape of the tensor. Each dimension must be a non-negative integer. 

79 row_major (bool):  

80 - If True (default), returns C-style (row-major) strides: last dimension changes fastest. 

81 - If False, returns strides where the last two dimensions are Fortran-style  

82 (i.e., for batched matrix operations in BLAS/LAPACK): second-to-last dim changes fastest. 

83 

84 Returns: 

85 tuple of int: The computed strides. 

86 

87 Examples: 

88 >>> make_contiguous_strides_for((2, 3, 4)) 

89 (12, 4, 1) 

90 >>> make_contiguous_strides_for((2, 3, 4), row_major=False) 

91 (12, 1, 3) 

92 >>> make_contiguous_strides_for((5,)) 

93 (1,) 

94 >>> make_contiguous_strides_for((5,), row_major=False) 

95 (1,) 

96 >>> make_contiguous_strides_for(()) 

97 () 

98 """ 

99 if not isinstance(shape, (tuple, list)): 

100 raise TypeError("shape must be a tuple or list of non-negative integers") 

101 

102 # Validate shape elements 

103 for dim in shape: 

104 if not isinstance(dim, int) or dim < 0: 

105 raise ValueError("All dimensions in shape must be non-negative integers") 

106 

107 if not shape: 

108 return () 

109 

110 # Compute C-style (row-major) strides: stride[i] = product(shape[i+1:]) 

111 strides = [] 

112 multiplier = 1 

113 # Traverse shape in reverse order 

114 for size in reversed(shape): 

115 strides.append(multiplier) 

116 multiplier *= max(size, 1) # handle size=0 gracefully (treat as 1 for stride calc) 

117 

118 # Reverse to get correct order 

119 c_strides = tuple(reversed(strides)) 

120 

121 if row_major: 

122 return c_strides 

123 # For column-major: only affect last two dimensions 

124 if len(shape) < 2: 

125 return c_strides 

126 # In Fortran-style for matrices: 

127 # stride of last dim = 1 

128 # stride of second-to-last dim = shape[-1] 

129 # But note: in batched case (..., M, N), we want strides (..., N, 1) → wait! 

130 # However, the original PyTorch logic returns: result[:-2] + (1, max(shape[-2], 1)) 

131 # Let's follow that exactly: 

132 # Example: shape=(B, M, N) → c_strides=(M*N, N, 1) 

133 # col-major → (M*N, 1, M) 

134 # So: keep all but last two, then (1, shape[-2]) 

135 return c_strides[:-2] + (1, max(shape[-2], 1)) 

136 

137 

138class MindSporeHSDPParamV2(HSDPParamV2): 

139 """ 

140 MindSpore HSDP parameter. 

141 """ 

142 

143 def __init__( 

144 self, 

145 param: Parameter, 

146 module_info: ParamModuleInfo, 

147 mesh_info: FSDPMeshInfo, 

148 shard_placement_fn: Optional[Callable[[Parameter], Optional[Shard]]] = None, 

149 mp_policy: Optional[MixedPrecisionPolicy] = None, 

150 offload_policy: Optional[OffloadPolicy] = None, 

151 device: Optional[str] = None, 

152 param_mode: Optional[FullyShardParamMode] = None, 

153 enable_fsdp_shard: bool = True, 

154 ): 

155 self._module_info: ParamModuleInfo = module_info 

156 self.mesh_info = mesh_info 

157 self.mp_policy = mp_policy 

158 self.device = device 

159 if param_mode is None: 

160 raise AssertionError("param_mode must be resolved before MindSporeHSDPParamV2 initialization.") 

161 self.param_mode = param_mode 

162 self.enable_fsdp_shard = enable_fsdp_shard 

163 self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy) 

164 self.pin_memory = ( 

165 self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory 

166 ) 

167 self.grad_offload_event: Optional[ms.runtime.Event] = None 

168 dtensor_payload = unwrap_dtensor_param(param) 

169 self._orig_param_is_dtensor = dtensor_payload is not None 

170 self._orig_dtensor_mesh = dtensor_payload.device_mesh if dtensor_payload is not None else None 

171 self._orig_dtensor_placements = ( 

172 tuple(dtensor_payload.placements) if dtensor_payload is not None else None 

173 ) 

174 self._spmd_shard_mesh_dim = getattr(self.mesh_info, "shard_mesh_dim", None) 

175 self._spmd_replicate_mesh_dim = getattr(self.mesh_info, "replicate_mesh_dim", None) 

176 self._init_sharded_param(param, shard_placement_fn) 

177 self._init_group_infos() 

178 self.all_gather_outputs: List[ms.Tensor] = [] 

179 self.unsharded_accumulated_grad = None 

180 self._unsharded_param: Optional[Parameter] = None 

181 self._param_fqn: Optional[str] = None 

182 # Communication attributes for prefetch pattern 

183 self.prefetch_handle: Optional[CommHandle] = None 

184 self._reduce_scatter_output = None 

185 self.reduce_scatter_handle: Optional[CommHandle] = None 

186 self._all_reduce_output = None 

187 self.all_reduce_handle: Optional[CommHandle] = None 

188 self._post_load_hook_handle = ( 

189 module_info.module.register_load_state_dict_post_hook( 

190 lambda *args, **kwargs: self.reset_sharded_param() 

191 ) 

192 ) 

193 

194 @property 

195 def uses_param_shard(self) -> bool: 

196 return self.enable_fsdp_shard 

197 

198 @property 

199 def is_dtensor_compat_mode(self) -> bool: 

200 return self.param_mode == FullyShardParamMode.DTENSOR_COMPAT 

201 

202 def _get_data_parallel_shard_placement(self, placements: list, shard_placement: Shard): 

203 """Return the explicit fully_shard placement on the unified SPMD mesh.""" 

204 split_factor = 1 

205 shard_mesh_dim = getattr(self, "_spmd_shard_mesh_dim", None) 

206 for mesh_idx, placement in enumerate(placements): 

207 if mesh_idx == shard_mesh_dim: 

208 continue 

209 if placement.is_shard(shard_placement.dim): 

210 split_factor *= self._spmd_mesh.mesh_shape[mesh_idx] 

211 if split_factor > 1: 

212 return StridedShard(shard_placement.dim, split_factor=split_factor) 

213 return shard_placement 

214 

215 def _release_full_param_storage_if_safe(self, param_data: ms.Tensor) -> None: 

216 """Release the temporary full-parameter storage once the sharded param is installed. 

217 

218 Skip storage reclamation only for meta tensors. Both plain Tensor inputs and DTensor local 

219 tensors should drop their original storage after the sharded Parameter has been installed 

220 onto the owning modules. 

221 """ 

222 if param_data.is_meta: 

223 return 

224 storage = param_data.untyped_storage() 

225 if storage.size() != 0: 

226 storage.resize_(0) 

227 

228 @_no_grad() 

229 def _init_sharded_param( 

230 self, 

231 param: Parameter, 

232 shard_placement_fn: Optional[Callable], 

233 ) -> None: 

234 param_device = normalize_runtime_device(param.device) 

235 if param_device not in ("meta", self.device): 

236 raise AssertionError( 

237 f"Expects the parameter to already be moved to device {self.device} but got {param.device}" 

238 ) 

239 hsdp_placement = shard_placement_fn(param) if shard_placement_fn else None 

240 if hsdp_placement is None: 

241 hsdp_placement = Shard(0) 

242 elif hsdp_placement.dim < 0: 

243 # if dim is negative, add the number of dimensions of the parameter 

244 hsdp_placement = Shard(hsdp_placement.dim + param.ndim) 

245 

246 if not isinstance(hsdp_placement, Shard): 

247 raise AssertionError( 

248 f"Expected Shard, got {type(hsdp_placement)}: {hsdp_placement}" 

249 ) 

250 

251 self.hsdp_placement = hsdp_placement 

252 base_placements = list(self._get_base_spmd_placements()) 

253 self._spmd_placements = self._apply_data_parallel_placements(base_placements, hsdp_placement) 

254 param_data = unwrap_dtensor_param(param).to_local() if self._orig_param_is_dtensor else param 

255 

256 shard_dim = hsdp_placement.dim 

257 self._orig_size = param_data.shape 

258 self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size) 

259 

260 if self.uses_param_shard and isinstance(self.mesh_info, FSDPMeshInfo): # FSDP or HSDP 

261 shard_rank = self.mesh_info.shard_mesh_rank 

262 shard_world_size = self.mesh_info.shard_mesh_size 

263 else: # DDP 

264 shard_rank = 0 

265 shard_world_size = 1 

266 

267 self.is_sharded = bool(self.uses_param_shard and shard_world_size > 1) 

268 

269 if param_data.shape[shard_dim] % shard_world_size != 0: 

270 raise NotImplementedError( 

271 f"Uneven sharding on dim {shard_dim} not supported: " 

272 f"shape={param_data.shape}, world_size={shard_world_size}" 

273 ) 

274 chunks = ms.mint.chunk(param_data, shard_world_size, dim=shard_dim) 

275 sharded_param = chunks[shard_rank].clone().contiguous() 

276 self.sharded_size = sharded_param.shape 

277 self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size) 

278 self._sharded_param_data = sharded_param.view(-1) 

279 

280 self._sharding_spec = Layout.from_device_mesh(self._spmd_mesh) 

281 self._sharding_spec.set_placements(self._spmd_placements) 

282 self._sharding_spec.placement_to_tensor_map(param.ndim) 

283 

284 shard_dtensor = DTensor.from_local(sharded_param, self._spmd_mesh, self._spmd_placements) 

285 self.sharded_param = Parameter(shard_dtensor, name=param.name) 

286 set_requires_grad_if_needed(param, self.sharded_param) 

287 self.sharded_param.grad = None 

288 

289 self._setattr_on_modules(self.sharded_param) 

290 self._release_full_param_storage_if_safe(param_data) 

291 self.sharded_param._hsdp_param_initialized = True 

292 self.sharded_state = ShardedState.SHARDED 

293 self.param_dtype = None 

294 

295 def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy): 

296 param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype) 

297 self.orig_dtype = self.sharded_param.dtype 

298 if reduce_dtype == param_dtype: 

299 reduce_dtype = None 

300 if param_dtype == self.orig_dtype: 

301 param_dtype = None 

302 self.param_dtype = param_dtype 

303 self.reduce_dtype = reduce_dtype 

304 

305 def init_all_gather_outputs( 

306 self, 

307 all_gather_input_numels: list[int], 

308 all_gather_input_dtypes: list[ms.Type], 

309 world_size: int, 

310 device: str, 

311 force_recreate: bool = False, 

312 ): 

313 if not force_recreate and len(self.all_gather_outputs) > 0: 

314 return # already initialized 

315 self.all_gather_outputs = [ 

316 ms.mint.empty([numel * world_size], dtype=dtype, device=device.split(':')[0]) 

317 for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes) 

318 ] 

319 

320 def init_unsharded_param(self): 

321 """ 

322 Initialize unsharded parameter from all-gather outputs. 

323 

324 This reconstructs the full parameter after all-gather by unpacking the 

325 gathered flat buffer back to the original tensor layout. 

326 """ 

327 unsharded_param = self._get_unsharded_param_from_all_gather_output() 

328 if self._unsharded_param is not None: 

329 # Keep the Parameter identity stable across forward-reshard-backward 

330 # cycles so backward hooks continue to read gradients from the same 

331 # object that participated in the forward graph. 

332 if self._orig_param_is_dtensor: 

333 self._unsharded_param.set_data(unsharded_param) 

334 else: 

335 self._unsharded_param.data = unsharded_param 

336 set_requires_grad_if_needed(self.sharded_param, self._unsharded_param) 

337 self._unsharded_param.grad = None 

338 return 

339 if self._orig_param_is_dtensor: 

340 self._unsharded_param = Parameter( 

341 unsharded_param, 

342 name=self.sharded_param.name, 

343 requires_grad=self.sharded_param.requires_grad, 

344 ) 

345 return 

346 # For MindSpore, if use `Parameter(tensor)`, Parameter will create a new Tensor instead of a view. 

347 # Here we need to share storage, so we use the `.data = tensor` approach to create shared storage. 

348 self._unsharded_param = Parameter( 

349 [], 

350 name=self.sharded_param.name, 

351 requires_grad=False, 

352 ) 

353 self._unsharded_param.data = unsharded_param 

354 if self.sharded_param.requires_grad: 

355 self._unsharded_param.requires_grad = True 

356 

357 def _get_unsharded_param_from_all_gather_output(self): 

358 """Reconstruct the full local parameter view from the packed all-gather output.""" 

359 if len(self.all_gather_outputs) != 1: 

360 raise AssertionError( 

361 f"Expected 1 all_gather_output, got {len(self.all_gather_outputs)}" 

362 ) 

363 unsharded_tensor = self.all_gather_outputs[0] 

364 plan = build_rs_plan( 

365 self, 

366 self._sharded_local_tensor, 

367 self.shard_world_size if self.is_sharded else 1, 

368 ) 

369 unsharded_param = unpack_from_all_gather(unsharded_tensor, plan) 

370 if getattr(self, "_orig_param_is_dtensor", False): 

371 unsharded_param = DTensor.from_local( 

372 unsharded_param, 

373 self._orig_dtensor_mesh, 

374 self._orig_dtensor_placements, 

375 ) 

376 return unsharded_param 

377 

378 def to_sharded(self) -> None: 

379 if not self.uses_param_shard and self._unsharded_param is not None: 

380 # Replicate params keep the same local shape across shard/unshard, 

381 # so persist forward-time state updates before switching objects. 

382 src = self._unsharded_param.to_local() if isinstance(self._unsharded_param, DTensor) \ 

383 else self._unsharded_param 

384 dst = self.sharded_param.to_local() if isinstance(self.sharded_param, DTensor) else self.sharded_param 

385 copy_without_bumping_version(dst, src) 

386 self._setattr_on_modules(self.sharded_param) 

387 self.free_unsharded_param() 

388 self.sharded_state = ShardedState.SHARDED 

389 

390 def to_unsharded(self) -> None: 

391 set_requires_grad_if_needed(self.sharded_param, self._unsharded_param) 

392 self._setattr_on_modules(self._unsharded_param) 

393 self.sharded_state = ShardedState.UNSHARDED 

394 

395 def _setattr_on_modules(self, param: Parameter) -> None: 

396 if getattr(self._module_info.module.__setattr__, "__func__", None) is nn.Cell.__setattr__: 

397 # fast path 

398 self._module_info.module._params[self._module_info.param_name] = param 

399 else: 

400 # slow path 

401 setattr(self._module_info.module, self._module_info.param_name, param) 

402 

403 # Iterate through all modules that share this parameter to prevent pointer desync. 

404 for shared_module, shared_param_name in zip( 

405 self._module_info.shared_modules, self._module_info.shared_param_names 

406 ): 

407 if getattr(shared_module.__setattr__, "__func__", None) is nn.Cell.__setattr__: 

408 shared_module._params[shared_param_name] = param 

409 else: 

410 setattr(shared_module, shared_param_name, param) 

411 

412 def to_sharded_dtensor(self, tensor: ms.Tensor) -> DTensor: 

413 """ 

414 Converts a local tensor representing either the sharded parameter or 

415 sharded gradient to DTensor. 

416 """ 

417 return DTensor.from_local( 

418 tensor, 

419 self._sharding_spec.mesh, 

420 self._sharding_spec.placements 

421 ) 

422 

423 def _to_local_unsharded_grad(self, grad): 

424 """Normalize a pending gradient to the local tensor expected by fully_shard collectives.""" 

425 return self._normalize_unsharded_grad_to_local(grad, reduce_partial_dtensor=False) 

426 

427 def to_accumulated_grad_if_needed(self) -> None: 

428 if self._unsharded_param.grad is None: 

429 return 

430 unsharded_grad = self._unsharded_param.grad 

431 self._unsharded_param.grad = None 

432 if self.reduce_dtype is not None and unsharded_grad.dtype != self.reduce_dtype: 

433 unsharded_grad = unsharded_grad.to(self.reduce_dtype) 

434 if self.unsharded_accumulated_grad is None: 

435 self.unsharded_accumulated_grad = unsharded_grad 

436 else: 

437 self.unsharded_accumulated_grad += unsharded_grad 

438 

439 def accumulate_unsharded_grad_if_needed(self) -> None: 

440 if ( 

441 self.unsharded_accumulated_grad is not None 

442 and self.unsharded_param.grad is not None 

443 ): 

444 # need to handle the gradient 

445 self.unsharded_accumulated_grad += self._to_local_unsharded_grad(self.unsharded_param.grad) 

446 self.unsharded_param.grad = None 

447 

448 def alloc_all_gather_outputs(self) -> None: 

449 for tensor in self.all_gather_outputs: 

450 expected_size = tensor.numel() * tensor.itemsize 

451 

452 storage = tensor.untyped_storage() 

453 if storage.size() != expected_size: 

454 storage.resize_(expected_size) 

455 

456 def free_unsharded_param(self) -> None: 

457 for tensor in itertools.chain( 

458 self.all_gather_outputs 

459 ): 

460 storage = tensor.untyped_storage() 

461 if storage.size() != 0: 

462 storage.resize_(0) 

463 

464 @property 

465 def all_gather_inputs(self) -> list[ms.Tensor]: 

466 self._assert_in_states(ShardedState.SHARDED) 

467 sharded_param_data = self._sharded_param_data 

468 if self.offload_to_cpu: 

469 sharded_param_data = sharded_param_data.to( 

470 self.device, non_blocking=True 

471 ) 

472 if self.param_dtype is not None and self.param_dtype != sharded_param_data.dtype: 

473 return [sharded_param_data.to(self.param_dtype)] 

474 return [sharded_param_data] 

475 

476 @property 

477 def unsharded_param(self) -> Parameter: 

478 """Return the full unsharded parameter after all-gather.""" 

479 return self._unsharded_param 

480 

481 @property 

482 def unsharded_grad_data(self) -> ms.Tensor: 

483 """ 

484 Get the unsharded gradient data as a local tensor. 

485 """ 

486 grad = self.unsharded_param.grad 

487 if grad is None: 

488 raise AssertionError("Expects unsharded_param.grad to not be None") 

489 return self._to_local_unsharded_grad(grad) 

490 

491 @property 

492 def unsharded_accumulated_grad_data(self) -> ms.Tensor: 

493 """ 

494 Get the unsharded accumulated gradient data as a local tensor. 

495 """ 

496 grad = self.unsharded_accumulated_grad 

497 return grad 

498 

499 @property 

500 def _sharded_local_tensor(self) -> ms.Tensor: 

501 """Return the underlying local tensor of the sharded DTensor parameter.""" 

502 return cast(DTensor, self.sharded_param)._local_tensor 

503 

504 @property 

505 def shard_world_size(self) -> int: 

506 """Get the world size for shard dimension.""" 

507 if isinstance(self.mesh_info, FSDPMeshInfo): 

508 return self.mesh_info.shard_mesh_size 

509 return 1 

510 

511 @property 

512 def replicate_world_size(self) -> int: 

513 """Get the world size for replicate dimension (HSDP only).""" 

514 if isinstance(self.mesh_info, HSDPMeshInfo): 

515 return self.mesh_info.replicate_mesh_size 

516 return 1 

517 

518 def _assert_in_states(self, *states: ShardedState) -> None: 

519 """Assert current state is one of expected states.""" 

520 if self.sharded_state not in states: 

521 raise AssertionError( 

522 f"Expected sharded_state in {states}, got {self.sharded_state}" 

523 ) 

524 

525 def reset_sharded_param(self) -> None: 

526 """Reset sharded param after load_state_dict.""" 

527 module_info = self._module_info 

528 new_param = getattr(module_info.module, module_info.param_name) 

529 if new_param is not self.sharded_param: 

530 if isinstance(new_param, DTensor): 

531 self.sharded_param = new_param 

532 if not getattr(self.sharded_param, "_hsdp_param_initialized", None): 

533 # reset _hsdp_param_initialized flag. 

534 self.sharded_param._hsdp_param_initialized = True 

535 elif isinstance(new_param, ms.Tensor): 

536 # if new_param is Tensor, don't re-ref 'self.sharded_param' 

537 # just update self.sharded_param._local_tensor and self.sharded_param_data. 

538 pass 

539 

540 local_tensor = new_param._local_tensor if isinstance(new_param, DTensor) else new_param 

541 if local_tensor.is_meta: 

542 return 

543 updated_local_tensor = False 

544 # local_tensor can be padded twice 

545 # 1st time in fully_shard(model) 

546 # 2nd time in model(input) lazy_init 

547 # 2nd time should be no-op if parameters remain unchanged 

548 # 2nd time shouldn't be no-op if people call model.load_state_dict(...) before lazy_init 

549 # this makes it possible for trainer to call `sd = model.state_dict()` before the training loop 

550 # and use `sd` without calling .state_dict() per iteration 

551 same_local_tensor = False 

552 if isinstance(self._sharded_param_data, ms.Tensor): 

553 same_local_tensor = ( 

554 # when sharding param with shape (1, ...) over 2 ranks 

555 # local_tensor on rank 1 can be size 0, data_ptr() can be 0 

556 self._sharded_param_data.untyped_storage().data_ptr() > 0 

557 and self._sharded_param_data.untyped_storage().data_ptr() 

558 == local_tensor.untyped_storage().data_ptr() 

559 ) 

560 sharded_size = self.sharded_size 

561 shard_dim = self.hsdp_placement.dim 

562 length = local_tensor.shape[shard_dim] if local_tensor.numel() > 0 else 0 

563 if not same_local_tensor: 

564 if local_tensor.shape != sharded_size : 

565 raise AssertionError( 

566 f"Expected sharded_size to be {sharded_size}, got {local_tensor.size()}" 

567 ) 

568 updated_local_tensor = True 

569 if self.pin_memory and not local_tensor.is_pinned(): 

570 local_tensor = local_tensor.to("cpu").pin_memory() 

571 updated_local_tensor = True 

572 if not same_local_tensor: 

573 self._sharded_param_data = local_tensor.view(-1) 

574 if not isinstance(self.sharded_param, DTensor): 

575 raise AssertionError(f"Expected DTensor, got {type(self.sharded_param)}") 

576 if updated_local_tensor: 

577 # Only change the local tensor object if needed 

578 self.sharded_param._local_tensor = local_tensor.narrow( 

579 dim=shard_dim, start=0, length=length 

580 ) 

581 if not self.sharded_param._local_tensor.is_contiguous(): 

582 raise AssertionError( 

583 "Expected sharded_param._local_tensor to be contiguous" 

584 ) 

585 self._sharding_spec = cast(DTensor, self.sharded_param).layout 

586 

587 def _get_unsharded_param_data(self, async_op: bool = False) -> Tuple[ms.Tensor, Optional[CommHandle]]: 

588 """ 

589 Perform all-gather to get unsharded parameter data. 

590 

591 Args: 

592 async_op: Whether to execute asynchronously. 

593 

594 Returns: 

595 (unsharded_param, handle): Unsharded parameter data and communication handle. 

596 """ 

597 # Optimizer steps may refresh the underlying local tensor storage. Re-sync 

598 # the cached flat shard view before reading all_gather_inputs for the next 

599 # unshard cycle. 

600 self.reset_sharded_param() 

601 all_gather_input = self.all_gather_inputs[0] 

602 

603 # If parameter is not sharded (below threshold), no communication needed 

604 if not self.is_sharded: 

605 self.init_all_gather_outputs( 

606 all_gather_input_numels=[all_gather_input.numel()], 

607 all_gather_input_dtypes=[all_gather_input.dtype], 

608 world_size=1, 

609 device=all_gather_input.device.split(':')[0], 

610 ) 

611 self.alloc_all_gather_outputs() 

612 copy_without_bumping_version(self.all_gather_outputs[0], all_gather_input) 

613 return self.all_gather_outputs[0], None 

614 

615 # Initialize output buffer 

616 self.init_all_gather_outputs( 

617 all_gather_input_numels=[all_gather_input.numel()], 

618 all_gather_input_dtypes=[all_gather_input.dtype], 

619 world_size=self.shard_world_size, 

620 device=self._sharded_param_data.device.split(':')[0], 

621 ) 

622 self.alloc_all_gather_outputs() 

623 

624 # Get communication group 

625 shard_group = self.mesh_info.shard_process_group if isinstance(self.mesh_info, FSDPMeshInfo) else None 

626 

627 if shard_group is None or self.shard_world_size <= 1: 

628 # No communication needed, just copy 

629 copy_without_bumping_version(self.all_gather_outputs[0], all_gather_input) 

630 return self.all_gather_outputs[0], None 

631 

632 # Execute all_gather_into_tensor 

633 handle = dist.all_gather_into_tensor( 

634 self.all_gather_outputs[0], 

635 all_gather_input, 

636 group=shard_group, 

637 async_op=async_op, 

638 ) 

639 

640 return self.all_gather_outputs[0], handle 

641 

642 def unshard(self, async_op: bool = False) -> None: 

643 if self.prefetch_handle is not None: 

644 # Already triggered by HSDPState.prefetch(), so return directly. 

645 return # no-op 

646 

647 _, handle = self._get_unsharded_param_data(async_op=async_op) 

648 self.prefetch_handle = handle 

649 

650 def wait_for_unshard(self) -> None: 

651 self._assert_in_states(ShardedState.SHARDED) 

652 

653 if self.prefetch_handle is not None: 

654 self.prefetch_handle.wait() 

655 self.prefetch_handle = None 

656 

657 self.init_unsharded_param() 

658 self.to_unsharded() 

659 

660 def shard(self) -> None: 

661 """ 

662 Transition parameter from unsharded back to sharded state. 

663 """ 

664 self._assert_in_states(ShardedState.UNSHARDED) 

665 self.to_sharded() 

666 

667 def reduce_scatter_output(self): 

668 """Return cached reduce-scatter output after waiting pending async work.""" 

669 if self.reduce_scatter_handle is not None: 

670 self.reduce_scatter_handle.wait() 

671 self.reduce_scatter_handle = None 

672 return self._reduce_scatter_output 

673 

674 def clear_reduce_scatter_output(self): 

675 """Clear cached reduce-scatter output.""" 

676 self._reduce_scatter_output = None 

677 

678 def reduce_scatter_grad( 

679 self, 

680 async_op: bool = True, 

681 dtype: Optional[ms.Type] = None, 

682 reduce_op: Optional[ops.ReduceOp] = ops.ReduceOp.SUM 

683 ) -> Tuple[ms.Tensor, Optional[CommHandle]]: 

684 """ 

685 Perform reduce-scatter on gradient to reduce and shard the full gradient. 

686 

687 Args: 

688 async_op: Whether to execute asynchronously. 

689 dtype: reduce dtype. 

690 reduce_op: do reduce-scatter avg or sum. 

691 

692 Returns: 

693 (sharded_grad, handle): Sharded gradient and communication handle. 

694 """ 

695 self._assert_in_states(ShardedState.UNSHARDED) 

696 

697 # Choose gradient source based on use_accumulated_grad flag 

698 if self.unsharded_accumulated_grad is not None: 

699 grad = self.unsharded_accumulated_grad_data 

700 else: 

701 grad = self.unsharded_grad_data 

702 reduce_dtype = dtype or grad.dtype 

703 grad = grad.to(reduce_dtype) 

704 shard_group_info = getattr(self, "sharded_group_info", None) 

705 shard_group = shard_group_info.group if shard_group_info is not None else None 

706 shard_group_size = shard_group_info.rank_size if shard_group_info is not None else 1 

707 if shard_group is None and isinstance(self.mesh_info, FSDPMeshInfo): 

708 shard_group = self.mesh_info.shard_process_group 

709 shard_group_size = self.shard_world_size 

710 plan_world_size = ( 

711 shard_group_size 

712 if self.is_sharded and shard_group is not None and shard_group_size > 1 

713 else 1 

714 ) 

715 plan = build_rs_plan(self, grad, plan_world_size) 

716 grad_flat = pack_for_reduce_scatter(grad, plan).reshape(-1) 

717 

718 # If parameter is not sharded (below threshold), no reduce-scatter needed 

719 if not self.is_sharded: 

720 return grad_flat, None 

721 

722 if shard_group is None or shard_group_size <= 1: 

723 # No communication needed 

724 return grad_flat, None 

725 

726 # Calculate output size 

727 output_numel = grad_flat.numel() // shard_group_size 

728 self._reduce_scatter_output = ms.mint.empty( 

729 output_numel, dtype=reduce_dtype, device=grad.device.split(':')[0] 

730 ) 

731 

732 # Execute reduce_scatter_tensor 

733 self.reduce_scatter_handle = dist.reduce_scatter_tensor( 

734 self._reduce_scatter_output, 

735 grad_flat, 

736 op=reduce_op, 

737 group=shard_group, 

738 async_op=async_op, 

739 ) 

740 

741 return self._reduce_scatter_output, self.reduce_scatter_handle 

742 

743 def zero_grad(self): 

744 self.sharded_param.grad = None 

745 

746 def all_reduce_grad( 

747 self, 

748 grad: Optional[ms.Tensor] = None, 

749 dtype: Optional[ms.Type] = None, 

750 async_op: bool = True, 

751 reduce_op: Optional[ops.ReduceOp] = ops.ReduceOp.SUM 

752 ) -> Tuple[ms.Tensor, Optional[CommHandle]]: 

753 """ 

754 Perform all-reduce on gradient (across replicate dimension in HSDP mode). 

755 

756 Args: 

757 grad: Gradient tensor to reduce. If None, will use unsharded_param.grad 

758 or unsharded_accumulated_grad based on use_accumulated_grad flag. 

759 async_op: Whether to execute asynchronously. 

760 reduce_op: Optional[ops.ReduceOp] = ops.ReduceOp.SUM. 

761 

762 Returns: 

763 (reduced_grad, handle): Reduced gradient and communication handle. 

764 """ 

765 # If grad is not provided, get from parameter 

766 if grad is None: 

767 if self.unsharded_accumulated_grad is not None: 

768 grad = self.unsharded_accumulated_grad_data 

769 else: 

770 grad = self.unsharded_grad_data 

771 else: 

772 grad = self._to_local_unsharded_grad(grad) 

773 

774 if dtype is not None and dtype != grad.dtype: 

775 grad = grad.to(dtype) 

776 reduce_group_info = self.unsharded_group_info 

777 if reduce_group_info.rank_size <= 1: 

778 return grad, None 

779 reduce_group = reduce_group_info.group 

780 if reduce_group is None: 

781 raise RuntimeError("Expected a valid unsharded all-reduce group when rank_size > 1") 

782 

783 self._all_reduce_output = grad 

784 self.all_reduce_handle = dist.all_reduce( 

785 grad, 

786 op=reduce_op, 

787 group=reduce_group, 

788 async_op=async_op 

789 ) 

790 return self._all_reduce_output, self.all_reduce_handle 

791 

792 def all_reduce_output(self): 

793 """Return cached all-reduce output after waiting pending async work.""" 

794 if self.all_reduce_handle is not None: 

795 self.all_reduce_handle.wait() 

796 self.all_reduce_handle = None 

797 return self._all_reduce_output 

798 

799 def clear_all_reduce_output(self): 

800 """Clear cached all-reduce output.""" 

801 self._all_reduce_output = None 

802 

803 def apply_reduced_grad(self, reduced_grad, param_type): 

804 """ 

805 Apply reduced gradient to the sharded parameter. 

806 

807 Reshapes ``reduced_grad`` to match the local shard, optionally 

808 offloads to CPU, then accumulates or assigns onto 

809 ``self.sharded_param.grad``. 

810 

811 Args: 

812 reduced_grad (ms.Tensor): Gradient after reduce-scatter 

813 and/or all-reduce. 

814 param_type (Optional[ms.Type]): Target dtype for the gradient. 

815 """ 

816 sharded_grad = self.sharded_param.grad 

817 reduced_grad = reduced_grad.view(self.sharded_size) 

818 reduced_grad = _to_dtype_if_needed(reduced_grad, param_type) 

819 to_accumulate_grad = sharded_grad is not None 

820 need_synchronize = False 

821 if self.offload_to_cpu: 

822 non_blocking = self.pin_memory and not to_accumulate_grad 

823 reduced_grad = reduced_grad.to( 

824 "cpu", non_blocking=non_blocking 

825 ) 

826 need_synchronize = True 

827 if sharded_grad is None: 

828 self.sharded_param.grad = self.to_sharded_dtensor(reduced_grad) 

829 else: 

830 self.sharded_param.grad._local_tensor += reduced_grad 

831 

832 if self.unsharded_accumulated_grad_data is not None: 

833 self.unsharded_accumulated_grad = None 

834 elif self.unsharded_param.grad is not None: 

835 self.unsharded_param.grad = None 

836 return need_synchronize 

837 

838 

839def set_requires_grad_if_needed( 

840 src_tensor: ms.Tensor, dst_tensor: ms.Tensor 

841) -> None: 

842 if src_tensor.requires_grad != dst_tensor.requires_grad: 

843 dst_tensor.requires_grad_(src_tensor.requires_grad)