Coverage for hyper_parallel / core / redistribute_infer.py: 61%

331 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""redistribute_infer""" 

16from typing import Dict, List, Tuple, Union 

17 

18class Status: 

19 SUCCESS = 0 

20 FAILED = 1 

21 

22 

23CONCAT_BY_AXIS = 0 

24SPLIT_BY_AXIS = 1 

25PERMUTE_BY_AXIS = 2 

26NONE = -1 

27 

28class TensorMap: 

29 """Enhanced tensor map struct supporting tuples for combined dimensions""" 

30 def __init__(self, dims: List[Union[int, Tuple[int, ...]]]): 

31 self.dims = dims 

32 

33 def GetDimByIdx(self, index: int) -> Union[int, Tuple[int, ...]]: 

34 return self.dims[index] if index < len(self.dims) else NONE 

35 

36 def GetIndexByValue(self, value: Union[int, Tuple[int, ...]]) -> int: 

37 for i, dim in enumerate(self.dims): 

38 if dim == value: 

39 return i 

40 return NONE 

41 

42 def GetIndexContainValue(self, value: Union[int, Tuple[int, ...]]) -> int: 

43 for i, dim in enumerate(self.dims): 

44 if not isinstance(dim, tuple): 

45 continue 

46 if isinstance(value, tuple) and value == dim[len(dim) - len(value):]: 

47 return i 

48 if not isinstance(value, tuple) and value == dim[-1]: 

49 return i 

50 return NONE 

51 

52class DevMat: 

53 """ 

54 Represents a multi-dimensional grid of devices where each dimension has a specific size. 

55 Supports operations to retrieve device groups along single or combined dimensions. 

56 

57 Attributes: 

58 dims (List[int]): Sizes of each dimension in the mesh shape. 

59 _combined_dims (Dict[Tuple[int, ...], int]): Cache for precomputed combined dimension sizes. 

60 """ 

61 

62 def __init__(self, dims: List[int]): 

63 """ 

64 Initialize mesh shape dimensions. 

65 

66 Args: 

67 dims: List of integers representing the size of each dimension. 

68 """ 

69 self.dims = dims 

70 self._combined_dims: Dict[Tuple[int, ...], int] = {} 

71 

72 def GetDimByReverseIdx(self, idx: Union[int, Tuple[int, ...]]) -> int: 

73 """ 

74 Get dimension size by reverse index or product of combined dimensions. 

75 

76 For a single integer index `i`, returns the size of the dimension at reverse 

77 position (i.e., `dims[len(dims)-1-i]`). For a tuple of indices, returns the 

78 product of sizes for the specified reverse-indexed dimensions. 

79 

80 Args: 

81 idx: Integer dimension index or tuple of indices. 

82 

83 Returns: 

84 Dimension size (for integer) or product of sizes (for tuple). 

85 """ 

86 if isinstance(idx, tuple): 

87 return self._GetCombinedSize(idx) 

88 return self.dims[len(self.dims) - 1 - idx] 

89 

90 def _GetCombinedSize(self, dims: Union[int, Tuple[int, ...]]) -> int: 

91 """ 

92 Compute and cache the product of sizes for combined dimensions. 

93 

94 Args: 

95 dims: Tuple of dimension indices (reverse-indexed). 

96 

97 Returns: 

98 Product of sizes for the specified dimensions. 

99 """ 

100 if dims in self._combined_dims: 

101 return self._combined_dims[dims] 

102 size = 1 

103 for d in dims: 

104 size *= self.dims[len(self.dims) - 1 - d] 

105 self._combined_dims[dims] = size 

106 return size 

107 

108 def _GetDevicesAlongDim(self, rank: int, rank_list: List[int], dim: int) -> List[int]: 

109 """ 

110 Get devices sharing the same coordinates. 

111 

112 Devices are grouped such that only the specified dimension varies. The mesh shape 

