Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_topk.py: 94%

18 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for TopK operator. 

17""" 

18 

19from .parallel_ops import DistributedOp 

20 

21 

22class TopKDistributedOp(DistributedOp): 

23 """Distributed implementation for TopK operator.""" 

24 

25 def infer_layout(self, layouts, extra_args=None): 

26 """ 

27 Infer output layouts for TopK operator. 

28 

29 TopK: values, indices = topk(input, k, dim) 

30 

31 Rules: 

32 1. dim = -1 if not specified. 

33 2. The dimension `dim` MUST be unsharded to ensure global top-k correctness. 

34 3. Both values and indices have same layout as input 

35 

36 Args: 

37 layouts (tuple): Layouts of inputs. Expected: 

38 layouts[0] (Layout): Input tensor layout (required). 

39 extra_args (tuple, optional): Requires k and optionally contains dim. Expected: 

40 extra_args[0] (int, required): K value. 

41 extra_args[1] (int, optional): Dimension to compute topk. Defaults to -1. 

42 

43 Returns: 

44 tuple: Layouts for values and indices tensors 

45 """ 

46 if not layouts or layouts[0] is None: 

47 raise ValueError("topk requires a valid input tensor layout.") 

48 

49 input_layout = layouts[0] 

50 in_tensor_map = input_layout.tensor_map 

51 

52 dim = -1 # If dim is not given, the last dimension of the input is chosen. 

53 if len(extra_args) >= 2 and extra_args[1] is not None: 

54 dim = extra_args[1] 

55 input_dim = len(in_tensor_map) 

56 if dim < 0: 

57 dim = input_dim + dim # -1 represents the last dimension 

58 if not 0 <= dim < input_dim: 

59 raise ValueError(f"Dimension out of range (expected to be in [0, {input_dim}), but got {dim}).") 

60 

61 # The chosen dim must NOT be sharded 

62 if in_tensor_map[dim] != -1: 

63 raise ValueError( 

64 f"Operation {self.op_name}: Cannot perform sharding on params along the chosen dim" 

65 ) 

66 

67 return input_layout, input_layout