Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_argmax_with_value_ops.py: 82%

40 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for ArgMaxWithValue operator. 

17""" 

18 

19from .parallel_ops import DistributedOp 

20 

21 

22class ArgMaxWithValueDistributedOp(DistributedOp): 

23 """Distributed implementation for ArgMaxWithValue operator.""" 

24 

25 def infer_layout(self, layouts, extra_args=None): 

26 """ 

27 Infer output layout for ArgMaxWithValue operator. 

28 Args: 

29 x_layout (Layout): Layout of input x 

30 Returns: 

31 tuple: Layout for output tensor 

32 Raises: 

33 ValueError: If input layouts have partial status. 

34 """ 

35 # Check partial inputs 

36 if not self._allow_partial_inputs: 

37 self._check_partial_inputs(layouts) 

38 

39 # Parse input layout 

40 if len(layouts) != 3: 

41 raise ValueError(f"ArgMaxWithValue requires 3 layouts, but {len(layouts)}") 

42 if len(extra_args) != 2: 

43 raise ValueError(f"ArgMaxWithValue requires 2 extra args, but {len(extra_args)}") 

44 

45 input_layout = layouts[0] 

46 input_mesh_shape = input_layout.mesh_shape 

47 input_tensor_map = input_layout.tensor_map 

48 input_alias_tensor_map = input_layout.alias_tensor_map 

49 axis, keep_dims = extra_args[0], extra_args[1] 

50 rank = len(input_tensor_map) 

51 

52 if not isinstance(axis, int): 

53 raise ValueError(f"ArgMaxWithValue axis must be int, but got {type(axis)}") 

54 if axis < 0: 

55 axis += rank 

56 if axis < 0 or axis >= rank: 

57 raise ValueError(f"ArgMaxWithValue axis out of range: axis={extra_args[0]}, rank={rank}") 

58 

59 def is_shard(index): 

60 mapping = input_tensor_map[index] 

61 if isinstance(mapping, tuple): 

62 shard_flag = False 

63 for elem in mapping: 

64 if elem != -1 and input_mesh_shape[len(input_mesh_shape) - 1 - elem] != 1: 

65 shard_flag = True 

66 return shard_flag 

67 

68 if mapping == -1 or input_mesh_shape[len(input_mesh_shape) - 1 - mapping] == 1: 

69 return False 

70 return True 

71 

72 # Create output layout 

73 if is_shard(axis): 

74 raise ValueError(f"{self.__class__.__name__} cannot perform sharding on axis dim") 

75 

76 if not keep_dims: 

77 tensor_map = input_alias_tensor_map[:axis] + input_alias_tensor_map[axis + 1 :] 

78 else: 

79 tensor_map = input_alias_tensor_map[:axis] + ("None",) + input_alias_tensor_map[axis + 1 :] 

80 

81 output_layout = input_layout 

82 output_layout = output_layout(*tensor_map) 

83 return output_layout, output_layout