Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / redistribute_infer.py: 53%

331 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""redistribute_infer""" 

16from typing import Dict, List, Tuple, Union 

17 

18 

19class Status: 

20 SUCCESS = 0 

21 FAILED = 1 

22 

23 

24CONCAT_BY_AXIS = 0 

25SPLIT_BY_AXIS = 1 

26PERMUTE_BY_AXIS = 2 

27NONE = -1 

28 

29 

30class TensorMap: 

31 """Enhanced tensor map struct supporting tuples for combined dimensions""" 

32 def __init__(self, dims: List[Union[int, Tuple[int, ...]]]): 

33 self.dims = dims 

34 

35 def get_dim_by_idx(self, index: int) -> Union[int, Tuple[int, ...]]: 

36 return self.dims[index] if index < len(self.dims) else NONE 

37 

38 def get_index_by_value(self, value: Union[int, Tuple[int, ...]]) -> int: 

39 for i, dim in enumerate(self.dims): 

40 if dim == value: 

41 return i 

42 return NONE 

43 

44 def get_index_contain_value(self, value: Union[int, Tuple[int, ...]]) -> int: 

45 for i, dim in enumerate(self.dims): 

46 if not isinstance(dim, tuple): 

47 continue 

48 if isinstance(value, tuple) and value == dim[len(dim) - len(value):]: 

49 return i 

50 if not isinstance(value, tuple) and value == dim[-1]: 

51 return i 

52 return NONE 

53 

54 

55class DevMat: 

56 """ 

57 Represents a multi-dimensional grid of devices where each dimension has a specific size. 

58 Supports operations to retrieve device groups along single or combined dimensions. 

59 

60 Attributes: 

61 dims (List[int]): Sizes of each dimension in the mesh shape. 

62 _combined_dims (Dict[Tuple[int, ...], int]): Cache for precomputed combined dimension sizes. 

63 """ 

64 

65 def __init__(self, dims: List[int]): 

66 """ 

67 Initialize mesh shape dimensions. 

68 

69 Args: 

70 dims: List of integers representing the size of each dimension. 

71 """ 

72 self.dims = dims 

73 self._combined_dims: Dict[Tuple[int, ...], int] = {} 

74 

75 def get_dim_by_reverse_idx(self, idx: Union[int, Tuple[int, ...]]) -> int: 

76 """ 

77 Get dimension size by reverse index or product of combined dimensions. 

78 

79 For a single integer index `i`, returns the size of the dimension at reverse 

80 position (i.e., `dims[len(dims)-1-i]`). For a tuple of indices, returns the 

81 product of sizes for the specified reverse-indexed dimensions. 

82 

83 Args: 

84 idx: Integer dimension index or tuple of indices. 

85 

86 Returns: 

87 Dimension size (for integer) or product of sizes (for tuple). 

88 """ 

89 if isinstance(idx, tuple): 

90 return self._get_combined_size(idx) 

91 return self.dims[len(self.dims) - 1 - idx] 

92 

93 def _get_combined_size(self, dims: Union[int, Tuple[int, ...]]) -> int: 

94 """ 

95 Compute and cache the product of sizes for combined dimensions. 

96 

97 Args: 

98 dims: Tuple of dimension indices (reverse-indexed). 

99 

100 Returns: 

101 Product of sizes for the specified dimensions. 

102 """ 

103 if dims in self._combined_dims: 

104 return self._combined_dims[dims] 

105 size = 1 

106 for d in dims: 

107 size *= self.dims[len(self.dims) - 1 - d] 

108 self._combined_dims[dims] = size 

109 return size 

110 

111 def _get_devices_along_dim(self, rank: int, rank_list: List[int], dim: int) -> List[int]: 

112 """ 

113 Get devices sharing the same coordinates. 

114 

115 Devices are grouped such that only the specified dimension varies. The mesh shape 

116 is assumed to be in row-major order (last dimension changes fastest). 

117 

118 Args: 

119 rank: Target device rank. 

120 rank_list: Flattened list of all devices in row-major order. 

121 dim: Target dimension index (0-indexed from outermost). 

122 

123 Returns: 

124 List of devices in the same group as `rank` along `dim`. 

125 

126 Raises: 

127 ValueError: For invalid dimension or mismatched rank_list size. 

