Coverage for hyper_parallel / core / device_mesh.py: 76%

442 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""device mesh""" 

16 

17import os 

18from typing import Optional, Union, List, Any 

19import numpy as np 

20from hyper_parallel.platform import get_platform 

21 

22platform = get_platform() 

23Tensor = platform.Tensor 

24 

25_group_map = {} 

26 

27 

28def _get_sub_rank_list(mesh_shape, mesh_dim_names, rank_list, sub_mesh_dim_names, current_rank): 

29 """ 

30 Get the sub rank list for a sub mesh. 

31 

32 Args: 

33 mesh_shape (tuple[int]): The shape of the original mesh. 

34 mesh_dim_names (tuple[str]): The mesh dim names of the original mesh dimensions. 

35 rank_list (tuple[int]): A tuple of ranks that participate in this mesh. 

36 sub_mesh_dim_names (tuple[str]): The mesh dim names of the sub mesh to extract. 

37 current_rank (int): The current process rank. 

38 

39 Returns: 

40 list: The sub rank list for the sub mesh. 

41 """ 

42 # Reshape rank list into mesh tensor according to mesh shape 

43 mesh_tensor = np.array(rank_list).reshape(mesh_shape) 

44 

45 # Iterate through each dimension of the original mesh 

46 for dim_index, dim_name in enumerate(mesh_dim_names): 

47 

48 # Skip dimensions that are included in the sub mesh 

49 if dim_name in sub_mesh_dim_names: 

50 continue 

51 

52 # Split mesh tensor along current dimension 

53 dim_size = mesh_shape[dim_index] 

54 sliced_tensors = np.split(mesh_tensor, dim_size, axis=dim_index) 

55 

56 # Find and keep only the slice containing the current rank 

57 for sliced_tensor in sliced_tensors: 

58 rank_exists = np.isin(np.array([current_rank]), sliced_tensor).any() 

59 if rank_exists: 

60 mesh_tensor = sliced_tensor 

61 break 

62 

63 # Flatten the resulting tensor to get the sub rank list 

64 sub_rank_list = mesh_tensor.reshape(-1).tolist() 

65 return sub_rank_list 

66 

67 

68class DeviceMesh: 

69 """ 

70 Topological abstraction describing cluster devices. 

71 

72 Args: 

73 device_type (str): Device type. 

74 mesh (Union[Tensor, list, tuple, np.ndarray]): A multi-dimensional array, list, or integer 

75 tensor describing the device layout. The IDs in the mesh are global IDs of the 

76 default process group, representing the multi-dimensional networking structure 

77 of devices in distributed training (e.g., [[0,1],[2,3]] represents a 2x2 device mesh). 

78 If a list or non-int32 tensor is provided, it will be automatically converted 

79 to an int32 tensor. 

80 mesh_dim_names (tuple[str]): A tuple[str] of mesh dim names for each dimension of mesh. 

81 _init_backend (boolean): Whether initial process group. 

82 

83 Attributes: 

84 ndim (int): Number of dimensions in the mesh. 

85 mesh_shape (tuple[int]): Shape of the device mesh. 

86 rank_list (tuple[int]): Flattened list of ranks from the mesh. 

87 root_mesh (DeviceMesh): The parent mesh if this is a sub mesh, None otherwise. 

88 sub_mesh (list[DeviceMesh]): List of child meshes created from this mesh. 

89 

90 Examples: 

91 >>> # Using Tensor 

92 >>> mesh = Tensor([[0, 1], [2, 3]]) 

93 >>> device_mesh = DeviceMesh("npu", mesh, nesh_dim_names=("dp", "tp")) 

94 >>> # Using list 

95 >>> device_mesh = DeviceMesh("npu", [[0, 1], [2, 3]], nesh_dim_names=("dp", "tp")) 

96 >>> # Get sub mesh 

97 >>> dp_mesh = device_mesh["dp"] 

98 >>> # Access ndim 

99 >>> print(device_mesh.ndim) # Output: 2 

100 >>> print(device_mesh.mesh_shape) # Output: (2, 2) 

101 >>> print(device_mesh.rank_list) # Output: (0, 1, 2, 3) 

102 """ 

103 

104 def __init__(self, 

105 device_type: str, 

106 mesh: Union[Tensor, list, tuple, np.ndarray], 

107 *, 

108 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None, 

109 _init_backend: bool = True, 

110 ): 

111 self._device_type = device_type 

112 # Convert mesh to Tensor with int32 dtype 

113 mesh = self._convert_mesh_to_tensor(mesh) 

114 

115 # Validate mesh dimensions 

116 if mesh.ndim == 0: 

117 raise ValueError("mesh must be at least 1-dimensional") 

118 

119 # Extract mesh_shape and rank_list from mesh 

120 self._mesh_shape = tuple(mesh.shape) 

121 self._rank_list = tuple(platform.tensor_to_numpy(mesh).flatten().tolist()) 

122 self._mesh = mesh 

123 self._dev_num = np.prod(np.array(self._mesh_shape)) 

124 self._dev_rank = len(self._mesh_shape) 

125 # mesh_dim_names 

126 self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None 

127 if self._mesh_dim_names is not None: 

128 # Validate mesh_dim_names 

129 if len(self._mesh_shape) != len(mesh_dim_names): 

130 raise ValueError( 

131 f'mesh dimensions ({len(self._mesh_shape)}) should be equal to ' 

132 f'mesh_dim_names length ({len(mesh_dim_names)})' 

133 ) 

134 if len(set(mesh_dim_names)) != len(mesh_dim_names): 

135 raise ValueError(f'Each element of mesh_dim_names {mesh_dim_names} should be different') 

136 inter_key = "interleaved_parallel" 

137 if inter_key in mesh_dim_names and mesh_dim_names.index(inter_key) != len(mesh_dim_names) - 1: 

138 raise ValueError( 

139 "'interleaved_parallel' should be at the last dim of mesh_dim_names, means virtual sharding." 

140 ) 

141 self._dev_name_to_dev_id = { 

142 name: self._dev_rank - i - 1 for i, name in enumerate(self._mesh_dim_names) 

143 } 

144 self._dev_name_to_index = {name: i for i, name in enumerate(self._mesh_dim_names)} 

145 

146 self._rank = platform.get_rank() 

147 self._cache_rank_list_along_axis = {} 

148 self._global_shape_map = {} 

149 self._sub_mesh_cache = {} 

150 self._flatten_mapping: dict[str, 'DeviceMesh'] = {} 

151 self._ndim: int = len(self._mesh_shape) 

152 self._root_mesh: Optional['DeviceMesh'] = None 

153 self._sub_mesh: List['DeviceMesh'] = [] 

154 if _init_backend: 

155 platform.init_process_group() 

156 self._dim_group_names = self._init_process_groups(self._mesh_shape, self._mesh_dim_names, self._rank_list) 

157 if os.getenv("MS_SIMULATION_LEVEL") is None: 

158 self._coordinate_on_dim = self._compute_coordinate_on_dim() 

159 

160 def _compute_coordinate_on_dim(self): 

161 # calculate the coordinates of the current global rank on the mesh 

162 return self._compute_coordinates_from_mesh(self.mesh, self._rank) 

163 

164 @staticmethod 

165 def _compute_coordinates_from_mesh( 

166 mesh_tensor: Tensor, 

167 rank: int, 

168 ): 

169 """ 

