Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_argsort.py: 86%
22 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Argsort operator.
17"""
19from .parallel_ops import DistributedOp
22class ArgsortDistributedOp(DistributedOp):
23 """Distributed implementation for torch.argsort."""
25 def infer_layout(self, layouts, extra_args=None):
26 """
27 Infer output layout for torch.argsort.
29 PyTorch semantics:
30 - Signature: torch.argsort(input, dim=-1, descending=False, stable=False)
31 - Returns the indices that sort a tensor along a given dimension.
32 - The output tensor has the exact same shape as the input tensor.
33 - Distributed constraint: The dimension being sorted (`dim`) MUST NOT be sharded,
34 as sorting requires full visibility of the elements along that axis.
36 Args:
37 layouts (tuple): Layouts of inputs. Expected:
38 layouts[0] (Layout): Input tensor layout (required).
39 extra_args (list): Additional scalar arguments. Expected:
40 extra_args[0] (int): The dimension to sort along (default: -1).
41 extra_args[1] (bool): descending flag.
42 extra_args[2] (bool): stable flag.
44 Returns:
45 Layout: Output tensor layout (identical to input layout, provided the sorted
46 dimension is valid and unsharded).
47 """
48 if not layouts or layouts[0] is None:
49 raise ValueError(
50 f"Operation {self.op_name}: argsort requires a valid input tensor layout."
51 )
53 input_layout = layouts[0]
54 in_tensor_map = input_layout.tensor_map
55 input_ndim = len(in_tensor_map)
57 # 1. Parse 'dim' from extra_args (default is -1 per PyTorch semantics)
58 dim = -1
59 if extra_args and len(extra_args) > 0:
60 # We assume the first extra argument is 'dim' based on positional unpacking
61 if isinstance(extra_args[0], int):
62 dim = extra_args[0]
63 # Fallback logic in case kwargs ordering puts booleans first
64 elif isinstance(extra_args[0], bool) and len(extra_args) > 1 and isinstance(extra_args[1], int):
65 dim = extra_args[1]
67 # 2. Normalize negative dimensions
68 actual_dim = dim
69 if actual_dim < 0:
70 actual_dim += input_ndim
72 # 3. Validate dimension bounds
73 if actual_dim < 0 or actual_dim >= input_ndim:
74 raise ValueError(
75 f"Operation {self.op_name}: dim {dim} is out of bounds for "
76 f"tensor of dimension {input_ndim}."
77 )
79 # 4. Enforce Distributed Constraint: The sorting dimension cannot be sharded.
80 # In tensor_map, a value of -1 means unsharded. Any value >= 0 represents
81 # the device mesh axis index that shards this dimension.
82 if in_tensor_map[actual_dim] != -1:
83 raise ValueError(
84 f"Operation {self.op_name}: Cannot perform argsort along dimension {dim} "
85 f"because it is currently sharded across device mesh axis {in_tensor_map[actual_dim]}. "
86 f"Please redistribute the tensor to unshard this dimension before sorting."
87 )
89 # 5. The shape and distribution of the indices are identical to the input
90 return input_layout