Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / layout.py: 75%

406 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""layout""" 

16 

17import copy 

18import functools 

19import numpy as np 

20 

21 

22from hyper_parallel.core.dtensor.placement_types import Placement, Shard, StridedShard, Replicate, Partial 

23from hyper_parallel.core.dtensor.device_mesh import DeviceMesh, _create_device_mesh 

24from hyper_parallel.platform import get_platform 

25 

26platform = get_platform() 

27 

28 

29def _infer_slice_area_by_rank(mesh_shape, tensor_map, rank_id: int, full_shape: tuple): # -> tuple[tuple[int]]: 

30 """Return the range of each axis from full tensor for slice in current rank.""" 

31 

32 def _get_dev_num_alone_dim(mesh_shape, dim): 

33 """_get_dev_num_alone_dim.""" 

34 return mesh_shape[-dim - 1] if dim != -1 else 1 

35 

36 def _rank_id_to_dev_id_list(mesh_shape, rank_id): 

37 """Infer dev id list by rank_id and mesh_shape""" 

38 dims = len(mesh_shape) 

39 dev_id_list = [0] * dims 

40 for i in range(dims - 1, -1, -1): 

41 dev_id_list[i] = rank_id % mesh_shape[i] 

42 rank_id = rank_id // mesh_shape[i] 

43 return dev_id_list 

44 

45 dev_id_list = _rank_id_to_dev_id_list(mesh_shape, rank_id) 

46 

47 dims = len(full_shape) 

48 area = [] 

49 for axis in range(dims): 

50 mapping = tensor_map[axis] 

51 if isinstance(mapping, int): 

52 mapping = (mapping,) 

53 split_num = 1 

54 for dim in mapping: 

55 split_num *= _get_dev_num_alone_dim(mesh_shape, dim) 

56 

57 slice_id = 0 

58 coef = 1 

59 for dim in reversed(mapping): 

60 if dim == -1: 

61 continue 

62 slice_id += dev_id_list[-dim - 1] * coef 

63 coef *= _get_dev_num_alone_dim(mesh_shape, dim) 

64 slice_size = full_shape[axis] // split_num 

65 start = slice_id * slice_size 

66 end = start + slice_size 

67 area.append((start, end)) 

68 return area 

69 

70 

71def _get_slice_tensor_by_layout(global_tensor, layout): 

72 """Transfer global tensor to local tensor by layout""" 

73 inner_rank_id = layout.rank_list.index(layout.mesh.rank) 

74 slice_area = _infer_slice_area_by_rank(layout.mesh_shape, layout.tensor_map, inner_rank_id, global_tensor.shape) 

75 

76 def get_slice_data(full_data, offset): 

77 area = () 

78 for begin, end in offset: 

79 area += (slice(begin, end),) 

80 return full_data[area].clone() 

81 

82 local_tensor = get_slice_data(global_tensor, slice_area) 

83 return local_tensor 

84 

85 

86def _infer_slice_shape_by_layout(global_shape, layout): 

87 """Infer slice shape from global_shape and layout""" 

88 slice_shape = list(global_shape) 

89 alias_tensor_map = layout.alias_tensor_map 

90 for i in range(len(global_shape)): 

91 axis_name = alias_tensor_map[i] 

92 if isinstance(axis_name, str): 

93 axis_name = (axis_name,) 

94 for sub_axis_name in axis_name: 

95 if sub_axis_name != "None": 

96 slice_shape[i] = slice_shape[i] // layout.mesh.get_device_num_along_axis(sub_axis_name) 

97 return slice_shape 

98 

99 

100class Layout: 

101 """ 

102 Topological abstraction describing cluster devices for tensor slice placement on the cluster. 

103 

104 Note: 

105 - It is valid only in semi auto parallel or auto parallel mode. 

106 - The multiplication result of the `mesh_shape` must be equal to the device count in a pipeline stage. 