170 Compute the coordinates of a rank within a mesh tensor. 

171 

172 Args: 

173 mesh_tensor (Tensor): The mesh tensor to search in 

174 rank (int): The rank to find coordinates for 

175 

176 Returns: 

177 A tuple of coordinates if the rank is found in the mesh, None otherwise 

178 

179 Raises: 

180 AssertionError: If the rank appears more than once in the mesh 

181 """ 

182 rank_coords = (mesh_tensor == rank).nonzero() 

183 if rank_coords.shape[0] not in (0, 1): 

184 raise AssertionError( 

185 f"rank_coords.shape[0] must be 0 or 1, got {rank_coords.shape[0]}" 

186 ) 

187 

188 if rank_coords.shape[0] == 0: 

189 return None 

190 

191 coords = rank_coords[0].tolist() 

192 return tuple(coords) 

193 

194 def size(self, mesh_dim=None) -> int: 

195 if mesh_dim is not None: 

196 return self.mesh.shape[mesh_dim] 

197 return self.mesh.numel() 

198 

199 def get_coordinate(self): 

200 """ 

201 Return the relative indices of this rank relative to all 

202 dimensions of the mesh. If this rank is not part of the mesh, return None. 

203 """ 

204 return self._coordinate_on_dim if self._coordinate_on_dim else None 

205 

206 @staticmethod 

207 def _convert_mesh_to_tensor(mesh: Union[Tensor, list, tuple, np.ndarray]) -> Tensor: 

208 """Convert mesh to Tensor with int32 dtype.""" 

209 if isinstance(mesh, Tensor): 

210 mesh = platform.tensor_to_numpy(mesh) 

211 elif isinstance(mesh, (list, tuple)): 

212 mesh = np.array(mesh) 

213 elif not isinstance(mesh, np.ndarray): 

214 raise TypeError( 

215 f"mesh must be Tensor, list, tuple or numpy array, but got {type(mesh)}" 

216 ) 

217 

218 mesh = mesh.astype(np.int32) 

219 return Tensor(mesh).int() 

220 

221 @staticmethod 

222 def _init_one_process_group(mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...], 

223 dim_name: str, rank_list: tuple[int, ...]) -> str: 

224 """ 

225 init one process group 

226 """ 

227 group_name = None 

228 group_desc = f"mesh_{dim_name}" 

229 split_ranks = set() 

230 if not isinstance(dim_name, tuple): 

231 dim_name = (dim_name,) 

232 for rank in rank_list: 

233 split_rank = _get_sub_rank_list(mesh_shape, mesh_dim_names, rank_list, dim_name, rank) 

234 split_ranks.add(tuple(sorted(split_rank))) 

235 split_ranks = sorted([list(item) for item in split_ranks]) 

236 group = platform.split_group(split_ranks=split_ranks, group_desc=group_desc) 

237 if group: 

238 if isinstance(group, str): 

239 group_name = group 

240 else: 

241 group_name = group.group_name 

242 _group_map[group_name] = group 

243 return group_name 

244 

245 @staticmethod 

246 def _init_process_groups(mesh_shape: tuple[int, ...], mesh_dim_names: Union[tuple[str, ...], None], 

247 rank_list: tuple[int, ...]) -> list: 

248 """ 

249 Init process groups. For every dim in mesh_shape, create split group for current rank. 

250 

251 Args: 

252 mesh_shape (tuple[int, ...]): Shape of mesh. 

253 mesh_dim_names (tuple[str, ...]): Names of every dimension of mesh. 

254 rank_list (tuple[int, ...]): Rank list of current process group worked on. 

