Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / fully_shard / scheduler.py: 46%
112 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MindSpore HSDP scheduler"""
16import mindspore as ms
17from mindspore.common.api import _pynative_executor
18from mindspore.utils._pytree import tree_flatten, tree_unflatten
19from hyper_parallel.core.fully_shard.hsdp_scheduler import HSDPSchedulerV2, FSDPSchedulerState
20from hyper_parallel.core.fully_shard.hsdp_utils import get_dtensor_managed_mesh
21from hyper_parallel.platform.mindspore.fully_shard.hook_function import PostBackwardFunction
22from hyper_parallel.platform.mindspore.fully_shard.param_group import get_comm_ctx
23from hyper_parallel.platform.mindspore.fully_shard.state import MindSporeHSDPStateV2
24from hyper_parallel.core.fully_shard.utils import FSDPMeshInfo, HSDPMeshInfo, DDPMeshInfo
25from hyper_parallel.platform import get_platform
28class MindSporeHSDPSchedulerV2(HSDPSchedulerV2):
29 """MindSpore HSDP scheduler.
31 List-unit grouped forward hooks use :class:`HSDPSchedulerV2` defaults for
32 ``_grouped_forward_pre_hook_skip`` / ``_grouped_forward_post_hook_skip`` (no overrides here).
33 """
34 def zero_grad(self) -> None:
35 """Zero grad."""
36 self.hsdp_state.zero_grad()
38 def _register_hooks(self):
39 """Register hooks."""
40 self._register_forward_backward_hooks()
42 def _init_platform(self):
43 """Initialize the platform."""
44 from hyper_parallel.platform.mindspore.platform import MindSporePlatform
45 self.platform = get_platform()
46 if not isinstance(self.platform, MindSporePlatform):
47 raise ValueError(f"MindSporeHSDPSchedulerV2 expect MindSporePlatform, but got type: {type(self.platform)}")
49 def _new_cell_state(self):
50 """Create a new cell state for mindspore."""
51 params = self._get_managed_params()
52 if self.mesh is None:
53 compat_meshes = [get_dtensor_managed_mesh(param) for param in params]
54 compat_meshes = [mesh for mesh in compat_meshes if mesh is not None]
55 compat_mesh = compat_meshes[0] if compat_meshes else None
56 if compat_mesh is None:
57 raise ValueError(
58 "Cannot build fully_shard compatibility mesh_info "
59 "without a DTensor parameter mesh."
60 )
61 compat_mesh_hash = compat_mesh.to_hash()
62 for param_mesh in compat_meshes[1:]:
63 if param_mesh.to_hash() != compat_mesh_hash:
64 raise ValueError(
65 "fully_shard compatibility mode requires all DTensor parameters to share the same mesh."
66 )
67 self.mesh_info = DDPMeshInfo(mesh=compat_mesh, replicate_mesh_dim=0)
68 elif self.mesh.ndim == 1:
69 self.mesh_info = FSDPMeshInfo(mesh=self.mesh, shard_mesh_dim=0)
70 elif self.mesh.ndim == 2:
71 self.mesh_info = HSDPMeshInfo(mesh=self.mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
72 else:
73 raise ValueError(
74 "fully_shard only supports explicit 1D DP/FSDP meshes or 2D HSDP meshes. "
75 f"Got mesh.ndim={self.mesh.ndim}."
76 )
77 self.hsdp_state = MindSporeHSDPStateV2(
78 self.modules, self.mesh_info, self.config, self.platform, self.device
79 )
81 def _register_post_backward_hook(self, args, kwargs):
82 """Register backward hook using backward function."""
83 if not _pynative_executor.enable_grad():
84 return args, kwargs
85 args_list, args_spec = tree_flatten(args)
86 kwargs_list, kwargs_spec = tree_flatten(kwargs)
87 args_kwargs_list = list(args_list) + list(kwargs_list)
88 if not any(
89 isinstance(obj, ms.Tensor) and getattr(obj, "requires_grad", False)
90 for obj in args_kwargs_list
91 ):
92 return args, kwargs
93 processed_list = list(PostBackwardFunction.apply(self, *args_kwargs_list))
94 for idx, (orig_obj, processed_obj) in enumerate(zip(args_kwargs_list, processed_list)):
95 if isinstance(orig_obj, ms.Tensor) and isinstance(processed_obj, ms.Tensor):
96 try:
97 processed_obj.requires_grad = bool(getattr(orig_obj, "requires_grad", False))
98 except (AttributeError, RuntimeError, TypeError, ValueError):
99 pass
100 processed_list[idx] = processed_obj
101 args_kwargs_list = processed_list
102 args_list = args_kwargs_list[: len(args_list)]
103 kwargs_list = args_kwargs_list[len(args_list):]
104 args = tree_unflatten(args_spec, args_list)
105 kwargs = tree_unflatten(kwargs_spec, kwargs_list)
106 return args, kwargs
108 def _forward_pre_hook(self, cell, args, kwargs):
109 """Execute forward pre hook and set up backward hook."""
110 args, kwargs = self._hsdp_forward_pre_hook(cell, args, kwargs)
111 return self._register_post_backward_hook(args, kwargs)
113 def _register_backward_pre_hook(self, outputs):
114 """Register output hook to trigger backward pre hook."""
115 flat_outputs, _ = tree_flatten(outputs)
116 for output in flat_outputs:
117 if isinstance(output, ms.Tensor) and output._requires_grad:
118 output.register_hook(self._backward_pre_hook)
119 return outputs
121 def _forward_hook(self, cell, inputs, outputs):
122 """Execute forward hook."""
123 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD:
124 return
125 self._register_backward_pre_hook(outputs)
126 if HSDPSchedulerV2.root_bp_state:
127 self._restore_forward_prefetch_after_recompute()
128 return
129 return self._hsdp_forward_hook(cell, inputs, outputs)
131 # pylint: disable=W0212
132 def _backward_pre_hook(self, grad):
133 """Execute backward pre hook."""
134 _pynative_executor.queue_backward_final_callback(self._root_backward_hook)
135 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD:
136 return grad
137 HSDPSchedulerV2.root_bp_state = True
138 self._hsdp_backward_pre_hook(self.cell, None)
139 return grad
141 def _root_backward_hook(self):
142 """Root backward hook: finalize the outermost backward and clear recompute state."""
143 apply_final_reduce = self.scheduler_state != FSDPSchedulerState.BACKWARD
144 self._backward_hook()
145 if apply_final_reduce:
146 comm_ctx = get_comm_ctx()
147 if comm_ctx.all_reduce_param_group is not None:
148 comm_ctx.all_reduce_param_group.wait_all_reduce_and_apply_grad()
149 comm_ctx.all_reduce_param_group = None
150 if comm_ctx.pre_param_group is not None:
151 comm_ctx.pre_param_group.apply_fusion_reduced_grad()
152 comm_ctx.pre_param_group = None
153 self.hsdp_state.reduce_params()
154 self.hsdp_state._finish_ignored_allreduce()
155 HSDPSchedulerV2.root_bp_state = False
157 def _backward_hook(self):
158 """Execute backward hook."""
159 if self.scheduler_state == FSDPSchedulerState.BACKWARD:
160 return
161 self._hsdp_backward_hook(self.cell, None, None)
163 def _register_forward_backward_hooks(self):
164 """Register module forward and backward hook on all managed modules."""
165 if self._fsdp_group_post_pending is None:
166 for mod in self.modules:
167 mod.register_forward_pre_hook(self._forward_pre_hook, with_kwargs=True)
168 mod.register_forward_hook(self._forward_hook)
169 return
170 for mod in self.modules:
171 mod.register_forward_pre_hook(self._grouped_forward_pre_hook, with_kwargs=True)
172 mod.register_forward_hook(self._make_grouped_forward_post_hook(mod))