128 """ 

129 if dim < 0 or dim >= len(self.dims): 

130 raise ValueError(f"Dimension {dim} out of range [0, {len(self.dims)})") 

131 

132 # Trivial case: dimension size is 1 

133 if self.dims[dim] == 1: 

134 return [rank] 

135 

136 total_devices = 1 

137 for d in self.dims: 

138 total_devices *= d 

139 

140 # Validate rank_list length 

141 if len(rank_list) != total_devices: 

142 raise ValueError(f"rank_list length ({len(rank_list)}) doesn't match " 

143 f"mesh shape product ({total_devices})") 

144 

145 # Compute stride for the dimension 

146 stride = 1 

147 for i in range(dim + 1, len(self.dims)): 

148 stride *= self.dims[i] 

149 

150 # Find local index of rank in rank_list 

151 try: 

152 local_index = rank_list.index(rank) 

153 except ValueError as e: 

154 raise ValueError(f"Rank {rank} not in rank_list") from e 

155 

156 # Calculate base index and generate group 

157 index_in_dim = (local_index // stride) % self.dims[dim] 

158 base = local_index - index_in_dim * stride 

159 group = [rank_list[base + k * stride] for k in range(self.dims[dim])] 

160 

161 return group 

162 

163 def get_devices_along_dim(self, rank: int, rank_list: List[int], dim: Union[int, List[int]]) -> List[int]: 

164 """ 

165 Get devices sharing the same coordinates. 

166 

167 For a single dimension, returns devices where only that dimension varies. 

168 For a tuple of dimensions, returns devices where ONLY the specified dimensions vary, 

169 sharing fixed coordinates in all other dimensions. 

170 

171 Args: 

172 rank: Target device rank. 

173 rank_list: Flattened list of all devices in row-major order. 

174 dim: Single dimension index or tuple of indices. 

175 

176 Returns: 

177 List of devices in the same hyperplane as `rank` orthogonal to `dim`. 

178 

179 Raises: 

180 ValueError: For invalid dimensions or mismatched rank_list size. 

181 """ 

182 if isinstance(dim, list): 

183 result = self._get_devices_along_dim(rank, rank_list, dim[0]) 

184 current_layer_len = len(result) 

185 current_layer_step = 0 

186 dim_index = 1 

187 while dim_index < len(dim): 

188 sub_rank = result.pop(0) 

189 result.extend(self._get_devices_along_dim(sub_rank, rank_list, dim[dim_index])) 

190 current_layer_step += 1 

191 if current_layer_step == current_layer_len: 

192 dim_index += 1 

193 current_layer_step = 0 

194 current_layer_len = len(result) 

195 return result 

196 return self._get_devices_along_dim(rank, rank_list, dim) 

197 

198 

199class RedistributionOperatorInfer: 

200 """ 

201 Infers communication operators for tensor redistribution in distributed systems. 

202 

203 Determines the sequence of communication operations (split, concat, permute) 

204 required to transform a tensor from an input device mapping to an output device mapping. 

205 

206 Args: 

207 dev_mat: Mesh shape dimensions representing the device grid 

208 in_tensor_map: Input tensor's device mapping for each tensor dimension 

209 out_tensor_map: Output tensor's device mapping for each tensor dimension 

210 use_permute: Whether to use permute operator (all-to-all) when possible (default: True) 

211 """ 

212 def __init__(self, dev_mat: List[int], 

213 in_tensor_map: List[Union[int, Tuple[int, ...]]], 

214 out_tensor_map: List[Union[int, Tuple[int, ...]]], 

215 use_permute: bool = True): 

216 

217 self.operator_list_: List[Tuple[int, Tuple]] = [] 

218 self.map_: Dict[int, Union[int, Tuple[int, ...]]] = {} 

219 self.use_permute = use_permute 

220 

221 # Initialize with expanded dimensions 

222 self.dev_ranks = len(dev_mat) 

223 self.dev_mat_ = DevMat(dev_mat) 

224 self.in_tensor_map_ = TensorMap(in_tensor_map) 

225 self.out_tensor_map_ = TensorMap(out_tensor_map) 

226 

227 self.map_ = {i: self.in_tensor_map_.get_dim_by_idx(i) 

228 for i in range(len(in_tensor_map))} 

229 

230 def insert_operator(self, op_type: int, args: Tuple) -> int: 

231 """ 

232 Adds an operator to the internal operator sequence. 

233 

234 Args: 

235 op_type: Operator type constant (SPLIT_BY_AXIS, CONCAT_BY_AXIS, PERMUTE_BY_AXIS) 

236 args: Operator-specific arguments tuple 

237 

238 Returns: 

239 Status.SUCCESS on success, Status.FAILED on error 