107 - When the layout function is invoked to constructs a sharding strategy, each alias name is only allowed to be 

108 used once to shard a tensor. 

109 

110 Args: 

111 mesh_shape (tuple): Describe the shape of devices arrangement, its element type is int. 

112 alias_name (tuple): The alias name for each axis of mesh_shape, its length shoits element type is string. 

113 When using "interleaved_parallel" as an alias name, the tensor would be split into multiple 

114 copies on the corresponding partition dimension on a single card. 

115 rank_list (tuple, optional): Data is allocated to the device according to rank_list. Default: ``None``. 

116 

117 Raises: 

118 TypeError: `mesh_shape` is not a tuple type. 

119 TypeError: `alias_name` is not a tuple type. 

120 TypeError: 'rank_list' is not a list type. 

121 ValueError: `mesh_shape` length is not equal to `alias_name` length. 

122 TypeError: The element of `mesh_shape` is not int type. 

123 TypeError: The element of `alias_name` is not a str type. 

124 TypeError: The element of `rank_list` is not int type. 

125 ValueError: The element of `alias_name` is an empty str. 

126 ValueError: The element of `alias_name` is "None". 

127 ValueError: `alias_name` contains repeated element. 

128 

129 Supported Platforms: 

130 ``Ascend`` 

131 

132 Examples: 

133 >>> from mindspore.parallel import Layout 

134 >>> layout = Layout((2, 2, 2), ("dp", "sp", "mp")) 

135 >>> layout0 = layout("dp", "mp") 

136 >>> print(layout0.to_dict()) 

137 {"mesh_shape": (2, 2, 2), "tensor_map": (2, 0), "interleaved_parallel": False, 

138 'alias_name': {'dp', 'sp', 'mp'}, "rank_list": [0, 1, 2, 3, 4, 5, 6, 7]} 

139 >>> layout = Layout((2, 2, 2), ("dp", "sp", "interleaved_parallel")) 

140 >>> layout1 = layout(("dp", "interleaved_parallel"), "sp") 

141 """ 

142 

143 def __init__(self, mesh_shape, alias_name, rank_list=None, init_backend=True): 

144 self._alias_name = alias_name 

145 self._tensor_map = None 

146 if not rank_list: 

147 self._rank_list = tuple(range(np.prod(np.array(mesh_shape)))) 

148 else: 

149 self._rank_list = tuple(rank_list) 

150 self._partial = [None] * len(mesh_shape) # partial status for each dev dim 

151 self._support_partial_op = ['sum', 'max', 'min', 'avg', 'prod', 'all', None] 

152 self._alias_tensor_map = None 

153 self._mesh = _create_device_mesh("npu", mesh_shape, mesh_dim_names=alias_name, rank_list=self._rank_list, 

154 init_backend=init_backend) 

155 self._compact_str = self._to_compact_string() 

156 self._placements = None 

157 self.partial_ops = {} # Initialized in _build_dim_map_from_placements() 

158 

159 @classmethod 

160 def from_device_mesh(cls, device_mesh: DeviceMesh) -> 'Layout': 

161 """ 

162 Create a Layout from an existing DeviceMesh. 

163 

164 Args: 

165 device_mesh (DeviceMesh): The device mesh to create layout from. 

166 

167 Returns: 

168 Layout: A new Layout instance initialized with the properties of the provided device mesh. 

169 

170 Examples: 

171 >>> from hyper_parallel.core.dtensor.layout import Layout, DeviceMesh 

172 >>> device_mesh = DeviceMesh("npu", (2, 2), mesh_dim_names=("dp", "mp")) 

173 >>> layout = Layout.from_device_mesh(device_mesh) 

