Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_isin.py: 96%
24 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Isin operator.
17"""
19from hyper_parallel.core.dtensor.layout import Layout
20from .parallel_ops import DistributedOp
23class IsinDistributedOp(DistributedOp):
24 """Distributed implementation for torch.isin."""
26 def infer_layout(self, layouts, extra_args=None):
27 """
28 Infer output layout for torch.isin(elements, test_elements, ...)
30 PyTorch semantics:
31 - Returns boolean tensor with SAME SHAPE as `elements`
32 - Each element is tested against ALL values in `test_elements`, so requires GLOBAL view of `test_elements`
34 Args:
35 layouts (tuple): Layouts of inputs. Expected:
36 layouts[0] (Layout): Layout of `elements` tensor (required).
37 layouts[1] (Layout): Layout of `test_elements` tensor (required).
38 extra_args: No need.
40 Returns:
41 Layout: Output layout identical to `elements` layout (boolean tensor with same sharding).
42 """
43 # Validate elements layout
44 if not layouts or layouts[0] is None:
45 raise ValueError(
46 f"Operation {self.op_name}: 'elements' requires a valid tensor layout."
47 )
48 elements_layout = layouts[0]
50 # Validate test_elements layout
51 if len(layouts) < 2 or layouts[1] is None:
52 raise ValueError(
53 f"Operation {self.op_name}: 'test_elements' requires a valid tensor layout."
54 )
55 test_elements_layout = layouts[1]
57 # test_elements must be unsharded
58 if not all(shard_way == -1 for shard_way in test_elements_layout.tensor_map):
59 raise ValueError(
60 f"Operation {self.op_name}: 'test_elements' must be unsharded. "
61 f"Current tensor_map: {test_elements_layout.tensor_map}."
62 )
64 mesh_shape = elements_layout.mesh_shape
65 alias_name = elements_layout.alias_name
66 rank_list = elements_layout.rank_list
67 tensor_map = elements_layout.tensor_map
69 def idx_to_alias(idx, aliases):
70 if idx == -1:
71 return "None"
72 return aliases[len(aliases) - idx - 1]
74 output_map_aliases = tuple(
75 idx_to_alias(idx, alias_name) for idx in tensor_map
76 )
78 output_layout = Layout(
79 mesh_shape=mesh_shape,
80 alias_name=alias_name,
81 rank_list=rank_list
82 )
83 output_layout = output_layout(*output_map_aliases)
85 return output_layout