Coverage for hyper_parallel / platform / mindspore / hsdp / scheduler.py: 50%

109 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""MindSpore HSDP scheduler""" 

16import warnings 

17from pathlib import Path 

18from importlib import resources 

19import mindspore as ms 

20from mindspore import ops 

21from mindspore import jit_class, nn 

22from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel 

23from hyper_parallel.core.hsdp.hsdp_scheduler import HSDPScheduler 

24from hyper_parallel.platform import get_platform 

25from hyper_parallel.platform.mindspore.platform_graph import MindSporeGraphPlatform 

26from hyper_parallel.platform.mindspore.hsdp.state import MindSporeHSDPState 

27from hyper_parallel.platform.mindspore.hsdp.grad_hook import MindSporeHSDPGradHook 

28from hyper_parallel.platform.mindspore.hsdp.async_grad_hook import MindSporeHSDPAsyncGradHook 

29 

30 

31@jit_class 

32class MindSporeHSDPScheduler(HSDPScheduler): 

33 """MindSporeHSDPScheduler is used to implement optimizer level.""" 

34 HYPER_PARALLEL_MINDSPORE_SO = "libhyper_parallel_mindspore.so" 

35 

36 def _init_platform(self): 

37 """Initialize the platform.""" 

38 if self.config.use_eager_hook: 

39 self.platform = get_platform() 

40 else: 

41 self.platform = MindSporeGraphPlatform() 

42 

43 def _new_cell_state(self): 

44 """Create a new cell state.""" 

45 # TODO: why reset use_eager_hook here? 

46 # self.config.use_eager_hook = ms.get_context("mode") != ms.GRAPH_MODE 

47 self.hsdp_state = MindSporeHSDPState(self.cell, self.config, self.platform) 

48 

49 def _new_grad_hook(self): 

50 """Create a new grad hook.""" 

51 if self.config.use_eager_hook and self.config.comm_async: 

52 self.grad_hook = MindSporeHSDPAsyncGradHook(self.config, self.platform) 

53 else: 

54 self.grad_hook = MindSporeHSDPGradHook(self.config, self.platform) 

55 

56 def _register_forward_backward_hooks(self): 

57 """Register module forward and backward hook.""" 

58 self.cell.register_forward_pre_hook(self._hsdp_forward_pre_hook) 

59 self.cell.register_backward_pre_hook(self._hsdp_backward_pre_hook) 

60 if self.shard_level == OptimizerLevel.SHARD_OPT_GRAD_PARAM: 

61 self.cell.register_forward_hook(self._hsdp_forward_hook) 

62 self.cell.register_backward_hook(self._hsdp_backward_hook) 

63 elif self.requires_acc_grad: 

64 self.cell.register_backward_hook(self._hsdp_acc_backward_hook) 

65 else: 

66 self.cell.register_backward_hook(self._hsdp_backward_hook) 

67 

68 def _register_hooks(self): 

69 """Register hooks.""" 

70 if self.config.use_eager_hook: 

71 super()._register_hooks() 

72 else: 

73 self._register_graph_hook() 

74 

75 @staticmethod 

76 def get_pass_library_pass(): 

77 """Safely locate pass library path (compatible with Python 3.8+)""" 

78 try: 

79 # Python 3.9+ 

80 if hasattr(resources, "files"): 

81 return resources.files( 

82 "hyper_parallel.platform.mindspore.custom_pass") / \ 

83 MindSporeHSDPScheduler.HYPER_PARALLEL_MINDSPORE_SO 

84 # Python 3.8 fallback 

85 import pkg_resources # pylint: disable=C0415 

86 return Path(pkg_resources.resource_filename( 

87 "hyper_parallel.platform.mindspore.custom_pass", 

88 MindSporeHSDPScheduler.HYPER_PARALLEL_MINDSPORE_SO 

89 )) 

90 except Exception as e: 

91 warnings.warn( 

92 f"Failed to locate mindspore custom pass library: {e}") 

93 return None 

94 

95 def _register_custom_passes(self): 

96 """Register custom graph optimization passes to mindspore""" 

97 

98 so_path = self.get_pass_library_pass() 

99 if so_path and so_path.exists(): 

100 success = ms.graph.register_custom_pass( 

101 pass_name="DuplicatePrimOnMultiUsersPass", 

102 plugin_so_path=str(so_path), 

103 device="cpu", 

104 pass_type=ms.graph.CustomPassType.FULL_GRAPH) 

105 if not success: 

106 print(f"Failed to register MindSpore custom pass from {so_path}.") 

107 return success 

108 

109 print(f"Failed to locate MindSpore custom pass library {so_path}.") 

110 return False 

111 

112 def _get_param_forward_hook(self, hsdp_param): 

113 """Get param forward hook.""" 

114 if self.shard_level == OptimizerLevel.SHARD_OPT_GRAD_PARAM: 

115 # pylint: disable=W0212 

116 allgather = ops._add_attr(self.platform.all_gather_into_tensor, duplicate_on_multiple_users=True) 

117 

118 def stateless_param_forward_hook(origin_param): 

119 output, _ = allgather(origin_param, hsdp_param.sharded_group_info) 

120 return output 

121 

122 if not self._register_custom_passes(): 

123 raise RuntimeError( 

124 "Mindspore custom pass registration failed but is mandatory for optimizer level " 

125 f"{OptimizerLevel.SHARD_OPT_GRAD_PARAM}. " 

126 "This optimization level requires graph transformations provided by the custom pass library " 

127 f"({self.HYPER_PARALLEL_MINDSPORE_SO}). Ensure MindSpore is installed and the pass library was " 

128 "successfully built during package installation." 

129 ) 

130 

131 return stateless_param_forward_hook 

132 

133 def stateful_param_forward_hook(origin_param): 

134 unshared_data, _ = self.platform.all_gather_into_tensor(origin_param, hsdp_param.sharded_group_info) 

135 return unshared_data 

136 return stateful_param_forward_hook 

137 

138 def _get_param_backward_hook(self, hsdp_param): 

139 """Get hook for param backward process.""" 

140 grad_hook = self.grad_hook.get_hook(hsdp_param) 

141 def backward_hook(grad): 

142 return grad_hook(grad) 

143 

144 def backward_acc_grad_hook(grad): 

145 return grad_hook(grad) 

146 

147 if self.requires_acc_grad: 

148 return backward_acc_grad_hook 

149 return backward_hook 

150 

151 

152 def _get_parameter_forward_hook(self, hsdp_forward_hook, hsdp_grad_hook): 

153 """ 

