Coverage for hyper_parallel / core / utils.py: 100%

14 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1""" 

2Utility functions for distributed tensor operations. 

3 

4This module provides helper functions for computing local shapes, global offsets, 

5and other layout-related calculations in distributed settings. 

6""" 

7from hyper_parallel.core.layout import Layout 

8 

9def compute_local_shape_and_global_offset(global_shape, device_mesh, placement): 

10 """ 

11 Compute local shard shape and its global offset. 

12 

13 Args: 

14 global_shape: Shape of the global tensor. 

15 mesh: Device mesh for distributed execution. 

16 placements: Sharding placements for each dimension. 

17 

18 Returns: 

19 tuple: (local_shape, global_offset) 

20 """ 

21 total_layout = Layout.from_device_mesh(device_mesh) 

22 layout = total_layout(placement) 

23 layout.placement_to_tensor_map(len(global_shape)) 

24 slice_shape = list(global_shape) 

25 alias_tensor_map = layout.alias_tensor_map 

26 for i, axis_name in enumerate(alias_tensor_map): 

27 if isinstance(axis_name, str): 

28 axis_name = (axis_name,) 

29 for sub_axis_name in axis_name: 

30 if sub_axis_name != "None": 

31 slice_shape[i] = slice_shape[i] // layout.mesh.get_device_num_along_axis(sub_axis_name) 

32 return slice_shape