Coverage for hyper_parallel / core / shard / ops / parallel_transpose.py: 79%

38 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Transpose operator. 

17""" 

18 

19from hyper_parallel.core.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23class TransposeDistributedOp(DistributedOp): 

24 """Distributed implementation for Transpose operator.""" 

25 

26 def infer_layout(self, layouts, extra_args): 

27 """ 

28 Infer output layout for Transpose operator. 

29 

30 Based on the op_name initialized in the base class, this method switches behavior: 

31 1. op_name == 'Transpose' or 'permute': Implements MindSpore Transpose behavior or PyTorch permute behavior. 

32 - extra_args expected: (perm,) where perm is a tuple of indices. 

33 - Rules: Output layout is determined by input layout and permutation. 

34 2. op_name == 'transpose': Implements PyTorch transpose behavior. 

35 - extra_args expected: (dim0, dim1) where dim0 and dim1 are integers. 

36 - Rules: Output layout is determined by swapping the specified dimensions in input layout. 

37 

38 Args: 

39 layouts (tuple): Layouts of input tensor. 

40 extra_args (tuple): Arguments for the operator. 

41 

42 Returns: 

43 Layout: Layout for output tensor. 

44 """ 

45 layout = layouts[0] 

46 in_tensor_map = layout.alias_tensor_map 

47 ndim = len(in_tensor_map) 

48 out_tensor_map = None 

49 

50 if self.op_name in ("Transpose", "permute"): 

51 # MindSpore style: Transpose(input, input_perm) 

52 # extra_args should contain a single element: the permutation tuple 

53 if not extra_args or not isinstance(extra_args[0], (list, tuple)): 

54 raise ValueError(f"For 'Transpose', expected permutation tuple in extra_args, got {extra_args}") 

55 

56 axis = extra_args[0] 

57 

58 if len(in_tensor_map) != len(axis): 

59 raise ValueError(f"Input tensor shape and permutation must have the same size. " 

60 f"Got {len(in_tensor_map)} and {len(axis)}") 

61 

62 # check if axis is a permutation 

63 seen = set() 

64 for v in axis: 

65 if v < 0 or v >= ndim or v in seen: 

66 raise ValueError(f"Invalid permutation {axis} for rank {ndim}") 

67 seen.add(v) 

68 

69 out_tensor_map = tuple(in_tensor_map[i] for i in axis) 

70 

71 elif self.op_name in ("transpose", "TransposeExtView"): 

72 # PyTorch style: transpose(input, dim0, dim1) 

73 # extra_args should contain two elements: dim0 and dim1 

74 if len(extra_args) != 2: 

75 raise ValueError(f"For 'transpose', expected (dim0, dim1), got {extra_args}") 

76 

77 dim0, dim1 = extra_args 

78 

79 if not isinstance(dim0, int) or not isinstance(dim1, int): 

80 raise ValueError(f"Dimensions must be integers, got {dim0}, {dim1}") 

81 

82 # Handle negative indices 

83 if dim0 < 0: 

84 dim0 += ndim 

85 if dim1 < 0: 

86 dim1 += ndim 

87 

88 # Validate dimensions 

89 if not (0 <= dim0 < ndim and 0 <= dim1 < ndim): 

90 raise ValueError(f"Transpose dimensions out of bounds: ({dim0}, {dim1}) for rank {ndim}") 

91 

92 # Swap the dimensions in the tensor map 

93 out_tensor_map_list = list(in_tensor_map) 

94 out_tensor_map_list[dim0], out_tensor_map_list[dim1] = out_tensor_map_list[dim1], out_tensor_map_list[dim0] 

95 out_tensor_map = tuple(out_tensor_map_list) 

96 

97 else: 

98 raise ValueError(f"Unsupported op_name: {self.op_name}. Expected 'Transpose' , 'transpose' or 'permute'.") 

99 

100 output_layout = Layout( 

101 mesh_shape=layout.mesh_shape, 

102 alias_name=layout.alias_name, 

103 rank_list=layout.rank_list 

104 ) 

105 

106 return output_layout(*out_tensor_map)