Coverage for hyper_parallel / core / hsdp / hsdp_grad_buffer.py: 54%
164 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP gradient buffer"""
16from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel
19class HSDPGradBuffer:
20 """
21 HSDP gradient buffer.
22 """
23 def __init__(self, config, init_hsdp_param, platform):
24 self.config = config
25 self.platform = platform
26 self.shard_size = init_hsdp_param.shard_size
27 self.dp_size = init_hsdp_param.dp_size
28 self.local_rank = init_hsdp_param.hsdp_rank % init_hsdp_param.shard_size
29 self.dtype = init_hsdp_param.param.dtype
30 self.sharded_group_info = init_hsdp_param.sharded_group_info
31 self.unsharded_group_info = init_hsdp_param.unsharded_group_info
32 self.sharded = init_hsdp_param.sharded
33 self.fully_sharded = init_hsdp_param.fully_sharded
34 self.device = init_hsdp_param.param.device
35 self.hsdp_params = []
36 self.numel = 0
37 self.requires_grad_sync = False
38 self.sharded_grad_buffer = None
39 self.unsharded_grad_buffer = None
40 self.prefetch_handle = None
41 self.prefetch_data = None
42 self.reduce_type = self.dtype
43 if self.config.reduce_dtype is not None:
44 self.reduce_type = self.config.reduce_dtype
46 def add_param(self, hsdp_param):
47 """add param to buffer"""
48 self.hsdp_params.append(hsdp_param)
50 def init(self):
51 """init buffer"""
52 self._init_grad_buffer_index()
53 self._init_acc_grad_buffer()
54 self.num_grad = len(self.hsdp_params)
55 self.num_ready_grad = 0
57 def set_requires_grad_sync(self, requires_grad_sync):
58 """set requires grad sync flag to control gradient sync."""
59 self.requires_grad_sync = requires_grad_sync
61 def _set_handle(self, handle, process=None):
62 """set handle for async comm and wait handle for sync comm"""
63 if self.config.comm_async:
64 if process is not None:
65 self.platform.set_grad_reduce_handle(handle, process)
66 else:
67 self.platform.set_grad_reduce_handle(handle)
68 else:
69 handle.wait()
70 if process is not None:
71 process()
73 def _init_grad_buffer_index(self):
74 """init grad buffer index"""
75 self.numel = 0
76 for hsdp_param in self.hsdp_params:
77 start_index = self.numel
78 end_index = start_index + hsdp_param.param.numel()
79 hsdp_param.grad_buffer_start_index = start_index
80 hsdp_param.grad_buffer_end_index = end_index
81 self.numel = end_index
83 def _init_acc_grad_buffer(self):
84 """init acc grad buffer"""
85 if not self.config.requires_acc_grad:
86 return
88 acc_grad_shape = (self.shard_size, self.numel)
89 self.acc_grad_buffer_dim0 = self.shard_size
90 if self.config.shard_level != OptimizerLevel.SHARD_OPT:
91 acc_grad_shape = (1, self.numel)
92 self.acc_grad_buffer_dim0 = 1
93 self.acc_grad_buffer = self.platform.new_zero_parameter(acc_grad_shape, self.reduce_type, False, self.device)
95 def _copy_grad_to_unsharded_buffer(self):
96 """copy grad to buffer"""
97 self.unsharded_grad_buffer = self.platform.new_tensor((self.shard_size, self.numel), self.reduce_type,
98 self.device)
99 for hsdp_param in self.hsdp_params:
100 if hsdp_param.grad is None:
101 raise ValueError("HSDP param grad can't be None with comm fusion.")
102 start_index = hsdp_param.grad_buffer_start_index
103 end_index = hsdp_param.grad_buffer_end_index
104 buffer_view = self.unsharded_grad_buffer[:, start_index:end_index]
105 if self.dtype != self.reduce_type:
106 buffer_view[:] = hsdp_param.grad.view(self.shard_size, -1).to(self.reduce_type)[:]
107 else:
108 buffer_view[:] = hsdp_param.grad.view(self.shard_size, -1)[:]
110 def _add_grad_to_acc_buffer(self):
111 """add grad to acc buffer"""
112 self._copy_grad_to_unsharded_buffer()
113 self.acc_grad_buffer.add_(self.unsharded_grad_buffer)
114 self.unsharded_grad_buffer = None
116 def _set_grad_data(self, buffer=None):
117 """set grad with buffer view"""
118 if buffer is None:
119 buffer = self.sharded_grad_buffer
121 if self.dtype != self.reduce_type:
122 buffer = buffer.to(self.dtype)
124 if self.config.grad_scale != 1.0:
125 buffer = buffer * self.config.grad_scale
127 for hsdp_param in self.hsdp_params:
128 if hsdp_param.grad is None:
129 raise ValueError("HSDP param grad can't be None with comm fusion.")
130 start_index = hsdp_param.grad_buffer_start_index
131 end_index = hsdp_param.grad_buffer_end_index
132 view_shape = list(hsdp_param.param_shape)
133 view_shape[0] = view_shape[0] // self.shard_size
134 buffer_view = buffer[:, start_index:end_index].view(view_shape)
135 hsdp_param.grad.data = buffer_view + 0
137 def _handle_single_node_grad(self):
138 """single node don't need to reduce grad, only need to acc grad"""
139 if not self.config.requires_acc_grad:
140 return
141 self._add_grad_to_acc_buffer()
142 self._set_grad_data(self.acc_grad_buffer)
144 def _reduce_no_shard_grad(self):
145 """reduce grad when param is not sharded"""
146 if not self.config.requires_acc_grad:
147 self._copy_grad_to_unsharded_buffer()
148 output, handle = self.platform.all_reduce(self.unsharded_grad_buffer, self.unsharded_group_info,
149 async_op=True)
150 else:
151 self._add_grad_to_acc_buffer()
152 if not self.requires_grad_sync:
153 return
154 output, handle = self.platform.all_reduce(self.acc_grad_buffer, self.unsharded_group_info, async_op=True)
155 self.sharded_grad_buffer = output
156 self._set_handle(handle, self._set_grad_data)
158 def _reduce_fully_shard_grad(self):
159 """reducescatter grad when param is fully sharded"""
160 if not self.config.requires_acc_grad:
161 self._copy_grad_to_unsharded_buffer()
162 output, handle = self.platform.reduce_scatter_tensor(self.unsharded_grad_buffer, self.sharded_group_info,
163 async_op=True)
164 self.sharded_grad_buffer = output
165 self._set_handle(handle, self._set_grad_data)
166 elif self.config.shard_level != OptimizerLevel.SHARD_OPT:
167 self._copy_grad_to_unsharded_buffer()
168 output, handle = self.platform.reduce_scatter_tensor(self.unsharded_grad_buffer, self.sharded_group_info,
169 async_op=True)
170 self.sharded_grad_buffer = output
171 def post_process():
172 self.acc_grad_buffer.add_(output)
173 self._set_grad_data(self.acc_grad_buffer)
174 self._set_handle(handle, post_process)
175 else:
176 self._add_grad_to_acc_buffer()
177 if not self.requires_grad_sync:
178 return
179 output, handle = self.platform.reduce_scatter_tensor(self.acc_grad_buffer, self.sharded_group_info,
180 async_op=True)
181 self.sharded_grad_buffer = output
182 self._set_handle(handle, self._set_grad_data)
184 def _reduce_partial_shard_grad(self):
185 """reduce grad after reducescatter grad when param is partial sharded"""
186 if not self.config.requires_acc_grad:
187 self._copy_grad_to_unsharded_buffer()
188 output, _ = self.platform.reduce_scatter_tensor(self.unsharded_grad_buffer, self.sharded_group_info,
189 async_op=False)
190 output, handle = self.platform.all_reduce(output, self.unsharded_group_info, async_op=True)
191 self.sharded_grad_buffer = output
192 self._set_handle(handle, self._set_grad_data)
193 elif self.config.shard_level != OptimizerLevel.SHARD_OPT:
194 self._copy_grad_to_unsharded_buffer()
195 output, _ = self.platform.reduce_scatter_tensor(self.unsharded_grad_buffer, self.sharded_group_info,
196 async_op=False)
197 self.acc_grad_buffer.add_(output)
198 if not self.requires_grad_sync:
199 return
200 output, handle = self.platform.all_reduce(self.acc_grad_buffer, self.unsharded_group_info, async_op=True)
201 self.sharded_grad_buffer = output
202 self._set_handle(handle, self._set_grad_data)
203 else:
204 self._add_grad_to_acc_buffer()
205 if not self.requires_grad_sync:
206 return
207 output, _ = self.platform.reduce_scatter_tensor(self.acc_grad_buffer, self.sharded_group_info,
208 async_op=False)
209 output, handle = self.platform.all_reduce(output, self.unsharded_group_info, async_op=True)
210 self.sharded_grad_buffer = output
211 self._set_handle(handle, self._set_grad_data)
213 def _reduce_grads(self):
214 """reduce or accumulate grad buffer"""
215 if not self.sharded:
216 if self.dp_size == 1:
217 self._handle_single_node_grad()
218 else:
219 self._reduce_no_shard_grad()
220 elif self.fully_sharded:
221 self._reduce_fully_shard_grad()
222 else:
223 self._reduce_partial_shard_grad()
225 def zero_grads(self):
226 """zero grad buffer"""
227 if not self.config.requires_acc_grad:
228 return
229 self.acc_grad_buffer.zero_()
231 def set_grad_ready(self):
232 """set grad ready"""
233 self.num_ready_grad = self.num_ready_grad + 1
234 if self.num_ready_grad == self.num_grad:
235 self._reduce_grads()
236 self.num_ready_grad = 0