Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_outer.py: 92%
37 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Outer operator.
17"""
19from hyper_parallel.core.dtensor.layout import Layout
20from .parallel_ops import DistributedOp
23class OuterDistributedOp(DistributedOp):
24 """Distributed implementation for torch.outer."""
26 def infer_layout(self, layouts, extra_args=None):
27 """
28 Infer output layout for torch.outer.
30 PyTorch semantics:
31 - Computes the outer product of two 1-D tensors.
32 - If input is of size N and vec2 is of size M, the output is of size (N, M).
33 - Input tensors must be 1-D.
35 Distributed semantics:
36 - The 0-th dimension of the output inherits the layout of input.
37 - The 1-st dimension of the output inherits the layout of vec2.
38 - The two inputs cannot be sharded along the same device mesh dimension.
40 Args:
41 layouts (tuple): Layouts of inputs. Expected:
42 layouts[0] (Layout): Layout of the first 1-D tensor (input).
43 layouts[1] (Layout): Layout of the second 1-D tensor (vec2).
44 extra_args (tuple, optional): Unused for outer.
46 Returns:
47 Layout: The 2-D output tensor layout.
48 """
49 if not layouts or len(layouts) != 2:
50 raise ValueError(
51 f"Operation {self.op_name}: requires exactly 2 input layouts."
52 )
54 layout1, layout2 = layouts[0], layouts[1]
56 if layout1 is None or layout2 is None:
57 raise ValueError(
58 f"Operation {self.op_name}: requires both inputs to have valid layouts."
59 )
61 map1 = layout1.tensor_map
62 map2 = layout2.tensor_map
64 if len(map1) != 1 or len(map2) != 1:
65 raise ValueError(
66 f"Operation {self.op_name}: requires exactly 1-D tensors as inputs, "
67 f"but got {len(map1)}-D and {len(map2)}-D."
68 )
70 dim0_map = map1[0]
71 dim1_map = map2[0]
73 # Helper to extract all sharded mesh dimensions for a tensor dimension
74 def _get_flattened_map(dim_map):
75 if isinstance(dim_map, int):
76 return {dim_map} if dim_map != -1 else set()
77 return set(dim_map)
79 set1 = _get_flattened_map(dim0_map)
80 set2 = _get_flattened_map(dim1_map)
82 # Ensure the two 1D tensors are not sharded on the same mesh dimension
83 if set1.intersection(set2):
84 raise ValueError(
85 f"Operation {self.op_name}: the two inputs cannot be sharded on the "
86 f"same device mesh dimension. Conflict on mesh index: {set1.intersection(set2)}"
87 )
89 # Build output tensor map: (input_dim, vec2_dim)
90 output_map = [dim0_map, dim1_map]
92 # Construct output layout
93 mesh_shape = layout1.mesh_shape
94 alias_name = layout1.alias_name
95 rank_list = layout1.rank_list
97 def idx_to_alias(idx_item, aliases):
98 # Handles both single int and nested tuple mapping
99 if isinstance(idx_item, int):
100 if idx_item == -1:
101 return "None"
102 return aliases[len(aliases) - idx_item - 1]
103 # Handle multi-axis sharding (tuple)
104 return tuple(
105 "None" if sub_idx == -1 else aliases[len(aliases) - sub_idx - 1]
106 for sub_idx in idx_item
107 )
109 output_alias_map = tuple(idx_to_alias(idx, alias_name) for idx in output_map)
111 output_layout = Layout(
112 mesh_shape=mesh_shape,
113 alias_name=alias_name,
114 rank_list=rank_list
115 )
117 output_layout = output_layout(*output_alias_map)
118 return output_layout