Coverage for hyper_parallel / platform / mindspore / hsdp / scheduler.py: 50%
109 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MindSpore HSDP scheduler"""
16import warnings
17from pathlib import Path
18from importlib import resources
19import mindspore as ms
20from mindspore import ops
21from mindspore import jit_class, nn
22from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel
23from hyper_parallel.core.hsdp.hsdp_scheduler import HSDPScheduler
24from hyper_parallel.platform import get_platform
25from hyper_parallel.platform.mindspore.platform_graph import MindSporeGraphPlatform
26from hyper_parallel.platform.mindspore.hsdp.state import MindSporeHSDPState
27from hyper_parallel.platform.mindspore.hsdp.grad_hook import MindSporeHSDPGradHook
28from hyper_parallel.platform.mindspore.hsdp.async_grad_hook import MindSporeHSDPAsyncGradHook
31@jit_class
32class MindSporeHSDPScheduler(HSDPScheduler):
33 """MindSporeHSDPScheduler is used to implement optimizer level."""
34 HYPER_PARALLEL_MINDSPORE_SO = "libhyper_parallel_mindspore.so"
36 def _init_platform(self):
37 """Initialize the platform."""
38 if self.config.use_eager_hook:
39 self.platform = get_platform()
40 else:
41 self.platform = MindSporeGraphPlatform()
43 def _new_cell_state(self):
44 """Create a new cell state."""
45 # TODO: why reset use_eager_hook here?
46 # self.config.use_eager_hook = ms.get_context("mode") != ms.GRAPH_MODE
47 self.hsdp_state = MindSporeHSDPState(self.cell, self.config, self.platform)
49 def _new_grad_hook(self):
50 """Create a new grad hook."""
51 if self.config.use_eager_hook and self.config.comm_async:
52 self.grad_hook = MindSporeHSDPAsyncGradHook(self.config, self.platform)
53 else:
54 self.grad_hook = MindSporeHSDPGradHook(self.config, self.platform)
56 def _register_forward_backward_hooks(self):
57 """Register module forward and backward hook."""
58 self.cell.register_forward_pre_hook(self._hsdp_forward_pre_hook)
59 self.cell.register_backward_pre_hook(self._hsdp_backward_pre_hook)
60 if self.shard_level == OptimizerLevel.SHARD_OPT_GRAD_PARAM:
61 self.cell.register_forward_hook(self._hsdp_forward_hook)
62 self.cell.register_backward_hook(self._hsdp_backward_hook)
63 elif self.requires_acc_grad:
64 self.cell.register_backward_hook(self._hsdp_acc_backward_hook)
65 else:
66 self.cell.register_backward_hook(self._hsdp_backward_hook)
68 def _register_hooks(self):
69 """Register hooks."""
70 if self.config.use_eager_hook:
71 super()._register_hooks()
72 else:
73 self._register_graph_hook()
75 @staticmethod
76 def get_pass_library_pass():
77 """Safely locate pass library path (compatible with Python 3.8+)"""
78 try:
79 # Python 3.9+
80 if hasattr(resources, "files"):
81 return resources.files(
82 "hyper_parallel.platform.mindspore.custom_pass") / \
83 MindSporeHSDPScheduler.HYPER_PARALLEL_MINDSPORE_SO
84 # Python 3.8 fallback
85 import pkg_resources # pylint: disable=C0415
86 return Path(pkg_resources.resource_filename(
87 "hyper_parallel.platform.mindspore.custom_pass",
88 MindSporeHSDPScheduler.HYPER_PARALLEL_MINDSPORE_SO
89 ))
90 except Exception as e:
91 warnings.warn(
92 f"Failed to locate mindspore custom pass library: {e}")
93 return None
95 def _register_custom_passes(self):
96 """Register custom graph optimization passes to mindspore"""
98 so_path = self.get_pass_library_pass()
99 if so_path and so_path.exists():
100 success = ms.graph.register_custom_pass(
101 pass_name="DuplicatePrimOnMultiUsersPass",
102 plugin_so_path=str(so_path),
103 device="cpu",
104 pass_type=ms.graph.CustomPassType.FULL_GRAPH)
105 if not success:
106 print(f"Failed to register MindSpore custom pass from {so_path}.")
107 return success
109 print(f"Failed to locate MindSpore custom pass library {so_path}.")
110 return False
112 def _get_param_forward_hook(self, hsdp_param):
113 """Get param forward hook."""
114 if self.shard_level == OptimizerLevel.SHARD_OPT_GRAD_PARAM:
115 # pylint: disable=W0212
116 allgather = ops._add_attr(self.platform.all_gather_into_tensor, duplicate_on_multiple_users=True)
118 def stateless_param_forward_hook(origin_param):
119 output, _ = allgather(origin_param, hsdp_param.sharded_group_info)
120 return output
122 if not self._register_custom_passes():
123 raise RuntimeError(
124 "Mindspore custom pass registration failed but is mandatory for optimizer level "
125 f"{OptimizerLevel.SHARD_OPT_GRAD_PARAM}. "
126 "This optimization level requires graph transformations provided by the custom pass library "
127 f"({self.HYPER_PARALLEL_MINDSPORE_SO}). Ensure MindSpore is installed and the pass library was "
128 "successfully built during package installation."
129 )
131 return stateless_param_forward_hook
133 def stateful_param_forward_hook(origin_param):
134 unshared_data, _ = self.platform.all_gather_into_tensor(origin_param, hsdp_param.sharded_group_info)
135 return unshared_data
136 return stateful_param_forward_hook
138 def _get_param_backward_hook(self, hsdp_param):
139 """Get hook for param backward process."""
140 grad_hook = self.grad_hook.get_hook(hsdp_param)
141 def backward_hook(grad):
142 return grad_hook(grad)
144 def backward_acc_grad_hook(grad):
145 return grad_hook(grad)
147 if self.requires_acc_grad:
148 return backward_acc_grad_hook
149 return backward_hook
152 def _get_parameter_forward_hook(self, hsdp_forward_hook, hsdp_grad_hook):
153 """
154 Get parameter forward hook according to the hsdp_forward_hook and hsdp_grad_hook.
155 """
156 class ForwardHookNet(nn.Cell):
157 def __init__(self, hsdp_forward_hook) -> None:
158 super().__init__()
159 self.hsdp_forward_hook = hsdp_forward_hook
160 def construct(self, param):
161 return self.hsdp_forward_hook(param)
162 def bprop(self, param, out, dout): # pylint: disable=W0613
163 return (dout,)
165 fwd_hook_net = ForwardHookNet(hsdp_forward_hook)
166 insert_grad_of = ops.InsertGradientOf(hsdp_grad_hook)
168 def parameter_forward_hook(param):
169 return insert_grad_of(fwd_hook_net(param))
170 return parameter_forward_hook
173 def _register_graph_hook(self):
174 """Register param forward and grad hook."""
175 params_hooks = []
176 for hsdp_param in self.hsdp_state.hsdp_params:
177 if not hsdp_param.sharded:
178 hsdp_param.param.register_hook(self.grad_hook.get_hook(hsdp_param))
179 else:
180 param_fwd_hook = self._get_parameter_forward_hook(
181 self._get_param_forward_hook(hsdp_param), self._get_param_backward_hook(hsdp_param))
182 params_hooks.append(
183 {"params": [hsdp_param.param], "hook": param_fwd_hook}
184 )
185 self.cell.register_parameter_forward_hook(params_hooks)
187 def _get_grad_buffer_hook(self, hsdp_param):
188 """Set grad for hsdp parameter."""
189 origin_hook = super()._get_grad_buffer_hook(hsdp_param)
190 def set_grad_hook(grad):
191 grad = origin_hook(grad)
192 hsdp_param.param.grad = grad
193 return grad
194 return set_grad_hook