Coverage for hyper_parallel / core / shard / ops / parallel_isin.py: 92%
24 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Isin operator.
17"""
19from hyper_parallel.core.layout import Layout
20from .parallel_ops import DistributedOp
22class IsinDistributedOp(DistributedOp):
23 """Distributed implementation for torch.isin."""
25 def infer_layout(self, layouts, extra_args=None):
26 """
27 Infer output layout for torch.isin(elements, test_elements, ...)
29 PyTorch semantics:
30 - Returns boolean tensor with SAME SHAPE as `elements`
31 - Each element is tested against ALL values in `test_elements`, so requires GLOBAL view of `test_elements`
33 Args:
34 layouts (tuple): Layouts of inputs. Expected:
35 layouts[0] (Layout): Layout of `elements` tensor (required).
36 layouts[1] (Layout): Layout of `test_elements` tensor (required).
37 extra_args: No need.
39 Returns:
40 Layout: Output layout identical to `elements` layout (boolean tensor with same sharding).
41 """
42 # Validate elements layout
43 if not layouts or layouts[0] is None:
44 raise ValueError(
45 f"Operation {self.op_name}: 'elements' requires a valid tensor layout."
46 )
47 elements_layout = layouts[0]
49 # Validate test_elements layout
50 if len(layouts) < 2 or layouts[1] is None:
51 raise ValueError(
52 f"Operation {self.op_name}: 'test_elements' requires a valid tensor layout."
53 )
54 test_elements_layout = layouts[1]
56 # test_elements must be unsharded
57 if not all(shard_way == -1 for shard_way in test_elements_layout.tensor_map):
58 raise ValueError(
59 f"Operation {self.op_name}: 'test_elements' must be unsharded. "
60 f"Current tensor_map: {test_elements_layout.tensor_map}."
61 )
63 mesh_shape = elements_layout.mesh_shape
64 alias_name = elements_layout.alias_name
65 rank_list = elements_layout.rank_list
66 tensor_map = elements_layout.tensor_map
68 def idx_to_alias(idx, aliases):
69 if idx == -1:
70 return "None"
71 return aliases[len(aliases) - idx - 1]
73 output_map_aliases = tuple(
74 idx_to_alias(idx, alias_name) for idx in tensor_map
75 )
77 output_layout = Layout(
78 mesh_shape=mesh_shape,
79 alias_name=alias_name,
80 rank_list=rank_list
81 )
82 output_layout = output_layout(*output_map_aliases)
84 return output_layout