Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_concat.py: 100%

22 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Concat operator. 

17""" 

18 

19from .parallel_ops import DistributedOp 

20 

21 

22class ConcatDistributedOp(DistributedOp): 

23 """Distributed implementation for Concat.""" 

24 

25 def infer_layout(self, layouts, extra_args=None): 

26 """ 

27 Infer output layout for Concat and normalize the concatenation dimension. 

28 Raises an error if the specified concatenation dimension is sharded. 

29 

30 Args: 

31 layouts (tuple): Layouts of input tensors and scalar arguments. 

32 extra_args (list): Additional arguments (e.g., dim). Modified in-place 

33 to store the normalized positive dimension. 

34 

35 Returns: 

36 Layout: The inferred output layout (identical to the input layouts). 

37 """ 

38 # Filter out None values which correspond to scalar arguments (e.g., dim) 

39 valid_layouts = [layout for layout in layouts if layout is not None] 

40 

41 if not valid_layouts: 

42 raise ValueError(f"Operation {self.op_name}: cat requires at least one input DTensor.") 

43 

44 # In this framework, we assume inputs must be aligned to the same layout 

45 # for concatenation. We use the first valid layout as the reference. 

46 base_layout = valid_layouts[0] 

47 

48 # Verify consistency across all valid input layouts 

49 for _, layout in enumerate(valid_layouts): 

50 if layout != base_layout: 

51 raise ValueError( 

52 f"Operation {self.op_name}: All input tensors must have the same layout. " 

53 f"Expected layout: {base_layout}, Mismatched layout: {layout}" 

54 ) 

55 

56 # Extract dim from extra_args, assuming the framework populates it correctly 

57 dim = extra_args[0] if extra_args else 0 

58 

59 # Convert negative dim to positive 

60 ndim = len(base_layout.tensor_map) 

61 actual_dim = dim if dim >= 0 else dim + ndim 

62 

63 # Check if the concatenation dimension is sharded 

64 mapping = base_layout.tensor_map[actual_dim] 

65 mapping_list = mapping if isinstance(mapping, tuple) else (mapping,) 

66 is_sharded = any(m != -1 for m in mapping_list) 

67 

68 if is_sharded: 

69 raise ValueError( 

70 f"Operation {self.op_name}: Concatenation along a sharded dimension " 

71 f"(dim={dim}, normalized_dim={actual_dim}) is not supported." 

72 ) 

73 

74 # Store the normalized actual_dim back into extra_args 

75 # so get_expand_impl can use it directly as a positive integer 

76 if extra_args: 

77 extra_args[0] = actual_dim 

78 else: 

79 extra_args.append(actual_dim) 

80 

81 return base_layout