Coverage for hyper_parallel / core / fully_shard / hsdp_scheduler.py: 69%

115 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP scheduler""" 

16import functools 

17from hyper_parallel.core.device_mesh import DeviceMesh 

18from hyper_parallel.core.fully_shard.hsdp_utils import HSDPConfigV2, FSDPSchedulerState 

19from hyper_parallel.core.fully_shard.hsdp_grad_hook import HSDPGradHook 

20from hyper_parallel.core.fully_shard.hsdp_async_grad_hook import HSDPAsyncGradHook 

21 

22 

23class HSDPSchedulerContext: 

24 """HSDPSchedulerContext""" 

25 

26 def __init__(self) -> None: 

27 self.post_backward_final_callback_queued: bool = False 

28 self.is_last_backward: bool = True 

29 self.post_optim_event = None 

30 

31 

32class HSDPSchedulerV2: 

33 """HSDPScheduler is used to scheduler hsdp""" 

34 def __init__(self, cell, mesh, reshard_after_forward, shard_placement_fn, 

35 mp_policy, offload_policy, ignored_params, device): 

36 """init hsdp scheduler.""" 

37 self.cell = cell 

38 self.mesh: DeviceMesh = mesh 

39 self.reshard_after_forward = reshard_after_forward 

40 self.shard_placement_fn = shard_placement_fn 

41 self.mp_policy = mp_policy 

42 self.offload_policy = offload_policy 

43 self.ignored_params = ignored_params 

44 self.device = device 

45 self.scheduler_state = None 

46 self.forward_prefetch_cells = [] 

47 self.backward_prefetch_cells = [] 

48 self.scheduler_ctx = HSDPSchedulerContext() 

49 self.config = HSDPConfigV2( 

50 mesh, 

51 reshard_after_forward, 

52 shard_placement_fn, 

53 mp_policy, 

54 offload_policy, 

55 ignored_params 

56 ) 

57 self._init_platform() 

58 self._new_cell_state() 

59 self._new_grad_hook() 

60 self._register_hooks() 

61 

62 def _init_platform(self): 

63 """Initialize the platform.""" 

64 raise NotImplementedError("HSDPScheduler subclasses must implement _init_platform") 

65 

66 def _new_cell_state(self): 

67 """Create a new cell state.""" 

68 raise NotImplementedError("HSDPScheduler subclasses must implement _new_cell_state") 

69 

70 def _new_grad_hook(self): 

71 """Create a new grad hook.""" 

72 if self.config.comm_async: 

73 self.grad_hook = HSDPAsyncGradHook(self.config, self.platform) 

74 else: 

75 self.grad_hook = HSDPGradHook(self.config, self.platform) 

76 

77 def _register_hooks(self): 

78 """Register hooks.""" 

79 raise NotImplementedError("HSDPScheduler subclasses must implement _register_hooks.") 

80 

81 def _register_grad_hook(self): 

82 """Register parameter grad hook.""" 

83 for hsdp_param in self.hsdp_state.hsdp_params: 

84 if not hsdp_param.param.requires_grad: 

85 continue 

86 if self.config.grad_fusion: 

87 hsdp_param.param.register_hook(self._get_grad_buffer_hook(hsdp_param)) 

88 else: 

89 hsdp_param.param.register_hook(self.grad_hook.get_hook(hsdp_param)) 

90 

91 def _register_forward_backward_hooks(self): 

92 """Register module forward and backward hook.""" 

93 raise NotImplementedError("HSDPScheduler subclasses must implement _register_forward_backward_hooks.") 

94 

95 def set_reshard_after_forward(self, reshard_after_forward: bool): 

96 """set reshard_after_forward flag""" 

97 if not isinstance(reshard_after_forward, bool): 

98 raise ValueError(f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}") 

99 self.reshard_after_forward = reshard_after_forward 

100 self.config.reshard_after_forward = reshard_after_forward 

101 

102 def set_reshard_after_backward(self, reshard_after_backward: bool): 

103 """set reshard_after_backward flag""" 

104 if not isinstance(reshard_after_backward, bool): 

105 raise ValueError(f"reshard_after_backward should be a bool, got {type(reshard_after_backward)}") 

106 if self.hsdp_state is not None: 

107 self.hsdp_state.reshard_after_backward = reshard_after_backward 

108 

109 def set_requires_all_reduce(self, requires_all_reduce: bool): 

110 """set requires_all_reduce flag""" 

111 if not isinstance(requires_all_reduce, bool): 

112 raise ValueError(f"requires_all_reduce should be a bool, got {type(requires_all_reduce)}") 

113 if self.hsdp_state is not None: 

114 self.hsdp_state.all_reduce_grads = requires_all_reduce 

115 

116 def set_requires_grad_sync(self, requires_grad_sync: bool): 

117 """Set requires grad sync flag to control gradient sync.""" 

118 if not isinstance(requires_grad_sync, bool): 

119 raise ValueError(f"requires_grad_sync should be a bool, got {type(requires_grad_sync)}") 

120 self.requires_grad_sync = requires_grad_sync 

121 self.hsdp_state.set_requires_grad_sync(requires_grad_sync) 

122 

123 def zero_grads(self): 

124 """Set gradient to zero.""" 

125 if self.requires_acc_grad: 

126 self.hsdp_state.zero_grads() 

127 

128 # pylint: disable=W0613 

129 def _hsdp_forward_pre_hook(self, cell, args, kwargs): 

130 """Forward pre hook to unsharded parameter for forward process.""" 

131 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD: 

132 return args, kwargs 

133 self.scheduler_state = FSDPSchedulerState.PRE_FORWARD 

134 if self.mp_policy.cast_forward_inputs and self.mp_policy.param_dtype: 

135 cast_fn = functools.partial(self.platform.cast_fp_tensor, self.mp_policy.param_dtype) 

136 args = self.platform.apply_to_tensors(cast_fn, args) 

137 kwargs = self.platform.apply_to_tensors(cast_fn, kwargs) 

138 self.hsdp_state.unshard() 

139 for prefetch_cell in self.forward_prefetch_cells: 

140 prefetch_cell.hsdp_scheduler.hsdp_state.prefetch() 

141 return args, kwargs 

142 

143 # pylint: disable=W0613 

144 def _hsdp_forward_hook(self, cell, inputs, outputs): 

145 """Forward hook to shard parameter for saving memory.""" 

146 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD: 

147 return 

148 self.scheduler_state = FSDPSchedulerState.FORWARD 

149 if self.reshard_after_forward: 

150 self.hsdp_state.shard() 

151 if self.mp_policy.output_dtype is not None: 

152 outputs = self.platform.apply_to_tensors( 

153 functools.partial(self.platform.cast_fp_tensor, self.mp_policy.output_dtype), 

154 outputs, 

155 ) 

156 return outputs 

157 

158 # pylint: disable=W0613 

159 def _hsdp_backward_pre_hook(self, cell, grad_outputs): 

160 """Backward pre hook to unsharded parameter for backward process.""" 

161 self.scheduler_state = FSDPSchedulerState.PRE_BACKWARD 

162 if self.reshard_after_forward: 

163 self.hsdp_state.unshard() 

164 for prefetch_cell in self.backward_prefetch_cells: 

165 prefetch_cell.hsdp_scheduler.hsdp_state.prefetch() 

166 

167 # pylint: disable=W0613 

168 def _hsdp_backward_hook(self, cell, grad_inputs, grad_outputs): 

169 """Backward hook to shard parameter for optimizer process or saving memory.""" 

170 self.scheduler_state = FSDPSchedulerState.BACKWARD 

171 self.hsdp_state.post_backward() 

172 

173 

174 def _get_grad_buffer_hook(self, hsdp_param): 

175 """Set grad ready.""" 

176 

177 def hook(grad): 

178 hsdp_param.grad = grad 

179 self.hsdp_state.set_grad_ready(hsdp_param) 

180 return grad 

181 

182 return hook 

183 

184 def set_forward_prefetch_cells(self, hsdp_cell_list): 

185 """Set forward prefetch cells.""" 

186 self.forward_prefetch_cells = hsdp_cell_list 

187 

188 def set_backward_prefetch_cells(self, hsdp_cell_list): 

189 """Set backward prefetch cells.""" 

190 self.backward_prefetch_cells = hsdp_cell_list 

191 

192 def set_requires_allreuce(self, requires_all_reduce): 

193 """set_require_allreuce for HSDP""" 

194 self.hsdp_state.requires_all_reduce = requires_all_reduce 

195 

196 def reshard(self,): 

197 """Reshard parameters after forward or backward.""" 

198 self.hsdp_state.reshard()