Coverage for hyper_parallel / core / shard / ops / parallel_isin.py: 92%

24 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Isin operator. 

17""" 

18 

19from hyper_parallel.core.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22class IsinDistributedOp(DistributedOp): 

23 """Distributed implementation for torch.isin.""" 

24 

25 def infer_layout(self, layouts, extra_args=None): 

26 """ 

27 Infer output layout for torch.isin(elements, test_elements, ...) 

28 

29 PyTorch semantics: 

30 - Returns boolean tensor with SAME SHAPE as `elements` 

31 - Each element is tested against ALL values in `test_elements`, so requires GLOBAL view of `test_elements` 

32 

33 Args: 

34 layouts (tuple): Layouts of inputs. Expected: 

35 layouts[0] (Layout): Layout of `elements` tensor (required). 

36 layouts[1] (Layout): Layout of `test_elements` tensor (required). 

37 extra_args: No need. 

38 

39 Returns: 

40 Layout: Output layout identical to `elements` layout (boolean tensor with same sharding). 

41 """ 

42 # Validate elements layout 

43 if not layouts or layouts[0] is None: 

44 raise ValueError( 

45 f"Operation {self.op_name}: 'elements' requires a valid tensor layout." 

46 ) 

47 elements_layout = layouts[0] 

48 

49 # Validate test_elements layout 

50 if len(layouts) < 2 or layouts[1] is None: 

51 raise ValueError( 

52 f"Operation {self.op_name}: 'test_elements' requires a valid tensor layout." 

53 ) 

54 test_elements_layout = layouts[1] 

55 

56 # test_elements must be unsharded 

57 if not all(shard_way == -1 for shard_way in test_elements_layout.tensor_map): 

58 raise ValueError( 

59 f"Operation {self.op_name}: 'test_elements' must be unsharded. " 

60 f"Current tensor_map: {test_elements_layout.tensor_map}." 

61 ) 

62 

63 mesh_shape = elements_layout.mesh_shape 

64 alias_name = elements_layout.alias_name 

65 rank_list = elements_layout.rank_list 

66 tensor_map = elements_layout.tensor_map 

67 

68 def idx_to_alias(idx, aliases): 

69 if idx == -1: 

70 return "None" 

71 return aliases[len(aliases) - idx - 1] 

72 

73 output_map_aliases = tuple( 

74 idx_to_alias(idx, alias_name) for idx in tensor_map 

75 ) 

76 

77 output_layout = Layout( 

78 mesh_shape=mesh_shape, 

79 alias_name=alias_name, 

80 rank_list=rank_list 

81 ) 

82 output_layout = output_layout(*output_map_aliases) 

83 

84 return output_layout