Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_transpose.py: 95%

38 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Transpose operator. 

17""" 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23class TransposeDistributedOp(DistributedOp): 

24 """Distributed implementation for Transpose operator.""" 

25 

26 def infer_layout(self, layouts, extra_args=None): 

27 """ 

28 Infer output layout for Transpose operator. 

29 

30 Based on the op_name initialized in the base class, this method switches behavior: 

31 1. op_name == 'Transpose', 'permute' or "TransposeView": Implements MindSpore Transpose behavior 

32 or PyTorch permute behavior. 

33 - extra_args expected: (perm,) where perm is a tuple of indices. 

34 - Rules: Output layout is determined by input layout and permutation. 

35 2. op_name == 'transpose': Implements PyTorch transpose behavior. 

36 - extra_args expected: (dim0, dim1) where dim0 and dim1 are integers. 

37 - Rules: Output layout is determined by swapping the specified dimensions in input layout. 

38 

39 Args: 

40 layouts (tuple): Layouts of input tensor. 

41 extra_args (tuple): Arguments for the operator. 

42 

43 Returns: 

44 Layout: Layout for output tensor. 

45 """ 

46 layout = layouts[0] 

47 in_tensor_map = layout.alias_tensor_map 

48 ndim = len(in_tensor_map) 

49 out_tensor_map = None 

50 

51 if self.op_name in ("Transpose", "permute", "TransposeView"): 

52 # MindSpore style: Transpose(input, input_perm) 

53 # extra_args should contain a single element: the permutation tuple 

54 if not extra_args or not isinstance(extra_args[0], (list, tuple)): 

55 raise ValueError(f"For 'Transpose', expected permutation tuple in extra_args, got {extra_args}") 

56 

57 axis = extra_args[0] 

58 

59 if len(in_tensor_map) != len(axis): 

60 raise ValueError(f"Input tensor shape and permutation must have the same size. " 

61 f"Got {len(in_tensor_map)} and {len(axis)}") 

62 

63 # check if axis is a permutation 

64 seen = set() 

65 for v in axis: 

66 if v < 0 or v >= ndim or v in seen: 

67 raise ValueError(f"Invalid permutation {axis} for rank {ndim}") 

68 seen.add(v) 

69 

70 out_tensor_map = tuple(in_tensor_map[i] for i in axis) 

71 

72 elif self.op_name in ("transpose", "TransposeExtView"): 

73 # PyTorch style: transpose(input, dim0, dim1) 

74 # extra_args should contain two elements: dim0 and dim1 

75 if len(extra_args) != 2: 

76 raise ValueError(f"For 'transpose', expected (dim0, dim1), got {extra_args}") 

77 

78 dim0, dim1 = extra_args 

79 

80 if not isinstance(dim0, int) or not isinstance(dim1, int): 

81 raise ValueError(f"Dimensions must be integers, got {dim0}, {dim1}") 

82 

83 # Handle negative indices 

84 if dim0 < 0: 

85 dim0 += ndim 

86 if dim1 < 0: 

87 dim1 += ndim 

88 

89 # Validate dimensions 

90 if not (0 <= dim0 < ndim and 0 <= dim1 < ndim): 

91 raise ValueError(f"Transpose dimensions out of bounds: ({dim0}, {dim1}) for rank {ndim}") 

92 

93 # Swap the dimensions in the tensor map 

94 out_tensor_map_list = list(in_tensor_map) 

95 out_tensor_map_list[dim0], out_tensor_map_list[dim1] = out_tensor_map_list[dim1], out_tensor_map_list[dim0] 

96 out_tensor_map = tuple(out_tensor_map_list) 

97 

98 else: 

99 raise ValueError(f"Unsupported op_name: {self.op_name}. Expected 'Transpose' , 'transpose' or 'permute'.") 

100 

101 output_layout = Layout( 

102 mesh_shape=layout.mesh_shape, 

103 alias_name=layout.alias_name, 

104 rank_list=layout.rank_list 

105 ) 

106 

107 return output_layout(*out_tensor_map)