174 """ 

175 obj = cls.__new__(cls) 

176 obj._mesh = device_mesh 

177 obj._alias_name = device_mesh.mesh_dim_names 

178 obj._rank_list = device_mesh.rank_list 

179 obj._tensor_map = None 

180 obj._partial = [None] * len(device_mesh.mesh_shape) 

181 obj._support_partial_op = ['sum', 'max', 'min', 'avg', 'prod', 'all', None] 

182 obj._alias_tensor_map = None 

183 obj._placements = None 

184 obj._compact_str = obj._to_compact_string() 

185 return obj 

186 

187 def __call__(self, *alias_tensor_map): 

188 obj = copy.deepcopy(self) 

189 

190 # Clear the inherited partial status. 

191 # When creating a new layout mapping configuration via __call__, 

192 # it should not inherit the dynamic execution state (Partial) of the original layout. 

193 # If the user intends to create a Partial placement, it will be parsed from alias_tensor_map. 

194 obj._partial = [None] * len(obj.mesh_shape) 

195 

196 if len(alias_tensor_map) == 1 and isinstance(alias_tensor_map[0], (list, tuple)): 

197 if len(alias_tensor_map[0]) > 0 and isinstance(alias_tensor_map[0][0], Placement): 

198 return self._process_placement_layout(obj, alias_tensor_map[0]) 

199 

200 if len(alias_tensor_map) > 0 and isinstance(alias_tensor_map[0], Placement): 

201 return self._process_placement_layout(obj, alias_tensor_map) 

202 

203 return self._process_alias_layout(obj, alias_tensor_map) 

204 

205 def __deepcopy__(self, memo): 

206 """Deep copy layout without rebuilding the underlying device mesh.""" 

207 cls = self.__class__ 

208 result = cls.__new__(cls) 

209 memo[id(self)] = result 

210 for k, v in self.__dict__.items(): 

211 setattr(result, k, copy.deepcopy(v, memo)) 

212 return result 

213 

214 @staticmethod 

215 def _process_placement_layout(obj, placements): 

216 """Process layout defined by Placement types.""" 

217 obj.set_placements(placements) 

218 return copy.deepcopy(obj) 

219 

220 @staticmethod 

221 def _process_alias_layout(obj, alias_tensor_map): 

222 """Process layout defined by alias strings.""" 

223 obj.set_alias_tensor_map(alias_tensor_map) 

224 tensor_map = () 

225 writed_map = () 

226 for ele in alias_tensor_map: 

227 if isinstance(ele, tuple): 

228 ele_map = () 

229 for item in ele: 

230 if item == "None": 

231 ele_map += (-1,) 

232 continue 

233 if item not in obj.alias_name: 

234 raise ValueError(f'The axis {item} is not found in {obj.alias_name}') 

235 if item in writed_map: 

236 raise ValueError(f'The axis {item} has been set more than one in {obj.alias_name}') 

237 ele_map += (len(obj.alias_name) - 1 - obj.alias_name.index(item),) 

238 writed_map += (item,) 

239 tensor_map += (ele_map,) 

240 continue 

241 if ele == "None": 

242 tensor_map += (-1,) 

243 continue 

244 if ele not in obj.alias_name: 

245 raise ValueError(f'The axis {ele} is not found in {obj.alias_name}') 

246 if ele in writed_map: 

247 raise ValueError(f'The axis {ele} has been set more than one in {obj.alias_name}') 

248 tensor_map += (len(obj.alias_name) - 1 - obj.alias_name.index(ele),) 

249 writed_map += (ele,) 

250 obj.set_tensor_map(tensor_map) 

251 obj.tensor_map_to_placement() 

252 obj.update_compact_str() 

253 return copy.deepcopy(obj) 

254 

255 def to_dict(self): 

256 """ 

257 Transform layout to a dictionary. 

258 """ 

259 if self._mesh.mesh_shape is None: 

260 raise ValueError("The device_shape of layout is None") 

261 if self._tensor_map is None: 

262 raise ValueError("The tensor_map of layout is None") 

263 interleaved_parallel = "interleaved_parallel" in self._mesh.mesh_dim_names 

264 return {"mesh_shape": self._mesh.mesh_shape, "tensor_map": self._tensor_map, 

265 "interleaved_parallel": interleaved_parallel, "alias_name": self._mesh.mesh_dim_names, 

266 "rank_list": self._rank_list} 

267 

268 def placement_to_tensor_map(self, dim): 

269 """ 

