Coverage for hyper_parallel / core / hsdp / hsdp_scheduler.py: 90%

89 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP scheduler""" 

16from enum import auto, Enum 

17from hyper_parallel.core.hsdp.hsdp_utils import HSDPConfig 

18from hyper_parallel.core.hsdp.hsdp_grad_hook import HSDPGradHook 

19from hyper_parallel.core.hsdp.hsdp_async_grad_hook import HSDPAsyncGradHook 

20 

21 

22class FSDPSchedulerState(Enum): 

23 """ 

24 Scheduler state: 

25 - PRE_FORWARD: 

26 already run hook before forward. 

27 - FORWARD: 

28 already run hook after forward. 

29 - PRE_BACKWARD: 

30 already run hook before backward. 

31 - PRE_BACKWARD: 

32 already run hook after backward. 

33 """ 

34 PRE_FORWARD = auto() 

35 FORWARD = auto() 

36 PRE_BACKWARD = auto() 

37 BACKWARD = auto() 

38 

39 

40class HSDPScheduler: 

41 """HSDPScheduler is used to implement optimizer level.""" 

42 

43 def __init__(self, cell, shard_size, threshold, shard_level, requires_acc_grad, grad_scale, use_eager_hook, 

44 reduce_dtype, comm_async, comm_fusion, bucket_size): 

45 """init hsdp scheduler.""" 

46 self.cell = cell 

47 self.no_param_sharded = shard_size == 1 

48 self.shard_level = shard_level 

49 self.requires_acc_grad = requires_acc_grad 

50 self.requires_grad_sync = False 

51 self.scheduler_state = None 

52 

53 self.forward_prefetch_cells = [] 

54 self.backward_prefetch_cells = [] 

55 self.config = HSDPConfig( 

56 shard_size, 

57 threshold, 

58 requires_acc_grad, 

59 grad_scale, 

60 shard_level, 

61 use_eager_hook, 

62 reduce_dtype, 

63 comm_async, 

64 comm_fusion, 

65 bucket_size 

66 ) 

67 self._init_platform() 

68 self._new_cell_state() 

69 self._new_grad_hook() 

70 self._register_hooks() 

71 

72 def _init_platform(self): 

73 """Initialize the platform.""" 

74 raise NotImplementedError("HSDPScheduler subclasses must implement _init_platform") 

75 

76 def _new_cell_state(self): 

77 """Create a new cell state.""" 

78 raise NotImplementedError("HSDPScheduler subclasses must implement _new_cell_state") 

79 

80 def _new_grad_hook(self): 

81 """Create a new grad hook.""" 

82 if self.config.comm_async: 

83 self.grad_hook = HSDPAsyncGradHook(self.config, self.platform) 

84 else: 

85 self.grad_hook = HSDPGradHook(self.config, self.platform) 

86 

87 def _register_hooks(self): 

88 """Register hooks.""" 

89 self._register_grad_hook() 

90 if self.no_param_sharded: 

91 return 

92 self._register_forward_backward_hooks() 

93 

94 def _register_grad_hook(self): 

95 """Register parameter grad hook.""" 

96 for hsdp_param in self.hsdp_state.hsdp_params: 

97 if not hsdp_param.param.requires_grad: 

98 continue 

99 if self.config.grad_fusion: 

100 hsdp_param.param.register_hook(self._get_grad_buffer_hook(hsdp_param)) 

101 else: 

102 hsdp_param.param.register_hook(self.grad_hook.get_hook(hsdp_param)) 

103 

104 def _register_forward_backward_hooks(self): 

105 """Register module forward and backward hook.""" 

106 raise NotImplementedError("HSDPScheduler subclasses must implement _register_forward_backward_hooks.") 

107 

108 def set_requires_grad_sync(self, requires_grad_sync): 

109 """Set requires grad sync flag to control gradient sync.""" 

110 self.requires_grad_sync = requires_grad_sync 

111 self.grad_hook.set_requires_grad_sync(requires_grad_sync) 

112 self.hsdp_state.set_requires_grad_sync(requires_grad_sync) 

113 

114 def zero_grads(self): 

115 """Set gradient to zero.""" 

116 if self.requires_acc_grad: 

117 self.hsdp_state.zero_grads() 

118 

119 # pylint: disable=W0613 

120 def _hsdp_forward_pre_hook(self, cell, inputs): 

121 """Forward pre hook to unsharded parameter for forward process.""" 

122 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD: 

123 return 

124 self.scheduler_state = FSDPSchedulerState.PRE_FORWARD 

125 if len(inputs) > 0: 

126 self.platform.set_tensor_requires_grad(inputs[0]) 

127 self.hsdp_state.unshard() 

128 for prefetch_cell in self.forward_prefetch_cells: 

129 prefetch_cell.hsdp_scheduler.hsdp_state.prefetch() 

130 

131 # pylint: disable=W0613 

132 def _hsdp_forward_hook(self, cell, inputs, outputs): 

133 """Forward hook to shard parameter for saving memory.""" 

134 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD: 

135 return 

136 self.scheduler_state = FSDPSchedulerState.FORWARD 

137 self.hsdp_state.shard() 

138 

139 # pylint: disable=W0613 

140 def _hsdp_backward_pre_hook(self, cell, grad_outputs): 

141 """Backward pre hook to unsharded parameter for backward process.""" 

142 self.scheduler_state = FSDPSchedulerState.PRE_BACKWARD 

143 self.hsdp_state.unshard() 

144 for prefetch_cell in self.backward_prefetch_cells: 

145 prefetch_cell.hsdp_scheduler.hsdp_state.prefetch() 

146 

147 # pylint: disable=W0613 

148 def _hsdp_backward_hook(self, cell, grad_inputs, grad_outputs): 

149 """Backward hook to shard parameter for optimizer process or saving memory.""" 

150 self.scheduler_state = FSDPSchedulerState.BACKWARD 

151 self.hsdp_state.shard() 

152 

153 # pylint: disable=W0613 

154 def _hsdp_acc_backward_hook(self, cell, grad_inputs, grad_outputs): 

155 """Backward hook to shard parameter for grad accumulation when requires_grad_sync is True.""" 

156 self.scheduler_state = FSDPSchedulerState.BACKWARD 

157 if self.requires_grad_sync: 

158 self.hsdp_state.shard() 

159 

160 def _get_grad_buffer_hook(self, hsdp_param): 

161 """Set grad ready.""" 

162 def hook(grad): 

163 hsdp_param.grad = grad 

164 self.hsdp_state.set_grad_ready(hsdp_param) 

165 return grad 

166 return hook 

167 

168 def set_forward_prefetch_cells(self, hsdp_cell_list): 

169 """Set forward prefetch cells.""" 

170 self.forward_prefetch_cells = hsdp_cell_list 

171 

172 def set_backward_prefetch_cells(self, hsdp_cell_list): 

173 """Set backward prefetch cells.""" 

174 self.backward_prefetch_cells = hsdp_cell_list