113 is assumed to be in row-major order (last dimension changes fastest). 

114 

115 Args: 

116 rank: Target device rank. 

117 rank_list: Flattened list of all devices in row-major order. 

118 dim: Target dimension index (0-indexed from outermost). 

119 

120 Returns: 

121 List of devices in the same group as `rank` along `dim`. 

122 

123 Raises: 

124 ValueError: For invalid dimension or mismatched rank_list size. 

125 """ 

126 if dim < 0 or dim >= len(self.dims): 

127 raise ValueError(f"Dimension {dim} out of range [0, {len(self.dims)})") 

128 

129 # Trivial case: dimension size is 1 

130 if self.dims[dim] == 1: 

131 return [rank] 

132 

133 total_devices = 1 

134 for d in self.dims: 

135 total_devices *= d 

136 

137 # Validate rank_list length 

138 if len(rank_list) != total_devices: 

139 raise ValueError(f"rank_list length ({len(rank_list)}) doesn't match " 

140 f"mesh shape product ({total_devices})") 

141 

142 # Compute stride for the dimension 

143 stride = 1 

144 for i in range(dim + 1, len(self.dims)): 

145 stride *= self.dims[i] 

146 

147 # Find local index of rank in rank_list 

148 try: 

149 local_index = rank_list.index(rank) 

150 except ValueError as e: 

151 raise ValueError(f"Rank {rank} not in rank_list") from e 

152 

153 # Calculate base index and generate group 

154 index_in_dim = (local_index // stride) % self.dims[dim] 

155 base = local_index - index_in_dim * stride 

156 group = [rank_list[base + k * stride] for k in range(self.dims[dim])] 

157 

158 return group 

159 

160 def GetDevicesAlongDim(self, rank: int, rank_list: List[int], dim: Union[int, List[int]]) -> List[int]: 

161 """ 

162 Get devices sharing the same coordinates. 

163 

164 For a single dimension, returns devices where only that dimension varies. 

165 For a tuple of dimensions, returns devices where ONLY the specified dimensions vary, 

166 sharing fixed coordinates in all other dimensions. 

167 

168 Args: 

169 rank: Target device rank. 

170 rank_list: Flattened list of all devices in row-major order. 

171 dim: Single dimension index or tuple of indices. 

172 

173 Returns: 

174 List of devices in the same hyperplane as `rank` orthogonal to `dim`. 

175 

176 Raises: 

177 ValueError: For invalid dimensions or mismatched rank_list size. 

178 """ 

179 if isinstance(dim, list): 

180 result = self._GetDevicesAlongDim(rank, rank_list, dim[0]) 

181 current_layer_len = len(result) 

182 current_layer_step = 0 

183 dim_index = 1 

184 while dim_index < len(dim): 

185 sub_rank = result.pop(0) 

186 result.extend(self._GetDevicesAlongDim(sub_rank, rank_list, dim[dim_index])) 

187 current_layer_step += 1 

188 if current_layer_step == current_layer_len: 

189 dim_index += 1 

190 current_layer_step = 0 

191 current_layer_len = len(result) 

192 return result 

193 return self._GetDevicesAlongDim(rank, rank_list, dim) 

194 

195 

196class RedistributionOperatorInfer: 

197 """ 

198 Infers communication operators for tensor redistribution in distributed systems. 

199 

200 Determines the sequence of communication operations (split, concat, permute) 

201 required to transform a tensor from an input device mapping to an output device mapping. 

202 

203 Args: 

204 dev_mat: Mesh shape dimensions representing the device grid 

205 in_tensor_map: Input tensor's device mapping for each tensor dimension 

206 out_tensor_map: Output tensor's device mapping for each tensor dimension 

207 use_permute: Whether to use permute operator (all-to-all) when possible (default: True) 

