Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_nonzero.py: 97%
31 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Nonzero operator.
17"""
19from hyper_parallel.core.dtensor.layout import Layout
20from .parallel_ops import DistributedOp
23class NonzeroDistributedOp(DistributedOp):
24 """Distributed implementation for torch.nonzero."""
26 def infer_layout(self, layouts, extra_args=None):
27 """
28 Infer output layout for torch.nonzero.
30 PyTorch semantics:
31 - Returns a tensor containing the indices of all non-zero elements.
32 - If as_tuple=False (default): Returns a 2-D tensor of shape (z, n),
33 where z is the number of non-zero elements and n is the input dimension.
34 - If as_tuple=True: Returns a tuple of 1-D tensors, one for each dimension.
36 Distributed semantics:
37 - Because the output shape depends dynamically on the input data,
38 the input tensor MUST be fully replicated (unsharded) across the mesh.
39 Otherwise, each rank would produce a different local shape.
40 - The output layout will also be fully replicated.
42 Args:
43 layouts (tuple): Layouts of inputs. Expected:
44 layouts[0] (Layout): Input tensor layout (required).
45 extra_args (list): Contains additional kwargs/args like 'as_tuple'.
47 Returns:
48 Layout or tuple[Layout]: Replicated output layout(s) matching PyTorch's
49 as_tuple return signature.
50 """
51 # =====================================================================
52 # FIX: Explicitly enforce the base class's partial status guardrail
53 # =====================================================================
54 if not self._allow_partial_inputs:
55 self._check_partial_inputs(layouts)
57 if not layouts or layouts[0] is None:
58 raise ValueError(
59 f"Operation {self.op_name}: nonzero requires a valid input tensor layout."
60 )
62 input_layout = layouts[0]
63 in_tensor_map = input_layout.tensor_map
64 input_ndim = len(in_tensor_map)
66 # Rule 1: Input must be fully replicated due to data-dependent dynamic shapes
67 for dim_sharding in in_tensor_map:
68 if dim_sharding != -1:
69 raise ValueError(
70 f"Operation {self.op_name}: input tensor must be fully replicated "
71 f"(unsharded). nonzero produces dynamic shapes that depend on data values, "
72 f"which causes shape mismatches across ranks if the tensor is sharded."
73 )
75 # Rule 2: Parse 'as_tuple' from extra_args (defaults to False)
76 as_tuple = False
77 if extra_args:
78 for arg in extra_args:
79 if isinstance(arg, bool):
80 as_tuple = arg
81 break
83 mesh_shape = input_layout.mesh_shape
84 alias_name = input_layout.alias_name
85 rank_list = input_layout.rank_list
87 def _create_replicated_layout(ndim):
88 """Helper to create a fully replicated layout for a given dimension."""
89 layout = Layout(
90 mesh_shape=mesh_shape,
91 alias_name=alias_name,
92 rank_list=rank_list
93 )
94 # Replicated layout maps all dimensions to "None"
95 alias_map = tuple("None" for _ in range(ndim))
96 return layout(*alias_map)
98 # Rule 3: Construct the return layout based on as_tuple flag
99 if as_tuple:
100 # Returns a tuple of 1D tensors, one for each dimension in the input
101 out_layout = _create_replicated_layout(1)
102 return tuple(out_layout for _ in range(input_ndim))
104 # Default: Returns a single 2D tensor of shape (z, n)
105 return _create_replicated_layout(2)