Coverage for hyper_parallel / core / shard / ops / parallel_transpose.py: 79%
38 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Transpose operator.
17"""
19from hyper_parallel.core.layout import Layout
20from .parallel_ops import DistributedOp
23class TransposeDistributedOp(DistributedOp):
24 """Distributed implementation for Transpose operator."""
26 def infer_layout(self, layouts, extra_args):
27 """
28 Infer output layout for Transpose operator.
30 Based on the op_name initialized in the base class, this method switches behavior:
31 1. op_name == 'Transpose' or 'permute': Implements MindSpore Transpose behavior or PyTorch permute behavior.
32 - extra_args expected: (perm,) where perm is a tuple of indices.
33 - Rules: Output layout is determined by input layout and permutation.
34 2. op_name == 'transpose': Implements PyTorch transpose behavior.
35 - extra_args expected: (dim0, dim1) where dim0 and dim1 are integers.
36 - Rules: Output layout is determined by swapping the specified dimensions in input layout.
38 Args:
39 layouts (tuple): Layouts of input tensor.
40 extra_args (tuple): Arguments for the operator.
42 Returns:
43 Layout: Layout for output tensor.
44 """
45 layout = layouts[0]
46 in_tensor_map = layout.alias_tensor_map
47 ndim = len(in_tensor_map)
48 out_tensor_map = None
50 if self.op_name in ("Transpose", "permute"):
51 # MindSpore style: Transpose(input, input_perm)
52 # extra_args should contain a single element: the permutation tuple
53 if not extra_args or not isinstance(extra_args[0], (list, tuple)):
54 raise ValueError(f"For 'Transpose', expected permutation tuple in extra_args, got {extra_args}")
56 axis = extra_args[0]
58 if len(in_tensor_map) != len(axis):
59 raise ValueError(f"Input tensor shape and permutation must have the same size. "
60 f"Got {len(in_tensor_map)} and {len(axis)}")
62 # check if axis is a permutation
63 seen = set()
64 for v in axis:
65 if v < 0 or v >= ndim or v in seen:
66 raise ValueError(f"Invalid permutation {axis} for rank {ndim}")
67 seen.add(v)
69 out_tensor_map = tuple(in_tensor_map[i] for i in axis)
71 elif self.op_name in ("transpose", "TransposeExtView"):
72 # PyTorch style: transpose(input, dim0, dim1)
73 # extra_args should contain two elements: dim0 and dim1
74 if len(extra_args) != 2:
75 raise ValueError(f"For 'transpose', expected (dim0, dim1), got {extra_args}")
77 dim0, dim1 = extra_args
79 if not isinstance(dim0, int) or not isinstance(dim1, int):
80 raise ValueError(f"Dimensions must be integers, got {dim0}, {dim1}")
82 # Handle negative indices
83 if dim0 < 0:
84 dim0 += ndim
85 if dim1 < 0:
86 dim1 += ndim
88 # Validate dimensions
89 if not (0 <= dim0 < ndim and 0 <= dim1 < ndim):
90 raise ValueError(f"Transpose dimensions out of bounds: ({dim0}, {dim1}) for rank {ndim}")
92 # Swap the dimensions in the tensor map
93 out_tensor_map_list = list(in_tensor_map)
94 out_tensor_map_list[dim0], out_tensor_map_list[dim1] = out_tensor_map_list[dim1], out_tensor_map_list[dim0]
95 out_tensor_map = tuple(out_tensor_map_list)
97 else:
98 raise ValueError(f"Unsupported op_name: {self.op_name}. Expected 'Transpose' , 'transpose' or 'permute'.")
100 output_layout = Layout(
101 mesh_shape=layout.mesh_shape,
102 alias_name=layout.alias_name,
103 rank_list=layout.rank_list
104 )
106 return output_layout(*out_tensor_map)