Coverage for hyper_parallel / core / hsdp / hsdp_async_grad_hook.py: 76%
88 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP async gradient hook"""
16from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel
17from hyper_parallel.core.hsdp.hsdp_grad_hook import HSDPGradHook
20class HSDPAsyncGradHook(HSDPGradHook):
21 """HSDP gradient hook with async communication op"""
23 def _pre_process(self, grad):
24 """process before call async comm op"""
25 origin_dtype = None
26 if self.reduce_dtype is not None:
27 origin_dtype = grad.dtype
28 grad = grad.to(self.reduce_dtype)
29 return origin_dtype, grad
31 def _post_process(self, grad, origin_dtype):
32 """process after call async comm op"""
33 if origin_dtype is not None:
34 grad = grad.to(origin_dtype)
36 if self.grad_scale != 1.0:
37 return grad * self.grad_scale
38 return grad
40 # pylint: disable=W0613
41 def _get_final_async_grad_hook(self, param, async_hook, post_hook=None):
42 """get async process hook"""
43 def async_hook_handler(grad):
44 origin_dtype, pre_grad = self._pre_process(grad)
45 output, handle = async_hook(pre_grad)
47 def post_process():
48 if post_hook is not None:
49 post_output = post_hook(output)
50 else:
51 post_output = output
52 post_grad = self._post_process(post_output, origin_dtype)
53 output.data = post_grad
55 if handle is None:
56 post_process()
57 else:
58 self.platform.set_grad_reduce_handle(handle, post_process)
59 return output
60 return async_hook_handler
62 def _get_async_param_unsharded_hook(self, hsdp_param):
63 """get hook for unsharded param."""
64 def grad_all_reduce_hook(grad):
65 return self.platform.all_reduce(grad, hsdp_param.unsharded_group_info, async_op=True)
67 def grad_acc_all_reduce_hook(grad):
68 hsdp_param.acc_grad.add_(grad)
69 if not self.requires_grad_sync:
70 return grad, None
71 return self.platform.all_reduce(hsdp_param.acc_grad, hsdp_param.unsharded_group_info, async_op=True)
73 if not self.requires_acc_grad:
74 grad_hook = grad_all_reduce_hook
75 else:
76 grad_hook = grad_acc_all_reduce_hook
77 return self._get_final_async_grad_hook(hsdp_param.param, grad_hook)
79 def _get_async_param_fully_sharded_hook(self, hsdp_param):
80 """get hook for fully sharded param."""
81 def grad_reduce_scatter_hook(grad):
82 return self.platform.reduce_scatter_tensor(grad, hsdp_param.sharded_group_info, async_op=True)
84 def grad_acc_reduce_scatter_hook(grad):
85 hsdp_param.acc_grad.add_(grad)
86 if not self.requires_grad_sync:
87 return grad, None
88 return self.platform.reduce_scatter_tensor(hsdp_param.acc_grad, hsdp_param.sharded_group_info,
89 async_op=True)
91 def grad_reduce_scatter_acc_post_hook(output):
92 hsdp_param.acc_grad.add_(output)
93 return hsdp_param.acc_grad
95 if not self.requires_acc_grad:
96 return self._get_final_async_grad_hook(hsdp_param.param, grad_reduce_scatter_hook)
98 if self.shard_level == OptimizerLevel.SHARD_OPT:
99 return self._get_final_async_grad_hook(hsdp_param.param, grad_acc_reduce_scatter_hook)
100 return self._get_final_async_grad_hook(hsdp_param.param, grad_reduce_scatter_hook,
101 grad_reduce_scatter_acc_post_hook)
103 def _get_async_param_partial_sharded_hook(self, hsdp_param):
104 """get hook for partial sharded param."""
105 def grad_reduce_scatter_hook(grad):
106 output, _ = self.platform.reduce_scatter_tensor(grad, hsdp_param.sharded_group_info, async_op=False)
107 return self.platform.all_reduce(output, hsdp_param.unsharded_group_info, async_op=True)
109 def grad_acc_reduce_scatter_hook(grad):
110 hsdp_param.acc_grad.add_(grad)
111 if not self.requires_grad_sync:
112 return grad, None
113 output, _ = self.platform.reduce_scatter_tensor(hsdp_param.acc_grad, hsdp_param.sharded_group_info,
114 async_op=False)
115 return self.platform.all_reduce(output, hsdp_param.unsharded_group_info, async_op=True)
117 def grad_reduce_scatter_acc_hook(grad):
118 output, _ = self.platform.reduce_scatter_tensor(grad, hsdp_param.sharded_group_info, async_op=False)
119 hsdp_param.acc_grad.add_(output)
120 if not self.requires_grad_sync:
121 return output, None
122 return self.platform.all_reduce(hsdp_param.acc_grad, hsdp_param.unsharded_group_info, async_op=True)
124 if not self.requires_acc_grad:
125 grad_hook = grad_reduce_scatter_hook
126 elif self.shard_level == OptimizerLevel.SHARD_OPT:
127 grad_hook = grad_acc_reduce_scatter_hook
128 else:
129 grad_hook = grad_reduce_scatter_acc_hook
130 return self._get_final_async_grad_hook(hsdp_param.param, grad_hook)
132 def get_hook(self, hsdp_param):
133 """get hook for param gradient process."""
134 if not hsdp_param.sharded:
135 if hsdp_param.dp_size == 1:
136 return self._get_hsdp_param_single_node_hook(hsdp_param)
137 return self._get_async_param_unsharded_hook(hsdp_param)
139 if hsdp_param.fully_sharded:
140 return self._get_async_param_fully_sharded_hook(hsdp_param)
142 return self._get_async_param_partial_sharded_hook(hsdp_param)