Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_inplace_scatter_value.py: 94%

31 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for InplaceScatterValue operator. 

17""" 

18 

19from .parallel_ops import DistributedOp 

20 

21 

22class InplaceScatterValueDistributedOp(DistributedOp): 

23 """Distributed implementation for InplaceScatterValue operator.""" 

24 

25 def infer_layout(self, layouts, extra_args): 

26 """ 

27 Infer output layout for InplaceScatterValue. 

28 

29 Requirements: 

30 1. Must have exactly 4 inputs: input, dim, index, value 

31 2. extra_args must have exactly 2 elements: dim (int), value (scalar) 

32 3. Output layout = input layout (inplace) 

33 """ 

34 if not layouts or len(layouts) != 4: 

35 raise ValueError( 

36 f"Operation {self.op_name}: InplaceScatterValue requires exactly 4 inputs: " 

37 f"input, dim, index, value. Got {len(layouts) if layouts else 0}." 

38 ) 

39 

40 input_layout = layouts[0] 

41 index_layout = layouts[2] 

42 if input_layout is None or not hasattr(input_layout, "tensor_map"): 

43 raise ValueError( 

44 f"Operation {self.op_name}: input tensor layout cannot be None." 

45 ) 

46 if index_layout is None or not hasattr(index_layout, "tensor_map"): 

47 raise ValueError( 

48 f"Operation {self.op_name}: index tensor layout cannot be None." 

49 ) 

50 input_map = input_layout.tensor_map 

51 index_map = index_layout.tensor_map 

52 ndim = len(input_map) 

53 if len(input_map) != len(index_map): 

54 raise ValueError( 

55 f"Operation {self.op_name}: input and index must have the same number of dimensions. " 

56 f"Got input rank={len(input_map)}, index rank={len(index_map)}" 

57 ) 

58 

59 if not extra_args or len(extra_args) != 2: 

60 raise ValueError( 

61 f"Operation {self.op_name}: extra_args must contain exactly 2 elements: " 

62 f"dim (int), value (scalar). Got {len(extra_args) if extra_args else 0}." 

63 ) 

64 dim = extra_args[0] 

65 

66 if not isinstance(dim, int): 

67 raise ValueError(f"Operation {self.op_name}: 'dim' must be an integer.") 

68 if dim < 0: 

69 dim += ndim 

70 if dim < 0 or dim >= ndim: 

71 raise ValueError( 

72 f"Operation {self.op_name}: dim {dim} is out of bounds for tensor with {ndim} dims." 

73 ) 

74 for axis, (input_axis_map, index_axis_map) in enumerate(zip(input_map, index_map)): 

75 if input_axis_map != index_axis_map: 

76 raise ValueError( 

77 f"Operation {self.op_name}: input and index must use the same sharding on non-dim axis {axis}. " 

78 f"Got input tensor_map={input_map}, index tensor_map={index_map}, dim={dim}" 

79 ) 

80 

81 if input_map[dim] != -1: 

82 raise ValueError( 

83 f"Operation {self.op_name}: Scatter along sharded dimension {dim} is not supported. " 

84 f"The target dimension must be replicated (unsharded)." 

85 ) 

86 

87 return input_layout