Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_concat.py: 100%
22 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Concat operator.
17"""
19from .parallel_ops import DistributedOp
22class ConcatDistributedOp(DistributedOp):
23 """Distributed implementation for Concat."""
25 def infer_layout(self, layouts, extra_args=None):
26 """
27 Infer output layout for Concat and normalize the concatenation dimension.
28 Raises an error if the specified concatenation dimension is sharded.
30 Args:
31 layouts (tuple): Layouts of input tensors and scalar arguments.
32 extra_args (list): Additional arguments (e.g., dim). Modified in-place
33 to store the normalized positive dimension.
35 Returns:
36 Layout: The inferred output layout (identical to the input layouts).
37 """
38 # Filter out None values which correspond to scalar arguments (e.g., dim)
39 valid_layouts = [layout for layout in layouts if layout is not None]
41 if not valid_layouts:
42 raise ValueError(f"Operation {self.op_name}: cat requires at least one input DTensor.")
44 # In this framework, we assume inputs must be aligned to the same layout
45 # for concatenation. We use the first valid layout as the reference.
46 base_layout = valid_layouts[0]
48 # Verify consistency across all valid input layouts
49 for _, layout in enumerate(valid_layouts):
50 if layout != base_layout:
51 raise ValueError(
52 f"Operation {self.op_name}: All input tensors must have the same layout. "
53 f"Expected layout: {base_layout}, Mismatched layout: {layout}"
54 )
56 # Extract dim from extra_args, assuming the framework populates it correctly
57 dim = extra_args[0] if extra_args else 0
59 # Convert negative dim to positive
60 ndim = len(base_layout.tensor_map)
61 actual_dim = dim if dim >= 0 else dim + ndim
63 # Check if the concatenation dimension is sharded
64 mapping = base_layout.tensor_map[actual_dim]
65 mapping_list = mapping if isinstance(mapping, tuple) else (mapping,)
66 is_sharded = any(m != -1 for m in mapping_list)
68 if is_sharded:
69 raise ValueError(
70 f"Operation {self.op_name}: Concatenation along a sharded dimension "
71 f"(dim={dim}, normalized_dim={actual_dim}) is not supported."
72 )
74 # Store the normalized actual_dim back into extra_args
75 # so get_expand_impl can use it directly as a positive integer
76 if extra_args:
77 extra_args[0] = actual_dim
78 else:
79 extra_args.append(actual_dim)
81 return base_layout