Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / redistribute_infer.py: 53%
331 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""redistribute_infer"""
16from typing import Dict, List, Tuple, Union
19class Status:
20 SUCCESS = 0
21 FAILED = 1
24CONCAT_BY_AXIS = 0
25SPLIT_BY_AXIS = 1
26PERMUTE_BY_AXIS = 2
27NONE = -1
30class TensorMap:
31 """Enhanced tensor map struct supporting tuples for combined dimensions"""
32 def __init__(self, dims: List[Union[int, Tuple[int, ...]]]):
33 self.dims = dims
35 def get_dim_by_idx(self, index: int) -> Union[int, Tuple[int, ...]]:
36 return self.dims[index] if index < len(self.dims) else NONE
38 def get_index_by_value(self, value: Union[int, Tuple[int, ...]]) -> int:
39 for i, dim in enumerate(self.dims):
40 if dim == value:
41 return i
42 return NONE
44 def get_index_contain_value(self, value: Union[int, Tuple[int, ...]]) -> int:
45 for i, dim in enumerate(self.dims):
46 if not isinstance(dim, tuple):
47 continue
48 if isinstance(value, tuple) and value == dim[len(dim) - len(value):]:
49 return i
50 if not isinstance(value, tuple) and value == dim[-1]:
51 return i
52 return NONE
55class DevMat:
56 """
57 Represents a multi-dimensional grid of devices where each dimension has a specific size.
58 Supports operations to retrieve device groups along single or combined dimensions.
60 Attributes:
61 dims (List[int]): Sizes of each dimension in the mesh shape.
62 _combined_dims (Dict[Tuple[int, ...], int]): Cache for precomputed combined dimension sizes.
63 """
65 def __init__(self, dims: List[int]):
66 """
67 Initialize mesh shape dimensions.
69 Args:
70 dims: List of integers representing the size of each dimension.
71 """
72 self.dims = dims
73 self._combined_dims: Dict[Tuple[int, ...], int] = {}
75 def get_dim_by_reverse_idx(self, idx: Union[int, Tuple[int, ...]]) -> int:
76 """
77 Get dimension size by reverse index or product of combined dimensions.
79 For a single integer index `i`, returns the size of the dimension at reverse
80 position (i.e., `dims[len(dims)-1-i]`). For a tuple of indices, returns the
81 product of sizes for the specified reverse-indexed dimensions.
83 Args:
84 idx: Integer dimension index or tuple of indices.
86 Returns:
87 Dimension size (for integer) or product of sizes (for tuple).
88 """
89 if isinstance(idx, tuple):
90 return self._get_combined_size(idx)
91 return self.dims[len(self.dims) - 1 - idx]
93 def _get_combined_size(self, dims: Union[int, Tuple[int, ...]]) -> int:
94 """
95 Compute and cache the product of sizes for combined dimensions.
97 Args:
98 dims: Tuple of dimension indices (reverse-indexed).
100 Returns:
101 Product of sizes for the specified dimensions.
102 """
103 if dims in self._combined_dims:
104 return self._combined_dims[dims]
105 size = 1
106 for d in dims:
107 size *= self.dims[len(self.dims) - 1 - d]
108 self._combined_dims[dims] = size
109 return size
111 def _get_devices_along_dim(self, rank: int, rank_list: List[int], dim: int) -> List[int]:
112 """
113 Get devices sharing the same coordinates.
115 Devices are grouped such that only the specified dimension varies. The mesh shape
116 is assumed to be in row-major order (last dimension changes fastest).
118 Args:
119 rank: Target device rank.
120 rank_list: Flattened list of all devices in row-major order.
121 dim: Target dimension index (0-indexed from outermost).
123 Returns:
124 List of devices in the same group as `rank` along `dim`.
126 Raises:
127 ValueError: For invalid dimension or mismatched rank_list size.
128 """
129 if dim < 0 or dim >= len(self.dims):
130 raise ValueError(f"Dimension {dim} out of range [0, {len(self.dims)})")
132 # Trivial case: dimension size is 1
133 if self.dims[dim] == 1:
134 return [rank]
136 total_devices = 1
137 for d in self.dims:
138 total_devices *= d
140 # Validate rank_list length
141 if len(rank_list) != total_devices:
142 raise ValueError(f"rank_list length ({len(rank_list)}) doesn't match "
143 f"mesh shape product ({total_devices})")
145 # Compute stride for the dimension
146 stride = 1
147 for i in range(dim + 1, len(self.dims)):
148 stride *= self.dims[i]
150 # Find local index of rank in rank_list
151 try:
152 local_index = rank_list.index(rank)
153 except ValueError as e:
154 raise ValueError(f"Rank {rank} not in rank_list") from e
156 # Calculate base index and generate group
157 index_in_dim = (local_index // stride) % self.dims[dim]
158 base = local_index - index_in_dim * stride
159 group = [rank_list[base + k * stride] for k in range(self.dims[dim])]
161 return group
163 def get_devices_along_dim(self, rank: int, rank_list: List[int], dim: Union[int, List[int]]) -> List[int]:
164 """
165 Get devices sharing the same coordinates.
167 For a single dimension, returns devices where only that dimension varies.
168 For a tuple of dimensions, returns devices where ONLY the specified dimensions vary,
169 sharing fixed coordinates in all other dimensions.
171 Args:
172 rank: Target device rank.
173 rank_list: Flattened list of all devices in row-major order.
174 dim: Single dimension index or tuple of indices.
176 Returns:
177 List of devices in the same hyperplane as `rank` orthogonal to `dim`.
179 Raises:
180 ValueError: For invalid dimensions or mismatched rank_list size.
181 """
182 if isinstance(dim, list):
183 result = self._get_devices_along_dim(rank, rank_list, dim[0])
184 current_layer_len = len(result)
185 current_layer_step = 0
186 dim_index = 1
187 while dim_index < len(dim):
188 sub_rank = result.pop(0)
189 result.extend(self._get_devices_along_dim(sub_rank, rank_list, dim[dim_index]))
190 current_layer_step += 1
191 if current_layer_step == current_layer_len:
192 dim_index += 1
193 current_layer_step = 0
194 current_layer_len = len(result)
195 return result
196 return self._get_devices_along_dim(rank, rank_list, dim)
199class RedistributionOperatorInfer:
200 """
201 Infers communication operators for tensor redistribution in distributed systems.
203 Determines the sequence of communication operations (split, concat, permute)
204 required to transform a tensor from an input device mapping to an output device mapping.
206 Args:
207 dev_mat: Mesh shape dimensions representing the device grid
208 in_tensor_map: Input tensor's device mapping for each tensor dimension
209 out_tensor_map: Output tensor's device mapping for each tensor dimension
210 use_permute: Whether to use permute operator (all-to-all) when possible (default: True)
211 """
212 def __init__(self, dev_mat: List[int],
213 in_tensor_map: List[Union[int, Tuple[int, ...]]],
214 out_tensor_map: List[Union[int, Tuple[int, ...]]],
215 use_permute: bool = True):
217 self.operator_list_: List[Tuple[int, Tuple]] = []
218 self.map_: Dict[int, Union[int, Tuple[int, ...]]] = {}
219 self.use_permute = use_permute
221 # Initialize with expanded dimensions
222 self.dev_ranks = len(dev_mat)
223 self.dev_mat_ = DevMat(dev_mat)
224 self.in_tensor_map_ = TensorMap(in_tensor_map)
225 self.out_tensor_map_ = TensorMap(out_tensor_map)
227 self.map_ = {i: self.in_tensor_map_.get_dim_by_idx(i)
228 for i in range(len(in_tensor_map))}
230 def insert_operator(self, op_type: int, args: Tuple) -> int:
231 """
232 Adds an operator to the internal operator sequence.
234 Args:
235 op_type: Operator type constant (SPLIT_BY_AXIS, CONCAT_BY_AXIS, PERMUTE_BY_AXIS)
236 args: Operator-specific arguments tuple
238 Returns:
239 Status.SUCCESS on success, Status.FAILED on error
240 """
241 self.operator_list_.append((op_type, args))
242 return Status.SUCCESS
244 def infer_redistribution_operator(self) -> int:
245 """
246 Main inference driver coordinating the redistribution sequence.
248 Executes in 3 phases until mapping is resolved:
249 1. Split operations
250 2. Permute/All-to-All operations
251 3. Concat operations
253 Returns:
254 Status.SUCCESS if full sequence inferred, Status.FAILED otherwise
255 """
256 while self.map_:
257 len_global = len(self.operator_list_)
259 while self.map_:
260 len_split_by_axis = len(self.operator_list_)
262 # Step 1: infer split op
263 if self.infer_split_by_axis() == Status.FAILED:
264 return Status.FAILED
266 # Step 2: infer alltoall op
267 while self.map_:
268 len_permute_by_axis = len(self.operator_list_)
269 if self.infer_permute_by_axis() == Status.FAILED:
270 return Status.FAILED
271 if len_permute_by_axis == len(self.operator_list_):
272 break
274 if len_split_by_axis == len(self.operator_list_):
275 break
277 # Step 3: infer allconcat op
278 if self.infer_concat_by_axis() == Status.FAILED:
279 return Status.FAILED
281 if len_global == len(self.operator_list_) and self.map_:
282 index = next(iter(self.map_.keys()))
283 in_dim = self.map_[index]
284 self.map_[index] = NONE
285 dev_dim = self.dev_mat_.get_dim_by_reverse_idx(in_dim)
286 args = (index, in_dim, dev_dim)
287 if self.insert_operator(CONCAT_BY_AXIS, args) == Status.FAILED:
288 return Status.FAILED
290 return Status.SUCCESS
292 def _handle_simple_split_case(self, index: int, in_dim: Union[int, Tuple[int, ...]],
293 out_dim: Union[int, Tuple[int, ...]]) -> bool:
294 """Handle the simple case where input dimension is None and output dimension is not conflicting"""
295 if in_dim != NONE:
296 return False
298 conflict = any(v == out_dim for v in self.map_.values())
299 if isinstance(out_dim, tuple):
300 conflict_tuple = any(isinstance(v, tuple) and v[len(v) - len(out_dim):] == out_dim
301 for v in self.map_.values())
302 else:
303 conflict_tuple = any(isinstance(v, tuple) and v[-1] == out_dim for v in self.map_.values())
305 if not conflict and not conflict_tuple:
306 dev_dim = self.dev_mat_.get_dim_by_reverse_idx(out_dim)
307 args = (index, out_dim, dev_dim)
308 return self.insert_operator(SPLIT_BY_AXIS, args) == Status.SUCCESS
310 return False
312 def _handle_tuple_split_case(self, index: int, in_dim: Union[int, Tuple[int, ...]],
313 out_dim: Union[int, Tuple[int, ...]]) -> bool:
314 """Handle the case where output dimension is a tuple and input dimension matches prefix"""
315 if not isinstance(out_dim, tuple):
316 return False
318 if ((not isinstance(in_dim, tuple) and in_dim == out_dim[0]) or
319 (isinstance(in_dim, tuple) and in_dim == out_dim[:len(in_dim)])):
321 if isinstance(in_dim, tuple):
322 out_dim_rest = out_dim[-1] if len(out_dim[len(in_dim):]) == 1 else out_dim[len(in_dim):]
323 else:
324 out_dim_rest = out_dim[-1] if len(out_dim[1:]) == 1 else out_dim[1:]
326 conflict = any(v == out_dim_rest for v in self.map_.values())
327 if not conflict:
328 dev_dim = self.dev_mat_.get_dim_by_reverse_idx(out_dim_rest)
329 args = (index, out_dim_rest, dev_dim)
330 return self.insert_operator(SPLIT_BY_AXIS, args) == Status.SUCCESS
332 return False
334 def infer_split_by_axis(self) -> int:
335 """
336 Infers split operations for the current mapping state.
338 Conditions for split:
339 - Tensor dimension changes from unmapped to mapped
340 - No conflicts in target device dimension
342 Updates internal mapping state and operator list.
344 Returns:
345 Status.SUCCESS if operations inferred, Status.FAILED on error
346 """
347 keys = list(self.map_.keys())
348 for index in keys:
349 if index not in self.map_:
350 continue
352 in_dim = self.map_[index]
353 out_dim = self.out_tensor_map_.get_dim_by_idx(index)
355 if in_dim == out_dim:
356 del self.map_[index]
357 continue
359 # Handle simple case: input dimension is None
360 if self._handle_simple_split_case(index, in_dim, out_dim):
361 del self.map_[index]
362 continue
364 # Handle tuple case: output dimension is a tuple
365 if self._handle_tuple_split_case(index, in_dim, out_dim):
366 del self.map_[index]
367 continue
369 return Status.SUCCESS
371 def _handle_none_dim_permute_case(self, index: int, in_dim: Union[int, Tuple[int, ...]],
372 out_dim: Union[int, Tuple[int, ...]]) -> bool:
373 """Handle permute case where input dimension is None"""
374 if in_dim != NONE:
375 return False
377 # Check for conflicts in output dimension
378 conflict = any(v == out_dim for v in self.map_.values())
379 if not conflict:
380 return False
382 # Handle regular dimension conflict
383 concat_axis = self.in_tensor_map_.get_index_by_value(out_dim)
384 if concat_axis is None:
385 return False
387 split_dev_num = self.dev_mat_.get_dim_by_reverse_idx(out_dim)
389 if self.use_permute:
390 # concat tensor map value, to get the communication group
391 concat_map = self.in_tensor_map_.get_dim_by_idx(concat_axis)
392 concat_dev_num = self.dev_mat_.get_dim_by_reverse_idx(concat_map)
393 args_permute = (concat_dev_num, index, concat_axis, concat_map, split_dev_num)
395 if self.insert_operator(PERMUTE_BY_AXIS, args_permute) == Status.FAILED:
396 return False
397 else:
398 args_concat = (concat_axis, out_dim, split_dev_num)
399 args_split = (index, out_dim, split_dev_num)
401 if self.insert_operator(CONCAT_BY_AXIS, args_concat) == Status.FAILED:
402 return False
403 if self.insert_operator(SPLIT_BY_AXIS, args_split) == Status.FAILED:
404 return False
406 del self.map_[index]
407 self.map_[concat_axis] = NONE
408 return True
410 def _handle_none_dim_tuple_permute_case(self, index: int, in_dim: Union[int, Tuple[int, ...]],
411 out_dim: Union[int, Tuple[int, ...]]) -> bool:
412 """Handle permute case where input dimension is None and output dimension is a tuple with conflicts"""
413 if in_dim != NONE:
414 return False
416 if isinstance(out_dim, tuple):
417 conflict_tuple = any(isinstance(v, tuple) and v[len(v) - len(out_dim):] == out_dim
418 for v in self.map_.values())
419 else:
420 conflict_tuple = any(isinstance(v, tuple) and v[-1] == out_dim for v in self.map_.values())
422 if not conflict_tuple:
423 return False
425 concat_axis = self.in_tensor_map_.get_index_contain_value(out_dim)
426 if concat_axis is None:
427 return False
429 split_dev_num = self.dev_mat_.get_dim_by_reverse_idx(out_dim)
431 if self.use_permute:
432 # concat tensor map value, to get the communication group
433 concat_map = out_dim
434 concat_dev_num = self.dev_mat_.get_dim_by_reverse_idx(concat_map)
435 args_permute = (concat_dev_num, index, concat_axis, concat_map, split_dev_num)
437 if self.insert_operator(PERMUTE_BY_AXIS, args_permute) == Status.FAILED:
438 return False
439 else:
440 args_concat = (concat_axis, out_dim, split_dev_num)
441 args_split = (index, out_dim, split_dev_num)
443 if self.insert_operator(CONCAT_BY_AXIS, args_concat) == Status.FAILED:
444 return False
445 if self.insert_operator(SPLIT_BY_AXIS, args_split) == Status.FAILED:
446 return False
448 del self.map_[index]
449 out_dim_len = 1 if not isinstance(out_dim, tuple) else len(out_dim)
450 rest_size = len(self.map_[concat_axis]) - out_dim_len
451 new_map_item = self.map_[concat_axis][:rest_size] if rest_size > 1 else self.map_[concat_axis][0]
452 self.map_[concat_axis] = new_map_item
453 return True
455 def _handle_tuple_dim_permute_case(self, index: int, in_dim: Union[int, Tuple[int, ...]],
456 out_dim: Union[int, Tuple[int, ...]]) -> bool:
457 """Handle permute case where both input and output dimensions are tuples"""
458 if not isinstance(out_dim, tuple):
459 return False
461 if not ((not isinstance(in_dim, tuple) and in_dim == out_dim[0]) or
462 (isinstance(in_dim, tuple) and in_dim == out_dim[:len(in_dim)])):
463 return False
465 if isinstance(in_dim, tuple):
466 out_dim_rest = out_dim[-1] if len(out_dim[len(in_dim):]) == 1 else out_dim[len(in_dim):]
467 else:
468 out_dim_rest = out_dim[-1] if len(out_dim[1:]) == 1 else out_dim[1:]
470 conflict = any(v == out_dim_rest for v in self.map_.values())
471 if not conflict:
472 return False
474 concat_axis = self.in_tensor_map_.get_index_by_value(out_dim_rest)
475 if concat_axis is None:
476 return False
478 split_dev_num = self.dev_mat_.get_dim_by_reverse_idx(out_dim_rest)
480 if self.use_permute:
481 # concat tensor map value, to get the communication group
482 concat_map = out_dim_rest
483 concat_dev_num = self.dev_mat_.get_dim_by_reverse_idx(concat_map)
484 args_permute = (concat_dev_num, index, concat_axis, concat_map, split_dev_num)
486 if self.insert_operator(PERMUTE_BY_AXIS, args_permute) == Status.FAILED:
487 return False
488 else:
489 args_concat = (concat_axis, out_dim_rest, split_dev_num)
490 args_split = (index, out_dim_rest, split_dev_num)
492 if self.insert_operator(CONCAT_BY_AXIS, args_concat) == Status.FAILED:
493 return False
494 if self.insert_operator(SPLIT_BY_AXIS, args_split) == Status.FAILED:
495 return False
497 del self.map_[index]
498 self.map_[concat_axis] = NONE
499 return True
501 def infer_permute_by_axis(self) -> int:
502 """
503 Infers permutation (all-to-all) operations for dimension conflicts.
505 Handles cases where:
506 - Input dimension is unmapped but output dimension is already occupied
507 - Uses either permute operator or split+concat pair based on use_permute flag
509 Returns:
510 Status.SUCCESS if operations inferred, Status.FAILED on error
511 """
512 keys = list(self.map_.keys())
513 for index in keys:
514 if index not in self.map_:
515 continue
517 in_dim = self.map_[index]
518 out_dim = self.out_tensor_map_.get_dim_by_idx(index)
520 if in_dim == out_dim:
521 del self.map_[index]
522 continue
524 # Handle different permute cases
525 if self._handle_none_dim_permute_case(index, in_dim, out_dim):
526 continue
528 if self._handle_none_dim_tuple_permute_case(index, in_dim, out_dim):
529 continue
531 if self._handle_tuple_dim_permute_case(index, in_dim, out_dim):
532 continue
534 return Status.SUCCESS
536 def _handle_tuple_concat_case(self, index: int, in_dim: Union[int, Tuple[int, ...]],
537 out_dim: Union[int, Tuple[int, ...]]) -> bool:
538 """Handle concat case where input dimension is a tuple and output matches prefix"""
539 if not isinstance(in_dim, tuple):
540 return False
542 if not ((not isinstance(out_dim, tuple) and out_dim == in_dim[0]) or
543 (isinstance(out_dim, tuple) and out_dim == in_dim[:len(out_dim)])):
544 return False
546 if isinstance(out_dim, tuple):
547 in_dim_rest = in_dim[-1] if len(in_dim[len(out_dim):]) == 1 else in_dim[len(out_dim):]
548 else:
549 in_dim_rest = in_dim[-1] if len(in_dim[1:]) == 1 else in_dim[1:]
551 concat_dev_num = self.dev_mat_.get_dim_by_reverse_idx(in_dim_rest)
552 args = (index, in_dim_rest, concat_dev_num)
554 if self.insert_operator(CONCAT_BY_AXIS, args) == Status.FAILED:
555 return False
557 del self.map_[index]
558 return True
560 def _handle_simple_concat_case(self, index: int, in_dim: Union[int, Tuple[int, ...]],
561 out_dim: Union[int, Tuple[int, ...]]) -> bool:
562 """Handle simple concat case where input dimension is mapped but output is not"""
563 if in_dim == NONE:
564 return False
566 if self.out_tensor_map_.get_index_by_value(in_dim) != NONE:
567 return False
569 concat_dev_num = self.dev_mat_.get_dim_by_reverse_idx(in_dim)
570 args = (index, in_dim, concat_dev_num)
572 if self.insert_operator(CONCAT_BY_AXIS, args) == Status.FAILED:
573 return False
575 if out_dim == NONE:
576 del self.map_[index]
577 else:
578 self.map_[index] = NONE
580 return True
582 def infer_concat_by_axis(self) -> int:
583 """
584 Infers concat operations for the current mapping state.
586 Conditions for concat:
587 - Input dimension is mapped but output is unmapped
588 - Device dimension needs consolidation
590 Returns:
591 Status.SUCCESS if operations inferred, Status.FAILED on error
592 """
593 keys = list(self.map_.keys())
594 for index in keys:
595 if index not in self.map_:
596 continue
598 in_dim = self.map_[index]
599 out_dim = self.out_tensor_map_.get_dim_by_idx(index)
601 # Handle tuple concat case
602 if self._handle_tuple_concat_case(index, in_dim, out_dim):
603 continue
605 # Handle simple concat case
606 if self._handle_simple_concat_case(index, in_dim, out_dim):
607 continue
609 return Status.SUCCESS
611 def infer_ops_list(self, rank: int, rank_list: List[int]):
612 """
613 Converts internal operator sequence to executable communication operations.
615 Args:
616 rank: Current device rank
617 rank_list: Full list of device ranks in row-major order
619 Returns:
620 List of executable communication operations as tuples:
621 - ("all_concat", (dim, size, group))
622 - ("all_split", (dim, size, group))
623 - ("all_to_all", (split_dim, concat_dim, size, group))
624 """
625 self.infer_redistribution_operator()
626 ops_list = []
627 for op in self.operator_list_:
628 if op[0] == CONCAT_BY_AXIS:
629 tensor_map = [self.dev_ranks - 1 - d for d in op[1][1]] if isinstance(op[1][1], tuple) \
630 else self.dev_ranks - 1 - op[1][1]
631 group = self.dev_mat_.get_devices_along_dim(rank, rank_list, tensor_map)
632 concat_dim = op[1][0]
633 concat_size = op[1][2]
634 if concat_size == 1:
635 continue
636 ops_list.append(("all_concat", (concat_dim, concat_size, group)))
637 elif op[0] == SPLIT_BY_AXIS:
638 tensor_map = [self.dev_ranks - 1 - d for d in op[1][1]] if isinstance(op[1][1], tuple) \
639 else self.dev_ranks - 1 - op[1][1]
640 group = self.dev_mat_.get_devices_along_dim(rank, rank_list, tensor_map)
641 split_dim = op[1][0]
642 split_size = op[1][2]
643 if split_size == 1:
644 continue
645 ops_list.append(("all_split", (split_dim, split_size, group)))
646 else:
647 tensor_map = [self.dev_ranks - 1 - d for d in op[1][3]] if isinstance(op[1][3], tuple) \
648 else self.dev_ranks - 1 - op[1][3]
649 group = self.dev_mat_.get_devices_along_dim(rank, rank_list, tensor_map)
650 concat_dim = op[1][2]
651 split_dim = op[1][1]
652 permute_size = op[1][0]
653 if permute_size == 1:
654 continue
655 ops_list.append(("all_to_all", (split_dim, concat_dim, permute_size, group)))
656 return ops_list