Coverage for hyper_parallel / core / fully_shard / hsdp_grad_hook.py: 10%
113 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP gradient hook"""
16from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel
19class HSDPGradHook:
20 """HSDP gradient hook"""
22 def __init__(self, config, platform):
23 """init"""
24 self.reduce_dtype = config.reduce_dtype
25 self.grad_scale = config.grad_scale
26 self.shard_level = config.shard_level
27 self.requires_acc_grad = config.requires_acc_grad
28 self.use_eager_hook = config.use_eager_hook
29 self.requires_grad_sync = False
30 self.platform = platform
32 def _cast_hook(self, hook, grad):
33 """add cast before and after reduce hook"""
34 if self.reduce_dtype is None:
35 return hook(grad)
36 origin_dtype = grad.dtype
37 grad_cast = grad.to(self.reduce_dtype)
38 output = hook(grad_cast)
39 output = output.to(origin_dtype)
40 return output
42 def _get_final_grad_hook(self, param, grad_hook, no_cast=False):
43 """add cast and scale grad"""
44 def scale_with_cast_hook(grad):
45 output = self._cast_hook(grad_hook, grad)
46 if self.grad_scale != 1.0:
47 scale_output = output * self.grad_scale
48 return scale_output
49 return output
51 def scale_hook(grad):
52 output = grad_hook(grad)
53 if self.grad_scale != 1.0:
54 scale_output = output * self.grad_scale
55 return scale_output
56 return output
58 if no_cast:
59 return scale_hook
60 return scale_with_cast_hook
62 def _get_hsdp_param_single_node_hook(self, hsdp_param):
63 """get hook for unsharded param with single node."""
64 def grad_dummy_hook(grad):
65 output = grad * self.grad_scale
66 return output
68 def grad_hook(grad):
69 hsdp_param.acc_grad.add_(grad)
70 return hsdp_param.acc_grad
72 if not self.requires_acc_grad:
73 return self._get_final_grad_hook(hsdp_param.param, grad_dummy_hook, no_cast=True)
74 return self._get_final_grad_hook(hsdp_param.param, grad_hook)
76 def _get_hsdp_param_unsharded_hook(self, hsdp_param):
77 """get hook for unsharded param."""
78 def grad_all_reduce_hook(grad):
79 output, _ = self.platform.all_reduce(grad, hsdp_param.unsharded_group_info)
80 return output
82 def grad_acc_all_reduce_hook(grad):
83 hsdp_param.acc_grad.add_(grad)
84 if self.requires_grad_sync:
85 output, _ = self.platform.all_reduce(hsdp_param.acc_grad, hsdp_param.unsharded_group_info)
86 return output
87 return hsdp_param.acc_grad
89 if not self.requires_acc_grad:
90 grad_hook = grad_all_reduce_hook
91 else:
92 grad_hook = grad_acc_all_reduce_hook
93 return self._get_final_grad_hook(hsdp_param.param, grad_hook)
95 def _get_hsdp_param_fully_sharded_hook(self, hsdp_param):
96 """get hook for fully sharded param."""
97 def grad_reduce_scatter_hook(grad):
98 output, _ = self.platform.reduce_scatter_tensor(grad, hsdp_param.sharded_group_info)
99 return output
101 def grad_acc_reduce_scatter_hook(grad):
102 hsdp_param.acc_grad.add_(grad)
103 if self.requires_grad_sync:
104 output, _ = self.platform.reduce_scatter_tensor(hsdp_param.acc_grad, hsdp_param.sharded_group_info)
105 return output
106 return hsdp_param.acc_grad
108 def grad_reduce_scatter_acc_hook(grad):
109 output, _ = self.platform.reduce_scatter_tensor(grad, hsdp_param.sharded_group_info)
110 hsdp_param.acc_grad.add_(output)
111 return hsdp_param.acc_grad
113 if not self.requires_acc_grad:
114 grad_hook = grad_reduce_scatter_hook
115 elif self.shard_level == OptimizerLevel.SHARD_OPT:
116 grad_hook = grad_acc_reduce_scatter_hook
117 else:
118 grad_hook = grad_reduce_scatter_acc_hook
119 return self._get_final_grad_hook(hsdp_param.param, grad_hook)
121 def _get_hsdp_param_partial_sharded_hook(self, hsdp_param):
122 """get hook for partial sharded param."""
123 def grad_reduce_scatter_hook(grad):
124 output, _ = self.platform.reduce_scatter_tensor(grad, hsdp_param.sharded_group_info)
125 sliced_grad, _ = self.platform.all_reduce(output, hsdp_param.unsharded_group_info)
126 return sliced_grad
128 def grad_acc_reduce_scatter_hook(grad):
129 hsdp_param.acc_grad.add_(grad)
130 if self.requires_grad_sync:
131 output, _ = self.platform.reduce_scatter_tensor(hsdp_param.acc_grad, hsdp_param.sharded_group_info)
132 sliced_grad, _ = self.platform.all_reduce(output, hsdp_param.unsharded_group_info)
133 return sliced_grad
134 return hsdp_param.acc_grad
136 def grad_reduce_scatter_acc_hook(grad):
137 output, _ = self.platform.reduce_scatter_tensor(grad, hsdp_param.sharded_group_info)
138 hsdp_param.acc_grad.add_(output)
139 if self.requires_grad_sync:
140 output, _ = self.platform.all_reduce(hsdp_param.acc_grad, hsdp_param.unsharded_group_info)
141 return output
142 return hsdp_param.acc_grad
144 if not self.requires_acc_grad:
145 grad_hook = grad_reduce_scatter_hook
146 elif self.shard_level == OptimizerLevel.SHARD_OPT:
147 grad_hook = grad_acc_reduce_scatter_hook
148 else:
149 grad_hook = grad_reduce_scatter_acc_hook
150 return self._get_final_grad_hook(hsdp_param.param, grad_hook)
152 def get_hook(self, hsdp_param):
153 """get hook for param gradient process."""
154 if not hsdp_param.sharded:
155 if hsdp_param.dp_size == 1:
156 return self._get_hsdp_param_single_node_hook(hsdp_param)
157 return self._get_hsdp_param_unsharded_hook(hsdp_param)
159 if hsdp_param.fully_sharded:
160 return self._get_hsdp_param_fully_sharded_hook(hsdp_param)
162 return self._get_hsdp_param_partial_sharded_hook(hsdp_param)
164 def set_requires_grad_sync(self, requires_grad_sync):
165 """set requires grad sync flag to control gradient sync."""
166 self.requires_grad_sync = requires_grad_sync