Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / fully_shard / api.py: 42%

297 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""hybrid shard data parallel interface""" 

16import warnings 

17from collections import namedtuple 

18from typing import Any, List, Mapping, cast, Optional, Union 

19 

20from hyper_parallel.platform.platform import PlatformType 

21from hyper_parallel.core.fully_shard.utils import MixedPrecisionPolicy, OffloadPolicy 

22from hyper_parallel import DeviceMesh, init_device_mesh 

23from hyper_parallel.platform import get_platform 

24from hyper_parallel.core.dtensor.dtensor import DTensor, distribute_tensor 

25from hyper_parallel.core.fully_shard.hsdp_utils import ( 

26 get_managed_modules_parameters, 

27 is_dtensor_managed_param, 

28 get_dtensor_managed_mesh, 

29) 

30 

31platform = get_platform() 

32 

33origin_class_to_extend_class = {} 

34 

35 

36def _resolve_comm_fusion_zero_copy_default( 

37 platform_type: PlatformType, 

38 comm_fusion: bool, 

39 comm_fusion_zero_copy: Optional[bool], 

40) -> bool: 

41 """Resolve backend-specific default for the comm_fusion zero-copy path.""" 

42 if comm_fusion_zero_copy is not None: 

43 return comm_fusion_zero_copy 

44 if not comm_fusion: 

45 return False 

46 if platform_type == PlatformType.PYTORCH: 

47 return True 

48 if platform_type == PlatformType.MINDSPORE: 

49 return False 

50 return False 

51 

52 

53def _check_strict_keys( 

54 module: platform.Module, state_dict: Mapping[str, Any], 

55) -> None: 

56 """Raise ``RuntimeError`` if *state_dict* keys do not match *module*.""" 

57 expected_keys = set(module.state_dict().keys()) 

58 missing = expected_keys - set(state_dict.keys()) 

59 unexpected = set(state_dict.keys()) - expected_keys 

60 error_msgs: list[str] = [] 

61 if missing: 

62 error_msgs.append( 

63 "Missing key(s): " + ", ".join(repr(k) for k in sorted(missing)) 

64 ) 

65 if unexpected: 

66 error_msgs.append( 

67 "Unexpected key(s): " + ", ".join(repr(k) for k in sorted(unexpected)) 

68 ) 

69 if error_msgs: 

70 raise RuntimeError( 

71 f"Error(s) in loading state_dict for " 

72 f"{module.__class__.__name__}:\n\t" 

73 + "\n\t".join(error_msgs) 

74 ) 

75 

76 

77def _resolve_local_tensor( 

78 key: str, val: platform.Tensor, target: DTensor, 

79) -> platform.Tensor: 

80 """Return the local shard tensor to be loaded into *target*.""" 

81 if isinstance(val, DTensor): 

82 return val.to_local() 

83 local_shape = tuple(target.local_shape) 

84 global_shape = tuple(target.shape) 

85 val_shape = tuple(val.shape) 

86 if val_shape == local_shape: 

87 return val 

88 if val_shape == global_shape: 

89 wrapped = distribute_tensor( 

90 val, target.device_mesh, 

91 target.layout.alias_placements if target.layout else target.placements, 

92 ) 

93 return wrapped.to_local() 

94 

95 raise ValueError( 

96 f"load '{key}': plain tensor shape {val_shape} " 

97 f"matches neither local shard {local_shape} " 

98 f"nor global {global_shape}." 

99 ) 

100 

101 

102class _UnshardHandle: 

103 """Unshard handle for user call HSDPModule.unshard(async_op=True)""" 

104 def __init__(self, hsdp_state=None): 

105 """ 

106 Initialize an async unshard handle. 

107 

108 Args: 

109 hsdp_state (HSDPState, optional): The state to wait on. None means a no-op handle. 

110 """ 

111 self._hsdp_state = hsdp_state 

112 

113 def wait(self): 

114 """Block until the async unshard operation completes.""" 

115 if self._hsdp_state is not None: 

116 self._hsdp_state.wait_for_unshard() 

117 self._hsdp_state = None 

118 

119 

120class HSDPModule: 

121 """ 

122 The hsdp block of neural networks with hsdp interface. 

