Coverage for hyper_parallel / core / shard / ops / parallel_argmax_with_value_ops.py: 9%

34 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for ArgMaxWithValue operator. 

17""" 

18 

19from .parallel_ops import DistributedOp 

20 

21 

22class ArgMaxWithValueDistributedOp(DistributedOp): 

23 """Distributed implementation for ArgMaxWithValue operator.""" 

24 

25 def infer_layout(self, layouts, extra_args): 

26 """ 

27 Infer output layout for ArgMaxWithValue operator. 

28 Args: 

29 x_layout (Layout): Layout of input x 

30 Returns: 

31 tuple: Layout for output tensor 

32 Raises: 

33 ValueError: If input layouts have partial status. 

34 """ 

35 # Check partial inputs 

36 if not self._allow_partial_inputs: 

37 self._check_partial_inputs(layouts) 

38 

39 # Parse input layout 

40 if len(layouts) != 3: 

41 raise ValueError(f"ArgMaxWithValue requires 3 layouts, but {len(layouts)}") 

42 if len(extra_args) != 2: 

43 raise ValueError(f"ArgMaxWithValue requires 2 extra args, but {len(extra_args)}") 

44 

45 input_layout = layouts[0] 

46 input_mesh_shape = input_layout.mesh_shape 

47 input_tensor_map = input_layout.tensor_map 

48 input_alias_tensor_map = input_layout.alias_tensor_map 

49 axis, keep_dims = extra_args[0], extra_args[1] 

50 

51 def is_shard(index): 

52 mapping = input_tensor_map[index] 

53 if isinstance(mapping, tuple): 

54 shard_flag = False 

55 for elem in mapping: 

56 if elem != -1 and input_mesh_shape[len(input_mesh_shape) - 1 - elem] != 1: 

57 shard_flag = True 

58 return shard_flag 

59 

60 if mapping == -1 or input_mesh_shape[len(input_mesh_shape) - 1 - mapping] == 1: 

61 return False 

62 return True 

63 

64 # Create output layout 

65 for i in range(len(input_tensor_map)): 

66 if axis != i and is_shard(i): 

67 raise ValueError(f"{self.__class__.__name__} cannot perform sharding on non axis dim") 

68 

69 if not keep_dims: 

70 tensor_map = input_alias_tensor_map[:axis] + input_alias_tensor_map[axis + 1 :] 

71 else: 

72 tensor_map = input_alias_tensor_map[:axis] + ("None",) + input_alias_tensor_map[axis + 1 :] 

73 

74 output_layout = input_layout 

75 output_layout = output_layout(*tensor_map) 

76 return output_layout, output_layout