Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / _mesh_layout.py: 54%

209 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Internal layout helpers for DeviceMesh bookkeeping.""" 

16 

17import math 

18from dataclasses import dataclass 

19from typing import Any, Union 

20 

21import numpy as np 

22 

23 

24IntTuple = Union[int, tuple["IntTuple", ...]] 

25 

26 

27def _is_int(value: Any) -> bool: 

28 return isinstance(value, int) and not isinstance(value, bool) 

29 

30 

31def _as_tuple(value: IntTuple) -> tuple[IntTuple, ...]: 

32 return value if isinstance(value, tuple) else (value,) 

33 

34 

35def _flatten_inttuple(value: IntTuple) -> tuple[int, ...]: 

36 if _is_int(value): 

37 return (value,) 

38 flattened: list[int] = [] 

39 for item in value: 

40 flattened.extend(_flatten_inttuple(item)) 

41 return tuple(flattened) 

42 

43 

44def _match_structure(shape: IntTuple, stride: IntTuple) -> bool: 

45 if _is_int(shape): 

46 return _is_int(stride) 

47 if not isinstance(shape, tuple) or not isinstance(stride, tuple) or len(shape) != len(stride): 

48 return False 

49 return all(_match_structure(sub_shape, sub_stride) for sub_shape, sub_stride in zip(shape, stride)) 

50 

51 

52def _numel(value: IntTuple) -> int: 

53 return int(np.prod(np.array(_flatten_inttuple(value), dtype=np.int64))) 

54 

55 

56def _contiguous_strides(mesh_shape: tuple[int, ...]) -> tuple[int, ...]: 

57 if len(mesh_shape) == 0: 

58 return () 

59 strides = [1] * len(mesh_shape) 

60 for idx in range(len(mesh_shape) - 2, -1, -1): 

61 strides[idx] = strides[idx + 1] * mesh_shape[idx + 1] 

62 return tuple(strides) 

63 

64 

65def _scale_inttuple(value: IntTuple, factor: int) -> IntTuple: 

66 if _is_int(value): 

67 return int(value) * factor 

68 return tuple(_scale_inttuple(item, factor) for item in value) 

69 

70 

71def _enumerate_offsets(shape: IntTuple, stride: IntTuple) -> list[int]: 

72 if _is_int(shape): 

73 return [i * int(stride) for i in range(shape)] 

74 

75 offsets = [0] 

76 for sub_shape, sub_stride in zip(_as_tuple(shape), _as_tuple(stride)): 

77 dim_offsets = _enumerate_offsets(sub_shape, sub_stride) 

78 offsets = [base + dim_offset for base in offsets for dim_offset in dim_offsets] 

79 return offsets 

80 

81 

82def _canonicalize_axis(shape, stride) -> tuple[tuple[int, ...], tuple[int, ...]]: 

83 """Normalize one logical axis into a flattened shape/stride pair.""" 

84 flat_shape = _flatten_inttuple(shape) 

85 flat_stride = _flatten_inttuple(stride if stride is not None else _contiguous_strides(flat_shape)) 

86 if len(flat_shape) != len(flat_stride): 

87 raise ValueError( 

88 f"shape and stride must have the same length, got {len(flat_shape)} and {len(flat_stride)}" 

89 ) 

90 

91 normalized_shape: list[int] = [] 

92 normalized_stride: list[int] = [] 

93 for size, step in zip(flat_shape, flat_stride): 

94 if size < 0: 

95 raise ValueError(f"shape entries must be non-negative, got {flat_shape}") 

96 if size == 1: 

97 continue 

98 normalized_shape.append(int(size)) 

99 normalized_stride.append(int(step)) 

100 

101 coalesced_shape: list[int] = [] 

102 coalesced_stride: list[int] = [] 

103 for size, step in zip(normalized_shape, normalized_stride): 

104 if coalesced_shape and coalesced_stride[-1] == step * size: 

105 coalesced_shape[-1] *= size 

106 coalesced_stride[-1] = step 

107 else: 

108 coalesced_shape.append(size) 

109 coalesced_stride.append(step) 

110 return tuple(coalesced_shape), tuple(coalesced_stride) 

111 

112 

113def _nested_from_flat(value: tuple[int, ...]) -> IntTuple: 

114 if len(value) == 1: 

115 return value[0] 

116 return tuple(value) 

117 

118 

119@dataclass(frozen=True) 

120class _FlatLayout: 

121 """Canonicalized layout for one logical DeviceMesh axis.""" 

122 

123 shape: tuple[int, ...] 

124 stride: tuple[int, ...] 

125 

126 def __init__(self, shape, stride=None) -> None: 

127 flat_shape, flat_stride = _canonicalize_axis(shape, stride) 

128 object.__setattr__(self, "shape", flat_shape) 

129 object.__setattr__(self, "stride", flat_stride) 

130 

131 def numel(self) -> int: 

132 return math.prod(self.shape) if len(self.shape) > 0 else 1 

133 

134 def cosize(self) -> int: 

135 ranks = self.all_ranks_from_zero() 

136 return max(ranks) + 1 if ranks else 1 

137 

138 def check_sorted(self) -> bool: 

139 return tuple(sorted(self.stride, reverse=True)) == self.stride 

140 

141 def check_orthogonal(self) -> bool: 

142 if len(self.shape) < 2: 

143 return True 

144 stride, shape = zip(*sorted(zip(self.stride, self.shape), reverse=True)) 

145 return all( 

146 stride[idx] % (stride[idx + 1] * shape[idx + 1]) == 0 

147 for idx in range(len(stride) - 1) 

148 ) 

149 

150 def all_ranks_from_zero(self) -> list[int]: 

151 if len(self.shape) == 0: 

152 return [0] 

153 return [ 

154 int(sum(coord[dim] * self.stride[dim] for dim in range(len(self.shape)))) 

155 for coord in np.ndindex(self.shape) 

156 ] 

157 

158 

159class _MeshLayout: 

160 """Minimal layout helper for DeviceMesh slicing, flattening, and concatenation.""" 

161 

162 def __init__(self, shape_or_axes: Union[IntTuple, list[_FlatLayout], tuple[_FlatLayout, ...]], stride=None): 

163 if stride is None and isinstance(shape_or_axes, (list, tuple)) and all( 

164 isinstance(axis, _FlatLayout) for axis in shape_or_axes 

165 ): 

166 axes = list(shape_or_axes) 

167 shape = tuple(_nested_from_flat(axis.shape) for axis in axes) 

168 stride = tuple(_nested_from_flat(axis.stride) for axis in axes) 

169 else: 

170 shape = shape_or_axes 

171 if not _match_structure(shape, stride): 

172 raise ValueError(f"shape {shape} and stride {stride} do not match") 

173 self.shape: IntTuple = shape 

174 self.stride: IntTuple = stride 

175 

176 @classmethod 

177 def from_sizes_strides( 

178 cls, 

179 sizes: tuple[int, ...], 

180 strides: tuple[int, ...] | None = None, 

181 ) -> "_MeshLayout": 

182 if strides is None: 

183 strides = _contiguous_strides(sizes) 

184 return cls(sizes, strides) 

185 

186 @property 

187 def sizes(self) -> IntTuple: 

188 return self.shape 

189 

190 @property 

191 def strides(self) -> IntTuple: 

192 return self.stride 

193 

194 @property 

195 def axes(self) -> tuple[_FlatLayout, ...]: 

196 return tuple(self[idx].collapse() for idx in range(len(self))) 

197 

198 def __len__(self) -> int: 

199 return len(self.shape) if isinstance(self.shape, tuple) else 1 

200 

201 def __iter__(self): 

202 for idx in range(len(self)): 

203 yield self[idx] 

204 

205 def __getitem__(self, idx: int) -> "_MeshLayout": 

206 if isinstance(self.shape, tuple): 

207 if idx < -len(self.shape) or idx >= len(self.shape): 

208 raise IndexError( 

209 f"Dim {idx} is out of range for layout with {len(self.shape)} dimensions." 

210 ) 

211 return _MeshLayout(self.shape[idx], self.stride[idx]) 

212 if idx not in (0, -1): 

213 raise IndexError("Dim is out of range for 1D layout.") 

214 return _MeshLayout(self.shape, self.stride) 

215 

216 def __eq__(self, other: object) -> bool: 

217 if not isinstance(other, _MeshLayout): 

218 return False 

219 return self.shape == other.shape and self.stride == other.stride 

220 

221 def __repr__(self) -> str: 

222 return f"_MeshLayout(shape={self.shape}, stride={self.stride})" 

223 

224 def numel(self) -> int: 

225 return _numel(self.shape) 

226 

227 @property 

228 def top_level_sizes(self) -> tuple[int, ...]: 

229 return tuple(self[idx].numel() for idx in range(len(self))) 

230 

231 def all_ranks_from_zero(self) -> list[int]: 

232 return _enumerate_offsets(self.shape, self.stride) 

233 

234 def check_non_overlap(self) -> bool: 

235 ranks = self.all_ranks_from_zero() 

236 return len(ranks) == len(set(ranks)) 

237 

238 def coalesce(self) -> "_MeshLayout": 

239 """Merge adjacent contiguous axes while preserving the represented layout.""" 

240 if _is_int(self.shape): 

241 return self 

242 

243 coalesced_shapes: list[IntTuple] = [] 

244 coalesced_strides: list[IntTuple] = [] 

245 for shape, stride in zip(self.shape, self.stride): 

246 child = _MeshLayout(shape, stride).coalesce() 

247 coalesced_shapes.append(child.shape) 

248 coalesced_strides.append(child.stride) 

249 

250 merged_shapes: list[IntTuple] = [] 

251 merged_strides: list[IntTuple] = [] 

252 for shape, stride in zip(coalesced_shapes, coalesced_strides): 

253 if ( 

254 merged_shapes 

255 and _is_int(merged_shapes[-1]) 

256 and _is_int(merged_strides[-1]) 

257 and _is_int(shape) 

258 and _is_int(stride) 

259 and merged_strides[-1] == stride * shape 

260 ): 

261 merged_shapes[-1] *= shape 

262 merged_strides[-1] = stride 

263 else: 

264 merged_shapes.append(shape) 

265 merged_strides.append(stride) 

266 

267 if len(merged_shapes) == 1: 

268 return _MeshLayout(merged_shapes[0], merged_strides[0]) 

269 return _MeshLayout(tuple(merged_shapes), tuple(merged_strides)) 

270 

271 def composition(self, layout: "_MeshLayout") -> "_MeshLayout": 

272 if not _is_int(self.stride): 

273 raise NotImplementedError( 

274 "Currently, _unflatten only supports unflattening a mesh dim with scalar stride." 

275 ) 

276 return _MeshLayout(layout.shape, _scale_inttuple(layout.stride, int(self.stride))) 

277 

278 def nest(self) -> "_MeshLayout": 

279 if len(self) == 1: 

280 return self 

281 return _MeshLayout((self.shape,), (self.stride,)) 

282 

283 def splice(self, start: int, end: int, layout: "_MeshLayout") -> "_MeshLayout": 

284 sizes = list(_as_tuple(self.shape)) 

285 strides = list(_as_tuple(self.stride)) 

286 sizes[start:end] = list(_as_tuple(layout.shape)) 

287 strides[start:end] = list(_as_tuple(layout.stride)) 

288 if len(sizes) == 1: 

289 return _MeshLayout(sizes[0], strides[0]) 

290 return _MeshLayout(tuple(sizes), tuple(strides)) 

291 

292 def collapse(self) -> _FlatLayout: 

293 return _FlatLayout(self.shape, self.stride) 

294 

295 def remap_to_numpy(self, rank_map) -> np.ndarray: 

296 """Materialize this layout as a dense numpy mesh over the provided rank map.""" 

297 rank_map_np = np.asarray(rank_map).reshape(-1) 

298 base_offsets = self.all_ranks_from_zero() 

299 if len(base_offsets) == 0: 

300 raise ValueError("Cannot remap an empty layout.") 

301 

302 groups: list[list[int]] = [] 

303 used: set[int] = set() 

304 world_size = rank_map_np.shape[0] 

305 

306 for anchor in range(world_size): 

307 if anchor in used: 

308 continue 

309 group = [anchor + offset for offset in base_offsets] 

310 if any(index >= world_size for index in group): 

311 continue 

312 if any(index in used for index in group): 

313 continue 

314 groups.append(group) 

315 used.update(group) 

316 

317 if len(used) != world_size: 

318 raise ValueError( 

319 f"Layout {self} does not form a full partition over rank_map with world size {world_size}." 

320 ) 

321 

322 remapped = rank_map_np[np.array(groups, dtype=np.int64)] 

323 remapped = remapped.reshape((len(groups),) + self.top_level_sizes) 

324 return remapped