240 """ 

241 self.operator_list_.append((op_type, args)) 

242 return Status.SUCCESS 

243 

244 def infer_redistribution_operator(self) -> int: 

245 """ 

246 Main inference driver coordinating the redistribution sequence. 

247 

248 Executes in 3 phases until mapping is resolved: 

249 1. Split operations 

250 2. Permute/All-to-All operations 

251 3. Concat operations 

252 

253 Returns: 

254 Status.SUCCESS if full sequence inferred, Status.FAILED otherwise 

255 """ 

256 while self.map_: 

257 len_global = len(self.operator_list_) 

258 

259 while self.map_: 

260 len_split_by_axis = len(self.operator_list_) 

261 

262 # Step 1: infer split op 

263 if self.infer_split_by_axis() == Status.FAILED: 

264 return Status.FAILED 

265 

266 # Step 2: infer alltoall op 

267 while self.map_: 

268 len_permute_by_axis = len(self.operator_list_) 

269 if self.infer_permute_by_axis() == Status.FAILED: 

270 return Status.FAILED 

271 if len_permute_by_axis == len(self.operator_list_): 

272 break 

273 

274 if len_split_by_axis == len(self.operator_list_): 

275 break 

276 

277 # Step 3: infer allconcat op 

278 if self.infer_concat_by_axis() == Status.FAILED: 

279 return Status.FAILED 

280 

281 if len_global == len(self.operator_list_) and self.map_: 

282 index = next(iter(self.map_.keys())) 

283 in_dim = self.map_[index] 

284 self.map_[index] = NONE 

285 dev_dim = self.dev_mat_.get_dim_by_reverse_idx(in_dim) 

286 args = (index, in_dim, dev_dim) 

287 if self.insert_operator(CONCAT_BY_AXIS, args) == Status.FAILED: 

288 return Status.FAILED 

289 

290 return Status.SUCCESS 

291 

292 def _handle_simple_split_case(self, index: int, in_dim: Union[int, Tuple[int, ...]], 

293 out_dim: Union[int, Tuple[int, ...]]) -> bool: 

294 """Handle the simple case where input dimension is None and output dimension is not conflicting""" 

295 if in_dim != NONE: 

296 return False 

297 

298 conflict = any(v == out_dim for v in self.map_.values()) 

299 if isinstance(out_dim, tuple): 

300 conflict_tuple = any(isinstance(v, tuple) and v[len(v) - len(out_dim):] == out_dim 

301 for v in self.map_.values()) 

302 else: 

303 conflict_tuple = any(isinstance(v, tuple) and v[-1] == out_dim for v in self.map_.values()) 

304 

305 if not conflict and not conflict_tuple: 

306 dev_dim = self.dev_mat_.get_dim_by_reverse_idx(out_dim) 

307 args = (index, out_dim, dev_dim) 

308 return self.insert_operator(SPLIT_BY_AXIS, args) == Status.SUCCESS 

309 

310 return False 

311 

312 def _handle_tuple_split_case(self, index: int, in_dim: Union[int, Tuple[int, ...]], 

313 out_dim: Union[int, Tuple[int, ...]]) -> bool: 

314 """Handle the case where output dimension is a tuple and input dimension matches prefix""" 

315 if not isinstance(out_dim, tuple): 

316 return False 

317 

318 if ((not isinstance(in_dim, tuple) and in_dim == out_dim[0]) or 

319 (isinstance(in_dim, tuple) and in_dim == out_dim[:len(in_dim)])): 

320 

321 if isinstance(in_dim, tuple): 

322 out_dim_rest = out_dim[-1] if len(out_dim[len(in_dim):]) == 1 else out_dim[len(in_dim):] 

323 else: 

324 out_dim_rest = out_dim[-1] if len(out_dim[1:]) == 1 else out_dim[1:] 

325 

326 conflict = any(v == out_dim_rest for v in self.map_.values()) 

327 if not conflict: 

328 dev_dim = self.dev_mat_.get_dim_by_reverse_idx(out_dim_rest) 

329 args = (index, out_dim_rest, dev_dim) 

330 return self.insert_operator(SPLIT_BY_AXIS, args) == Status.SUCCESS 

331 

332 return False 

333 

334 def infer_split_by_axis(self) -> int: 

335 """ 

336 Infers split operations for the current mapping state. 

337 

338 Conditions for split: 

339 - Tensor dimension changes from unmapped to mapped 

340 - No conflicts in target device dimension 

341 

342 Updates internal mapping state and operator list. 

343 

344 Returns: 

