Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / utils / shape_utils.py: 91%

23 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1""" 

2Utility functions for distributed tensor operations. 

3 

4This module provides helper functions for computing local shapes, global offsets, 

5and other layout-related calculations in distributed settings. 

6""" 

7from hyper_parallel.core.dtensor.layout import Layout 

8 

9 

10def compute_local_shape_and_global_offset(global_shape, device_mesh, placement): 

11 """ 

12 Compute local shard shape and its global offset. 

13 

14 Args: 

15 global_shape: Shape of the global tensor. 

16 mesh: Device mesh for distributed execution. 

17 placements: Sharding placements for each dimension. 

18 Supports Placement objects or alias strings. 

19 

20 Returns: 

21 tuple: (local_shape, global_offset) 

22 """ 

23 from hyper_parallel.core.dtensor.dtensor import _is_alias_placements # pylint: disable=C0415 

24 total_layout = Layout.from_device_mesh(device_mesh) 

25 if _is_alias_placements(placement): 

26 layout = total_layout(*placement) 

27 else: 

28 layout = total_layout(placement) 

29 layout.placement_to_tensor_map(len(global_shape)) 

30 slice_shape = list(global_shape) 

31 alias_tensor_map = layout.alias_tensor_map 

32 for i, axis_name in enumerate(alias_tensor_map): 

33 if isinstance(axis_name, str): 

34 axis_name = (axis_name,) 

35 for sub_axis_name in axis_name: 

36 if sub_axis_name != "None": 

37 num_devices = layout.mesh.get_device_num_along_axis(sub_axis_name) 

38 local_rank = layout.mesh.get_local_rank(sub_axis_name) 

39 global_size = slice_shape[i] 

40 remainder = global_size % num_devices 

41 # Consistent with torch.chunk: first `remainder` ranks get one extra element 

42 if remainder != 0 and local_rank < remainder: 

43 slice_shape[i] = global_size // num_devices + 1 

44 else: 

45 slice_shape[i] = global_size // num_devices 

46 return slice_shape