Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / device_mesh.py: 55%
780 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""device mesh"""
17import copy
18import os
19import threading
20from types import TracebackType
21from typing import Any, List, Literal, Optional, Sequence, Type, Union
22import numpy as np
24from hyper_parallel.core.dtensor._mesh_layout import IntTuple, _MeshLayout, _contiguous_strides, _is_int
25from hyper_parallel.platform import get_platform
26from hyper_parallel.platform.platform import EXISTING_COMM_GROUPS, PlatformType
28platform = get_platform()
29Tensor = platform.Tensor
32class _MeshEnv(threading.local):
33 """Per-thread stack of active :class:`DeviceMesh` (PyTorch ``_mesh_resources`` parity)."""
35 def __init__(self) -> None:
36 super().__init__()
37 self.mesh_stack: List["DeviceMesh"] = []
39 def get_current_mesh(self) -> "DeviceMesh":
40 """Return the innermost active :class:`DeviceMesh` for this thread (PyTorch parity)."""
41 if len(self.mesh_stack) == 0:
42 raise RuntimeError("No device mesh is currently active!")
43 return self.mesh_stack[-1]
46_mesh_resources = _MeshEnv()
48BackendConfig = Optional[str]
51def _get_sub_rank_list(mesh_shape, mesh_dim_names, rank_list, sub_mesh_dim_names, current_rank):
52 """
53 Get the sub rank list for a sub mesh.
55 Args:
56 mesh_shape (tuple[int]): The shape of the original mesh.
57 mesh_dim_names (tuple[str]): The mesh dim names of the original mesh dimensions.
58 rank_list (tuple[int]): A tuple of ranks that participate in this mesh.
59 sub_mesh_dim_names (tuple[str]): The mesh dim names of the sub mesh to extract.
60 current_rank (int): The current process rank.
62 Returns:
63 list: The sub rank list for the sub mesh.
64 """
65 mesh_tensor = np.array(rank_list).reshape(mesh_shape)
67 for dim_index, dim_name in enumerate(mesh_dim_names):
68 if dim_name in sub_mesh_dim_names:
69 continue
71 dim_size = mesh_shape[dim_index]
72 sliced_tensors = np.split(mesh_tensor, dim_size, axis=dim_index)
74 for sliced_tensor in sliced_tensors:
75 rank_exists = np.isin(np.array([current_rank]), sliced_tensor).any()
76 if rank_exists:
77 mesh_tensor = sliced_tensor
78 break
80 sub_rank_list = mesh_tensor.reshape(-1).tolist()
81 return sub_rank_list
84def _normalize_backend_value(value: Any) -> BackendConfig:
85 if value is None:
86 return None
87 if isinstance(value, str):
88 return value
89 if isinstance(value, tuple) and len(value) > 0:
90 backend = value[0]
91 if backend is None or isinstance(backend, str):
92 return backend
93 return None
96def _normalize_backend_override(
97 backend_override: dict[Union[int, str], Any],
98 ndim: int,
99 mesh_dim_names: Optional[tuple[str, ...]] = None,
100) -> tuple[BackendConfig, ...]:
101 """Normalize backend overrides by dim index/name."""
102 remaining = dict(backend_override)
103 normalized: list[BackendConfig] = []
104 mesh_dim_names = mesh_dim_names or ()
106 for dim_idx in range(ndim):
107 dim_name = mesh_dim_names[dim_idx] if dim_idx < len(mesh_dim_names) else None
108 if dim_name is not None and dim_name in remaining:
109 if dim_idx in remaining:
110 raise RuntimeError(
111 f"Found redundant dim index {dim_idx} and name {dim_name} in backend_override"
112 )
113 normalized.append(_normalize_backend_value(remaining.pop(dim_name)))
114 elif dim_idx in remaining:
115 normalized.append(_normalize_backend_value(remaining.pop(dim_idx)))
116 else:
117 normalized.append(None)
119 if remaining:
120 raise RuntimeError(
121 f"Found invalid keys in backend_override: got {list(remaining.keys())}, "
122 f"expected integers in range [0, {ndim}) or one of {mesh_dim_names}"
123 )
124 return tuple(normalized)
127def _should_defer_group_init(sub_layout: _MeshLayout, backend_override: BackendConfig) -> bool:
128 """Whether this mesh dimension should skip eager process-group creation."""
129 return backend_override == "fake" or sub_layout.numel() == 1
132class DeviceMesh:
133 """
134 Topological abstraction describing cluster devices.
136 Args:
137 device_type (str): Device type. Valid values depend on the active platform:
139 - **PyTorch** (same as ``torch.distributed.device_mesh.DeviceMesh``):
140 ``"cpu"``, ``"cuda"``, ``"npu"``.
141 - **MindSpore** (mapped to the corresponding communication backend):
142 ``"cpu"`` → mccl, ``"gpu"`` → nccl, ``"npu"`` → hccl.
143 mesh (Union[Tensor, list, tuple, np.ndarray, None]): A multi-dimensional array, list, or integer
144 tensor describing the device layout. The IDs in the mesh are global IDs of the
145 default process group, representing the multi-dimensional networking structure
146 of devices in distributed training (e.g., [[0,1],[2,3]] represents a 2x2 device mesh).
147 If a list or non-int32 tensor is provided, it will be automatically converted
148 to an int32 tensor. If None, a 1D mesh containing all ranks
149 (i.e., ``[0, 1, ..., world_size-1]``) will be created automatically.
150 mesh_dim_names (tuple[str]): A tuple[str] of mesh dim names for each dimension of mesh.
151 _init_backend (boolean): Whether initial process group.
153 Attributes:
154 ndim (int): Number of dimensions in the mesh.
155 mesh_shape (tuple[int]): Shape of the device mesh.
156 rank_list (tuple[int]): Flattened list of ranks from the mesh.
157 root_mesh (DeviceMesh): The parent mesh if this is a sub mesh, None otherwise.
158 sub_mesh (list[DeviceMesh]): List of child meshes created from this mesh.
160 Context manager:
161 Use ``with device_mesh:`` to set the **current** mesh for this thread.
162 """
164 device_type: Literal["cpu", "cuda", "gpu", "npu"]
165 mesh: Union[Tensor, list, tuple, np.ndarray]
166 mesh_dim_names: Union[tuple[str, ...], list[str], None]
168 _VALID_DEVICE_TYPES = {
169 PlatformType.PYTORCH: {"cpu", "cuda", "npu"},
170 PlatformType.MINDSPORE: {"cpu", "gpu", "npu"},
171 }
173 def __init__(self,
174 device_type: Literal["cpu", "cuda", "gpu", "npu"],
175 mesh: Union[Tensor, list, tuple, np.ndarray, None] = None,
176 *,
177 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None,
178 _init_backend: bool = True,
179 _layout: Optional[_MeshLayout] = None,
180 _rank_map: Optional[Tensor] = None,
181 _root_mesh: Optional['DeviceMesh'] = None,
182 ):
183 self._validate_device_type(device_type)
184 self.device_type = device_type
186 if _init_backend:
187 platform.init_process_group()
189 self._layout, self._rank_map = self._resolve_layout_and_rank_map(mesh, _layout, _rank_map)
190 self._rank = platform.get_rank()
191 self._root_mesh = _root_mesh
192 self._refresh_mesh_view()
193 self._set_mesh_dim_names(mesh_dim_names)
194 self._initialize_runtime_state(_init_backend)
195 if os.getenv("MS_SIMULATION_LEVEL") is None:
196 self._coordinate_on_dim = self._compute_coordinate_on_dim()
198 @classmethod
199 def _validate_device_type(cls, device_type: str) -> None:
200 """Validate that the requested device type is supported on the active platform."""
201 valid_device_types = cls._VALID_DEVICE_TYPES.get(platform.platform_type)
202 if valid_device_types is not None and device_type not in valid_device_types:
203 raise ValueError(
204 f"Invalid device_type '{device_type}' for {platform.platform_type.name} platform. "
205 f"Valid device types are: {sorted(valid_device_types)}"
206 )
208 @classmethod
209 def _resolve_layout_and_rank_map(
210 cls,
211 mesh: Union[Tensor, list, tuple, np.ndarray, None],
212 layout: Optional[_MeshLayout],
213 rank_map: Optional[Tensor],
214 ) -> tuple[_MeshLayout, Tensor]:
215 """Build the internal layout and rank map from either public or private constructor inputs."""
216 if mesh is not None and (layout is not None or rank_map is not None):
217 raise TypeError("Cannot provide both explicit mesh and private _layout/_rank_map arguments.")
219 if mesh is None and (layout is None or rank_map is None):
220 world_size = platform.get_world_size()
221 mesh = list(range(world_size))
223 if mesh is not None:
224 mesh_tensor = cls._convert_mesh_to_tensor(mesh)
225 if mesh_tensor.ndim == 0:
226 raise ValueError("mesh must be at least 1-dimensional")
227 return cls._build_layout_from_mesh(mesh_tensor), cls._build_rank_map_from_mesh(mesh_tensor)
229 rank_map_tensor = cls._convert_rank_map_to_tensor(rank_map)
230 if layout is None or rank_map_tensor is None:
231 raise TypeError("The mesh argument is required except for private _layout/_rank_map construction.")
232 if not layout.check_non_overlap():
233 raise ValueError(f"Invalid overlapping layout {layout}.")
234 return layout, rank_map_tensor
236 def _refresh_mesh_view(self) -> None:
237 """Materialize the visible mesh tensor and the derived shape/rank metadata."""
238 full_mesh_np = self._layout.remap_to_numpy(platform.tensor_to_numpy(self._rank_map))
239 full_mesh = Tensor(full_mesh_np).int()
240 self.mesh = self._get_mesh_tensor_from_full_mesh(full_mesh, current_rank=self._rank)
241 self._mesh_shape = tuple(self.mesh.shape)
242 self._rank_list = tuple(platform.tensor_to_numpy(self.mesh).reshape(-1).tolist())
243 self._flatten_rank_map = tuple(platform.tensor_to_numpy(self._rank_map).reshape(-1).tolist())
244 self._dev_num = np.prod(np.array(self._mesh_shape))
245 self._dev_rank = len(self._mesh_shape)
247 def _set_mesh_dim_names(
248 self,
249 mesh_dim_names: Union[tuple[str, ...], list[str], None],
250 ) -> None:
251 """Validate mesh dim names and build lookup tables for named access."""
252 self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
253 if self.mesh_dim_names is None:
254 return
256 if len(self._mesh_shape) != len(self.mesh_dim_names):
257 raise ValueError(
258 f'mesh dimensions ({len(self._mesh_shape)}) should be equal to '
259 f'mesh_dim_names length ({len(self.mesh_dim_names)})'
260 )
261 if len(set(self.mesh_dim_names)) != len(self.mesh_dim_names):
262 raise ValueError(f'Each element of mesh_dim_names {self.mesh_dim_names} should be different')
263 inter_key = "interleaved_parallel"
264 if inter_key in self.mesh_dim_names and self.mesh_dim_names.index(inter_key) != len(self.mesh_dim_names) - 1:
265 raise ValueError(
266 "'interleaved_parallel' should be at the last dim of mesh_dim_names, means virtual sharding."
267 )
268 self._dev_name_to_dev_id = {
269 name: self._dev_rank - i - 1 for i, name in enumerate(self.mesh_dim_names)
270 }
271 self._dev_name_to_index = {name: i for i, name in enumerate(self.mesh_dim_names)}
273 def _initialize_runtime_state(self, init_backend: bool) -> None:
274 """Initialize caches and optional process-group state for the mesh view."""
275 self._cache_rank_list_along_axis = {}
276 self._global_shape_map = {}
277 self._sub_mesh_cache = {}
278 self._flatten_mapping: dict[str, 'DeviceMesh'] = {}
279 self._ndim = len(self._mesh_shape)
280 self._dim_group_backends = (None,) * self._ndim
281 self._dim_group_sources = tuple((self, dim) for dim in range(self._ndim))
282 self._sub_mesh: List['DeviceMesh'] = []
283 if not init_backend:
284 return
285 self._dim_group_names = self._init_process_groups(
286 self._mesh_shape,
287 self.mesh_dim_names,
288 self._rank_list,
289 )
291 @staticmethod
292 def _build_layout_from_mesh(mesh: Tensor) -> _MeshLayout:
293 mesh_shape = tuple(mesh.shape)
294 return _MeshLayout(mesh_shape, _contiguous_strides(mesh_shape))
296 @staticmethod
297 def _build_rank_map_from_mesh(mesh: Tensor) -> Tensor:
298 return Tensor(platform.tensor_to_numpy(mesh).reshape(-1)).int()
300 @staticmethod
301 def _convert_rank_map_to_tensor(rank_map: Tensor) -> Tensor:
302 if isinstance(rank_map, Tensor):
303 rank_map_np = platform.tensor_to_numpy(rank_map)
304 else:
305 rank_map_np = np.array(rank_map)
306 return Tensor(rank_map_np.reshape(-1).astype(np.int32)).int()
308 @staticmethod
309 def _get_mesh_tensor_from_full_mesh(full_mesh: Tensor, current_rank: Optional[int] = None) -> Tensor:
310 """Select the per-rank mesh view from a fully materialized layout remap."""
311 if full_mesh.shape[0] == 1:
312 return full_mesh[0]
314 if current_rank is None:
315 current_rank = platform.get_rank()
317 rank_coords = (full_mesh == current_rank).nonzero()
318 if rank_coords.shape[0] > 0:
319 return full_mesh[rank_coords[0, 0]]
320 raise RuntimeError(
321 "In order to get the mesh tensor of a DeviceMesh it needs to "
322 "either have all its original dimensions or contain the local rank."
323 )
325 def _compute_coordinate_on_dim(self):
326 """Compute the current rank coordinates inside this mesh view."""
327 return self._compute_coordinates_from_mesh(self.mesh, self._rank)
329 @staticmethod
330 def _compute_coordinates_from_mesh(
331 mesh_tensor: Tensor,
332 rank: int,
333 ):
334 """Locate one rank inside a mesh tensor and return its coordinates."""
335 rank_coords = (mesh_tensor == rank).nonzero()
336 if rank_coords.shape[0] not in (0, 1):
337 raise AssertionError(
338 f"rank_coords.shape[0] must be 0 or 1, got {rank_coords.shape[0]}"
339 )
341 if rank_coords.shape[0] == 0:
342 return None
344 coords = rank_coords[0].tolist()
345 return tuple(coords)
347 def size(self, mesh_dim=None) -> int:
348 if mesh_dim is not None:
349 return self.mesh.shape[mesh_dim]
350 return self.mesh.numel()
352 def get_coordinate(self):
353 return self._coordinate_on_dim if self._coordinate_on_dim else None
355 def __enter__(self) -> "DeviceMesh":
356 _mesh_resources.mesh_stack.append(self)
357 return self
359 def __exit__(
360 self,
361 exc_type: Optional[Type[BaseException]],
362 exc_val: Optional[BaseException],
363 exc_tb: Optional[TracebackType],
364 ) -> None:
365 _mesh_resources.mesh_stack.pop()
367 @staticmethod
368 def _convert_mesh_to_tensor(mesh: Union[Tensor, list, tuple, np.ndarray]) -> Tensor:
369 """Convert a public mesh input into an int32 platform tensor."""
370 if isinstance(mesh, Tensor):
371 mesh = platform.tensor_to_numpy(mesh)
372 elif isinstance(mesh, (list, tuple)):
373 mesh = np.array(mesh)
374 elif not isinstance(mesh, np.ndarray):
375 raise TypeError(
376 f"mesh must be Tensor, list, tuple or numpy array, but got {type(mesh)}"
377 )
379 mesh = mesh.astype(np.int32)
380 return Tensor(mesh).int()
382 @staticmethod
383 def _init_one_process_group(mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...],
384 dim_name: str, rank_list: tuple[int, ...]) -> str:
385 """Create one process-group family for the named mesh dimension."""
386 group_key = None
387 split_ranks = set()
388 if not isinstance(dim_name, tuple):
389 dim_name = (dim_name,)
390 for rank in rank_list:
391 split_rank = _get_sub_rank_list(mesh_shape, mesh_dim_names, rank_list, dim_name, rank)
392 sorted_rank = tuple(sorted(split_rank))
393 split_ranks.add(sorted_rank)
394 if rank == platform.get_rank():
395 group_key = str(sorted_rank)
396 split_ranks = sorted([list(item) for item in split_ranks])
397 platform.split_group(split_ranks=split_ranks)
398 return group_key
400 @staticmethod
401 def _build_dim_split_ranks(
402 sub_layout: _MeshLayout,
403 rank_map: Tensor,
404 ) -> tuple[list[list[int]], Optional[str]]:
405 """Build rank lists and the local cache key for one logical mesh axis."""
406 pg_ranks_by_dim = sub_layout.remap_to_numpy(platform.tensor_to_numpy(rank_map))
407 current_rank = platform.get_rank()
408 split_ranks = []
409 split_ranks_set = set()
410 group_key = None
411 for dim_mesh in np.array(pg_ranks_by_dim):
412 subgroup_ranks = tuple(int(rank) for rank in np.array(dim_mesh).reshape(-1).tolist())
413 subgroup_ranks_sorted = tuple(sorted(subgroup_ranks))
414 if subgroup_ranks_sorted not in split_ranks_set:
415 split_ranks_set.add(subgroup_ranks_sorted)
416 split_ranks.append(list(subgroup_ranks_sorted))
417 if current_rank in subgroup_ranks:
418 if group_key is not None:
419 raise RuntimeError(
420 "Each device mesh dimension should get only one process group per rank."
421 )
422 group_key = str(subgroup_ranks_sorted)
423 split_ranks = sorted(split_ranks)
424 return split_ranks, group_key
426 @staticmethod
427 def _cache_group_if_needed(group_key: Optional[str], group: Any) -> None:
428 if group_key is not None and group is not None and group_key not in EXISTING_COMM_GROUPS:
429 EXISTING_COMM_GROUPS[group_key] = group
431 @staticmethod
432 def _init_process_groups_for_layout(
433 layout: _MeshLayout,
434 rank_map: Tensor,
435 mesh_dim_names: Union[tuple[str, ...], None],
436 backend_override: Optional[tuple[BackendConfig, ...]] = None,
437 ) -> list:
438 """Initialize process groups for each top-level axis in the given layout."""
439 if mesh_dim_names is None:
440 mesh_dim_names = tuple(f"dim_{dim}" for dim in range(len(layout)))
441 if backend_override is None:
442 backend_override = (None,) * len(layout)
443 if len(backend_override) != len(layout):
444 raise ValueError(
445 f"backend_override length {len(backend_override)} must match layout rank {len(layout)}"
446 )
448 dim_group_names = []
449 for dim, sub_layout in enumerate(layout):
450 split_ranks, group_key = DeviceMesh._build_dim_split_ranks(sub_layout, rank_map)
451 if _should_defer_group_init(sub_layout, backend_override[dim]):
452 dim_group_names.append(None)
453 continue
454 group = platform.split_group(split_ranks=split_ranks)
455 DeviceMesh._cache_group_if_needed(group_key, group)
456 dim_group_names.append(group_key)
457 return dim_group_names
459 @staticmethod
460 def _init_process_groups(mesh_shape: tuple[int, ...], mesh_dim_names: Union[tuple[str, ...], None],
461 rank_list: tuple[int, ...],
462 backend_override: Optional[tuple[BackendConfig, ...]] = None) -> list:
463 layout = _MeshLayout(mesh_shape, _contiguous_strides(mesh_shape))
464 rank_map = DeviceMesh._convert_rank_map_to_tensor(rank_list)
465 return DeviceMesh._init_process_groups_for_layout(
466 layout,
467 rank_map,
468 mesh_dim_names,
469 backend_override=backend_override,
470 )
472 @property
473 def rank(self):
474 return self._rank
476 @property
477 def mesh_shape(self):
478 return self._mesh_shape
480 @property
481 def rank_list(self):
482 return self._rank_list
484 @property
485 def ndim(self) -> int:
486 return self._ndim
488 @property
489 def shape(self) -> tuple:
490 return self._mesh_shape
492 @property
493 def root_mesh(self) -> Optional['DeviceMesh']:
494 return self._root_mesh
496 @root_mesh.setter
497 def root_mesh(self, value: Optional['DeviceMesh']):
498 self._root_mesh = value
500 @property
501 def sub_mesh(self) -> List['DeviceMesh']:
502 return self._sub_mesh
504 def get_flatten_mapping(self) -> dict:
505 return self._flatten_mapping
507 def add_flatten_mapping(self, name: str, mesh: 'DeviceMesh') -> None:
508 self._flatten_mapping[name] = mesh
510 def __getitem__(self, sub_mesh_dim_names: Union[str, tuple[str, ...]]) -> 'DeviceMesh':
511 if not self.mesh_dim_names:
512 raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!")
514 sub_mesh_dim_names = DeviceMesh._normalize_sub_mesh_dim_names(sub_mesh_dim_names)
515 flatten_mapping = self._get_root_mesh().get_flatten_mapping()
517 flattened_result = self._try_get_from_flatten_mapping(sub_mesh_dim_names, flatten_mapping)
518 if flattened_result is not None:
519 return flattened_result
521 layout = self._get_slice_mesh_layout(sub_mesh_dim_names)
522 if sub_mesh_dim_names in self._sub_mesh_cache:
523 return self._sub_mesh_cache[sub_mesh_dim_names]
524 if layout == self._layout:
525 return self
526 return self._create_and_cache_sub_mesh(sub_mesh_dim_names, layout)
528 @staticmethod
529 def _normalize_sub_mesh_dim_names(sub_mesh_dim_names: Union[str, tuple[str, ...]]) -> tuple[str, ...]:
530 """Normalize a slice selector into a non-empty tuple of mesh dim names."""
531 if isinstance(sub_mesh_dim_names, str):
532 sub_mesh_dim_names = (sub_mesh_dim_names,)
534 if not isinstance(sub_mesh_dim_names, tuple):
535 raise TypeError(
536 f"sub_mesh_dim_names must be str or tuple, but got {type(sub_mesh_dim_names)}"
537 )
539 if len(sub_mesh_dim_names) == 0:
540 raise ValueError("sub_mesh_dim_names cannot be empty")
542 return sub_mesh_dim_names
544 @staticmethod
545 def _try_get_from_flatten_mapping(sub_mesh_dim_names: tuple[str, ...],
546 flatten_mapping: dict) -> Optional['DeviceMesh']:
547 if len(sub_mesh_dim_names) == 1 and sub_mesh_dim_names[0] in flatten_mapping:
548 return flatten_mapping[sub_mesh_dim_names[0]]
549 return None
551 def _get_mesh_dim_by_name(self, mesh_dim_name: str) -> int:
552 """Resolve a named mesh axis to its integer position."""
553 mesh_dim_names = self.mesh_dim_names or ()
554 if len(mesh_dim_names) == 0:
555 raise KeyError("No mesh_dim_names found.")
556 if mesh_dim_name not in mesh_dim_names:
557 raise KeyError(
558 f"Mesh dimension '{mesh_dim_name}' does not exist. "
559 f"Available mesh dimensions are: {mesh_dim_names}"
560 )
561 return mesh_dim_names.index(mesh_dim_name)
563 def _get_slice_mesh_layout(self, sub_mesh_dim_names: tuple[str, ...]) -> _MeshLayout:
564 """Construct the layout corresponding to one named sub-mesh slice request."""
565 root_mesh = self._get_root_mesh()
566 slice_from_root = self == root_mesh
567 flatten_name_to_layout = (
568 {key: mesh._layout for key, mesh in root_mesh.get_flatten_mapping().items()}
569 if slice_from_root else {}
570 )
571 valid_dim_names = [*(self.mesh_dim_names or ()), *flatten_name_to_layout]
572 if not all(name in valid_dim_names for name in sub_mesh_dim_names):
573 raise KeyError(
574 f"Invalid mesh_dim_names {sub_mesh_dim_names} specified. "
575 f"Valid mesh_dim_names are {valid_dim_names}."
576 )
578 if all(name in (self.mesh_dim_names or ()) for name in sub_mesh_dim_names):
579 indices = [self.mesh_dim_names.index(name) for name in sub_mesh_dim_names]
580 if indices != sorted(indices):
581 raise ValueError(
582 f"sub_mesh_dim_names {sub_mesh_dim_names} must follow the order of "
583 f"original mesh_dim_names {self.mesh_dim_names}"
584 )
586 sliced_sizes: list[IntTuple] = []
587 sliced_strides: list[IntTuple] = []
588 for name in sub_mesh_dim_names:
589 if name in (self.mesh_dim_names or ()):
590 layout = self._layout[self.mesh_dim_names.index(name)]
591 else:
592 layout = flatten_name_to_layout[name]
593 sliced_sizes.append(layout.sizes)
594 sliced_strides.append(layout.strides)
596 pre_stride = -1
597 for stride in reversed(sliced_strides):
598 if not _is_int(stride):
599 raise NotImplementedError(
600 "Currently, this only allows slicing out a contiguous flattened dim."
601 )
602 if stride < pre_stride:
603 raise ValueError(
604 f"Invalid mesh_dim_names {sub_mesh_dim_names} specified. "
605 "Mesh dim indices should be in ascending order."
606 )
607 pre_stride = stride
609 if len(sliced_sizes) == 1:
610 layout = _MeshLayout(sliced_sizes[0], sliced_strides[0])
611 else:
612 layout = _MeshLayout(tuple(sliced_sizes), tuple(sliced_strides))
613 if not layout.check_non_overlap():
614 raise RuntimeError(f"Slicing overlapping dim_names {sub_mesh_dim_names} is not allowed.")
615 return layout
617 def _create_and_cache_sub_mesh(self, sub_mesh_dim_names: tuple[str, ...], layout: _MeshLayout) -> 'DeviceMesh':
618 """Create a sub-mesh view, copy group metadata, and cache the result."""
619 root_mesh = self._get_root_mesh()
620 sub_mesh = DeviceMesh(
621 device_type=self.device_type,
622 mesh_dim_names=sub_mesh_dim_names,
623 _init_backend=False,
624 _layout=layout,
625 _rank_map=root_mesh._rank_map,
626 _root_mesh=root_mesh,
627 )
629 slice_dim_group_name = []
630 slice_dim_group_backends: list[BackendConfig] = []
631 slice_dim_group_sources: list[tuple['DeviceMesh', int]] = []
632 for name in sub_mesh_dim_names:
633 if name in (self.mesh_dim_names or ()):
634 dim_index = self.mesh_dim_names.index(name)
635 if hasattr(self, "_dim_group_names"):
636 slice_dim_group_name.append(self._dim_group_names[dim_index])
637 slice_dim_group_backends.append(self._dim_group_backends[dim_index])
638 if hasattr(self, "_dim_group_sources"):
639 slice_dim_group_sources.append(self._dim_group_sources[dim_index]) # pylint: disable=W0212
640 else:
641 slice_dim_group_sources.append((self, dim_index))
642 elif name in root_mesh.get_flatten_mapping():
643 flatten_mesh = root_mesh.get_flatten_mapping()[name]
644 if hasattr(flatten_mesh, "_dim_group_names"):
645 slice_dim_group_name.append(flatten_mesh._dim_group_names[0])
646 slice_dim_group_backends.append(flatten_mesh._dim_group_backends[0])
647 if hasattr(flatten_mesh, "_dim_group_sources"):
648 slice_dim_group_sources.append(flatten_mesh._dim_group_sources[0]) # pylint: disable=W0212
649 else:
650 slice_dim_group_sources.append((flatten_mesh, 0))
651 if slice_dim_group_name:
652 sub_mesh._dim_group_names = slice_dim_group_name # pylint: disable=W0212
653 if slice_dim_group_backends:
654 sub_mesh._dim_group_backends = tuple(slice_dim_group_backends) # pylint: disable=W0212
655 if slice_dim_group_sources:
656 sub_mesh._dim_group_sources = tuple(slice_dim_group_sources) # pylint: disable=W0212
658 self._sub_mesh_cache[sub_mesh_dim_names] = sub_mesh
659 self.sub_mesh.append(sub_mesh)
660 return sub_mesh
662 def get_group(self, mesh_dim: Optional[Union[int, str]] = None):
663 """Return the communication group for one mesh axis."""
664 if not hasattr(self, "_dim_group_names"):
665 raise RuntimeError("DeviceMesh process groups not initialized!")
667 if self.ndim > 1 and mesh_dim is None:
668 raise RuntimeError(
669 f"Found the DeviceMesh have {self.ndim} dimensions. "
670 "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1."
671 )
673 root_mesh = self._get_root_mesh()
674 if isinstance(mesh_dim, str) and mesh_dim in root_mesh.get_flatten_mapping():
675 flattened_mesh = root_mesh.get_flatten_mapping()[mesh_dim]
676 return flattened_mesh.get_comm_group_by_axis(mesh_dim)
678 return self.get_comm_group_by_axis(mesh_dim)
680 def get_all_groups(self) -> list:
681 if not hasattr(self, "_dim_group_names"):
682 raise RuntimeError("DeviceMesh process groups not initialized!")
684 return [self.get_group(i) for i in range(self.ndim)]
686 @staticmethod
687 def from_group(group: Union[Any, list[Any]],
688 device_type: str,
689 mesh: Union[Tensor, list, tuple, np.ndarray] = None,
690 mesh_dim_names: Union[tuple[str, ...], list[str]] = None
691 ) -> 'DeviceMesh':
692 """Build a DeviceMesh from an existing process group or a list of groups."""
693 if not isinstance(group, list):
694 group_ranks = platform.get_process_group_ranks(group)
695 group_key = str(tuple(sorted(group_ranks)))
696 if not platform.get_created_group(group_ranks):
697 EXISTING_COMM_GROUPS[group_key] = group
698 if (
699 isinstance(mesh, Tensor) and mesh.tolist() != group_ranks
700 ) or (
701 mesh is not None
702 and not isinstance(mesh, Tensor)
703 and mesh != group_ranks
704 ):
705 raise ValueError(
706 f"Invalid mesh_shape {str(mesh)} for 1D group with ranks {group_ranks}"
707 )
708 device_mesh = DeviceMesh(device_type, group_ranks, mesh_dim_names=mesh_dim_names, _init_backend=False)
709 device_mesh._dim_group_names = [group_key] # pylint: disable=W0212
710 return device_mesh
712 groups = list(group)
713 if len(groups) == 0:
714 raise ValueError("Expect at least one group be specified.")
715 if mesh is None:
716 raise ValueError("mesh_shape is must specified when group is a list.")
717 mesh = DeviceMesh._convert_mesh_to_tensor(mesh)
718 if mesh.ndim != len(groups):
719 raise ValueError("mesh dimensions must match group dimensions.")
720 device_mesh = DeviceMesh(device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False)
721 device_mesh._dim_group_names = [] # pylint: disable=W0212
722 for dim_group in groups:
723 group_ranks = platform.get_process_group_ranks(dim_group)
724 group_key = str(tuple(sorted(group_ranks)))
725 if not platform.get_created_group(group_ranks):
726 EXISTING_COMM_GROUPS[group_key] = dim_group
727 device_mesh._dim_group_names.append(group_key) # pylint: disable=W0212
728 return device_mesh
730 def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
731 """Return the local coordinate of the current rank along one mesh dimension."""
732 if self.ndim > 1 and mesh_dim is None:
733 raise RuntimeError(
734 f"Found the DeviceMesh have {self.ndim} dimensions. "
735 "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1."
736 )
738 if mesh_dim is None:
739 mesh_dim = 0
741 if isinstance(mesh_dim, str):
742 if mesh_dim not in self.mesh_dim_names: # pylint: disable=E1135
743 raise ValueError(
744 f"mesh_dim '{mesh_dim}' not found in mesh_dim_names {self.mesh_dim_names}"
745 )
746 dim_index = self.mesh_dim_names.index(mesh_dim)
747 else:
748 if not isinstance(mesh_dim, int) or mesh_dim < 0 or mesh_dim >= self.ndim:
749 raise ValueError(
750 f"mesh_dim must be an integer in range [0, {self.ndim}), "
751 f"but got {mesh_dim}"
752 )
753 dim_index = mesh_dim
755 if self._rank not in self._rank_list:
756 raise ValueError(
757 f"Current rank {self._rank} not found in rank_list {self._rank_list}"
758 )
760 idx = self._rank_list.index(self._rank)
761 coord = [0] * len(self._mesh_shape)
762 temp = idx
763 for i in range(len(self._mesh_shape) - 1, -1, -1):
764 coord[i] = temp % self._mesh_shape[i]
765 temp //= self._mesh_shape[i]
767 return coord[dim_index]
769 def flatten(self, mesh_dim_name: Optional[str] = None) -> 'DeviceMesh':
770 return self._create_flatten_mesh(mesh_dim_name)
772 def _get_root_mesh(self) -> 'DeviceMesh':
773 """Return the canonical root mesh for this view."""
774 if self._root_mesh is None:
775 return self
776 return self._root_mesh._get_root_mesh() # pylint: disable=protected-access
778 @staticmethod
779 def _validate_concatenate_inputs(
780 meshes: Sequence['DeviceMesh'],
781 ) -> tuple['DeviceMesh', tuple[str, ...], tuple[int, ...]]:
782 """Validate concatenate inputs and return the shared root metadata."""
783 if len(meshes) == 0:
784 raise ValueError("DeviceMesh.concatenate expects at least one mesh.")
785 if len(meshes) == 1:
786 return meshes[0]._get_root_mesh(), tuple(meshes[0].mesh_dim_names or ()), meshes[0]._flatten_rank_map
788 root_mesh = meshes[0]._get_root_mesh() # pylint: disable=protected-access
789 requested_dim_names: list[str] = []
790 flatten_rank_map = meshes[0]._flatten_rank_map # pylint: disable=protected-access
791 for mesh in meshes:
792 if mesh._get_root_mesh().to_hash() != root_mesh.to_hash(): # pylint: disable=protected-access
793 raise ValueError("DeviceMesh.concatenate expects all meshes to share the same root mesh.")
794 if mesh._flatten_rank_map != flatten_rank_map: # pylint: disable=protected-access
795 raise ValueError("DeviceMesh.concatenate expects all meshes to share the same root mesh.")
796 if not mesh.mesh_dim_names:
797 raise ValueError("DeviceMesh.concatenate requires mesh_dim_names on every input mesh.")
798 requested_dim_names.extend(mesh.mesh_dim_names)
799 return root_mesh, tuple(requested_dim_names), flatten_rank_map
801 @staticmethod
802 def _validate_concatenate_root_order(root_mesh: 'DeviceMesh', requested_dim_names: tuple[str, ...]) -> None:
803 """Require original root dims to stay in root order when concatenating by name."""
804 root_dim_names = tuple(root_mesh.mesh_dim_names) if root_mesh.mesh_dim_names else ()
805 if not root_dim_names or not all(dim_name in root_dim_names for dim_name in requested_dim_names):
806 return
808 requested_indices = [root_dim_names.index(dim_name) for dim_name in requested_dim_names]
809 if requested_indices != sorted(requested_indices):
810 raise ValueError(
811 "DeviceMesh.concatenate expects meshes to follow the root mesh order. "
812 f"Got root mesh dims {root_dim_names} and requested dims {requested_dim_names}."
813 )
815 @staticmethod
816 def _collect_concatenate_metadata(
817 meshes: Sequence['DeviceMesh'],
818 ) -> tuple[
819 list[str],
820 list[IntTuple],
821 list[IntTuple],
822 list[Optional[str]],
823 list[BackendConfig],
824 list[tuple['DeviceMesh', int]],
825 ]:
826 """Collect layout and process-group metadata from all concatenate inputs."""
827 concat_dim_names: list[str] = []
828 concat_sizes: list[IntTuple] = []
829 concat_strides: list[IntTuple] = []
830 concat_dim_group_names: list[Optional[str]] = []
831 concat_dim_group_backends: list[BackendConfig] = []
832 concat_dim_group_sources: list[tuple['DeviceMesh', int]] = []
834 for mesh in meshes:
835 for dim, sub_layout in enumerate(mesh._layout): # pylint: disable=protected-access
836 concat_sizes.append(sub_layout.sizes)
837 concat_strides.append(sub_layout.strides)
838 if hasattr(mesh, "_dim_group_names"):
839 concat_dim_group_names.append(mesh._dim_group_names[dim]) # pylint: disable=protected-access
840 concat_dim_group_backends.append(mesh._dim_group_backends[dim]) # pylint: disable=protected-access
841 if hasattr(mesh, "_dim_group_sources"):
842 concat_dim_group_sources.append(mesh._dim_group_sources[dim]) # pylint: disable=protected-access
843 else:
844 concat_dim_group_sources.append((mesh, dim))
845 concat_dim_names.extend(mesh.mesh_dim_names)
847 if len(set(concat_dim_names)) != len(concat_dim_names):
848 raise ValueError(
849 f"DeviceMesh.concatenate expects disjoint mesh dims, but got {tuple(concat_dim_names)}."
850 )
851 return (
852 concat_dim_names,
853 concat_sizes,
854 concat_strides,
855 concat_dim_group_names,
856 concat_dim_group_backends,
857 concat_dim_group_sources,
858 )
860 @staticmethod
861 def _build_concatenate_layout(concat_sizes: list[IntTuple], concat_strides: list[IntTuple]) -> _MeshLayout:
862 """Build the layout represented by concatenated top-level mesh axes."""
863 if len(concat_sizes) == 1:
864 return _MeshLayout(concat_sizes[0], concat_strides[0])
865 return _MeshLayout(tuple(concat_sizes), tuple(concat_strides))
867 @staticmethod
868 def _set_concatenated_group_state(
869 mesh: 'DeviceMesh',
870 dim_group_names: list[Optional[str]],
871 dim_group_backends: list[BackendConfig],
872 dim_group_sources: list[tuple['DeviceMesh', int]],
873 ) -> None:
874 """Attach inherited process-group metadata to a concatenated mesh view."""
875 if dim_group_names:
876 mesh._dim_group_names = dim_group_names # pylint: disable=W0212
877 if dim_group_backends:
878 mesh._dim_group_backends = tuple(dim_group_backends) # pylint: disable=W0212
879 if dim_group_sources:
880 mesh._dim_group_sources = tuple(dim_group_sources) # pylint: disable=W0212
882 @staticmethod
883 def concatenate(meshes: Sequence['DeviceMesh']) -> 'DeviceMesh':
884 """Concatenate multiple sub-mesh views into one wider layout-backed mesh."""
885 if len(meshes) == 1:
886 return meshes[0]
887 root_mesh, requested_dim_names, _ = DeviceMesh._validate_concatenate_inputs(meshes)
888 DeviceMesh._validate_concatenate_root_order(root_mesh, requested_dim_names)
889 (
890 concat_dim_names,
891 concat_sizes,
892 concat_strides,
893 concat_dim_group_names,
894 concat_dim_group_backends,
895 concat_dim_group_sources,
896 ) = DeviceMesh._collect_concatenate_metadata(meshes)
897 concat_layout = DeviceMesh._build_concatenate_layout(concat_sizes, concat_strides)
898 if not concat_layout.check_non_overlap():
899 raise ValueError(f"Cannot concatenate overlapping meshes: {meshes}")
901 res_mesh = DeviceMesh(
902 meshes[0].device_type,
903 mesh_dim_names=tuple(concat_dim_names),
904 _init_backend=False,
905 _layout=concat_layout,
906 _rank_map=meshes[0]._rank_map, # pylint: disable=protected-access
907 _root_mesh=meshes[0]._get_root_mesh(), # pylint: disable=protected-access
908 )
909 DeviceMesh._set_concatenated_group_state(
910 res_mesh,
911 concat_dim_group_names,
912 concat_dim_group_backends,
913 concat_dim_group_sources,
914 )
915 return res_mesh
917 _concatenate = concatenate
919 def _create_flatten_mesh(
920 self,
921 mesh_dim_name: Optional[str] = None,
922 backend_override: BackendConfig = None,
923 ) -> 'DeviceMesh':
924 """Create or reuse a flattened one-dimensional mesh view."""
925 root_mesh = self._get_root_mesh()
927 if mesh_dim_name is None:
928 mesh_dim_name = "_".join(self.mesh_dim_names)
930 if self.ndim == 1 and mesh_dim_name in self.mesh_dim_names: # pylint: disable=E1135
931 return self
933 invalid_dim_names = root_mesh.mesh_dim_names
934 if mesh_dim_name in invalid_dim_names:
935 raise ValueError(
936 f"'{mesh_dim_name}' already exists in the root mesh mesh_dim_names "
937 f"{invalid_dim_names}. Please specify another valid mesh_dim_name."
938 )
940 flattened_mesh_layout = self._layout.coalesce()
941 if len(flattened_mesh_layout) > 1:
942 flattened_mesh_layout = flattened_mesh_layout.nest()
944 flatten_mapping = root_mesh.get_flatten_mapping()
945 if mesh_dim_name in flatten_mapping:
946 cached_mesh = flatten_mapping[mesh_dim_name]
947 if cached_mesh._layout == flattened_mesh_layout: # pylint: disable=protected-access
948 return cached_mesh
949 raise ValueError(
950 f"Flatten mesh with mesh_dim_name '{mesh_dim_name}' has been created "
951 f"before with different layout. Please specify another valid mesh_dim_name."
952 )
954 res_flattened_mesh = DeviceMesh(
955 device_type=root_mesh.device_type,
956 mesh_dim_names=(mesh_dim_name,),
957 _init_backend=False,
958 _layout=flattened_mesh_layout,
959 _rank_map=root_mesh._rank_map,
960 _root_mesh=root_mesh,
961 )
962 res_flattened_mesh._dim_group_backends = (backend_override,) # pylint: disable=W0212
963 if hasattr(self, "_dim_group_names"):
964 res_flattened_mesh._dim_group_names = DeviceMesh._init_process_groups_for_layout( # pylint: disable=W0212
965 res_flattened_mesh._layout,
966 root_mesh._rank_map,
967 res_flattened_mesh.mesh_dim_names,
968 backend_override=(backend_override,),
969 )
971 root_mesh.add_flatten_mapping(mesh_dim_name, res_flattened_mesh)
972 root_mesh._sub_mesh_cache[(mesh_dim_name,)] = res_flattened_mesh # pylint: disable=W0212
973 root_mesh.sub_mesh.append(res_flattened_mesh)
975 return res_flattened_mesh
977 def _create_unflatten_mesh(
978 self,
979 dim: int,
980 mesh_sizes: tuple[int, ...],
981 mesh_dim_names: tuple[str, ...],
982 backend_override: tuple[BackendConfig, ...],
983 ) -> 'DeviceMesh':
984 """Split one logical mesh axis into multiple named axes."""
985 inner_layout = _MeshLayout(mesh_sizes, _contiguous_strides(mesh_sizes))
986 original_layout = self._layout[dim]
987 if inner_layout.numel() != original_layout.numel():
988 raise ValueError(
989 f"The product of mesh_sizes={mesh_sizes} is {inner_layout.numel()}, "
990 f"but the original dimension at dim={dim} has size {original_layout.numel()}."
991 )
993 partial_layout = original_layout.composition(inner_layout)
994 unflattened_layout = self._layout.splice(dim, dim + 1, partial_layout)
995 unflattened_mesh_dim_names = list(self.mesh_dim_names or ())
996 unflattened_mesh_dim_names[dim: dim + 1] = list(mesh_dim_names)
998 root_mesh = self._get_root_mesh()
999 res_mesh = DeviceMesh(
1000 self.device_type,
1001 mesh_dim_names=tuple(unflattened_mesh_dim_names),
1002 _init_backend=False,
1003 _layout=unflattened_layout,
1004 _rank_map=root_mesh._rank_map,
1005 _root_mesh=root_mesh,
1006 )
1008 dim_group_backends = list(self._dim_group_backends)
1009 dim_group_backends[dim: dim + 1] = list(backend_override)
1010 res_mesh._dim_group_backends = tuple(dim_group_backends) # pylint: disable=W0212
1012 if hasattr(self, "_dim_group_names"):
1013 dim_group_names = list(self._dim_group_names)
1014 dim_group_names[dim: dim + 1] = DeviceMesh._init_process_groups_for_layout(
1015 partial_layout,
1016 root_mesh._rank_map,
1017 mesh_dim_names,
1018 backend_override=backend_override,
1019 )
1020 res_mesh._dim_group_names = dim_group_names # pylint: disable=W0212
1022 return res_mesh
1024 def _flatten(self, mesh_dim_name: Optional[str] = None, backend_override: Any = None) -> 'DeviceMesh':
1025 return self._create_flatten_mesh(
1026 mesh_dim_name,
1027 backend_override=_normalize_backend_value(backend_override),
1028 )
1030 def _unflatten(
1031 self,
1032 dim: Union[int, str],
1033 mesh_sizes: tuple[int, ...],
1034 mesh_dim_names: tuple[str, ...],
1035 backend_override: Optional[dict[Union[int, str], Any]] = None,
1036 ) -> 'DeviceMesh':
1037 """Torch-compatible helper that expands one mesh axis into a nested layout."""
1038 if isinstance(dim, int):
1039 if dim < 0 or dim >= self.ndim:
1040 raise ValueError(f"dim {dim} specified in `_unflatten` is out of range {self.ndim}")
1041 else:
1042 mesh_dim_names_tuple = self.mesh_dim_names or ()
1043 if dim not in mesh_dim_names_tuple:
1044 raise ValueError(f"dim {dim} specified in `_unflatten` is not in {mesh_dim_names_tuple}")
1045 dim = mesh_dim_names_tuple.index(dim)
1047 if len(mesh_sizes) != len(mesh_dim_names):
1048 raise RuntimeError("mesh_dim_names must have same length as mesh_sizes in _unflatten!")
1050 backend_override_tuple = (
1051 _normalize_backend_override(backend_override, len(mesh_sizes), mesh_dim_names)
1052 if backend_override is not None
1053 else (None,) * len(mesh_dim_names)
1054 )
1055 return self._create_unflatten_mesh(dim, mesh_sizes, mesh_dim_names, backend_override_tuple)
1057 def assert_axis(self, axis, operate_name):
1058 if not self.mesh_dim_names:
1059 raise RuntimeError(f"mesh_dim_names not specified, {operate_name} is not supported.")
1060 if axis not in self.mesh_dim_names: # pylint: disable=E1135
1061 raise ValueError(
1062 f"The axis name must be one of mesh dim name {self.mesh_dim_names}, but got {axis}"
1063 )
1065 def axis_id(self, axis):
1066 if axis == "None":
1067 return -1
1068 self.assert_axis(axis, "axis_id")
1069 return self._dev_name_to_dev_id[axis]
1071 def axis_index(self, axis):
1072 self.assert_axis(axis, "axis_index")
1073 return self._dev_name_to_index[axis]
1075 def get_device_num_along_axis(self, axis):
1076 self.assert_axis(axis, "get_device_num_along_axis")
1077 return self.mesh_shape[self.mesh_dim_names.index(axis)]
1079 def get_rank_list_along_axis(self, mesh_dim):
1080 """Return the ranks that share every other coordinate with the current rank."""
1081 if mesh_dim in self._cache_rank_list_along_axis:
1082 return self._cache_rank_list_along_axis[mesh_dim]
1083 self.assert_axis(mesh_dim, "get_rank_list_along_axis")
1085 mesh_shape = self.mesh_shape
1086 mesh_dim_names = self.mesh_dim_names
1087 rank_list = self.rank_list
1088 rank = self.rank
1090 if rank not in rank_list:
1091 raise ValueError(f"Rank {rank} not found in rank_list")
1093 idx = rank_list.index(rank)
1094 coord = [0] * len(mesh_shape)
1095 temp = idx
1096 for i in range(len(mesh_shape) - 1, -1, -1):
1097 coord[i] = temp % mesh_shape[i]
1098 temp //= mesh_shape[i]
1100 dim_index = mesh_dim_names.index(mesh_dim)
1101 strides = [1] * len(mesh_shape)
1102 for i in range(len(mesh_shape) - 2, -1, -1):
1103 strides[i] = strides[i + 1] * mesh_shape[i + 1]
1105 result_ranks = []
1106 for v in range(mesh_shape[dim_index]):
1107 new_coord = coord.copy()
1108 new_coord[dim_index] = v
1109 new_idx = 0
1110 for i in range(len(mesh_shape)):
1111 new_idx += new_coord[i] * strides[i]
1113 result_ranks.append(rank_list[new_idx])
1115 self._cache_rank_list_along_axis[mesh_dim] = result_ranks
1116 return result_ranks
1118 def get_global_shape(self, slice_shape, tensor_map):
1119 """Infer the global tensor shape from a shard shape and tensor-map metadata."""
1120 map_key = hash((slice_shape, tensor_map))
1121 if map_key in self._global_shape_map:
1122 return self._global_shape_map[map_key]
1123 if tensor_map is None:
1124 raise ValueError(
1125 "tensor_map is not set. Please configure the tensor map by calling the layout."
1126 )
1127 if len(slice_shape) != len(tensor_map):
1128 raise ValueError(
1129 f"Length of slice_shape ({len(slice_shape)}) must match "
1130 f"the length of tensor_map ({len(tensor_map)})."
1131 )
1133 n_dims = len(self._mesh_shape)
1134 factors = [1] * len(slice_shape)
1136 for dev_idx, size in enumerate(self._mesh_shape):
1137 reverse_idx = n_dims - 1 - dev_idx
1138 for axis_idx, mapping in enumerate(tensor_map):
1139 if isinstance(mapping, int):
1140 if mapping == -1:
1141 continue
1142 if mapping == reverse_idx:
1143 factors[axis_idx] *= size
1144 break
1145 elif isinstance(mapping, tuple):
1146 if reverse_idx in mapping:
1147 factors[axis_idx] *= size
1148 break
1150 global_shape = []
1151 for i, dim in enumerate(slice_shape):
1152 global_shape.append(dim * factors[i])
1153 self._global_shape_map[map_key] = tuple(global_shape)
1154 return tuple(global_shape)
1156 def _materialize_dim_group(self, mesh_dim: int) -> Optional[str]:
1157 """Create a deferred process group for one mesh dimension on first use."""
1158 if not hasattr(self, "_dim_group_names"):
1159 self._dim_group_names = [None] * self.ndim # pylint: disable=W0201
1161 if hasattr(self, "_dim_group_sources"):
1162 source_mesh, source_dim = self._dim_group_sources[mesh_dim] # pylint: disable=W0212
1163 if source_mesh is not self or source_dim != mesh_dim:
1164 source_group_key = source_mesh._materialize_dim_group(source_dim) # pylint: disable=W0212
1165 self._dim_group_names[mesh_dim] = source_group_key
1166 return source_group_key
1168 group_key = self._dim_group_names[mesh_dim]
1169 if group_key is not None and group_key in EXISTING_COMM_GROUPS:
1170 return group_key
1172 split_ranks, group_key = DeviceMesh._build_dim_split_ranks(self._layout[mesh_dim], self._rank_map)
1173 group = platform.split_group(split_ranks=split_ranks)
1174 DeviceMesh._cache_group_if_needed(group_key, group)
1175 self._dim_group_names[mesh_dim] = group_key
1176 return group_key
1178 def get_comm_group_by_axis(self, mesh_dim: Union[str, int]):
1179 """Return the cached or lazily materialized process group for one mesh axis."""
1180 if self.ndim == 1 and mesh_dim is None:
1181 mesh_dim = 0
1183 if isinstance(mesh_dim, str):
1184 if self.mesh_dim_names is None or len(self.mesh_dim_names) == 0:
1185 raise ValueError(f"DeviceMesh mesh_dim_names is not set, string mesh_dim {mesh_dim}, is not support.")
1186 if mesh_dim not in self.mesh_dim_names: # pylint: disable=E1135
1187 raise ValueError(
1188 f"mesh_dim can pass a string or integer, but string mesh_dim '{mesh_dim}' not found in "
1189 f"mesh_dim_names {self.mesh_dim_names}"
1190 )
1191 mesh_dim = self.mesh_dim_names.index(mesh_dim)
1192 else:
1193 if not isinstance(mesh_dim, int) or mesh_dim < 0 or mesh_dim >= self.ndim:
1194 raise ValueError(
1195 f"mesh_dim can pass a string or integer, if not string, mesh_dim should be a integer in range "
1196 f"[0, {self.ndim}), but got {mesh_dim}"
1197 )
1199 if not hasattr(self, "_dim_group_names"):
1200 raise RuntimeError("DeviceMesh process groups not initialized!")
1202 group_key = self._dim_group_names[mesh_dim]
1203 if group_key is None or group_key not in EXISTING_COMM_GROUPS:
1204 group_key = self._materialize_dim_group(mesh_dim)
1205 if group_key not in EXISTING_COMM_GROUPS:
1206 raise ValueError(f"{group_key} not in group cache {EXISTING_COMM_GROUPS.keys()}")
1207 return EXISTING_COMM_GROUPS[group_key]
1209 def get_devices_for_axis(self, mesh_dim: Union[str, int], rank: int):
1210 """List peer ranks that share all coordinates except the requested axis."""
1211 if isinstance(mesh_dim, str):
1212 if not self.mesh_dim_names:
1213 raise ValueError("_mesh_dim_names is not set, string mesh_dim is not supported, please pass a integer.")
1214 mesh_dim_names = self.mesh_dim_names
1215 if mesh_dim not in mesh_dim_names: # pylint: disable=E1135
1216 raise ValueError(f"mesh_dim '{mesh_dim}' not found in mesh_dim_names {mesh_dim_names}")
1217 mesh_dim = mesh_dim_names.index(mesh_dim)
1219 mesh_shape = self._mesh_shape
1220 if mesh_dim < 0 or mesh_dim >= self.ndim:
1221 raise ValueError(f"mesh_dim {mesh_dim} can not out of range [0, {self.ndim})")
1222 rank_list = self._rank_list
1223 if rank not in rank_list:
1224 raise ValueError(f"Rank {rank} not found in rank_list")
1226 idx = rank_list.index(rank)
1227 coord = [0] * len(mesh_shape)
1228 temp = idx
1229 for i in range(len(mesh_shape) - 1, -1, -1):
1230 coord[i] = temp % mesh_shape[i]
1231 temp //= mesh_shape[i]
1233 strides = [1] * len(mesh_shape)
1234 for i in range(len(mesh_shape) - 2, -1, -1):
1235 strides[i] = strides[i + 1] * mesh_shape[i + 1]
1237 result_ranks = []
1238 for v in range(mesh_shape[mesh_dim]):
1239 new_coord = coord.copy()
1240 new_coord[mesh_dim] = v
1241 new_idx = 0
1242 for i in range(len(mesh_shape)):
1243 new_idx += new_coord[i] * strides[i]
1245 result_ranks.append(rank_list[new_idx])
1247 return result_ranks
1249 def to_hash(self):
1250 map_key = (self.mesh_shape, self.mesh_dim_names, self.rank_list)
1251 return map_key
1253 def __repr__(self):
1254 return (
1255 f"DeviceMesh(device_type='{self.device_type}', mesh_shape={self._mesh_shape}, "
1256 f"mesh_dim_names={self.mesh_dim_names}, rank_list={self._rank_list})"
1257 )
1259 def __str__(self):
1260 return self.__repr__()
1262 def __deepcopy__(self, memo):
1263 cls = self.__class__
1264 result = cls.__new__(cls)
1265 memo[id(self)] = result
1266 for k, v in self.__dict__.items():
1267 if k in ("_root_mesh", "_dim_group_sources"):
1268 setattr(result, k, v)
1269 else:
1270 setattr(result, k, copy.deepcopy(v, memo))
1271 return result
1274_DEVICE_MESH_MAP = {}
1277def _create_device_mesh(device_type: str,
1278 mesh_shape: tuple[int, ...],
1279 *,
1280 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None,
1281 rank_list: tuple[int, ...],
1282 init_backend: bool = True, ):
1283 """Create or reuse a cached DeviceMesh with the requested topology."""
1284 mesh = np.array(rank_list).reshape(mesh_shape)
1285 mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
1286 map_key = hash((mesh_shape, mesh_dim_names, rank_list))
1287 if map_key not in _DEVICE_MESH_MAP:
1288 _DEVICE_MESH_MAP[map_key] = DeviceMesh(device_type, mesh,
1289 mesh_dim_names=mesh_dim_names,
1290 _init_backend=init_backend)
1291 return _DEVICE_MESH_MAP.get(map_key, None)
1294def init_device_mesh(
1295 device_type: str,
1296 mesh_shape: tuple[int, ...],
1297 *,
1298 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None,
1299 rank_list: Optional[tuple[int, ...]] = None,
1300 init_backend: bool = True,
1301) -> DeviceMesh:
1302 """Initialize a cached DeviceMesh from the provided shape, names, and ranks."""
1303 total_devices = int(np.prod(np.array(mesh_shape)))
1304 if rank_list is not None:
1305 if len(rank_list) != total_devices:
1306 raise ValueError(
1307 f"rank_list length ({len(rank_list)}) must equal mesh size ({total_devices})"
1308 )
1309 else:
1310 if init_backend:
1311 platform.init_process_group()
1312 try:
1313 current_rank = platform.get_rank()
1314 except Exception as exc:
1315 raise RuntimeError(
1316 "init_device_mesh: failed to get current rank for automatic rank_list generation. "
1317 "Either pass rank_list explicitly, or ensure the process group is initialized before calling "
1318 "init_device_mesh (or set init_backend=True to let init_device_mesh initialize it)."
1319 ) from exc
1320 base = current_rank - (current_rank % total_devices)
1321 rank_list = tuple(range(base, base + total_devices))
1323 if not isinstance(mesh_shape, tuple):
1324 raise TypeError(f'mesh_shape must be a tuple, but got {type(mesh_shape)}')
1326 for size in mesh_shape:
1327 if not isinstance(size, int) or size <= 0:
1328 raise ValueError(
1329 f"Each element of mesh_shape must be a positive integer, but got {mesh_shape}"
1330 )
1332 if mesh_dim_names is not None:
1333 if not isinstance(mesh_dim_names, (tuple, list)):
1334 raise TypeError(
1335 f'mesh_dim_names must be a tuple or list, but got {type(mesh_dim_names)}'
1336 )
1337 mesh_dim_names = tuple(mesh_dim_names)
1338 if len(mesh_shape) != len(mesh_dim_names):
1339 raise ValueError(
1340 f'mesh_shape ({len(mesh_shape)}) and mesh_dim_names '
1341 f'({len(mesh_dim_names)}) should have same length'
1342 )
1343 if len(set(mesh_dim_names)) != len(mesh_dim_names):
1344 raise ValueError(f'Each element of mesh_dim_names {mesh_dim_names} should be different')
1345 if any(not isinstance(name, str) or name == "" for name in mesh_dim_names):
1346 raise ValueError(f'Each element of mesh_dim_names {mesh_dim_names} should be a non-empty string')
1348 return _create_device_mesh(
1349 device_type,
1350 mesh_shape,
1351 mesh_dim_names=mesh_dim_names,
1352 rank_list=rank_list,
1353 init_backend=init_backend,
1354 )