345 Status.SUCCESS if operations inferred, Status.FAILED on error 

346 """ 

347 keys = list(self.map_.keys()) 

348 for index in keys: 

349 if index not in self.map_: 

350 continue 

351 

352 in_dim = self.map_[index] 

353 out_dim = self.out_tensor_map_.get_dim_by_idx(index) 

354 

355 if in_dim == out_dim: 

356 del self.map_[index] 

357 continue 

358 

359 # Handle simple case: input dimension is None 

360 if self._handle_simple_split_case(index, in_dim, out_dim): 

361 del self.map_[index] 

362 continue 

363 

364 # Handle tuple case: output dimension is a tuple 

365 if self._handle_tuple_split_case(index, in_dim, out_dim): 

366 del self.map_[index] 

367 continue 

368 

369 return Status.SUCCESS 

370 

371 def _handle_none_dim_permute_case(self, index: int, in_dim: Union[int, Tuple[int, ...]], 

372 out_dim: Union[int, Tuple[int, ...]]) -> bool: 

373 """Handle permute case where input dimension is None""" 

374 if in_dim != NONE: 

375 return False 

376 

377 # Check for conflicts in output dimension 

378 conflict = any(v == out_dim for v in self.map_.values()) 

379 if not conflict: 

380 return False 

381 

382 # Handle regular dimension conflict 

383 concat_axis = self.in_tensor_map_.get_index_by_value(out_dim) 

384 if concat_axis is None: 

385 return False 

386 

387 split_dev_num = self.dev_mat_.get_dim_by_reverse_idx(out_dim) 

388 

389 if self.use_permute: 

390 # concat tensor map value, to get the communication group 

391 concat_map = self.in_tensor_map_.get_dim_by_idx(concat_axis) 

392 concat_dev_num = self.dev_mat_.get_dim_by_reverse_idx(concat_map) 

393 args_permute = (concat_dev_num, index, concat_axis, concat_map, split_dev_num) 

394 

395 if self.insert_operator(PERMUTE_BY_AXIS, args_permute) == Status.FAILED: 

396 return False 

397 else: 

398 args_concat = (concat_axis, out_dim, split_dev_num) 

399 args_split = (index, out_dim, split_dev_num) 

400 

401 if self.insert_operator(CONCAT_BY_AXIS, args_concat) == Status.FAILED: 

402 return False 

403 if self.insert_operator(SPLIT_BY_AXIS, args_split) == Status.FAILED: 

404 return False 

405 

406 del self.map_[index] 

407 self.map_[concat_axis] = NONE 

408 return True 

409 

410 def _handle_none_dim_tuple_permute_case(self, index: int, in_dim: Union[int, Tuple[int, ...]], 

411 out_dim: Union[int, Tuple[int, ...]]) -> bool: 

412 """Handle permute case where input dimension is None and output dimension is a tuple with conflicts""" 

413 if in_dim != NONE: 

414 return False 

415 

416 if isinstance(out_dim, tuple): 

417 conflict_tuple = any(isinstance(v, tuple) and v[len(v) - len(out_dim):] == out_dim 

418 for v in self.map_.values()) 

419 else: 

420 conflict_tuple = any(isinstance(v, tuple) and v[-1] == out_dim for v in self.map_.values()) 

421 

422 if not conflict_tuple: 

423 return False 

424 

425 concat_axis = self.in_tensor_map_.get_index_contain_value(out_dim) 

426 if concat_axis is None: 

427 return False 

428 

429 split_dev_num = self.dev_mat_.get_dim_by_reverse_idx(out_dim) 

430 

431 if self.use_permute: 

432 # concat tensor map value, to get the communication group 

433 concat_map = out_dim 

434 concat_dev_num = self.dev_mat_.get_dim_by_reverse_idx(concat_map) 

435 args_permute = (concat_dev_num, index, concat_axis, concat_map, split_dev_num) 

436 

437 if self.insert_operator(PERMUTE_BY_AXIS, args_permute) == Status.FAILED: 

438 return False 

439 else: 

440 args_concat = (concat_axis, out_dim, split_dev_num) 

441 args_split = (index, out_dim, split_dev_num) 

442 

443 if self.insert_operator(CONCAT_BY_AXIS, args_concat) == Status.FAILED: 

444 return False 

445 if self.insert_operator(SPLIT_BY_AXIS, args_split) == Status.FAILED: 

446 return False 

447 

448 del self.map_[index] 

449 out_dim_len = 1 if not isinstance(out_dim, tuple) else len(out_dim) 

450 rest_size = len(self.map_[concat_axis]) - out_dim_len 

451 new_map_item = self.map_[concat_axis][:rest_size] if rest_size > 1 else self.map_[concat_axis][0] 

452 self.map_[concat_axis] = new_map_item 

453 return True 

454 

455 def _handle_tuple_dim_permute_case(self, index: int, in_dim: Union[int, Tuple[int, ...]], 

456 out_dim: Union[int, Tuple[int, ...]]) -> bool: 

457 """Handle permute case where both input and output dimensions are tuples""" 

458 if not isinstance(out_dim, tuple): 

459 return False 

460 

461 if not ((not isinstance(in_dim, tuple) and in_dim == out_dim[0]) or 

462 (isinstance(in_dim, tuple) and in_dim == out_dim[:len(in_dim)])): 

463 return False 

464 

465 if isinstance(in_dim, tuple): 

466 out_dim_rest = out_dim[-1] if len(out_dim[len(in_dim):]) == 1 else out_dim[len(in_dim):] 

467 else: 

468 out_dim_rest = out_dim[-1] if len(out_dim[1:]) == 1 else out_dim[1:] 

469 

470 conflict = any(v == out_dim_rest for v in self.map_.values()) 

471 if not conflict: 

472 return False 

473 

474 concat_axis = self.in_tensor_map_.get_index_by_value(out_dim_rest) 

475 if concat_axis is None: 

476 return False 

477 

478 split_dev_num = self.dev_mat_.get_dim_by_reverse_idx(out_dim_rest) 

479 

480 if self.use_permute: 

481 # concat tensor map value, to get the communication group 

482 concat_map = out_dim_rest 

483 concat_dev_num = self.dev_mat_.get_dim_by_reverse_idx(concat_map) 

484 args_permute = (concat_dev_num, index, concat_axis, concat_map, split_dev_num) 

485 

486 if self.insert_operator(PERMUTE_BY_AXIS, args_permute) == Status.FAILED: 

487 return False 

488 else: 

489 args_concat = (concat_axis, out_dim_rest, split_dev_num) 

490 args_split = (index, out_dim_rest, split_dev_num) 

491 

492 if self.insert_operator(CONCAT_BY_AXIS, args_concat) == Status.FAILED: 

493 return False 

494 if self.insert_operator(SPLIT_BY_AXIS, args_split) == Status.FAILED: 

495 return False 

496 

497 del self.map_[index] 

498 self.map_[concat_axis] = NONE 

499 return True 

500 

501 def infer_permute_by_axis(self) -> int: 

502 """ 

