Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_nonzero.py: 97%

31 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Nonzero operator. 

17""" 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23class NonzeroDistributedOp(DistributedOp): 

24 """Distributed implementation for torch.nonzero.""" 

25 

26 def infer_layout(self, layouts, extra_args=None): 

27 """ 

28 Infer output layout for torch.nonzero. 

29 

30 PyTorch semantics: 

31 - Returns a tensor containing the indices of all non-zero elements. 

32 - If as_tuple=False (default): Returns a 2-D tensor of shape (z, n), 

33 where z is the number of non-zero elements and n is the input dimension. 

34 - If as_tuple=True: Returns a tuple of 1-D tensors, one for each dimension. 

35 

36 Distributed semantics: 

37 - Because the output shape depends dynamically on the input data, 

38 the input tensor MUST be fully replicated (unsharded) across the mesh. 

39 Otherwise, each rank would produce a different local shape. 

40 - The output layout will also be fully replicated. 

41 

42 Args: 

43 layouts (tuple): Layouts of inputs. Expected: 

44 layouts[0] (Layout): Input tensor layout (required). 

45 extra_args (list): Contains additional kwargs/args like 'as_tuple'. 

46 

47 Returns: 

48 Layout or tuple[Layout]: Replicated output layout(s) matching PyTorch's 

49 as_tuple return signature. 

50 """ 

51 # ===================================================================== 

52 # FIX: Explicitly enforce the base class's partial status guardrail 

53 # ===================================================================== 

54 if not self._allow_partial_inputs: 

55 self._check_partial_inputs(layouts) 

56 

57 if not layouts or layouts[0] is None: 

58 raise ValueError( 

59 f"Operation {self.op_name}: nonzero requires a valid input tensor layout." 

60 ) 

61 

62 input_layout = layouts[0] 

63 in_tensor_map = input_layout.tensor_map 

64 input_ndim = len(in_tensor_map) 

65 

66 # Rule 1: Input must be fully replicated due to data-dependent dynamic shapes 

67 for dim_sharding in in_tensor_map: 

68 if dim_sharding != -1: 

69 raise ValueError( 

70 f"Operation {self.op_name}: input tensor must be fully replicated " 

71 f"(unsharded). nonzero produces dynamic shapes that depend on data values, " 

72 f"which causes shape mismatches across ranks if the tensor is sharded." 

73 ) 

74 

75 # Rule 2: Parse 'as_tuple' from extra_args (defaults to False) 

76 as_tuple = False 

77 if extra_args: 

78 for arg in extra_args: 

79 if isinstance(arg, bool): 

80 as_tuple = arg 

81 break 

82 

83 mesh_shape = input_layout.mesh_shape 

84 alias_name = input_layout.alias_name 

85 rank_list = input_layout.rank_list 

86 

87 def _create_replicated_layout(ndim): 

88 """Helper to create a fully replicated layout for a given dimension.""" 

89 layout = Layout( 

90 mesh_shape=mesh_shape, 

91 alias_name=alias_name, 

92 rank_list=rank_list 

93 ) 

94 # Replicated layout maps all dimensions to "None" 

95 alias_map = tuple("None" for _ in range(ndim)) 

96 return layout(*alias_map) 

97 

98 # Rule 3: Construct the return layout based on as_tuple flag 

99 if as_tuple: 

100 # Returns a tuple of 1D tensors, one for each dimension in the input 

101 out_layout = _create_replicated_layout(1) 

102 return tuple(out_layout for _ in range(input_ndim)) 

103 

104 # Default: Returns a single 2D tensor of shape (z, n) 

105 return _create_replicated_layout(2)