Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / utils / shape_utils.py: 91%
23 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1"""
2Utility functions for distributed tensor operations.
4This module provides helper functions for computing local shapes, global offsets,
5and other layout-related calculations in distributed settings.
6"""
7from hyper_parallel.core.dtensor.layout import Layout
10def compute_local_shape_and_global_offset(global_shape, device_mesh, placement):
11 """
12 Compute local shard shape and its global offset.
14 Args:
15 global_shape: Shape of the global tensor.
16 mesh: Device mesh for distributed execution.
17 placements: Sharding placements for each dimension.
18 Supports Placement objects or alias strings.
20 Returns:
21 tuple: (local_shape, global_offset)
22 """
23 from hyper_parallel.core.dtensor.dtensor import _is_alias_placements # pylint: disable=C0415
24 total_layout = Layout.from_device_mesh(device_mesh)
25 if _is_alias_placements(placement):
26 layout = total_layout(*placement)
27 else:
28 layout = total_layout(placement)
29 layout.placement_to_tensor_map(len(global_shape))
30 slice_shape = list(global_shape)
31 alias_tensor_map = layout.alias_tensor_map
32 for i, axis_name in enumerate(alias_tensor_map):
33 if isinstance(axis_name, str):
34 axis_name = (axis_name,)
35 for sub_axis_name in axis_name:
36 if sub_axis_name != "None":
37 num_devices = layout.mesh.get_device_num_along_axis(sub_axis_name)
38 local_rank = layout.mesh.get_local_rank(sub_axis_name)
39 global_size = slice_shape[i]
40 remainder = global_size % num_devices
41 # Consistent with torch.chunk: first `remainder` ranks get one extra element
42 if remainder != 0 and local_rank < remainder:
43 slice_shape[i] = global_size // num_devices + 1
44 else:
45 slice_shape[i] = global_size // num_devices
46 return slice_shape