Coverage for hyper_parallel / core / fully_shard / hsdp_grad_hook.py: 10%

113 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP gradient hook""" 

16from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel 

17 

18 

19class HSDPGradHook: 

20 """HSDP gradient hook""" 

21 

22 def __init__(self, config, platform): 

23 """init""" 

24 self.reduce_dtype = config.reduce_dtype 

25 self.grad_scale = config.grad_scale 

26 self.shard_level = config.shard_level 

27 self.requires_acc_grad = config.requires_acc_grad 

28 self.use_eager_hook = config.use_eager_hook 

29 self.requires_grad_sync = False 

30 self.platform = platform 

31 

32 def _cast_hook(self, hook, grad): 

33 """add cast before and after reduce hook""" 

34 if self.reduce_dtype is None: 

35 return hook(grad) 

36 origin_dtype = grad.dtype 

37 grad_cast = grad.to(self.reduce_dtype) 

38 output = hook(grad_cast) 

39 output = output.to(origin_dtype) 

40 return output 

41 

42 def _get_final_grad_hook(self, param, grad_hook, no_cast=False): 

43 """add cast and scale grad""" 

44 def scale_with_cast_hook(grad): 

45 output = self._cast_hook(grad_hook, grad) 

46 if self.grad_scale != 1.0: 

47 scale_output = output * self.grad_scale 

48 return scale_output 

49 return output 

50 

51 def scale_hook(grad): 

52 output = grad_hook(grad) 

53 if self.grad_scale != 1.0: 

54 scale_output = output * self.grad_scale 

55 return scale_output 

56 return output 

57 

58 if no_cast: 

59 return scale_hook 

60 return scale_with_cast_hook 

61 

62 def _get_hsdp_param_single_node_hook(self, hsdp_param): 

63 """get hook for unsharded param with single node.""" 

64 def grad_dummy_hook(grad): 

65 output = grad * self.grad_scale 

66 return output 

67 

68 def grad_hook(grad): 

69 hsdp_param.acc_grad.add_(grad) 

70 return hsdp_param.acc_grad 

71 

72 if not self.requires_acc_grad: 

73 return self._get_final_grad_hook(hsdp_param.param, grad_dummy_hook, no_cast=True) 

74 return self._get_final_grad_hook(hsdp_param.param, grad_hook) 

75 

76 def _get_hsdp_param_unsharded_hook(self, hsdp_param): 

77 """get hook for unsharded param.""" 

78 def grad_all_reduce_hook(grad): 

79 output, _ = self.platform.all_reduce(grad, hsdp_param.unsharded_group_info) 

80 return output 

81 

82 def grad_acc_all_reduce_hook(grad): 

83 hsdp_param.acc_grad.add_(grad) 

84 if self.requires_grad_sync: 

85 output, _ = self.platform.all_reduce(hsdp_param.acc_grad, hsdp_param.unsharded_group_info) 

86 return output 

87 return hsdp_param.acc_grad 

88 

89 if not self.requires_acc_grad: 

90 grad_hook = grad_all_reduce_hook 

91 else: 

92 grad_hook = grad_acc_all_reduce_hook 

93 return self._get_final_grad_hook(hsdp_param.param, grad_hook) 

94 

95 def _get_hsdp_param_fully_sharded_hook(self, hsdp_param): 

96 """get hook for fully sharded param.""" 

97 def grad_reduce_scatter_hook(grad): 

98 output, _ = self.platform.reduce_scatter_tensor(grad, hsdp_param.sharded_group_info) 

99 return output 

100 

101 def grad_acc_reduce_scatter_hook(grad): 

102 hsdp_param.acc_grad.add_(grad) 

103 if self.requires_grad_sync: 

104 output, _ = self.platform.reduce_scatter_tensor(hsdp_param.acc_grad, hsdp_param.sharded_group_info) 

105 return output 

106 return hsdp_param.acc_grad 

107 

108 def grad_reduce_scatter_acc_hook(grad): 

109 output, _ = self.platform.reduce_scatter_tensor(grad, hsdp_param.sharded_group_info) 

110 hsdp_param.acc_grad.add_(output) 

111 return hsdp_param.acc_grad 

112 

113 if not self.requires_acc_grad: 

114 grad_hook = grad_reduce_scatter_hook 

115 elif self.shard_level == OptimizerLevel.SHARD_OPT: 

116 grad_hook = grad_acc_reduce_scatter_hook 

117 else: 

118 grad_hook = grad_reduce_scatter_acc_hook 

119 return self._get_final_grad_hook(hsdp_param.param, grad_hook) 

120 

121 def _get_hsdp_param_partial_sharded_hook(self, hsdp_param): 

122 """get hook for partial sharded param.""" 

123 def grad_reduce_scatter_hook(grad): 

124 output, _ = self.platform.reduce_scatter_tensor(grad, hsdp_param.sharded_group_info) 

125 sliced_grad, _ = self.platform.all_reduce(output, hsdp_param.unsharded_group_info) 

126 return sliced_grad 

127 

128 def grad_acc_reduce_scatter_hook(grad): 

129 hsdp_param.acc_grad.add_(grad) 

130 if self.requires_grad_sync: 

131 output, _ = self.platform.reduce_scatter_tensor(hsdp_param.acc_grad, hsdp_param.sharded_group_info) 

132 sliced_grad, _ = self.platform.all_reduce(output, hsdp_param.unsharded_group_info) 

133 return sliced_grad 

134 return hsdp_param.acc_grad 

135 

136 def grad_reduce_scatter_acc_hook(grad): 

137 output, _ = self.platform.reduce_scatter_tensor(grad, hsdp_param.sharded_group_info) 

138 hsdp_param.acc_grad.add_(output) 

139 if self.requires_grad_sync: 

140 output, _ = self.platform.all_reduce(hsdp_param.acc_grad, hsdp_param.unsharded_group_info) 

141 return output 

142 return hsdp_param.acc_grad 

143 

144 if not self.requires_acc_grad: 

145 grad_hook = grad_reduce_scatter_hook 

146 elif self.shard_level == OptimizerLevel.SHARD_OPT: 

147 grad_hook = grad_acc_reduce_scatter_hook 

148 else: 

149 grad_hook = grad_reduce_scatter_acc_hook 

150 return self._get_final_grad_hook(hsdp_param.param, grad_hook) 

151 

152 def get_hook(self, hsdp_param): 

153 """get hook for param gradient process.""" 

154 if not hsdp_param.sharded: 

155 if hsdp_param.dp_size == 1: 

156 return self._get_hsdp_param_single_node_hook(hsdp_param) 

157 return self._get_hsdp_param_unsharded_hook(hsdp_param) 

158 

159 if hsdp_param.fully_sharded: 

160 return self._get_hsdp_param_fully_sharded_hook(hsdp_param) 

161 

162 return self._get_hsdp_param_partial_sharded_hook(hsdp_param) 

163 

164 def set_requires_grad_sync(self, requires_grad_sync): 

165 """set requires grad sync flag to control gradient sync.""" 

166 self.requires_grad_sync = requires_grad_sync