270 Transform placement to tensor map. 

271 

272 This method converts the `placements` configuration (consisting of Shard, StridedShard, 

273 Replicate, Partial) 

274 into a `tensor_map` representation used for distributed tensor operations. 

275 

276 Args: 

277 dim (int): The dimension of the tensor. Must be a positive integer. 

278 

279 Returns: 

280 tuple: A tuple representing the tensor map, where each element corresponds to a tensor dimension. 

281 A value of -1 indicates the dimension is not sharded, an integer indicates the mesh 

282 dimension index along which the tensor dimension is sharded, and a tuple indicates 

283 that the same tensor dimension is sharded multiple times in order. 

284 

285 Raises: 

286 ValueError: If `dim` is negative. 

287 ValueError: If a shard dimension in `placements` is out of bounds for the given tensor dimension. 

288 """ 

289 if dim < 0: 

290 raise ValueError(f"Tensor dimension must be positive, but got {dim}") 

291 if dim == 0: 

292 return self._handle_zero_dim_placement() 

293 

294 dim_map = self._build_dim_map_from_placements(dim) 

295 tensor_map = self._convert_dim_map_to_tensor_map(dim_map) 

296 self.set_tensor_map(tuple(tensor_map)) 

297 self._alias_tensor_map = self._build_readable_tensor_map() 

298 self.update_compact_str() 

299 return tensor_map 

300 

301 def _handle_zero_dim_placement(self): 

302 """Handle the special case of zero-dimensional tensor.""" 

303 self.set_tensor_map(()) 

304 self._alias_tensor_map = () 

305 for mesh_idx, placement in enumerate(self.placements): 

306 if isinstance(placement, Partial): 

307 self._partial[mesh_idx] = self._extract_reduce_op(placement) 

308 return [] 

309 

310 def _build_dim_map_from_placements(self, dim): 

311 """Build dimension map from placements.""" 

312 dim_map = [-1] * dim 

313 self.partial_ops = {} 

314 for mesh_idx, placement in enumerate(self.placements): 

315 if isinstance(placement, Shard): 

316 shard_dim = placement.dim 

317 if shard_dim < -dim or shard_dim >= dim: 

318 raise ValueError(f"Shard dimension {shard_dim} is out of bounds for tensor of dimension {dim}") 

319 if shard_dim < 0: 

320 shard_dim += dim 

321 if dim_map[shard_dim] == -1: 

322 dim_map[shard_dim] = [mesh_idx] 

323 else: 

324 dim_map[shard_dim].append(mesh_idx) 

325 elif isinstance(placement, Partial): 

326 self._partial[mesh_idx] = self._extract_reduce_op(placement) 

327 self._validate_strided_shard_split_factor(dim_map) 

328 self._reorder_dim_map_for_strided_shard(dim_map) 

329 return dim_map 

330 

331 @staticmethod 

332 def _placement_split_factor(placement): 

333 """Return the effective split factor carried by a placement.""" 

334 return placement.split_factor if isinstance(placement, StridedShard) else 1 

335 

336 @staticmethod 

337 def _build_order_positions(shard_order): 

338 """Build a mesh axis to order position mapping.""" 

339 return {mesh_idx: order_idx for order_idx, mesh_idx in enumerate(shard_order)} 

340 

341 def _compute_expected_split_factors(self, shard_axes, shard_order): 

342 """Infer the split_factor each mesh axis should carry for the given sharding order.""" 

343 order_positions = self._build_order_positions(shard_order) 

344 expected_split_factors = {} 

345 for mesh_idx in shard_axes: 

346 split_factor = 1 

347 for right_mesh_idx in shard_axes: 

348 if right_mesh_idx <= mesh_idx: 

349 continue 

350 if order_positions[right_mesh_idx] < order_positions[mesh_idx]: 

351 split_factor *= self.mesh_shape[right_mesh_idx] 

352 expected_split_factors[mesh_idx] = split_factor 

353 return expected_split_factors 

354 

355 def _get_effective_shard_axes(self, shard_axes): 

356 """Return shard axes ordered by their effective sharding order.""" 

357 return sorted( 

358 shard_axes, 

359 key=lambda mesh_idx: self._placement_split_factor(self.placements[mesh_idx]), 

360 ) 

361 

362 def _reorder_dim_map_for_strided_shard(self, dim_map): 

363 """Reorder dim_map entries to reflect the effective sharding order.""" 

364 for i, shard_axes in enumerate(dim_map): 

365 if shard_axes == -1 or len(shard_axes) <= 1: 

366 continue 

367 dim_map[i] = self._get_effective_shard_axes(shard_axes) 

368 

369 def _validate_strided_shard_split_factor(self, dim_map): 

370 """Validate that split factors match the effective sharding order.""" 

371 for shard_axes in dim_map: 

372 if shard_axes == -1: 

373 continue 

374 shard_order = self._get_effective_shard_axes(shard_axes) 

375 expected_split_factors = self._compute_expected_split_factors( 

376 shard_axes, shard_order 

377 ) 

378 for mesh_idx in shard_axes: 

379 placement = self.placements[mesh_idx] 

380 actual_split_factor = self._placement_split_factor(placement) 

381 expected_split_factor = expected_split_factors[mesh_idx] 

382 if actual_split_factor != expected_split_factor: 

383 raise ValueError( 

384 f"StridedShard split_factor mismatch on mesh axis {mesh_idx}: " 

385 f"expected {expected_split_factor}, got {actual_split_factor}." 

386 ) 

387 

388 @staticmethod 

389 def _extract_reduce_op(placement): 

390 """Extract reduce operation name from Partial placement.""" 

391 op_name = getattr(placement, "reduce_op", "sum") 

392 if isinstance(op_name, str): 

393 op_name = op_name.lower() 

394 return op_name 

395 

396 def _convert_dim_map_to_tensor_map(self, dim_map): 

397 """Convert dimension map to tensor map format.""" 

398 device_dim_count = len(self.mesh_shape) 

399 tensor_map = [] 

400 for mesh_idx in dim_map: 

401 if mesh_idx == -1: 

402 tensor_map.append(-1) 

403 continue 

404 mapped_axes = tuple(device_dim_count - 1 - axis for axis in mesh_idx) 

405 tensor_map.append(mapped_axes[0] if len(mapped_axes) == 1 else mapped_axes) 

406 return tensor_map 

407 

408 def _build_readable_tensor_map(self): 

409 """Build human-readable alias tensor map from tensor_map.""" 

410 mesh_dim_names = self._mesh.mesh_dim_names 

411 has_names = mesh_dim_names is not None 

412 

413 def _map_dim(dim): 

414 """covert dimension index to dimension name.""" 

415 if dim == -1: 

416 return "None" 

417 if not has_names: 

418 return f"dim_{dim}" 

419 return mesh_dim_names[len(mesh_dim_names) - 1 - dim] 

420 

421 readable_map = [] 

422 for item in self._tensor_map: 

423 if isinstance(item, tuple): 

424 mapped_tuple = tuple(_map_dim(dim) for dim in item) 

425 readable_map.append(mapped_tuple) 

426 else: 

427 readable_map.append(_map_dim(item)) 

428 return tuple(readable_map) 

429 

430 def tensor_map_to_placement(self): 

431 """ 

