Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_outer.py: 92%

37 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Outer operator. 

17""" 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23class OuterDistributedOp(DistributedOp): 

24 """Distributed implementation for torch.outer.""" 

25 

26 def infer_layout(self, layouts, extra_args=None): 

27 """ 

28 Infer output layout for torch.outer. 

29 

30 PyTorch semantics: 

31 - Computes the outer product of two 1-D tensors. 

32 - If input is of size N and vec2 is of size M, the output is of size (N, M). 

33 - Input tensors must be 1-D. 

34 

35 Distributed semantics: 

36 - The 0-th dimension of the output inherits the layout of input. 

37 - The 1-st dimension of the output inherits the layout of vec2. 

38 - The two inputs cannot be sharded along the same device mesh dimension. 

39 

40 Args: 

41 layouts (tuple): Layouts of inputs. Expected: 

42 layouts[0] (Layout): Layout of the first 1-D tensor (input). 

43 layouts[1] (Layout): Layout of the second 1-D tensor (vec2). 

44 extra_args (tuple, optional): Unused for outer. 

45 

46 Returns: 

47 Layout: The 2-D output tensor layout. 

48 """ 

49 if not layouts or len(layouts) != 2: 

50 raise ValueError( 

51 f"Operation {self.op_name}: requires exactly 2 input layouts." 

52 ) 

53 

54 layout1, layout2 = layouts[0], layouts[1] 

55 

56 if layout1 is None or layout2 is None: 

57 raise ValueError( 

58 f"Operation {self.op_name}: requires both inputs to have valid layouts." 

59 ) 

60 

61 map1 = layout1.tensor_map 

62 map2 = layout2.tensor_map 

63 

64 if len(map1) != 1 or len(map2) != 1: 

65 raise ValueError( 

66 f"Operation {self.op_name}: requires exactly 1-D tensors as inputs, " 

67 f"but got {len(map1)}-D and {len(map2)}-D." 

68 ) 

69 

70 dim0_map = map1[0] 

71 dim1_map = map2[0] 

72 

73 # Helper to extract all sharded mesh dimensions for a tensor dimension 

74 def _get_flattened_map(dim_map): 

75 if isinstance(dim_map, int): 

76 return {dim_map} if dim_map != -1 else set() 

77 return set(dim_map) 

78 

79 set1 = _get_flattened_map(dim0_map) 

80 set2 = _get_flattened_map(dim1_map) 

81 

82 # Ensure the two 1D tensors are not sharded on the same mesh dimension 

83 if set1.intersection(set2): 

84 raise ValueError( 

85 f"Operation {self.op_name}: the two inputs cannot be sharded on the " 

86 f"same device mesh dimension. Conflict on mesh index: {set1.intersection(set2)}" 

87 ) 

88 

89 # Build output tensor map: (input_dim, vec2_dim) 

90 output_map = [dim0_map, dim1_map] 

91 

92 # Construct output layout 

93 mesh_shape = layout1.mesh_shape 

94 alias_name = layout1.alias_name 

95 rank_list = layout1.rank_list 

96 

97 def idx_to_alias(idx_item, aliases): 

98 # Handles both single int and nested tuple mapping 

99 if isinstance(idx_item, int): 

100 if idx_item == -1: 

101 return "None" 

102 return aliases[len(aliases) - idx_item - 1] 

103 # Handle multi-axis sharding (tuple) 

104 return tuple( 

105 "None" if sub_idx == -1 else aliases[len(aliases) - sub_idx - 1] 

106 for sub_idx in idx_item 

107 ) 

108 

109 output_alias_map = tuple(idx_to_alias(idx, alias_name) for idx in output_map) 

110 

111 output_layout = Layout( 

112 mesh_shape=mesh_shape, 

113 alias_name=alias_name, 

114 rank_list=rank_list 

115 ) 

116 

117 output_layout = output_layout(*output_alias_map) 

118 return output_layout