255 """ 

256 if mesh_dim_names is None: 

257 mesh_dim_names = [] 

258 for dim in range(len(mesh_shape)): 

259 mesh_dim_names.append(f"dim_{dim}") 

260 mesh_dim_names = tuple(mesh_dim_names) 

261 

262 dim_group_names = [] 

263 for dim in range(len(mesh_shape)): 

264 dim_name = mesh_dim_names[dim] 

265 dim_group_name = DeviceMesh._init_one_process_group(mesh_shape, mesh_dim_names, dim_name, rank_list) 

266 dim_group_names.append(dim_group_name) 

267 

268 # Filter out None values. If any are None then they should all be None. 

269 dim_non_none_group_names = [n for n in dim_group_names if n is not None] 

270 assert not dim_non_none_group_names or len(dim_non_none_group_names) == len(dim_group_names) 

271 return dim_non_none_group_names 

272 

273 @property 

274 def mesh(self) -> Tensor: 

275 """Get the mesh tensor.""" 

276 return self._mesh 

277 

278 def device_type(self) -> str: 

279 """Get the device type.""" 

280 return self._device_type 

281 

282 @property 

283 def rank(self): 

284 return self._rank 

285 

286 @property 

287 def mesh_shape(self): 

288 return self._mesh_shape 

289 

290 @property 

291 def mesh_dim_names(self): 

292 return self._mesh_dim_names 

293 

294 @property 

295 def rank_list(self): 

296 return self._rank_list 

297 

298 @property 

299 def ndim(self) -> int: 

300 return self._ndim 

301 

302 @property 

303 def root_mesh(self) -> Optional['DeviceMesh']: 

304 return self._root_mesh 

305 

306 @root_mesh.setter 

307 def root_mesh(self, value: Optional['DeviceMesh']): 

308 """Set the parent mesh reference.""" 

309 self._root_mesh = value 

310 

311 @property 

312 def sub_mesh(self) -> List['DeviceMesh']: 

313 return self._sub_mesh 

314 

315 def get_flatten_mapping(self) -> dict: 

316 """Get the flatten mapping dictionary.""" 

317 return self._flatten_mapping 

318 

319 def add_flatten_mapping(self, name: str, mesh: 'DeviceMesh') -> None: 

320 """Add a flattened mesh to the flatten mapping.""" 

321 self._flatten_mapping[name] = mesh 

322 

323 def __getitem__(self, sub_mesh_dim_names: Union[str, tuple[str, ...]]) -> 'DeviceMesh': 

324 """ 

325 Get a sub DeviceMesh based on the specified dimension names. 

326 

327 This method supports both original dimension names and flattened dimension names. 

328 For example, if a mesh has dimensions ("dp", "cp", "tp") and a flattened mesh 

329 "dp_cp" was created via flatten(), both mesh["dp"] and mesh["dp_cp"] are valid. 

330 

331 Args: 

332 sub_mesh_dim_names: A string or tuple of strings specifying the dimension names 

333 for the sub mesh. Can be original dimension names or flattened 

334 dimension names registered in the root mesh's flatten_mapping. 

335 

336 Returns: 

337 DeviceMesh: A new DeviceMesh representing the sub mesh. 

338 

339 Raises: 

340 ValueError: If sub_mesh_dim_names is invalid or not a contiguous prefix. 

341 KeyError: If sub_mesh_dim_names contains names not in mesh_dim_names or flatten_mapping. 

342 

343 Examples: 

344 >>> mesh = platform.tensor([[0, 1], [2, 3]]) 

345 >>> device_mesh = DeviceMesh("npu", mesh, nesh_dim_names=("dp", "tp")) 

346 >>> dp_mesh = device_mesh["dp"] 

347 >>> print(dp_mesh.mesh_shape) # Output: (2,) 

348 >>> print(dp_mesh.mesh_dim_names) # Output: ("dp",) 

349 >>> # After creating a flattened mesh: 

350 >>> flat_mesh = device_mesh.flatten() 

351 >>> # Can also access via flattened name: 

352 >>> same_flat_mesh = device_mesh["dp_tp"] 

