Coverage for hyper_parallel / core / hsdp / hsdp_param.py: 93%
239 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP parameter"""
16import functools
17from hyper_parallel.core.dtensor import DTensor
18from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel, GroupInfo
21class HSDPParam:
22 """
23 HSDP parameter.
24 """
25 def __init__(self, cell, param_name, param, config, platform):
26 self.cell = cell
27 self.param_name = param_name
28 self.param = param
29 self.config = config
30 self.platform = platform
31 self.shard_size = 1
32 self.unsharded_param = None
33 self.sharded_param = None
34 self.acc_grad = None
35 self.grad = None
36 self.sharded = False
37 self.fully_sharded = True
38 self.prefetch_handle = None
39 self.prefetch_data = None
40 self.param_buffer_start_index = 0
41 self.param_buffer_end_index = 0
42 self.grad_buffer_start_index = 0
43 self.grad_buffer_end_index = 0
44 self._init_rank_info()
45 self._init_param_shard_size()
46 self._init_param()
47 self.dp_size = self.rank_size // self.shard_size
48 group_name, group = self._create_sharded_dp_group()
49 self.sharded_group_info = GroupInfo(group_name, group, self.shard_size)
50 group_name, group = self._create_unsharded_dp_group()
51 self.unsharded_group_info = GroupInfo(group_name, group, self.dp_size)
53 def _init_rank_info(self):
54 """init parameter rank info"""
55 self.rank_id = self.platform.get_rank()
56 self.hsdp_rank = self.rank_id
57 self.local_rank = self.rank_id
58 self.tp_rank = 0
59 if not isinstance(self.param, DTensor) or self.param.layout is None:
60 self.rank_size = self.platform.get_world_size()
61 return
63 if len(self.param.layout.rank_list) == 1:
64 self.rank_size = 1
65 return
67 try:
68 self.local_rank = self.param.layout.rank_list.index(self.rank_id)
69 except ValueError as e:
70 raise ValueError(f"HSDP invalid rank {self.rank_id} with rank list {self.param.layout.rank_list}.") from e
72 tensor_map = self.param.layout.tensor_map
73 sharded_axis_set = set()
74 for axis in tensor_map:
75 if isinstance(axis, int) and axis != -1:
76 sharded_axis_set.add(axis)
77 continue
78 if isinstance(axis, tuple):
79 for item in axis:
80 sharded_axis_set.add(item)
81 self.sharded_axis_set = sharded_axis_set
82 self.rank_size = 1
83 self.unsharded_reverse_axis_list = []
84 self.global_rank_stride_list = []
85 self.hsdp_rank_stride_list = []
86 self.tp_rank_stride_list = []
87 device_dims = len(self.param.layout.mesh_shape)
88 stride = 1
89 hsdp_stride = 1
90 tp_stride = 1
91 for axis in range(device_dims):
92 r_axis = device_dims - 1 - axis
93 self.global_rank_stride_list.append(stride)
94 self.hsdp_rank_stride_list.append(hsdp_stride)
95 self.tp_rank_stride_list.append(tp_stride)
96 stride = stride * self.param.layout.mesh_shape[r_axis]
97 if axis in self.sharded_axis_set:
98 tp_stride = tp_stride * self.param.layout.mesh_shape[r_axis]
99 continue
101 hsdp_stride = hsdp_stride * self.param.layout.mesh_shape[r_axis]
102 self.unsharded_reverse_axis_list.append(r_axis)
103 self.rank_size = self.rank_size * self.param.layout.mesh_shape[r_axis]
104 self.global_rank_stride_list.reverse()
105 self.hsdp_rank_stride_list.reverse()
106 self.tp_rank_stride_list.reverse()
107 self.unsharded_reverse_axis_list.reverse()
109 rank_indices = []
110 index = self.local_rank
111 for stride in self.global_rank_stride_list:
112 rank_indices.append(index // stride)
113 index = index % stride
114 self.rank_indices = rank_indices
115 hsdp_rank = 0
116 for axis in self.unsharded_reverse_axis_list:
117 hsdp_rank = hsdp_rank + rank_indices[axis] * self.hsdp_rank_stride_list[axis]
118 self.hsdp_rank = hsdp_rank
119 tp_rank = 0
120 for axis in range(device_dims):
121 if axis in self.sharded_axis_set:
122 r_axis = device_dims - 1 - axis
123 tp_rank = tp_rank + rank_indices[r_axis] * self.tp_rank_stride_list[r_axis]
124 self.tp_rank = tp_rank
126 def _hsdp_rank_to_global_rank(self, hsdp_rank_list):
127 """transform from hsdp rank to global rank"""
128 rank_list = []
129 for hsdp_rank in hsdp_rank_list:
130 local_index = hsdp_rank
131 local_indices_dict = {}
132 for axis in self.unsharded_reverse_axis_list:
133 stride = self.hsdp_rank_stride_list[axis]
134 local_indices_dict[axis] = local_index // stride
135 local_index = local_index % stride
136 global_rank = 0
137 for axis, index in enumerate(self.rank_indices):
138 index = local_indices_dict.get(axis, index)
139 global_rank = global_rank + index * self.global_rank_stride_list[axis]
140 if self.param.layout is not None:
141 if global_rank >= len(self.param.layout.rank_list):
142 raise ValueError(f"HSDP invalid index {global_rank} with"
143 f"rank list len {len(self.param.layout.rank_list)}.")
144 global_rank = self.param.layout.rank_list[global_rank]
145 rank_list.append(global_rank)
146 return rank_list
148 def _get_op_rank_list(self):
149 """get data parallel rank list"""
150 if isinstance(self.param, DTensor):
151 rank_base = self.hsdp_rank // self.shard_size * self.shard_size
152 hsdp_rank_list = [i + rank_base for i in range(self.shard_size)]
153 return self._hsdp_rank_to_global_rank(hsdp_rank_list)
154 rank_base = self.local_rank // self.shard_size * self.shard_size
155 rank_list = [i + rank_base for i in range(self.shard_size)]
156 return rank_list
158 def _get_dp_rank_list(self):
159 """get optimizer parallel rank list"""
160 if isinstance(self.param, DTensor):
161 rank_stride = self.shard_size
162 rank_base = self.hsdp_rank % rank_stride
163 hsdp_rank_list = [i * rank_stride + rank_base for i in range(self.dp_size)]
164 return self._hsdp_rank_to_global_rank(hsdp_rank_list)
165 rank_stride = self.shard_size
166 rank_base = self.local_rank % rank_stride
167 rank_list = [i * rank_stride + rank_base for i in range(self.dp_size)]
168 return rank_list
170 def _init_sharded_param(self):
171 """add and init sharded param"""
172 raise NotImplementedError("HSDP param subclasses must implement _init_sharded_param")
174 def _init_unsharded_param(self):
175 """add and init unshared param"""
176 raise NotImplementedError("HSDP param subclasses must implement _init_unsharded_param")
178 def _get_unsharded_param_data(self, async_op):
179 """get unsharded param data with async comm"""
180 local_data = self.platform.get_param_local_data(self.param)
181 return self.platform.all_gather_into_tensor(local_data, self.sharded_group_info, async_op=async_op)
183 def _init_param_shard_size(self):
184 """init parameter dp shard size"""
185 if hasattr(self.param, "hsdp_shard_size"):
186 if not isinstance(self.param.hsdp_shard_size, int) or \
187 (self.param.hsdp_shard_size <= 0 and self.param.hsdp_shard_size != -1):
188 raise ValueError(f"param's hsdp_shard_size must be a positive integer, "
189 f"but got {self.param.hsdp_shard_size}.")
190 self.shard_size = self.param.hsdp_shard_size
191 else:
192 self.shard_size = self.config.shard_size
193 local_shape = self.platform.get_param_local_shape(self.param)
194 if len(local_shape) < 1:
195 self.shard_size = 1
196 return
197 param_type_size = self.platform.get_param_type_size(self.param)
198 param_size = functools.reduce(lambda x, y: x * y, local_shape, param_type_size)
199 if param_size < self.config.threshold:
200 self.shard_size = 1
201 return
202 if self.shard_size == -1 or local_shape[0] < self.shard_size:
203 self.shard_size = local_shape[0]
205 def _gcd(m, n):
206 if m < n:
207 m, n = n, m
208 if n == 0:
209 raise ValueError("HSDP invalid gcd input 0.")
210 r = m % n
211 if r == 0:
212 return n
213 return _gcd(n, r)
215 rank_gcd = _gcd(local_shape[0], self.rank_size)
216 self.shard_size = min(self.shard_size, rank_gcd)
217 if rank_gcd % self.shard_size != 0:
218 self.shard_size = 1
219 self.param.hsdp_effective_shard_size = self.shard_size
221 def _create_sharded_dp_group(self):
222 """create communication group for sharded parameter"""
223 if self.shard_size <= 1:
224 return "hsdp_sharded_dp_group_invalid", None
226 rank_list = self._get_op_rank_list()
227 rank_list_str = "_".join([str(i) for i in rank_list])
228 group_name = "hsdp_sharded_dp_group_" + rank_list_str
229 group = self.platform.create_group(rank_list, group_name)
230 return group_name, group
232 def _create_unsharded_dp_group(self):
233 """create communication group for unsharded parameter"""
234 if self.dp_size <= 1:
235 return "hsdp_unsharded_dp_group_invalid", None
237 rank_list = self._get_dp_rank_list()
238 rank_list_str = "_".join([str(i) for i in rank_list])
239 group_name = "hsdp_unshared_dp_group_" + rank_list_str
240 group = self.platform.create_group(rank_list, group_name)
241 return group_name, group
243 def _init_param(self):
244 """init hsdp parameter"""
245 self.param.acc_grad = None
246 self.param_shape = self.platform.get_param_local_shape(self.param)
248 if self.shard_size == 1:
249 self.sharded = False
250 self.fully_sharded = False
251 if self.config.requires_acc_grad and self.param.requires_grad:
252 acc_grad_type = self.param.dtype
253 if self.config.reduce_dtype is not None:
254 acc_grad_type = self.config.reduce_dtype
255 self.acc_grad = self.platform.new_zero_parameter(self.param_shape, acc_grad_type, False,
256 self.param.device)
257 self.param.acc_grad = self.acc_grad
258 return
260 origin_param_shape = list(self.param_shape)
261 self._init_unsharded_param()
262 self._init_sharded_param()
263 if self.config.requires_acc_grad and self.param.requires_grad:
264 acc_grad_shape = origin_param_shape
265 if self.config.shard_level != OptimizerLevel.SHARD_OPT:
266 acc_grad_shape = self.sharded_param.shape
267 acc_grad_type = self.param.dtype
268 if self.config.reduce_dtype is not None:
269 acc_grad_type = self.config.reduce_dtype
270 self.acc_grad = self.platform.new_zero_parameter(acc_grad_shape, acc_grad_type, False, self.param.device)
271 self.param.acc_grad = self.acc_grad
272 self.sharded = True
273 if self.shard_size == self.rank_size:
274 self.fully_sharded = True
275 else:
276 self.fully_sharded = False
278 def to_sharded(self):
279 """change parameter to sharded state"""
280 self.platform.update_param_data(self.param, self.sharded_param)
282 def prefetch_unsharded(self):
283 """prefetch unsharded param with async all gather"""
284 if self.prefetch_handle is not None:
285 return
286 unshared_param_data, handle = self._get_unsharded_param_data(async_op=True)
287 self.prefetch_data = unshared_param_data
288 self.prefetch_handle = handle
290 #pylint: disable=W0212
291 def to_unsharded(self):
292 """change parameter to unsharded state"""
293 if self.prefetch_handle is not None:
294 self.prefetch_handle.wait()
295 self.platform.update_param_data(self.sharded_param, self.platform.get_param_local_data(self.param))
296 self.platform.update_param_data(self.param, self.prefetch_data)
297 self.prefetch_handle = None
298 self.prefetch_data = None
299 return
301 unshared_param_data, _ = self._get_unsharded_param_data(async_op=False)
302 self.platform.update_param_data(self.sharded_param, self.platform.get_param_local_data(self.param))
303 self.platform.update_param_data(self.param, unshared_param_data)
305 def zero_acc_grad(self):
306 """zero accumunication grad"""
307 if self.param.acc_grad is not None:
308 self.param.acc_grad.zero_()