123 

124 Supported Platforms: 

125 ``MindSpore`` ``torch`` 

126 """ 

127 

128 def __init__(self): 

129 """Initialize HSDPModule.""" 

130 self.hsdp_scheduler = None # Initialized in hsdp_init() 

131 

132 # pylint: disable=C0415 

133 def hsdp_init(self, platform_type, module, mesh, reshard_after_forward, 

134 shard_placement_fn, mp_policy, offload_policy, ignored_params, replicate_params, device, 

135 comm_fusion, comm_fusion_zero_copy: Optional[bool] = None): 

136 """init hsdp2 scheduler.""" 

137 scheduler_class = None 

138 if platform_type == PlatformType.MINDSPORE: 

139 from hyper_parallel.platform.mindspore.fully_shard.scheduler import MindSporeHSDPSchedulerV2 

140 scheduler_class = MindSporeHSDPSchedulerV2 

141 else: 

142 from hyper_parallel.platform.torch.fully_shard.scheduler import TorchHSDPSchedulerV2 

143 scheduler_class = TorchHSDPSchedulerV2 

144 

145 resolved_comm_fusion_zero_copy = _resolve_comm_fusion_zero_copy_default( 

146 platform_type, 

147 comm_fusion, 

148 comm_fusion_zero_copy, 

149 ) 

150 

151 self.hsdp_scheduler = scheduler_class(module, 

152 mesh, 

153 reshard_after_forward, 

154 shard_placement_fn, 

155 mp_policy, 

156 offload_policy, 

157 ignored_params, 

158 replicate_params, 

159 device, 

160 comm_fusion, 

161 resolved_comm_fusion_zero_copy, 

162 ) 

163 

164 def set_requires_gradient_sync(self, requires_grad_sync): 

165 r""" 

166 set requires grad sync flag. 

167 Args: 

168 requires_grad_sync(bool): requires_grad_sync is used to control gradient sync process. 

169 Raises: 

170 ValueError: If `requires_grad_sync` is not bool. 

171 """ 

172 if not isinstance(requires_grad_sync, bool): 

173 raise ValueError(f"requires_grad_sync must be bool but got {requires_grad_sync}.") 

174 if not hasattr(self, "hsdp_scheduler"): 

175 raise ValueError("call hsdp interface first.") 

176 

177 for _, module in platform.get_cells_and_names(self): 

178 if isinstance(module, HSDPModule): 

179 module.hsdp_scheduler.set_requires_grad_sync(requires_grad_sync) 

180 

181 def zero_grad(self): 

182 """zero accumunication grads""" 

183 if not hasattr(self, "hsdp_scheduler"): 

184 raise ValueError("call hsdp interface first.") 

185 if platform.platform_type == PlatformType.PYTORCH: 

186 return super().zero_grad() 

187 for _, module in platform.get_cells_and_names(self): 

188 if isinstance(module, HSDPModule): 

189 module.hsdp_scheduler.zero_grad() 

190 

191 def set_modules_to_forward_prefetch(self, modules): 

192 """set forward prefetch module list to prefetch all gather for unsharded parameters""" 

193 if not isinstance(modules, (tuple, list)): 

194 raise ValueError("modules must be HSDPModule list") 

195 for module in modules: 

196 if not isinstance(module, HSDPModule): 

197 raise ValueError(f"modules must be HSDPModule list but got {type(module)} in list.") 

198 if not hasattr(self, "hsdp_scheduler"): 

199 raise ValueError("call hsdp interface first.") 

200 self.hsdp_scheduler.set_forward_prefetch_cells(modules) 

201 

202 def set_modules_to_backward_prefetch(self, modules): 

203 """set backward prefetch module list to prefetch all gather for unsharded parameters""" 

204 if not isinstance(modules, (tuple, list)): 

205 raise ValueError("modules must be HSDPModule list") 

206 for module in modules: 

207 if not isinstance(module, HSDPModule): 

208 raise ValueError(f"modules must be HSDPModule list but got {type(module)} in list.") 

209 if not hasattr(self, "hsdp_scheduler"): 

210 raise ValueError("call fully_shard interface first.") 

211 self.hsdp_scheduler.set_backward_prefetch_cells(modules) 

212 

213 def reshard(self) -> None: 

214 """reshard all sharded parameters""" 

215 if not self.hsdp_scheduler: 

216 raise ValueError("hsdp_scheduler is None") 

217 hsdp_state = self.hsdp_scheduler.hsdp_state 

218 if hsdp_state: 

219 hsdp_state.shard() 

220 

221 def unshard(self, async_op: bool = False): 

222 """unshard all sharded parameters""" 

223 if not isinstance(async_op, bool): 

224 raise ValueError(f"async_op should be a bool, got {type(async_op)}") 

225 if not self.hsdp_scheduler: 

226 raise ValueError("hsdp_scheduler is None") 

227 hsdp_state = self.hsdp_scheduler.hsdp_state 

228 if hsdp_state: 

229 hsdp_state.unshard(async_op) # pylint: disable=too-many-function-args 

230 if async_op: 

231 return _UnshardHandle(hsdp_state=hsdp_state) 

232 return None 

233 

234 def load_state_dict( 

235 self, 

236 state_dict: Mapping[str, Any], 

237 strict: bool = True, 

238 assign: bool = False, 

239 ): 

240 """ 

