Coverage for hyper_parallel / core / shard / ops / parallel_unbind.py: 96%

23 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Unbind operator. 

17""" 

18 

19from hyper_parallel.core.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23class UnbindDistributedOp(DistributedOp): 

24 """Distributed implementation for Unbind operator.""" 

25 

26 def infer_layout(self, layouts, extra_args): 

27 """ 

28 Infer output layouts for Unbind operator. 

29 

30 Args: 

31 layouts (tuple): Layouts of input tensor. 

32 extra_args (list): 

33 - If configured with 'WithShape' suffix, the last element is input_shapes (list of tuples). 

34 - Preceding elements are scalar arguments (e.g., dim). 

35 

36 Returns: 

37 tuple: A tuple of Layouts for the output tensors. 

38 """ 

39 # Parse arguments provided by _with_layout_infer_with_shape 

40 input_shapes = extra_args[-1] 

41 args = extra_args[:-1] 

42 

43 # Pytorch unbind(input, dim=0) 

44 dim = args[0] if args else 0 

45 

46 layout = layouts[0] 

47 shape = input_shapes[0] 

48 tensor_map = layout.tensor_map 

49 alias_tensor_map = layout.alias_tensor_map 

50 ndim = len(shape) 

51 

52 # Handle negative dimension 

53 if dim < 0: 

54 dim += ndim 

55 

56 if not 0 <= dim < ndim: 

57 raise ValueError(f"Dimension out of range (expected to be in range of [0, {ndim-1}], but got {dim})") 

58 

59 # Check if the dimension to unbind is sharded. 

60 # tensor_map values != -1 indicate the dimension is split across a mesh axis. 

61 # We cannot unbind a sharded dimension without explicit redistribution (Gather), 

62 # as it would result in different tensor lists on different ranks. 

63 if tensor_map[dim] != -1: 

64 raise ValueError( 

65 f"For 'unbind', the dimension {dim} is sharded (mapped to mesh axis {tensor_map[dim]}). " 

66 f"Unbinding a sharded dimension is not supported. " 

67 f"Please redistribute the tensor to replicate this dimension first." 

68 ) 

69 

70 # Construct output layout: remove the mapping for the unbound dimension 

71 # We use alias_tensor_map to utilize Layout's robust initialization 

72 out_alias_map = alias_tensor_map[:dim] + alias_tensor_map[dim+1:] 

73 

74 # Create a base layout object to generate the new layout 

75 base_layout = Layout( 

76 mesh_shape=layout.mesh_shape, 

77 alias_name=layout.alias_name, 

78 rank_list=layout.rank_list 

79 ) 

80 

81 # Generate the specific output layout 

82 out_layout = base_layout(*out_alias_map) 

83 

84 # Unbind returns a tuple of tensors, the number of which is the size of the unbound dimension 

85 num_outputs = shape[dim] 

86 

87 return (out_layout,) * num_outputs