Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_transpose.py: 95%
38 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Transpose operator.
17"""
19from hyper_parallel.core.dtensor.layout import Layout
20from .parallel_ops import DistributedOp
23class TransposeDistributedOp(DistributedOp):
24 """Distributed implementation for Transpose operator."""
26 def infer_layout(self, layouts, extra_args=None):
27 """
28 Infer output layout for Transpose operator.
30 Based on the op_name initialized in the base class, this method switches behavior:
31 1. op_name == 'Transpose', 'permute' or "TransposeView": Implements MindSpore Transpose behavior
32 or PyTorch permute behavior.
33 - extra_args expected: (perm,) where perm is a tuple of indices.
34 - Rules: Output layout is determined by input layout and permutation.
35 2. op_name == 'transpose': Implements PyTorch transpose behavior.
36 - extra_args expected: (dim0, dim1) where dim0 and dim1 are integers.
37 - Rules: Output layout is determined by swapping the specified dimensions in input layout.
39 Args:
40 layouts (tuple): Layouts of input tensor.
41 extra_args (tuple): Arguments for the operator.
43 Returns:
44 Layout: Layout for output tensor.
45 """
46 layout = layouts[0]
47 in_tensor_map = layout.alias_tensor_map
48 ndim = len(in_tensor_map)
49 out_tensor_map = None
51 if self.op_name in ("Transpose", "permute", "TransposeView"):
52 # MindSpore style: Transpose(input, input_perm)
53 # extra_args should contain a single element: the permutation tuple
54 if not extra_args or not isinstance(extra_args[0], (list, tuple)):
55 raise ValueError(f"For 'Transpose', expected permutation tuple in extra_args, got {extra_args}")
57 axis = extra_args[0]
59 if len(in_tensor_map) != len(axis):
60 raise ValueError(f"Input tensor shape and permutation must have the same size. "
61 f"Got {len(in_tensor_map)} and {len(axis)}")
63 # check if axis is a permutation
64 seen = set()
65 for v in axis:
66 if v < 0 or v >= ndim or v in seen:
67 raise ValueError(f"Invalid permutation {axis} for rank {ndim}")
68 seen.add(v)
70 out_tensor_map = tuple(in_tensor_map[i] for i in axis)
72 elif self.op_name in ("transpose", "TransposeExtView"):
73 # PyTorch style: transpose(input, dim0, dim1)
74 # extra_args should contain two elements: dim0 and dim1
75 if len(extra_args) != 2:
76 raise ValueError(f"For 'transpose', expected (dim0, dim1), got {extra_args}")
78 dim0, dim1 = extra_args
80 if not isinstance(dim0, int) or not isinstance(dim1, int):
81 raise ValueError(f"Dimensions must be integers, got {dim0}, {dim1}")
83 # Handle negative indices
84 if dim0 < 0:
85 dim0 += ndim
86 if dim1 < 0:
87 dim1 += ndim
89 # Validate dimensions
90 if not (0 <= dim0 < ndim and 0 <= dim1 < ndim):
91 raise ValueError(f"Transpose dimensions out of bounds: ({dim0}, {dim1}) for rank {ndim}")
93 # Swap the dimensions in the tensor map
94 out_tensor_map_list = list(in_tensor_map)
95 out_tensor_map_list[dim0], out_tensor_map_list[dim1] = out_tensor_map_list[dim1], out_tensor_map_list[dim0]
96 out_tensor_map = tuple(out_tensor_map_list)
98 else:
99 raise ValueError(f"Unsupported op_name: {self.op_name}. Expected 'Transpose' , 'transpose' or 'permute'.")
101 output_layout = Layout(
102 mesh_shape=layout.mesh_shape,
103 alias_name=layout.alias_name,
104 rank_list=layout.rank_list
105 )
107 return output_layout(*out_tensor_map)