Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / fully_shard / scheduler.py: 46%

112 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""MindSpore HSDP scheduler""" 

16import mindspore as ms 

17from mindspore.common.api import _pynative_executor 

18from mindspore.utils._pytree import tree_flatten, tree_unflatten 

19from hyper_parallel.core.fully_shard.hsdp_scheduler import HSDPSchedulerV2, FSDPSchedulerState 

20from hyper_parallel.core.fully_shard.hsdp_utils import get_dtensor_managed_mesh 

21from hyper_parallel.platform.mindspore.fully_shard.hook_function import PostBackwardFunction 

22from hyper_parallel.platform.mindspore.fully_shard.param_group import get_comm_ctx 

23from hyper_parallel.platform.mindspore.fully_shard.state import MindSporeHSDPStateV2 

24from hyper_parallel.core.fully_shard.utils import FSDPMeshInfo, HSDPMeshInfo, DDPMeshInfo 

25from hyper_parallel.platform import get_platform 

26 

27 

28class MindSporeHSDPSchedulerV2(HSDPSchedulerV2): 

29 """MindSpore HSDP scheduler. 

30 

31 List-unit grouped forward hooks use :class:`HSDPSchedulerV2` defaults for 

32 ``_grouped_forward_pre_hook_skip`` / ``_grouped_forward_post_hook_skip`` (no overrides here). 