353 """ 

354 if not self._mesh_dim_names: 

355 raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!") 

356 

357 sub_mesh_dim_names = self._normalize_sub_mesh_dim_names(sub_mesh_dim_names) 

358 flatten_mapping = self._get_root_mesh().get_flatten_mapping() 

359 

360 # Try to get from flatten_mapping first 

361 flattened_result = self._try_get_from_flatten_mapping(sub_mesh_dim_names, flatten_mapping) 

362 if flattened_result is not None: 

363 return flattened_result 

364 

365 # Validate dimension names 

366 self._validate_getitem_dimensions(sub_mesh_dim_names, flatten_mapping) 

367 

368 # Get or create sub mesh for original dimensions 

369 return self._get_or_create_original_sub_mesh(sub_mesh_dim_names) 

370 

371 def _normalize_sub_mesh_dim_names(self, sub_mesh_dim_names: Union[str, tuple[str, ...]]) -> tuple[str, ...]: 

372 """Convert sub_mesh_dim_names to tuple format and validate basic type.""" 

373 if isinstance(sub_mesh_dim_names, str): 

374 sub_mesh_dim_names = (sub_mesh_dim_names,) 

375 

376 if not isinstance(sub_mesh_dim_names, tuple): 

377 raise TypeError( 

378 f"sub_mesh_dim_names must be str or tuple, but got {type(sub_mesh_dim_names)}" 

379 ) 

380 

381 if len(sub_mesh_dim_names) == 0: 

382 raise ValueError("sub_mesh_dim_names cannot be empty") 

383 

384 return sub_mesh_dim_names 

385 

386 def _try_get_from_flatten_mapping(self, sub_mesh_dim_names: tuple[str, ...], 

387 flatten_mapping: dict) -> Optional['DeviceMesh']: 

388 """Try to get mesh from flatten_mapping. Returns None if not applicable.""" 

389 if len(sub_mesh_dim_names) == 1 and sub_mesh_dim_names[0] in flatten_mapping: 

390 return flatten_mapping[sub_mesh_dim_names[0]] 

391 return None 

392 

393 def _validate_getitem_dimensions(self, sub_mesh_dim_names: tuple[str, ...], flatten_mapping: dict): 

394 """Validate dimension names for __getitem__ operation.""" 

395 valid_dim_names = list(self._mesh_dim_names) + list(flatten_mapping.keys()) 

396 

397 # Validate all names exist 

398 for name in sub_mesh_dim_names: 

399 if name not in valid_dim_names: 

400 raise KeyError( 

401 f"Dimension name '{name}' not found in mesh_dim_names {self._mesh_dim_names} " 

402 f"or flatten_mapping keys {list(flatten_mapping.keys())}" 

403 ) 

404 

405 # Check for mixed or multiple flattened dimensions 

406 original_dims = [name for name in sub_mesh_dim_names if name in self._mesh_dim_names] # pylint: disable=E1135 

407 flattened_dims = [name for name in sub_mesh_dim_names if name in flatten_mapping] 

408 

409 if len(flattened_dims) == len(sub_mesh_dim_names) and len(flattened_dims) > 1: 

410 raise ValueError( 

411 f"Slicing multiple flattened dimensions {flattened_dims} simultaneously " 

412 f"is not supported. Please slice them separately." 

413 ) 

414 

415 if flattened_dims and original_dims: 

416 raise ValueError( 

417 f"Cannot mix original dimensions {original_dims} with flattened dimensions " 

418 f"{flattened_dims} in a single slice operation." 

419 ) 

420 

421 def _get_or_create_original_sub_mesh(self, sub_mesh_dim_names: tuple[str, ...]) -> 'DeviceMesh': 

422 """Get or create sub mesh for original (non-flattened) dimensions.""" 

423 # Validate dimension order 

424 indices = [self._mesh_dim_names.index(name) for name in sub_mesh_dim_names] 

425 if indices != sorted(indices): 

426 raise ValueError( 

427 f"sub_mesh_dim_names {sub_mesh_dim_names} must follow the order of " 

428 f"original mesh_dim_names {self._mesh_dim_names}" 

429 ) 

430 

431 # Check cache 

432 if sub_mesh_dim_names in self._sub_mesh_cache: 

433 return self._sub_mesh_cache[sub_mesh_dim_names] 

434 

435 # Return self if requesting all dimensions 

436 if len(sub_mesh_dim_names) == len(self._mesh_dim_names): 

437 return self 

438 

439 # Create new sub mesh 

440 return self._create_and_cache_sub_mesh(sub_mesh_dim_names, indices) 

441 

442 def _create_and_cache_sub_mesh(self, sub_mesh_dim_names: tuple[str, ...], 

443 indices: List[int]) -> 'DeviceMesh': 

444 """Create a new sub mesh and cache it.""" 

445 sub_mesh_shape = tuple(self._mesh_shape[i] for i in indices) 

446 

447 sub_rank_list = _get_sub_rank_list( 

448 self._mesh_shape, 

449 self._mesh_dim_names, 

450 self._rank_list, 

451 sub_mesh_dim_names, 

452 self._rank 

453 ) 

454 sub_rank_list = tuple(sub_rank_list) 

455 

456 # Create sub mesh tensor using Tensor() 

457 sub_mesh_tensor = Tensor(sub_rank_list).reshape(sub_mesh_shape) 

458 

459 # Create sub mesh 

460 sub_mesh = DeviceMesh( 

461 device_type="npu", 

462 mesh=sub_mesh_tensor, 

463 mesh_dim_names=sub_mesh_dim_names, 

464 _init_backend=False 

465 ) 

466 # Set root mesh reference 

467 sub_mesh.root_mesh = self._get_root_mesh() 

468 

469 slice_dim_group_name = [] 

470 for name in sub_mesh_dim_names: 

471 # pylint: disable=E1135 

472 if name in self._mesh_dim_names: 

473 slice_dim_group_name.append( 

474 self._dim_group_names[self._mesh_dim_names.index(name)] 

475 ) 

476 sub_mesh._dim_group_names = slice_dim_group_name # pylint: disable=W0212 

477 

478 # Cache and track 

479 self._sub_mesh_cache[sub_mesh_dim_names] = sub_mesh 

480 # Add to sub_mesh list 

481 self.sub_mesh.append(sub_mesh) 

482 

483 return sub_mesh 

484 

485 def get_group(self, mesh_dim: Optional[Union[int, str]] = None): 

486 """ 

487 Get the communication group for a specific mesh dimension. 

488 

489 Args: 

490 mesh_dim: The dimension index or name. If None and mesh is 1D, 

491 returns the only group. If None and mesh is multi-dimensional, 

492 raises an error. 

493 

494 Returns: 

495 The process group for the specified dimension. 

496 

497 Raises: 

498 RuntimeError: If mesh_dim is None and mesh has more than 1 dimension. 

499 ValueError: If mesh_dim is invalid. 

500 

501 Examples: 

502 >>> mesh = Tensor([[0, 1], [2, 3]]) 

503 >>> device_mesh = DeviceMesh("npu", mesh, nesh_dim_names=("dp", "tp")) 

504 >>> dp_group = device_mesh.get_group("dp") 

505 >>> # or by index 

506 >>> dp_group = device_mesh.get_group(0) 

507 """ 

508 if not hasattr(self, "_dim_group_names"): 

509 raise RuntimeError("DeviceMesh process groups not initialized!") 

510 

511 if self.ndim > 1 and mesh_dim is None: 

512 raise RuntimeError( 

513 f"Found the DeviceMesh have {self.ndim} dimensions. " 

514 "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1." 

515 ) 

516 

517 # Check if mesh_dim is a flattened dimension name in root mesh's flatten_mapping 

518 root_mesh = self._get_root_mesh() 

519 if isinstance(mesh_dim, str) and mesh_dim in root_mesh.get_flatten_mapping(): 

520 # Return the group from the flattened mesh 

521 flattened_mesh = root_mesh.get_flatten_mapping()[mesh_dim] 

522 return flattened_mesh.get_comm_group_by_axis(mesh_dim) 

523 

524 return self.get_comm_group_by_axis(mesh_dim) 

525 

526 @staticmethod 

527 def from_group(group: Union[Any, list[Any]], 

528 device_type: str, 

529 mesh: Union[Tensor, list, tuple, np.ndarray] = None, 

530 mesh_dim_names: Union[tuple[str, ...], list[str]] = None 

531 ) -> 'DeviceMesh': 

532 """ 

533 Create device mesh from group or group list. 

534 

535 Args: 

536 group: The group or group list to create device mesh from. 

537 device_type: Device type. 

538 mesh: 

539 For 1d group, mesh can pass None. If group is 1d and mesh is not None, the mesh must equal to 

540 group_ranks get from group, or must be a tensor which tolist value equal to group_ranks. 

541 For nd group, mesh must be passed. 

