Coverage for hyper_parallel / core / shard / ops / parallel_concat.py: 12%

32 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Concat operator. 

17""" 

18 

19from hyper_parallel.core.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23class ConcatDistributedOp(DistributedOp): 

24 """Distributed implementation for Concat operator.""" 

25 

26 def infer_layout(self, layouts, extra_args): 

27 """ 

28 Infer output layouts for Concat operations. 

29 

30 Args: 

31 layouts (tuple): Layouts of input tensors. 

32 extra_args (tuple): Extra arguments. 

33 For MindSpore Concat: (axis, ) 

34 For PyTorch cat: (dim, ) or () - dim defaults to 0. 

35 

36 Returns: 

37 tuple: Layout for output tensor. 

38 

39 Raises: 

40 ValueError: If input layouts are not compatible or have partial status. 

41 """ 

42 # Check partial inputs 

43 if not self._allow_partial_inputs: 

44 self._check_partial_inputs(layouts) 

45 

46 # Parse input layout 

47 base_layout = layouts[0] 

48 rank = len(base_layout.tensor_map) 

49 

50 # Determine concatenation dimension based on op_name and arguments 

51 dim = 0 

52 if self.op_name == "cat": 

53 # PyTorch 'cat': dim is optional and defaults to 0 

54 if extra_args: 

55 dim = extra_args[0] 

56 elif self.op_name == "Concat": 

57 # MindSpore 'Concat': axis is usually required and provided 

58 if not extra_args: 

59 # Fallback to 0 if not provided, though typically required for Concat 

60 # Or raise error if strict validation is needed 

61 pass 

62 else: 

63 dim = extra_args[0] 

64 

65 # Handle negative dimension 

66 if dim < 0: 

67 dim += rank 

68 

69 if dim < 0 or dim >= rank: 

70 raise ValueError( 

71 f"Operation {self.op_name}: dim value is out of valid range" 

72 ) 

73 

74 base_map = base_layout.tensor_map 

75 base_mesh_shape = base_layout.mesh_shape 

76 

77 for layout in layouts[1:]: 

78 if not layout: 

79 continue 

80 

81 if layout.mesh_shape != base_mesh_shape: 

82 raise ValueError( 

83 f"Operation {self.op_name}: Concat inputs must have same mesh_shape" 

84 ) 

85 

86 # Check consistency of tensor map on non-concatenation dimensions 

87 # The sharding strategy must be identical for all dimensions except the concat dimension 

88 if layout.tensor_map[:dim] + layout.tensor_map[dim + 1 :] != base_map[:dim] + base_map[dim + 1 :]: 

89 raise ValueError( 

90 f"Operation {self.op_name}: Except for dim, the tensor map of inputs must be equal" 

91 ) 

92 

93 # Create output layout 

94 output_layout = Layout( 

95 mesh_shape=base_layout.mesh_shape, 

96 alias_name=base_layout.alias_name, 

97 rank_list=base_layout.rank_list, 

98 ) 

99 

100 # Apply the alias strategy from the first input layout 

101 

102 output_layout = output_layout(*base_layout.alias_tensor_map) 

103 

104 

105 return (output_layout,)