Coverage for hyper_parallel / core / shard / ops / parallel_scatter_update.py: 73%
41 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for ScatterUpdate operator.
17"""
18from hyper_parallel.core.layout import Layout
19from .parallel_ops import DistributedOp
22class ScatterUpdateDistributedOp(DistributedOp):
23 """Distributed implementation for ScatterUpdate operator."""
25 def infer_layout(self, layouts, extra_args):
26 """
27 Infer output layout for ScatterUpdate.
29 Args:
30 layouts (tuple): Tuple containing (input_layout, indices_layout, updates_layout).
31 extra_args: Additional arguments (not used).
33 Returns:
34 Layout: Output layout (same as input layout).
35 """
36 if len(layouts) != 3:
37 raise ValueError(f"{self.__class__.__name__} requires exactly 3 input layouts")
39 input_layout, indices_layout, updates_layout = layouts
41 if input_layout.mesh_shape is None:
42 raise ValueError("Input layout cannot be None")
44 self._validate_strategy(input_layout, indices_layout, updates_layout)
46 output_layout = Layout(
47 mesh_shape=input_layout.mesh_shape,
48 alias_name=input_layout.alias_name,
49 rank_list=input_layout.rank_list
50 )
51 output_layout = output_layout(*input_layout.alias_tensor_map)
53 for i, partial_op in enumerate(input_layout.partial):
54 if partial_op is not None:
55 dev_axis_name = input_layout.alias_name[i]
56 output_layout.set_partial_by_dev_axis(dev_axis_name, partial_op)
58 return output_layout
60 def _validate_strategy(self, input_layout, indices_layout, updates_layout):
61 """Validate sharding strategy for ScatterUpdate."""
62 input_map = input_layout.alias_tensor_map
63 indices_map = indices_layout.alias_tensor_map
64 updates_map = updates_layout.alias_tensor_map
66 if not input_map:
67 raise ValueError(f"{self.op_name}: input tensor map is empty")
69 if input_map[0] != "None":
70 raise ValueError(
71 f"{self.op_name}: first dimension of input cannot be sharded"
72 )
74 for i, axis in enumerate(indices_map):
75 if axis != "None":
76 raise ValueError(
77 f"{self.op_name}: indices cannot be sharded, "
78 f"but dimension {i} is sharded on '{axis}'"
79 )
81 indices_ndim = len(indices_map)
82 for i in range(indices_ndim):
83 if i >= len(updates_map):
84 raise ValueError(
85 f"{self.op_name}: updates rank is smaller than indices rank"
86 )
87 if updates_map[i] != "None":
88 raise ValueError(
89 f"{self.op_name}: first {indices_ndim} dimensions of updates cannot be sharded, "
90 f"but dimension {i} is sharded on '{updates_map[i]}'"
91 )
93 expected_updates_ndim = indices_ndim + len(input_map) - 1
94 if len(updates_map) != expected_updates_ndim:
95 raise ValueError(
96 f"{self.op_name}: updates rank mismatch. "
97 f"Expected {expected_updates_ndim}, got {len(updates_map)}"
98 )
100 for i in range(1, len(input_map)):
101 updates_idx = indices_ndim + i - 1
102 if input_map[i] != updates_map[updates_idx]:
103 raise ValueError(
104 f"{self.op_name}: updates sharding must match input[1:]. "
105 f"Mismatch at input dim {i}: '{input_map[i]}' != '{updates_map[updates_idx]}'"
106 )