241 Load state dict by copying directly into local shards. 

242 

243 Bypasses ``super().load_state_dict()`` because the standard PyTorch 

244 implementation triggers ``copy_`` through the DTensor dispatcher, which 

245 is not registered in the hyper-parallel layout system. 

246 

247 Each value in ``state_dict`` is dispatched by type: 

248 - hyper DTensor: extract local shard and copy directly. 

249 - plain Tensor whose shape == local shard shape: copy as-is. 

250 - plain Tensor whose shape == global shape: distribute via 

251 ``distribute_tensor``, then copy the local shard. 

252 

253 Args: 

254 state_dict (Mapping[str, Any]): Fully-qualified parameter/buffer 

255 names mapped to tensors (DTensor or plain Tensor). 

256 strict (bool): If ``True`` (default), missing or unexpected keys 

257 raise ``RuntimeError``, matching ``nn.Module.load_state_dict`` 

258 semantics. 

259 assign (bool): Accepted for API compatibility with 

260 ``nn.Module.load_state_dict(assign=True)`` but currently 

261 ignored; HSDP always copies into existing DTensor storage. 

262 

263 Raises: 

264 RuntimeError: When ``strict`` is ``True`` and keys do not match. 

265 ValueError: When a plain tensor shape matches neither the local 

266 shard shape nor the global shape of the target DTensor. 