432 Transform tensor map to placement. 

433 

434 This method converts the existing `tensor_map` and `partial` status into a list of `Placement` objects 

435 (Shard, StridedShard, Replicate, Partial). This is the inverse operation of 

436 `placement_to_tensor_map`. 

437 

438 Returns: 

439 list[Placement]: A list of Placement objects describing the distribution strategy for each 

440 dimension of the device mesh. 

441 

442 Raises: 

443 ValueError: If `tensor_map` is not configured (None). 

444 """ 

445 if self._tensor_map is None: 

446 raise ValueError("The tensor_map is None, cannot transform to placements.") 

447 mesh_ndim = len(self.mesh_shape) 

448 placements = [Replicate()] * mesh_ndim 

449 for tensor_dim, mapping in enumerate(self._tensor_map): 

450 mapping_list = mapping if isinstance(mapping, tuple) else (mapping,) 

451 valid_mapping = [map_val for map_val in mapping_list if map_val != -1] 

452 mesh_indices = [mesh_ndim - 1 - map_val for map_val in valid_mapping] 

453 shard_axes = sorted(mesh_indices) 

454 expected_split_factors = self._compute_expected_split_factors( 

455 shard_axes, mesh_indices 

456 ) 

457 for mesh_idx in shard_axes: 

458 split_factor = expected_split_factors[mesh_idx] 

459 placement = ( 

460 StridedShard(dim=tensor_dim, split_factor=split_factor) 

461 if split_factor > 1 

462 else Shard(dim=tensor_dim) 

463 ) 

464 placements[mesh_idx] = placement 

465 for mesh_idx, op in enumerate(self.partial): 

466 if op is not None: 

467 placements[mesh_idx] = Partial(reduce_op=op) 

468 self.set_placements(placements) 

469 return placements 

470 

471 def __setstate__(self, state): 

472 self.__dict__.update(state) 

473 self.update_mesh(init_backend=False) 

474 

475 @property 

476 def mesh(self): 

477 """ 

