Coverage for hyper_parallel / core / shard / ops / parallel_argmax_with_value_ops.py: 9%
34 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for ArgMaxWithValue operator.
17"""
19from .parallel_ops import DistributedOp
22class ArgMaxWithValueDistributedOp(DistributedOp):
23 """Distributed implementation for ArgMaxWithValue operator."""
25 def infer_layout(self, layouts, extra_args):
26 """
27 Infer output layout for ArgMaxWithValue operator.
28 Args:
29 x_layout (Layout): Layout of input x
30 Returns:
31 tuple: Layout for output tensor
32 Raises:
33 ValueError: If input layouts have partial status.
34 """
35 # Check partial inputs
36 if not self._allow_partial_inputs:
37 self._check_partial_inputs(layouts)
39 # Parse input layout
40 if len(layouts) != 3:
41 raise ValueError(f"ArgMaxWithValue requires 3 layouts, but {len(layouts)}")
42 if len(extra_args) != 2:
43 raise ValueError(f"ArgMaxWithValue requires 2 extra args, but {len(extra_args)}")
45 input_layout = layouts[0]
46 input_mesh_shape = input_layout.mesh_shape
47 input_tensor_map = input_layout.tensor_map
48 input_alias_tensor_map = input_layout.alias_tensor_map
49 axis, keep_dims = extra_args[0], extra_args[1]
51 def is_shard(index):
52 mapping = input_tensor_map[index]
53 if isinstance(mapping, tuple):
54 shard_flag = False
55 for elem in mapping:
56 if elem != -1 and input_mesh_shape[len(input_mesh_shape) - 1 - elem] != 1:
57 shard_flag = True
58 return shard_flag
60 if mapping == -1 or input_mesh_shape[len(input_mesh_shape) - 1 - mapping] == 1:
61 return False
62 return True
64 # Create output layout
65 for i in range(len(input_tensor_map)):
66 if axis != i and is_shard(i):
67 raise ValueError(f"{self.__class__.__name__} cannot perform sharding on non axis dim")
69 if not keep_dims:
70 tensor_map = input_alias_tensor_map[:axis] + input_alias_tensor_map[axis + 1 :]
71 else:
72 tensor_map = input_alias_tensor_map[:axis] + ("None",) + input_alias_tensor_map[axis + 1 :]
74 output_layout = input_layout
75 output_layout = output_layout(*tensor_map)
76 return output_layout, output_layout