154 Get parameter forward hook according to the hsdp_forward_hook and hsdp_grad_hook. 

155 """ 

156 class ForwardHookNet(nn.Cell): 

157 def __init__(self, hsdp_forward_hook) -> None: 

158 super().__init__() 

159 self.hsdp_forward_hook = hsdp_forward_hook 

160 def construct(self, param): 

161 return self.hsdp_forward_hook(param) 

162 def bprop(self, param, out, dout): # pylint: disable=W0613 

163 return (dout,) 

164 

165 fwd_hook_net = ForwardHookNet(hsdp_forward_hook) 

166 insert_grad_of = ops.InsertGradientOf(hsdp_grad_hook) 

167 

168 def parameter_forward_hook(param): 

169 return insert_grad_of(fwd_hook_net(param)) 

170 return parameter_forward_hook 

171 

172 

173 def _register_graph_hook(self): 

174 """Register param forward and grad hook.""" 

175 params_hooks = [] 

176 for hsdp_param in self.hsdp_state.hsdp_params: 

177 if not hsdp_param.sharded: 

178 hsdp_param.param.register_hook(self.grad_hook.get_hook(hsdp_param)) 

179 else: 

180 param_fwd_hook = self._get_parameter_forward_hook( 

181 self._get_param_forward_hook(hsdp_param), self._get_param_backward_hook(hsdp_param)) 

182 params_hooks.append( 

183 {"params": [hsdp_param.param], "hook": param_fwd_hook} 

184 ) 

185 self.cell.register_parameter_forward_hook(params_hooks) 

186 

187 def _get_grad_buffer_hook(self, hsdp_param): 

188 """Set grad for hsdp parameter.""" 

189 origin_hook = super()._get_grad_buffer_hook(hsdp_param) 

190 def set_grad_hook(grad): 

191 grad = origin_hook(grad) 

192 hsdp_param.param.grad = grad 

193 return grad 

194 return set_grad_hook