267 """ 

268 if assign: 

269 warnings.warn( 

270 "HSDPModule.load_state_dict: assign=True is ignored; " 

271 "HSDP always copies into existing DTensor parameters.", 

272 stacklevel=2, 

273 ) 

274 self_module = cast(platform.Module, self) 

275 

276 target_map: dict[str, platform.Tensor] = {} 

277 for name, p in platform.parameters_dict(self_module): 

278 target_map[name] = p 

279 for name, b in self_module.named_buffers(): 

280 target_map[name] = b 

281 

282 if strict: 

283 _check_strict_keys(self_module, state_dict) 

284 

285 with platform.no_grad(): 

286 for key, val in state_dict.items(): 

287 target = target_map.get(key) 

288 if target is None: 

289 continue 

290 

291 if isinstance(target, DTensor): 

292 val = _resolve_local_tensor(key, val, target) 

293 platform.load_into_param(target, val) 

294 

295 # Trigger load_state_dict post-hooks so that HSDP internal 

296 # bookkeeping (e.g. _sharded_param_data) stays in sync. 

297 # Pass an IncompatibleKeys with the same attribute names as PyTorch 

298 # so external hooks can safely read .missing_keys/.unexpected_keys. 

299 _IK = namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]) 

300 incompatible_keys = _IK([], []) 

301 for _, module in platform.get_cells_and_names(self_module): 

302 hooks = module._load_state_dict_post_hooks # pylint: disable=protected-access 

303 for hook in hooks.values(): 

304 hook(module, incompatible_keys) 

305 

306 def set_is_last_backward(self, is_last_backward: bool): 

307 """set is_last_backward flag""" 

308 self.hsdp_scheduler.scheduler_ctx.is_last_backward = is_last_backward 

309 

310 def set_requires_all_reduce(self, requires_all_reduce: bool, *, recurse: bool = True) -> None: 

311 """set requires_all_reduce flag""" 

312 if not isinstance(requires_all_reduce, bool): 

313 raise ValueError( 

314 f"requires_all_reduce should be a bool, got {type(requires_all_reduce)}" 

315 ) 

316 if not recurse: 

317 raise NotImplementedError( 

318 "Currently impl is equal to recurse=True, " 

319 "need support module_param mapping." 

320 ) 

321 self_module = cast(platform.Module, self) 

322 modules = list(self_module.modules()) if recurse else [self_module] 

323 for module in modules: 

324 if isinstance(module, HSDPModule): 

325 module.hsdp_scheduler.set_requires_all_reduce(requires_all_reduce) 

326 

327 def set_reshard_after_forward(self, reshard_after_forward: bool, recurse: bool = True) -> None: 

328 """set reshard_after_forward flag""" 

329 if not isinstance(reshard_after_forward, bool): 

330 raise ValueError( 

331 f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}" 

332 ) 

333 if not recurse: 

334 raise NotImplementedError( 

335 "Currently impl is equal to recurse=True, " 

336 "need support module_param mapping." 

337 ) 

338 self_module = cast(platform.Module, self) 

339 modules = list(self_module.modules()) if recurse else [self_module] 

340 for module in modules: 

341 if isinstance(module, HSDPModule): 

342 module.hsdp_scheduler.set_reshard_after_forward(reshard_after_forward) 

343 

344 def set_reshard_after_backward(self, reshard_after_backward: bool, recurse: bool = True) -> None: 

345 """set reshard_after_backward flag""" 

346 if not isinstance(reshard_after_backward, bool): 

347 raise ValueError( 

348 f"reshard_after_backward should be a bool, got {type(reshard_after_backward)}" 

349 ) 

350 if not recurse: 

351 raise NotImplementedError( 

352 "Currently impl is equal to recurse=True, " 

353 "need support module_param mapping." 

354 ) 

355 self_module = cast(platform.Module, self) 

356 modules = list(self_module.modules()) if recurse else [self_module] 

357 for module in modules: 

358 if isinstance(module, HSDPModule): 

359 module.hsdp_scheduler.set_reshard_after_backward(reshard_after_backward) 

360 

361 def set_reduce_op_type(self, reduce_op_type) -> None: 

362 """ 

363 Set reduce_op_type for all gradient reductions in fully_shard. 

364 

365 Supports ``"avg"`` and ``"sum"``. Local-parameter FSDP/HSDP keeps the 

366 historical ``"avg"`` default, while DTensor-based paths default to ``"sum"``. 

367 """ 

368 if hsdp_state := self.hsdp_scheduler.hsdp_state: 

369 hsdp_state.set_reduce_op_type(reduce_op_type) 

370 

371 

372def _extend_module_with_hsdp_interface(module): 

373 """Dynamically extend module's class to inherit from HSDPModule, adding HSDP capabilities.""" 

374 origin_class = module.__class__ 

375 extend_class = origin_class_to_extend_class.get(origin_class, None) 

376 if extend_class is None: 

377 extend_class = type(f"HSDP{origin_class.__name__}", (HSDPModule, origin_class), {}) 

378 origin_class_to_extend_class[origin_class] = extend_class 

379 module.__class__ = extend_class 

380 

381 

382def _get_root_modules(modules: List[platform.Module]) -> List[platform.Module]: 

383 """ 

384 Returns the modules in ``modules`` that are root modules (i.e. parent-less) 

385 with respect to the set ``modules``. In other words, these are the modules 

386 in ``modules`` that are not the child of any other module in ``modules``. 

387 

388 Aligned with PyTorch torch.distributed.utils._get_root_modules. 

389 """ 

390 root_modules: List[platform.Module] = [] 

391 

392 def _get_submodules(mod): 

393 if platform.platform_type == PlatformType.MINDSPORE: 

394 return set(c for _, c in mod.cells_and_names()) 

395 return set(mod.modules()) 

396 

