Coverage for hyper_parallel / core / shard / ops / parallel_scatter_update.py: 73%

41 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for ScatterUpdate operator. 

17""" 

18from hyper_parallel.core.layout import Layout 

19from .parallel_ops import DistributedOp 

20 

21 

22class ScatterUpdateDistributedOp(DistributedOp): 

23 """Distributed implementation for ScatterUpdate operator.""" 

24 

25 def infer_layout(self, layouts, extra_args): 

26 """ 

27 Infer output layout for ScatterUpdate. 

28 

29 Args: 

30 layouts (tuple): Tuple containing (input_layout, indices_layout, updates_layout). 

31 extra_args: Additional arguments (not used). 

32 

33 Returns: 

34 Layout: Output layout (same as input layout). 

35 """ 

36 if len(layouts) != 3: 

37 raise ValueError(f"{self.__class__.__name__} requires exactly 3 input layouts") 

38 

39 input_layout, indices_layout, updates_layout = layouts 

40 

41 if input_layout.mesh_shape is None: 

42 raise ValueError("Input layout cannot be None") 

43 

44 self._validate_strategy(input_layout, indices_layout, updates_layout) 

45 

46 output_layout = Layout( 

47 mesh_shape=input_layout.mesh_shape, 

48 alias_name=input_layout.alias_name, 

49 rank_list=input_layout.rank_list 

50 ) 

51 output_layout = output_layout(*input_layout.alias_tensor_map) 

52 

53 for i, partial_op in enumerate(input_layout.partial): 

54 if partial_op is not None: 

55 dev_axis_name = input_layout.alias_name[i] 

56 output_layout.set_partial_by_dev_axis(dev_axis_name, partial_op) 

57 

58 return output_layout 

59 

60 def _validate_strategy(self, input_layout, indices_layout, updates_layout): 

61 """Validate sharding strategy for ScatterUpdate.""" 

62 input_map = input_layout.alias_tensor_map 

63 indices_map = indices_layout.alias_tensor_map 

64 updates_map = updates_layout.alias_tensor_map 

65 

66 if not input_map: 

67 raise ValueError(f"{self.op_name}: input tensor map is empty") 

68 

69 if input_map[0] != "None": 

70 raise ValueError( 

71 f"{self.op_name}: first dimension of input cannot be sharded" 

72 ) 

73 

74 for i, axis in enumerate(indices_map): 

75 if axis != "None": 

76 raise ValueError( 

77 f"{self.op_name}: indices cannot be sharded, " 

78 f"but dimension {i} is sharded on '{axis}'" 

79 ) 

80 

81 indices_ndim = len(indices_map) 

82 for i in range(indices_ndim): 

83 if i >= len(updates_map): 

84 raise ValueError( 

85 f"{self.op_name}: updates rank is smaller than indices rank" 

86 ) 

87 if updates_map[i] != "None": 

88 raise ValueError( 

89 f"{self.op_name}: first {indices_ndim} dimensions of updates cannot be sharded, " 

90 f"but dimension {i} is sharded on '{updates_map[i]}'" 

91 ) 

92 

93 expected_updates_ndim = indices_ndim + len(input_map) - 1 

94 if len(updates_map) != expected_updates_ndim: 

95 raise ValueError( 

96 f"{self.op_name}: updates rank mismatch. " 

97 f"Expected {expected_updates_ndim}, got {len(updates_map)}" 

98 ) 

99 

100 for i in range(1, len(input_map)): 

101 updates_idx = indices_ndim + i - 1 

102 if input_map[i] != updates_map[updates_idx]: 

103 raise ValueError( 

104 f"{self.op_name}: updates sharding must match input[1:]. " 

105 f"Mismatch at input dim {i}: '{input_map[i]}' != '{updates_map[updates_idx]}'" 

106 )