Coverage for hyper_parallel / core / shard / ops / parallel_sort.py: 84%
25 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Sort operator.
17"""
19from .parallel_ops import DistributedOp
22class SortDistributedOp(DistributedOp):
23 """Distributed implementation for Sort operator."""
25 def infer_layout(self, layouts, extra_args):
26 """
27 Infer output layout for Sort operator.
29 The sort operator expects the sorting dimension to be fully available on each device
30 (i.e., not sharded). If the dimension is sharded, a global sort cannot be performed
31 locally without redistribution.
33 Args:
34 layouts (tuple): Layouts of input tensor.
35 extra_args (tuple): Arguments for the operator. Expected: (dim, descending, stable).
36 If empty, dim defaults to -1.
38 Returns:
39 tuple: (Layout, Layout) representing the layouts for (values, indices).
40 """
41 layout = layouts[0]
43 # Parse dim from extra_args if available, otherwise default to -1
44 dim = -1
45 if extra_args:
46 # extra_args[0] corresponds to 'dim' in torch.sort(input, dim, ...)
47 dim = extra_args[0]
49 if not isinstance(dim, int):
50 raise TypeError(f"For 'sort', dimension must be int, but got {type(dim)}")
52 # Get tensor map to check sharding status
53 in_tensor_map = layout.tensor_map
54 ndim = len(in_tensor_map)
56 # Handle negative dimension index
57 if dim < -ndim or dim >= ndim:
58 raise ValueError(f"Dimension out of range (expected to be in range of [{-ndim}, {ndim-1}], but got {dim})")
60 if dim < 0:
61 dim += ndim
63 # Check if the sorting dimension is sharded
64 # In tensor_map, -1 means Replicate (not sharded). Any other value implies sharding.
65 mapping = in_tensor_map[dim]
66 is_sharded = False
68 if isinstance(mapping, (list, tuple)):
69 # If mapped to multiple mesh axes, check if any is not -1
70 if any(m != -1 for m in mapping):
71 is_sharded = True
72 elif mapping != -1:
73 is_sharded = True
75 if is_sharded:
76 raise ValueError(
77 f"For 'sort', sorting along a sharded dimension (dim {dim} mapped to {mapping}) is not supported. "
78 f"Please redistribute the tensor to Replicate status on this dimension before sorting."
79 )
81 # The output layouts for 'values' and 'indices' are the same as the input layout
82 return (layout, layout)