Coverage for hyper_parallel / core / redistribute_infer.py: 61%
331 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""redistribute_infer"""
16from typing import Dict, List, Tuple, Union
18class Status:
19 SUCCESS = 0
20 FAILED = 1
23CONCAT_BY_AXIS = 0
24SPLIT_BY_AXIS = 1
25PERMUTE_BY_AXIS = 2
26NONE = -1
28class TensorMap:
29 """Enhanced tensor map struct supporting tuples for combined dimensions"""
30 def __init__(self, dims: List[Union[int, Tuple[int, ...]]]):
31 self.dims = dims
33 def GetDimByIdx(self, index: int) -> Union[int, Tuple[int, ...]]:
34 return self.dims[index] if index < len(self.dims) else NONE
36 def GetIndexByValue(self, value: Union[int, Tuple[int, ...]]) -> int:
37 for i, dim in enumerate(self.dims):
38 if dim == value:
39 return i
40 return NONE
42 def GetIndexContainValue(self, value: Union[int, Tuple[int, ...]]) -> int:
43 for i, dim in enumerate(self.dims):
44 if not isinstance(dim, tuple):
45 continue
46 if isinstance(value, tuple) and value == dim[len(dim) - len(value):]:
47 return i
48 if not isinstance(value, tuple) and value == dim[-1]:
49 return i
50 return NONE
52class DevMat:
53 """
54 Represents a multi-dimensional grid of devices where each dimension has a specific size.
55 Supports operations to retrieve device groups along single or combined dimensions.
57 Attributes:
58 dims (List[int]): Sizes of each dimension in the mesh shape.
59 _combined_dims (Dict[Tuple[int, ...], int]): Cache for precomputed combined dimension sizes.
60 """
62 def __init__(self, dims: List[int]):
63 """
64 Initialize mesh shape dimensions.
66 Args:
67 dims: List of integers representing the size of each dimension.
68 """
69 self.dims = dims
70 self._combined_dims: Dict[Tuple[int, ...], int] = {}
72 def GetDimByReverseIdx(self, idx: Union[int, Tuple[int, ...]]) -> int:
73 """
74 Get dimension size by reverse index or product of combined dimensions.
76 For a single integer index `i`, returns the size of the dimension at reverse
77 position (i.e., `dims[len(dims)-1-i]`). For a tuple of indices, returns the
78 product of sizes for the specified reverse-indexed dimensions.
80 Args:
81 idx: Integer dimension index or tuple of indices.
83 Returns:
84 Dimension size (for integer) or product of sizes (for tuple).
85 """
86 if isinstance(idx, tuple):
87 return self._GetCombinedSize(idx)
88 return self.dims[len(self.dims) - 1 - idx]
90 def _GetCombinedSize(self, dims: Union[int, Tuple[int, ...]]) -> int:
91 """
92 Compute and cache the product of sizes for combined dimensions.
94 Args:
95 dims: Tuple of dimension indices (reverse-indexed).
97 Returns:
98 Product of sizes for the specified dimensions.
99 """
100 if dims in self._combined_dims:
101 return self._combined_dims[dims]
102 size = 1
103 for d in dims:
104 size *= self.dims[len(self.dims) - 1 - d]
105 self._combined_dims[dims] = size
106 return size
108 def _GetDevicesAlongDim(self, rank: int, rank_list: List[int], dim: int) -> List[int]:
109 """
110 Get devices sharing the same coordinates.
112 Devices are grouped such that only the specified dimension varies. The mesh shape
113 is assumed to be in row-major order (last dimension changes fastest).
115 Args:
116 rank: Target device rank.
117 rank_list: Flattened list of all devices in row-major order.
118 dim: Target dimension index (0-indexed from outermost).
120 Returns:
121 List of devices in the same group as `rank` along `dim`.
123 Raises:
124 ValueError: For invalid dimension or mismatched rank_list size.
125 """
126 if dim < 0 or dim >= len(self.dims):
127 raise ValueError(f"Dimension {dim} out of range [0, {len(self.dims)})")
129 # Trivial case: dimension size is 1
130 if self.dims[dim] == 1:
131 return [rank]
133 total_devices = 1
134 for d in self.dims:
135 total_devices *= d
137 # Validate rank_list length
138 if len(rank_list) != total_devices:
139 raise ValueError(f"rank_list length ({len(rank_list)}) doesn't match "
140 f"mesh shape product ({total_devices})")
142 # Compute stride for the dimension
143 stride = 1
144 for i in range(dim + 1, len(self.dims)):
145 stride *= self.dims[i]
147 # Find local index of rank in rank_list
148 try:
149 local_index = rank_list.index(rank)
150 except ValueError as e:
151 raise ValueError(f"Rank {rank} not in rank_list") from e
153 # Calculate base index and generate group
154 index_in_dim = (local_index // stride) % self.dims[dim]
155 base = local_index - index_in_dim * stride
156 group = [rank_list[base + k * stride] for k in range(self.dims[dim])]
158 return group
160 def GetDevicesAlongDim(self, rank: int, rank_list: List[int], dim: Union[int, List[int]]) -> List[int]:
161 """
162 Get devices sharing the same coordinates.
164 For a single dimension, returns devices where only that dimension varies.
165 For a tuple of dimensions, returns devices where ONLY the specified dimensions vary,
166 sharing fixed coordinates in all other dimensions.
168 Args:
169 rank: Target device rank.
170 rank_list: Flattened list of all devices in row-major order.
171 dim: Single dimension index or tuple of indices.
173 Returns:
174 List of devices in the same hyperplane as `rank` orthogonal to `dim`.
176 Raises:
177 ValueError: For invalid dimensions or mismatched rank_list size.
178 """
179 if isinstance(dim, list):
180 result = self._GetDevicesAlongDim(rank, rank_list, dim[0])
181 current_layer_len = len(result)
182 current_layer_step = 0
183 dim_index = 1
184 while dim_index < len(dim):
185 sub_rank = result.pop(0)
186 result.extend(self._GetDevicesAlongDim(sub_rank, rank_list, dim[dim_index]))
187 current_layer_step += 1
188 if current_layer_step == current_layer_len:
189 dim_index += 1
190 current_layer_step = 0
191 current_layer_len = len(result)
192 return result
193 return self._GetDevicesAlongDim(rank, rank_list, dim)
196class RedistributionOperatorInfer:
197 """
198 Infers communication operators for tensor redistribution in distributed systems.
200 Determines the sequence of communication operations (split, concat, permute)
201 required to transform a tensor from an input device mapping to an output device mapping.
203 Args:
204 dev_mat: Mesh shape dimensions representing the device grid
205 in_tensor_map: Input tensor's device mapping for each tensor dimension
206 out_tensor_map: Output tensor's device mapping for each tensor dimension
207 use_permute: Whether to use permute operator (all-to-all) when possible (default: True)
208 """
209 def __init__(self, dev_mat: List[int],
210 in_tensor_map: List[Union[int, Tuple[int, ...]]],
211 out_tensor_map: List[Union[int, Tuple[int, ...]]],
212 use_permute: bool = True):
214 self.operator_list_: List[Tuple[int, Tuple]] = []
215 self.map_: Dict[int, Union[int, Tuple[int, ...]]] = {}
216 self.use_permute = use_permute
218 # Initialize with expanded dimensions
219 self.dev_ranks = len(dev_mat)
220 self.dev_mat_ = DevMat(dev_mat)
221 self.in_tensor_map_ = TensorMap(in_tensor_map)
222 self.out_tensor_map_ = TensorMap(out_tensor_map)
224 self.map_ = {i: self.in_tensor_map_.GetDimByIdx(i)
225 for i in range(len(in_tensor_map))}
227 def InsertOperator(self, op_type: int, args: Tuple) -> int:
228 """
229 Adds an operator to the internal operator sequence.
231 Args:
232 op_type: Operator type constant (SPLIT_BY_AXIS, CONCAT_BY_AXIS, PERMUTE_BY_AXIS)
233 args: Operator-specific arguments tuple
235 Returns:
236 Status.SUCCESS on success, Status.FAILED on error
237 """
238 self.operator_list_.append((op_type, args))
239 return Status.SUCCESS
241 def InferRedistributionOperator(self) -> int:
242 """
243 Main inference driver coordinating the redistribution sequence.
245 Executes in 3 phases until mapping is resolved:
246 1. Split operations
247 2. Permute/All-to-All operations
248 3. Concat operations
250 Returns:
251 Status.SUCCESS if full sequence inferred, Status.FAILED otherwise
252 """
253 while self.map_:
254 len_global = len(self.operator_list_)
256 while self.map_:
257 len_split_by_axis = len(self.operator_list_)
259 # Step 1: infer split op
260 if self.InferSplitByAxis() == Status.FAILED:
261 return Status.FAILED
263 # Step 2: infer alltoall op
264 while self.map_:
265 len_permute_by_axis = len(self.operator_list_)
266 if self.InferPermuteByAxis() == Status.FAILED:
267 return Status.FAILED
268 if len_permute_by_axis == len(self.operator_list_):
269 break
271 if len_split_by_axis == len(self.operator_list_):
272 break
274 # Step 3: infer allconcat op
275 if self.InferConcatByAxis() == Status.FAILED:
276 return Status.FAILED
278 if len_global == len(self.operator_list_) and self.map_:
279 index = next(iter(self.map_.keys()))
280 in_dim = self.map_[index]
281 self.map_[index] = NONE
282 dev_dim = self.dev_mat_.GetDimByReverseIdx(in_dim)
283 args = (index, in_dim, dev_dim)
284 if self.InsertOperator(CONCAT_BY_AXIS, args) == Status.FAILED:
285 return Status.FAILED
287 return Status.SUCCESS
289 def _HandleSimpleSplitCase(self, index: int, in_dim: Union[int, Tuple[int, ...]],
290 out_dim: Union[int, Tuple[int, ...]]) -> bool:
291 """Handle the simple case where input dimension is None and output dimension is not conflicting"""
292 if in_dim != NONE:
293 return False
295 conflict = any(v == out_dim for v in self.map_.values())
296 if isinstance(out_dim, tuple):
297 conflict_tuple = any(isinstance(v, tuple) and v[len(v) - len(out_dim):] == out_dim
298 for v in self.map_.values())
299 else:
300 conflict_tuple = any(isinstance(v, tuple) and v[-1] == out_dim for v in self.map_.values())
302 if not conflict and not conflict_tuple:
303 dev_dim = self.dev_mat_.GetDimByReverseIdx(out_dim)
304 args = (index, out_dim, dev_dim)
305 return self.InsertOperator(SPLIT_BY_AXIS, args) == Status.SUCCESS
307 return False
309 def _HandleTupleSplitCase(self, index: int, in_dim: Union[int, Tuple[int, ...]],
310 out_dim: Union[int, Tuple[int, ...]]) -> bool:
311 """Handle the case where output dimension is a tuple and input dimension matches prefix"""
312 if not isinstance(out_dim, tuple):
313 return False
315 if ((not isinstance(in_dim, tuple) and in_dim == out_dim[0]) or
316 (isinstance(in_dim, tuple) and in_dim == out_dim[:len(in_dim)])):
318 if isinstance(in_dim, tuple):
319 out_dim_rest = out_dim[-1] if len(out_dim[len(in_dim):]) == 1 else out_dim[len(in_dim):]
320 else:
321 out_dim_rest = out_dim[-1] if len(out_dim[1:]) == 1 else out_dim[1:]
323 conflict = any(v == out_dim_rest for v in self.map_.values())
324 if not conflict:
325 dev_dim = self.dev_mat_.GetDimByReverseIdx(out_dim_rest)
326 args = (index, out_dim_rest, dev_dim)
327 return self.InsertOperator(SPLIT_BY_AXIS, args) == Status.SUCCESS
329 return False
331 def InferSplitByAxis(self) -> int:
332 """
333 Infers split operations for the current mapping state.
335 Conditions for split:
336 - Tensor dimension changes from unmapped to mapped
337 - No conflicts in target device dimension
339 Updates internal mapping state and operator list.
341 Returns:
342 Status.SUCCESS if operations inferred, Status.FAILED on error
343 """
344 keys = list(self.map_.keys())
345 for index in keys:
346 if index not in self.map_:
347 continue
349 in_dim = self.map_[index]
350 out_dim = self.out_tensor_map_.GetDimByIdx(index)
352 if in_dim == out_dim:
353 del self.map_[index]
354 continue
356 # Handle simple case: input dimension is None
357 if self._HandleSimpleSplitCase(index, in_dim, out_dim):
358 del self.map_[index]
359 continue
361 # Handle tuple case: output dimension is a tuple
362 if self._HandleTupleSplitCase(index, in_dim, out_dim):
363 del self.map_[index]
364 continue
366 return Status.SUCCESS
368 def _HandleNoneDimPermuteCase(self, index: int, in_dim: Union[int, Tuple[int, ...]],
369 out_dim: Union[int, Tuple[int, ...]]) -> bool:
370 """Handle permute case where input dimension is None"""
371 if in_dim != NONE:
372 return False
374 # Check for conflicts in output dimension
375 conflict = any(v == out_dim for v in self.map_.values())
376 if not conflict:
377 return False
379 # Handle regular dimension conflict
380 concat_axis = self.in_tensor_map_.GetIndexByValue(out_dim)
381 if concat_axis is None:
382 return False
384 split_dev_num = self.dev_mat_.GetDimByReverseIdx(out_dim)
386 if self.use_permute:
387 # concat tensor map value, to get the communication group
388 concat_map = self.in_tensor_map_.GetDimByIdx(concat_axis)
389 concat_dev_num = self.dev_mat_.GetDimByReverseIdx(concat_map)
390 args_permute = (concat_dev_num, index, concat_axis, concat_map, split_dev_num)
392 if self.InsertOperator(PERMUTE_BY_AXIS, args_permute) == Status.FAILED:
393 return False
394 else:
395 args_concat = (concat_axis, out_dim, split_dev_num)
396 args_split = (index, out_dim, split_dev_num)
398 if self.InsertOperator(CONCAT_BY_AXIS, args_concat) == Status.FAILED:
399 return False
400 if self.InsertOperator(SPLIT_BY_AXIS, args_split) == Status.FAILED:
401 return False
403 del self.map_[index]
404 self.map_[concat_axis] = NONE
405 return True
407 def _HandleNoneDimTuplePermuteCase(self, index: int, in_dim: Union[int, Tuple[int, ...]],
408 out_dim: Union[int, Tuple[int, ...]]) -> bool:
409 """Handle permute case where input dimension is None and output dimension is a tuple with conflicts"""
410 if in_dim != NONE:
411 return False
413 if isinstance(out_dim, tuple):
414 conflict_tuple = any(isinstance(v, tuple) and v[len(v) - len(out_dim):] == out_dim
415 for v in self.map_.values())
416 else:
417 conflict_tuple = any(isinstance(v, tuple) and v[-1] == out_dim for v in self.map_.values())
419 if not conflict_tuple:
420 return False
422 concat_axis = self.in_tensor_map_.GetIndexContainValue(out_dim)
423 if concat_axis is None:
424 return False
426 split_dev_num = self.dev_mat_.GetDimByReverseIdx(out_dim)
428 if self.use_permute:
429 # concat tensor map value, to get the communication group
430 concat_map = out_dim
431 concat_dev_num = self.dev_mat_.GetDimByReverseIdx(concat_map)
432 args_permute = (concat_dev_num, index, concat_axis, concat_map, split_dev_num)
434 if self.InsertOperator(PERMUTE_BY_AXIS, args_permute) == Status.FAILED:
435 return False
436 else:
437 args_concat = (concat_axis, out_dim, split_dev_num)
438 args_split = (index, out_dim, split_dev_num)
440 if self.InsertOperator(CONCAT_BY_AXIS, args_concat) == Status.FAILED:
441 return False
442 if self.InsertOperator(SPLIT_BY_AXIS, args_split) == Status.FAILED:
443 return False
445 del self.map_[index]
446 out_dim_len = 1 if not isinstance(out_dim, tuple) else len(out_dim)
447 rest_size = len(self.map_[concat_axis]) - out_dim_len
448 new_map_item = self.map_[concat_axis][:rest_size] if rest_size > 1 else self.map_[concat_axis][0]
449 self.map_[concat_axis] = new_map_item
450 return True
452 def _HandleTupleDimPermuteCase(self, index: int, in_dim: Union[int, Tuple[int, ...]],
453 out_dim: Union[int, Tuple[int, ...]]) -> bool:
454 """Handle permute case where both input and output dimensions are tuples"""
455 if not isinstance(out_dim, tuple):
456 return False
458 if not ((not isinstance(in_dim, tuple) and in_dim == out_dim[0]) or
459 (isinstance(in_dim, tuple) and in_dim == out_dim[:len(in_dim)])):
460 return False
462 if isinstance(in_dim, tuple):
463 out_dim_rest = out_dim[-1] if len(out_dim[len(in_dim):]) == 1 else out_dim[len(in_dim):]
464 else:
465 out_dim_rest = out_dim[-1] if len(out_dim[1:]) == 1 else out_dim[1:]
467 conflict = any(v == out_dim_rest for v in self.map_.values())
468 if not conflict:
469 return False
471 concat_axis = self.in_tensor_map_.GetIndexByValue(out_dim_rest)
472 if concat_axis is None:
473 return False
475 split_dev_num = self.dev_mat_.GetDimByReverseIdx(out_dim_rest)
477 if self.use_permute:
478 # concat tensor map value, to get the communication group
479 concat_map = out_dim_rest
480 concat_dev_num = self.dev_mat_.GetDimByReverseIdx(concat_map)
481 args_permute = (concat_dev_num, index, concat_axis, concat_map, split_dev_num)
483 if self.InsertOperator(PERMUTE_BY_AXIS, args_permute) == Status.FAILED:
484 return False
485 else:
486 args_concat = (concat_axis, out_dim_rest, split_dev_num)
487 args_split = (index, out_dim_rest, split_dev_num)
489 if self.InsertOperator(CONCAT_BY_AXIS, args_concat) == Status.FAILED:
490 return False
491 if self.InsertOperator(SPLIT_BY_AXIS, args_split) == Status.FAILED:
492 return False
494 del self.map_[index]
495 self.map_[concat_axis] = NONE
496 return True
498 def InferPermuteByAxis(self) -> int:
499 """
500 Infers permutation (all-to-all) operations for dimension conflicts.
502 Handles cases where:
503 - Input dimension is unmapped but output dimension is already occupied
504 - Uses either permute operator or split+concat pair based on use_permute flag
506 Returns:
507 Status.SUCCESS if operations inferred, Status.FAILED on error
508 """
509 keys = list(self.map_.keys())
510 for index in keys:
511 if index not in self.map_:
512 continue
514 in_dim = self.map_[index]
515 out_dim = self.out_tensor_map_.GetDimByIdx(index)
517 if in_dim == out_dim:
518 del self.map_[index]
519 continue
521 # Handle different permute cases
522 if self._HandleNoneDimPermuteCase(index, in_dim, out_dim):
523 continue
525 if self._HandleNoneDimTuplePermuteCase(index, in_dim, out_dim):
526 continue
528 if self._HandleTupleDimPermuteCase(index, in_dim, out_dim):
529 continue
531 return Status.SUCCESS
533 def _HandleTupleConcatCase(self, index: int, in_dim: Union[int, Tuple[int, ...]],
534 out_dim: Union[int, Tuple[int, ...]]) -> bool:
535 """Handle concat case where input dimension is a tuple and output matches prefix"""
536 if not isinstance(in_dim, tuple):
537 return False
539 if not ((not isinstance(out_dim, tuple) and out_dim == in_dim[0]) or
540 (isinstance(out_dim, tuple) and out_dim == in_dim[:len(out_dim)])):
541 return False
543 if isinstance(out_dim, tuple):
544 in_dim_rest = in_dim[-1] if len(in_dim[len(out_dim):]) == 1 else in_dim[len(out_dim):]
545 else:
546 in_dim_rest = in_dim[-1] if len(in_dim[1:]) == 1 else in_dim[1:]
548 concat_dev_num = self.dev_mat_.GetDimByReverseIdx(in_dim_rest)
549 args = (index, in_dim_rest, concat_dev_num)
551 if self.InsertOperator(CONCAT_BY_AXIS, args) == Status.FAILED:
552 return False
554 del self.map_[index]
555 return True
557 def _HandleSimpleConcatCase(self, index: int, in_dim: Union[int, Tuple[int, ...]],
558 out_dim: Union[int, Tuple[int, ...]]) -> bool:
559 """Handle simple concat case where input dimension is mapped but output is not"""
560 if in_dim == NONE:
561 return False
563 if self.out_tensor_map_.GetIndexByValue(in_dim) != NONE:
564 return False
566 concat_dev_num = self.dev_mat_.GetDimByReverseIdx(in_dim)
567 args = (index, in_dim, concat_dev_num)
569 if self.InsertOperator(CONCAT_BY_AXIS, args) == Status.FAILED:
570 return False
572 if out_dim == NONE:
573 del self.map_[index]
574 else:
575 self.map_[index] = NONE
577 return True
579 def InferConcatByAxis(self) -> int:
580 """
581 Infers concat operations for the current mapping state.
583 Conditions for concat:
584 - Input dimension is mapped but output is unmapped
585 - Device dimension needs consolidation
587 Returns:
588 Status.SUCCESS if operations inferred, Status.FAILED on error
589 """
590 keys = list(self.map_.keys())
591 for index in keys:
592 if index not in self.map_:
593 continue
595 in_dim = self.map_[index]
596 out_dim = self.out_tensor_map_.GetDimByIdx(index)
598 # Handle tuple concat case
599 if self._HandleTupleConcatCase(index, in_dim, out_dim):
600 continue
602 # Handle simple concat case
603 if self._HandleSimpleConcatCase(index, in_dim, out_dim):
604 continue
606 return Status.SUCCESS
608 def InferOpsList(self, rank: int, rank_list: List[int]):
609 """
610 Converts internal operator sequence to executable communication operations.
612 Args:
613 rank: Current device rank
614 rank_list: Full list of device ranks in row-major order
616 Returns:
617 List of executable communication operations as tuples:
618 - ("all_concat", (dim, size, group))
619 - ("all_split", (dim, size, group))
620 - ("all_to_all", (split_dim, concat_dim, size, group))
621 """
622 self.InferRedistributionOperator()
623 ops_list = []
624 for op in self.operator_list_:
625 if op[0] == CONCAT_BY_AXIS:
626 tensor_map = [self.dev_ranks - 1 - d for d in op[1][1]] if isinstance(op[1][1], tuple) \
627 else self.dev_ranks - 1 - op[1][1]
628 group = self.dev_mat_.GetDevicesAlongDim(rank, rank_list, tensor_map)
629 concat_dim = op[1][0]
630 concat_size = op[1][2]
631 if concat_size == 1:
632 continue
633 ops_list.append(("all_concat", (concat_dim, concat_size, group)))
634 elif op[0] == SPLIT_BY_AXIS:
635 tensor_map = [self.dev_ranks - 1 - d for d in op[1][1]] if isinstance(op[1][1], tuple) \
636 else self.dev_ranks - 1 - op[1][1]
637 group = self.dev_mat_.GetDevicesAlongDim(rank, rank_list, tensor_map)
638 split_dim = op[1][0]
639 split_size = op[1][2]
640 if split_size == 1:
641 continue
642 ops_list.append(("all_split", (split_dim, split_size, group)))
643 else:
644 tensor_map = [self.dev_ranks - 1 - d for d in op[1][3]] if isinstance(op[1][3], tuple) \
645 else self.dev_ranks - 1 - op[1][3]
646 group = self.dev_mat_.GetDevicesAlongDim(rank, rank_list, tensor_map)
647 concat_dim = op[1][2]
648 split_dim = op[1][1]
649 permute_size = op[1][0]
650 if permute_size == 1:
651 continue
652 ops_list.append(("all_to_all", (split_dim, concat_dim, permute_size, group)))
653 return ops_list