208 """ 

209 def __init__(self, dev_mat: List[int], 

210 in_tensor_map: List[Union[int, Tuple[int, ...]]], 

211 out_tensor_map: List[Union[int, Tuple[int, ...]]], 

212 use_permute: bool = True): 

213 

214 self.operator_list_: List[Tuple[int, Tuple]] = [] 

215 self.map_: Dict[int, Union[int, Tuple[int, ...]]] = {} 

216 self.use_permute = use_permute 

217 

218 # Initialize with expanded dimensions 

219 self.dev_ranks = len(dev_mat) 

220 self.dev_mat_ = DevMat(dev_mat) 

221 self.in_tensor_map_ = TensorMap(in_tensor_map) 

222 self.out_tensor_map_ = TensorMap(out_tensor_map) 

223 

224 self.map_ = {i: self.in_tensor_map_.GetDimByIdx(i) 

225 for i in range(len(in_tensor_map))} 

226 

227 def InsertOperator(self, op_type: int, args: Tuple) -> int: 

228 """ 

229 Adds an operator to the internal operator sequence. 

230 

231 Args: 

232 op_type: Operator type constant (SPLIT_BY_AXIS, CONCAT_BY_AXIS, PERMUTE_BY_AXIS) 

233 args: Operator-specific arguments tuple 

234 

235 Returns: 

236 Status.SUCCESS on success, Status.FAILED on error 

237 """ 

238 self.operator_list_.append((op_type, args)) 

239 return Status.SUCCESS 

240 

241 def InferRedistributionOperator(self) -> int: 

242 """ 

243 Main inference driver coordinating the redistribution sequence. 

244 

245 Executes in 3 phases until mapping is resolved: 

246 1. Split operations 

247 2. Permute/All-to-All operations 

248 3. Concat operations 

249 

250 Returns: 

251 Status.SUCCESS if full sequence inferred, Status.FAILED otherwise 

252 """ 

253 while self.map_: 

254 len_global = len(self.operator_list_) 

255 

256 while self.map_: 

257 len_split_by_axis = len(self.operator_list_) 

258 

259 # Step 1: infer split op 

260 if self.InferSplitByAxis() == Status.FAILED: 

261 return Status.FAILED 

262 

263 # Step 2: infer alltoall op 

264 while self.map_: 

265 len_permute_by_axis = len(self.operator_list_) 

266 if self.InferPermuteByAxis() == Status.FAILED: 

267 return Status.FAILED 

268 if len_permute_by_axis == len(self.operator_list_): 

269 break 

270 

271 if len_split_by_axis == len(self.operator_list_): 

272 break 

273 

274 # Step 3: infer allconcat op 

275 if self.InferConcatByAxis() == Status.FAILED: 

276 return Status.FAILED 

277 

278 if len_global == len(self.operator_list_) and self.map_: 

279 index = next(iter(self.map_.keys())) 

280 in_dim = self.map_[index] 

281 self.map_[index] = NONE 

282 dev_dim = self.dev_mat_.GetDimByReverseIdx(in_dim) 

283 args = (index, in_dim, dev_dim) 

284 if self.InsertOperator(CONCAT_BY_AXIS, args) == Status.FAILED: 

285 return Status.FAILED 

286 

287 return Status.SUCCESS 

288 

289 def _HandleSimpleSplitCase(self, index: int, in_dim: Union[int, Tuple[int, ...]], 

290 out_dim: Union[int, Tuple[int, ...]]) -> bool: 

291 """Handle the simple case where input dimension is None and output dimension is not conflicting""" 

292 if in_dim != NONE: 

293 return False 

294 

295 conflict = any(v == out_dim for v in self.map_.values()) 

296 if isinstance(out_dim, tuple): 

297 conflict_tuple = any(isinstance(v, tuple) and v[len(v) - len(out_dim):] == out_dim 

298 for v in self.map_.values()) 

299 else: 

300 conflict_tuple = any(isinstance(v, tuple) and v[-1] == out_dim for v in self.map_.values()) 

301 

302 if not conflict and not conflict_tuple: 

303 dev_dim = self.dev_mat_.GetDimByReverseIdx(out_dim) 

304 args = (index, out_dim, dev_dim) 

305 return self.InsertOperator(SPLIT_BY_AXIS, args) == Status.SUCCESS 

306 

307 return False 

308 

309 def _HandleTupleSplitCase(self, index: int, in_dim: Union[int, Tuple[int, ...]], 

310 out_dim: Union[int, Tuple[int, ...]]) -> bool: 

311 """Handle the case where output dimension is a tuple and input dimension matches prefix""" 

312 if not isinstance(out_dim, tuple): 

313 return False 

314 

315 if ((not isinstance(in_dim, tuple) and in_dim == out_dim[0]) or 

316 (isinstance(in_dim, tuple) and in_dim == out_dim[:len(in_dim)])): 

317 

318 if isinstance(in_dim, tuple): 

319 out_dim_rest = out_dim[-1] if len(out_dim[len(in_dim):]) == 1 else out_dim[len(in_dim):] 

320 else: 

321 out_dim_rest = out_dim[-1] if len(out_dim[1:]) == 1 else out_dim[1:] 

322 

323 conflict = any(v == out_dim_rest for v in self.map_.values()) 

324 if not conflict: 

325 dev_dim = self.dev_mat_.GetDimByReverseIdx(out_dim_rest) 

326 args = (index, out_dim_rest, dev_dim) 

327 return self.InsertOperator(SPLIT_BY_AXIS, args) == Status.SUCCESS 

328 

329 return False 

330 

331 def InferSplitByAxis(self) -> int: 

332 """ 

