Coverage for hyper_parallel / core / shard / ops / parallel_expand_dims.py: 77%
26 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for ExpandDims operator.
17"""
18from hyper_parallel.core.layout import Layout
19from .parallel_ops import DistributedOp
22class ExpandDimsDistributedOp(DistributedOp):
23 """Distributed implementation for ExpandDims operator."""
25 def infer_layout(self, layouts, extra_args):
26 """
27 Infer output layout for ExpandDims.
29 Args:
30 layouts (tuple): Tuple containing input layout.
31 extra_args: axis parameter (int or list/tuple containing int).
33 Returns:
34 Layout: Output layout with inserted dimension.
35 """
36 if not layouts:
37 raise ValueError(f"{self.__class__.__name__} requires at least one input layout")
39 x_layout = layouts[0]
41 if x_layout.mesh_shape is None:
42 raise ValueError("Input layout cannot be None.")
44 if not extra_args:
45 raise ValueError(f"{self.op_name}: axis parameter is required")
47 axis = extra_args[0] if isinstance(extra_args, (tuple, list)) else extra_args
49 in_rank = len(x_layout.alias_tensor_map)
50 if axis < 0:
51 axis = axis + in_rank + 1
53 if axis < 0 or axis > in_rank:
54 raise ValueError(
55 f"{self.op_name}: axis {axis} out of range for input rank {in_rank}. "
56 f"Valid range is [{-in_rank-1}, {in_rank}]"
57 )
59 x_map = list(x_layout.alias_tensor_map)
60 x_map.insert(axis, "None")
62 output_layout = Layout(
63 mesh_shape=x_layout.mesh_shape,
64 alias_name=x_layout.alias_name,
65 rank_list=x_layout.rank_list
66 )
67 output_layout = output_layout(*x_map)
69 for i, partial_op in enumerate(x_layout.partial):
70 if partial_op is not None:
71 dev_axis_name = x_layout.alias_name[i]
72 output_layout.set_partial_by_dev_axis(dev_axis_name, partial_op)
74 return output_layout