478 Get the device mesh associated with this layout. 

479 

480 Returns: 

481 DeviceMesh: The device mesh describing the device topology. 

482 """ 

483 return self._mesh 

484 

485 def update_mesh(self, init_backend: bool = True): 

486 """Recreate the internal DeviceMesh from current layout properties. 

487 

488 Args: 

489 init_backend (bool): Whether to initialize communication backend 

490 (process groups). Set to ``False`` during deserialization to 

491 avoid creating process groups with a stale rank_list from the 

492 sender side. Default ``True``. 

493 """ 

494 self._mesh = _create_device_mesh("npu", self.mesh_shape, mesh_dim_names=self.alias_name, 

495 rank_list=self.rank_list, init_backend=init_backend) 

496 

497 @property 

498 def rank_list(self): 

499 """ 

500 Get the list of ranks participating in this layout. 

501 

502 Returns: 

503 tuple[int]: The rank list. 

504 """ 

505 return self._rank_list 

506 

507 @rank_list.setter 

508 def rank_list(self, val): 

509 self._rank_list = val 

510 

511 @property 

512 def mesh_shape(self): 

513 """mesh shape""" 

514 return self._mesh.mesh_shape 

515 

516 @property 

517 def alias_name(self): 

518 """alias name""" 

519 return self._mesh.mesh_dim_names 

520 

521 @property 

522 def alias_tensor_map(self): 

523 return self._alias_tensor_map 

524 

525 @property 

526 def alias_placements(self): 

527 """Return alias_tensor_map when it contains multi-axis tuples, otherwise placements. 

528 

529 alias_tensor_map preserves multi-axis ordering information 

530 (e.g., (("dp", "tp"), "None") vs (("tp", "dp"), "None")) 

531 that Placement objects cannot represent, since both map to 

532 [Shard(0), Shard(0)]. 

533 

534 For single-axis layouts, Placement objects are preferred because they 

535 also carry Partial status which alias_tensor_map cannot encode. 

536 

537 Use this property when constructing DTensors from an existing Layout 

