Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / layout.py: 75%
406 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""layout"""
17import copy
18import functools
19import numpy as np
22from hyper_parallel.core.dtensor.placement_types import Placement, Shard, StridedShard, Replicate, Partial
23from hyper_parallel.core.dtensor.device_mesh import DeviceMesh, _create_device_mesh
24from hyper_parallel.platform import get_platform
26platform = get_platform()
29def _infer_slice_area_by_rank(mesh_shape, tensor_map, rank_id: int, full_shape: tuple): # -> tuple[tuple[int]]:
30 """Return the range of each axis from full tensor for slice in current rank."""
32 def _get_dev_num_alone_dim(mesh_shape, dim):
33 """_get_dev_num_alone_dim."""
34 return mesh_shape[-dim - 1] if dim != -1 else 1
36 def _rank_id_to_dev_id_list(mesh_shape, rank_id):
37 """Infer dev id list by rank_id and mesh_shape"""
38 dims = len(mesh_shape)
39 dev_id_list = [0] * dims
40 for i in range(dims - 1, -1, -1):
41 dev_id_list[i] = rank_id % mesh_shape[i]
42 rank_id = rank_id // mesh_shape[i]
43 return dev_id_list
45 dev_id_list = _rank_id_to_dev_id_list(mesh_shape, rank_id)
47 dims = len(full_shape)
48 area = []
49 for axis in range(dims):
50 mapping = tensor_map[axis]
51 if isinstance(mapping, int):
52 mapping = (mapping,)
53 split_num = 1
54 for dim in mapping:
55 split_num *= _get_dev_num_alone_dim(mesh_shape, dim)
57 slice_id = 0
58 coef = 1
59 for dim in reversed(mapping):
60 if dim == -1:
61 continue
62 slice_id += dev_id_list[-dim - 1] * coef
63 coef *= _get_dev_num_alone_dim(mesh_shape, dim)
64 slice_size = full_shape[axis] // split_num
65 start = slice_id * slice_size
66 end = start + slice_size
67 area.append((start, end))
68 return area
71def _get_slice_tensor_by_layout(global_tensor, layout):
72 """Transfer global tensor to local tensor by layout"""
73 inner_rank_id = layout.rank_list.index(layout.mesh.rank)
74 slice_area = _infer_slice_area_by_rank(layout.mesh_shape, layout.tensor_map, inner_rank_id, global_tensor.shape)
76 def get_slice_data(full_data, offset):
77 area = ()
78 for begin, end in offset:
79 area += (slice(begin, end),)
80 return full_data[area].clone()
82 local_tensor = get_slice_data(global_tensor, slice_area)
83 return local_tensor
86def _infer_slice_shape_by_layout(global_shape, layout):
87 """Infer slice shape from global_shape and layout"""
88 slice_shape = list(global_shape)
89 alias_tensor_map = layout.alias_tensor_map
90 for i in range(len(global_shape)):
91 axis_name = alias_tensor_map[i]
92 if isinstance(axis_name, str):
93 axis_name = (axis_name,)
94 for sub_axis_name in axis_name:
95 if sub_axis_name != "None":
96 slice_shape[i] = slice_shape[i] // layout.mesh.get_device_num_along_axis(sub_axis_name)
97 return slice_shape
100class Layout:
101 """
102 Topological abstraction describing cluster devices for tensor slice placement on the cluster.
104 Note:
105 - It is valid only in semi auto parallel or auto parallel mode.
106 - The multiplication result of the `mesh_shape` must be equal to the device count in a pipeline stage.
107 - When the layout function is invoked to constructs a sharding strategy, each alias name is only allowed to be
108 used once to shard a tensor.
110 Args:
111 mesh_shape (tuple): Describe the shape of devices arrangement, its element type is int.
112 alias_name (tuple): The alias name for each axis of mesh_shape, its length shoits element type is string.
113 When using "interleaved_parallel" as an alias name, the tensor would be split into multiple
114 copies on the corresponding partition dimension on a single card.
115 rank_list (tuple, optional): Data is allocated to the device according to rank_list. Default: ``None``.
117 Raises:
118 TypeError: `mesh_shape` is not a tuple type.
119 TypeError: `alias_name` is not a tuple type.
120 TypeError: 'rank_list' is not a list type.
121 ValueError: `mesh_shape` length is not equal to `alias_name` length.
122 TypeError: The element of `mesh_shape` is not int type.
123 TypeError: The element of `alias_name` is not a str type.
124 TypeError: The element of `rank_list` is not int type.
125 ValueError: The element of `alias_name` is an empty str.
126 ValueError: The element of `alias_name` is "None".
127 ValueError: `alias_name` contains repeated element.
129 Supported Platforms:
130 ``Ascend``
132 Examples:
133 >>> from mindspore.parallel import Layout
134 >>> layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
135 >>> layout0 = layout("dp", "mp")
136 >>> print(layout0.to_dict())
137 {"mesh_shape": (2, 2, 2), "tensor_map": (2, 0), "interleaved_parallel": False,
138 'alias_name': {'dp', 'sp', 'mp'}, "rank_list": [0, 1, 2, 3, 4, 5, 6, 7]}
139 >>> layout = Layout((2, 2, 2), ("dp", "sp", "interleaved_parallel"))
140 >>> layout1 = layout(("dp", "interleaved_parallel"), "sp")
141 """
143 def __init__(self, mesh_shape, alias_name, rank_list=None, init_backend=True):
144 self._alias_name = alias_name
145 self._tensor_map = None
146 if not rank_list:
147 self._rank_list = tuple(range(np.prod(np.array(mesh_shape))))
148 else:
149 self._rank_list = tuple(rank_list)
150 self._partial = [None] * len(mesh_shape) # partial status for each dev dim
151 self._support_partial_op = ['sum', 'max', 'min', 'avg', 'prod', 'all', None]
152 self._alias_tensor_map = None
153 self._mesh = _create_device_mesh("npu", mesh_shape, mesh_dim_names=alias_name, rank_list=self._rank_list,
154 init_backend=init_backend)
155 self._compact_str = self._to_compact_string()
156 self._placements = None
157 self.partial_ops = {} # Initialized in _build_dim_map_from_placements()
159 @classmethod
160 def from_device_mesh(cls, device_mesh: DeviceMesh) -> 'Layout':
161 """
162 Create a Layout from an existing DeviceMesh.
164 Args:
165 device_mesh (DeviceMesh): The device mesh to create layout from.
167 Returns:
168 Layout: A new Layout instance initialized with the properties of the provided device mesh.
170 Examples:
171 >>> from hyper_parallel.core.dtensor.layout import Layout, DeviceMesh
172 >>> device_mesh = DeviceMesh("npu", (2, 2), mesh_dim_names=("dp", "mp"))
173 >>> layout = Layout.from_device_mesh(device_mesh)
174 """
175 obj = cls.__new__(cls)
176 obj._mesh = device_mesh
177 obj._alias_name = device_mesh.mesh_dim_names
178 obj._rank_list = device_mesh.rank_list
179 obj._tensor_map = None
180 obj._partial = [None] * len(device_mesh.mesh_shape)
181 obj._support_partial_op = ['sum', 'max', 'min', 'avg', 'prod', 'all', None]
182 obj._alias_tensor_map = None
183 obj._placements = None
184 obj._compact_str = obj._to_compact_string()
185 return obj
187 def __call__(self, *alias_tensor_map):
188 obj = copy.deepcopy(self)
190 # Clear the inherited partial status.
191 # When creating a new layout mapping configuration via __call__,
192 # it should not inherit the dynamic execution state (Partial) of the original layout.
193 # If the user intends to create a Partial placement, it will be parsed from alias_tensor_map.
194 obj._partial = [None] * len(obj.mesh_shape)
196 if len(alias_tensor_map) == 1 and isinstance(alias_tensor_map[0], (list, tuple)):
197 if len(alias_tensor_map[0]) > 0 and isinstance(alias_tensor_map[0][0], Placement):
198 return self._process_placement_layout(obj, alias_tensor_map[0])
200 if len(alias_tensor_map) > 0 and isinstance(alias_tensor_map[0], Placement):
201 return self._process_placement_layout(obj, alias_tensor_map)
203 return self._process_alias_layout(obj, alias_tensor_map)
205 def __deepcopy__(self, memo):
206 """Deep copy layout without rebuilding the underlying device mesh."""
207 cls = self.__class__
208 result = cls.__new__(cls)
209 memo[id(self)] = result
210 for k, v in self.__dict__.items():
211 setattr(result, k, copy.deepcopy(v, memo))
212 return result
214 @staticmethod
215 def _process_placement_layout(obj, placements):
216 """Process layout defined by Placement types."""
217 obj.set_placements(placements)
218 return copy.deepcopy(obj)
220 @staticmethod
221 def _process_alias_layout(obj, alias_tensor_map):
222 """Process layout defined by alias strings."""
223 obj.set_alias_tensor_map(alias_tensor_map)
224 tensor_map = ()
225 writed_map = ()
226 for ele in alias_tensor_map:
227 if isinstance(ele, tuple):
228 ele_map = ()
229 for item in ele:
230 if item == "None":
231 ele_map += (-1,)
232 continue
233 if item not in obj.alias_name:
234 raise ValueError(f'The axis {item} is not found in {obj.alias_name}')
235 if item in writed_map:
236 raise ValueError(f'The axis {item} has been set more than one in {obj.alias_name}')
237 ele_map += (len(obj.alias_name) - 1 - obj.alias_name.index(item),)
238 writed_map += (item,)
239 tensor_map += (ele_map,)
240 continue
241 if ele == "None":
242 tensor_map += (-1,)
243 continue
244 if ele not in obj.alias_name:
245 raise ValueError(f'The axis {ele} is not found in {obj.alias_name}')
246 if ele in writed_map:
247 raise ValueError(f'The axis {ele} has been set more than one in {obj.alias_name}')
248 tensor_map += (len(obj.alias_name) - 1 - obj.alias_name.index(ele),)
249 writed_map += (ele,)
250 obj.set_tensor_map(tensor_map)
251 obj.tensor_map_to_placement()
252 obj.update_compact_str()
253 return copy.deepcopy(obj)
255 def to_dict(self):
256 """
257 Transform layout to a dictionary.
258 """
259 if self._mesh.mesh_shape is None:
260 raise ValueError("The device_shape of layout is None")
261 if self._tensor_map is None:
262 raise ValueError("The tensor_map of layout is None")
263 interleaved_parallel = "interleaved_parallel" in self._mesh.mesh_dim_names
264 return {"mesh_shape": self._mesh.mesh_shape, "tensor_map": self._tensor_map,
265 "interleaved_parallel": interleaved_parallel, "alias_name": self._mesh.mesh_dim_names,
266 "rank_list": self._rank_list}
268 def placement_to_tensor_map(self, dim):
269 """
270 Transform placement to tensor map.
272 This method converts the `placements` configuration (consisting of Shard, StridedShard,
273 Replicate, Partial)
274 into a `tensor_map` representation used for distributed tensor operations.
276 Args:
277 dim (int): The dimension of the tensor. Must be a positive integer.
279 Returns:
280 tuple: A tuple representing the tensor map, where each element corresponds to a tensor dimension.
281 A value of -1 indicates the dimension is not sharded, an integer indicates the mesh
282 dimension index along which the tensor dimension is sharded, and a tuple indicates
283 that the same tensor dimension is sharded multiple times in order.
285 Raises:
286 ValueError: If `dim` is negative.
287 ValueError: If a shard dimension in `placements` is out of bounds for the given tensor dimension.
288 """
289 if dim < 0:
290 raise ValueError(f"Tensor dimension must be positive, but got {dim}")
291 if dim == 0:
292 return self._handle_zero_dim_placement()
294 dim_map = self._build_dim_map_from_placements(dim)
295 tensor_map = self._convert_dim_map_to_tensor_map(dim_map)
296 self.set_tensor_map(tuple(tensor_map))
297 self._alias_tensor_map = self._build_readable_tensor_map()
298 self.update_compact_str()
299 return tensor_map
301 def _handle_zero_dim_placement(self):
302 """Handle the special case of zero-dimensional tensor."""
303 self.set_tensor_map(())
304 self._alias_tensor_map = ()
305 for mesh_idx, placement in enumerate(self.placements):
306 if isinstance(placement, Partial):
307 self._partial[mesh_idx] = self._extract_reduce_op(placement)
308 return []
310 def _build_dim_map_from_placements(self, dim):
311 """Build dimension map from placements."""
312 dim_map = [-1] * dim
313 self.partial_ops = {}
314 for mesh_idx, placement in enumerate(self.placements):
315 if isinstance(placement, Shard):
316 shard_dim = placement.dim
317 if shard_dim < -dim or shard_dim >= dim:
318 raise ValueError(f"Shard dimension {shard_dim} is out of bounds for tensor of dimension {dim}")
319 if shard_dim < 0:
320 shard_dim += dim
321 if dim_map[shard_dim] == -1:
322 dim_map[shard_dim] = [mesh_idx]
323 else:
324 dim_map[shard_dim].append(mesh_idx)
325 elif isinstance(placement, Partial):
326 self._partial[mesh_idx] = self._extract_reduce_op(placement)
327 self._validate_strided_shard_split_factor(dim_map)
328 self._reorder_dim_map_for_strided_shard(dim_map)
329 return dim_map
331 @staticmethod
332 def _placement_split_factor(placement):
333 """Return the effective split factor carried by a placement."""
334 return placement.split_factor if isinstance(placement, StridedShard) else 1
336 @staticmethod
337 def _build_order_positions(shard_order):
338 """Build a mesh axis to order position mapping."""
339 return {mesh_idx: order_idx for order_idx, mesh_idx in enumerate(shard_order)}
341 def _compute_expected_split_factors(self, shard_axes, shard_order):
342 """Infer the split_factor each mesh axis should carry for the given sharding order."""
343 order_positions = self._build_order_positions(shard_order)
344 expected_split_factors = {}
345 for mesh_idx in shard_axes:
346 split_factor = 1
347 for right_mesh_idx in shard_axes:
348 if right_mesh_idx <= mesh_idx:
349 continue
350 if order_positions[right_mesh_idx] < order_positions[mesh_idx]:
351 split_factor *= self.mesh_shape[right_mesh_idx]
352 expected_split_factors[mesh_idx] = split_factor
353 return expected_split_factors
355 def _get_effective_shard_axes(self, shard_axes):
356 """Return shard axes ordered by their effective sharding order."""
357 return sorted(
358 shard_axes,
359 key=lambda mesh_idx: self._placement_split_factor(self.placements[mesh_idx]),
360 )
362 def _reorder_dim_map_for_strided_shard(self, dim_map):
363 """Reorder dim_map entries to reflect the effective sharding order."""
364 for i, shard_axes in enumerate(dim_map):
365 if shard_axes == -1 or len(shard_axes) <= 1:
366 continue
367 dim_map[i] = self._get_effective_shard_axes(shard_axes)
369 def _validate_strided_shard_split_factor(self, dim_map):
370 """Validate that split factors match the effective sharding order."""
371 for shard_axes in dim_map:
372 if shard_axes == -1:
373 continue
374 shard_order = self._get_effective_shard_axes(shard_axes)
375 expected_split_factors = self._compute_expected_split_factors(
376 shard_axes, shard_order
377 )
378 for mesh_idx in shard_axes:
379 placement = self.placements[mesh_idx]
380 actual_split_factor = self._placement_split_factor(placement)
381 expected_split_factor = expected_split_factors[mesh_idx]
382 if actual_split_factor != expected_split_factor:
383 raise ValueError(
384 f"StridedShard split_factor mismatch on mesh axis {mesh_idx}: "
385 f"expected {expected_split_factor}, got {actual_split_factor}."
386 )
388 @staticmethod
389 def _extract_reduce_op(placement):
390 """Extract reduce operation name from Partial placement."""
391 op_name = getattr(placement, "reduce_op", "sum")
392 if isinstance(op_name, str):
393 op_name = op_name.lower()
394 return op_name
396 def _convert_dim_map_to_tensor_map(self, dim_map):
397 """Convert dimension map to tensor map format."""
398 device_dim_count = len(self.mesh_shape)
399 tensor_map = []
400 for mesh_idx in dim_map:
401 if mesh_idx == -1:
402 tensor_map.append(-1)
403 continue
404 mapped_axes = tuple(device_dim_count - 1 - axis for axis in mesh_idx)
405 tensor_map.append(mapped_axes[0] if len(mapped_axes) == 1 else mapped_axes)
406 return tensor_map
408 def _build_readable_tensor_map(self):
409 """Build human-readable alias tensor map from tensor_map."""
410 mesh_dim_names = self._mesh.mesh_dim_names
411 has_names = mesh_dim_names is not None
413 def _map_dim(dim):
414 """covert dimension index to dimension name."""
415 if dim == -1:
416 return "None"
417 if not has_names:
418 return f"dim_{dim}"
419 return mesh_dim_names[len(mesh_dim_names) - 1 - dim]
421 readable_map = []
422 for item in self._tensor_map:
423 if isinstance(item, tuple):
424 mapped_tuple = tuple(_map_dim(dim) for dim in item)
425 readable_map.append(mapped_tuple)
426 else:
427 readable_map.append(_map_dim(item))
428 return tuple(readable_map)
430 def tensor_map_to_placement(self):
431 """
432 Transform tensor map to placement.
434 This method converts the existing `tensor_map` and `partial` status into a list of `Placement` objects
435 (Shard, StridedShard, Replicate, Partial). This is the inverse operation of
436 `placement_to_tensor_map`.
438 Returns:
439 list[Placement]: A list of Placement objects describing the distribution strategy for each
440 dimension of the device mesh.
442 Raises:
443 ValueError: If `tensor_map` is not configured (None).
444 """
445 if self._tensor_map is None:
446 raise ValueError("The tensor_map is None, cannot transform to placements.")
447 mesh_ndim = len(self.mesh_shape)
448 placements = [Replicate()] * mesh_ndim
449 for tensor_dim, mapping in enumerate(self._tensor_map):
450 mapping_list = mapping if isinstance(mapping, tuple) else (mapping,)
451 valid_mapping = [map_val for map_val in mapping_list if map_val != -1]
452 mesh_indices = [mesh_ndim - 1 - map_val for map_val in valid_mapping]
453 shard_axes = sorted(mesh_indices)
454 expected_split_factors = self._compute_expected_split_factors(
455 shard_axes, mesh_indices
456 )
457 for mesh_idx in shard_axes:
458 split_factor = expected_split_factors[mesh_idx]
459 placement = (
460 StridedShard(dim=tensor_dim, split_factor=split_factor)
461 if split_factor > 1
462 else Shard(dim=tensor_dim)
463 )
464 placements[mesh_idx] = placement
465 for mesh_idx, op in enumerate(self.partial):
466 if op is not None:
467 placements[mesh_idx] = Partial(reduce_op=op)
468 self.set_placements(placements)
469 return placements
471 def __setstate__(self, state):
472 self.__dict__.update(state)
473 self.update_mesh(init_backend=False)
475 @property
476 def mesh(self):
477 """
478 Get the device mesh associated with this layout.
480 Returns:
481 DeviceMesh: The device mesh describing the device topology.
482 """
483 return self._mesh
485 def update_mesh(self, init_backend: bool = True):
486 """Recreate the internal DeviceMesh from current layout properties.
488 Args:
489 init_backend (bool): Whether to initialize communication backend
490 (process groups). Set to ``False`` during deserialization to
491 avoid creating process groups with a stale rank_list from the
492 sender side. Default ``True``.
493 """
494 self._mesh = _create_device_mesh("npu", self.mesh_shape, mesh_dim_names=self.alias_name,
495 rank_list=self.rank_list, init_backend=init_backend)
497 @property
498 def rank_list(self):
499 """
500 Get the list of ranks participating in this layout.
502 Returns:
503 tuple[int]: The rank list.
504 """
505 return self._rank_list
507 @rank_list.setter
508 def rank_list(self, val):
509 self._rank_list = val
511 @property
512 def mesh_shape(self):
513 """mesh shape"""
514 return self._mesh.mesh_shape
516 @property
517 def alias_name(self):
518 """alias name"""
519 return self._mesh.mesh_dim_names
521 @property
522 def alias_tensor_map(self):
523 return self._alias_tensor_map
525 @property
526 def alias_placements(self):
527 """Return alias_tensor_map when it contains multi-axis tuples, otherwise placements.
529 alias_tensor_map preserves multi-axis ordering information
530 (e.g., (("dp", "tp"), "None") vs (("tp", "dp"), "None"))
531 that Placement objects cannot represent, since both map to
532 [Shard(0), Shard(0)].
534 For single-axis layouts, Placement objects are preferred because they
535 also carry Partial status which alias_tensor_map cannot encode.
537 Use this property when constructing DTensors from an existing Layout
538 to avoid the lossy Placement round-trip for multi-axis cases.
539 """
540 if self._alias_tensor_map is not None and any(
541 isinstance(item, tuple) for item in self._alias_tensor_map
542 ):
543 return self._alias_tensor_map
544 return self._placements
546 def set_alias_tensor_map(self, alias_tensor_map):
547 """Set alias_tensor_map"""
548 self._alias_tensor_map = alias_tensor_map
550 @property
551 def placements(self):
552 """placements"""
553 return self._placements
555 def set_placements(self, placements):
556 """Set placements."""
557 self._placements = placements
559 @property
560 def tensor_map(self):
561 """tensor map"""
562 return self._tensor_map
564 def set_tensor_map(self, tensor_map):
565 """Set tensor_map."""
566 self._tensor_map = tensor_map
568 @property
569 def partial(self):
570 """partial status"""
571 return self._partial
573 def set_partial_by_dev_axis(self, axis, op):
574 """Set the partial status for the specified dev ID, means pending to do reduce by op."""
575 if op not in self._support_partial_op:
576 raise ValueError(f"Partial op must be one of {self._support_partial_op}, but got {op}")
577 if self.is_dev_axis_apply_shard(axis):
578 raise ValueError("Partial dim must be replicate.")
579 self._partial[self._mesh.axis_index(axis)] = op
580 self.tensor_map_to_placement()
581 self.update_compact_str()
583 def get_partial_by_dev_id(self, axis):
584 """Get the partial status for the specified dev id"""
585 return self.partial[self._mesh.axis_index(axis)]
587 def is_dev_axis_apply_shard(self, axis):
588 """Return true if device axis is applying shard"""
589 axis_id = self._mesh.axis_id(axis)
591 def flatten(input_x):
592 flatten_res = []
593 for item in input_x:
594 if isinstance(item, tuple):
595 flatten_res.extend(flatten(item))
596 else:
597 flatten_res.append(item)
598 return flatten_res
600 flatten_tensor_map = flatten(self.tensor_map)
601 return axis_id in flatten_tensor_map
603 def get_dev_axis_apply_shard_axis(self, axis):
604 """Return the axis which be split by axis. If axis not be apply to shard, return None."""
605 for dim, dim_map in enumerate(self.alias_tensor_map):
606 if (isinstance(dim_map, tuple) and axis in dim_map) or axis == dim_map:
607 return dim
608 return None
610 def reset_partial(self):
611 self._partial = [None] * len(self.mesh_shape)
612 self.tensor_map_to_placement()
613 self.update_compact_str()
615 def is_partial(self):
616 """Return true if any dim in mesh_shape is partial"""
617 return any(self.partial)
619 def get_global_shape(self, slice_shape):
620 """get global shape"""
621 return self._mesh.get_global_shape(slice_shape, self._tensor_map)
623 def get_devices_for_axis(self, axis, rank):
624 """
625 Get the repeat rank list when the axis is not shard.
627 Args:
628 layout (Layout): Layout
629 axis (str): Axis name.
630 rank (int): Global rank
632 Returns:
633 list: reduce rank list
634 """
635 return self._mesh.get_devices_for_axis(axis, rank)
637 def get_comm_group_by_axis(self, axis):
638 return self._mesh.get_comm_group_by_axis(axis)
640 def repeat_num(self):
641 """
642 Number of repeated placements.
643 For example:
644 layout = Layout((2, 4), ("dp", "mp"))
645 x_layout = layout("dp", "None")
646 The repeat_num is equal to all device num 8 divided by device num corresponding to used axis 2, that is 4.
647 """
648 if self._tensor_map is None:
649 raise ValueError(f"The tensor_map is None, the mesh_shape is {self._mesh.mesh_shape},"
650 f" alias_name is {self._mesh.mesh_dim_names}")
652 all_device_num = functools.reduce(lambda x, y: x * y, self._mesh.mesh_shape)
653 used_dev_num = 1
654 for ele in self._tensor_map:
655 if isinstance(ele, tuple):
656 for item in ele:
657 if item >= 0:
658 used_dev_num *= self._mesh.mesh_shape[len(self._mesh.mesh_shape) - item - 1]
659 continue
660 if ele >= 0:
661 used_dev_num *= self._mesh.mesh_shape[len(self._mesh.mesh_shape) - ele - 1]
663 return all_device_num // used_dev_num
665 def _to_compact_string(self):
666 """
667 generate dict key
669 Returns:
670 str: string for compact
671 """
672 mesh_key = self._mesh.to_hash()
673 hash_key = (self._tensor_map, self.partial)
674 hash_key += mesh_key
675 return str(hash_key)
677 @property
678 def compact_str(self):
679 return self._compact_str
681 def update_compact_str(self):
682 self._compact_str = self._to_compact_string()
684 def to_string(self):
685 """
686 layout dump
688 Returns:
689 str: layout string
690 """
691 device_info = f"Mesh shape: {self._mesh.mesh_shape}"
692 alias_info = f"Alias Names: {self._mesh.mesh_dim_names}"
693 rank_info = f"Rank List: {self._rank_list}"
694 partial_info = f"Partial: {self.partial}"
696 if self._tensor_map is None:
697 tensor_info = "Tensor Map: Not configured"
698 else:
699 readable_map = []
700 for item in self._tensor_map:
701 if isinstance(item, tuple):
702 # handle nested tuple
703 mapped_tuple = tuple(
704 self._mesh.mesh_dim_names[len(self._mesh.mesh_dim_names) - 1 - dim] if dim != -1 else "None"
705 for dim in item
706 )
707 readable_map.append(mapped_tuple)
708 else:
709 readable_map.append(
710 self._mesh.mesh_dim_names[len(self._mesh.mesh_dim_names) - 1 - item] if item != -1 else "None"
711 )
713 tensor_info = f"Tensor Map: {tuple(readable_map)}"
715 interleaved = "Yes" if "interleaved_parallel" in self._mesh.mesh_dim_names else "No"
716 interleaved_info = f"Interleaved Parallel: {interleaved}"
718 return (
719 f"Layout Configuration:\n"
720 f" {device_info}\n"
721 f" {alias_info}\n"
722 f" {partial_info}\n"
723 f" {tensor_info}\n"
724 f" {interleaved_info}\n"
725 f" {rank_info}"
726 )
728 def __str__(self):
729 """__str__"""
730 return self.to_string()
732 def __repr__(self):
733 """__repr__"""
734 return f"<Layout at {hex(id(self))}>"
736 def __eq__(self, other):
737 """
738 __eq__
739 """
740 if not isinstance(other, Layout):
741 return False
743 if (self.mesh_shape != other.mesh_shape or
744 self.alias_name != other.alias_name or
745 self.partial != other.partial or
746 self.rank_list != other.rank_list):
747 return False
749 if self._tensor_map is None or other.tensor_map is None:
750 return self._tensor_map is other.tensor_map
751 return self._tensor_map == other.tensor_map