333 Infers split operations for the current mapping state. 

334 

335 Conditions for split: 

336 - Tensor dimension changes from unmapped to mapped 

337 - No conflicts in target device dimension 

338 

339 Updates internal mapping state and operator list. 

340 

341 Returns: 

342 Status.SUCCESS if operations inferred, Status.FAILED on error 

343 """ 

344 keys = list(self.map_.keys()) 

345 for index in keys: 

346 if index not in self.map_: 

347 continue 

348 

349 in_dim = self.map_[index] 

350 out_dim = self.out_tensor_map_.GetDimByIdx(index) 

351 

352 if in_dim == out_dim: 

353 del self.map_[index] 

354 continue 

355 

356 # Handle simple case: input dimension is None 

357 if self._HandleSimpleSplitCase(index, in_dim, out_dim): 

358 del self.map_[index] 

359 continue 

360 

361 # Handle tuple case: output dimension is a tuple 

362 if self._HandleTupleSplitCase(index, in_dim, out_dim): 

363 del self.map_[index] 

364 continue 

365 

366 return Status.SUCCESS 

367 

368 def _HandleNoneDimPermuteCase(self, index: int, in_dim: Union[int, Tuple[int, ...]], 

369 out_dim: Union[int, Tuple[int, ...]]) -> bool: 

370 """Handle permute case where input dimension is None""" 

371 if in_dim != NONE: 

372 return False 

373 

374 # Check for conflicts in output dimension 

375 conflict = any(v == out_dim for v in self.map_.values()) 

376 if not conflict: 

377 return False 

378 

379 # Handle regular dimension conflict 

380 concat_axis = self.in_tensor_map_.GetIndexByValue(out_dim) 

381 if concat_axis is None: 

382 return False 

383 

384 split_dev_num = self.dev_mat_.GetDimByReverseIdx(out_dim) 

385 

386 if self.use_permute: 

387 # concat tensor map value, to get the communication group 

388 concat_map = self.in_tensor_map_.GetDimByIdx(concat_axis) 

389 concat_dev_num = self.dev_mat_.GetDimByReverseIdx(concat_map) 

390 args_permute = (concat_dev_num, index, concat_axis, concat_map, split_dev_num) 

391 

392 if self.InsertOperator(PERMUTE_BY_AXIS, args_permute) == Status.FAILED: 

393 return False 

394 else: 

395 args_concat = (concat_axis, out_dim, split_dev_num) 

396 args_split = (index, out_dim, split_dev_num) 

397 

398 if self.InsertOperator(CONCAT_BY_AXIS, args_concat) == Status.FAILED: 

399 return False 

400 if self.InsertOperator(SPLIT_BY_AXIS, args_split) == Status.FAILED: 

401 return False 

402 

403 del self.map_[index] 

404 self.map_[concat_axis] = NONE 

405 return True 

406 

407 def _HandleNoneDimTuplePermuteCase(self, index: int, in_dim: Union[int, Tuple[int, ...]], 

408 out_dim: Union[int, Tuple[int, ...]]) -> bool: 

409 """Handle permute case where input dimension is None and output dimension is a tuple with conflicts""" 

410 if in_dim != NONE: 

411 return False 

412 

413 if isinstance(out_dim, tuple): 

414 conflict_tuple = any(isinstance(v, tuple) and v[len(v) - len(out_dim):] == out_dim 

415 for v in self.map_.values()) 

416 else: 

417 conflict_tuple = any(isinstance(v, tuple) and v[-1] == out_dim for v in self.map_.values()) 

418 

419 if not conflict_tuple: 

420 return False 

421 

422 concat_axis = self.in_tensor_map_.GetIndexContainValue(out_dim) 

423 if concat_axis is None: 

424 return False 

425 

426 split_dev_num = self.dev_mat_.GetDimByReverseIdx(out_dim) 

427 

428 if self.use_permute: 

429 # concat tensor map value, to get the communication group 

430 concat_map = out_dim 

431 concat_dev_num = self.dev_mat_.GetDimByReverseIdx(concat_map) 

432 args_permute = (concat_dev_num, index, concat_axis, concat_map, split_dev_num) 

433 

434 if self.InsertOperator(PERMUTE_BY_AXIS, args_permute) == Status.FAILED: 

435 return False 

436 else: 

437 args_concat = (concat_axis, out_dim, split_dev_num) 

438 args_split = (index, out_dim, split_dev_num) 

439 

440 if self.InsertOperator(CONCAT_BY_AXIS, args_concat) == Status.FAILED: 

441 return False 

442 if self.InsertOperator(SPLIT_BY_AXIS, args_split) == Status.FAILED: 

443 return False 

444 

445 del self.map_[index] 

446 out_dim_len = 1 if not isinstance(out_dim, tuple) else len(out_dim) 

447 rest_size = len(self.map_[concat_axis]) - out_dim_len 

448 new_map_item = self.map_[concat_axis][:rest_size] if rest_size > 1 else self.map_[concat_axis][0] 

449 self.map_[concat_axis] = new_map_item 

450 return True 

451 

452 def _HandleTupleDimPermuteCase(self, index: int, in_dim: Union[int, Tuple[int, ...]], 

453 out_dim: Union[int, Tuple[int, ...]]) -> bool: 

454 """Handle permute case where both input and output dimensions are tuples""" 

455 if not isinstance(out_dim, tuple): 

456 return False 

457 

458 if not ((not isinstance(in_dim, tuple) and in_dim == out_dim[0]) or 

459 (isinstance(in_dim, tuple) and in_dim == out_dim[:len(in_dim)])): 

460 return False 

461 

462 if isinstance(in_dim, tuple): 

463 out_dim_rest = out_dim[-1] if len(out_dim[len(in_dim):]) == 1 else out_dim[len(in_dim):] 

464 else: 

465 out_dim_rest = out_dim[-1] if len(out_dim[1:]) == 1 else out_dim[1:] 

466 

467 conflict = any(v == out_dim_rest for v in self.map_.values()) 

468 if not conflict: 

469 return False 

470 

471 concat_axis = self.in_tensor_map_.GetIndexByValue(out_dim_rest) 

472 if concat_axis is None: 

473 return False 

474 

475 split_dev_num = self.dev_mat_.GetDimByReverseIdx(out_dim_rest) 

476 

477 if self.use_permute: 

478 # concat tensor map value, to get the communication group 

479 concat_map = out_dim_rest 

480 concat_dev_num = self.dev_mat_.GetDimByReverseIdx(concat_map) 

481 args_permute = (concat_dev_num, index, concat_axis, concat_map, split_dev_num) 

482 

483 if self.InsertOperator(PERMUTE_BY_AXIS, args_permute) == Status.FAILED: 

484 return False 

485 else: 

486 args_concat = (concat_axis, out_dim_rest, split_dev_num) 

487 args_split = (index, out_dim_rest, split_dev_num) 

488 

489 if self.InsertOperator(CONCAT_BY_AXIS, args_concat) == Status.FAILED: 

490 return False 

491 if self.InsertOperator(SPLIT_BY_AXIS, args_split) == Status.FAILED: 

492 return False 

493 

494 del self.map_[index] 

495 self.map_[concat_axis] = NONE 

496 return True 

497 

498 def InferPermuteByAxis(self) -> int: 

499 """ 

