Coverage for hyper_parallel / core / hsdp / hsdp_async_grad_hook.py: 76%

88 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP async gradient hook""" 

16from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel 

17from hyper_parallel.core.hsdp.hsdp_grad_hook import HSDPGradHook 

18 

19 

20class HSDPAsyncGradHook(HSDPGradHook): 

21 """HSDP gradient hook with async communication op""" 

22 

23 def _pre_process(self, grad): 

24 """process before call async comm op""" 

25 origin_dtype = None 

26 if self.reduce_dtype is not None: 

27 origin_dtype = grad.dtype 

28 grad = grad.to(self.reduce_dtype) 

29 return origin_dtype, grad 

30 

31 def _post_process(self, grad, origin_dtype): 

32 """process after call async comm op""" 

33 if origin_dtype is not None: 

34 grad = grad.to(origin_dtype) 

35 

36 if self.grad_scale != 1.0: 

37 return grad * self.grad_scale 

38 return grad 

39 

40 # pylint: disable=W0613 

41 def _get_final_async_grad_hook(self, param, async_hook, post_hook=None): 

42 """get async process hook""" 

43 def async_hook_handler(grad): 

44 origin_dtype, pre_grad = self._pre_process(grad) 

45 output, handle = async_hook(pre_grad) 

46 

47 def post_process(): 

48 if post_hook is not None: 

49 post_output = post_hook(output) 

50 else: 

51 post_output = output 

52 post_grad = self._post_process(post_output, origin_dtype) 

53 output.data = post_grad 

54 

55 if handle is None: 

56 post_process() 

57 else: 

58 self.platform.set_grad_reduce_handle(handle, post_process) 

59 return output 

60 return async_hook_handler 

61 

62 def _get_async_param_unsharded_hook(self, hsdp_param): 

63 """get hook for unsharded param.""" 

64 def grad_all_reduce_hook(grad): 

65 return self.platform.all_reduce(grad, hsdp_param.unsharded_group_info, async_op=True) 

66 

67 def grad_acc_all_reduce_hook(grad): 

68 hsdp_param.acc_grad.add_(grad) 

69 if not self.requires_grad_sync: 

70 return grad, None 

71 return self.platform.all_reduce(hsdp_param.acc_grad, hsdp_param.unsharded_group_info, async_op=True) 

72 

73 if not self.requires_acc_grad: 

74 grad_hook = grad_all_reduce_hook 

75 else: 

76 grad_hook = grad_acc_all_reduce_hook 

77 return self._get_final_async_grad_hook(hsdp_param.param, grad_hook) 

78 

79 def _get_async_param_fully_sharded_hook(self, hsdp_param): 

80 """get hook for fully sharded param.""" 

81 def grad_reduce_scatter_hook(grad): 

82 return self.platform.reduce_scatter_tensor(grad, hsdp_param.sharded_group_info, async_op=True) 

83 

84 def grad_acc_reduce_scatter_hook(grad): 

85 hsdp_param.acc_grad.add_(grad) 

86 if not self.requires_grad_sync: 

87 return grad, None 

88 return self.platform.reduce_scatter_tensor(hsdp_param.acc_grad, hsdp_param.sharded_group_info, 

89 async_op=True) 

90 

91 def grad_reduce_scatter_acc_post_hook(output): 

92 hsdp_param.acc_grad.add_(output) 

93 return hsdp_param.acc_grad 

94 

95 if not self.requires_acc_grad: 

96 return self._get_final_async_grad_hook(hsdp_param.param, grad_reduce_scatter_hook) 

97 

98 if self.shard_level == OptimizerLevel.SHARD_OPT: 

99 return self._get_final_async_grad_hook(hsdp_param.param, grad_acc_reduce_scatter_hook) 

100 return self._get_final_async_grad_hook(hsdp_param.param, grad_reduce_scatter_hook, 

101 grad_reduce_scatter_acc_post_hook) 

102 

103 def _get_async_param_partial_sharded_hook(self, hsdp_param): 

104 """get hook for partial sharded param.""" 

105 def grad_reduce_scatter_hook(grad): 

106 output, _ = self.platform.reduce_scatter_tensor(grad, hsdp_param.sharded_group_info, async_op=False) 

107 return self.platform.all_reduce(output, hsdp_param.unsharded_group_info, async_op=True) 

108 

109 def grad_acc_reduce_scatter_hook(grad): 

110 hsdp_param.acc_grad.add_(grad) 

111 if not self.requires_grad_sync: 

112 return grad, None 

113 output, _ = self.platform.reduce_scatter_tensor(hsdp_param.acc_grad, hsdp_param.sharded_group_info, 

114 async_op=False) 

115 return self.platform.all_reduce(output, hsdp_param.unsharded_group_info, async_op=True) 

116 

117 def grad_reduce_scatter_acc_hook(grad): 

118 output, _ = self.platform.reduce_scatter_tensor(grad, hsdp_param.sharded_group_info, async_op=False) 

119 hsdp_param.acc_grad.add_(output) 

120 if not self.requires_grad_sync: 

121 return output, None 

122 return self.platform.all_reduce(hsdp_param.acc_grad, hsdp_param.unsharded_group_info, async_op=True) 

123 

124 if not self.requires_acc_grad: 

125 grad_hook = grad_reduce_scatter_hook 

126 elif self.shard_level == OptimizerLevel.SHARD_OPT: 

127 grad_hook = grad_acc_reduce_scatter_hook 

128 else: 

129 grad_hook = grad_reduce_scatter_acc_hook 

130 return self._get_final_async_grad_hook(hsdp_param.param, grad_hook) 

131 

132 def get_hook(self, hsdp_param): 

133 """get hook for param gradient process.""" 

134 if not hsdp_param.sharded: 

135 if hsdp_param.dp_size == 1: 

136 return self._get_hsdp_param_single_node_hook(hsdp_param) 

137 return self._get_async_param_unsharded_hook(hsdp_param) 

138 

139 if hsdp_param.fully_sharded: 

140 return self._get_async_param_fully_sharded_hook(hsdp_param) 

141 

142 return self._get_async_param_partial_sharded_hook(hsdp_param)