Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_scatter.py: 100%
32 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Scatter operator.
17"""
19from .parallel_ops import DistributedOp
22class ScatterDistributedOp(DistributedOp):
23 """Distributed implementation for torch.scatter."""
25 def infer_layout(self, layouts, extra_args=None):
26 """
27 Infer output layout for scatter.
29 Args:
30 layouts (tuple): Layouts of inputs. Expected order for scatter(input, dim, index, src):
31 layouts[0]: input (DTensor)
32 layouts[1]: dim (None, as it's int)
33 layouts[2]: index (DTensor)
34 layouts[3]: src (DTensor or None if scalar)
35 extra_args (list): Contains non-tensor arguments.
36 extra_args[0]: dim (int)
37 extra_args[1]: src (if src is scalar)
39 Returns:
40 Layout: Output has same layout as input.
41 """
42 # 1. Check partial status
43 if not self._allow_partial_inputs:
44 self._check_partial_inputs(layouts)
46 if not layouts or layouts[0] is None:
47 raise ValueError(
48 f"Operation {self.op_name}: scatter requires a valid input tensor layout."
49 )
51 input_layout = layouts[0]
52 input_map = input_layout.tensor_map
53 ndim = len(input_map)
55 # 2. Extract and Validate 'dim'
56 if not extra_args or len(extra_args) < 1:
57 raise ValueError(
58 f"Operation {self.op_name}: scatter requires 'dim' parameter in extra_args."
59 )
61 dim = extra_args[0]
62 if not isinstance(dim, int):
63 raise ValueError(f"Operation {self.op_name}: 'dim' must be an integer.")
65 # Normalize dim
66 if dim < 0:
67 dim += ndim
69 if dim < 0 or dim >= ndim:
70 raise ValueError(
71 f"Operation {self.op_name}: dim {dim} is out of bounds for tensor with {ndim} dimensions."
72 )
74 # 3. Rule: Scatter dimension cannot be sharded
75 if input_map[dim] != -1:
76 raise ValueError(
77 f"Operation {self.op_name}: Scatter along sharded dimension {dim} is not supported. "
78 "The target dimension must be Replicated (unsharded)."
79 )
81 # 4. Rule: Index layout must match Input layout
82 if len(layouts) > 2:
83 index_layout = layouts[2]
84 if index_layout is not None:
85 if index_layout.tensor_map != input_map:
86 raise ValueError(
87 f"Operation {self.op_name}: Index tensor layout {index_layout.tensor_map} "
88 f"must match input tensor layout {input_map}."
89 )
91 # 5. Rule: Src layout must match Input layout (if src is a tensor)
92 if len(layouts) > 3:
93 src_layout = layouts[3]
94 if src_layout is not None:
95 if src_layout.tensor_map != input_map:
96 raise ValueError(
97 f"Operation {self.op_name}: Src tensor layout {src_layout.tensor_map} "
98 f"must match input tensor layout {input_map}."
99 )
101 # Output layout is the same as input layout (scatter modifies input or returns shape of input)
102 return input_layout