Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / tensor_redistribution.py: 48%
214 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""tensor_redistribution"""
17from hyper_parallel.core.dtensor.dtensor import DTensor
18from hyper_parallel.core.dtensor.redistribute_infer import RedistributionOperatorInfer
19from hyper_parallel.platform import get_platform
20platform = get_platform()
23def _construct_layout_tuple_for_transform_operator_list(from_layout, to_layout, from_full_shape):
24 """_construct_layout_tuple_for_transform_operator_list"""
25 from_layout_dict = from_layout.to_dict()
26 to_layout_dict = to_layout.to_dict()
27 from_layout_tuple = (
28 from_layout_dict["mesh_shape"], from_layout_dict["tensor_map"], list(from_full_shape)
29 )
30 # NOTE: consider reshape scenario when to_full_shape differs from from_full_shape
31 to_layout_tuple = (
32 to_layout_dict["mesh_shape"], to_layout_dict["tensor_map"], list(from_full_shape)
33 )
34 return from_layout_tuple, to_layout_tuple
37class TensorRedistribution:
38 """
39 TensorRedistribution.
40 """
41 def __init__(self):
42 self.is_init = False
43 self.rank_id = None # current rank_id (global)
44 self._transform_cache = {}
45 self._construct_op_operator = {
46 "Reshape": self._construct_reshape,
47 "AllConcat": self._construct_all_concat,
48 "StridedSlice": self._construct_strided_slice,
49 "all_concat": TensorRedistribution._construct_all_concat_new,
50 "all_split": self._construct_all_split,
51 "all_to_all": self._construct_all_to_all
52 }
54 def _construct_reshape(self, x, *args):
55 """args: (*shape)"""
56 return x.view(args)
58 def _construct_all_concat(self, x, *args):
59 """args: (*rank_list, concat_dim)"""
60 rank_list = args[0:-1]
61 concat_dim = args[-1]
62 group = platform.create_group(rank_list)
63 concat_size = len(rank_list)
64 return platform.differentiable_all_gather_concat(x, group, concat_size, concat_dim)
67 def _construct_strided_slice(self, x, *args):
68 """args: (begin, end, strides)"""
69 dims = len(args) // 3
70 return platform.construct_strided_slice(x, args[0: dims], args[dims: 2 * dims], args[2 * dims:])
72 @staticmethod
73 def _construct_all_concat_new(x, *args):
74 """args: (concat_dim, concat_size, group)"""
75 rank_list = args[2]
76 concat_dim = args[0]
77 concat_size = args[1]
78 group = platform.create_group(rank_list)
79 return platform.differentiable_all_gather_concat(x, group, concat_size, concat_dim)
81 def _construct_all_split(self, x, *args):
82 """args: (split_dim, split_size, group)"""
83 rank_list = list(args[2])
84 split_dim = args[0]
85 split_size = args[1]
86 idx = rank_list.index(self.rank_id)
87 return platform.chunk(x, split_dim, split_size, idx)
89 def _construct_all_to_all(self, x, *args):
90 """args: (split_dim, concat_dim, permute_size, group)"""
91 split_dim, concat_dim, split_count, rank_list = args
92 group = platform.create_group(rank_list)
93 original_shape = x.shape
95 dim_size = original_shape[split_dim]
96 if dim_size % split_count != 0:
97 raise ValueError(f"Dimension {split_dim} with size {dim_size} "
98 f"cannot be evenly split into {split_count} parts")
100 split_size = dim_size // split_count
101 final_shape = list(original_shape)
102 if split_dim != concat_dim:
103 final_shape[split_dim] = split_size
104 final_shape[concat_dim] = final_shape[concat_dim] * split_count
105 final_shape = tuple(final_shape)
107 pre_special_handle = all(original_shape[i] == 1 for i in range(split_dim))
108 if pre_special_handle:
109 reshape_shape = (split_count * split_size,) + original_shape[split_dim + 1:]
110 x_reshaped = x.view(reshape_shape)
111 else:
112 reshape_dims = list(original_shape)
113 reshape_dims[split_dim] = split_count
114 reshape_dims.insert(split_dim + 1, split_size)
116 trans_dims = list(range(len(reshape_dims)))
117 trans_dims.remove(split_dim)
118 trans_dims.insert(0, split_dim)
120 x_reshaped = x.reshape(reshape_dims).permute(trans_dims).contiguous()
122 reshape_shape = list(x_reshaped.shape)
123 reshape_shape[0] = reshape_shape[0] * reshape_shape[1]
124 reshape_shape.pop(1)
125 reshape_shape = tuple(reshape_shape)
126 x_reshaped = x_reshaped.reshape(reshape_shape)
127 x_reshaped = x_reshaped.contiguous()
128 output_tensor = platform.differentiable_all_to_all(
129 input_data=x_reshaped,
130 output_shape=reshape_shape,
131 group=group
132 )
134 post_special_handle = all(final_shape[i] == 1 for i in range(concat_dim))
135 if post_special_handle:
136 return output_tensor.view(final_shape)
138 # When pre_special_handle collapsed leading size-1 dims, the A2A was executed
139 # in a reduced-rank space where the effective concat axis is shifted left by
140 # split_dim positions. Use recon_concat_dim for all post-A2A reshaping so
141 # that split_count is merged into the correct dimension.
142 recon_concat_dim = (concat_dim - split_dim) if pre_special_handle else concat_dim
144 output_reshape = list(output_tensor.shape)
145 output_reshape[0] = split_count
146 output_reshape.insert(1, output_tensor.shape[0] // split_count)
148 out_trans_dims = list(range(len(output_reshape)))
149 first_dim = out_trans_dims.pop(0)
150 if recon_concat_dim >= len(out_trans_dims):
151 out_trans_dims.append(first_dim)
152 else:
153 out_trans_dims.insert(recon_concat_dim, first_dim)
155 final_output = output_tensor.reshape(output_reshape).permute(out_trans_dims).contiguous()
157 final_reshape = list(final_output.shape)
158 if recon_concat_dim < len(final_reshape) - 1:
159 final_reshape[recon_concat_dim] = (
160 final_reshape[recon_concat_dim] * final_reshape[recon_concat_dim + 1]
161 )
162 final_reshape.pop(recon_concat_dim + 1)
164 result = final_output.reshape(final_reshape)
165 if pre_special_handle:
166 result = result.view(final_shape)
167 return result
169 def _apply_eazy_redistribute(self, src_layout, dst_layout):
170 """_apply_eazy_redistribute"""
171 if (src_layout.mesh_shape != dst_layout.mesh_shape or
172 src_layout.rank_list != dst_layout.rank_list):
173 return False
175 tensor_map_size = len(src_layout.tensor_map)
176 if len(dst_layout.tensor_map) != tensor_map_size:
177 return False
178 return True
180 def _redistribution_without_shape(self, local_x, src_layout, dst_layout, key, rank_list):
181 """_redistribution_without_shape"""
182 inferrer = RedistributionOperatorInfer(
183 dev_mat=src_layout.mesh_shape,
184 in_tensor_map=list(src_layout.tensor_map),
185 out_tensor_map=list(dst_layout.tensor_map)
186 )
187 op_list = inferrer.infer_ops_list(self.rank_id, rank_list)
188 self._transform_cache[key] = op_list
189 for op in op_list:
190 local_x = self._construct_op_operator[op[0]](local_x, *op[1])
191 return local_x
193 def redistribution(self, input_x, to_layout):
194 """tensor redistribution"""
195 x_layout = input_x.layout
196 x = input_x
197 if input_x.layout.is_partial():
198 # Solve partial status first
199 if input_x.layout.mesh_shape == to_layout.mesh_shape:
200 x = self.reduce_partial(input_x, to_layout)
201 else:
202 x = self.reduce_partial(input_x, x_layout)
204 from_layout = x.layout
205 if not self.is_init:
206 self.rank_id = platform.get_rank()
207 self.is_init = True
208 if from_layout.rank_list != to_layout.rank_list:
209 raise ValueError(f"The from_layout rank list: {from_layout.rank_list} is not equal to "
210 f"to_layout rank list: {to_layout.rank_list}")
211 key = from_layout.compact_str + to_layout.compact_str + str(self.rank_id)
212 if key in self._transform_cache:
213 x = x.to_local()
214 transform_operator_list = self._transform_cache[key]
215 for transform_operator in transform_operator_list:
216 x = self._construct_op_operator[transform_operator[0]](x, *transform_operator[1])
217 return DTensor.from_local(x, to_layout.mesh, to_layout.alias_placements)
219 full_shape = x.shape
220 key_and_shape = key + str(full_shape)
221 x = x.to_local()
222 if key_and_shape in self._transform_cache:
223 transform_operator_list = self._transform_cache[key_and_shape]
224 for transform_operator in transform_operator_list:
225 x = self._construct_op_operator[transform_operator[0]](x, *transform_operator[1])
226 return DTensor.from_local(x, to_layout.mesh, to_layout.alias_placements)
228 rank_list = from_layout.rank_list
229 if self._apply_eazy_redistribute(from_layout, to_layout):
230 if from_layout.is_partial():
231 from_layout.reset_partial()
232 x = self._redistribution_without_shape(x, from_layout, to_layout, key, rank_list)
233 else:
234 transform_operator_list = self._infer_transform_operator_list(from_layout, to_layout,
235 full_shape, key_and_shape, rank_list)
236 for transform_operator in transform_operator_list:
237 x = self._construct_op_operator[transform_operator[0]](x, *transform_operator[1])
238 return DTensor.from_local(x, to_layout.mesh, to_layout.alias_placements)
240 def _infer_transform_operator_list(self, from_layout, to_layout, from_full_shape, key, rank_list):
241 """infer transform operator list"""
242 from_layout_tuple, to_layout_tuple = \
243 _construct_layout_tuple_for_transform_operator_list(from_layout, to_layout, from_full_shape)
244 self._transform_cache[key] = \
245 platform.get_tensor_transform().transform_tensor_sharding(from_layout_tuple, to_layout_tuple,
246 rank_list, False, self.rank_id)
247 return self._transform_cache[key]
249 @staticmethod
250 def _allreduce_along_dev_dim(x, op, layout, dev_dim):
251 """Do allreduce at specified axis along dev_dim."""
252 group = layout.get_comm_group_by_axis(dev_dim)
253 zero_dim = x.dim() == 0
254 if zero_dim:
255 x = x.unsqueeze(0)
256 if op == 'avg':
257 dev_num = layout.mesh_shape[layout.alias_name.index(dev_dim)]
258 x = platform.differentiable_all_reduce(x, 'sum', group)
259 x = x / dev_num
260 elif op == 'all':
261 x_int32 = platform.tensor_type_cast(x.bool(), 'int32') # True→1, False→0
262 x = platform.differentiable_all_reduce(x_int32, 'all', group)
263 x = x.bool()
264 else:
265 x = platform.differentiable_all_reduce(x, op, group)
266 if zero_dim:
267 x = x.squeeze(0)
268 return x
270 def _reduce_scatter_along_dev_dim_with_axis(self, x, axis, op, layout, dev_dim):
271 """Do reduce_scatter at specified axis along dev_dim."""
272 dev_num = layout.mesh_shape[layout.alias_name.index(dev_dim)]
273 group = layout.get_comm_group_by_axis(dev_dim)
274 output_tensor = platform.differentiable_reduce_scatter(x, dev_num, axis, op, group)
275 return output_tensor
277 def reduce_partial(self, input_x, to_layout):
278 """Reduce partial status."""
279 from_layout = input_x.layout
280 x = input_x
281 if from_layout is None or not from_layout.is_partial():
282 return x
284 x = x.to_local()
285 if from_layout.mesh_shape != to_layout.mesh_shape:
286 raise ValueError(f"For reduce partial, mesh_shape between from_layout and to_layout must be the same, "
287 f"but got {from_layout.mesh_shape} and {to_layout.mesh_shape}")
288 if to_layout.is_partial():
289 raise ValueError(f"For reduce partial, to_layout must be non-partial status, but got to_layout.partial: "
290 f"{to_layout.partial}")
292 dev_map_order = {}
293 for dev_axis in to_layout.alias_tensor_map:
294 if isinstance(dev_axis, tuple):
295 for i, sub_dev_axis in enumerate(dev_axis):
296 dev_map_order[sub_dev_axis] = i
297 else:
298 dev_map_order[dev_axis] = 0
300 pending_reduce_op_list = [] # List[Tuple[comm_op, op, dev_dim, reduce_dim]]
301 for dev_axis_index, op in enumerate(from_layout.partial):
302 if op is None:
303 continue
304 dev_axis = from_layout.alias_name[dev_axis_index]
305 apply_shard_dim = to_layout.get_dev_axis_apply_shard_axis(dev_axis)
306 comm_op = "ReduceScatter" if apply_shard_dim else "AllReduce"
307 pending_reduce_op_list.append((comm_op, op, dev_axis, apply_shard_dim))
309 # sort reduce op
310 # 1. ReduceScatter is executed before AllReduce
311 # 2. If multiple split, the dev axis split outer will be execute first.
312 # e.g. ("cp", "tp"), will execute reduce_scatter along "cp" before "tp"
313 # 3. Lower dev_id execute before higher dev_id
314 sorted_pending_reduce_op_list = \
315 sorted(pending_reduce_op_list, key=lambda reduce_pair: (reduce_pair[0] != "ReduceScatter",
316 dev_map_order.get(reduce_pair[2], 0),
317 to_layout.mesh.axis_id(reduce_pair[2])))
319 output_alias_tensor_map = list(from_layout.alias_tensor_map)
320 for reduce_op_pair in sorted_pending_reduce_op_list:
321 comm_op = reduce_op_pair[0]
322 op = reduce_op_pair[1]
323 dev_axis = reduce_op_pair[2]
324 if comm_op == "AllReduce":
325 x = TensorRedistribution._allreduce_along_dev_dim(x, op, from_layout, dev_axis)
326 elif comm_op == "ReduceScatter":
327 reduce_axis = reduce_op_pair[3]
328 x = self._reduce_scatter_along_dev_dim_with_axis(x, reduce_axis, op, from_layout, dev_axis)
329 if output_alias_tensor_map[reduce_axis] == "None":
330 output_alias_tensor_map[reduce_axis] = dev_axis
331 elif isinstance(output_alias_tensor_map[reduce_axis], tuple):
332 output_alias_tensor_map[reduce_axis] += (dev_axis,)
333 else:
334 output_alias_tensor_map[reduce_axis] = (output_alias_tensor_map[reduce_axis], dev_axis)
336 output_layout = from_layout(*output_alias_tensor_map)
337 output_layout.reset_partial()
338 return DTensor.from_local(x, output_layout.mesh, output_layout.alias_placements)
341_tensor_redistribution = TensorRedistribution()