542 mesh_dim_names: Names of every mesh dimension. 

543 """ 

544 if not isinstance(group, list): 

545 group_ranks = platform.get_process_group_ranks(group) 

546 if ( 

547 isinstance(mesh, Tensor) and mesh.tolist() != group_ranks 

548 ) or ( 

549 mesh is not None 

550 and not isinstance(mesh, Tensor) 

551 and mesh != group_ranks 

552 ): 

553 raise ValueError( 

554 f"Invalid mesh_shape {str(mesh)} for 1D group with ranks {group_ranks}" 

555 ) 

556 device_mesh = DeviceMesh(device_type, group_ranks, mesh_dim_names=mesh_dim_names, _init_backend=False) 

557 if isinstance(group, str): 

558 # pylint: disable=W0212 

559 device_mesh._dim_group_names = [group] 

560 _group_map[group] = group 

561 else: 

562 device_mesh._dim_group_names = [group.group_name] # pylint: disable=W0212 

563 _group_map[group.group_name] = group 

564 return device_mesh 

565 

566 groups = list(group) 

567 if len(groups) == 0: 

568 raise ValueError("Expect at least one group be specified.") 

569 if mesh is None: 

570 raise ValueError("mesh_shape is must specified when group is a list.") 

571 mesh = DeviceMesh._convert_mesh_to_tensor(mesh) 

572 if mesh.ndim != len(groups): 

573 raise ValueError("mesh dimensions must match group dimensions.") 

574 device_mesh = DeviceMesh(device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False) 

575 # pylint: disable=W0212 

576 device_mesh._dim_group_names = [] 

577 for dim_group in groups: 

578 if isinstance(dim_group, str): 

579 # pylint: disable=W0212 

580 device_mesh._dim_group_names.append(dim_group) 

581 _group_map[dim_group] = dim_group 

582 else: 

583 # pylint: disable=W0212 

584 device_mesh._dim_group_names.append(dim_group.group_name) 

585 _group_map[dim_group.group_name] = dim_group 

586 return device_mesh 

587 

588 def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: 

589 """ 

590 Get the local rank within a specific mesh dimension. 

591 

592 Args: 

593 mesh_dim: The dimension index or name. If None and mesh is 1D, 

594 uses dimension 0. If None and mesh is multi-dimensional, 

595 raises an error. 

596 

597 Returns: 

598 int: The local rank within the specified dimension. 

599 

600 Raises: 

601 RuntimeError: If mesh_dim is None and mesh has more than 1 dimension. 

602 ValueError: If mesh_dim is invalid or current rank not in rank_list. 

603 

604 Examples: 

605 >>> mesh = Tensor([[0, 1, 2, 3], [4, 5, 6, 7]]) 

606 >>> device_mesh = DeviceMesh("npu", mesh, nesh_dim_names=("dp", "tp")) 

607 >>> # On rank 0 

608 >>> print(device_mesh.get_local_rank("dp")) # Output: 0 

609 >>> print(device_mesh.get_local_rank("tp")) # Output: 0 

610 """ 

611 if self.ndim > 1 and mesh_dim is None: 

612 raise RuntimeError( 

613 f"Found the DeviceMesh have {self.ndim} dimensions. " 

614 "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1." 

615 ) 

616 

617 if mesh_dim is None: 

618 mesh_dim = 0 

619 

620 # Convert string to index 

621 if isinstance(mesh_dim, str): 

622 # pylint: disable=E1135 

623 if mesh_dim not in self._mesh_dim_names: 

624 raise ValueError( 

625 f"mesh_dim '{mesh_dim}' not found in mesh_dim_names {self._mesh_dim_names}" 

626 ) 

627 dim_index = self._mesh_dim_names.index(mesh_dim) 

628 else: 

629 if not isinstance(mesh_dim, int) or mesh_dim < 0 or mesh_dim >= self.ndim: 

630 raise ValueError( 

631 f"mesh_dim must be an integer in range [0, {self.ndim}), " 

632 f"but got {mesh_dim}" 

633 ) 

634 dim_index = mesh_dim 

635 

636 if self._rank not in self._rank_list: 

637 raise ValueError( 

638 f"Current rank {self._rank} not found in rank_list {self._rank_list}" 

639 ) 

640 

641 # Calculate the coordinate of current rank in the mesh 

642 idx = self._rank_list.index(self._rank) 

643 coord = [0] * len(self._mesh_shape) 

644 temp = idx 

645 for i in range(len(self._mesh_shape) - 1, -1, -1): 

646 coord[i] = temp % self._mesh_shape[i] 

647 temp //= self._mesh_shape[i] 

648 

649 return coord[dim_index] 

650 

651 def flatten(self, mesh_dim_name: Optional[str] = None) -> 'DeviceMesh': 

652 """ 

653 Returns a 1D DeviceMesh by flattening the current DeviceMesh. 

654 

655 Args: 

656 mesh_dim_name (str, optional): The name for the flattened dimension. 

657 If not provided, the name will be generated by joining the original 

658 mesh dim names with underscore (e.g., "dp_tp" for ("dp", "tp")). 

659 This name will be used as the key in the root mesh's flatten_mapping. 

660 

661 Returns: 

662 DeviceMesh: A 1D DeviceMesh with flattened dimensions. 

663 

664 Raises: 

665 ValueError: If mesh_dim_name conflicts with existing mesh dim names. 

666 

667 Examples: 

668 >>> mesh = Tensor([[0, 1], [2, 3]]) 

669 >>> device_mesh = DeviceMesh("npu", mesh, nesh_dim_names=("dp", "tp")) 

670 >>> # Using default name 

671 >>> flat_mesh = device_mesh.flatten() 

672 >>> print(flat_mesh.mesh_dim_names) # Output: ("dp_tp",) 

673 >>> # Using custom name 

674 >>> flat_mesh = device_mesh.flatten(mesh_dim_name="custom_name") 

675 >>> print(flat_mesh.mesh_dim_names) # Output: ("custom_name",) 

