Coverage for hyper_parallel / core / utils.py: 100%
14 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1"""
2Utility functions for distributed tensor operations.
4This module provides helper functions for computing local shapes, global offsets,
5and other layout-related calculations in distributed settings.
6"""
7from hyper_parallel.core.layout import Layout
9def compute_local_shape_and_global_offset(global_shape, device_mesh, placement):
10 """
11 Compute local shard shape and its global offset.
13 Args:
14 global_shape: Shape of the global tensor.
15 mesh: Device mesh for distributed execution.
16 placements: Sharding placements for each dimension.
18 Returns:
19 tuple: (local_shape, global_offset)
20 """
21 total_layout = Layout.from_device_mesh(device_mesh)
22 layout = total_layout(placement)
23 layout.placement_to_tensor_map(len(global_shape))
24 slice_shape = list(global_shape)
25 alias_tensor_map = layout.alias_tensor_map
26 for i, axis_name in enumerate(alias_tensor_map):
27 if isinstance(axis_name, str):
28 axis_name = (axis_name,)
29 for sub_axis_name in axis_name:
30 if sub_axis_name != "None":
31 slice_shape[i] = slice_shape[i] // layout.mesh.get_device_num_along_axis(sub_axis_name)
32 return slice_shape