Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / device_mesh.py: 55%

780 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""device mesh""" 

16 

17import copy 

18import os 

19import threading 

20from types import TracebackType 

21from typing import Any, List, Literal, Optional, Sequence, Type, Union 

22import numpy as np 

23 

24from hyper_parallel.core.dtensor._mesh_layout import IntTuple, _MeshLayout, _contiguous_strides, _is_int 

25from hyper_parallel.platform import get_platform 

26from hyper_parallel.platform.platform import EXISTING_COMM_GROUPS, PlatformType 

27 

28platform = get_platform() 

29Tensor = platform.Tensor 

30 

31 

32class _MeshEnv(threading.local): 

33 """Per-thread stack of active :class:`DeviceMesh` (PyTorch ``_mesh_resources`` parity).""" 

34 

35 def __init__(self) -> None: 

36 super().__init__() 

37 self.mesh_stack: List["DeviceMesh"] = [] 

38 

39 def get_current_mesh(self) -> "DeviceMesh": 

40 """Return the innermost active :class:`DeviceMesh` for this thread (PyTorch parity).""" 

41 if len(self.mesh_stack) == 0: 

42 raise RuntimeError("No device mesh is currently active!") 

43 return self.mesh_stack[-1] 

44 

45 

46_mesh_resources = _MeshEnv() 

47 

48BackendConfig = Optional[str] 

49 

50 

51def _get_sub_rank_list(mesh_shape, mesh_dim_names, rank_list, sub_mesh_dim_names, current_rank): 

52 """ 

53 Get the sub rank list for a sub mesh. 

54 

55 Args: 

56 mesh_shape (tuple[int]): The shape of the original mesh. 

57 mesh_dim_names (tuple[str]): The mesh dim names of the original mesh dimensions. 

58 rank_list (tuple[int]): A tuple of ranks that participate in this mesh. 

59 sub_mesh_dim_names (tuple[str]): The mesh dim names of the sub mesh to extract. 

60 current_rank (int): The current process rank. 

61 

62 Returns: 

63 list: The sub rank list for the sub mesh. 

64 """ 

65 mesh_tensor = np.array(rank_list).reshape(mesh_shape) 

66 

67 for dim_index, dim_name in enumerate(mesh_dim_names): 

68 if dim_name in sub_mesh_dim_names: 

69 continue 

70 

71 dim_size = mesh_shape[dim_index] 

72 sliced_tensors = np.split(mesh_tensor, dim_size, axis=dim_index) 

73 

74 for sliced_tensor in sliced_tensors: 

75 rank_exists = np.isin(np.array([current_rank]), sliced_tensor).any() 

76 if rank_exists: 

77 mesh_tensor = sliced_tensor 

78 break 

79 

80 sub_rank_list = mesh_tensor.reshape(-1).tolist() 

81 return sub_rank_list 

82 

83 

84def _normalize_backend_value(value: Any) -> BackendConfig: 

85 if value is None: 

86 return None 

87 if isinstance(value, str): 

88 return value 

89 if isinstance(value, tuple) and len(value) > 0: 

90 backend = value[0] 

91 if backend is None or isinstance(backend, str): 

92 return backend 

93 return None 

94 

95 

96def _normalize_backend_override( 

97 backend_override: dict[Union[int, str], Any], 

98 ndim: int, 

99 mesh_dim_names: Optional[tuple[str, ...]] = None, 

100) -> tuple[BackendConfig, ...]: 

101 """Normalize backend overrides by dim index/name.""" 

102 remaining = dict(backend_override) 

103 normalized: list[BackendConfig] = [] 

104 mesh_dim_names = mesh_dim_names or () 

105 

106 for dim_idx in range(ndim): 

107 dim_name = mesh_dim_names[dim_idx] if dim_idx < len(mesh_dim_names) else None 

108 if dim_name is not None and dim_name in remaining: 

109 if dim_idx in remaining: 

110 raise RuntimeError( 

111 f"Found redundant dim index {dim_idx} and name {dim_name} in backend_override" 

112 ) 

113 normalized.append(_normalize_backend_value(remaining.pop(dim_name))) 

114 elif dim_idx in remaining: 

115 normalized.append(_normalize_backend_value(remaining.pop(dim_idx))) 

116 else: 

117 normalized.append(None) 

118 

119 if remaining: 

120 raise RuntimeError( 

121 f"Found invalid keys in backend_override: got {list(remaining.keys())}, " 

122 f"expected integers in range [0, {ndim}) or one of {mesh_dim_names}" 

123 ) 

124 return tuple(normalized) 

125 

126 

127def _should_defer_group_init(sub_layout: _MeshLayout, backend_override: BackendConfig) -> bool: 

128 """Whether this mesh dimension should skip eager process-group creation.""" 

129 return backend_override == "fake" or sub_layout.numel() == 1 

130 

131 

132class DeviceMesh: 

133 """ 

134 Topological abstraction describing cluster devices. 

135 

136 Args: 

137 device_type (str): Device type. Valid values depend on the active platform: 

138 

139 - **PyTorch** (same as ``torch.distributed.device_mesh.DeviceMesh``): 

140 ``"cpu"``, ``"cuda"``, ``"npu"``. 

141 - **MindSpore** (mapped to the corresponding communication backend): 

142 ``"cpu"`` → mccl, ``"gpu"`` → nccl, ``"npu"`` → hccl. 

143 mesh (Union[Tensor, list, tuple, np.ndarray, None]): A multi-dimensional array, list, or integer 

144 tensor describing the device layout. The IDs in the mesh are global IDs of the 

145 default process group, representing the multi-dimensional networking structure 

146 of devices in distributed training (e.g., [[0,1],[2,3]] represents a 2x2 device mesh). 

147 If a list or non-int32 tensor is provided, it will be automatically converted 

148 to an int32 tensor. If None, a 1D mesh containing all ranks 

149 (i.e., ``[0, 1, ..., world_size-1]``) will be created automatically. 

150 mesh_dim_names (tuple[str]): A tuple[str] of mesh dim names for each dimension of mesh. 

151 _init_backend (boolean): Whether initial process group. 

152 

153 Attributes: 

154 ndim (int): Number of dimensions in the mesh. 

155 mesh_shape (tuple[int]): Shape of the device mesh. 

156 rank_list (tuple[int]): Flattened list of ranks from the mesh. 

157 root_mesh (DeviceMesh): The parent mesh if this is a sub mesh, None otherwise. 

158 sub_mesh (list[DeviceMesh]): List of child meshes created from this mesh. 

159 

160 Context manager: 

161 Use ``with device_mesh:`` to set the **current** mesh for this thread. 

