Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_inplace_scatter_value.py: 94%
31 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for InplaceScatterValue operator.
17"""
19from .parallel_ops import DistributedOp
22class InplaceScatterValueDistributedOp(DistributedOp):
23 """Distributed implementation for InplaceScatterValue operator."""
25 def infer_layout(self, layouts, extra_args):
26 """
27 Infer output layout for InplaceScatterValue.
29 Requirements:
30 1. Must have exactly 4 inputs: input, dim, index, value
31 2. extra_args must have exactly 2 elements: dim (int), value (scalar)
32 3. Output layout = input layout (inplace)
33 """
34 if not layouts or len(layouts) != 4:
35 raise ValueError(
36 f"Operation {self.op_name}: InplaceScatterValue requires exactly 4 inputs: "
37 f"input, dim, index, value. Got {len(layouts) if layouts else 0}."
38 )
40 input_layout = layouts[0]
41 index_layout = layouts[2]
42 if input_layout is None or not hasattr(input_layout, "tensor_map"):
43 raise ValueError(
44 f"Operation {self.op_name}: input tensor layout cannot be None."
45 )
46 if index_layout is None or not hasattr(index_layout, "tensor_map"):
47 raise ValueError(
48 f"Operation {self.op_name}: index tensor layout cannot be None."
49 )
50 input_map = input_layout.tensor_map
51 index_map = index_layout.tensor_map
52 ndim = len(input_map)
53 if len(input_map) != len(index_map):
54 raise ValueError(
55 f"Operation {self.op_name}: input and index must have the same number of dimensions. "
56 f"Got input rank={len(input_map)}, index rank={len(index_map)}"
57 )
59 if not extra_args or len(extra_args) != 2:
60 raise ValueError(
61 f"Operation {self.op_name}: extra_args must contain exactly 2 elements: "
62 f"dim (int), value (scalar). Got {len(extra_args) if extra_args else 0}."
63 )
64 dim = extra_args[0]
66 if not isinstance(dim, int):
67 raise ValueError(f"Operation {self.op_name}: 'dim' must be an integer.")
68 if dim < 0:
69 dim += ndim
70 if dim < 0 or dim >= ndim:
71 raise ValueError(
72 f"Operation {self.op_name}: dim {dim} is out of bounds for tensor with {ndim} dims."
73 )
74 for axis, (input_axis_map, index_axis_map) in enumerate(zip(input_map, index_map)):
75 if input_axis_map != index_axis_map:
76 raise ValueError(
77 f"Operation {self.op_name}: input and index must use the same sharding on non-dim axis {axis}. "
78 f"Got input tensor_map={input_map}, index tensor_map={index_map}, dim={dim}"
79 )
81 if input_map[dim] != -1:
82 raise ValueError(
83 f"Operation {self.op_name}: Scatter along sharded dimension {dim} is not supported. "
84 f"The target dimension must be replicated (unsharded)."
85 )
87 return input_layout