Coverage for hyper_parallel / core / shard / ops / parallel_masked_scatter.py: 92%
12 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for MaskedScatter operator.
17"""
19from .parallel_ops import DistributedOp
22class MaskedScatterDistributedOp(DistributedOp):
23 """Distributed implementation for torch.Tensor.masked_scatter."""
25 def infer_layout(self, layouts, extra_args=None):
26 """
27 Infer output layout for torch.Tensor.masked_scatter.
29 PyTorch semantics:
30 masked_scatter_(mask, source)
31 Copies elements from source into self tensor at positions where the mask is True.
32 Elements from source are taken in order.
34 Distributed restrictions:
35 Because `masked_scatter` consumes elements from `source` sequentially based on
36 the flattened index of `True` values in `mask`, sharding the input or mask
37 would require a global prefix sum (scan) to determine the correct offset
38 in `source` for each rank. Without this communication overhead, correct
39 behavior cannot be guaranteed on sharded tensors.
41 Therefore, this implementation enforces that all inputs (input, mask, source)
42 must be fully Replicated (Unsharded).
44 Args:
45 layouts (tuple): Layouts of inputs. Expected:
46 layouts[0] (Layout): Input tensor layout.
47 layouts[1] (Layout): Mask tensor layout.
48 layouts[2] (Layout): Source tensor layout.
49 extra_args (dict): Additional arguments.
51 Returns:
52 Layout: Output tensor layout (same as input).
54 Raises:
55 ValueError: If any input tensor is sharded.
56 """
57 # Check partial status via base class
58 if not self._allow_partial_inputs:
59 self._check_partial_inputs(layouts)
61 # Check strict replication for all involved distributed tensors
62 for i, layout in enumerate(layouts):
63 if layout is None:
64 continue
66 # Check tensor_map for sharding
67 # -1 indicates un-sharded (Replicated).
68 # Any integer >= 0 or tuple (for interleaved) indicates sharding.
69 for dim_map in layout.tensor_map:
70 if dim_map != -1:
71 raise ValueError(
72 f"Operation {self.op_name}: Input {i} (Layout: {layout}) is sharded. "
73 f"masked_scatter currently only supports fully Replicated (Unsharded) tensors "
74 f"due to sequential dependency on source elements."
75 )
77 # Output layout follows input layout (which we verified is Replicated/None)
78 # Note: Must return a single Layout object, not a tuple, because the op returns a single Tensor.
79 return layouts[0]