538 to avoid the lossy Placement round-trip for multi-axis cases. 

539 """ 

540 if self._alias_tensor_map is not None and any( 

541 isinstance(item, tuple) for item in self._alias_tensor_map 

542 ): 

543 return self._alias_tensor_map 

544 return self._placements 

545 

546 def set_alias_tensor_map(self, alias_tensor_map): 

547 """Set alias_tensor_map""" 

548 self._alias_tensor_map = alias_tensor_map 

549 

550 @property 

551 def placements(self): 

552 """placements""" 

553 return self._placements 

554 

555 def set_placements(self, placements): 

556 """Set placements.""" 

557 self._placements = placements 

558 

559 @property 

560 def tensor_map(self): 

561 """tensor map""" 

562 return self._tensor_map 

563 

564 def set_tensor_map(self, tensor_map): 

565 """Set tensor_map.""" 

566 self._tensor_map = tensor_map 

567 

568 @property 

569 def partial(self): 

570 """partial status""" 

571 return self._partial 

572 

573 def set_partial_by_dev_axis(self, axis, op): 

574 """Set the partial status for the specified dev ID, means pending to do reduce by op.""" 

575 if op not in self._support_partial_op: 

576 raise ValueError(f"Partial op must be one of {self._support_partial_op}, but got {op}") 

577 if self.is_dev_axis_apply_shard(axis): 

578 raise ValueError("Partial dim must be replicate.") 

579 self._partial[self._mesh.axis_index(axis)] = op 

580 self.tensor_map_to_placement() 

581 self.update_compact_str() 

582 

583 def get_partial_by_dev_id(self, axis): 

584 """Get the partial status for the specified dev id""" 

585 return self.partial[self._mesh.axis_index(axis)] 

586 

587 def is_dev_axis_apply_shard(self, axis): 

588 """Return true if device axis is applying shard""" 

589 axis_id = self._mesh.axis_id(axis) 

590 

591 def flatten(input_x): 

592 flatten_res = [] 

593 for item in input_x: 

594 if isinstance(item, tuple): 

595 flatten_res.extend(flatten(item)) 

596 else: 

597 flatten_res.append(item) 

598 return flatten_res 

599 

600 flatten_tensor_map = flatten(self.tensor_map) 

601 return axis_id in flatten_tensor_map 

602 

603 def get_dev_axis_apply_shard_axis(self, axis): 

604 """Return the axis which be split by axis. If axis not be apply to shard, return None.""" 

605 for dim, dim_map in enumerate(self.alias_tensor_map): 

606 if (isinstance(dim_map, tuple) and axis in dim_map) or axis == dim_map: 

607 return dim 

608 return None 

609 

610 def reset_partial(self): 

611 self._partial = [None] * len(self.mesh_shape) 

612 self.tensor_map_to_placement() 

613 self.update_compact_str() 

614 

615 def is_partial(self): 

616 """Return true if any dim in mesh_shape is partial""" 

617 return any(self.partial) 

618 

619 def get_global_shape(self, slice_shape): 

620 """get global shape""" 

621 return self._mesh.get_global_shape(slice_shape, self._tensor_map) 

622 

623 def get_devices_for_axis(self, axis, rank): 

624 """ 

625 Get the repeat rank list when the axis is not shard. 

626 

627 Args: 

628 layout (Layout): Layout 

629 axis (str): Axis name. 

630 rank (int): Global rank 

631 

632 Returns: 

633 list: reduce rank list 

634 """ 

635 return self._mesh.get_devices_for_axis(axis, rank) 

636 

637 def get_comm_group_by_axis(self, axis): 

638 return self._mesh.get_comm_group_by_axis(axis) 

639 

640 def repeat_num(self): 

641 """ 

642 Number of repeated placements. 

643 For example: 

644 layout = Layout((2, 4), ("dp", "mp")) 

645 x_layout = layout("dp", "None") 