503 Infers permutation (all-to-all) operations for dimension conflicts. 

504 

505 Handles cases where: 

506 - Input dimension is unmapped but output dimension is already occupied 

507 - Uses either permute operator or split+concat pair based on use_permute flag 

508 

509 Returns: 

510 Status.SUCCESS if operations inferred, Status.FAILED on error 

511 """ 

512 keys = list(self.map_.keys()) 

513 for index in keys: 

514 if index not in self.map_: 

515 continue 

516 

517 in_dim = self.map_[index] 

518 out_dim = self.out_tensor_map_.get_dim_by_idx(index) 

519 

520 if in_dim == out_dim: 

521 del self.map_[index] 

522 continue 

523 

524 # Handle different permute cases 

525 if self._handle_none_dim_permute_case(index, in_dim, out_dim): 

526 continue 

527 

528 if self._handle_none_dim_tuple_permute_case(index, in_dim, out_dim): 

529 continue 

530 

531 if self._handle_tuple_dim_permute_case(index, in_dim, out_dim): 

532 continue 

533 

534 return Status.SUCCESS 

535 

536 def _handle_tuple_concat_case(self, index: int, in_dim: Union[int, Tuple[int, ...]], 

537 out_dim: Union[int, Tuple[int, ...]]) -> bool: 

538 """Handle concat case where input dimension is a tuple and output matches prefix""" 

539 if not isinstance(in_dim, tuple): 

540 return False 

541 

542 if not ((not isinstance(out_dim, tuple) and out_dim == in_dim[0]) or 

543 (isinstance(out_dim, tuple) and out_dim == in_dim[:len(out_dim)])): 

544 return False 

545 

546 if isinstance(out_dim, tuple): 

547 in_dim_rest = in_dim[-1] if len(in_dim[len(out_dim):]) == 1 else in_dim[len(out_dim):] 

548 else: 

549 in_dim_rest = in_dim[-1] if len(in_dim[1:]) == 1 else in_dim[1:] 

550 

551 concat_dev_num = self.dev_mat_.get_dim_by_reverse_idx(in_dim_rest) 

552 args = (index, in_dim_rest, concat_dev_num) 

553 

554 if self.insert_operator(CONCAT_BY_AXIS, args) == Status.FAILED: 

555 return False 

556 

557 del self.map_[index] 

558 return True 

559 

560 def _handle_simple_concat_case(self, index: int, in_dim: Union[int, Tuple[int, ...]], 

561 out_dim: Union[int, Tuple[int, ...]]) -> bool: 

562 """Handle simple concat case where input dimension is mapped but output is not""" 

563 if in_dim == NONE: 

564 return False 

565 

566 if self.out_tensor_map_.get_index_by_value(in_dim) != NONE: 

567 return False 

568 

569 concat_dev_num = self.dev_mat_.get_dim_by_reverse_idx(in_dim) 

570 args = (index, in_dim, concat_dev_num) 

571 

572 if self.insert_operator(CONCAT_BY_AXIS, args) == Status.FAILED: 

573 return False 

574 

575 if out_dim == NONE: 

576 del self.map_[index] 

577 else: 

578 self.map_[index] = NONE 

579 

580 return True 

581 

582 def infer_concat_by_axis(self) -> int: 

583 """ 

