Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / _mesh_layout.py: 54%
209 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Internal layout helpers for DeviceMesh bookkeeping."""
17import math
18from dataclasses import dataclass
19from typing import Any, Union
21import numpy as np
24IntTuple = Union[int, tuple["IntTuple", ...]]
27def _is_int(value: Any) -> bool:
28 return isinstance(value, int) and not isinstance(value, bool)
31def _as_tuple(value: IntTuple) -> tuple[IntTuple, ...]:
32 return value if isinstance(value, tuple) else (value,)
35def _flatten_inttuple(value: IntTuple) -> tuple[int, ...]:
36 if _is_int(value):
37 return (value,)
38 flattened: list[int] = []
39 for item in value:
40 flattened.extend(_flatten_inttuple(item))
41 return tuple(flattened)
44def _match_structure(shape: IntTuple, stride: IntTuple) -> bool:
45 if _is_int(shape):
46 return _is_int(stride)
47 if not isinstance(shape, tuple) or not isinstance(stride, tuple) or len(shape) != len(stride):
48 return False
49 return all(_match_structure(sub_shape, sub_stride) for sub_shape, sub_stride in zip(shape, stride))
52def _numel(value: IntTuple) -> int:
53 return int(np.prod(np.array(_flatten_inttuple(value), dtype=np.int64)))
56def _contiguous_strides(mesh_shape: tuple[int, ...]) -> tuple[int, ...]:
57 if len(mesh_shape) == 0:
58 return ()
59 strides = [1] * len(mesh_shape)
60 for idx in range(len(mesh_shape) - 2, -1, -1):
61 strides[idx] = strides[idx + 1] * mesh_shape[idx + 1]
62 return tuple(strides)
65def _scale_inttuple(value: IntTuple, factor: int) -> IntTuple:
66 if _is_int(value):
67 return int(value) * factor
68 return tuple(_scale_inttuple(item, factor) for item in value)
71def _enumerate_offsets(shape: IntTuple, stride: IntTuple) -> list[int]:
72 if _is_int(shape):
73 return [i * int(stride) for i in range(shape)]
75 offsets = [0]
76 for sub_shape, sub_stride in zip(_as_tuple(shape), _as_tuple(stride)):
77 dim_offsets = _enumerate_offsets(sub_shape, sub_stride)
78 offsets = [base + dim_offset for base in offsets for dim_offset in dim_offsets]
79 return offsets
82def _canonicalize_axis(shape, stride) -> tuple[tuple[int, ...], tuple[int, ...]]:
83 """Normalize one logical axis into a flattened shape/stride pair."""
84 flat_shape = _flatten_inttuple(shape)
85 flat_stride = _flatten_inttuple(stride if stride is not None else _contiguous_strides(flat_shape))
86 if len(flat_shape) != len(flat_stride):
87 raise ValueError(
88 f"shape and stride must have the same length, got {len(flat_shape)} and {len(flat_stride)}"
89 )
91 normalized_shape: list[int] = []
92 normalized_stride: list[int] = []
93 for size, step in zip(flat_shape, flat_stride):
94 if size < 0:
95 raise ValueError(f"shape entries must be non-negative, got {flat_shape}")
96 if size == 1:
97 continue
98 normalized_shape.append(int(size))
99 normalized_stride.append(int(step))
101 coalesced_shape: list[int] = []
102 coalesced_stride: list[int] = []
103 for size, step in zip(normalized_shape, normalized_stride):
104 if coalesced_shape and coalesced_stride[-1] == step * size:
105 coalesced_shape[-1] *= size
106 coalesced_stride[-1] = step
107 else:
108 coalesced_shape.append(size)
109 coalesced_stride.append(step)
110 return tuple(coalesced_shape), tuple(coalesced_stride)
113def _nested_from_flat(value: tuple[int, ...]) -> IntTuple:
114 if len(value) == 1:
115 return value[0]
116 return tuple(value)
119@dataclass(frozen=True)
120class _FlatLayout:
121 """Canonicalized layout for one logical DeviceMesh axis."""
123 shape: tuple[int, ...]
124 stride: tuple[int, ...]
126 def __init__(self, shape, stride=None) -> None:
127 flat_shape, flat_stride = _canonicalize_axis(shape, stride)
128 object.__setattr__(self, "shape", flat_shape)
129 object.__setattr__(self, "stride", flat_stride)
131 def numel(self) -> int:
132 return math.prod(self.shape) if len(self.shape) > 0 else 1
134 def cosize(self) -> int:
135 ranks = self.all_ranks_from_zero()
136 return max(ranks) + 1 if ranks else 1
138 def check_sorted(self) -> bool:
139 return tuple(sorted(self.stride, reverse=True)) == self.stride
141 def check_orthogonal(self) -> bool:
142 if len(self.shape) < 2:
143 return True
144 stride, shape = zip(*sorted(zip(self.stride, self.shape), reverse=True))
145 return all(
146 stride[idx] % (stride[idx + 1] * shape[idx + 1]) == 0
147 for idx in range(len(stride) - 1)
148 )
150 def all_ranks_from_zero(self) -> list[int]:
151 if len(self.shape) == 0:
152 return [0]
153 return [
154 int(sum(coord[dim] * self.stride[dim] for dim in range(len(self.shape))))
155 for coord in np.ndindex(self.shape)
156 ]
159class _MeshLayout:
160 """Minimal layout helper for DeviceMesh slicing, flattening, and concatenation."""
162 def __init__(self, shape_or_axes: Union[IntTuple, list[_FlatLayout], tuple[_FlatLayout, ...]], stride=None):
163 if stride is None and isinstance(shape_or_axes, (list, tuple)) and all(
164 isinstance(axis, _FlatLayout) for axis in shape_or_axes
165 ):
166 axes = list(shape_or_axes)
167 shape = tuple(_nested_from_flat(axis.shape) for axis in axes)
168 stride = tuple(_nested_from_flat(axis.stride) for axis in axes)
169 else:
170 shape = shape_or_axes
171 if not _match_structure(shape, stride):
172 raise ValueError(f"shape {shape} and stride {stride} do not match")
173 self.shape: IntTuple = shape
174 self.stride: IntTuple = stride
176 @classmethod
177 def from_sizes_strides(
178 cls,
179 sizes: tuple[int, ...],
180 strides: tuple[int, ...] | None = None,
181 ) -> "_MeshLayout":
182 if strides is None:
183 strides = _contiguous_strides(sizes)
184 return cls(sizes, strides)
186 @property
187 def sizes(self) -> IntTuple:
188 return self.shape
190 @property
191 def strides(self) -> IntTuple:
192 return self.stride
194 @property
195 def axes(self) -> tuple[_FlatLayout, ...]:
196 return tuple(self[idx].collapse() for idx in range(len(self)))
198 def __len__(self) -> int:
199 return len(self.shape) if isinstance(self.shape, tuple) else 1
201 def __iter__(self):
202 for idx in range(len(self)):
203 yield self[idx]
205 def __getitem__(self, idx: int) -> "_MeshLayout":
206 if isinstance(self.shape, tuple):
207 if idx < -len(self.shape) or idx >= len(self.shape):
208 raise IndexError(
209 f"Dim {idx} is out of range for layout with {len(self.shape)} dimensions."
210 )
211 return _MeshLayout(self.shape[idx], self.stride[idx])
212 if idx not in (0, -1):
213 raise IndexError("Dim is out of range for 1D layout.")
214 return _MeshLayout(self.shape, self.stride)
216 def __eq__(self, other: object) -> bool:
217 if not isinstance(other, _MeshLayout):
218 return False
219 return self.shape == other.shape and self.stride == other.stride
221 def __repr__(self) -> str:
222 return f"_MeshLayout(shape={self.shape}, stride={self.stride})"
224 def numel(self) -> int:
225 return _numel(self.shape)
227 @property
228 def top_level_sizes(self) -> tuple[int, ...]:
229 return tuple(self[idx].numel() for idx in range(len(self)))
231 def all_ranks_from_zero(self) -> list[int]:
232 return _enumerate_offsets(self.shape, self.stride)
234 def check_non_overlap(self) -> bool:
235 ranks = self.all_ranks_from_zero()
236 return len(ranks) == len(set(ranks))
238 def coalesce(self) -> "_MeshLayout":
239 """Merge adjacent contiguous axes while preserving the represented layout."""
240 if _is_int(self.shape):
241 return self
243 coalesced_shapes: list[IntTuple] = []
244 coalesced_strides: list[IntTuple] = []
245 for shape, stride in zip(self.shape, self.stride):
246 child = _MeshLayout(shape, stride).coalesce()
247 coalesced_shapes.append(child.shape)
248 coalesced_strides.append(child.stride)
250 merged_shapes: list[IntTuple] = []
251 merged_strides: list[IntTuple] = []
252 for shape, stride in zip(coalesced_shapes, coalesced_strides):
253 if (
254 merged_shapes
255 and _is_int(merged_shapes[-1])
256 and _is_int(merged_strides[-1])
257 and _is_int(shape)
258 and _is_int(stride)
259 and merged_strides[-1] == stride * shape
260 ):
261 merged_shapes[-1] *= shape
262 merged_strides[-1] = stride
263 else:
264 merged_shapes.append(shape)
265 merged_strides.append(stride)
267 if len(merged_shapes) == 1:
268 return _MeshLayout(merged_shapes[0], merged_strides[0])
269 return _MeshLayout(tuple(merged_shapes), tuple(merged_strides))
271 def composition(self, layout: "_MeshLayout") -> "_MeshLayout":
272 if not _is_int(self.stride):
273 raise NotImplementedError(
274 "Currently, _unflatten only supports unflattening a mesh dim with scalar stride."
275 )
276 return _MeshLayout(layout.shape, _scale_inttuple(layout.stride, int(self.stride)))
278 def nest(self) -> "_MeshLayout":
279 if len(self) == 1:
280 return self
281 return _MeshLayout((self.shape,), (self.stride,))
283 def splice(self, start: int, end: int, layout: "_MeshLayout") -> "_MeshLayout":
284 sizes = list(_as_tuple(self.shape))
285 strides = list(_as_tuple(self.stride))
286 sizes[start:end] = list(_as_tuple(layout.shape))
287 strides[start:end] = list(_as_tuple(layout.stride))
288 if len(sizes) == 1:
289 return _MeshLayout(sizes[0], strides[0])
290 return _MeshLayout(tuple(sizes), tuple(strides))
292 def collapse(self) -> _FlatLayout:
293 return _FlatLayout(self.shape, self.stride)
295 def remap_to_numpy(self, rank_map) -> np.ndarray:
296 """Materialize this layout as a dense numpy mesh over the provided rank map."""
297 rank_map_np = np.asarray(rank_map).reshape(-1)
298 base_offsets = self.all_ranks_from_zero()
299 if len(base_offsets) == 0:
300 raise ValueError("Cannot remap an empty layout.")
302 groups: list[list[int]] = []
303 used: set[int] = set()
304 world_size = rank_map_np.shape[0]
306 for anchor in range(world_size):
307 if anchor in used:
308 continue
309 group = [anchor + offset for offset in base_offsets]
310 if any(index >= world_size for index in group):
311 continue
312 if any(index in used for index in group):
313 continue
314 groups.append(group)
315 used.update(group)
317 if len(used) != world_size:
318 raise ValueError(
319 f"Layout {self} does not form a full partition over rank_map with world size {world_size}."
320 )
322 remapped = rank_map_np[np.array(groups, dtype=np.int64)]
323 remapped = remapped.reshape((len(groups),) + self.top_level_sizes)
324 return remapped