646 The repeat_num is equal to all device num 8 divided by device num corresponding to used axis 2, that is 4. 

647 """ 

648 if self._tensor_map is None: 

649 raise ValueError(f"The tensor_map is None, the mesh_shape is {self._mesh.mesh_shape}," 

650 f" alias_name is {self._mesh.mesh_dim_names}") 

651 

652 all_device_num = functools.reduce(lambda x, y: x * y, self._mesh.mesh_shape) 

653 used_dev_num = 1 

654 for ele in self._tensor_map: 

655 if isinstance(ele, tuple): 

656 for item in ele: 

657 if item >= 0: 

658 used_dev_num *= self._mesh.mesh_shape[len(self._mesh.mesh_shape) - item - 1] 

659 continue 

660 if ele >= 0: 

661 used_dev_num *= self._mesh.mesh_shape[len(self._mesh.mesh_shape) - ele - 1] 

662 

663 return all_device_num // used_dev_num 

664 

665 def _to_compact_string(self): 

666 """ 

667 generate dict key 

668 

669 Returns: 

670 str: string for compact 

671 """ 

672 mesh_key = self._mesh.to_hash() 

673 hash_key = (self._tensor_map, self.partial) 

674 hash_key += mesh_key 

675 return str(hash_key) 

676 

677 @property 

678 def compact_str(self): 

679 return self._compact_str 

680 

681 def update_compact_str(self): 

682 self._compact_str = self._to_compact_string() 

683 

684 def to_string(self): 

685 """ 

686 layout dump 

687 

688 Returns: 

689 str: layout string 

690 """ 

691 device_info = f"Mesh shape: {self._mesh.mesh_shape}" 

692 alias_info = f"Alias Names: {self._mesh.mesh_dim_names}" 

693 rank_info = f"Rank List: {self._rank_list}" 

694 partial_info = f"Partial: {self.partial}" 

695 

696 if self._tensor_map is None: 

697 tensor_info = "Tensor Map: Not configured" 

698 else: 

699 readable_map = [] 

700 for item in self._tensor_map: 

701 if isinstance(item, tuple): 

702 # handle nested tuple 

703 mapped_tuple = tuple( 

704 self._mesh.mesh_dim_names[len(self._mesh.mesh_dim_names) - 1 - dim] if dim != -1 else "None" 

705 for dim in item 

706 ) 

707 readable_map.append(mapped_tuple) 

708 else: 

709 readable_map.append( 

710 self._mesh.mesh_dim_names[len(self._mesh.mesh_dim_names) - 1 - item] if item != -1 else "None" 

711 ) 

712 

713 tensor_info = f"Tensor Map: {tuple(readable_map)}" 

714 

715 interleaved = "Yes" if "interleaved_parallel" in self._mesh.mesh_dim_names else "No" 

716 interleaved_info = f"Interleaved Parallel: {interleaved}" 

717 

718 return ( 

719 f"Layout Configuration:\n" 

720 f" {device_info}\n" 

721 f" {alias_info}\n" 

722 f" {partial_info}\n" 

723 f" {tensor_info}\n" 

724 f" {interleaved_info}\n" 

725 f" {rank_info}" 

726 ) 

727 

728 def __str__(self): 

729 """__str__""" 

730 return self.to_string() 

731 

732 def __repr__(self): 

733 """__repr__""" 

734 return f"<Layout at {hex(id(self))}>" 

735 

736 def __eq__(self, other): 

737 """ 

738 __eq__ 

739 """ 

740 if not isinstance(other, Layout): 

741 return False 

742 

743 if (self.mesh_shape != other.mesh_shape or 

744 self.alias_name != other.alias_name or 

745 self.partial != other.partial or 

746 self.rank_list != other.rank_list): 

747 return False 

748 

749 if self._tensor_map is None or other.tensor_map is None: 

750 return self._tensor_map is other.tensor_map 

751 return self._tensor_map == other.tensor_map