676 """ 

677 return self._create_flatten_mesh(mesh_dim_name) 

678 

679 def _get_root_mesh(self) -> 'DeviceMesh': 

680 """Get the root mesh of this DeviceMesh.""" 

681 if self._root_mesh is None: 

682 return self 

683 # pylint: disable=protected-access 

684 return self._root_mesh._get_root_mesh() 

685 

686 def _create_flatten_mesh(self, mesh_dim_name: Optional[str] = None) -> 'DeviceMesh': 

687 """Create a flattened 1D mesh from the current mesh. 

688 

689 Args: 

690 mesh_dim_name (str, optional): The name for the flattened dimension. 

691 If not provided, defaults to joining mesh dim names with underscore. 

692 """ 

693 root_mesh = self._get_root_mesh() 

694 

695 # Generate mesh_dim_name by joining mesh dim names if not provided 

696 if mesh_dim_name is None: 

697 mesh_dim_name = "_".join(self._mesh_dim_names) 

698 

699 # Flatten a 1D device mesh into its original mesh_dim_names will return itself 

700 if self.ndim == 1 and mesh_dim_name in self._mesh_dim_names: # pylint: disable=E1135 

701 return self 

702 

703 # Check whether the mesh_dim_name for flattened mesh is valid 

704 # It should not conflict with existing mesh dim names in root mesh 

705 invalid_dim_names = root_mesh.mesh_dim_names 

706 if mesh_dim_name in invalid_dim_names: 

707 raise ValueError( 

708 f"'{mesh_dim_name}' already exists in the root mesh mesh_dim_names " 

709 f"{invalid_dim_names}. Please specify another valid mesh_dim_name." 

710 ) 

711 

712 # Quick return if the flatten mesh has been created before with same layout 

713 flatten_mapping = root_mesh.get_flatten_mapping() 

714 if mesh_dim_name in flatten_mapping: 

715 cached_mesh = flatten_mapping[mesh_dim_name] 

716 # Verify the cached mesh has the expected flattened size 

717 expected_size = int(np.prod(self._mesh_shape)) 

718 if cached_mesh.mesh_shape == (expected_size,): 

719 return cached_mesh 

720 raise ValueError( 

721 f"Flatten mesh with mesh_dim_name '{mesh_dim_name}' has been created " 

722 f"before with different layout. Please specify another valid mesh_dim_name." 

723 ) 

724 

725 # Calculate the flattened mesh properties 

726 flattened_mesh_dim = (mesh_dim_name,) 

727 

728 # Create flattened mesh tensor using Tensor() 

729 flattened_mesh_tensor = Tensor(self._rank_list) 

730 

731 # Create the flattened mesh 

732 res_flattened_mesh = DeviceMesh( 

733 device_type="npu", 

734 mesh=flattened_mesh_tensor, 

735 mesh_dim_names=flattened_mesh_dim 

736 ) 

737 # Set root mesh reference to the actual root mesh 

738 res_flattened_mesh.root_mesh = root_mesh 

739 

740 # Cache the flattened mesh in root mesh's flatten_mapping 

741 root_mesh.add_flatten_mapping(mesh_dim_name, res_flattened_mesh) 

742 root_mesh.sub_mesh.append(res_flattened_mesh) 

743 

744 return res_flattened_mesh 

745 

746 def axis_id(self, axis): 

747 if axis == "None": 

748 return -1 

749 # pylint: disable=E1135 

750 if axis not in self.mesh_dim_names: 

751 raise ValueError( 

752 f"The axis name must be one of mesh shape mesh dim name {self.mesh_dim_names}), " 

753 f"but got {axis}" 

754 ) 

755 return self._dev_name_to_dev_id[axis] 

756 

757 def axis_index(self, axis): 

758 # pylint: disable=E1135 

759 if axis not in self.mesh_dim_names: 

760 raise ValueError( 

761 f"The axis name must be one of mesh shape mesh dim name {self.mesh_dim_names}), " 

762 f"but got {axis}" 

763 ) 

764 return self._dev_name_to_index[axis] 

765 

766 def get_device_num_along_axis(self, axis): 

767 """Return device num along specify device axis""" 

768 # pylint: disable=E1135 

769 if axis not in self.mesh_dim_names: 

770 raise ValueError( 

771 f"The axis must be one of device mesh dim name: {self.mesh_dim_names}, but got {axis}" 

772 ) 

773 return self.mesh_shape[self.mesh_dim_names.index(axis)] 

774 

775 def get_rank_list_along_axis(self, mesh_dim): 

776 """ 

777 Get the repeat rank list when the axis is not shard. 

778 

779 Args: 

780 mesh_dim (str): mesh_dim name. 

781 

782 Returns: 

783 list: reduce rank list 

