Coverage for hyper_parallel / core / shard / ops / parallel_masked_scatter.py: 92%

12 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for MaskedScatter operator. 

17""" 

18 

19from .parallel_ops import DistributedOp 

20 

21 

22class MaskedScatterDistributedOp(DistributedOp): 

23 """Distributed implementation for torch.Tensor.masked_scatter.""" 

24 

25 def infer_layout(self, layouts, extra_args=None): 

26 """ 

27 Infer output layout for torch.Tensor.masked_scatter. 

28 

29 PyTorch semantics: 

30 masked_scatter_(mask, source) 

31 Copies elements from source into self tensor at positions where the mask is True. 

32 Elements from source are taken in order. 

33 

34 Distributed restrictions: 

35 Because `masked_scatter` consumes elements from `source` sequentially based on 

36 the flattened index of `True` values in `mask`, sharding the input or mask 

37 would require a global prefix sum (scan) to determine the correct offset 

38 in `source` for each rank. Without this communication overhead, correct 

39 behavior cannot be guaranteed on sharded tensors. 

40 

41 Therefore, this implementation enforces that all inputs (input, mask, source) 

42 must be fully Replicated (Unsharded). 

43 

44 Args: 

45 layouts (tuple): Layouts of inputs. Expected: 

46 layouts[0] (Layout): Input tensor layout. 

47 layouts[1] (Layout): Mask tensor layout. 

48 layouts[2] (Layout): Source tensor layout. 

49 extra_args (dict): Additional arguments. 

50 

51 Returns: 

52 Layout: Output tensor layout (same as input). 

53 

54 Raises: 

55 ValueError: If any input tensor is sharded. 

56 """ 

57 # Check partial status via base class 

58 if not self._allow_partial_inputs: 

59 self._check_partial_inputs(layouts) 

60 

61 # Check strict replication for all involved distributed tensors 

62 for i, layout in enumerate(layouts): 

63 if layout is None: 

64 continue 

65 

66 # Check tensor_map for sharding 

67 # -1 indicates un-sharded (Replicated). 

68 # Any integer >= 0 or tuple (for interleaved) indicates sharding. 

69 for dim_map in layout.tensor_map: 

70 if dim_map != -1: 

71 raise ValueError( 

72 f"Operation {self.op_name}: Input {i} (Layout: {layout}) is sharded. " 

73 f"masked_scatter currently only supports fully Replicated (Unsharded) tensors " 

74 f"due to sequential dependency on source elements." 

75 ) 

76 

77 # Output layout follows input layout (which we verified is Replicated/None) 

78 # Note: Must return a single Layout object, not a tuple, because the op returns a single Tensor. 

79 return layouts[0]