500 Infers permutation (all-to-all) operations for dimension conflicts. 

501 

502 Handles cases where: 

503 - Input dimension is unmapped but output dimension is already occupied 

504 - Uses either permute operator or split+concat pair based on use_permute flag 

505 

506 Returns: 

507 Status.SUCCESS if operations inferred, Status.FAILED on error 

508 """ 

509 keys = list(self.map_.keys()) 

510 for index in keys: 

511 if index not in self.map_: 

512 continue 

513 

514 in_dim = self.map_[index] 

515 out_dim = self.out_tensor_map_.GetDimByIdx(index) 

516 

517 if in_dim == out_dim: 

518 del self.map_[index] 

519 continue 

520 

521 # Handle different permute cases 

522 if self._HandleNoneDimPermuteCase(index, in_dim, out_dim): 

523 continue 

524 

525 if self._HandleNoneDimTuplePermuteCase(index, in_dim, out_dim): 

526 continue 

527 

528 if self._HandleTupleDimPermuteCase(index, in_dim, out_dim): 

529 continue 

530 

531 return Status.SUCCESS 

532 

533 def _HandleTupleConcatCase(self, index: int, in_dim: Union[int, Tuple[int, ...]], 

534 out_dim: Union[int, Tuple[int, ...]]) -> bool: 

535 """Handle concat case where input dimension is a tuple and output matches prefix""" 

536 if not isinstance(in_dim, tuple): 

537 return False 

538 

539 if not ((not isinstance(out_dim, tuple) and out_dim == in_dim[0]) or 

540 (isinstance(out_dim, tuple) and out_dim == in_dim[:len(out_dim)])): 

541 return False 

542 

543 if isinstance(out_dim, tuple): 

544 in_dim_rest = in_dim[-1] if len(in_dim[len(out_dim):]) == 1 else in_dim[len(out_dim):] 

545 else: 

546 in_dim_rest = in_dim[-1] if len(in_dim[1:]) == 1 else in_dim[1:] 

547 

548 concat_dev_num = self.dev_mat_.GetDimByReverseIdx(in_dim_rest) 

549 args = (index, in_dim_rest, concat_dev_num) 

550 

551 if self.InsertOperator(CONCAT_BY_AXIS, args) == Status.FAILED: 

552 return False 

553 

554 del self.map_[index] 

555 return True 

556 

557 def _HandleSimpleConcatCase(self, index: int, in_dim: Union[int, Tuple[int, ...]], 

558 out_dim: Union[int, Tuple[int, ...]]) -> bool: 

559 """Handle simple concat case where input dimension is mapped but output is not""" 

560 if in_dim == NONE: 

561 return False 

562 

563 if self.out_tensor_map_.GetIndexByValue(in_dim) != NONE: 

564 return False 

565 

566 concat_dev_num = self.dev_mat_.GetDimByReverseIdx(in_dim) 

567 args = (index, in_dim, concat_dev_num) 

568 

569 if self.InsertOperator(CONCAT_BY_AXIS, args) == Status.FAILED: 

570 return False 

571 

572 if out_dim == NONE: 

573 del self.map_[index] 

574 else: 

575 self.map_[index] = NONE 

576 

577 return True 

578 

579 def InferConcatByAxis(self) -> int: 

580 """ 

