Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_argsort.py: 86%

22 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Argsort operator. 

17""" 

18 

19from .parallel_ops import DistributedOp 

20 

21 

22class ArgsortDistributedOp(DistributedOp): 

23 """Distributed implementation for torch.argsort.""" 

24 

25 def infer_layout(self, layouts, extra_args=None): 

26 """ 

27 Infer output layout for torch.argsort. 

28 

29 PyTorch semantics: 

30 - Signature: torch.argsort(input, dim=-1, descending=False, stable=False) 

31 - Returns the indices that sort a tensor along a given dimension. 

32 - The output tensor has the exact same shape as the input tensor. 

33 - Distributed constraint: The dimension being sorted (`dim`) MUST NOT be sharded, 

34 as sorting requires full visibility of the elements along that axis. 

35 

36 Args: 

37 layouts (tuple): Layouts of inputs. Expected: 

38 layouts[0] (Layout): Input tensor layout (required). 

39 extra_args (list): Additional scalar arguments. Expected: 

40 extra_args[0] (int): The dimension to sort along (default: -1). 

41 extra_args[1] (bool): descending flag. 

42 extra_args[2] (bool): stable flag. 

43 

44 Returns: 

45 Layout: Output tensor layout (identical to input layout, provided the sorted 

46 dimension is valid and unsharded). 

47 """ 

48 if not layouts or layouts[0] is None: 

49 raise ValueError( 

50 f"Operation {self.op_name}: argsort requires a valid input tensor layout." 

51 ) 

52 

53 input_layout = layouts[0] 

54 in_tensor_map = input_layout.tensor_map 

55 input_ndim = len(in_tensor_map) 

56 

57 # 1. Parse 'dim' from extra_args (default is -1 per PyTorch semantics) 

58 dim = -1 

59 if extra_args and len(extra_args) > 0: 

60 # We assume the first extra argument is 'dim' based on positional unpacking 

61 if isinstance(extra_args[0], int): 

62 dim = extra_args[0] 

63 # Fallback logic in case kwargs ordering puts booleans first 

64 elif isinstance(extra_args[0], bool) and len(extra_args) > 1 and isinstance(extra_args[1], int): 

65 dim = extra_args[1] 

66 

67 # 2. Normalize negative dimensions 

68 actual_dim = dim 

69 if actual_dim < 0: 

70 actual_dim += input_ndim 

71 

72 # 3. Validate dimension bounds 

73 if actual_dim < 0 or actual_dim >= input_ndim: 

74 raise ValueError( 

75 f"Operation {self.op_name}: dim {dim} is out of bounds for " 

76 f"tensor of dimension {input_ndim}." 

77 ) 

78 

79 # 4. Enforce Distributed Constraint: The sorting dimension cannot be sharded. 

80 # In tensor_map, a value of -1 means unsharded. Any value >= 0 represents 

81 # the device mesh axis index that shards this dimension. 

82 if in_tensor_map[actual_dim] != -1: 

83 raise ValueError( 

84 f"Operation {self.op_name}: Cannot perform argsort along dimension {dim} " 

85 f"because it is currently sharded across device mesh axis {in_tensor_map[actual_dim]}. " 

86 f"Please redistribute the tensor to unshard this dimension before sorting." 

87 ) 

88 

89 # 5. The shape and distribution of the indices are identical to the input 

90 return input_layout