Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_isin.py: 96%

24 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Isin operator. 

17""" 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23class IsinDistributedOp(DistributedOp): 

24 """Distributed implementation for torch.isin.""" 

25 

26 def infer_layout(self, layouts, extra_args=None): 

27 """ 

28 Infer output layout for torch.isin(elements, test_elements, ...) 

29 

30 PyTorch semantics: 

31 - Returns boolean tensor with SAME SHAPE as `elements` 

32 - Each element is tested against ALL values in `test_elements`, so requires GLOBAL view of `test_elements` 

33 

34 Args: 

35 layouts (tuple): Layouts of inputs. Expected: 

36 layouts[0] (Layout): Layout of `elements` tensor (required). 

37 layouts[1] (Layout): Layout of `test_elements` tensor (required). 

38 extra_args: No need. 

39 

40 Returns: 

41 Layout: Output layout identical to `elements` layout (boolean tensor with same sharding). 

42 """ 

43 # Validate elements layout 

44 if not layouts or layouts[0] is None: 

45 raise ValueError( 

46 f"Operation {self.op_name}: 'elements' requires a valid tensor layout." 

47 ) 

48 elements_layout = layouts[0] 

49 

50 # Validate test_elements layout 

51 if len(layouts) < 2 or layouts[1] is None: 

52 raise ValueError( 

53 f"Operation {self.op_name}: 'test_elements' requires a valid tensor layout." 

54 ) 

55 test_elements_layout = layouts[1] 

56 

57 # test_elements must be unsharded 

58 if not all(shard_way == -1 for shard_way in test_elements_layout.tensor_map): 

59 raise ValueError( 

60 f"Operation {self.op_name}: 'test_elements' must be unsharded. " 

61 f"Current tensor_map: {test_elements_layout.tensor_map}." 

62 ) 

63 

64 mesh_shape = elements_layout.mesh_shape 

65 alias_name = elements_layout.alias_name 

66 rank_list = elements_layout.rank_list 

67 tensor_map = elements_layout.tensor_map 

68 

69 def idx_to_alias(idx, aliases): 

70 if idx == -1: 

71 return "None" 

72 return aliases[len(aliases) - idx - 1] 

73 

74 output_map_aliases = tuple( 

75 idx_to_alias(idx, alias_name) for idx in tensor_map 

76 ) 

77 

78 output_layout = Layout( 

79 mesh_shape=mesh_shape, 

80 alias_name=alias_name, 

81 rank_list=rank_list 

82 ) 

83 output_layout = output_layout(*output_map_aliases) 

84 

85 return output_layout