581 Infers concat operations for the current mapping state. 

582 

583 Conditions for concat: 

584 - Input dimension is mapped but output is unmapped 

585 - Device dimension needs consolidation 

586 

587 Returns: 

588 Status.SUCCESS if operations inferred, Status.FAILED on error 

589 """ 

590 keys = list(self.map_.keys()) 

591 for index in keys: 

592 if index not in self.map_: 

593 continue 

594 

595 in_dim = self.map_[index] 

596 out_dim = self.out_tensor_map_.GetDimByIdx(index) 

597 

598 # Handle tuple concat case 

599 if self._HandleTupleConcatCase(index, in_dim, out_dim): 

600 continue 

601 

602 # Handle simple concat case 

603 if self._HandleSimpleConcatCase(index, in_dim, out_dim): 

604 continue 

605 

606 return Status.SUCCESS 

607 

608 def InferOpsList(self, rank: int, rank_list: List[int]): 

609 """ 

610 Converts internal operator sequence to executable communication operations. 

611 

612 Args: 

613 rank: Current device rank 

614 rank_list: Full list of device ranks in row-major order 

615 

616 Returns: 

617 List of executable communication operations as tuples: 

618 - ("all_concat", (dim, size, group)) 

619 - ("all_split", (dim, size, group)) 

