Coverage for hyper_parallel / core / shard / ops / parallel_concat.py: 12%
32 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Concat operator.
17"""
19from hyper_parallel.core.layout import Layout
20from .parallel_ops import DistributedOp
23class ConcatDistributedOp(DistributedOp):
24 """Distributed implementation for Concat operator."""
26 def infer_layout(self, layouts, extra_args):
27 """
28 Infer output layouts for Concat operations.
30 Args:
31 layouts (tuple): Layouts of input tensors.
32 extra_args (tuple): Extra arguments.
33 For MindSpore Concat: (axis, )
34 For PyTorch cat: (dim, ) or () - dim defaults to 0.
36 Returns:
37 tuple: Layout for output tensor.
39 Raises:
40 ValueError: If input layouts are not compatible or have partial status.
41 """
42 # Check partial inputs
43 if not self._allow_partial_inputs:
44 self._check_partial_inputs(layouts)
46 # Parse input layout
47 base_layout = layouts[0]
48 rank = len(base_layout.tensor_map)
50 # Determine concatenation dimension based on op_name and arguments
51 dim = 0
52 if self.op_name == "cat":
53 # PyTorch 'cat': dim is optional and defaults to 0
54 if extra_args:
55 dim = extra_args[0]
56 elif self.op_name == "Concat":
57 # MindSpore 'Concat': axis is usually required and provided
58 if not extra_args:
59 # Fallback to 0 if not provided, though typically required for Concat
60 # Or raise error if strict validation is needed
61 pass
62 else:
63 dim = extra_args[0]
65 # Handle negative dimension
66 if dim < 0:
67 dim += rank
69 if dim < 0 or dim >= rank:
70 raise ValueError(
71 f"Operation {self.op_name}: dim value is out of valid range"
72 )
74 base_map = base_layout.tensor_map
75 base_mesh_shape = base_layout.mesh_shape
77 for layout in layouts[1:]:
78 if not layout:
79 continue
81 if layout.mesh_shape != base_mesh_shape:
82 raise ValueError(
83 f"Operation {self.op_name}: Concat inputs must have same mesh_shape"
84 )
86 # Check consistency of tensor map on non-concatenation dimensions
87 # The sharding strategy must be identical for all dimensions except the concat dimension
88 if layout.tensor_map[:dim] + layout.tensor_map[dim + 1 :] != base_map[:dim] + base_map[dim + 1 :]:
89 raise ValueError(
90 f"Operation {self.op_name}: Except for dim, the tensor map of inputs must be equal"
91 )
93 # Create output layout
94 output_layout = Layout(
95 mesh_shape=base_layout.mesh_shape,
96 alias_name=base_layout.alias_name,
97 rank_list=base_layout.rank_list,
98 )
100 # Apply the alias strategy from the first input layout
102 output_layout = output_layout(*base_layout.alias_tensor_map)
105 return (output_layout,)