Coverage for hyper_parallel / core / device_mesh.py: 76%
442 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""device mesh"""
17import os
18from typing import Optional, Union, List, Any
19import numpy as np
20from hyper_parallel.platform import get_platform
22platform = get_platform()
23Tensor = platform.Tensor
25_group_map = {}
28def _get_sub_rank_list(mesh_shape, mesh_dim_names, rank_list, sub_mesh_dim_names, current_rank):
29 """
30 Get the sub rank list for a sub mesh.
32 Args:
33 mesh_shape (tuple[int]): The shape of the original mesh.
34 mesh_dim_names (tuple[str]): The mesh dim names of the original mesh dimensions.
35 rank_list (tuple[int]): A tuple of ranks that participate in this mesh.
36 sub_mesh_dim_names (tuple[str]): The mesh dim names of the sub mesh to extract.
37 current_rank (int): The current process rank.
39 Returns:
40 list: The sub rank list for the sub mesh.
41 """
42 # Reshape rank list into mesh tensor according to mesh shape
43 mesh_tensor = np.array(rank_list).reshape(mesh_shape)
45 # Iterate through each dimension of the original mesh
46 for dim_index, dim_name in enumerate(mesh_dim_names):
48 # Skip dimensions that are included in the sub mesh
49 if dim_name in sub_mesh_dim_names:
50 continue
52 # Split mesh tensor along current dimension
53 dim_size = mesh_shape[dim_index]
54 sliced_tensors = np.split(mesh_tensor, dim_size, axis=dim_index)
56 # Find and keep only the slice containing the current rank
57 for sliced_tensor in sliced_tensors:
58 rank_exists = np.isin(np.array([current_rank]), sliced_tensor).any()
59 if rank_exists:
60 mesh_tensor = sliced_tensor
61 break
63 # Flatten the resulting tensor to get the sub rank list
64 sub_rank_list = mesh_tensor.reshape(-1).tolist()
65 return sub_rank_list
68class DeviceMesh:
69 """
70 Topological abstraction describing cluster devices.
72 Args:
73 device_type (str): Device type.
74 mesh (Union[Tensor, list, tuple, np.ndarray]): A multi-dimensional array, list, or integer
75 tensor describing the device layout. The IDs in the mesh are global IDs of the
76 default process group, representing the multi-dimensional networking structure
77 of devices in distributed training (e.g., [[0,1],[2,3]] represents a 2x2 device mesh).
78 If a list or non-int32 tensor is provided, it will be automatically converted
79 to an int32 tensor.
80 mesh_dim_names (tuple[str]): A tuple[str] of mesh dim names for each dimension of mesh.
81 _init_backend (boolean): Whether initial process group.
83 Attributes:
84 ndim (int): Number of dimensions in the mesh.
85 mesh_shape (tuple[int]): Shape of the device mesh.
86 rank_list (tuple[int]): Flattened list of ranks from the mesh.
87 root_mesh (DeviceMesh): The parent mesh if this is a sub mesh, None otherwise.
88 sub_mesh (list[DeviceMesh]): List of child meshes created from this mesh.
90 Examples:
91 >>> # Using Tensor
92 >>> mesh = Tensor([[0, 1], [2, 3]])
93 >>> device_mesh = DeviceMesh("npu", mesh, nesh_dim_names=("dp", "tp"))
94 >>> # Using list
95 >>> device_mesh = DeviceMesh("npu", [[0, 1], [2, 3]], nesh_dim_names=("dp", "tp"))
96 >>> # Get sub mesh
97 >>> dp_mesh = device_mesh["dp"]
98 >>> # Access ndim
99 >>> print(device_mesh.ndim) # Output: 2
100 >>> print(device_mesh.mesh_shape) # Output: (2, 2)
101 >>> print(device_mesh.rank_list) # Output: (0, 1, 2, 3)
102 """
104 def __init__(self,
105 device_type: str,
106 mesh: Union[Tensor, list, tuple, np.ndarray],
107 *,
108 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None,
109 _init_backend: bool = True,
110 ):
111 self._device_type = device_type
112 # Convert mesh to Tensor with int32 dtype
113 mesh = self._convert_mesh_to_tensor(mesh)
115 # Validate mesh dimensions
116 if mesh.ndim == 0:
117 raise ValueError("mesh must be at least 1-dimensional")
119 # Extract mesh_shape and rank_list from mesh
120 self._mesh_shape = tuple(mesh.shape)
121 self._rank_list = tuple(platform.tensor_to_numpy(mesh).flatten().tolist())
122 self._mesh = mesh
123 self._dev_num = np.prod(np.array(self._mesh_shape))
124 self._dev_rank = len(self._mesh_shape)
125 # mesh_dim_names
126 self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
127 if self._mesh_dim_names is not None:
128 # Validate mesh_dim_names
129 if len(self._mesh_shape) != len(mesh_dim_names):
130 raise ValueError(
131 f'mesh dimensions ({len(self._mesh_shape)}) should be equal to '
132 f'mesh_dim_names length ({len(mesh_dim_names)})'
133 )
134 if len(set(mesh_dim_names)) != len(mesh_dim_names):
135 raise ValueError(f'Each element of mesh_dim_names {mesh_dim_names} should be different')
136 inter_key = "interleaved_parallel"
137 if inter_key in mesh_dim_names and mesh_dim_names.index(inter_key) != len(mesh_dim_names) - 1:
138 raise ValueError(
139 "'interleaved_parallel' should be at the last dim of mesh_dim_names, means virtual sharding."
140 )
141 self._dev_name_to_dev_id = {
142 name: self._dev_rank - i - 1 for i, name in enumerate(self._mesh_dim_names)
143 }
144 self._dev_name_to_index = {name: i for i, name in enumerate(self._mesh_dim_names)}
146 self._rank = platform.get_rank()
147 self._cache_rank_list_along_axis = {}
148 self._global_shape_map = {}
149 self._sub_mesh_cache = {}
150 self._flatten_mapping: dict[str, 'DeviceMesh'] = {}
151 self._ndim: int = len(self._mesh_shape)
152 self._root_mesh: Optional['DeviceMesh'] = None
153 self._sub_mesh: List['DeviceMesh'] = []
154 if _init_backend:
155 platform.init_process_group()
156 self._dim_group_names = self._init_process_groups(self._mesh_shape, self._mesh_dim_names, self._rank_list)
157 if os.getenv("MS_SIMULATION_LEVEL") is None:
158 self._coordinate_on_dim = self._compute_coordinate_on_dim()
160 def _compute_coordinate_on_dim(self):
161 # calculate the coordinates of the current global rank on the mesh
162 return self._compute_coordinates_from_mesh(self.mesh, self._rank)
164 @staticmethod
165 def _compute_coordinates_from_mesh(
166 mesh_tensor: Tensor,
167 rank: int,
168 ):
169 """
170 Compute the coordinates of a rank within a mesh tensor.
172 Args:
173 mesh_tensor (Tensor): The mesh tensor to search in
174 rank (int): The rank to find coordinates for
176 Returns:
177 A tuple of coordinates if the rank is found in the mesh, None otherwise
179 Raises:
180 AssertionError: If the rank appears more than once in the mesh
181 """
182 rank_coords = (mesh_tensor == rank).nonzero()
183 if rank_coords.shape[0] not in (0, 1):
184 raise AssertionError(
185 f"rank_coords.shape[0] must be 0 or 1, got {rank_coords.shape[0]}"
186 )
188 if rank_coords.shape[0] == 0:
189 return None
191 coords = rank_coords[0].tolist()
192 return tuple(coords)
194 def size(self, mesh_dim=None) -> int:
195 if mesh_dim is not None:
196 return self.mesh.shape[mesh_dim]
197 return self.mesh.numel()
199 def get_coordinate(self):
200 """
201 Return the relative indices of this rank relative to all
202 dimensions of the mesh. If this rank is not part of the mesh, return None.
203 """
204 return self._coordinate_on_dim if self._coordinate_on_dim else None
206 @staticmethod
207 def _convert_mesh_to_tensor(mesh: Union[Tensor, list, tuple, np.ndarray]) -> Tensor:
208 """Convert mesh to Tensor with int32 dtype."""
209 if isinstance(mesh, Tensor):
210 mesh = platform.tensor_to_numpy(mesh)
211 elif isinstance(mesh, (list, tuple)):
212 mesh = np.array(mesh)
213 elif not isinstance(mesh, np.ndarray):
214 raise TypeError(
215 f"mesh must be Tensor, list, tuple or numpy array, but got {type(mesh)}"
216 )
218 mesh = mesh.astype(np.int32)
219 return Tensor(mesh).int()
221 @staticmethod
222 def _init_one_process_group(mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...],
223 dim_name: str, rank_list: tuple[int, ...]) -> str:
224 """
225 init one process group
226 """
227 group_name = None
228 group_desc = f"mesh_{dim_name}"
229 split_ranks = set()
230 if not isinstance(dim_name, tuple):
231 dim_name = (dim_name,)
232 for rank in rank_list:
233 split_rank = _get_sub_rank_list(mesh_shape, mesh_dim_names, rank_list, dim_name, rank)
234 split_ranks.add(tuple(sorted(split_rank)))
235 split_ranks = sorted([list(item) for item in split_ranks])
236 group = platform.split_group(split_ranks=split_ranks, group_desc=group_desc)
237 if group:
238 if isinstance(group, str):
239 group_name = group
240 else:
241 group_name = group.group_name
242 _group_map[group_name] = group
243 return group_name
245 @staticmethod
246 def _init_process_groups(mesh_shape: tuple[int, ...], mesh_dim_names: Union[tuple[str, ...], None],
247 rank_list: tuple[int, ...]) -> list:
248 """
249 Init process groups. For every dim in mesh_shape, create split group for current rank.
251 Args:
252 mesh_shape (tuple[int, ...]): Shape of mesh.
253 mesh_dim_names (tuple[str, ...]): Names of every dimension of mesh.
254 rank_list (tuple[int, ...]): Rank list of current process group worked on.
255 """
256 if mesh_dim_names is None:
257 mesh_dim_names = []
258 for dim in range(len(mesh_shape)):
259 mesh_dim_names.append(f"dim_{dim}")
260 mesh_dim_names = tuple(mesh_dim_names)
262 dim_group_names = []
263 for dim in range(len(mesh_shape)):
264 dim_name = mesh_dim_names[dim]
265 dim_group_name = DeviceMesh._init_one_process_group(mesh_shape, mesh_dim_names, dim_name, rank_list)
266 dim_group_names.append(dim_group_name)
268 # Filter out None values. If any are None then they should all be None.
269 dim_non_none_group_names = [n for n in dim_group_names if n is not None]
270 assert not dim_non_none_group_names or len(dim_non_none_group_names) == len(dim_group_names)
271 return dim_non_none_group_names
273 @property
274 def mesh(self) -> Tensor:
275 """Get the mesh tensor."""
276 return self._mesh
278 def device_type(self) -> str:
279 """Get the device type."""
280 return self._device_type
282 @property
283 def rank(self):
284 return self._rank
286 @property
287 def mesh_shape(self):
288 return self._mesh_shape
290 @property
291 def mesh_dim_names(self):
292 return self._mesh_dim_names
294 @property
295 def rank_list(self):
296 return self._rank_list
298 @property
299 def ndim(self) -> int:
300 return self._ndim
302 @property
303 def root_mesh(self) -> Optional['DeviceMesh']:
304 return self._root_mesh
306 @root_mesh.setter
307 def root_mesh(self, value: Optional['DeviceMesh']):
308 """Set the parent mesh reference."""
309 self._root_mesh = value
311 @property
312 def sub_mesh(self) -> List['DeviceMesh']:
313 return self._sub_mesh
315 def get_flatten_mapping(self) -> dict:
316 """Get the flatten mapping dictionary."""
317 return self._flatten_mapping
319 def add_flatten_mapping(self, name: str, mesh: 'DeviceMesh') -> None:
320 """Add a flattened mesh to the flatten mapping."""
321 self._flatten_mapping[name] = mesh
323 def __getitem__(self, sub_mesh_dim_names: Union[str, tuple[str, ...]]) -> 'DeviceMesh':
324 """
325 Get a sub DeviceMesh based on the specified dimension names.
327 This method supports both original dimension names and flattened dimension names.
328 For example, if a mesh has dimensions ("dp", "cp", "tp") and a flattened mesh
329 "dp_cp" was created via flatten(), both mesh["dp"] and mesh["dp_cp"] are valid.
331 Args:
332 sub_mesh_dim_names: A string or tuple of strings specifying the dimension names
333 for the sub mesh. Can be original dimension names or flattened
334 dimension names registered in the root mesh's flatten_mapping.
336 Returns:
337 DeviceMesh: A new DeviceMesh representing the sub mesh.
339 Raises:
340 ValueError: If sub_mesh_dim_names is invalid or not a contiguous prefix.
341 KeyError: If sub_mesh_dim_names contains names not in mesh_dim_names or flatten_mapping.
343 Examples:
344 >>> mesh = platform.tensor([[0, 1], [2, 3]])
345 >>> device_mesh = DeviceMesh("npu", mesh, nesh_dim_names=("dp", "tp"))
346 >>> dp_mesh = device_mesh["dp"]
347 >>> print(dp_mesh.mesh_shape) # Output: (2,)
348 >>> print(dp_mesh.mesh_dim_names) # Output: ("dp",)
349 >>> # After creating a flattened mesh:
350 >>> flat_mesh = device_mesh.flatten()
351 >>> # Can also access via flattened name:
352 >>> same_flat_mesh = device_mesh["dp_tp"]
353 """
354 if not self._mesh_dim_names:
355 raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!")
357 sub_mesh_dim_names = self._normalize_sub_mesh_dim_names(sub_mesh_dim_names)
358 flatten_mapping = self._get_root_mesh().get_flatten_mapping()
360 # Try to get from flatten_mapping first
361 flattened_result = self._try_get_from_flatten_mapping(sub_mesh_dim_names, flatten_mapping)
362 if flattened_result is not None:
363 return flattened_result
365 # Validate dimension names
366 self._validate_getitem_dimensions(sub_mesh_dim_names, flatten_mapping)
368 # Get or create sub mesh for original dimensions
369 return self._get_or_create_original_sub_mesh(sub_mesh_dim_names)
371 def _normalize_sub_mesh_dim_names(self, sub_mesh_dim_names: Union[str, tuple[str, ...]]) -> tuple[str, ...]:
372 """Convert sub_mesh_dim_names to tuple format and validate basic type."""
373 if isinstance(sub_mesh_dim_names, str):
374 sub_mesh_dim_names = (sub_mesh_dim_names,)
376 if not isinstance(sub_mesh_dim_names, tuple):
377 raise TypeError(
378 f"sub_mesh_dim_names must be str or tuple, but got {type(sub_mesh_dim_names)}"
379 )
381 if len(sub_mesh_dim_names) == 0:
382 raise ValueError("sub_mesh_dim_names cannot be empty")
384 return sub_mesh_dim_names
386 def _try_get_from_flatten_mapping(self, sub_mesh_dim_names: tuple[str, ...],
387 flatten_mapping: dict) -> Optional['DeviceMesh']:
388 """Try to get mesh from flatten_mapping. Returns None if not applicable."""
389 if len(sub_mesh_dim_names) == 1 and sub_mesh_dim_names[0] in flatten_mapping:
390 return flatten_mapping[sub_mesh_dim_names[0]]
391 return None
393 def _validate_getitem_dimensions(self, sub_mesh_dim_names: tuple[str, ...], flatten_mapping: dict):
394 """Validate dimension names for __getitem__ operation."""
395 valid_dim_names = list(self._mesh_dim_names) + list(flatten_mapping.keys())
397 # Validate all names exist
398 for name in sub_mesh_dim_names:
399 if name not in valid_dim_names:
400 raise KeyError(
401 f"Dimension name '{name}' not found in mesh_dim_names {self._mesh_dim_names} "
402 f"or flatten_mapping keys {list(flatten_mapping.keys())}"
403 )
405 # Check for mixed or multiple flattened dimensions
406 original_dims = [name for name in sub_mesh_dim_names if name in self._mesh_dim_names] # pylint: disable=E1135
407 flattened_dims = [name for name in sub_mesh_dim_names if name in flatten_mapping]
409 if len(flattened_dims) == len(sub_mesh_dim_names) and len(flattened_dims) > 1:
410 raise ValueError(
411 f"Slicing multiple flattened dimensions {flattened_dims} simultaneously "
412 f"is not supported. Please slice them separately."
413 )
415 if flattened_dims and original_dims:
416 raise ValueError(
417 f"Cannot mix original dimensions {original_dims} with flattened dimensions "
418 f"{flattened_dims} in a single slice operation."
419 )
421 def _get_or_create_original_sub_mesh(self, sub_mesh_dim_names: tuple[str, ...]) -> 'DeviceMesh':
422 """Get or create sub mesh for original (non-flattened) dimensions."""
423 # Validate dimension order
424 indices = [self._mesh_dim_names.index(name) for name in sub_mesh_dim_names]
425 if indices != sorted(indices):
426 raise ValueError(
427 f"sub_mesh_dim_names {sub_mesh_dim_names} must follow the order of "
428 f"original mesh_dim_names {self._mesh_dim_names}"
429 )
431 # Check cache
432 if sub_mesh_dim_names in self._sub_mesh_cache:
433 return self._sub_mesh_cache[sub_mesh_dim_names]
435 # Return self if requesting all dimensions
436 if len(sub_mesh_dim_names) == len(self._mesh_dim_names):
437 return self
439 # Create new sub mesh
440 return self._create_and_cache_sub_mesh(sub_mesh_dim_names, indices)
442 def _create_and_cache_sub_mesh(self, sub_mesh_dim_names: tuple[str, ...],
443 indices: List[int]) -> 'DeviceMesh':
444 """Create a new sub mesh and cache it."""
445 sub_mesh_shape = tuple(self._mesh_shape[i] for i in indices)
447 sub_rank_list = _get_sub_rank_list(
448 self._mesh_shape,
449 self._mesh_dim_names,
450 self._rank_list,
451 sub_mesh_dim_names,
452 self._rank
453 )
454 sub_rank_list = tuple(sub_rank_list)
456 # Create sub mesh tensor using Tensor()
457 sub_mesh_tensor = Tensor(sub_rank_list).reshape(sub_mesh_shape)
459 # Create sub mesh
460 sub_mesh = DeviceMesh(
461 device_type="npu",
462 mesh=sub_mesh_tensor,
463 mesh_dim_names=sub_mesh_dim_names,
464 _init_backend=False
465 )
466 # Set root mesh reference
467 sub_mesh.root_mesh = self._get_root_mesh()
469 slice_dim_group_name = []
470 for name in sub_mesh_dim_names:
471 # pylint: disable=E1135
472 if name in self._mesh_dim_names:
473 slice_dim_group_name.append(
474 self._dim_group_names[self._mesh_dim_names.index(name)]
475 )
476 sub_mesh._dim_group_names = slice_dim_group_name # pylint: disable=W0212
478 # Cache and track
479 self._sub_mesh_cache[sub_mesh_dim_names] = sub_mesh
480 # Add to sub_mesh list
481 self.sub_mesh.append(sub_mesh)
483 return sub_mesh
485 def get_group(self, mesh_dim: Optional[Union[int, str]] = None):
486 """
487 Get the communication group for a specific mesh dimension.
489 Args:
490 mesh_dim: The dimension index or name. If None and mesh is 1D,
491 returns the only group. If None and mesh is multi-dimensional,
492 raises an error.
494 Returns:
495 The process group for the specified dimension.
497 Raises:
498 RuntimeError: If mesh_dim is None and mesh has more than 1 dimension.
499 ValueError: If mesh_dim is invalid.
501 Examples:
502 >>> mesh = Tensor([[0, 1], [2, 3]])
503 >>> device_mesh = DeviceMesh("npu", mesh, nesh_dim_names=("dp", "tp"))
504 >>> dp_group = device_mesh.get_group("dp")
505 >>> # or by index
506 >>> dp_group = device_mesh.get_group(0)
507 """
508 if not hasattr(self, "_dim_group_names"):
509 raise RuntimeError("DeviceMesh process groups not initialized!")
511 if self.ndim > 1 and mesh_dim is None:
512 raise RuntimeError(
513 f"Found the DeviceMesh have {self.ndim} dimensions. "
514 "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1."
515 )
517 # Check if mesh_dim is a flattened dimension name in root mesh's flatten_mapping
518 root_mesh = self._get_root_mesh()
519 if isinstance(mesh_dim, str) and mesh_dim in root_mesh.get_flatten_mapping():
520 # Return the group from the flattened mesh
521 flattened_mesh = root_mesh.get_flatten_mapping()[mesh_dim]
522 return flattened_mesh.get_comm_group_by_axis(mesh_dim)
524 return self.get_comm_group_by_axis(mesh_dim)
526 @staticmethod
527 def from_group(group: Union[Any, list[Any]],
528 device_type: str,
529 mesh: Union[Tensor, list, tuple, np.ndarray] = None,
530 mesh_dim_names: Union[tuple[str, ...], list[str]] = None
531 ) -> 'DeviceMesh':
532 """
533 Create device mesh from group or group list.
535 Args:
536 group: The group or group list to create device mesh from.
537 device_type: Device type.
538 mesh:
539 For 1d group, mesh can pass None. If group is 1d and mesh is not None, the mesh must equal to
540 group_ranks get from group, or must be a tensor which tolist value equal to group_ranks.
541 For nd group, mesh must be passed.
542 mesh_dim_names: Names of every mesh dimension.
543 """
544 if not isinstance(group, list):
545 group_ranks = platform.get_process_group_ranks(group)
546 if (
547 isinstance(mesh, Tensor) and mesh.tolist() != group_ranks
548 ) or (
549 mesh is not None
550 and not isinstance(mesh, Tensor)
551 and mesh != group_ranks
552 ):
553 raise ValueError(
554 f"Invalid mesh_shape {str(mesh)} for 1D group with ranks {group_ranks}"
555 )
556 device_mesh = DeviceMesh(device_type, group_ranks, mesh_dim_names=mesh_dim_names, _init_backend=False)
557 if isinstance(group, str):
558 # pylint: disable=W0212
559 device_mesh._dim_group_names = [group]
560 _group_map[group] = group
561 else:
562 device_mesh._dim_group_names = [group.group_name] # pylint: disable=W0212
563 _group_map[group.group_name] = group
564 return device_mesh
566 groups = list(group)
567 if len(groups) == 0:
568 raise ValueError("Expect at least one group be specified.")
569 if mesh is None:
570 raise ValueError("mesh_shape is must specified when group is a list.")
571 mesh = DeviceMesh._convert_mesh_to_tensor(mesh)
572 if mesh.ndim != len(groups):
573 raise ValueError("mesh dimensions must match group dimensions.")
574 device_mesh = DeviceMesh(device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False)
575 # pylint: disable=W0212
576 device_mesh._dim_group_names = []
577 for dim_group in groups:
578 if isinstance(dim_group, str):
579 # pylint: disable=W0212
580 device_mesh._dim_group_names.append(dim_group)
581 _group_map[dim_group] = dim_group
582 else:
583 # pylint: disable=W0212
584 device_mesh._dim_group_names.append(dim_group.group_name)
585 _group_map[dim_group.group_name] = dim_group
586 return device_mesh
588 def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
589 """
590 Get the local rank within a specific mesh dimension.
592 Args:
593 mesh_dim: The dimension index or name. If None and mesh is 1D,
594 uses dimension 0. If None and mesh is multi-dimensional,
595 raises an error.
597 Returns:
598 int: The local rank within the specified dimension.
600 Raises:
601 RuntimeError: If mesh_dim is None and mesh has more than 1 dimension.
602 ValueError: If mesh_dim is invalid or current rank not in rank_list.
604 Examples:
605 >>> mesh = Tensor([[0, 1, 2, 3], [4, 5, 6, 7]])
606 >>> device_mesh = DeviceMesh("npu", mesh, nesh_dim_names=("dp", "tp"))
607 >>> # On rank 0
608 >>> print(device_mesh.get_local_rank("dp")) # Output: 0
609 >>> print(device_mesh.get_local_rank("tp")) # Output: 0
610 """
611 if self.ndim > 1 and mesh_dim is None:
612 raise RuntimeError(
613 f"Found the DeviceMesh have {self.ndim} dimensions. "
614 "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1."
615 )
617 if mesh_dim is None:
618 mesh_dim = 0
620 # Convert string to index
621 if isinstance(mesh_dim, str):
622 # pylint: disable=E1135
623 if mesh_dim not in self._mesh_dim_names:
624 raise ValueError(
625 f"mesh_dim '{mesh_dim}' not found in mesh_dim_names {self._mesh_dim_names}"
626 )
627 dim_index = self._mesh_dim_names.index(mesh_dim)
628 else:
629 if not isinstance(mesh_dim, int) or mesh_dim < 0 or mesh_dim >= self.ndim:
630 raise ValueError(
631 f"mesh_dim must be an integer in range [0, {self.ndim}), "
632 f"but got {mesh_dim}"
633 )
634 dim_index = mesh_dim
636 if self._rank not in self._rank_list:
637 raise ValueError(
638 f"Current rank {self._rank} not found in rank_list {self._rank_list}"
639 )
641 # Calculate the coordinate of current rank in the mesh
642 idx = self._rank_list.index(self._rank)
643 coord = [0] * len(self._mesh_shape)
644 temp = idx
645 for i in range(len(self._mesh_shape) - 1, -1, -1):
646 coord[i] = temp % self._mesh_shape[i]
647 temp //= self._mesh_shape[i]
649 return coord[dim_index]
651 def flatten(self, mesh_dim_name: Optional[str] = None) -> 'DeviceMesh':
652 """
653 Returns a 1D DeviceMesh by flattening the current DeviceMesh.
655 Args:
656 mesh_dim_name (str, optional): The name for the flattened dimension.
657 If not provided, the name will be generated by joining the original
658 mesh dim names with underscore (e.g., "dp_tp" for ("dp", "tp")).
659 This name will be used as the key in the root mesh's flatten_mapping.
661 Returns:
662 DeviceMesh: A 1D DeviceMesh with flattened dimensions.
664 Raises:
665 ValueError: If mesh_dim_name conflicts with existing mesh dim names.
667 Examples:
668 >>> mesh = Tensor([[0, 1], [2, 3]])
669 >>> device_mesh = DeviceMesh("npu", mesh, nesh_dim_names=("dp", "tp"))
670 >>> # Using default name
671 >>> flat_mesh = device_mesh.flatten()
672 >>> print(flat_mesh.mesh_dim_names) # Output: ("dp_tp",)
673 >>> # Using custom name
674 >>> flat_mesh = device_mesh.flatten(mesh_dim_name="custom_name")
675 >>> print(flat_mesh.mesh_dim_names) # Output: ("custom_name",)
676 """
677 return self._create_flatten_mesh(mesh_dim_name)
679 def _get_root_mesh(self) -> 'DeviceMesh':
680 """Get the root mesh of this DeviceMesh."""
681 if self._root_mesh is None:
682 return self
683 # pylint: disable=protected-access
684 return self._root_mesh._get_root_mesh()
686 def _create_flatten_mesh(self, mesh_dim_name: Optional[str] = None) -> 'DeviceMesh':
687 """Create a flattened 1D mesh from the current mesh.
689 Args:
690 mesh_dim_name (str, optional): The name for the flattened dimension.
691 If not provided, defaults to joining mesh dim names with underscore.
692 """
693 root_mesh = self._get_root_mesh()
695 # Generate mesh_dim_name by joining mesh dim names if not provided
696 if mesh_dim_name is None:
697 mesh_dim_name = "_".join(self._mesh_dim_names)
699 # Flatten a 1D device mesh into its original mesh_dim_names will return itself
700 if self.ndim == 1 and mesh_dim_name in self._mesh_dim_names: # pylint: disable=E1135
701 return self
703 # Check whether the mesh_dim_name for flattened mesh is valid
704 # It should not conflict with existing mesh dim names in root mesh
705 invalid_dim_names = root_mesh.mesh_dim_names
706 if mesh_dim_name in invalid_dim_names:
707 raise ValueError(
708 f"'{mesh_dim_name}' already exists in the root mesh mesh_dim_names "
709 f"{invalid_dim_names}. Please specify another valid mesh_dim_name."
710 )
712 # Quick return if the flatten mesh has been created before with same layout
713 flatten_mapping = root_mesh.get_flatten_mapping()
714 if mesh_dim_name in flatten_mapping:
715 cached_mesh = flatten_mapping[mesh_dim_name]
716 # Verify the cached mesh has the expected flattened size
717 expected_size = int(np.prod(self._mesh_shape))
718 if cached_mesh.mesh_shape == (expected_size,):
719 return cached_mesh
720 raise ValueError(
721 f"Flatten mesh with mesh_dim_name '{mesh_dim_name}' has been created "
722 f"before with different layout. Please specify another valid mesh_dim_name."
723 )
725 # Calculate the flattened mesh properties
726 flattened_mesh_dim = (mesh_dim_name,)
728 # Create flattened mesh tensor using Tensor()
729 flattened_mesh_tensor = Tensor(self._rank_list)
731 # Create the flattened mesh
732 res_flattened_mesh = DeviceMesh(
733 device_type="npu",
734 mesh=flattened_mesh_tensor,
735 mesh_dim_names=flattened_mesh_dim
736 )
737 # Set root mesh reference to the actual root mesh
738 res_flattened_mesh.root_mesh = root_mesh
740 # Cache the flattened mesh in root mesh's flatten_mapping
741 root_mesh.add_flatten_mapping(mesh_dim_name, res_flattened_mesh)
742 root_mesh.sub_mesh.append(res_flattened_mesh)
744 return res_flattened_mesh
746 def axis_id(self, axis):
747 if axis == "None":
748 return -1
749 # pylint: disable=E1135
750 if axis not in self.mesh_dim_names:
751 raise ValueError(
752 f"The axis name must be one of mesh shape mesh dim name {self.mesh_dim_names}), "
753 f"but got {axis}"
754 )
755 return self._dev_name_to_dev_id[axis]
757 def axis_index(self, axis):
758 # pylint: disable=E1135
759 if axis not in self.mesh_dim_names:
760 raise ValueError(
761 f"The axis name must be one of mesh shape mesh dim name {self.mesh_dim_names}), "
762 f"but got {axis}"
763 )
764 return self._dev_name_to_index[axis]
766 def get_device_num_along_axis(self, axis):
767 """Return device num along specify device axis"""
768 # pylint: disable=E1135
769 if axis not in self.mesh_dim_names:
770 raise ValueError(
771 f"The axis must be one of device mesh dim name: {self.mesh_dim_names}, but got {axis}"
772 )
773 return self.mesh_shape[self.mesh_dim_names.index(axis)]
775 def get_rank_list_along_axis(self, mesh_dim):
776 """
777 Get the repeat rank list when the axis is not shard.
779 Args:
780 mesh_dim (str): mesh_dim name.
782 Returns:
783 list: reduce rank list
784 """
785 if mesh_dim in self._cache_rank_list_along_axis:
786 # shortcut, get rank list from cache
787 return self._cache_rank_list_along_axis[mesh_dim]
789 mesh_shape = self.mesh_shape
790 mesh_dim_names = self.mesh_dim_names
791 rank_list = self.rank_list
792 rank = self.rank
794 # pylint: disable=E1135
795 if mesh_dim not in mesh_dim_names:
796 raise ValueError(f"Axis '{mesh_dim}' not found in mesh_dim_names {mesh_dim_names}")
798 if rank not in rank_list:
799 raise ValueError(f"Rank {rank} not found in rank_list")
801 idx = rank_list.index(rank)
802 coord = [0] * len(mesh_shape)
803 temp = idx
804 for i in range(len(mesh_shape) - 1, -1, -1):
805 coord[i] = temp % mesh_shape[i]
806 temp //= mesh_shape[i]
808 dim_index = mesh_dim_names.index(mesh_dim)
809 strides = [1] * len(mesh_shape)
810 for i in range(len(mesh_shape) - 2, -1, -1):
811 strides[i] = strides[i + 1] * mesh_shape[i + 1]
813 result_ranks = []
814 for v in range(mesh_shape[dim_index]):
815 new_coord = coord.copy()
816 new_coord[dim_index] = v
817 new_idx = 0
818 for i in range(len(mesh_shape)):
819 new_idx += new_coord[i] * strides[i]
821 result_ranks.append(rank_list[new_idx])
823 self._cache_rank_list_along_axis[mesh_dim] = result_ranks
824 return result_ranks
826 def get_global_shape(self, slice_shape, tensor_map):
827 """get global shape"""
828 map_key = hash((slice_shape, tensor_map))
829 if map_key in self._global_shape_map:
830 return self._global_shape_map[map_key]
831 if tensor_map is None:
832 raise ValueError(
833 "tensor_map is not set. Please configure the tensor map by calling the layout."
834 )
835 if len(slice_shape) != len(tensor_map):
836 raise ValueError(
837 f"Length of slice_shape ({len(slice_shape)}) must match "
838 f"the length of tensor_map ({len(tensor_map)})."
839 )
841 n_dims = len(self._mesh_shape)
842 factors = [1] * len(slice_shape)
844 for dev_idx, size in enumerate(self._mesh_shape):
845 reverse_idx = n_dims - 1 - dev_idx
846 for axis_idx, mapping in enumerate(tensor_map):
847 if isinstance(mapping, int):
848 if mapping == -1:
849 continue
850 if mapping == reverse_idx:
851 factors[axis_idx] *= size
852 break
853 elif isinstance(mapping, tuple):
854 if reverse_idx in mapping:
855 factors[axis_idx] *= size
856 break
858 global_shape = []
859 for i, dim in enumerate(slice_shape):
860 global_shape.append(dim * factors[i])
861 self._global_shape_map[map_key] = tuple(global_shape)
862 return tuple(global_shape)
864 def get_comm_group_by_axis(self, mesh_dim: Union[str, int]):
865 """
866 Get group for specified mesh_dim.
868 Args:
869 mesh_dim: Mesh dim or Mesh dim name.
871 Return:
872 group: group of specified mesh dim.
873 """
874 # Quick return if the current device_mesh is a 1D mesh
875 if self.ndim == 1 and mesh_dim is None:
876 mesh_dim = 0
878 # Convert string to axis name
879 if isinstance(mesh_dim, str):
880 if self._mesh_dim_names is None or len(self._mesh_dim_names) == 0:
881 raise ValueError(f"DeviceMesh mesh_dim_names is not set, string mesh_dim {mesh_dim}, is not support.")
882 # pylint: disable=E1135
883 if mesh_dim not in self._mesh_dim_names:
884 raise ValueError(
885 f"mesh_dim can pass a string or integer, but string mesh_dim '{mesh_dim}' not found in "
886 f"mesh_dim_names {self._mesh_dim_names}"
887 )
888 mesh_dim = self._mesh_dim_names.index(mesh_dim)
889 else:
890 if not isinstance(mesh_dim, int) or mesh_dim < 0 or mesh_dim >= self.ndim:
891 raise ValueError(
892 f"mesh_dim can pass a string or integer, if not string, mesh_dim should be a integer in range "
893 f"[0, {self.ndim}), but got {mesh_dim}"
894 )
896 group_name = self._dim_group_names[mesh_dim]
897 assert group_name in _group_map, f"{group_name} not in _group_map keys {_group_map.keys()}"
898 return _group_map[group_name]
900 def get_devices_for_axis(self, mesh_dim: Union[str, int], rank: int):
901 """
902 Get the repeat rank list when the axis is not shard.
904 Args:
905 mesh_dim (Union[str, int]): Mesh dim or dim name.
906 rank (int): Global rank
908 Returns:
909 list: reduce rank list
910 """
911 if isinstance(mesh_dim, str):
912 if not self._mesh_dim_names:
913 raise ValueError("_mesh_dim_names is not set, string mesh_dim is not supported, please pass a integer.")
914 mesh_dim_names = self._mesh_dim_names
915 # pylint: disable=E1135
916 if mesh_dim not in mesh_dim_names:
917 raise ValueError(f"mesh_dim '{mesh_dim}' not found in mesh_dim_names {mesh_dim_names}")
918 mesh_dim = mesh_dim_names.index(mesh_dim)
920 mesh_shape = self._mesh_shape
921 if mesh_dim < 0 or mesh_dim >= self.ndim:
922 raise ValueError(f"mesh_dim {mesh_dim} can not out of range [0, {self.ndim})")
923 rank_list = self._rank_list
924 if rank not in rank_list:
925 raise ValueError(f"Rank {rank} not found in rank_list")
927 idx = rank_list.index(rank)
928 coord = [0] * len(mesh_shape)
929 temp = idx
930 for i in range(len(mesh_shape) - 1, -1, -1):
931 coord[i] = temp % mesh_shape[i]
932 temp //= mesh_shape[i]
934 strides = [1] * len(mesh_shape)
935 for i in range(len(mesh_shape) - 2, -1, -1):
936 strides[i] = strides[i + 1] * mesh_shape[i + 1]
938 result_ranks = []
939 for v in range(mesh_shape[mesh_dim]):
940 new_coord = coord.copy()
941 new_coord[mesh_dim] = v
942 new_idx = 0
943 for i in range(len(mesh_shape)):
944 new_idx += new_coord[i] * strides[i]
946 result_ranks.append(rank_list[new_idx])
948 return result_ranks
950 def to_hash(self):
951 rank_ids = (self.rank_list[0], self.rank_list[-1])
952 map_key = (self.mesh_shape, self.mesh_dim_names, rank_ids)
953 return map_key
955 def __repr__(self):
956 """__repr__"""
957 return (
958 f"DeviceMesh(device_type='npu', mesh_shape={self._mesh_shape}, "
959 f"mesh_dim_names={self._mesh_dim_names}, rank_list={self._rank_list})"
960 )
962 def __str__(self):
963 """__str__"""
964 return self.__repr__()
967_DEVICE_MESH_MAP = {}
970def _create_device_mesh(device_type: str,
971 mesh_shape: tuple[int, ...],
972 *,
973 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None,
974 rank_list: tuple[int, ...],
975 init_backend: bool = True, ):
976 """
977 Create or retrieve a cached DeviceMesh.
979 Args:
980 device_type (str): Device type.
981 mesh_shape (Tensor): A multi dimension tensor describing the device layout.
982 mesh_dim_names (Union[tuple[str, ...], list[str], None]): A tuple of mesh dim names for each dimension.
983 rank_list (tuple[int]): A tuple of rank.
984 init_backend (bool): Whether to initialize the device mesh.
986 Returns:
987 DeviceMesh: A DeviceMesh object.
988 """
989 mesh = np.array(rank_list).reshape(mesh_shape)
990 rank_ids = (rank_list[0], rank_list[-1])
991 mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
992 map_key = hash((mesh_shape, mesh_dim_names, rank_ids))
993 if map_key not in _DEVICE_MESH_MAP:
994 _DEVICE_MESH_MAP[map_key] = DeviceMesh(device_type, mesh,
995 mesh_dim_names=mesh_dim_names,
996 _init_backend=init_backend)
997 return _DEVICE_MESH_MAP.get(map_key, None)
1000def init_device_mesh(
1001 device_type: str,
1002 mesh_shape: tuple[int, ...],
1003 *,
1004 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None,
1005 rank_list: Optional[tuple[int, ...]] = None,
1006 init_backend: bool = True,
1007) -> DeviceMesh:
1008 """
1009 Initialize a DeviceMesh based on mesh_shape and mesh_dim_names parameters.
1011 This function creates a DeviceMesh with an n-dimensional array layout, where n is the
1012 length of mesh_shape. Each dimension is labeled with the corresponding mesh_dim_names.
1013 When rank_list is not provided, it is generated so that the current rank is included
1014 (e.g. for onecard/simulation: base, base+1, ..., base+n-1 where base aligns to mesh size).
1016 Compared to directly constructing DeviceMesh, init_device_mesh provides:
1017 - Automatic mesh array generation from mesh_shape
1018 - Caching mechanism to reuse existing DeviceMesh objects
1019 - Validation of parameters
1021 Args:
1022 mesh_shape (tuple[int]): A tuple describing the dimensions of the multi-dimensional
1023 array that describes the layout of devices. For example, (2, 4) creates
1024 a 2D mesh with 2 rows and 4 columns.
1025 mesh_dim_names (Union[tuple[str, ...], list[str], None]): A tuple or list string of names to assign to each
1026 dimension of the mesh. Its length must match the length of mesh_shape. Each string must be unique.
1027 rank_list (tuple[int], optional): Flattened list of ranks for the mesh. When None,
1028 generated so that the current process rank is included (for onecard/simulation).
1029 device_type (str): The type of device to create.
1030 init_backend (bool): Whether to initialize the backend.
1032 Returns:
1033 DeviceMesh: A DeviceMesh object representing the device layout.
1035 Raises:
1036 TypeError: If mesh_shape or mesh_dim_names is not a tuple.
1037 ValueError: If mesh_shape and mesh_dim_names have different lengths.
1038 ValueError: If mesh_dim_names contains duplicate or empty strings.
1040 Examples:
1041 >>> # Create a 2D mesh with shape (2, 2)
1042 >>> device_mesh = init_device_mesh(
1043 ... device_type="npu",
1044 ... mesh_shape=(2, 2),
1045 ... mesh_dim_names=("dp", "tp")
1046 ... )
1047 >>> print(device_mesh.mesh_shape) # Output: (2, 2)
1048 >>> print(device_mesh.mesh_dim_names) # Output: ("dp", "tp")
1049 >>> print(device_mesh.rank_list) # Output: (0, 1, 2, 3)
1051 >>> # Get sub mesh
1052 >>> dp_mesh = device_mesh["dp"]
1053 >>> print(dp_mesh.mesh_shape) # Output: (2,)
1055 >>> # Create a larger mesh
1056 >>> mesh = init_device_mesh(device_type="npu", mesh_shape=(2, 4), mesh_dim_names=("dp", "tp"))
1057 >>> print(mesh.rank_list) # Output: (0, 1, 2, 3, 4, 5, 6, 7)
1058 """
1059 # Generate rank_list: use provided or build one that includes current rank
1060 total_devices = int(np.prod(np.array(mesh_shape)))
1061 if rank_list is not None:
1062 if len(rank_list) != total_devices:
1063 raise ValueError(
1064 f"rank_list length ({len(rank_list)}) must equal mesh size ({total_devices})"
1065 )
1066 rank_list = tuple(rank_list)
1067 else:
1068 current_rank = platform.get_rank()
1069 base = current_rank - (current_rank % total_devices)
1070 rank_list = tuple(base + i for i in range(total_devices))
1072 # Use the caching mechanism
1073 return _create_device_mesh(device_type, mesh_shape, mesh_dim_names=mesh_dim_names, rank_list=rank_list,
1074 init_backend=init_backend)