162 """ 

163 

164 device_type: Literal["cpu", "cuda", "gpu", "npu"] 

165 mesh: Union[Tensor, list, tuple, np.ndarray] 

166 mesh_dim_names: Union[tuple[str, ...], list[str], None] 

167 

168 _VALID_DEVICE_TYPES = { 

169 PlatformType.PYTORCH: {"cpu", "cuda", "npu"}, 

170 PlatformType.MINDSPORE: {"cpu", "gpu", "npu"}, 

171 } 

172 

173 def __init__(self, 

174 device_type: Literal["cpu", "cuda", "gpu", "npu"], 

175 mesh: Union[Tensor, list, tuple, np.ndarray, None] = None, 

176 *, 

177 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None, 

178 _init_backend: bool = True, 

179 _layout: Optional[_MeshLayout] = None, 

180 _rank_map: Optional[Tensor] = None, 

181 _root_mesh: Optional['DeviceMesh'] = None, 

182 ): 

183 self._validate_device_type(device_type) 

184 self.device_type = device_type 

185 

186 if _init_backend: 

187 platform.init_process_group() 

188 

189 self._layout, self._rank_map = self._resolve_layout_and_rank_map(mesh, _layout, _rank_map) 

190 self._rank = platform.get_rank() 

191 self._root_mesh = _root_mesh 

192 self._refresh_mesh_view() 

193 self._set_mesh_dim_names(mesh_dim_names) 

194 self._initialize_runtime_state(_init_backend) 

195 if os.getenv("MS_SIMULATION_LEVEL") is None: 

196 self._coordinate_on_dim = self._compute_coordinate_on_dim() 

197 

198 @classmethod 

199 def _validate_device_type(cls, device_type: str) -> None: 

200 """Validate that the requested device type is supported on the active platform.""" 

201 valid_device_types = cls._VALID_DEVICE_TYPES.get(platform.platform_type) 

202 if valid_device_types is not None and device_type not in valid_device_types: 

203 raise ValueError( 

204 f"Invalid device_type '{device_type}' for {platform.platform_type.name} platform. " 

205 f"Valid device types are: {sorted(valid_device_types)}" 

206 ) 

207 

208 @classmethod 

209 def _resolve_layout_and_rank_map( 

210 cls, 

211 mesh: Union[Tensor, list, tuple, np.ndarray, None], 

212 layout: Optional[_MeshLayout], 

213 rank_map: Optional[Tensor], 

214 ) -> tuple[_MeshLayout, Tensor]: 

215 """Build the internal layout and rank map from either public or private constructor inputs.""" 

216 if mesh is not None and (layout is not None or rank_map is not None): 

217 raise TypeError("Cannot provide both explicit mesh and private _layout/_rank_map arguments.") 

218 

219 if mesh is None and (layout is None or rank_map is None): 

220 world_size = platform.get_world_size() 

221 mesh = list(range(world_size)) 

222 

223 if mesh is not None: 

224 mesh_tensor = cls._convert_mesh_to_tensor(mesh) 

225 if mesh_tensor.ndim == 0: 

226 raise ValueError("mesh must be at least 1-dimensional") 

227 return cls._build_layout_from_mesh(mesh_tensor), cls._build_rank_map_from_mesh(mesh_tensor) 

228 

229 rank_map_tensor = cls._convert_rank_map_to_tensor(rank_map) 

230 if layout is None or rank_map_tensor is None: 

231 raise TypeError("The mesh argument is required except for private _layout/_rank_map construction.") 

232 if not layout.check_non_overlap(): 

233 raise ValueError(f"Invalid overlapping layout {layout}.") 

234 return layout, rank_map_tensor 

235 

236 def _refresh_mesh_view(self) -> None: 

237 """Materialize the visible mesh tensor and the derived shape/rank metadata.""" 

238 full_mesh_np = self._layout.remap_to_numpy(platform.tensor_to_numpy(self._rank_map)) 

239 full_mesh = Tensor(full_mesh_np).int() 

240 self.mesh = self._get_mesh_tensor_from_full_mesh(full_mesh, current_rank=self._rank) 

241 self._mesh_shape = tuple(self.mesh.shape) 

242 self._rank_list = tuple(platform.tensor_to_numpy(self.mesh).reshape(-1).tolist()) 

243 self._flatten_rank_map = tuple(platform.tensor_to_numpy(self._rank_map).reshape(-1).tolist()) 

244 self._dev_num = np.prod(np.array(self._mesh_shape)) 

245 self._dev_rank = len(self._mesh_shape) 

246 

247 def _set_mesh_dim_names( 

248 self, 

249 mesh_dim_names: Union[tuple[str, ...], list[str], None], 

250 ) -> None: 

251 """Validate mesh dim names and build lookup tables for named access.""" 

252 self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None 

253 if self.mesh_dim_names is None: 

254 return 

255 

256 if len(self._mesh_shape) != len(self.mesh_dim_names): 

257 raise ValueError( 

258 f'mesh dimensions ({len(self._mesh_shape)}) should be equal to ' 

259 f'mesh_dim_names length ({len(self.mesh_dim_names)})' 

260 ) 

261 if len(set(self.mesh_dim_names)) != len(self.mesh_dim_names): 

262 raise ValueError(f'Each element of mesh_dim_names {self.mesh_dim_names} should be different') 

263 inter_key = "interleaved_parallel" 

264 if inter_key in self.mesh_dim_names and self.mesh_dim_names.index(inter_key) != len(self.mesh_dim_names) - 1: 

265 raise ValueError( 

266 "'interleaved_parallel' should be at the last dim of mesh_dim_names, means virtual sharding." 

267 ) 

268 self._dev_name_to_dev_id = { 

269 name: self._dev_rank - i - 1 for i, name in enumerate(self.mesh_dim_names) 

270 } 

271 self._dev_name_to_index = {name: i for i, name in enumerate(self.mesh_dim_names)} 

272 

273 def _initialize_runtime_state(self, init_backend: bool) -> None: 

274 """Initialize caches and optional process-group state for the mesh view.""" 

275 self._cache_rank_list_along_axis = {} 

276 self._global_shape_map = {} 

277 self._sub_mesh_cache = {} 

278 self._flatten_mapping: dict[str, 'DeviceMesh'] = {} 

279 self._ndim = len(self._mesh_shape) 

280 self._dim_group_backends = (None,) * self._ndim 

281 self._dim_group_sources = tuple((self, dim) for dim in range(self._ndim)) 

282 self._sub_mesh: List['DeviceMesh'] = [] 

283 if not init_backend: 

284 return 

285 self._dim_group_names = self._init_process_groups( 

286 self._mesh_shape, 

287 self.mesh_dim_names, 

288 self._rank_list, 

289 ) 

290 

291 @staticmethod 

292 def _build_layout_from_mesh(mesh: Tensor) -> _MeshLayout: 

293 mesh_shape = tuple(mesh.shape) 

294 return _MeshLayout(mesh_shape, _contiguous_strides(mesh_shape)) 

295 

296 @staticmethod 

297 def _build_rank_map_from_mesh(mesh: Tensor) -> Tensor: 

298 return Tensor(platform.tensor_to_numpy(mesh).reshape(-1)).int() 

299 

300 @staticmethod 

301 def _convert_rank_map_to_tensor(rank_map: Tensor) -> Tensor: 

302 if isinstance(rank_map, Tensor): 

303 rank_map_np = platform.tensor_to_numpy(rank_map) 

304 else: 

305 rank_map_np = np.array(rank_map) 

306 return Tensor(rank_map_np.reshape(-1).astype(np.int32)).int() 

307 

308 @staticmethod 

309 def _get_mesh_tensor_from_full_mesh(full_mesh: Tensor, current_rank: Optional[int] = None) -> Tensor: 

310 """Select the per-rank mesh view from a fully materialized layout remap.""" 

311 if full_mesh.shape[0] == 1: 

312 return full_mesh[0] 

313 

314 if current_rank is None: 

315 current_rank = platform.get_rank() 

316 

317 rank_coords = (full_mesh == current_rank).nonzero() 

318 if rank_coords.shape[0] > 0: 

319 return full_mesh[rank_coords[0, 0]] 

320 raise RuntimeError( 

321 "In order to get the mesh tensor of a DeviceMesh it needs to " 

322 "either have all its original dimensions or contain the local rank." 

323 ) 

324 

325 def _compute_coordinate_on_dim(self): 

326 """Compute the current rank coordinates inside this mesh view.""" 

327 return self._compute_coordinates_from_mesh(self.mesh, self._rank) 

328 

329 @staticmethod 

330 def _compute_coordinates_from_mesh( 

331 mesh_tensor: Tensor, 

332 rank: int, 

333 ): 

334 """Locate one rank inside a mesh tensor and return its coordinates.""" 

335 rank_coords = (mesh_tensor == rank).nonzero() 

336 if rank_coords.shape[0] not in (0, 1): 

337 raise AssertionError( 

338 f"rank_coords.shape[0] must be 0 or 1, got {rank_coords.shape[0]}" 

339 ) 

340 

341 if rank_coords.shape[0] == 0: 

342 return None 

343 

344 coords = rank_coords[0].tolist() 

345 return tuple(coords) 

346 

347 def size(self, mesh_dim=None) -> int: 

348 if mesh_dim is not None: 

349 return self.mesh.shape[mesh_dim] 

350 return self.mesh.numel() 

351 

352 def get_coordinate(self): 

353 return self._coordinate_on_dim if self._coordinate_on_dim else None 

354 

355 def __enter__(self) -> "DeviceMesh": 

356 _mesh_resources.mesh_stack.append(self) 

357 return self 

358 

359 def __exit__( 

360 self, 

361 exc_type: Optional[Type[BaseException]], 

362 exc_val: Optional[BaseException], 

363 exc_tb: Optional[TracebackType], 

364 ) -> None: 

365 _mesh_resources.mesh_stack.pop() 

366 

367 @staticmethod 

368 def _convert_mesh_to_tensor(mesh: Union[Tensor, list, tuple, np.ndarray]) -> Tensor: 

369 """Convert a public mesh input into an int32 platform tensor.""" 

370 if isinstance(mesh, Tensor): 

371 mesh = platform.tensor_to_numpy(mesh) 

372 elif isinstance(mesh, (list, tuple)): 

373 mesh = np.array(mesh) 

374 elif not isinstance(mesh, np.ndarray): 

375 raise TypeError( 

376 f"mesh must be Tensor, list, tuple or numpy array, but got {type(mesh)}" 

377 ) 

378 

379 mesh = mesh.astype(np.int32) 

380 return Tensor(mesh).int() 

381 

382 @staticmethod 

383 def _init_one_process_group(mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...], 

384 dim_name: str, rank_list: tuple[int, ...]) -> str: 

385 """Create one process-group family for the named mesh dimension.""" 

386 group_key = None 

387 split_ranks = set() 

388 if not isinstance(dim_name, tuple): 

389 dim_name = (dim_name,) 

390 for rank in rank_list: 

391 split_rank = _get_sub_rank_list(mesh_shape, mesh_dim_names, rank_list, dim_name, rank) 

392 sorted_rank = tuple(sorted(split_rank)) 

393 split_ranks.add(sorted_rank) 

394 if rank == platform.get_rank(): 

395 group_key = str(sorted_rank) 

396 split_ranks = sorted([list(item) for item in split_ranks]) 

397 platform.split_group(split_ranks=split_ranks) 

398 return group_key 

399 

400 @staticmethod 

401 def _build_dim_split_ranks( 

402 sub_layout: _MeshLayout, 

403 rank_map: Tensor, 

404 ) -> tuple[list[list[int]], Optional[str]]: 

405 """Build rank lists and the local cache key for one logical mesh axis.""" 

406 pg_ranks_by_dim = sub_layout.remap_to_numpy(platform.tensor_to_numpy(rank_map)) 

407 current_rank = platform.get_rank() 

408 split_ranks = [] 

409 split_ranks_set = set() 

410 group_key = None 

411 for dim_mesh in np.array(pg_ranks_by_dim): 

412 subgroup_ranks = tuple(int(rank) for rank in np.array(dim_mesh).reshape(-1).tolist()) 

413 subgroup_ranks_sorted = tuple(sorted(subgroup_ranks)) 

414 if subgroup_ranks_sorted not in split_ranks_set: 

415 split_ranks_set.add(subgroup_ranks_sorted) 

416 split_ranks.append(list(subgroup_ranks_sorted)) 

417 if current_rank in subgroup_ranks: 

418 if group_key is not None: 

419 raise RuntimeError( 

420 "Each device mesh dimension should get only one process group per rank." 

421 ) 

422 group_key = str(subgroup_ranks_sorted) 

423 split_ranks = sorted(split_ranks) 

424 return split_ranks, group_key 

425 

426 @staticmethod 

427 def _cache_group_if_needed(group_key: Optional[str], group: Any) -> None: 

428 if group_key is not None and group is not None and group_key not in EXISTING_COMM_GROUPS: 

429 EXISTING_COMM_GROUPS[group_key] = group 

430 

431 @staticmethod 

432 def _init_process_groups_for_layout( 

433 layout: _MeshLayout, 

434 rank_map: Tensor, 

435 mesh_dim_names: Union[tuple[str, ...], None], 

436 backend_override: Optional[tuple[BackendConfig, ...]] = None, 

437 ) -> list: 

438 """Initialize process groups for each top-level axis in the given layout.""" 

439 if mesh_dim_names is None: 

440 mesh_dim_names = tuple(f"dim_{dim}" for dim in range(len(layout))) 

441 if backend_override is None: 

442 backend_override = (None,) * len(layout) 

443 if len(backend_override) != len(layout): 

444 raise ValueError( 

445 f"backend_override length {len(backend_override)} must match layout rank {len(layout)}" 

446 ) 

447 

448 dim_group_names = [] 

449 for dim, sub_layout in enumerate(layout): 

450 split_ranks, group_key = DeviceMesh._build_dim_split_ranks(sub_layout, rank_map) 

451 if _should_defer_group_init(sub_layout, backend_override[dim]): 

452 dim_group_names.append(None) 

453 continue 

454 group = platform.split_group(split_ranks=split_ranks) 

455 DeviceMesh._cache_group_if_needed(group_key, group) 

456 dim_group_names.append(group_key) 

457 return dim_group_names 

458 

459 @staticmethod 

460 def _init_process_groups(mesh_shape: tuple[int, ...], mesh_dim_names: Union[tuple[str, ...], None], 

461 rank_list: tuple[int, ...], 

462 backend_override: Optional[tuple[BackendConfig, ...]] = None) -> list: 

463 layout = _MeshLayout(mesh_shape, _contiguous_strides(mesh_shape)) 

464 rank_map = DeviceMesh._convert_rank_map_to_tensor(rank_list) 

465 return DeviceMesh._init_process_groups_for_layout( 

466 layout, 

467 rank_map, 

468 mesh_dim_names, 

469 backend_override=backend_override, 

470 ) 

471 

472 @property 

473 def rank(self): 

474 return self._rank 

475 

476 @property 

477 def mesh_shape(self): 

478 return self._mesh_shape 

479 

480 @property 

481 def rank_list(self): 

482 return self._rank_list 

483 

484 @property 

485 def ndim(self) -> int: 

486 return self._ndim 

487 

488 @property 

489 def shape(self) -> tuple: 

490 return self._mesh_shape 

491 

492 @property 

493 def root_mesh(self) -> Optional['DeviceMesh']: 

494 return self._root_mesh 

495 

496 @root_mesh.setter 

497 def root_mesh(self, value: Optional['DeviceMesh']): 

498 self._root_mesh = value 

499 

500 @property 

501 def sub_mesh(self) -> List['DeviceMesh']: 

502 return self._sub_mesh 

503 

504 def get_flatten_mapping(self) -> dict: 

505 return self._flatten_mapping 

506 

507 def add_flatten_mapping(self, name: str, mesh: 'DeviceMesh') -> None: 

508 self._flatten_mapping[name] = mesh 

509 

510 def __getitem__(self, sub_mesh_dim_names: Union[str, tuple[str, ...]]) -> 'DeviceMesh': 

511 if not self.mesh_dim_names: 

512 raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!") 

513 

514 sub_mesh_dim_names = DeviceMesh._normalize_sub_mesh_dim_names(sub_mesh_dim_names) 

515 flatten_mapping = self._get_root_mesh().get_flatten_mapping() 

516 

517 flattened_result = self._try_get_from_flatten_mapping(sub_mesh_dim_names, flatten_mapping) 

518 if flattened_result is not None: 

519 return flattened_result 

520 

521 layout = self._get_slice_mesh_layout(sub_mesh_dim_names) 

522 if sub_mesh_dim_names in self._sub_mesh_cache: 

523 return self._sub_mesh_cache[sub_mesh_dim_names] 

524 if layout == self._layout: 

525 return self 

526 return self._create_and_cache_sub_mesh(sub_mesh_dim_names, layout) 

527 

528 @staticmethod 

529 def _normalize_sub_mesh_dim_names(sub_mesh_dim_names: Union[str, tuple[str, ...]]) -> tuple[str, ...]: 

530 """Normalize a slice selector into a non-empty tuple of mesh dim names.""" 

531 if isinstance(sub_mesh_dim_names, str): 

532 sub_mesh_dim_names = (sub_mesh_dim_names,) 

533 

534 if not isinstance(sub_mesh_dim_names, tuple): 

535 raise TypeError( 

536 f"sub_mesh_dim_names must be str or tuple, but got {type(sub_mesh_dim_names)}" 

537 ) 

538 

539 if len(sub_mesh_dim_names) == 0: 

540 raise ValueError("sub_mesh_dim_names cannot be empty") 

541 

542 return sub_mesh_dim_names 

543 

544 @staticmethod 

545 def _try_get_from_flatten_mapping(sub_mesh_dim_names: tuple[str, ...], 

546 flatten_mapping: dict) -> Optional['DeviceMesh']: 

547 if len(sub_mesh_dim_names) == 1 and sub_mesh_dim_names[0] in flatten_mapping: 

548 return flatten_mapping[sub_mesh_dim_names[0]] 

549 return None 

550 

551 def _get_mesh_dim_by_name(self, mesh_dim_name: str) -> int: 

552 """Resolve a named mesh axis to its integer position.""" 

553 mesh_dim_names = self.mesh_dim_names or () 

554 if len(mesh_dim_names) == 0: 

555 raise KeyError("No mesh_dim_names found.") 

556 if mesh_dim_name not in mesh_dim_names: 

557 raise KeyError( 

558 f"Mesh dimension '{mesh_dim_name}' does not exist. " 

559 f"Available mesh dimensions are: {mesh_dim_names}" 

560 ) 

561 return mesh_dim_names.index(mesh_dim_name) 

562 

563 def _get_slice_mesh_layout(self, sub_mesh_dim_names: tuple[str, ...]) -> _MeshLayout: 

564 """Construct the layout corresponding to one named sub-mesh slice request.""" 

565 root_mesh = self._get_root_mesh() 

566 slice_from_root = self == root_mesh 

567 flatten_name_to_layout = ( 

568 {key: mesh._layout for key, mesh in root_mesh.get_flatten_mapping().items()} 

569 if slice_from_root else {} 

570 ) 

571 valid_dim_names = [*(self.mesh_dim_names or ()), *flatten_name_to_layout] 

572 if not all(name in valid_dim_names for name in sub_mesh_dim_names): 

573 raise KeyError( 

574 f"Invalid mesh_dim_names {sub_mesh_dim_names} specified. " 

575 f"Valid mesh_dim_names are {valid_dim_names}." 

576 ) 

577 

578 if all(name in (self.mesh_dim_names or ()) for name in sub_mesh_dim_names): 

579 indices = [self.mesh_dim_names.index(name) for name in sub_mesh_dim_names] 

580 if indices != sorted(indices): 

581 raise ValueError( 

582 f"sub_mesh_dim_names {sub_mesh_dim_names} must follow the order of " 

583 f"original mesh_dim_names {self.mesh_dim_names}" 

584 ) 

585 

586 sliced_sizes: list[IntTuple] = [] 

587 sliced_strides: list[IntTuple] = [] 

588 for name in sub_mesh_dim_names: 

589 if name in (self.mesh_dim_names or ()): 

590 layout = self._layout[self.mesh_dim_names.index(name)] 

591 else: 

592 layout = flatten_name_to_layout[name] 

593 sliced_sizes.append(layout.sizes) 

594 sliced_strides.append(layout.strides) 

595 

596 pre_stride = -1 

597 for stride in reversed(sliced_strides): 

598 if not _is_int(stride): 

599 raise NotImplementedError( 

600 "Currently, this only allows slicing out a contiguous flattened dim." 

601 ) 

602 if stride < pre_stride: 

603 raise ValueError( 

604 f"Invalid mesh_dim_names {sub_mesh_dim_names} specified. " 

605 "Mesh dim indices should be in ascending order." 

606 ) 

607 pre_stride = stride 

608 

609 if len(sliced_sizes) == 1: 

610 layout = _MeshLayout(sliced_sizes[0], sliced_strides[0]) 

611 else: 

612 layout = _MeshLayout(tuple(sliced_sizes), tuple(sliced_strides)) 

613 if not layout.check_non_overlap(): 

614 raise RuntimeError(f"Slicing overlapping dim_names {sub_mesh_dim_names} is not allowed.") 

615 return layout 

616 

617 def _create_and_cache_sub_mesh(self, sub_mesh_dim_names: tuple[str, ...], layout: _MeshLayout) -> 'DeviceMesh': 

618 """Create a sub-mesh view, copy group metadata, and cache the result.""" 

619 root_mesh = self._get_root_mesh() 

620 sub_mesh = DeviceMesh( 

621 device_type=self.device_type, 

622 mesh_dim_names=sub_mesh_dim_names, 

623 _init_backend=False, 

624 _layout=layout, 

625 _rank_map=root_mesh._rank_map, 

626 _root_mesh=root_mesh, 

627 ) 

628 

629 slice_dim_group_name = [] 

630 slice_dim_group_backends: list[BackendConfig] = [] 

631 slice_dim_group_sources: list[tuple['DeviceMesh', int]] = [] 

632 for name in sub_mesh_dim_names: 

633 if name in (self.mesh_dim_names or ()): 

634 dim_index = self.mesh_dim_names.index(name) 

635 if hasattr(self, "_dim_group_names"): 

636 slice_dim_group_name.append(self._dim_group_names[dim_index]) 

637 slice_dim_group_backends.append(self._dim_group_backends[dim_index]) 

638 if hasattr(self, "_dim_group_sources"): 

639 slice_dim_group_sources.append(self._dim_group_sources[dim_index]) # pylint: disable=W0212 

640 else: 

641 slice_dim_group_sources.append((self, dim_index)) 

642 elif name in root_mesh.get_flatten_mapping(): 

643 flatten_mesh = root_mesh.get_flatten_mapping()[name] 

644 if hasattr(flatten_mesh, "_dim_group_names"): 

645 slice_dim_group_name.append(flatten_mesh._dim_group_names[0]) 

646 slice_dim_group_backends.append(flatten_mesh._dim_group_backends[0]) 

647 if hasattr(flatten_mesh, "_dim_group_sources"): 

648 slice_dim_group_sources.append(flatten_mesh._dim_group_sources[0]) # pylint: disable=W0212 

649 else: 

650 slice_dim_group_sources.append((flatten_mesh, 0)) 

651 if slice_dim_group_name: 

652 sub_mesh._dim_group_names = slice_dim_group_name # pylint: disable=W0212 

653 if slice_dim_group_backends: 

654 sub_mesh._dim_group_backends = tuple(slice_dim_group_backends) # pylint: disable=W0212 

655 if slice_dim_group_sources: 

656 sub_mesh._dim_group_sources = tuple(slice_dim_group_sources) # pylint: disable=W0212 

657 

658 self._sub_mesh_cache[sub_mesh_dim_names] = sub_mesh 

659 self.sub_mesh.append(sub_mesh) 

660 return sub_mesh 

661 

662 def get_group(self, mesh_dim: Optional[Union[int, str]] = None): 

663 """Return the communication group for one mesh axis.""" 

664 if not hasattr(self, "_dim_group_names"): 

665 raise RuntimeError("DeviceMesh process groups not initialized!") 

666 

667 if self.ndim > 1 and mesh_dim is None: 

668 raise RuntimeError( 

669 f"Found the DeviceMesh have {self.ndim} dimensions. " 

670 "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1." 

671 ) 

672 

673 root_mesh = self._get_root_mesh() 

674 if isinstance(mesh_dim, str) and mesh_dim in root_mesh.get_flatten_mapping(): 

675 flattened_mesh = root_mesh.get_flatten_mapping()[mesh_dim] 

676 return flattened_mesh.get_comm_group_by_axis(mesh_dim) 

677 

678 return self.get_comm_group_by_axis(mesh_dim) 

679 

680 def get_all_groups(self) -> list: 

681 if not hasattr(self, "_dim_group_names"): 

682 raise RuntimeError("DeviceMesh process groups not initialized!") 

683 

684 return [self.get_group(i) for i in range(self.ndim)] 

685 

686 @staticmethod 

687 def from_group(group: Union[Any, list[Any]], 

688 device_type: str, 

689 mesh: Union[Tensor, list, tuple, np.ndarray] = None, 

690 mesh_dim_names: Union[tuple[str, ...], list[str]] = None 

691 ) -> 'DeviceMesh': 

692 """Build a DeviceMesh from an existing process group or a list of groups.""" 

693 if not isinstance(group, list): 

694 group_ranks = platform.get_process_group_ranks(group) 

695 group_key = str(tuple(sorted(group_ranks))) 

696 if not platform.get_created_group(group_ranks): 

697 EXISTING_COMM_GROUPS[group_key] = group 

698 if ( 

699 isinstance(mesh, Tensor) and mesh.tolist() != group_ranks 

700 ) or ( 

701 mesh is not None 

702 and not isinstance(mesh, Tensor) 

703 and mesh != group_ranks 

704 ): 

705 raise ValueError( 

706 f"Invalid mesh_shape {str(mesh)} for 1D group with ranks {group_ranks}" 

707 ) 

708 device_mesh = DeviceMesh(device_type, group_ranks, mesh_dim_names=mesh_dim_names, _init_backend=False) 

709 device_mesh._dim_group_names = [group_key] # pylint: disable=W0212 

710 return device_mesh 

711 

712 groups = list(group) 

713 if len(groups) == 0: 

714 raise ValueError("Expect at least one group be specified.") 

715 if mesh is None: 

716 raise ValueError("mesh_shape is must specified when group is a list.") 

717 mesh = DeviceMesh._convert_mesh_to_tensor(mesh) 

718 if mesh.ndim != len(groups): 

719 raise ValueError("mesh dimensions must match group dimensions.") 

720 device_mesh = DeviceMesh(device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False) 

721 device_mesh._dim_group_names = [] # pylint: disable=W0212 

722 for dim_group in groups: 

723 group_ranks = platform.get_process_group_ranks(dim_group) 

724 group_key = str(tuple(sorted(group_ranks))) 

725 if not platform.get_created_group(group_ranks): 

726 EXISTING_COMM_GROUPS[group_key] = dim_group 

727 device_mesh._dim_group_names.append(group_key) # pylint: disable=W0212 

728 return device_mesh 

729 

730 def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: 

731 """Return the local coordinate of the current rank along one mesh dimension.""" 

732 if self.ndim > 1 and mesh_dim is None: 

733 raise RuntimeError( 

734 f"Found the DeviceMesh have {self.ndim} dimensions. " 

735 "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1." 

736 ) 

737 

738 if mesh_dim is None: 

739 mesh_dim = 0 

740 

741 if isinstance(mesh_dim, str): 

742 if mesh_dim not in self.mesh_dim_names: # pylint: disable=E1135 

743 raise ValueError( 

744 f"mesh_dim '{mesh_dim}' not found in mesh_dim_names {self.mesh_dim_names}" 

745 ) 

746 dim_index = self.mesh_dim_names.index(mesh_dim) 

747 else: 

748 if not isinstance(mesh_dim, int) or mesh_dim < 0 or mesh_dim >= self.ndim: 

749 raise ValueError( 

750 f"mesh_dim must be an integer in range [0, {self.ndim}), " 

751 f"but got {mesh_dim}" 

752 ) 

753 dim_index = mesh_dim 

754 

755 if self._rank not in self._rank_list: 

756 raise ValueError( 

757 f"Current rank {self._rank} not found in rank_list {self._rank_list}" 

758 ) 

759 

760 idx = self._rank_list.index(self._rank) 

761 coord = [0] * len(self._mesh_shape) 

762 temp = idx 

763 for i in range(len(self._mesh_shape) - 1, -1, -1): 

764 coord[i] = temp % self._mesh_shape[i] 

765 temp //= self._mesh_shape[i] 

766 

767 return coord[dim_index] 

768 

769 def flatten(self, mesh_dim_name: Optional[str] = None) -> 'DeviceMesh': 

770 return self._create_flatten_mesh(mesh_dim_name) 

771 

772 def _get_root_mesh(self) -> 'DeviceMesh': 

773 """Return the canonical root mesh for this view.""" 

774 if self._root_mesh is None: 

775 return self 

776 return self._root_mesh._get_root_mesh() # pylint: disable=protected-access 

777 

778 @staticmethod 

779 def _validate_concatenate_inputs( 

780 meshes: Sequence['DeviceMesh'], 

781 ) -> tuple['DeviceMesh', tuple[str, ...], tuple[int, ...]]: 

782 """Validate concatenate inputs and return the shared root metadata.""" 

783 if len(meshes) == 0: 

784 raise ValueError("DeviceMesh.concatenate expects at least one mesh.") 

785 if len(meshes) == 1: 

786 return meshes[0]._get_root_mesh(), tuple(meshes[0].mesh_dim_names or ()), meshes[0]._flatten_rank_map 

787 

788 root_mesh = meshes[0]._get_root_mesh() # pylint: disable=protected-access 

789 requested_dim_names: list[str] = [] 

790 flatten_rank_map = meshes[0]._flatten_rank_map # pylint: disable=protected-access 

791 for mesh in meshes: 

792 if mesh._get_root_mesh().to_hash() != root_mesh.to_hash(): # pylint: disable=protected-access 

793 raise ValueError("DeviceMesh.concatenate expects all meshes to share the same root mesh.") 

794 if mesh._flatten_rank_map != flatten_rank_map: # pylint: disable=protected-access 

795 raise ValueError("DeviceMesh.concatenate expects all meshes to share the same root mesh.") 

796 if not mesh.mesh_dim_names: 

797 raise ValueError("DeviceMesh.concatenate requires mesh_dim_names on every input mesh.") 

798 requested_dim_names.extend(mesh.mesh_dim_names) 

799 return root_mesh, tuple(requested_dim_names), flatten_rank_map 

800 

801 @staticmethod 

802 def _validate_concatenate_root_order(root_mesh: 'DeviceMesh', requested_dim_names: tuple[str, ...]) -> None: 

803 """Require original root dims to stay in root order when concatenating by name.""" 

804 root_dim_names = tuple(root_mesh.mesh_dim_names) if root_mesh.mesh_dim_names else () 

805 if not root_dim_names or not all(dim_name in root_dim_names for dim_name in requested_dim_names): 

806 return 

807 

808 requested_indices = [root_dim_names.index(dim_name) for dim_name in requested_dim_names] 

809 if requested_indices != sorted(requested_indices): 

810 raise ValueError( 

811 "DeviceMesh.concatenate expects meshes to follow the root mesh order. " 

812 f"Got root mesh dims {root_dim_names} and requested dims {requested_dim_names}." 

813 ) 

814 

815 @staticmethod 

816 def _collect_concatenate_metadata( 

817 meshes: Sequence['DeviceMesh'], 

818 ) -> tuple[ 

819 list[str], 

820 list[IntTuple], 

821 list[IntTuple], 

822 list[Optional[str]], 

823 list[BackendConfig], 

824 list[tuple['DeviceMesh', int]], 

825 ]: 

826 """Collect layout and process-group metadata from all concatenate inputs.""" 

827 concat_dim_names: list[str] = [] 

828 concat_sizes: list[IntTuple] = [] 

829 concat_strides: list[IntTuple] = [] 

830 concat_dim_group_names: list[Optional[str]] = [] 

831 concat_dim_group_backends: list[BackendConfig] = [] 

832 concat_dim_group_sources: list[tuple['DeviceMesh', int]] = [] 

833 

834 for mesh in meshes: 

835 for dim, sub_layout in enumerate(mesh._layout): # pylint: disable=protected-access 

836 concat_sizes.append(sub_layout.sizes) 

837 concat_strides.append(sub_layout.strides) 

838 if hasattr(mesh, "_dim_group_names"): 

839 concat_dim_group_names.append(mesh._dim_group_names[dim]) # pylint: disable=protected-access 

840 concat_dim_group_backends.append(mesh._dim_group_backends[dim]) # pylint: disable=protected-access 

841 if hasattr(mesh, "_dim_group_sources"): 

842 concat_dim_group_sources.append(mesh._dim_group_sources[dim]) # pylint: disable=protected-access 

843 else: 

844 concat_dim_group_sources.append((mesh, dim)) 

845 concat_dim_names.extend(mesh.mesh_dim_names) 

846 

847 if len(set(concat_dim_names)) != len(concat_dim_names): 

848 raise ValueError( 

849 f"DeviceMesh.concatenate expects disjoint mesh dims, but got {tuple(concat_dim_names)}." 

850 ) 

851 return ( 

852 concat_dim_names, 

853 concat_sizes, 

854 concat_strides, 

855 concat_dim_group_names, 

856 concat_dim_group_backends, 

857 concat_dim_group_sources, 

858 ) 

859 

860 @staticmethod 

861 def _build_concatenate_layout(concat_sizes: list[IntTuple], concat_strides: list[IntTuple]) -> _MeshLayout: 

862 """Build the layout represented by concatenated top-level mesh axes.""" 

863 if len(concat_sizes) == 1: 

864 return _MeshLayout(concat_sizes[0], concat_strides[0]) 

865 return _MeshLayout(tuple(concat_sizes), tuple(concat_strides)) 

866 

867 @staticmethod 

868 def _set_concatenated_group_state( 

869 mesh: 'DeviceMesh', 

870 dim_group_names: list[Optional[str]], 

871 dim_group_backends: list[BackendConfig], 

872 dim_group_sources: list[tuple['DeviceMesh', int]], 

873 ) -> None: 

874 """Attach inherited process-group metadata to a concatenated mesh view.""" 

875 if dim_group_names: 

876 mesh._dim_group_names = dim_group_names # pylint: disable=W0212 

877 if dim_group_backends: 

878 mesh._dim_group_backends = tuple(dim_group_backends) # pylint: disable=W0212 

879 if dim_group_sources: 

880 mesh._dim_group_sources = tuple(dim_group_sources) # pylint: disable=W0212 

881 

882 @staticmethod 

883 def concatenate(meshes: Sequence['DeviceMesh']) -> 'DeviceMesh': 

884 """Concatenate multiple sub-mesh views into one wider layout-backed mesh.""" 

885 if len(meshes) == 1: 

886 return meshes[0] 

887 root_mesh, requested_dim_names, _ = DeviceMesh._validate_concatenate_inputs(meshes) 

888 DeviceMesh._validate_concatenate_root_order(root_mesh, requested_dim_names) 

889 ( 

890 concat_dim_names, 

891 concat_sizes, 

892 concat_strides, 

893 concat_dim_group_names, 

894 concat_dim_group_backends, 

895 concat_dim_group_sources, 

896 ) = DeviceMesh._collect_concatenate_metadata(meshes) 

897 concat_layout = DeviceMesh._build_concatenate_layout(concat_sizes, concat_strides) 

898 if not concat_layout.check_non_overlap(): 

899 raise ValueError(f"Cannot concatenate overlapping meshes: {meshes}") 

900 

901 res_mesh = DeviceMesh( 

902 meshes[0].device_type, 

903 mesh_dim_names=tuple(concat_dim_names), 

904 _init_backend=False, 

905 _layout=concat_layout, 

906 _rank_map=meshes[0]._rank_map, # pylint: disable=protected-access 

907 _root_mesh=meshes[0]._get_root_mesh(), # pylint: disable=protected-access 

908 ) 

909 DeviceMesh._set_concatenated_group_state( 

910 res_mesh, 

911 concat_dim_group_names, 

912 concat_dim_group_backends, 

913 concat_dim_group_sources, 

914 ) 

915 return res_mesh 

916 

917 _concatenate = concatenate 

918 

919 def _create_flatten_mesh( 

920 self, 

921 mesh_dim_name: Optional[str] = None, 

922 backend_override: BackendConfig = None, 

923 ) -> 'DeviceMesh': 

924 """Create or reuse a flattened one-dimensional mesh view.""" 

925 root_mesh = self._get_root_mesh() 

926 

927 if mesh_dim_name is None: 

928 mesh_dim_name = "_".join(self.mesh_dim_names) 

929 

930 if self.ndim == 1 and mesh_dim_name in self.mesh_dim_names: # pylint: disable=E1135 

931 return self 

932 

933 invalid_dim_names = root_mesh.mesh_dim_names 

934 if mesh_dim_name in invalid_dim_names: 

935 raise ValueError( 

936 f"'{mesh_dim_name}' already exists in the root mesh mesh_dim_names " 

937 f"{invalid_dim_names}. Please specify another valid mesh_dim_name." 

938 ) 

939 

940 flattened_mesh_layout = self._layout.coalesce() 

941 if len(flattened_mesh_layout) > 1: 

942 flattened_mesh_layout = flattened_mesh_layout.nest() 

943 

944 flatten_mapping = root_mesh.get_flatten_mapping() 

945 if mesh_dim_name in flatten_mapping: 

946 cached_mesh = flatten_mapping[mesh_dim_name] 

947 if cached_mesh._layout == flattened_mesh_layout: # pylint: disable=protected-access 

948 return cached_mesh 

949 raise ValueError( 

950 f"Flatten mesh with mesh_dim_name '{mesh_dim_name}' has been created " 

951 f"before with different layout. Please specify another valid mesh_dim_name." 

952 ) 

953 

954 res_flattened_mesh = DeviceMesh( 

955 device_type=root_mesh.device_type, 

956 mesh_dim_names=(mesh_dim_name,), 

957 _init_backend=False, 

958 _layout=flattened_mesh_layout, 

959 _rank_map=root_mesh._rank_map, 

960 _root_mesh=root_mesh, 

961 ) 

962 res_flattened_mesh._dim_group_backends = (backend_override,) # pylint: disable=W0212 

963 if hasattr(self, "_dim_group_names"): 

964 res_flattened_mesh._dim_group_names = DeviceMesh._init_process_groups_for_layout( # pylint: disable=W0212 

965 res_flattened_mesh._layout, 

966 root_mesh._rank_map, 

967 res_flattened_mesh.mesh_dim_names, 

968 backend_override=(backend_override,), 

969 ) 

970 

971 root_mesh.add_flatten_mapping(mesh_dim_name, res_flattened_mesh) 

972 root_mesh._sub_mesh_cache[(mesh_dim_name,)] = res_flattened_mesh # pylint: disable=W0212 

973 root_mesh.sub_mesh.append(res_flattened_mesh) 

974 

975 return res_flattened_mesh 

976 

977 def _create_unflatten_mesh( 

978 self, 

979 dim: int, 

980 mesh_sizes: tuple[int, ...], 

981 mesh_dim_names: tuple[str, ...], 

982 backend_override: tuple[BackendConfig, ...], 

983 ) -> 'DeviceMesh': 

984 """Split one logical mesh axis into multiple named axes.""" 

985 inner_layout = _MeshLayout(mesh_sizes, _contiguous_strides(mesh_sizes)) 

986 original_layout = self._layout[dim] 

987 if inner_layout.numel() != original_layout.numel(): 

988 raise ValueError( 

989 f"The product of mesh_sizes={mesh_sizes} is {inner_layout.numel()}, " 

990 f"but the original dimension at dim={dim} has size {original_layout.numel()}." 

991 ) 

992 

993 partial_layout = original_layout.composition(inner_layout) 

994 unflattened_layout = self._layout.splice(dim, dim + 1, partial_layout) 

995 unflattened_mesh_dim_names = list(self.mesh_dim_names or ()) 

996 unflattened_mesh_dim_names[dim: dim + 1] = list(mesh_dim_names) 

997 

998 root_mesh = self._get_root_mesh() 

999 res_mesh = DeviceMesh( 

1000 self.device_type, 

1001 mesh_dim_names=tuple(unflattened_mesh_dim_names), 

1002 _init_backend=False, 

1003 _layout=unflattened_layout, 

1004 _rank_map=root_mesh._rank_map, 

1005 _root_mesh=root_mesh, 

1006 ) 

1007 

1008 dim_group_backends = list(self._dim_group_backends) 

1009 dim_group_backends[dim: dim + 1] = list(backend_override) 

1010 res_mesh._dim_group_backends = tuple(dim_group_backends) # pylint: disable=W0212 

1011 

1012 if hasattr(self, "_dim_group_names"): 

1013 dim_group_names = list(self._dim_group_names) 

1014 dim_group_names[dim: dim + 1] = DeviceMesh._init_process_groups_for_layout( 

1015 partial_layout, 

1016 root_mesh._rank_map, 

1017 mesh_dim_names, 

1018 backend_override=backend_override, 

1019 ) 

1020 res_mesh._dim_group_names = dim_group_names # pylint: disable=W0212 

1021 

1022 return res_mesh 

1023 

1024 def _flatten(self, mesh_dim_name: Optional[str] = None, backend_override: Any = None) -> 'DeviceMesh': 

1025 return self._create_flatten_mesh( 

1026 mesh_dim_name, 

1027 backend_override=_normalize_backend_value(backend_override), 

1028 ) 

1029 

1030 def _unflatten( 

1031 self, 

1032 dim: Union[int, str], 

1033 mesh_sizes: tuple[int, ...], 

1034 mesh_dim_names: tuple[str, ...], 

1035 backend_override: Optional[dict[Union[int, str], Any]] = None, 

1036 ) -> 'DeviceMesh': 

1037 """Torch-compatible helper that expands one mesh axis into a nested layout.""" 

1038 if isinstance(dim, int): 

1039 if dim < 0 or dim >= self.ndim: 

1040 raise ValueError(f"dim {dim} specified in `_unflatten` is out of range {self.ndim}") 

1041 else: 

1042 mesh_dim_names_tuple = self.mesh_dim_names or () 

1043 if dim not in mesh_dim_names_tuple: 

1044 raise ValueError(f"dim {dim} specified in `_unflatten` is not in {mesh_dim_names_tuple}") 

1045 dim = mesh_dim_names_tuple.index(dim) 

1046 

1047 if len(mesh_sizes) != len(mesh_dim_names): 

1048 raise RuntimeError("mesh_dim_names must have same length as mesh_sizes in _unflatten!") 

1049 

1050 backend_override_tuple = ( 

1051 _normalize_backend_override(backend_override, len(mesh_sizes), mesh_dim_names) 

1052 if backend_override is not None 

1053 else (None,) * len(mesh_dim_names) 

1054 ) 

1055 return self._create_unflatten_mesh(dim, mesh_sizes, mesh_dim_names, backend_override_tuple) 

1056 

1057 def assert_axis(self, axis, operate_name): 

1058 if not self.mesh_dim_names: 

1059 raise RuntimeError(f"mesh_dim_names not specified, {operate_name} is not supported.") 

1060 if axis not in self.mesh_dim_names: # pylint: disable=E1135 

1061 raise ValueError( 

1062 f"The axis name must be one of mesh dim name {self.mesh_dim_names}, but got {axis}" 

1063 ) 

1064 

1065 def axis_id(self, axis): 

1066 if axis == "None": 

1067 return -1 

1068 self.assert_axis(axis, "axis_id") 

1069 return self._dev_name_to_dev_id[axis] 

1070 

1071 def axis_index(self, axis): 

1072 self.assert_axis(axis, "axis_index") 

1073 return self._dev_name_to_index[axis] 

1074 

1075 def get_device_num_along_axis(self, axis): 

1076 self.assert_axis(axis, "get_device_num_along_axis") 

1077 return self.mesh_shape[self.mesh_dim_names.index(axis)] 

1078 

1079 def get_rank_list_along_axis(self, mesh_dim): 

1080 """Return the ranks that share every other coordinate with the current rank.""" 

1081 if mesh_dim in self._cache_rank_list_along_axis: 

1082 return self._cache_rank_list_along_axis[mesh_dim] 

1083 self.assert_axis(mesh_dim, "get_rank_list_along_axis") 

1084 

1085 mesh_shape = self.mesh_shape 

1086 mesh_dim_names = self.mesh_dim_names 

1087 rank_list = self.rank_list 

1088 rank = self.rank 

1089 

1090 if rank not in rank_list: 

1091 raise ValueError(f"Rank {rank} not found in rank_list") 

1092 

1093 idx = rank_list.index(rank) 

1094 coord = [0] * len(mesh_shape) 

1095 temp = idx 

1096 for i in range(len(mesh_shape) - 1, -1, -1): 

1097 coord[i] = temp % mesh_shape[i] 

1098 temp //= mesh_shape[i] 

1099 

1100 dim_index = mesh_dim_names.index(mesh_dim) 

1101 strides = [1] * len(mesh_shape) 

1102 for i in range(len(mesh_shape) - 2, -1, -1): 

1103 strides[i] = strides[i + 1] * mesh_shape[i + 1] 

1104 

1105 result_ranks = [] 

1106 for v in range(mesh_shape[dim_index]): 

1107 new_coord = coord.copy() 

1108 new_coord[dim_index] = v 

1109 new_idx = 0 

1110 for i in range(len(mesh_shape)): 

1111 new_idx += new_coord[i] * strides[i] 

1112 

1113 result_ranks.append(rank_list[new_idx]) 

1114 

1115 self._cache_rank_list_along_axis[mesh_dim] = result_ranks 

1116 return result_ranks 

1117 

1118 def get_global_shape(self, slice_shape, tensor_map): 

1119 """Infer the global tensor shape from a shard shape and tensor-map metadata.""" 

1120 map_key = hash((slice_shape, tensor_map)) 

1121 if map_key in self._global_shape_map: 

1122 return self._global_shape_map[map_key] 

1123 if tensor_map is None: 

1124 raise ValueError( 

1125 "tensor_map is not set. Please configure the tensor map by calling the layout." 

1126 ) 

1127 if len(slice_shape) != len(tensor_map): 

1128 raise ValueError( 

1129 f"Length of slice_shape ({len(slice_shape)}) must match " 

1130 f"the length of tensor_map ({len(tensor_map)})." 

1131 ) 

1132 

1133 n_dims = len(self._mesh_shape) 

1134 factors = [1] * len(slice_shape) 

1135 

1136 for dev_idx, size in enumerate(self._mesh_shape): 

1137 reverse_idx = n_dims - 1 - dev_idx 

1138 for axis_idx, mapping in enumerate(tensor_map): 

1139 if isinstance(mapping, int): 

1140 if mapping == -1: 

1141 continue 

1142 if mapping == reverse_idx: 

1143 factors[axis_idx] *= size 

1144 break 

1145 elif isinstance(mapping, tuple): 

1146 if reverse_idx in mapping: 

1147 factors[axis_idx] *= size 

1148 break 

1149 

1150 global_shape = [] 

1151 for i, dim in enumerate(slice_shape): 

1152 global_shape.append(dim * factors[i]) 

1153 self._global_shape_map[map_key] = tuple(global_shape) 

1154 return tuple(global_shape) 

1155 

1156 def _materialize_dim_group(self, mesh_dim: int) -> Optional[str]: 

1157 """Create a deferred process group for one mesh dimension on first use.""" 

1158 if not hasattr(self, "_dim_group_names"): 

1159 self._dim_group_names = [None] * self.ndim # pylint: disable=W0201 

1160 

1161 if hasattr(self, "_dim_group_sources"): 

1162 source_mesh, source_dim = self._dim_group_sources[mesh_dim] # pylint: disable=W0212 

1163 if source_mesh is not self or source_dim != mesh_dim: 

1164 source_group_key = source_mesh._materialize_dim_group(source_dim) # pylint: disable=W0212 

1165 self._dim_group_names[mesh_dim] = source_group_key 

1166 return source_group_key 

1167 

1168 group_key = self._dim_group_names[mesh_dim] 

1169 if group_key is not None and group_key in EXISTING_COMM_GROUPS: 

1170 return group_key 

1171 

1172 split_ranks, group_key = DeviceMesh._build_dim_split_ranks(self._layout[mesh_dim], self._rank_map) 

1173 group = platform.split_group(split_ranks=split_ranks) 

1174 DeviceMesh._cache_group_if_needed(group_key, group) 

1175 self._dim_group_names[mesh_dim] = group_key 

1176 return group_key 

1177 

1178 def get_comm_group_by_axis(self, mesh_dim: Union[str, int]): 

1179 """Return the cached or lazily materialized process group for one mesh axis.""" 

1180 if self.ndim == 1 and mesh_dim is None: 

1181 mesh_dim = 0 

1182 

1183 if isinstance(mesh_dim, str): 

1184 if self.mesh_dim_names is None or len(self.mesh_dim_names) == 0: 

1185 raise ValueError(f"DeviceMesh mesh_dim_names is not set, string mesh_dim {mesh_dim}, is not support.") 

1186 if mesh_dim not in self.mesh_dim_names: # pylint: disable=E1135 

1187 raise ValueError( 

1188 f"mesh_dim can pass a string or integer, but string mesh_dim '{mesh_dim}' not found in " 

1189 f"mesh_dim_names {self.mesh_dim_names}" 

1190 ) 

1191 mesh_dim = self.mesh_dim_names.index(mesh_dim) 

1192 else: 

1193 if not isinstance(mesh_dim, int) or mesh_dim < 0 or mesh_dim >= self.ndim: 

1194 raise ValueError( 

1195 f"mesh_dim can pass a string or integer, if not string, mesh_dim should be a integer in range " 

1196 f"[0, {self.ndim}), but got {mesh_dim}" 

1197 ) 

1198 

1199 if not hasattr(self, "_dim_group_names"): 

1200 raise RuntimeError("DeviceMesh process groups not initialized!") 

1201 

1202 group_key = self._dim_group_names[mesh_dim] 

1203 if group_key is None or group_key not in EXISTING_COMM_GROUPS: 

1204 group_key = self._materialize_dim_group(mesh_dim) 

1205 if group_key not in EXISTING_COMM_GROUPS: 

1206 raise ValueError(f"{group_key} not in group cache {EXISTING_COMM_GROUPS.keys()}") 

1207 return EXISTING_COMM_GROUPS[group_key] 

1208 

1209 def get_devices_for_axis(self, mesh_dim: Union[str, int], rank: int): 

1210 """List peer ranks that share all coordinates except the requested axis.""" 

1211 if isinstance(mesh_dim, str): 

1212 if not self.mesh_dim_names: 

1213 raise ValueError("_mesh_dim_names is not set, string mesh_dim is not supported, please pass a integer.") 

1214 mesh_dim_names = self.mesh_dim_names 

1215 if mesh_dim not in mesh_dim_names: # pylint: disable=E1135 

1216 raise ValueError(f"mesh_dim '{mesh_dim}' not found in mesh_dim_names {mesh_dim_names}") 

1217 mesh_dim = mesh_dim_names.index(mesh_dim) 

1218 

1219 mesh_shape = self._mesh_shape 

1220 if mesh_dim < 0 or mesh_dim >= self.ndim: 

1221 raise ValueError(f"mesh_dim {mesh_dim} can not out of range [0, {self.ndim})") 

1222 rank_list = self._rank_list 

1223 if rank not in rank_list: 

1224 raise ValueError(f"Rank {rank} not found in rank_list") 

1225 

1226 idx = rank_list.index(rank) 

1227 coord = [0] * len(mesh_shape) 

1228 temp = idx 

1229 for i in range(len(mesh_shape) - 1, -1, -1): 

1230 coord[i] = temp % mesh_shape[i] 

1231 temp //= mesh_shape[i] 

1232 

1233 strides = [1] * len(mesh_shape) 

1234 for i in range(len(mesh_shape) - 2, -1, -1): 

1235 strides[i] = strides[i + 1] * mesh_shape[i + 1] 

1236 

1237 result_ranks = [] 

1238 for v in range(mesh_shape[mesh_dim]): 

1239 new_coord = coord.copy() 

1240 new_coord[mesh_dim] = v 

1241 new_idx = 0 

1242 for i in range(len(mesh_shape)): 

1243 new_idx += new_coord[i] * strides[i] 

1244 

1245 result_ranks.append(rank_list[new_idx]) 

1246 

1247 return result_ranks 

1248 

1249 def to_hash(self): 

1250 map_key = (self.mesh_shape, self.mesh_dim_names, self.rank_list) 

1251 return map_key 

1252 

1253 def __repr__(self): 

1254 return ( 

1255 f"DeviceMesh(device_type='{self.device_type}', mesh_shape={self._mesh_shape}, " 

1256 f"mesh_dim_names={self.mesh_dim_names}, rank_list={self._rank_list})" 

1257 ) 

1258 

1259 def __str__(self): 

1260 return self.__repr__() 

1261 

1262 def __deepcopy__(self, memo): 

1263 cls = self.__class__ 

1264 result = cls.__new__(cls) 

1265 memo[id(self)] = result 

1266 for k, v in self.__dict__.items(): 

1267 if k in ("_root_mesh", "_dim_group_sources"): 

1268 setattr(result, k, v) 

1269 else: 

1270 setattr(result, k, copy.deepcopy(v, memo)) 

1271 return result 

1272 

1273 

1274_DEVICE_MESH_MAP = {} 

1275 

1276 

1277def _create_device_mesh(device_type: str, 

1278 mesh_shape: tuple[int, ...], 

1279 *, 

1280 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None, 

1281 rank_list: tuple[int, ...], 

1282 init_backend: bool = True, ): 

1283 """Create or reuse a cached DeviceMesh with the requested topology.""" 

1284 mesh = np.array(rank_list).reshape(mesh_shape) 

1285 mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None 

1286 map_key = hash((mesh_shape, mesh_dim_names, rank_list)) 

1287 if map_key not in _DEVICE_MESH_MAP: 

1288 _DEVICE_MESH_MAP[map_key] = DeviceMesh(device_type, mesh, 

1289 mesh_dim_names=mesh_dim_names, 

1290 _init_backend=init_backend) 

1291 return _DEVICE_MESH_MAP.get(map_key, None) 

1292 

1293 

1294def init_device_mesh( 

1295 device_type: str, 

1296 mesh_shape: tuple[int, ...], 

1297 *, 

1298 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None, 

1299 rank_list: Optional[tuple[int, ...]] = None, 

1300 init_backend: bool = True, 

1301) -> DeviceMesh: 

1302 """Initialize a cached DeviceMesh from the provided shape, names, and ranks.""" 

1303 total_devices = int(np.prod(np.array(mesh_shape))) 

1304 if rank_list is not None: 

1305 if len(rank_list) != total_devices: 

1306 raise ValueError( 

1307 f"rank_list length ({len(rank_list)}) must equal mesh size ({total_devices})" 

1308 ) 

1309 else: 

1310 if init_backend: 

1311 platform.init_process_group() 

1312 try: 

1313 current_rank = platform.get_rank() 

1314 except Exception as exc: 

1315 raise RuntimeError( 

1316 "init_device_mesh: failed to get current rank for automatic rank_list generation. " 

1317 "Either pass rank_list explicitly, or ensure the process group is initialized before calling " 

1318 "init_device_mesh (or set init_backend=True to let init_device_mesh initialize it)." 

1319 ) from exc 

1320 base = current_rank - (current_rank % total_devices) 

1321 rank_list = tuple(range(base, base + total_devices)) 

1322 

1323 if not isinstance(mesh_shape, tuple): 

1324 raise TypeError(f'mesh_shape must be a tuple, but got {type(mesh_shape)}') 

1325 

1326 for size in mesh_shape: 

1327 if not isinstance(size, int) or size <= 0: 

1328 raise ValueError( 

1329 f"Each element of mesh_shape must be a positive integer, but got {mesh_shape}" 

1330 ) 

1331 

1332 if mesh_dim_names is not None: 

1333 if not isinstance(mesh_dim_names, (tuple, list)): 

1334 raise TypeError( 

1335 f'mesh_dim_names must be a tuple or list, but got {type(mesh_dim_names)}' 

1336 ) 

1337 mesh_dim_names = tuple(mesh_dim_names) 

1338 if len(mesh_shape) != len(mesh_dim_names): 

1339 raise ValueError( 

1340 f'mesh_shape ({len(mesh_shape)}) and mesh_dim_names ' 

1341 f'({len(mesh_dim_names)}) should have same length' 

1342 ) 

1343 if len(set(mesh_dim_names)) != len(mesh_dim_names): 

1344 raise ValueError(f'Each element of mesh_dim_names {mesh_dim_names} should be different') 

1345 if any(not isinstance(name, str) or name == "" for name in mesh_dim_names): 

1346 raise ValueError(f'Each element of mesh_dim_names {mesh_dim_names} should be a non-empty string') 

1347 

1348 return _create_device_mesh( 

1349 device_type, 

1350 mesh_shape, 

1351 mesh_dim_names=mesh_dim_names, 

1352 rank_list=rank_list, 

1353 init_backend=init_backend, 

1354 )