Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / fully_shard / api.py: 42%
297 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""hybrid shard data parallel interface"""
16import warnings
17from collections import namedtuple
18from typing import Any, List, Mapping, cast, Optional, Union
20from hyper_parallel.platform.platform import PlatformType
21from hyper_parallel.core.fully_shard.utils import MixedPrecisionPolicy, OffloadPolicy
22from hyper_parallel import DeviceMesh, init_device_mesh
23from hyper_parallel.platform import get_platform
24from hyper_parallel.core.dtensor.dtensor import DTensor, distribute_tensor
25from hyper_parallel.core.fully_shard.hsdp_utils import (
26 get_managed_modules_parameters,
27 is_dtensor_managed_param,
28 get_dtensor_managed_mesh,
29)
31platform = get_platform()
33origin_class_to_extend_class = {}
36def _resolve_comm_fusion_zero_copy_default(
37 platform_type: PlatformType,
38 comm_fusion: bool,
39 comm_fusion_zero_copy: Optional[bool],
40) -> bool:
41 """Resolve backend-specific default for the comm_fusion zero-copy path."""
42 if comm_fusion_zero_copy is not None:
43 return comm_fusion_zero_copy
44 if not comm_fusion:
45 return False
46 if platform_type == PlatformType.PYTORCH:
47 return True
48 if platform_type == PlatformType.MINDSPORE:
49 return False
50 return False
53def _check_strict_keys(
54 module: platform.Module, state_dict: Mapping[str, Any],
55) -> None:
56 """Raise ``RuntimeError`` if *state_dict* keys do not match *module*."""
57 expected_keys = set(module.state_dict().keys())
58 missing = expected_keys - set(state_dict.keys())
59 unexpected = set(state_dict.keys()) - expected_keys
60 error_msgs: list[str] = []
61 if missing:
62 error_msgs.append(
63 "Missing key(s): " + ", ".join(repr(k) for k in sorted(missing))
64 )
65 if unexpected:
66 error_msgs.append(
67 "Unexpected key(s): " + ", ".join(repr(k) for k in sorted(unexpected))
68 )
69 if error_msgs:
70 raise RuntimeError(
71 f"Error(s) in loading state_dict for "
72 f"{module.__class__.__name__}:\n\t"
73 + "\n\t".join(error_msgs)
74 )
77def _resolve_local_tensor(
78 key: str, val: platform.Tensor, target: DTensor,
79) -> platform.Tensor:
80 """Return the local shard tensor to be loaded into *target*."""
81 if isinstance(val, DTensor):
82 return val.to_local()
83 local_shape = tuple(target.local_shape)
84 global_shape = tuple(target.shape)
85 val_shape = tuple(val.shape)
86 if val_shape == local_shape:
87 return val
88 if val_shape == global_shape:
89 wrapped = distribute_tensor(
90 val, target.device_mesh,
91 target.layout.alias_placements if target.layout else target.placements,
92 )
93 return wrapped.to_local()
95 raise ValueError(
96 f"load '{key}': plain tensor shape {val_shape} "
97 f"matches neither local shard {local_shape} "
98 f"nor global {global_shape}."
99 )
102class _UnshardHandle:
103 """Unshard handle for user call HSDPModule.unshard(async_op=True)"""
104 def __init__(self, hsdp_state=None):
105 """
106 Initialize an async unshard handle.
108 Args:
109 hsdp_state (HSDPState, optional): The state to wait on. None means a no-op handle.
110 """
111 self._hsdp_state = hsdp_state
113 def wait(self):
114 """Block until the async unshard operation completes."""
115 if self._hsdp_state is not None:
116 self._hsdp_state.wait_for_unshard()
117 self._hsdp_state = None
120class HSDPModule:
121 """
122 The hsdp block of neural networks with hsdp interface.
124 Supported Platforms:
125 ``MindSpore`` ``torch``
126 """
128 def __init__(self):
129 """Initialize HSDPModule."""
130 self.hsdp_scheduler = None # Initialized in hsdp_init()
132 # pylint: disable=C0415
133 def hsdp_init(self, platform_type, module, mesh, reshard_after_forward,
134 shard_placement_fn, mp_policy, offload_policy, ignored_params, replicate_params, device,
135 comm_fusion, comm_fusion_zero_copy: Optional[bool] = None):
136 """init hsdp2 scheduler."""
137 scheduler_class = None
138 if platform_type == PlatformType.MINDSPORE:
139 from hyper_parallel.platform.mindspore.fully_shard.scheduler import MindSporeHSDPSchedulerV2
140 scheduler_class = MindSporeHSDPSchedulerV2
141 else:
142 from hyper_parallel.platform.torch.fully_shard.scheduler import TorchHSDPSchedulerV2
143 scheduler_class = TorchHSDPSchedulerV2
145 resolved_comm_fusion_zero_copy = _resolve_comm_fusion_zero_copy_default(
146 platform_type,
147 comm_fusion,
148 comm_fusion_zero_copy,
149 )
151 self.hsdp_scheduler = scheduler_class(module,
152 mesh,
153 reshard_after_forward,
154 shard_placement_fn,
155 mp_policy,
156 offload_policy,
157 ignored_params,
158 replicate_params,
159 device,
160 comm_fusion,
161 resolved_comm_fusion_zero_copy,
162 )
164 def set_requires_gradient_sync(self, requires_grad_sync):
165 r"""
166 set requires grad sync flag.
167 Args:
168 requires_grad_sync(bool): requires_grad_sync is used to control gradient sync process.
169 Raises:
170 ValueError: If `requires_grad_sync` is not bool.
171 """
172 if not isinstance(requires_grad_sync, bool):
173 raise ValueError(f"requires_grad_sync must be bool but got {requires_grad_sync}.")
174 if not hasattr(self, "hsdp_scheduler"):
175 raise ValueError("call hsdp interface first.")
177 for _, module in platform.get_cells_and_names(self):
178 if isinstance(module, HSDPModule):
179 module.hsdp_scheduler.set_requires_grad_sync(requires_grad_sync)
181 def zero_grad(self):
182 """zero accumunication grads"""
183 if not hasattr(self, "hsdp_scheduler"):
184 raise ValueError("call hsdp interface first.")
185 if platform.platform_type == PlatformType.PYTORCH:
186 return super().zero_grad()
187 for _, module in platform.get_cells_and_names(self):
188 if isinstance(module, HSDPModule):
189 module.hsdp_scheduler.zero_grad()
191 def set_modules_to_forward_prefetch(self, modules):
192 """set forward prefetch module list to prefetch all gather for unsharded parameters"""
193 if not isinstance(modules, (tuple, list)):
194 raise ValueError("modules must be HSDPModule list")
195 for module in modules:
196 if not isinstance(module, HSDPModule):
197 raise ValueError(f"modules must be HSDPModule list but got {type(module)} in list.")
198 if not hasattr(self, "hsdp_scheduler"):
199 raise ValueError("call hsdp interface first.")
200 self.hsdp_scheduler.set_forward_prefetch_cells(modules)
202 def set_modules_to_backward_prefetch(self, modules):
203 """set backward prefetch module list to prefetch all gather for unsharded parameters"""
204 if not isinstance(modules, (tuple, list)):
205 raise ValueError("modules must be HSDPModule list")
206 for module in modules:
207 if not isinstance(module, HSDPModule):
208 raise ValueError(f"modules must be HSDPModule list but got {type(module)} in list.")
209 if not hasattr(self, "hsdp_scheduler"):
210 raise ValueError("call fully_shard interface first.")
211 self.hsdp_scheduler.set_backward_prefetch_cells(modules)
213 def reshard(self) -> None:
214 """reshard all sharded parameters"""
215 if not self.hsdp_scheduler:
216 raise ValueError("hsdp_scheduler is None")
217 hsdp_state = self.hsdp_scheduler.hsdp_state
218 if hsdp_state:
219 hsdp_state.shard()
221 def unshard(self, async_op: bool = False):
222 """unshard all sharded parameters"""
223 if not isinstance(async_op, bool):
224 raise ValueError(f"async_op should be a bool, got {type(async_op)}")
225 if not self.hsdp_scheduler:
226 raise ValueError("hsdp_scheduler is None")
227 hsdp_state = self.hsdp_scheduler.hsdp_state
228 if hsdp_state:
229 hsdp_state.unshard(async_op) # pylint: disable=too-many-function-args
230 if async_op:
231 return _UnshardHandle(hsdp_state=hsdp_state)
232 return None
234 def load_state_dict(
235 self,
236 state_dict: Mapping[str, Any],
237 strict: bool = True,
238 assign: bool = False,
239 ):
240 """
241 Load state dict by copying directly into local shards.
243 Bypasses ``super().load_state_dict()`` because the standard PyTorch
244 implementation triggers ``copy_`` through the DTensor dispatcher, which
245 is not registered in the hyper-parallel layout system.
247 Each value in ``state_dict`` is dispatched by type:
248 - hyper DTensor: extract local shard and copy directly.
249 - plain Tensor whose shape == local shard shape: copy as-is.
250 - plain Tensor whose shape == global shape: distribute via
251 ``distribute_tensor``, then copy the local shard.
253 Args:
254 state_dict (Mapping[str, Any]): Fully-qualified parameter/buffer
255 names mapped to tensors (DTensor or plain Tensor).
256 strict (bool): If ``True`` (default), missing or unexpected keys
257 raise ``RuntimeError``, matching ``nn.Module.load_state_dict``
258 semantics.
259 assign (bool): Accepted for API compatibility with
260 ``nn.Module.load_state_dict(assign=True)`` but currently
261 ignored; HSDP always copies into existing DTensor storage.
263 Raises:
264 RuntimeError: When ``strict`` is ``True`` and keys do not match.
265 ValueError: When a plain tensor shape matches neither the local
266 shard shape nor the global shape of the target DTensor.
267 """
268 if assign:
269 warnings.warn(
270 "HSDPModule.load_state_dict: assign=True is ignored; "
271 "HSDP always copies into existing DTensor parameters.",
272 stacklevel=2,
273 )
274 self_module = cast(platform.Module, self)
276 target_map: dict[str, platform.Tensor] = {}
277 for name, p in platform.parameters_dict(self_module):
278 target_map[name] = p
279 for name, b in self_module.named_buffers():
280 target_map[name] = b
282 if strict:
283 _check_strict_keys(self_module, state_dict)
285 with platform.no_grad():
286 for key, val in state_dict.items():
287 target = target_map.get(key)
288 if target is None:
289 continue
291 if isinstance(target, DTensor):
292 val = _resolve_local_tensor(key, val, target)
293 platform.load_into_param(target, val)
295 # Trigger load_state_dict post-hooks so that HSDP internal
296 # bookkeeping (e.g. _sharded_param_data) stays in sync.
297 # Pass an IncompatibleKeys with the same attribute names as PyTorch
298 # so external hooks can safely read .missing_keys/.unexpected_keys.
299 _IK = namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])
300 incompatible_keys = _IK([], [])
301 for _, module in platform.get_cells_and_names(self_module):
302 hooks = module._load_state_dict_post_hooks # pylint: disable=protected-access
303 for hook in hooks.values():
304 hook(module, incompatible_keys)
306 def set_is_last_backward(self, is_last_backward: bool):
307 """set is_last_backward flag"""
308 self.hsdp_scheduler.scheduler_ctx.is_last_backward = is_last_backward
310 def set_requires_all_reduce(self, requires_all_reduce: bool, *, recurse: bool = True) -> None:
311 """set requires_all_reduce flag"""
312 if not isinstance(requires_all_reduce, bool):
313 raise ValueError(
314 f"requires_all_reduce should be a bool, got {type(requires_all_reduce)}"
315 )
316 if not recurse:
317 raise NotImplementedError(
318 "Currently impl is equal to recurse=True, "
319 "need support module_param mapping."
320 )
321 self_module = cast(platform.Module, self)
322 modules = list(self_module.modules()) if recurse else [self_module]
323 for module in modules:
324 if isinstance(module, HSDPModule):
325 module.hsdp_scheduler.set_requires_all_reduce(requires_all_reduce)
327 def set_reshard_after_forward(self, reshard_after_forward: bool, recurse: bool = True) -> None:
328 """set reshard_after_forward flag"""
329 if not isinstance(reshard_after_forward, bool):
330 raise ValueError(
331 f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}"
332 )
333 if not recurse:
334 raise NotImplementedError(
335 "Currently impl is equal to recurse=True, "
336 "need support module_param mapping."
337 )
338 self_module = cast(platform.Module, self)
339 modules = list(self_module.modules()) if recurse else [self_module]
340 for module in modules:
341 if isinstance(module, HSDPModule):
342 module.hsdp_scheduler.set_reshard_after_forward(reshard_after_forward)
344 def set_reshard_after_backward(self, reshard_after_backward: bool, recurse: bool = True) -> None:
345 """set reshard_after_backward flag"""
346 if not isinstance(reshard_after_backward, bool):
347 raise ValueError(
348 f"reshard_after_backward should be a bool, got {type(reshard_after_backward)}"
349 )
350 if not recurse:
351 raise NotImplementedError(
352 "Currently impl is equal to recurse=True, "
353 "need support module_param mapping."
354 )
355 self_module = cast(platform.Module, self)
356 modules = list(self_module.modules()) if recurse else [self_module]
357 for module in modules:
358 if isinstance(module, HSDPModule):
359 module.hsdp_scheduler.set_reshard_after_backward(reshard_after_backward)
361 def set_reduce_op_type(self, reduce_op_type) -> None:
362 """
363 Set reduce_op_type for all gradient reductions in fully_shard.
365 Supports ``"avg"`` and ``"sum"``. Local-parameter FSDP/HSDP keeps the
366 historical ``"avg"`` default, while DTensor-based paths default to ``"sum"``.
367 """
368 if hsdp_state := self.hsdp_scheduler.hsdp_state:
369 hsdp_state.set_reduce_op_type(reduce_op_type)
372def _extend_module_with_hsdp_interface(module):
373 """Dynamically extend module's class to inherit from HSDPModule, adding HSDP capabilities."""
374 origin_class = module.__class__
375 extend_class = origin_class_to_extend_class.get(origin_class, None)
376 if extend_class is None:
377 extend_class = type(f"HSDP{origin_class.__name__}", (HSDPModule, origin_class), {})
378 origin_class_to_extend_class[origin_class] = extend_class
379 module.__class__ = extend_class
382def _get_root_modules(modules: List[platform.Module]) -> List[platform.Module]:
383 """
384 Returns the modules in ``modules`` that are root modules (i.e. parent-less)
385 with respect to the set ``modules``. In other words, these are the modules
386 in ``modules`` that are not the child of any other module in ``modules``.
388 Aligned with PyTorch torch.distributed.utils._get_root_modules.
389 """
390 root_modules: List[platform.Module] = []
392 def _get_submodules(mod):
393 if platform.platform_type == PlatformType.MINDSPORE:
394 return set(c for _, c in mod.cells_and_names())
395 return set(mod.modules())
397 module_to_modules: dict[platform.Module, set] = {
398 m: _get_submodules(m) for m in modules
399 }
400 for candidate in modules:
401 is_root = True
402 for mod, submodules in module_to_modules.items():
403 if candidate is not mod and candidate in submodules:
404 is_root = False
405 break
406 if is_root:
407 root_modules.append(candidate)
408 return root_modules
411def _check_module_valid(platform_type, module):
412 """check module valid"""
413 if platform_type == PlatformType.MINDSPORE:
414 from mindspore.nn.cell import Cell
415 if not isinstance(module, Cell):
416 raise ValueError(f"module's type must be nn.cell but got {type(module)}.")
417 else:
418 from torch.nn import Module
419 if not isinstance(module, Module):
420 raise ValueError(f"module's type must be nn.Module but got {type(module)}.")
423def _validate_module_for_fully_shard(
424 module: Union[platform.Module, List[platform.Module]], platform_type
425) -> None:
426 """Validate module(s) for fully_shard. Platform-aware for single module."""
427 if isinstance(module, list):
428 if len(module) == 0:
429 raise ValueError("fully_shard does not support empty list of modules.")
430 for i, m in enumerate(module):
431 try:
432 _check_module_valid(platform_type, m)
433 except ValueError:
434 raise ValueError(
435 f"fully_shard expects nn.Module or list[nn.Module], "
436 f"but got list with {type(m).__name__} at index {i}."
437 ) from None
438 else:
439 _check_module_valid(platform_type, module)
442def _check_hsdp_input_valid(platform_type, module, shard_size, threshold, optimizer_level, enable_grad_accumulation,
443 grad_scale, reduce_dtype, comm_async, comm_fusion, bucket_size):
444 """check hsdp input valid"""
445 _check_module_valid(platform_type, module)
446 if not isinstance(shard_size, int) or (shard_size <= 0 and shard_size != -1):
447 raise ValueError(f"shard_size must be a positive integer, but got {shard_size}.")
448 if not isinstance(threshold, int) or threshold < 0:
449 raise ValueError(f"threshold must be a positive integer or 0, but got {threshold}.")
450 if optimizer_level not in ["level1", "level2", "level3"]:
451 raise ValueError(f"Optimizer level should in ['level1', 'level2', 'level3'], but got {optimizer_level}.")
452 if not isinstance(enable_grad_accumulation, bool):
453 raise ValueError(f"enable_grad_accumulation must be bool but got {enable_grad_accumulation}.")
454 if not isinstance(grad_scale, float):
455 raise ValueError(f"grad_scale must be float but got {grad_scale}.")
456 if platform_type == PlatformType.MINDSPORE:
457 from mindspore._c_expression.typing import Type
458 if reduce_dtype is not None and not isinstance(reduce_dtype, Type):
459 raise ValueError(f"reduce_dtype must be mindspore.dtype but got {reduce_dtype}.")
460 else:
461 import torch
462 if reduce_dtype is not None and not isinstance(reduce_dtype, torch.dtype):
463 raise ValueError(f"reduce_dtype must be torch.dtype but got {reduce_dtype}.")
464 if not isinstance(comm_async, bool):
465 raise ValueError(f"comm_async must be bool but got {comm_async}.")
466 if not isinstance(comm_fusion, bool):
467 raise ValueError(f"comm_fusion must be bool but got {comm_fusion}.")
468 if not isinstance(bucket_size, int) or (bucket_size < 0 and bucket_size != -1):
469 raise ValueError(f"bucket_size must be a positive integer or 0, but got {bucket_size}.")
472def _get_device_from_mesh(mesh: DeviceMesh):
473 """Extract and validate the torch device from the device mesh."""
474 device = None
475 device_type = mesh.device_type
476 if device_type not in ("npu", "cuda"):
477 raise AssertionError(
478 f"hyper_parallel.fully_shard support device in [torch.npu, torch.cuda], "
479 f"but got '{device_type}'"
480 )
481 if platform.platform_type == PlatformType.PYTORCH:
482 device_handle = platform.get_device_handle(device_type)
483 if device_handle is None:
484 raise ValueError(
485 f"hyper_parallel.fully_shard can't find device_handle of "
486 f"'torch.{device_type}', check the environment."
487 )
488 if device_handle.is_available():
489 import torch
490 device = torch.device(device_handle.current_device())
491 else:
492 device = device_type
493 return device
496def _normalize_replicate_params(
497 replicate_params: Optional[set[platform.Parameter]],
498) -> set[platform.Parameter]:
499 """
500 Normalize replicate_params for fully_shard
501 Args:
502 replicate_params (Optional[set[nn.Parameter]]): Set of parameters to exclude from sharding.
503 Returns:
504 set[nn.Parameter]: Set of parameters to exclude from sharding.
505 """
506 if replicate_params is None:
507 return set()
508 out = set(replicate_params)
509 for p in out:
510 if not isinstance(p, (platform.Parameter, DTensor)):
511 raise TypeError(
512 "replicate_params must contain only nn.Parameter or DTensor, "
513 f"got {type(p).__name__}."
514 )
515 return out
518def _get_modules_parameters(modules, ignored_params=None):
519 """Collect deduplicated parameters from module roots."""
520 return get_managed_modules_parameters(modules, ignored_params)
523def fully_shard(
524 module: Union[platform.Module, List[platform.Module]],
525 *,
526 mesh: Optional[DeviceMesh] = None,
527 reshard_after_forward: bool = True,
528 shard_placement_fn: None = None,
529 mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
530 offload_policy: OffloadPolicy = OffloadPolicy(),
531 ignored_params: Optional[set[platform.Parameter]] = None,
532 replicate_params: Optional[set[platform.Parameter]] = None,
533 comm_fusion: bool = False,
534 comm_fusion_zero_copy: Optional[bool] = None,
535) -> Union[platform.Module, List[platform.Module]]:
537 """
538 Apply fully_shard to a module (or list of modules) for distributed training with parameter sharding.
540 This interface provides PyTorch-compatible HSDP (Hybrid Sharded Data Parallelism)
541 functionality, enabling efficient training of large models by sharding parameters
542 across multiple devices. The module is automatically enhanced with distributed
543 capabilities including parameter sharding, gradient synchronization, and memory
544 management.
546 When a list of modules is passed, they are treated as one FSDP unit (parameters
547 grouped together). Both PyTorch and MindSpore platforms support list input.
549 Parameters:
550 module (nn.Module or List[nn.Module]):
551 The module(s) to apply fully_shard to. Modified in-place. When a list
552 is passed, parameters from all modules are grouped as one FSDP unit.
554 mesh (Optional[DeviceMesh], default=None):
555 The device mesh defining the process topology for distributed training.
556 If None, fully_shard keeps pure-DTensor modules on their original
557 distributed layout and only creates a default 1D mesh when local
558 parameters need explicit data-parallel/FSDP management.
560 reshard_after_forward (bool, default=True):
561 Whether to automatically reshard parameters after forward. When True,
562 parameters are resharded immediately after they are no longer needed,
563 freeing memory for subsequent operations. Set to False if you want to
564 keep parameters unsharded for backward pass or manual control.
566 shard_placement_fn (Callable, default=None):
567 A callable that determines how to shard each parameter. The function
568 should accept a parameter and return a Shard object specifying the
569 sharding dimension, or None to use default sharding (dimension 0)
571 mp_policy (MixedPrecisionPolicy, default=MixedPrecisionPolicy()):
572 Mixed precision training policy controlling data type conversions.
573 offload_policy (OffloadPolicy, default=OffloadPolicy()):
574 Memory offload policy for reducing device memory usage.
576 ignored_params (Optional[set[nn.Parameter]], default=None):
577 Set of parameters to exclude from fully_shard management entirely.
578 These parameters are left on the original module as regular parameters,
579 are not sharded, and do not participate in fully_shard gradient
580 synchronization. Use this for parameters that should remain outside
581 the fully_shard lifecycle.
583 comm_fusion (bool, default=False):
584 Whether enable all_gather fusion and reduce_scatter fusion.
586 replicate_params (Optional[set[nn.Parameter]], default=None):
587 Set of parameters to keep replicated while still managing them under
588 fully_shard. These parameters are not sharded, but their gradients
589 are still synchronized with DDP-style all-reduce over the current
590 fully_shard communication domain. This differs from ``ignored_params``,
591 which skips fully_shard management and gradient synchronization
592 entirely for the selected parameters.
594 comm_fusion_zero_copy (Optional[bool], default=None):
595 Whether allow the experimental zero-copy path for
596 ``comm_fusion``. When set to ``None``, fully_shard uses a backend-specific
597 default:
598 - PyTorch: enabled automatically when ``comm_fusion=True``
599 - MindSpore: disabled automatically even when ``comm_fusion=True``
600 When enabled, fully_shard may rebase sharded local parameter storage
601 into one shared flat buffer so fused all-gather can read directly from
602 contiguous memory. This path depends on optimizer compatibility with
603 view-backed parameters.
605 Returns:
606 nn.Module or List[nn.Module]: The input module(s) with HSDP capabilities added.
607 """
608 platform_type = platform.platform_type
609 _validate_module_for_fully_shard(module, platform_type)
610 if platform_type == PlatformType.MINDSPORE:
611 from hyper_parallel.platform.mindspore.autograd_compat import enable_mindspore_backward_compat
613 enable_mindspore_backward_compat()
615 arg_module = module
616 if isinstance(module, list):
617 modules = tuple(_get_root_modules(module))
618 else:
619 modules = (module,)
621 for mod in modules:
622 _extend_module_with_hsdp_interface(mod)
624 params = _get_modules_parameters(modules, ignored_params)
625 has_dtensor_param = any(is_dtensor_managed_param(param) for param in params)
626 replicate_params = _normalize_replicate_params(replicate_params)
628 if mesh is None and not has_dtensor_param:
629 mesh = init_device_mesh(device_type="npu", mesh_shape=(platform.get_world_size(),))
630 if mesh is not None:
631 device = _get_device_from_mesh(mesh)
632 else:
633 compat_mesh = next(
634 (dtensor_mesh for param in params if (dtensor_mesh := get_dtensor_managed_mesh(param)) is not None),
635 None,
636 )
637 if compat_mesh is None:
638 raise ValueError("fully_shard could not resolve a DTensor mesh for compatibility mode.")
639 device = _get_device_from_mesh(compat_mesh)
641 init_modules = modules
642 modules[0].hsdp_init(
643 platform_type,
644 init_modules,
645 mesh,
646 reshard_after_forward,
647 shard_placement_fn,
648 mp_policy,
649 offload_policy,
650 ignored_params,
651 replicate_params,
652 device,
653 comm_fusion,
654 comm_fusion_zero_copy,
655 )
656 # Share the same scheduler handle with other roots so mods[i].unshard()/prefetch work
657 if len(modules) > 1:
658 for mod in modules[1:]:
659 mod.hsdp_scheduler = modules[0].hsdp_scheduler
660 return arg_module
663def get_model_state_dict(model, *, options=None):
664 """Get model state dict with platform-specific implementation.
666 Delegates to the platform-specific implementation at runtime.
667 Users import from here instead of platform internals.
668 """
669 return platform.get_model_state_dict(model, options=options)
672def hsdp_sync_stream():
673 """Wait for hsdp gradient handle to be completed."""
674 platform.wait_grad_handle()