397 module_to_modules: dict[platform.Module, set] = { 

398 m: _get_submodules(m) for m in modules 

399 } 

400 for candidate in modules: 

401 is_root = True 

402 for mod, submodules in module_to_modules.items(): 

403 if candidate is not mod and candidate in submodules: 

404 is_root = False 

405 break 

406 if is_root: 

407 root_modules.append(candidate) 

408 return root_modules 

409 

410 

411def _check_module_valid(platform_type, module): 

412 """check module valid""" 

413 if platform_type == PlatformType.MINDSPORE: 

414 from mindspore.nn.cell import Cell 

415 if not isinstance(module, Cell): 

416 raise ValueError(f"module's type must be nn.cell but got {type(module)}.") 

417 else: 

418 from torch.nn import Module 

419 if not isinstance(module, Module): 

420 raise ValueError(f"module's type must be nn.Module but got {type(module)}.") 

421 

422 

423def _validate_module_for_fully_shard( 

424 module: Union[platform.Module, List[platform.Module]], platform_type 

425) -> None: 

426 """Validate module(s) for fully_shard. Platform-aware for single module.""" 

427 if isinstance(module, list): 

428 if len(module) == 0: 

429 raise ValueError("fully_shard does not support empty list of modules.") 

430 for i, m in enumerate(module): 

431 try: 

432 _check_module_valid(platform_type, m) 

433 except ValueError: 

434 raise ValueError( 

435 f"fully_shard expects nn.Module or list[nn.Module], " 

436 f"but got list with {type(m).__name__} at index {i}." 

437 ) from None 

438 else: 

439 _check_module_valid(platform_type, module) 

440 

441 

442def _check_hsdp_input_valid(platform_type, module, shard_size, threshold, optimizer_level, enable_grad_accumulation, 

443 grad_scale, reduce_dtype, comm_async, comm_fusion, bucket_size): 

444 """check hsdp input valid""" 

445 _check_module_valid(platform_type, module) 

446 if not isinstance(shard_size, int) or (shard_size <= 0 and shard_size != -1): 

447 raise ValueError(f"shard_size must be a positive integer, but got {shard_size}.") 

448 if not isinstance(threshold, int) or threshold < 0: 

449 raise ValueError(f"threshold must be a positive integer or 0, but got {threshold}.") 

450 if optimizer_level not in ["level1", "level2", "level3"]: 

451 raise ValueError(f"Optimizer level should in ['level1', 'level2', 'level3'], but got {optimizer_level}.") 

452 if not isinstance(enable_grad_accumulation, bool): 

453 raise ValueError(f"enable_grad_accumulation must be bool but got {enable_grad_accumulation}.") 

454 if not isinstance(grad_scale, float): 

455 raise ValueError(f"grad_scale must be float but got {grad_scale}.") 

456 if platform_type == PlatformType.MINDSPORE: 

457 from mindspore._c_expression.typing import Type 

458 if reduce_dtype is not None and not isinstance(reduce_dtype, Type): 

459 raise ValueError(f"reduce_dtype must be mindspore.dtype but got {reduce_dtype}.") 

460 else: 

461 import torch 

462 if reduce_dtype is not None and not isinstance(reduce_dtype, torch.dtype): 

463 raise ValueError(f"reduce_dtype must be torch.dtype but got {reduce_dtype}.") 

464 if not isinstance(comm_async, bool): 

465 raise ValueError(f"comm_async must be bool but got {comm_async}.") 

466 if not isinstance(comm_fusion, bool): 

467 raise ValueError(f"comm_fusion must be bool but got {comm_fusion}.") 

468 if not isinstance(bucket_size, int) or (bucket_size < 0 and bucket_size != -1): 

469 raise ValueError(f"bucket_size must be a positive integer or 0, but got {bucket_size}.") 

470 

471 

472def _get_device_from_mesh(mesh: DeviceMesh): 

473 """Extract and validate the torch device from the device mesh.""" 

474 device = None 

475 device_type = mesh.device_type 

476 if device_type not in ("npu", "cuda"): 

477 raise AssertionError( 

478 f"hyper_parallel.fully_shard support device in [torch.npu, torch.cuda], " 

479 f"but got '{device_type}'" 

480 ) 