784 """ 

785 if mesh_dim in self._cache_rank_list_along_axis: 

786 # shortcut, get rank list from cache 

787 return self._cache_rank_list_along_axis[mesh_dim] 

788 

789 mesh_shape = self.mesh_shape 

790 mesh_dim_names = self.mesh_dim_names 

791 rank_list = self.rank_list 

792 rank = self.rank 

793 

794 # pylint: disable=E1135 

795 if mesh_dim not in mesh_dim_names: 

796 raise ValueError(f"Axis '{mesh_dim}' not found in mesh_dim_names {mesh_dim_names}") 

797 

798 if rank not in rank_list: 

799 raise ValueError(f"Rank {rank} not found in rank_list") 

800 

801 idx = rank_list.index(rank) 

802 coord = [0] * len(mesh_shape) 

803 temp = idx 

804 for i in range(len(mesh_shape) - 1, -1, -1): 

805 coord[i] = temp % mesh_shape[i] 

806 temp //= mesh_shape[i] 

807 

808 dim_index = mesh_dim_names.index(mesh_dim) 

809 strides = [1] * len(mesh_shape) 

810 for i in range(len(mesh_shape) - 2, -1, -1): 

811 strides[i] = strides[i + 1] * mesh_shape[i + 1] 

812 

813 result_ranks = [] 

814 for v in range(mesh_shape[dim_index]): 

815 new_coord = coord.copy() 

816 new_coord[dim_index] = v 

817 new_idx = 0 

818 for i in range(len(mesh_shape)): 

819 new_idx += new_coord[i] * strides[i] 

820 

821 result_ranks.append(rank_list[new_idx]) 

822 

823 self._cache_rank_list_along_axis[mesh_dim] = result_ranks 

824 return result_ranks 

825 

826 def get_global_shape(self, slice_shape, tensor_map): 

827 """get global shape""" 

828 map_key = hash((slice_shape, tensor_map)) 

829 if map_key in self._global_shape_map: 

830 return self._global_shape_map[map_key] 

831 if tensor_map is None: 

832 raise ValueError( 

833 "tensor_map is not set. Please configure the tensor map by calling the layout." 

834 ) 

835 if len(slice_shape) != len(tensor_map): 

836 raise ValueError( 

837 f"Length of slice_shape ({len(slice_shape)}) must match " 

838 f"the length of tensor_map ({len(tensor_map)})." 

839 ) 

840 

841 n_dims = len(self._mesh_shape) 

842 factors = [1] * len(slice_shape) 

843 

844 for dev_idx, size in enumerate(self._mesh_shape): 

845 reverse_idx = n_dims - 1 - dev_idx 

846 for axis_idx, mapping in enumerate(tensor_map): 

847 if isinstance(mapping, int): 

848 if mapping == -1: 

849 continue 

850 if mapping == reverse_idx: 

851 factors[axis_idx] *= size 

852 break 

853 elif isinstance(mapping, tuple): 

854 if reverse_idx in mapping: 

855 factors[axis_idx] *= size 

856 break 

857 

858 global_shape = [] 

859 for i, dim in enumerate(slice_shape): 

860 global_shape.append(dim * factors[i]) 

861 self._global_shape_map[map_key] = tuple(global_shape) 

862 return tuple(global_shape) 

863 

864 def get_comm_group_by_axis(self, mesh_dim: Union[str, int]): 

865 """ 

866 Get group for specified mesh_dim. 

867 

868 Args: 

869 mesh_dim: Mesh dim or Mesh dim name. 

870 

871 Return: 

872 group: group of specified mesh dim. 

873 """ 

874 # Quick return if the current device_mesh is a 1D mesh 

875 if self.ndim == 1 and mesh_dim is None: 

876 mesh_dim = 0 

877 

878 # Convert string to axis name 

879 if isinstance(mesh_dim, str): 

880 if self._mesh_dim_names is None or len(self._mesh_dim_names) == 0: 

881 raise ValueError(f"DeviceMesh mesh_dim_names is not set, string mesh_dim {mesh_dim}, is not support.") 

882 # pylint: disable=E1135 

883 if mesh_dim not in self._mesh_dim_names: 

884 raise ValueError( 

885 f"mesh_dim can pass a string or integer, but string mesh_dim '{mesh_dim}' not found in " 

886 f"mesh_dim_names {self._mesh_dim_names}" 

887 ) 

888 mesh_dim = self._mesh_dim_names.index(mesh_dim) 

889 else: 

890 if not isinstance(mesh_dim, int) or mesh_dim < 0 or mesh_dim >= self.ndim: 

891 raise ValueError( 

892 f"mesh_dim can pass a string or integer, if not string, mesh_dim should be a integer in range " 

893 f"[0, {self.ndim}), but got {mesh_dim}" 

894 ) 

895 

896 group_name = self._dim_group_names[mesh_dim] 

897 assert group_name in _group_map, f"{group_name} not in _group_map keys {_group_map.keys()}" 

898 return _group_map[group_name] 

899 

900 def get_devices_for_axis(self, mesh_dim: Union[str, int], rank: int): 

901 """ 

902 Get the repeat rank list when the axis is not shard. 

903 

904 Args: 

905 mesh_dim (Union[str, int]): Mesh dim or dim name. 

906 rank (int): Global rank 

907 

908 Returns: 

909 list: reduce rank list 

910 """ 

911 if isinstance(mesh_dim, str): 

912 if not self._mesh_dim_names: 

913 raise ValueError("_mesh_dim_names is not set, string mesh_dim is not supported, please pass a integer.") 

914 mesh_dim_names = self._mesh_dim_names 

915 # pylint: disable=E1135 

916 if mesh_dim not in mesh_dim_names: 

917 raise ValueError(f"mesh_dim '{mesh_dim}' not found in mesh_dim_names {mesh_dim_names}") 

918 mesh_dim = mesh_dim_names.index(mesh_dim) 

919 

920 mesh_shape = self._mesh_shape 

921 if mesh_dim < 0 or mesh_dim >= self.ndim: 

922 raise ValueError(f"mesh_dim {mesh_dim} can not out of range [0, {self.ndim})") 

923 rank_list = self._rank_list 

924 if rank not in rank_list: 

925 raise ValueError(f"Rank {rank} not found in rank_list") 

926 

927 idx = rank_list.index(rank) 

928 coord = [0] * len(mesh_shape) 

929 temp = idx 

930 for i in range(len(mesh_shape) - 1, -1, -1): 

931 coord[i] = temp % mesh_shape[i] 

932 temp //= mesh_shape[i] 

933 

934 strides = [1] * len(mesh_shape) 

935 for i in range(len(mesh_shape) - 2, -1, -1): 

936 strides[i] = strides[i + 1] * mesh_shape[i + 1] 

937 

938 result_ranks = [] 

939 for v in range(mesh_shape[mesh_dim]): 

940 new_coord = coord.copy() 

941 new_coord[mesh_dim] = v 

942 new_idx = 0 

943 for i in range(len(mesh_shape)): 

944 new_idx += new_coord[i] * strides[i] 

945 

946 result_ranks.append(rank_list[new_idx]) 

947 

948 return result_ranks 

949 

950 def to_hash(self): 

951 rank_ids = (self.rank_list[0], self.rank_list[-1]) 

952 map_key = (self.mesh_shape, self.mesh_dim_names, rank_ids) 

953 return map_key 

954 

955 def __repr__(self): 

956 """__repr__""" 

957 return ( 

958 f"DeviceMesh(device_type='npu', mesh_shape={self._mesh_shape}, " 

959 f"mesh_dim_names={self._mesh_dim_names}, rank_list={self._rank_list})" 

960 ) 

961 

962 def __str__(self): 

963 """__str__""" 

964 return self.__repr__() 

965 

966 

967_DEVICE_MESH_MAP = {} 

968 

969 

970def _create_device_mesh(device_type: str, 

971 mesh_shape: tuple[int, ...], 

972 *, 

973 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None, 

974 rank_list: tuple[int, ...], 

975 init_backend: bool = True, ): 

976 """ 

