Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_sort.py: 95%

38 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Sort operator. 

17""" 

18 

19from typing import Tuple 

20 

21from .parallel_ops import DistributedOp 

22 

23 

24def _normalize_sort_args(x, dim=-1, descending=False, stable=False): 

25 return (x,), {'dim': dim, 'descending': descending, 'stable': stable} 

26 

27 

28class SortDistributedOp(DistributedOp): 

29 """Distributed implementation for Sort operator.""" 

30 _MS_PRIMITIVE_OP_NAMES = frozenset({'SortExt'}) 

31 

32 def preprocess(self, args: tuple, kwargs: dict) -> tuple: 

33 """ 

34 Preprocess arguments for Sort operator. 

35 

36 Args: 

37 args (tuple): Input arguments, first element is the input tensor. 

38 kwargs (dict): Keyword arguments (dim, descending, stable). 

39 

40 Returns: 

41 tuple: (local_args, local_kwargs, cache_values) 

42 """ 

43 args, kwargs = _normalize_sort_args(*args, **kwargs) 

44 input_tensor = args[0] 

45 dim = kwargs['dim'] 

46 descending = kwargs['descending'] 

47 stable = kwargs['stable'] 

48 

49 if self.op_name in self._MS_PRIMITIVE_OP_NAMES: 

50 local_args = (input_tensor.to_local(), dim, descending, stable) 

51 local_kwargs = {} 

52 else: 

53 local_args = (input_tensor.to_local(),) 

54 local_kwargs = {'dim': dim, 'descending': descending, 'stable': stable} 

55 

56 cache_values = [input_tensor.layout, dim] 

57 return local_args, local_kwargs, cache_values 

58 

59 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]: 

60 """ 

61 Infer output layouts for Sort operator. 

62 

63 Rules: 

64 1. Input must not have Partial status. 

65 2. dim must be an integer within the valid range [-ndim, ndim-1]. 

66 3. The sort dimension must not be sharded (including StridedShard multi-axis mappings). 

67 4. Output values and indices layouts are identical to the input layout. 

68 

69 Args: 

70 cache_values (list): [input_layout, dim] where dim is the sort dimension. 

71 

72 Returns: 

73 tuple: ((values_layout, indices_layout), None) 

74 

75 Raises: 

76 ValueError: If input has Partial status, dim is out of range, or the sort dimension 

77 is sharded. 

78 """ 

79 layout = cache_values[0] 

80 dim = cache_values[1] 

81 

82 self._check_partial_inputs([layout]) 

83 

84 if not isinstance(dim, int): 

85 raise ValueError( 

86 f"For {self.op_name}, dimension should be int, but got {type(dim)}" 

87 ) 

88 

89 # Get tensor map to check sharding status 

90 in_tensor_map = layout.tensor_map 

91 ndim = len(in_tensor_map) 

92 

93 if dim < -ndim or dim >= ndim: 

94 raise ValueError( 

95 f"For {self.op_name}, dimension out of range " 

96 f"(expected to be in range of [{-ndim}, {ndim - 1}], but got {dim})" 

97 ) 

98 

99 if dim < 0: 

100 dim += ndim 

101 

102 # Check if the sorting dimension is sharded. 

103 # In tensor_map, -1 means Replicate (not sharded); any other value implies sharding. 

104 mapping = in_tensor_map[dim] 

105 if isinstance(mapping, (list, tuple)): 

106 is_sharded = any(m != -1 for m in mapping) 

107 else: 

108 is_sharded = mapping != -1 

109 

110 if is_sharded: 

111 raise ValueError( 

112 f"For {self.op_name}, sorting along a sharded dimension " 

113 f"(dim {dim} mapped to {mapping}) is not supported. " 

114 f"Please redistribute the tensor to Replicate on this dimension before sorting." 

115 ) 

116 

117 return ((layout, layout), None)