Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_scatter.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Scatter operator. 

17""" 

18 

19from .parallel_ops import DistributedOp 

20 

21 

22class ScatterDistributedOp(DistributedOp): 

23 """Distributed implementation for torch.scatter.""" 

24 

25 def infer_layout(self, layouts, extra_args=None): 

26 """ 

27 Infer output layout for scatter. 

28 

29 Args: 

30 layouts (tuple): Layouts of inputs. Expected order for scatter(input, dim, index, src): 

31 layouts[0]: input (DTensor) 

32 layouts[1]: dim (None, as it's int) 

33 layouts[2]: index (DTensor) 

34 layouts[3]: src (DTensor or None if scalar) 

35 extra_args (list): Contains non-tensor arguments. 

36 extra_args[0]: dim (int) 

37 extra_args[1]: src (if src is scalar) 

38 

39 Returns: 

40 Layout: Output has same layout as input. 

41 """ 

42 # 1. Check partial status 

43 if not self._allow_partial_inputs: 

44 self._check_partial_inputs(layouts) 

45 

46 if not layouts or layouts[0] is None: 

47 raise ValueError( 

48 f"Operation {self.op_name}: scatter requires a valid input tensor layout." 

49 ) 

50 

51 input_layout = layouts[0] 

52 input_map = input_layout.tensor_map 

53 ndim = len(input_map) 

54 

55 # 2. Extract and Validate 'dim' 

56 if not extra_args or len(extra_args) < 1: 

57 raise ValueError( 

58 f"Operation {self.op_name}: scatter requires 'dim' parameter in extra_args." 

59 ) 

60 

61 dim = extra_args[0] 

62 if not isinstance(dim, int): 

63 raise ValueError(f"Operation {self.op_name}: 'dim' must be an integer.") 

64 

65 # Normalize dim 

66 if dim < 0: 

67 dim += ndim 

68 

69 if dim < 0 or dim >= ndim: 

70 raise ValueError( 

71 f"Operation {self.op_name}: dim {dim} is out of bounds for tensor with {ndim} dimensions." 

72 ) 

73 

74 # 3. Rule: Scatter dimension cannot be sharded 

75 if input_map[dim] != -1: 

76 raise ValueError( 

77 f"Operation {self.op_name}: Scatter along sharded dimension {dim} is not supported. " 

78 "The target dimension must be Replicated (unsharded)." 

79 ) 

80 

81 # 4. Rule: Index layout must match Input layout 

82 if len(layouts) > 2: 

83 index_layout = layouts[2] 

84 if index_layout is not None: 

85 if index_layout.tensor_map != input_map: 

86 raise ValueError( 

87 f"Operation {self.op_name}: Index tensor layout {index_layout.tensor_map} " 

88 f"must match input tensor layout {input_map}." 

89 ) 

90 

91 # 5. Rule: Src layout must match Input layout (if src is a tensor) 

92 if len(layouts) > 3: 

93 src_layout = layouts[3] 

94 if src_layout is not None: 

95 if src_layout.tensor_map != input_map: 

96 raise ValueError( 

97 f"Operation {self.op_name}: Src tensor layout {src_layout.tensor_map} " 

98 f"must match input tensor layout {input_map}." 

99 ) 

100 

101 # Output layout is the same as input layout (scatter modifies input or returns shape of input) 

102 return input_layout