Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_sort.py: 95%
38 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Sort operator.
17"""
19from typing import Tuple
21from .parallel_ops import DistributedOp
24def _normalize_sort_args(x, dim=-1, descending=False, stable=False):
25 return (x,), {'dim': dim, 'descending': descending, 'stable': stable}
28class SortDistributedOp(DistributedOp):
29 """Distributed implementation for Sort operator."""
30 _MS_PRIMITIVE_OP_NAMES = frozenset({'SortExt'})
32 def preprocess(self, args: tuple, kwargs: dict) -> tuple:
33 """
34 Preprocess arguments for Sort operator.
36 Args:
37 args (tuple): Input arguments, first element is the input tensor.
38 kwargs (dict): Keyword arguments (dim, descending, stable).
40 Returns:
41 tuple: (local_args, local_kwargs, cache_values)
42 """
43 args, kwargs = _normalize_sort_args(*args, **kwargs)
44 input_tensor = args[0]
45 dim = kwargs['dim']
46 descending = kwargs['descending']
47 stable = kwargs['stable']
49 if self.op_name in self._MS_PRIMITIVE_OP_NAMES:
50 local_args = (input_tensor.to_local(), dim, descending, stable)
51 local_kwargs = {}
52 else:
53 local_args = (input_tensor.to_local(),)
54 local_kwargs = {'dim': dim, 'descending': descending, 'stable': stable}
56 cache_values = [input_tensor.layout, dim]
57 return local_args, local_kwargs, cache_values
59 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
60 """
61 Infer output layouts for Sort operator.
63 Rules:
64 1. Input must not have Partial status.
65 2. dim must be an integer within the valid range [-ndim, ndim-1].
66 3. The sort dimension must not be sharded (including StridedShard multi-axis mappings).
67 4. Output values and indices layouts are identical to the input layout.
69 Args:
70 cache_values (list): [input_layout, dim] where dim is the sort dimension.
72 Returns:
73 tuple: ((values_layout, indices_layout), None)
75 Raises:
76 ValueError: If input has Partial status, dim is out of range, or the sort dimension
77 is sharded.
78 """
79 layout = cache_values[0]
80 dim = cache_values[1]
82 self._check_partial_inputs([layout])
84 if not isinstance(dim, int):
85 raise ValueError(
86 f"For {self.op_name}, dimension should be int, but got {type(dim)}"
87 )
89 # Get tensor map to check sharding status
90 in_tensor_map = layout.tensor_map
91 ndim = len(in_tensor_map)
93 if dim < -ndim or dim >= ndim:
94 raise ValueError(
95 f"For {self.op_name}, dimension out of range "
96 f"(expected to be in range of [{-ndim}, {ndim - 1}], but got {dim})"
97 )
99 if dim < 0:
100 dim += ndim
102 # Check if the sorting dimension is sharded.
103 # In tensor_map, -1 means Replicate (not sharded); any other value implies sharding.
104 mapping = in_tensor_map[dim]
105 if isinstance(mapping, (list, tuple)):
106 is_sharded = any(m != -1 for m in mapping)
107 else:
108 is_sharded = mapping != -1
110 if is_sharded:
111 raise ValueError(
112 f"For {self.op_name}, sorting along a sharded dimension "
113 f"(dim {dim} mapped to {mapping}) is not supported. "
114 f"Please redistribute the tensor to Replicate on this dimension before sorting."
115 )
117 return ((layout, layout), None)