Coverage for hyper_parallel / core / shard / ops / parallel_topk.py: 89%
18 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for TopK operator.
17"""
19from .parallel_ops import DistributedOp
21class TopKDistributedOp(DistributedOp):
22 """Distributed implementation for TopK operator."""
24 def infer_layout(self, layouts, extra_args=None):
25 """
26 Infer output layouts for TopK operator.
28 TopK: values, indices = topk(input, k, dim)
30 Rules:
31 1. dim = -1 if not specified.
32 2. The dimension `dim` MUST be unsharded to ensure global top-k correctness.
33 3. Both values and indices have same layout as input
35 Args:
36 layouts (tuple): Layouts of inputs. Expected:
37 layouts[0] (Layout): Input tensor layout (required).
38 extra_args (tuple, optional): Requires k and optionally contains dim. Expected:
39 extra_args[0] (int, required): K value.
40 extra_args[1] (int, optional): Dimension to compute topk. Defaults to -1.
42 Returns:
43 tuple: Layouts for values and indices tensors
44 """
45 if not layouts or layouts[0] is None:
46 raise ValueError("topk requires a valid input tensor layout.")
48 input_layout = layouts[0]
49 in_tensor_map = input_layout.tensor_map
51 dim = -1 # If dim is not given, the last dimension of the input is chosen.
52 if len(extra_args) >= 2 and extra_args[1] is not None:
53 dim = extra_args[1]
54 input_dim = len(in_tensor_map)
55 if dim < 0:
56 dim = input_dim + dim # -1 represents the last dimension
57 if not 0 <= dim < input_dim:
58 raise ValueError(f"Dimension out of range (expected to be in [0, {input_dim}), but got {dim}).")
60 # The chosen dim must NOT be sharded
61 if in_tensor_map[dim] != -1:
62 raise ValueError(
63 f"Operation {self.op_name}: Cannot perform sharding on params along the chosen dim"
64 )
66 return input_layout, input_layout