Coverage for hyper_parallel / core / layout.py: 91%
336 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""layout"""
17import copy
18import functools
19import numpy as np
21from hyper_parallel.core.placement_types import Placement, Shard, Replicate, Partial
22from hyper_parallel.core.device_mesh import DeviceMesh, _create_device_mesh
23from hyper_parallel.platform import get_platform
25platform = get_platform()
28def _infer_slice_area_by_rank(mesh_shape, tensor_map, rank_id: int, full_shape: tuple): # -> tuple[tuple[int]]:
29 """Return the range of each axis from full tensor for slice in current rank."""
31 def _get_dev_num_alone_dim(mesh_shape, dim):
32 """_get_dev_num_alone_dim."""
33 return mesh_shape[-dim - 1] if dim != -1 else 1
35 def _rank_id_to_dev_id_list(mesh_shape, rank_id):
36 """Infer dev id list by rank_id and mesh_shape"""
37 dims = len(mesh_shape)
38 dev_id_list = [0] * dims
39 for i in range(dims - 1, -1, -1):
40 dev_id_list[i] = rank_id % mesh_shape[i]
41 rank_id = rank_id // mesh_shape[i]
42 return dev_id_list
44 dev_id_list = _rank_id_to_dev_id_list(mesh_shape, rank_id)
46 dims = len(full_shape)
47 area = []
48 for axis in range(dims):
49 mapping = tensor_map[axis]
50 if isinstance(mapping, int):
51 mapping = (mapping,)
52 split_num = 1
53 for dim in mapping:
54 split_num *= _get_dev_num_alone_dim(mesh_shape, dim)
56 slice_id = 0
57 coef = 1
58 for dim in reversed(mapping):
59 if dim == -1:
60 continue
61 slice_id += dev_id_list[-dim - 1] * coef
62 coef *= _get_dev_num_alone_dim(mesh_shape, dim)
63 slice_size = full_shape[axis] // split_num
64 start = slice_id * slice_size
65 end = start + slice_size
66 area.append((start, end))
67 return area
70def _get_slice_tensor_by_layout(global_tensor, layout):
71 """Transfer global tensor to local tensor by layout"""
72 inner_rank_id = layout.rank_list.index(layout.mesh.rank)
73 slice_area = _infer_slice_area_by_rank(layout.mesh_shape, layout.tensor_map, inner_rank_id, global_tensor.shape)
75 def get_slice_data(full_data, offset):
76 area = ()
77 for begin, end in offset:
78 area += (slice(begin, end),)
79 return full_data[area]
81 local_tensor = get_slice_data(global_tensor, slice_area)
82 return local_tensor
85def _infer_slice_shape_by_layout(global_shape, layout):
86 """Infer slice shape from global_shape and layout"""
87 slice_shape = list(global_shape)
88 alias_tensor_map = layout.alias_tensor_map
89 for i in range(len(global_shape)):
90 axis_name = alias_tensor_map[i]
91 if isinstance(axis_name, str):
92 axis_name = (axis_name,)
93 for sub_axis_name in axis_name:
94 if sub_axis_name != "None":
95 slice_shape[i] = slice_shape[i] // layout.mesh.get_device_num_along_axis(sub_axis_name)
96 return slice_shape
99class Layout:
100 """
101 Topological abstraction describing cluster devices for tensor slice placement on the cluster.
103 Note:
104 - It is valid only in semi auto parallel or auto parallel mode.
105 - The multiplication result of the `mesh_shape` must be equal to the device count in a pipeline stage.
106 - When the layout function is invoked to constructs a sharding strategy, each alias name is only allowed to be
107 used once to shard a tensor.
109 Args:
110 mesh_shape (tuple): Describe the shape of devices arrangement, its element type is int.
111 alias_name (tuple): The alias name for each axis of mesh_shape, its length shoits element type is string.
112 When using "interleaved_parallel" as an alias name, the tensor would be split into multiple
113 copies on the corresponding partition dimension on a single card.
114 rank_list (tuple, optional): Data is allocated to the device according to rank_list. Default: ``None``.
116 Raises:
117 TypeError: `mesh_shape` is not a tuple type.
118 TypeError: `alias_name` is not a tuple type.
119 TypeError: 'rank_list' is not a list type.
120 ValueError: `mesh_shape` length is not equal to `alias_name` length.
121 TypeError: The element of `mesh_shape` is not int type.
122 TypeError: The element of `alias_name` is not a str type.
123 TypeError: The element of `rank_list` is not int type.
124 ValueError: The element of `alias_name` is an empty str.
125 ValueError: The element of `alias_name` is "None".
126 ValueError: `alias_name` contains repeated element.
128 Supported Platforms:
129 ``Ascend``
131 Examples:
132 >>> from mindspore.parallel import Layout
133 >>> layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
134 >>> layout0 = layout("dp", "mp")
135 >>> print(layout0.to_dict())
136 {"mesh_shape": (2, 2, 2), "tensor_map": (2, 0), "interleaved_parallel": False,
137 'alias_name': {'dp', 'sp', 'mp'}, "rank_list": [0, 1, 2, 3, 4, 5, 6, 7]}
138 >>> layout = Layout((2, 2, 2), ("dp", "sp", "interleaved_parallel"))
139 >>> layout1 = layout(("dp", "interleaved_parallel"), "sp")
140 """
142 def __init__(self, mesh_shape, alias_name, rank_list=None, init_backend=True):
143 self._alias_name = alias_name
144 self._tensor_map = None
145 if not rank_list:
146 self._rank_list = tuple(range(np.prod(np.array(mesh_shape))))
147 else:
148 self._rank_list = tuple(rank_list)
149 self._partial = [None] * len(mesh_shape) # partial status for each dev dim
150 self._support_partial_op = ['sum', 'max', 'min', 'avg', 'prod', 'all', None]
151 self._alias_tensor_map = None
152 self._mesh = _create_device_mesh("npu", mesh_shape, mesh_dim_names=alias_name, rank_list=self._rank_list,
153 init_backend=init_backend)
154 self._compact_str = self._to_compact_string()
155 self._placements = None
157 @classmethod
158 def from_device_mesh(cls, device_mesh: DeviceMesh) -> 'Layout':
159 """
160 Create a Layout from an existing DeviceMesh.
162 Args:
163 device_mesh (DeviceMesh): The device mesh to create layout from.
165 Returns:
166 Layout: A new Layout instance initialized with the properties of the provided device mesh.
168 Examples:
169 >>> from hyper_parallel.core.layout import Layout, DeviceMesh
170 >>> device_mesh = DeviceMesh("npu", (2, 2), nesh_dim_names=("dp", "mp"))
171 >>> layout = Layout.from_device_mesh(device_mesh)
172 """
173 obj = cls.__new__(cls)
174 obj._mesh = device_mesh
175 obj._alias_name = device_mesh.mesh_dim_names
176 obj._rank_list = device_mesh.rank_list
177 obj._tensor_map = None
178 obj._partial = [None] * len(device_mesh.mesh_shape)
179 obj._support_partial_op = ['sum', 'max', 'min', 'avg', 'prod', 'all', None]
180 obj._alias_tensor_map = None
181 obj._placements = None
182 obj._compact_str = obj._to_compact_string()
183 return obj
185 def __call__(self, *alias_tensor_map):
186 obj = copy.deepcopy(self)
188 if len(alias_tensor_map) == 1 and isinstance(alias_tensor_map[0], (list, tuple)):
189 if len(alias_tensor_map[0]) > 0 and isinstance(alias_tensor_map[0][0], Placement):
190 return self._process_placement_layout(obj, alias_tensor_map[0])
192 if len(alias_tensor_map) > 0 and isinstance(alias_tensor_map[0], Placement):
193 return self._process_placement_layout(obj, alias_tensor_map)
195 return self._process_alias_layout(obj, alias_tensor_map)
197 def _process_placement_layout(self, obj, placements):
198 """Process layout defined by Placement types."""
199 obj.set_placements(placements)
200 return copy.deepcopy(obj)
202 def _process_alias_layout(self, obj, alias_tensor_map):
203 """Process layout defined by alias strings."""
204 obj.set_alias_tensor_map(alias_tensor_map)
205 tensor_map = ()
206 writed_map = ()
207 for ele in alias_tensor_map:
208 if isinstance(ele, tuple):
209 ele_map = ()
210 for item in ele:
211 if item == "None":
212 ele_map += (-1,)
213 continue
214 if item not in obj.alias_name:
215 raise ValueError(f'The axis {item} is not found in {obj.alias_name}')
216 if item in writed_map:
217 raise ValueError(f'The axis {item} has been set more than one in {obj.alias_name}')
218 ele_map += (len(obj.alias_name) - 1 - obj.alias_name.index(item),)
219 writed_map += (item,)
220 tensor_map += (ele_map,)
221 continue
222 if ele == "None":
223 tensor_map += (-1,)
224 continue
225 if ele not in obj.alias_name:
226 raise ValueError(f'The axis {ele} is not found in {obj.alias_name}')
227 if ele in writed_map:
228 raise ValueError(f'The axis {ele} has been set more than one in {obj.alias_name}')
229 tensor_map += (len(obj.alias_name) - 1 - obj.alias_name.index(ele),)
230 writed_map += (ele,)
231 obj.set_tensor_map(tensor_map)
232 obj.tensor_map_to_placement()
233 obj.update_compact_str()
234 return copy.deepcopy(obj)
236 def to_dict(self):
237 """
238 Transform layout to a dictionary.
239 """
240 if self._mesh.mesh_shape is None:
241 raise ValueError("The device_shape of layout is None")
242 if self._tensor_map is None:
243 raise ValueError("The tensor_map of layout is None")
244 interleaved_parallel = "interleaved_parallel" in self._mesh.mesh_dim_names
245 return {"mesh_shape": self._mesh.mesh_shape, "tensor_map": self._tensor_map,
246 "interleaved_parallel": interleaved_parallel, "alias_name": self._mesh.mesh_dim_names,
247 "rank_list": self._rank_list}
249 def placement_to_tensor_map(self, dim):
250 """
251 Transform placement to tensor map.
253 This method converts the `placements` configuration (consisting of Shard, Replicate, Partial)
254 into a `tensor_map` representation used for distributed tensor operations.
256 Args:
257 dim (int): The dimension of the tensor. Must be a positive integer.
259 Returns:
260 tuple: A tuple representing the tensor map, where each element corresponds to a tensor dimension.
261 A value of -1 indicates the dimension is not sharded, while other values indicate
262 the mesh dimension index along which the tensor dimension is sharded.
264 Raises:
265 ValueError: If `dim` is negative.
266 ValueError: If a shard dimension in `placements` is out of bounds for the given tensor dimension.
267 ValueError: If a tensor dimension is sharded by multiple mesh axes.
268 """
269 if dim < 0:
270 raise ValueError(f"Tensor dimension must be positive, but got {dim}")
271 if dim == 0:
272 return self._handle_zero_dim_placement()
274 dim_map = self._build_dim_map_from_placements(dim)
275 tensor_map = self._convert_dim_map_to_tensor_map(dim_map)
276 self.set_tensor_map(tuple(tensor_map))
277 self._alias_tensor_map = self._build_readable_tensor_map()
278 self.update_compact_str()
279 return tensor_map
281 def _handle_zero_dim_placement(self):
282 """Handle the special case of zero-dimensional tensor."""
283 self.set_tensor_map(())
284 self._alias_tensor_map = ()
285 for mesh_idx, placement in enumerate(self.placements):
286 if isinstance(placement, Partial):
287 self._partial[mesh_idx] = self._extract_reduce_op(placement)
288 return []
290 def _build_dim_map_from_placements(self, dim):
291 """Build dimension map from placements."""
292 dim_map = [-1] * dim
293 self.partial_ops = {}
294 for mesh_idx, placement in enumerate(self.placements):
295 if isinstance(placement, Shard):
296 shard_dim = placement.dim
297 if shard_dim < -dim or shard_dim >= dim:
298 raise ValueError(f"Shard dimension {shard_dim} is out of bounds for tensor of dimension {dim}")
299 if shard_dim < 0:
300 shard_dim += dim
301 if dim_map[shard_dim] != -1:
302 raise ValueError(f"Dimension {shard_dim} has been sharded by Mesh axis {dim_map[shard_dim]}")
303 dim_map[shard_dim] = mesh_idx
304 elif isinstance(placement, Partial):
305 self._partial[mesh_idx] = self._extract_reduce_op(placement)
306 return dim_map
308 def _extract_reduce_op(self, placement):
309 """Extract reduce operation name from Partial placement."""
310 op_name = getattr(placement, "reduce_op", "sum")
311 if isinstance(op_name, str):
312 op_name = op_name.lower()
313 return op_name
315 def _convert_dim_map_to_tensor_map(self, dim_map):
316 """Convert dimension map to tensor map format."""
317 device_dim_count = len(self.mesh_shape)
318 return [
319 device_dim_count - 1 - mesh_idx if mesh_idx != -1 else -1
320 for mesh_idx in dim_map
321 ]
323 def _build_readable_tensor_map(self):
324 """Build human-readable alias tensor map from tensor_map."""
325 readable_map = []
326 for item in self._tensor_map:
327 if self._mesh.mesh_dim_names is None:
328 readable_map.append("None")
329 elif isinstance(item, tuple):
330 mapped_tuple = tuple(
331 self._mesh.mesh_dim_names[len(self._mesh.mesh_dim_names) - 1 - dim] if dim != -1 else "None"
332 for dim in item
333 )
334 readable_map.append(mapped_tuple)
335 else:
336 readable_map.append(
337 self._mesh.mesh_dim_names[len(self._mesh.mesh_dim_names) - 1 - item] if item != -1 else "None"
338 )
339 return tuple(readable_map)
341 def tensor_map_to_placement(self):
342 """
343 Transform tensor map to placement.
345 This method converts the existing `tensor_map` and `partial` status into a list of `Placement` objects
346 (Shard, Replicate, Partial). This is the inverse operation of `placement_to_tensor_map`.
348 Returns:
349 list[Placement]: A list of Placement objects describing the distribution strategy for each
350 dimension of the device mesh.
352 Raises:
353 ValueError: If `tensor_map` is not configured (None).
354 """
355 if self._tensor_map is None:
356 raise ValueError("The tensor_map is None, cannot transform to placements.")
357 mesh_ndim = len(self.mesh_shape)
358 placements = [Replicate()] * mesh_ndim
359 for tensor_dim, mapping in enumerate(self._tensor_map):
360 mapping_list = mapping if isinstance(mapping, tuple) else (mapping,)
361 for map_val in mapping_list:
362 if map_val != -1:
363 root_mesh_idx = mesh_ndim - 1 - map_val
364 placements[root_mesh_idx] = Shard(dim=tensor_dim)
365 for mesh_idx, op in enumerate(self.partial):
366 if op is not None:
367 placements[mesh_idx] = Partial(reduce_op=op)
368 self.set_placements(placements)
369 return placements
371 def __setstate__(self, state):
372 self.__dict__.update(state)
373 self.update_mesh()
375 @property
376 def mesh(self):
377 return self._mesh
379 def update_mesh(self):
380 self._mesh = _create_device_mesh("npu", self.mesh_shape, mesh_dim_names=self.alias_name,
381 rank_list=self.rank_list)
383 @property
384 def rank_list(self):
385 """rank list"""
386 return self._rank_list
388 @rank_list.setter
389 def rank_list(self, val):
390 self._rank_list = val
392 @property
393 def mesh_shape(self):
394 """mesh shape"""
395 return self._mesh.mesh_shape
397 @property
398 def alias_name(self):
399 """alias name"""
400 return self._mesh.mesh_dim_names
402 @property
403 def alias_tensor_map(self):
404 return self._alias_tensor_map
406 def set_alias_tensor_map(self, alias_tensor_map):
407 """Set alias_tensor_map"""
408 self._alias_tensor_map = alias_tensor_map
410 @property
411 def placements(self):
412 """placements"""
413 return self._placements
415 def set_placements(self, placements):
416 """Set placements."""
417 self._placements = placements
419 @property
420 def tensor_map(self):
421 """tensor map"""
422 return self._tensor_map
424 def set_tensor_map(self, tensor_map):
425 """Set tensor_map."""
426 self._tensor_map = tensor_map
428 @property
429 def partial(self):
430 """partial status"""
431 return self._partial
433 def set_partial_by_dev_axis(self, axis, op):
434 """Set the partial status for the specified dev ID, means pending to do reduce by op."""
435 if op not in self._support_partial_op:
436 raise ValueError(f"Partial op must be one of {self._support_partial_op}, but got {op}")
437 if self.is_dev_axis_apply_shard(axis):
438 raise ValueError("Partial dim must be replicate.")
439 self._partial[self._mesh.axis_index(axis)] = op
440 self.tensor_map_to_placement()
441 self.update_compact_str()
443 def get_partial_by_dev_id(self, axis):
444 """Get the partial status for the specified dev id"""
445 return self.partial[self._mesh.axis_index(axis)]
447 def is_dev_axis_apply_shard(self, axis):
448 """Return true if device axis is applying shard"""
449 axis_id = self._mesh.axis_id(axis)
451 def flatten(input_x):
452 flatten_res = []
453 for item in input_x:
454 if isinstance(item, tuple):
455 flatten_res.extend(flatten(item))
456 else:
457 flatten_res.append(item)
458 return flatten_res
460 flatten_tensor_map = flatten(self.tensor_map)
461 return axis_id in flatten_tensor_map
463 def get_dev_axis_apply_shard_axis(self, axis):
464 """Return the axis which be split by axis. If axis not be apply to shard, return None."""
465 for dim, dim_map in enumerate(self.alias_tensor_map):
466 if (isinstance(dim_map, tuple) and axis in dim_map) or axis == dim_map:
467 return dim
468 return None
470 def reset_partial(self):
471 self._partial = [None] * len(self.mesh_shape)
472 self.tensor_map_to_placement()
473 self.update_compact_str()
475 def is_partial(self):
476 """Return true if any dim in mesh_shape is partial"""
477 return any(self.partial)
479 def get_global_shape(self, slice_shape):
480 """get global shape"""
481 return self._mesh.get_global_shape(slice_shape, self._tensor_map)
483 def get_devices_for_axis(self, axis, rank):
484 """
485 Get the repeat rank list when the axis is not shard.
487 Args:
488 layout (Layout): Layout
489 axis (str): Axis name.
490 rank (int): Global rank
492 Returns:
493 list: reduce rank list
494 """
495 return self._mesh.get_devices_for_axis(axis, rank)
497 def get_comm_group_by_axis(self, axis):
498 return self._mesh.get_comm_group_by_axis(axis)
500 def repeat_num(self):
501 """
502 Number of repeated placements.
503 In pipeline parallel, only the last stage return repeat num, other stages return -1.
504 For example:
505 layout = Layout((2, 4), ("dp", "mp"))
506 x_layout = layout("dp", "None")
507 The repeat_num is equal to all device num 8 divided by device num corresponding to used axis 2, that is 4.
508 """
509 if self._tensor_map is None:
510 raise ValueError(f"The tensor_map is None, the mesh_shape is {self._mesh.mesh_shape},"
511 f" alias_name is {self._mesh.mesh_dim_names}")
513 # if it is not the last stage, return -1
514 group_size = platform.get_world_size()
515 if self._rank_list[-1] != (group_size - 1):
516 return -1
518 all_device_num = functools.reduce(lambda x, y: x * y, self._mesh.mesh_shape)
519 used_dev_num = 1
520 for ele in self._tensor_map:
521 if isinstance(ele, tuple):
522 for item in ele:
523 if item >= 0:
524 used_dev_num *= self._mesh.mesh_shape[len(self._mesh.mesh_shape) - item - 1]
525 continue
526 if ele >= 0:
527 used_dev_num *= self._mesh.mesh_shape[len(self._mesh.mesh_shape) - ele - 1]
529 return all_device_num // used_dev_num
531 def _to_compact_string(self):
532 """
533 generate dict key
535 Returns:
536 str: string for compact
537 """
538 mesh_key = self._mesh.to_hash()
539 hash_key = (self._tensor_map, self.partial)
540 hash_key += mesh_key
541 return str(hash_key)
543 @property
544 def compact_str(self):
545 return self._compact_str
547 def update_compact_str(self):
548 self._compact_str = self._to_compact_string()
550 def to_string(self):
551 """
552 layout dump
554 Returns:
555 str: layout string
556 """
557 device_info = f"Mesh shape: {self._mesh.mesh_shape}"
558 alias_info = f"Alias Names: {self._mesh.mesh_dim_names}"
559 rank_info = f"Rank List: {self._rank_list}"
560 partial_info = f"Partial: {self.partial}"
562 if self._tensor_map is None:
563 tensor_info = "Tensor Map: Not configured"
564 else:
565 readable_map = []
566 for item in self._tensor_map:
567 if isinstance(item, tuple):
568 # 处理嵌套元组
569 mapped_tuple = tuple(
570 self._mesh.mesh_dim_names[len(self._mesh.mesh_dim_names) - 1 - dim] if dim != -1 else "None"
571 for dim in item
572 )
573 readable_map.append(mapped_tuple)
574 else:
575 readable_map.append(
576 self._mesh.mesh_dim_names[len(self._mesh.mesh_dim_names) - 1 - item] if item != -1 else "None"
577 )
579 tensor_info = f"Tensor Map: {tuple(readable_map)}"
581 interleaved = "Yes" if "interleaved_parallel" in self._mesh.mesh_dim_names else "No"
582 interleaved_info = f"Interleaved Parallel: {interleaved}"
584 return (
585 f"Layout Configuration:\n"
586 f" {device_info}\n"
587 f" {alias_info}\n"
588 f" {partial_info}\n"
589 f" {tensor_info}\n"
590 f" {interleaved_info}\n"
591 f" {rank_info}"
592 )
594 def __str__(self):
595 """__str__"""
596 return self.to_string()
598 def __repr__(self):
599 """__repr__"""
600 return f"<Layout at {hex(id(self))}>"
602 def __eq__(self, other):
603 """
604 __eq__
605 """
606 if not isinstance(other, Layout):
607 return False
609 if (self.mesh_shape != other.mesh_shape or
610 self.alias_name != other.alias_name or
611 self.partial != other.partial or
612 self.rank_list != other.rank_list):
613 return False
615 if self._tensor_map is None or other.tensor_map is None:
616 return self._tensor_map is other.tensor_map
617 return self._tensor_map == other.tensor_map