977 Create or retrieve a cached DeviceMesh. 

978 

979 Args: 

980 device_type (str): Device type. 

981 mesh_shape (Tensor): A multi dimension tensor describing the device layout. 

982 mesh_dim_names (Union[tuple[str, ...], list[str], None]): A tuple of mesh dim names for each dimension. 

983 rank_list (tuple[int]): A tuple of rank. 

984 init_backend (bool): Whether to initialize the device mesh. 

985 

986 Returns: 

987 DeviceMesh: A DeviceMesh object. 

988 """ 

989 mesh = np.array(rank_list).reshape(mesh_shape) 

990 rank_ids = (rank_list[0], rank_list[-1]) 

991 mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None 

992 map_key = hash((mesh_shape, mesh_dim_names, rank_ids)) 

993 if map_key not in _DEVICE_MESH_MAP: 

994 _DEVICE_MESH_MAP[map_key] = DeviceMesh(device_type, mesh, 

995 mesh_dim_names=mesh_dim_names, 

996 _init_backend=init_backend) 

997 return _DEVICE_MESH_MAP.get(map_key, None) 

998 

999 

1000def init_device_mesh( 

1001 device_type: str, 

1002 mesh_shape: tuple[int, ...], 

1003 *, 

1004 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None, 

1005 rank_list: Optional[tuple[int, ...]] = None, 

1006 init_backend: bool = True, 

1007) -> DeviceMesh: 

1008 """ 

1009 Initialize a DeviceMesh based on mesh_shape and mesh_dim_names parameters. 

1010 

1011 This function creates a DeviceMesh with an n-dimensional array layout, where n is the 

1012 length of mesh_shape. Each dimension is labeled with the corresponding mesh_dim_names. 

1013 When rank_list is not provided, it is generated so that the current rank is included 

1014 (e.g. for onecard/simulation: base, base+1, ..., base+n-1 where base aligns to mesh size). 

1015 

1016 Compared to directly constructing DeviceMesh, init_device_mesh provides: 

1017 - Automatic mesh array generation from mesh_shape 

1018 - Caching mechanism to reuse existing DeviceMesh objects 

1019 - Validation of parameters 

1020 

1021 Args: 

1022 mesh_shape (tuple[int]): A tuple describing the dimensions of the multi-dimensional 

1023 array that describes the layout of devices. For example, (2, 4) creates 

1024 a 2D mesh with 2 rows and 4 columns. 

1025 mesh_dim_names (Union[tuple[str, ...], list[str], None]): A tuple or list string of names to assign to each 

1026 dimension of the mesh. Its length must match the length of mesh_shape. Each string must be unique. 

1027 rank_list (tuple[int], optional): Flattened list of ranks for the mesh. When None, 

1028 generated so that the current process rank is included (for onecard/simulation). 

1029 device_type (str): The type of device to create. 

1030 init_backend (bool): Whether to initialize the backend. 

1031 

1032 Returns: 

1033 DeviceMesh: A DeviceMesh object representing the device layout. 

1034 

1035 Raises: 

1036 TypeError: If mesh_shape or mesh_dim_names is not a tuple. 

1037 ValueError: If mesh_shape and mesh_dim_names have different lengths. 

1038 ValueError: If mesh_dim_names contains duplicate or empty strings. 

1039 

1040 Examples: 

1041 >>> # Create a 2D mesh with shape (2, 2) 

1042 >>> device_mesh = init_device_mesh( 

1043 ... device_type="npu", 

1044 ... mesh_shape=(2, 2), 

1045 ... mesh_dim_names=("dp", "tp") 

1046 ... ) 

1047 >>> print(device_mesh.mesh_shape) # Output: (2, 2) 

1048 >>> print(device_mesh.mesh_dim_names) # Output: ("dp", "tp") 

1049 >>> print(device_mesh.rank_list) # Output: (0, 1, 2, 3) 

1050 

1051 >>> # Get sub mesh 

1052 >>> dp_mesh = device_mesh["dp"] 

1053 >>> print(dp_mesh.mesh_shape) # Output: (2,) 

1054 

1055 >>> # Create a larger mesh 

1056 >>> mesh = init_device_mesh(device_type="npu", mesh_shape=(2, 4), mesh_dim_names=("dp", "tp")) 

1057 >>> print(mesh.rank_list) # Output: (0, 1, 2, 3, 4, 5, 6, 7) 

1058 """ 

1059 # Generate rank_list: use provided or build one that includes current rank 

1060 total_devices = int(np.prod(np.array(mesh_shape))) 

1061 if rank_list is not None: 

1062 if len(rank_list) != total_devices: 

1063 raise ValueError( 

1064 f"rank_list length ({len(rank_list)}) must equal mesh size ({total_devices})" 

1065 ) 

1066 rank_list = tuple(rank_list) 

1067 else: 

1068 current_rank = platform.get_rank() 

1069 base = current_rank - (current_rank % total_devices) 

1070 rank_list = tuple(base + i for i in range(total_devices)) 

1071 

1072 # Use the caching mechanism 

1073 return _create_device_mesh(device_type, mesh_shape, mesh_dim_names=mesh_dim_names, rank_list=rank_list, 

1074 init_backend=init_backend)