Coverage for hyper_parallel / core / shard / ops / parallel_topk.py: 89%

18 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for TopK operator. 

17""" 

18 

19from .parallel_ops import DistributedOp 

20 

21class TopKDistributedOp(DistributedOp): 

22 """Distributed implementation for TopK operator.""" 

23 

24 def infer_layout(self, layouts, extra_args=None): 

25 """ 

26 Infer output layouts for TopK operator. 

27 

28 TopK: values, indices = topk(input, k, dim) 

29 

30 Rules: 

31 1. dim = -1 if not specified. 

32 2. The dimension `dim` MUST be unsharded to ensure global top-k correctness. 

33 3. Both values and indices have same layout as input 

34 

35 Args: 

36 layouts (tuple): Layouts of inputs. Expected: 

37 layouts[0] (Layout): Input tensor layout (required). 

38 extra_args (tuple, optional): Requires k and optionally contains dim. Expected: 

39 extra_args[0] (int, required): K value. 

40 extra_args[1] (int, optional): Dimension to compute topk. Defaults to -1. 

41 

42 Returns: 

43 tuple: Layouts for values and indices tensors 

44 """ 

45 if not layouts or layouts[0] is None: 

46 raise ValueError("topk requires a valid input tensor layout.") 

47 

48 input_layout = layouts[0] 

49 in_tensor_map = input_layout.tensor_map 

50 

51 dim = -1 # If dim is not given, the last dimension of the input is chosen. 

52 if len(extra_args) >= 2 and extra_args[1] is not None: 

53 dim = extra_args[1] 

54 input_dim = len(in_tensor_map) 

55 if dim < 0: 

56 dim = input_dim + dim # -1 represents the last dimension 

57 if not 0 <= dim < input_dim: 

58 raise ValueError(f"Dimension out of range (expected to be in [0, {input_dim}), but got {dim}).") 

59 

60 # The chosen dim must NOT be sharded 

61 if in_tensor_map[dim] != -1: 

62 raise ValueError( 

63 f"Operation {self.op_name}: Cannot perform sharding on params along the chosen dim" 

64 ) 

65 

66 return input_layout, input_layout