481 if platform.platform_type == PlatformType.PYTORCH: 

482 device_handle = platform.get_device_handle(device_type) 

483 if device_handle is None: 

484 raise ValueError( 

485 f"hyper_parallel.fully_shard can't find device_handle of " 

486 f"'torch.{device_type}', check the environment." 

487 ) 

488 if device_handle.is_available(): 

489 import torch 

490 device = torch.device(device_handle.current_device()) 

491 else: 

492 device = device_type 

493 return device 

494 

495 

496def _normalize_replicate_params( 

497 replicate_params: Optional[set[platform.Parameter]], 

498) -> set[platform.Parameter]: 

499 """ 

500 Normalize replicate_params for fully_shard 

501 Args: 

502 replicate_params (Optional[set[nn.Parameter]]): Set of parameters to exclude from sharding. 

503 Returns: 

504 set[nn.Parameter]: Set of parameters to exclude from sharding. 

505 """ 

506 if replicate_params is None: 

507 return set() 

508 out = set(replicate_params) 

509 for p in out: 

510 if not isinstance(p, (platform.Parameter, DTensor)): 

511 raise TypeError( 

512 "replicate_params must contain only nn.Parameter or DTensor, " 

513 f"got {type(p).__name__}." 

514 ) 

515 return out 

516 

517 

518def _get_modules_parameters(modules, ignored_params=None): 

519 """Collect deduplicated parameters from module roots.""" 

520 return get_managed_modules_parameters(modules, ignored_params) 

521 

522 

523def fully_shard( 

524 module: Union[platform.Module, List[platform.Module]], 

525 *, 

526 mesh: Optional[DeviceMesh] = None, 

527 reshard_after_forward: bool = True, 

528 shard_placement_fn: None = None, 

529 mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), 

530 offload_policy: OffloadPolicy = OffloadPolicy(), 

531 ignored_params: Optional[set[platform.Parameter]] = None, 

532 replicate_params: Optional[set[platform.Parameter]] = None, 

533 comm_fusion: bool = False, 

534 comm_fusion_zero_copy: Optional[bool] = None, 

535) -> Union[platform.Module, List[platform.Module]]: 

536 

537 """ 

538 Apply fully_shard to a module (or list of modules) for distributed training with parameter sharding. 

539 

540 This interface provides PyTorch-compatible HSDP (Hybrid Sharded Data Parallelism) 

541 functionality, enabling efficient training of large models by sharding parameters 

542 across multiple devices. The module is automatically enhanced with distributed 

543 capabilities including parameter sharding, gradient synchronization, and memory 

544 management. 

545 

546 When a list of modules is passed, they are treated as one FSDP unit (parameters 

547 grouped together). Both PyTorch and MindSpore platforms support list input. 

548 

549 Parameters: 

550 module (nn.Module or List[nn.Module]): 

551 The module(s) to apply fully_shard to. Modified in-place. When a list 

552 is passed, parameters from all modules are grouped as one FSDP unit. 

553 

554 mesh (Optional[DeviceMesh], default=None): 

555 The device mesh defining the process topology for distributed training. 

556 If None, fully_shard keeps pure-DTensor modules on their original 

557 distributed layout and only creates a default 1D mesh when local 

558 parameters need explicit data-parallel/FSDP management. 

559 

560 reshard_after_forward (bool, default=True): 

561 Whether to automatically reshard parameters after forward. When True, 

562 parameters are resharded immediately after they are no longer needed, 

563 freeing memory for subsequent operations. Set to False if you want to 

564 keep parameters unsharded for backward pass or manual control. 

565 

566 shard_placement_fn (Callable, default=None): 

567 A callable that determines how to shard each parameter. The function 

568 should accept a parameter and return a Shard object specifying the 

569 sharding dimension, or None to use default sharding (dimension 0) 

570 

571 mp_policy (MixedPrecisionPolicy, default=MixedPrecisionPolicy()): 

572 Mixed precision training policy controlling data type conversions. 

573 offload_policy (OffloadPolicy, default=OffloadPolicy()): 

574 Memory offload policy for reducing device memory usage. 

575 

576 ignored_params (Optional[set[nn.Parameter]], default=None): 

577 Set of parameters to exclude from fully_shard management entirely. 

578 These parameters are left on the original module as regular parameters, 

579 are not sharded, and do not participate in fully_shard gradient 

580 synchronization. Use this for parameters that should remain outside 

581 the fully_shard lifecycle. 

582 

583 comm_fusion (bool, default=False): 

584 Whether enable all_gather fusion and reduce_scatter fusion. 

585 

586 replicate_params (Optional[set[nn.Parameter]], default=None): 

587 Set of parameters to keep replicated while still managing them under 

588 fully_shard. These parameters are not sharded, but their gradients 

589 are still synchronized with DDP-style all-reduce over the current 

590 fully_shard communication domain. This differs from ``ignored_params``, 

591 which skips fully_shard management and gradient synchronization 

592 entirely for the selected parameters. 

593 

594 comm_fusion_zero_copy (Optional[bool], default=None): 

595 Whether allow the experimental zero-copy path for 

596 ``comm_fusion``. When set to ``None``, fully_shard uses a backend-specific 

597 default: 

598 - PyTorch: enabled automatically when ``comm_fusion=True`` 

599 - MindSpore: disabled automatically even when ``comm_fusion=True`` 

600 When enabled, fully_shard may rebase sharded local parameter storage 

601 into one shared flat buffer so fused all-gather can read directly from 

602 contiguous memory. This path depends on optimizer compatibility with 

603 view-backed parameters. 

604 

605 Returns: 

606 nn.Module or List[nn.Module]: The input module(s) with HSDP capabilities added. 

607 """ 