33 """ 

34 def zero_grad(self) -> None: 

35 """Zero grad.""" 

36 self.hsdp_state.zero_grad() 

37 

38 def _register_hooks(self): 

39 """Register hooks.""" 

40 self._register_forward_backward_hooks() 

41 

42 def _init_platform(self): 

43 """Initialize the platform.""" 

44 from hyper_parallel.platform.mindspore.platform import MindSporePlatform 

45 self.platform = get_platform() 

46 if not isinstance(self.platform, MindSporePlatform): 

47 raise ValueError(f"MindSporeHSDPSchedulerV2 expect MindSporePlatform, but got type: {type(self.platform)}") 

48 

49 def _new_cell_state(self): 

50 """Create a new cell state for mindspore.""" 

51 params = self._get_managed_params() 

52 if self.mesh is None: 

53 compat_meshes = [get_dtensor_managed_mesh(param) for param in params] 

54 compat_meshes = [mesh for mesh in compat_meshes if mesh is not None] 

55 compat_mesh = compat_meshes[0] if compat_meshes else None 

56 if compat_mesh is None: 

57 raise ValueError( 

58 "Cannot build fully_shard compatibility mesh_info " 

59 "without a DTensor parameter mesh." 

60 ) 

61 compat_mesh_hash = compat_mesh.to_hash() 

62 for param_mesh in compat_meshes[1:]: 

63 if param_mesh.to_hash() != compat_mesh_hash: 

64 raise ValueError( 

65 "fully_shard compatibility mode requires all DTensor parameters to share the same mesh." 

66 ) 

67 self.mesh_info = DDPMeshInfo(mesh=compat_mesh, replicate_mesh_dim=0) 

68 elif self.mesh.ndim == 1: 

69 self.mesh_info = FSDPMeshInfo(mesh=self.mesh, shard_mesh_dim=0) 

70 elif self.mesh.ndim == 2: 

71 self.mesh_info = HSDPMeshInfo(mesh=self.mesh, shard_mesh_dim=1, replicate_mesh_dim=0) 

72 else: 

73 raise ValueError( 

74 "fully_shard only supports explicit 1D DP/FSDP meshes or 2D HSDP meshes. " 

75 f"Got mesh.ndim={self.mesh.ndim}." 

76 ) 

77 self.hsdp_state = MindSporeHSDPStateV2( 

78 self.modules, self.mesh_info, self.config, self.platform, self.device 

79 ) 

80 

81 def _register_post_backward_hook(self, args, kwargs): 

82 """Register backward hook using backward function.""" 

83 if not _pynative_executor.enable_grad(): 

84 return args, kwargs 

85 args_list, args_spec = tree_flatten(args) 

86 kwargs_list, kwargs_spec = tree_flatten(kwargs) 

87 args_kwargs_list = list(args_list) + list(kwargs_list) 

88 if not any( 

89 isinstance(obj, ms.Tensor) and getattr(obj, "requires_grad", False) 

90 for obj in args_kwargs_list 

91 ): 

92 return args, kwargs 

93 processed_list = list(PostBackwardFunction.apply(self, *args_kwargs_list)) 

94 for idx, (orig_obj, processed_obj) in enumerate(zip(args_kwargs_list, processed_list)): 

95 if isinstance(orig_obj, ms.Tensor) and isinstance(processed_obj, ms.Tensor): 

96 try: 

97 processed_obj.requires_grad = bool(getattr(orig_obj, "requires_grad", False)) 

98 except (AttributeError, RuntimeError, TypeError, ValueError): 

99 pass 

100 processed_list[idx] = processed_obj 

101 args_kwargs_list = processed_list 

102 args_list = args_kwargs_list[: len(args_list)] 

103 kwargs_list = args_kwargs_list[len(args_list):] 

104 args = tree_unflatten(args_spec, args_list) 

105 kwargs = tree_unflatten(kwargs_spec, kwargs_list) 

106 return args, kwargs 

107 

108 def _forward_pre_hook(self, cell, args, kwargs): 

109 """Execute forward pre hook and set up backward hook.""" 

110 args, kwargs = self._hsdp_forward_pre_hook(cell, args, kwargs) 

111 return self._register_post_backward_hook(args, kwargs) 

112 

113 def _register_backward_pre_hook(self, outputs): 

114 """Register output hook to trigger backward pre hook.""" 

115 flat_outputs, _ = tree_flatten(outputs) 

116 for output in flat_outputs: 

117 if isinstance(output, ms.Tensor) and output._requires_grad: 

118 output.register_hook(self._backward_pre_hook) 

119 return outputs 

120 

121 def _forward_hook(self, cell, inputs, outputs): 

122 """Execute forward hook.""" 

123 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD: 

124 return 

125 self._register_backward_pre_hook(outputs) 

126 if HSDPSchedulerV2.root_bp_state: 

127 self._restore_forward_prefetch_after_recompute() 

128 return 

129 return self._hsdp_forward_hook(cell, inputs, outputs) 

130 

131 # pylint: disable=W0212 

132 def _backward_pre_hook(self, grad): 

133 """Execute backward pre hook.""" 

134 _pynative_executor.queue_backward_final_callback(self._root_backward_hook) 

135 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD: 

136 return grad 

137 HSDPSchedulerV2.root_bp_state = True 

138 self._hsdp_backward_pre_hook(self.cell, None) 

139 return grad 

140 

141 def _root_backward_hook(self): 

142 """Root backward hook: finalize the outermost backward and clear recompute state.""" 

143 apply_final_reduce = self.scheduler_state != FSDPSchedulerState.BACKWARD 

144 self._backward_hook() 

145 if apply_final_reduce: 

146 comm_ctx = get_comm_ctx() 

147 if comm_ctx.all_reduce_param_group is not None: 

148 comm_ctx.all_reduce_param_group.wait_all_reduce_and_apply_grad() 

149 comm_ctx.all_reduce_param_group = None 

150 if comm_ctx.pre_param_group is not None: 

151 comm_ctx.pre_param_group.apply_fusion_reduced_grad() 

152 comm_ctx.pre_param_group = None 

153 self.hsdp_state.reduce_params() 

154 self.hsdp_state._finish_ignored_allreduce() 

155 HSDPSchedulerV2.root_bp_state = False 

156 

157 def _backward_hook(self): 

158 """Execute backward hook.""" 

159 if self.scheduler_state == FSDPSchedulerState.BACKWARD: 

160 return 

161 self._hsdp_backward_hook(self.cell, None, None) 

162 

163 def _register_forward_backward_hooks(self): 

164 """Register module forward and backward hook on all managed modules.""" 

165 if self._fsdp_group_post_pending is None: 

166 for mod in self.modules: 

167 mod.register_forward_pre_hook(self._forward_pre_hook, with_kwargs=True) 

168 mod.register_forward_hook(self._forward_hook) 

169 return 

170 for mod in self.modules: 

171 mod.register_forward_pre_hook(self._grouped_forward_pre_hook, with_kwargs=True) 

172 mod.register_forward_hook(self._make_grouped_forward_post_hook(mod))