Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_argmax_with_value_ops.py: 82%
40 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for ArgMaxWithValue operator.
17"""
19from .parallel_ops import DistributedOp
22class ArgMaxWithValueDistributedOp(DistributedOp):
23 """Distributed implementation for ArgMaxWithValue operator."""
25 def infer_layout(self, layouts, extra_args=None):
26 """
27 Infer output layout for ArgMaxWithValue operator.
28 Args:
29 x_layout (Layout): Layout of input x
30 Returns:
31 tuple: Layout for output tensor
32 Raises:
33 ValueError: If input layouts have partial status.
34 """
35 # Check partial inputs
36 if not self._allow_partial_inputs:
37 self._check_partial_inputs(layouts)
39 # Parse input layout
40 if len(layouts) != 3:
41 raise ValueError(f"ArgMaxWithValue requires 3 layouts, but {len(layouts)}")
42 if len(extra_args) != 2:
43 raise ValueError(f"ArgMaxWithValue requires 2 extra args, but {len(extra_args)}")
45 input_layout = layouts[0]
46 input_mesh_shape = input_layout.mesh_shape
47 input_tensor_map = input_layout.tensor_map
48 input_alias_tensor_map = input_layout.alias_tensor_map
49 axis, keep_dims = extra_args[0], extra_args[1]
50 rank = len(input_tensor_map)
52 if not isinstance(axis, int):
53 raise ValueError(f"ArgMaxWithValue axis must be int, but got {type(axis)}")
54 if axis < 0:
55 axis += rank
56 if axis < 0 or axis >= rank:
57 raise ValueError(f"ArgMaxWithValue axis out of range: axis={extra_args[0]}, rank={rank}")
59 def is_shard(index):
60 mapping = input_tensor_map[index]
61 if isinstance(mapping, tuple):
62 shard_flag = False
63 for elem in mapping:
64 if elem != -1 and input_mesh_shape[len(input_mesh_shape) - 1 - elem] != 1:
65 shard_flag = True
66 return shard_flag
68 if mapping == -1 or input_mesh_shape[len(input_mesh_shape) - 1 - mapping] == 1:
69 return False
70 return True
72 # Create output layout
73 if is_shard(axis):
74 raise ValueError(f"{self.__class__.__name__} cannot perform sharding on axis dim")
76 if not keep_dims:
77 tensor_map = input_alias_tensor_map[:axis] + input_alias_tensor_map[axis + 1 :]
78 else:
79 tensor_map = input_alias_tensor_map[:axis] + ("None",) + input_alias_tensor_map[axis + 1 :]
81 output_layout = input_layout
82 output_layout = output_layout(*tensor_map)
83 return output_layout, output_layout