608 platform_type = platform.platform_type 

609 _validate_module_for_fully_shard(module, platform_type) 

610 if platform_type == PlatformType.MINDSPORE: 

611 from hyper_parallel.platform.mindspore.autograd_compat import enable_mindspore_backward_compat 

612 

613 enable_mindspore_backward_compat() 

614 

615 arg_module = module 

616 if isinstance(module, list): 

617 modules = tuple(_get_root_modules(module)) 

618 else: 

619 modules = (module,) 

620 

621 for mod in modules: 

622 _extend_module_with_hsdp_interface(mod) 

623 

624 params = _get_modules_parameters(modules, ignored_params) 

625 has_dtensor_param = any(is_dtensor_managed_param(param) for param in params) 

626 replicate_params = _normalize_replicate_params(replicate_params) 

627 

628 if mesh is None and not has_dtensor_param: 

629 mesh = init_device_mesh(device_type="npu", mesh_shape=(platform.get_world_size(),)) 

630 if mesh is not None: 

631 device = _get_device_from_mesh(mesh) 

632 else: 

633 compat_mesh = next( 

634 (dtensor_mesh for param in params if (dtensor_mesh := get_dtensor_managed_mesh(param)) is not None), 

635 None, 

636 ) 

637 if compat_mesh is None: 

638 raise ValueError("fully_shard could not resolve a DTensor mesh for compatibility mode.") 

639 device = _get_device_from_mesh(compat_mesh) 

640 

641 init_modules = modules 

642 modules[0].hsdp_init( 

643 platform_type, 

644 init_modules, 

645 mesh, 

646 reshard_after_forward, 

647 shard_placement_fn, 

648 mp_policy, 

649 offload_policy, 

650 ignored_params, 

651 replicate_params, 

652 device, 

653 comm_fusion, 

654 comm_fusion_zero_copy, 

655 ) 

656 # Share the same scheduler handle with other roots so mods[i].unshard()/prefetch work 

657 if len(modules) > 1: 

658 for mod in modules[1:]: 

659 mod.hsdp_scheduler = modules[0].hsdp_scheduler 

660 return arg_module 

661 

662 

663def get_model_state_dict(model, *, options=None): 

664 """Get model state dict with platform-specific implementation. 

665 

666 Delegates to the platform-specific implementation at runtime. 

667 Users import from here instead of platform internals. 

668 """ 

669 return platform.get_model_state_dict(model, options=options) 

670 

671 

672def hsdp_sync_stream(): 

673 """Wait for hsdp gradient handle to be completed.""" 

674 platform.wait_grad_handle()