Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_topk.py: 94%
18 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for TopK operator.
17"""
19from .parallel_ops import DistributedOp
22class TopKDistributedOp(DistributedOp):
23 """Distributed implementation for TopK operator."""
25 def infer_layout(self, layouts, extra_args=None):
26 """
27 Infer output layouts for TopK operator.
29 TopK: values, indices = topk(input, k, dim)
31 Rules:
32 1. dim = -1 if not specified.
33 2. The dimension `dim` MUST be unsharded to ensure global top-k correctness.
34 3. Both values and indices have same layout as input
36 Args:
37 layouts (tuple): Layouts of inputs. Expected:
38 layouts[0] (Layout): Input tensor layout (required).
39 extra_args (tuple, optional): Requires k and optionally contains dim. Expected:
40 extra_args[0] (int, required): K value.
41 extra_args[1] (int, optional): Dimension to compute topk. Defaults to -1.
43 Returns:
44 tuple: Layouts for values and indices tensors
45 """
46 if not layouts or layouts[0] is None:
47 raise ValueError("topk requires a valid input tensor layout.")
49 input_layout = layouts[0]
50 in_tensor_map = input_layout.tensor_map
52 dim = -1 # If dim is not given, the last dimension of the input is chosen.
53 if len(extra_args) >= 2 and extra_args[1] is not None:
54 dim = extra_args[1]
55 input_dim = len(in_tensor_map)
56 if dim < 0:
57 dim = input_dim + dim # -1 represents the last dimension
58 if not 0 <= dim < input_dim:
59 raise ValueError(f"Dimension out of range (expected to be in [0, {input_dim}), but got {dim}).")
61 # The chosen dim must NOT be sharded
62 if in_tensor_map[dim] != -1:
63 raise ValueError(
64 f"Operation {self.op_name}: Cannot perform sharding on params along the chosen dim"
65 )
67 return input_layout, input_layout