584 Infers concat operations for the current mapping state. 

585 

586 Conditions for concat: 

587 - Input dimension is mapped but output is unmapped 

588 - Device dimension needs consolidation 

589 

590 Returns: 

591 Status.SUCCESS if operations inferred, Status.FAILED on error 

592 """ 

593 keys = list(self.map_.keys()) 

594 for index in keys: 

595 if index not in self.map_: 

596 continue 

597 

598 in_dim = self.map_[index] 

599 out_dim = self.out_tensor_map_.get_dim_by_idx(index) 

600 

601 # Handle tuple concat case 

602 if self._handle_tuple_concat_case(index, in_dim, out_dim): 

603 continue 

604 

605 # Handle simple concat case 

606 if self._handle_simple_concat_case(index, in_dim, out_dim): 

607 continue 

608 

609 return Status.SUCCESS 

610 

611 def infer_ops_list(self, rank: int, rank_list: List[int]): 

612 """ 

613 Converts internal operator sequence to executable communication operations. 

614 

615 Args: 

616 rank: Current device rank 

617 rank_list: Full list of device ranks in row-major order 

618 

619 Returns: 

620 List of executable communication operations as tuples: 

621 - ("all_concat", (dim, size, group)) 

622 - ("all_split", (dim, size, group)) 

623 - ("all_to_all", (split_dim, concat_dim, size, group)) 

624 """ 

625 self.infer_redistribution_operator() 

626 ops_list = [] 

627 for op in self.operator_list_: 

628 if op[0] == CONCAT_BY_AXIS: 

629 tensor_map = [self.dev_ranks - 1 - d for d in op[1][1]] if isinstance(op[1][1], tuple) \ 

630 else self.dev_ranks - 1 - op[1][1] 

631 group = self.dev_mat_.get_devices_along_dim(rank, rank_list, tensor_map) 

632 concat_dim = op[1][0] 

633 concat_size = op[1][2] 

634 if concat_size == 1: 

635 continue 

636 ops_list.append(("all_concat", (concat_dim, concat_size, group))) 

637 elif op[0] == SPLIT_BY_AXIS: 

638 tensor_map = [self.dev_ranks - 1 - d for d in op[1][1]] if isinstance(op[1][1], tuple) \ 

639 else self.dev_ranks - 1 - op[1][1] 

640 group = self.dev_mat_.get_devices_along_dim(rank, rank_list, tensor_map) 

641 split_dim = op[1][0] 

642 split_size = op[1][2] 

643 if split_size == 1: 

644 continue 

645 ops_list.append(("all_split", (split_dim, split_size, group))) 

646 else: 

647 tensor_map = [self.dev_ranks - 1 - d for d in op[1][3]] if isinstance(op[1][3], tuple) \ 

648 else self.dev_ranks - 1 - op[1][3] 

649 group = self.dev_mat_.get_devices_along_dim(rank, rank_list, tensor_map) 

650 concat_dim = op[1][2] 

651 split_dim = op[1][1] 

652 permute_size = op[1][0] 

653 if permute_size == 1: 

654 continue 

655 ops_list.append(("all_to_all", (split_dim, concat_dim, permute_size, group))) 

656 return ops_list