620 - ("all_to_all", (split_dim, concat_dim, size, group)) 

621 """ 

622 self.InferRedistributionOperator() 

623 ops_list = [] 

624 for op in self.operator_list_: 

625 if op[0] == CONCAT_BY_AXIS: 

626 tensor_map = [self.dev_ranks - 1 - d for d in op[1][1]] if isinstance(op[1][1], tuple) \ 

627 else self.dev_ranks - 1 - op[1][1] 

628 group = self.dev_mat_.GetDevicesAlongDim(rank, rank_list, tensor_map) 

629 concat_dim = op[1][0] 

630 concat_size = op[1][2] 

631 if concat_size == 1: 

632 continue 

633 ops_list.append(("all_concat", (concat_dim, concat_size, group))) 

634 elif op[0] == SPLIT_BY_AXIS: 

635 tensor_map = [self.dev_ranks - 1 - d for d in op[1][1]] if isinstance(op[1][1], tuple) \ 

636 else self.dev_ranks - 1 - op[1][1] 

637 group = self.dev_mat_.GetDevicesAlongDim(rank, rank_list, tensor_map) 

638 split_dim = op[1][0] 

639 split_size = op[1][2] 

640 if split_size == 1: 

641 continue 

642 ops_list.append(("all_split", (split_dim, split_size, group))) 

643 else: 

644 tensor_map = [self.dev_ranks - 1 - d for d in op[1][3]] if isinstance(op[1][3], tuple) \ 

645 else self.dev_ranks - 1 - op[1][3] 

646 group = self.dev_mat_.GetDevicesAlongDim(rank, rank_list, tensor_map) 

647 concat_dim = op[1][2] 

648 split_dim = op[1][1] 

649 permute_size = op[1][0] 

650 if permute_size == 1: 

651 continue 

652 ops_list.append(("all_to_all", (split_dim, concat_dim, permute_size, group))) 

653 return ops_list