Coverage for hyper_parallel / core / shard / ops / parallel_sort.py: 84%

25 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Sort operator. 

17""" 

18 

19from .parallel_ops import DistributedOp 

20 

21 

22class SortDistributedOp(DistributedOp): 

23 """Distributed implementation for Sort operator.""" 

24 

25 def infer_layout(self, layouts, extra_args): 

26 """ 

27 Infer output layout for Sort operator. 

28 

29 The sort operator expects the sorting dimension to be fully available on each device 

30 (i.e., not sharded). If the dimension is sharded, a global sort cannot be performed 

31 locally without redistribution. 

32 

33 Args: 

34 layouts (tuple): Layouts of input tensor. 

35 extra_args (tuple): Arguments for the operator. Expected: (dim, descending, stable). 

36 If empty, dim defaults to -1. 

37 

38 Returns: 

39 tuple: (Layout, Layout) representing the layouts for (values, indices). 

40 """ 

41 layout = layouts[0] 

42 

43 # Parse dim from extra_args if available, otherwise default to -1 

44 dim = -1 

45 if extra_args: 

46 # extra_args[0] corresponds to 'dim' in torch.sort(input, dim, ...) 

47 dim = extra_args[0] 

48 

49 if not isinstance(dim, int): 

50 raise TypeError(f"For 'sort', dimension must be int, but got {type(dim)}") 

51 

52 # Get tensor map to check sharding status 

53 in_tensor_map = layout.tensor_map 

54 ndim = len(in_tensor_map) 

55 

56 # Handle negative dimension index 

57 if dim < -ndim or dim >= ndim: 

58 raise ValueError(f"Dimension out of range (expected to be in range of [{-ndim}, {ndim-1}], but got {dim})") 

59 

60 if dim < 0: 

61 dim += ndim 

62 

63 # Check if the sorting dimension is sharded 

64 # In tensor_map, -1 means Replicate (not sharded). Any other value implies sharding. 

65 mapping = in_tensor_map[dim] 

66 is_sharded = False 

67 

68 if isinstance(mapping, (list, tuple)): 

69 # If mapped to multiple mesh axes, check if any is not -1 

70 if any(m != -1 for m in mapping): 

71 is_sharded = True 

72 elif mapping != -1: 

73 is_sharded = True 

74 

75 if is_sharded: 

76 raise ValueError( 

77 f"For 'sort', sorting along a sharded dimension (dim {dim} mapped to {mapping}) is not supported. " 

78 f"Please redistribute the tensor to Replicate status on this dimension before sorting." 

79 ) 

80 

81 # The output layouts for 'values' and 'indices' are the same as the input layout 

82 return (layout, layout)