Coverage for hyper_parallel / core / tensor_redistribution.py: 85%
209 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""tensor_redistribution"""
17from hyper_parallel.core.dtensor import DTensor
18from hyper_parallel.core.redistribute_infer import RedistributionOperatorInfer
19from hyper_parallel.platform import get_platform
20platform = get_platform()
23def _construct_layout_tuple_for_transform_operator_list(from_layout, to_layout, from_full_shape):
24 """_construct_layout_tuple_for_transform_operator_list"""
25 from_layout_dict = from_layout.to_dict()
26 to_layout_dict = to_layout.to_dict()
27 from_layout_tuple = (from_layout_dict["mesh_shape"], from_layout_dict["tensor_map"], list(from_full_shape))
28 to_layout_tuple = (to_layout_dict["mesh_shape"], to_layout_dict["tensor_map"], list(from_full_shape)) # TODO: 考虑reshape的场景
29 return from_layout_tuple, to_layout_tuple
32class TensorRedistribution:
33 """
34 TensorRedistribution.
35 """
36 def __init__(self):
37 self.is_init = False
38 self.rank_list = None # rank_list for current stage
39 self.rank_id = None # current rank_lid
40 self._transform_cache = {}
41 self._construct_op_operator = {
42 "Reshape": self._construct_reshape,
43 "AllConcat": self._construct_all_concat,
44 "StridedSlice": self._construct_strided_slice,
45 "all_concat": self._construct_all_concat_new,
46 "all_split": self._construct_all_split,
47 "all_to_all": self._construct_all_to_all
48 }
50 def _construct_reshape(self, x, *args):
51 """args: (*shape)"""
52 return x.view(args)
54 def _construct_all_concat(self, x, *args):
55 """args: (*rank_list, concat_dim)"""
56 rank_list = args[0:-1]
57 concat_dim = args[-1]
58 group = platform.create_group(rank_list)
59 concat_size = len(rank_list)
60 return platform.differentiable_all_gather_concat(x, group, concat_size, concat_dim)
63 def _construct_strided_slice(self, x, *args):
64 """args: (begin, end, strides)"""
65 dims = len(args) // 3
66 return platform.construct_strided_slice(x, args[0: dims], args[dims: 2 * dims], args[2 * dims:])
68 def _construct_all_concat_new(self, x, *args):
69 """args: (concat_dim, concat_size, group)"""
70 rank_list = args[2]
71 concat_dim = args[0]
72 concat_size = args[1]
73 group = platform.create_group(rank_list)
74 return platform.differentiable_all_gather_concat(x, group, concat_size, concat_dim)
76 def _construct_all_split(self, x, *args):
77 """args: (split_dim, split_size, group)"""
78 rank_list = list(args[2])
79 split_dim = args[0]
80 split_size = args[1]
81 idx = rank_list.index(self.rank_id)
82 return platform.chunk(x, split_dim, split_size, idx)
84 def _construct_all_to_all(self, x, *args):
85 """args: (split_dim, concat_dim, permute_size, group)"""
86 split_dim, concat_dim, split_count, rank_list = args
87 group = platform.create_group(rank_list)
88 original_shape = x.shape
90 dim_size = original_shape[split_dim]
91 if dim_size % split_count != 0:
92 raise ValueError(f"Dimension {split_dim} with size {dim_size} "
93 f"cannot be evenly split into {split_count} parts")
95 split_size = dim_size // split_count
96 final_shape = list(original_shape)
97 if split_dim != concat_dim:
98 final_shape[split_dim] = split_size
99 final_shape[concat_dim] = final_shape[concat_dim] * split_count
100 final_shape = tuple(final_shape)
102 pre_special_handle = all(original_shape[i] == 1 for i in range(split_dim))
103 if pre_special_handle:
104 reshape_shape = (split_count * split_size,) + original_shape[split_dim + 1:]
105 x_reshaped = x.view(reshape_shape)
106 else:
107 reshape_dims = list(original_shape)
108 reshape_dims[split_dim] = split_count
109 reshape_dims.insert(split_dim + 1, split_size)
111 trans_dims = list(range(len(reshape_dims)))
112 trans_dims.remove(split_dim)
113 trans_dims.insert(0, split_dim)
115 x_reshaped = x.reshape(reshape_dims).permute(trans_dims).contiguous()
117 reshape_shape = list(x_reshaped.shape)
118 reshape_shape[0] = reshape_shape[0] * reshape_shape[1]
119 reshape_shape.pop(1)
120 reshape_shape = tuple(reshape_shape)
121 x_reshaped = x_reshaped.reshape(reshape_shape)
122 x_reshaped = x_reshaped.contiguous()
123 output_tensor = platform.differentiable_all_to_all(
124 input_data=x_reshaped,
125 output_shape=reshape_shape,
126 group=group
127 )
129 post_special_handle = all(final_shape[i] == 1 for i in range(concat_dim))
130 if post_special_handle:
131 return output_tensor.view(final_shape)
133 output_reshape = list(output_tensor.shape)
134 output_reshape[0] = split_count
135 output_reshape.insert(1, output_tensor.shape[0] // split_count)
137 out_trans_dims = list(range(len(output_reshape)))
138 first_dim = out_trans_dims.pop(0)
139 if concat_dim >= len(out_trans_dims):
140 out_trans_dims.append(first_dim)
141 else:
142 out_trans_dims.insert(concat_dim, first_dim)
144 final_output = output_tensor.reshape(output_reshape).permute(out_trans_dims).contiguous()
146 final_reshape = list(final_output.shape)
147 if concat_dim < len(final_reshape) - 1:
148 final_reshape[concat_dim] = final_reshape[concat_dim] * final_reshape[concat_dim + 1]
149 final_reshape.pop(concat_dim + 1)
151 return final_output.reshape(final_reshape)
153 def _apply_eazy_redistribute(self, src_layout, dst_layout):
154 """_apply_eazy_redistribute"""
155 if (src_layout.mesh_shape != dst_layout.mesh_shape or
156 src_layout.rank_list != dst_layout.rank_list):
157 return False
159 tensor_map_size = len(src_layout.tensor_map)
160 if len(dst_layout.tensor_map) != tensor_map_size:
161 return False
162 return True
164 def _redistribution_without_shape(self, local_x, src_layout, dst_layout, key):
165 """_redistribution_without_shape"""
166 inferrer = RedistributionOperatorInfer(
167 dev_mat=src_layout.mesh_shape,
168 in_tensor_map=list(src_layout.tensor_map),
169 out_tensor_map=list(dst_layout.tensor_map)
170 )
171 op_list = inferrer.InferOpsList(self.rank_id, self.rank_list)
172 self._transform_cache[key] = op_list
173 for op in op_list:
174 local_x = self._construct_op_operator[op[0]](local_x, *op[1])
175 return local_x
177 def redistribution(self, input_x, to_layout):
178 """ tensor redistribution """
179 x_layout = input_x.layout
180 x = input_x
181 if input_x.layout.is_partial():
182 # Solve partial status first
183 if input_x.layout.mesh_shape == to_layout.mesh_shape:
184 x = self.reduce_partial(input_x, to_layout)
185 else:
186 x = self.reduce_partial(input_x, x_layout)
188 from_layout = x.layout
189 if not self.is_init:
190 self.rank_id = platform.get_rank()
191 self.rank_list = from_layout.rank_list
192 self.is_init = True
193 if self.rank_list != to_layout.rank_list:
194 raise ValueError(f"The from_layout rank list: {self.rank_list} is not equal to "
195 f"to_layout rank list: {to_layout.rank_list}")
196 key = from_layout.compact_str + to_layout.compact_str + str(self.rank_id)
197 if key in self._transform_cache:
198 x = x.to_local()
199 transform_operator_list = self._transform_cache[key]
200 for transform_operator in transform_operator_list:
201 x = self._construct_op_operator[transform_operator[0]](x, *transform_operator[1])
202 return DTensor.from_local(x, to_layout.mesh, to_layout.placements)
204 full_shape = x.shape
205 key_and_shape = key + str(full_shape)
206 x = x.to_local()
207 if key_and_shape in self._transform_cache:
208 transform_operator_list = self._transform_cache[key_and_shape]
209 for transform_operator in transform_operator_list:
210 x = self._construct_op_operator[transform_operator[0]](x, *transform_operator[1])
211 return DTensor.from_local(x, to_layout.mesh, to_layout.placements)
213 if self._apply_eazy_redistribute(from_layout, to_layout):
214 if from_layout.is_partial:
215 from_layout.reset_partial()
216 x = self._redistribution_without_shape(x, from_layout, to_layout, key)
217 else:
218 transform_operator_list = self._infer_transform_operator_list(from_layout, to_layout,
219 full_shape, key_and_shape)
220 for transform_operator in transform_operator_list:
221 x = self._construct_op_operator[transform_operator[0]](x, *transform_operator[1])
222 return DTensor.from_local(x, to_layout.mesh, to_layout.placements)
224 def _infer_transform_operator_list(self, from_layout, to_layout, from_full_shape, key):
225 """infer transform operator list"""
226 from_layout_tuple, to_layout_tuple = \
227 _construct_layout_tuple_for_transform_operator_list(from_layout, to_layout, from_full_shape)
228 self._transform_cache[key] = \
229 platform.get_tensor_transform().transform_tensor_sharding(from_layout_tuple, to_layout_tuple,
230 self.rank_list, False, self.rank_id)
231 return self._transform_cache[key]
233 def _allreduce_along_dev_dim(self, x, op, layout, dev_dim):
234 """Do allreduce at specified axis along dev_dim."""
235 group = layout.get_comm_group_by_axis(dev_dim)
236 zero_dim = x.dim() == 0
237 if zero_dim:
238 x = x.unsqueeze(0)
239 if op == 'avg':
240 dev_num = layout.mesh_shape[layout.alias_name.index(dev_dim)]
241 x = platform.differentiable_all_reduce(x, 'sum', group)
242 x = x / dev_num
243 elif op == 'all':
244 x_int32 = platform.tensor_type_cast(x.bool(), 'int32') # True→1, False→0
245 x = platform.differentiable_all_reduce(x_int32, 'all', group)
246 x = x.bool()
247 else:
248 x = platform.differentiable_all_reduce(x, op, group)
249 if zero_dim:
250 x = x.squeeze(0)
251 return x
253 def _reduce_scatter_along_dev_dim_with_axis(self, x, axis, op, layout, dev_dim):
254 """Do reduce_scatter at specified axis along dev_dim."""
255 dev_num = layout.mesh_shape[layout.alias_name.index(dev_dim)]
256 group = layout.get_comm_group_by_axis(dev_dim)
257 output_tensor = self.platform.reduce_scatter(x, dev_num, axis, op, group)
258 return output_tensor
260 def reduce_partial(self, input_x, to_layout):
261 """Reduce partial status."""
262 from_layout = input_x.layout
263 x = input_x
264 if from_layout is None or not from_layout.is_partial:
265 return x
267 x = x.to_local()
268 if from_layout.mesh_shape != to_layout.mesh_shape:
269 raise ValueError(f"For reduce partial, mesh_shape between from_layout and to_layout must be the same, "
270 f"but got {from_layout.mesh_shape} and {to_layout.mesh_shape}")
271 if to_layout.is_partial():
272 raise ValueError(f"For reduce partial, to_layout must be non-partial status, but got to_layout.partial: "
273 f"{to_layout.partial}")
275 dev_map_order = {}
276 for dev_axis in to_layout.alias_tensor_map:
277 if isinstance(dev_axis, tuple):
278 for i, sub_dev_axis in enumerate(dev_axis):
279 dev_map_order[sub_dev_axis] = i
280 else:
281 dev_map_order[dev_axis] = 0
283 pending_reduce_op_list = [] # List[Tuple[comm_op, op, dev_dim, reduce_dim]]
284 for dev_axis_index, op in enumerate(from_layout.partial):
285 if op is None:
286 continue
287 dev_axis = from_layout.alias_name[dev_axis_index]
288 apply_shard_dim = to_layout.get_dev_axis_apply_shard_axis(dev_axis)
289 comm_op = "ReduceScatter" if apply_shard_dim else "AllReduce"
290 pending_reduce_op_list.append((comm_op, op, dev_axis, apply_shard_dim))
292 # sort reduce op
293 # 1. ReduceScatter is executed before AllReduce
294 # 2. If multiple split, the dev axis split outer will be execute first.
295 # e.g ("cp", "tp"), will execute reduce_scatter along "cp" before "tp"
296 # 3. Lower dev_id execute before higher dev_id
297 sorted_pending_reduce_op_list = \
298 sorted(pending_reduce_op_list, key=lambda reduce_pair: (reduce_pair[0] != "ReduceScatter",
299 dev_map_order.get(reduce_pair[2], 0),
300 to_layout.mesh.axis_id(reduce_pair[2])))
302 output_alias_tensor_map = list(from_layout.alias_tensor_map)
303 for reduce_op_pair in sorted_pending_reduce_op_list:
304 comm_op = reduce_op_pair[0]
305 op = reduce_op_pair[1]
306 dev_axis = reduce_op_pair[2]
307 if comm_op == "AllReduce":
308 x = self._allreduce_along_dev_dim(x, op, from_layout, dev_axis)
309 elif comm_op == "ReduceScatter":
310 reduce_axis = reduce_op_pair[3]
311 x = self._reduce_scatter_along_dev_dim_with_axis(x, reduce_axis, op, from_layout, dev_axis)
312 if output_alias_tensor_map[reduce_axis] == "None":
313 output_alias_tensor_map[reduce_axis] = dev_axis
314 elif isinstance(output_alias_tensor_map[reduce_axis], tuple):
315 output_alias_tensor_map[reduce_axis] += (dev_axis,)
316 else:
317 output_alias_tensor_map[reduce_axis] = (output_alias_tensor_map[reduce_axis], dev_axis)
319 output_layout = from_layout(*output_alias_tensor_map)
320 output_layout.reset_partial()
321 return DTensor.from_local(x, output_layout.mesh, output_layout.placements)
324_tensor_redistribution = TensorRedistribution()