Coverage for hyper_parallel / core / shard / ops / parallel_pad.py: 78%
32 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Pad operator.
17"""
19from .parallel_ops import DistributedOp
22class PadDistributedOp(DistributedOp):
23 """Distributed implementation for Pad operator."""
25 def infer_layout(self, layouts, extra_args):
26 """
27 Infer output layout for Pad operator.
29 The Pad operator expands the tensor size. In a distributed setting,
30 padding on a sharded dimension is generally not supported without
31 explicit redistribution because it disrupts the uniform slicing logic.
33 Args:
34 layouts (tuple): Layouts of input tensor.
35 extra_args (tuple): Arguments for the operator.
36 extra_args[0] should be the 'pad' tuple.
38 Returns:
39 Layout: Layout for output tensor (same as input layout).
41 Raises:
42 ValueError: If padding is attempted on a sharded dimension.
43 """
44 input_layout = layouts[0]
45 tensor_map = input_layout.alias_tensor_map
46 ndim = len(tensor_map)
48 # Parse extra_args
49 # Pytorch style pad: (input, pad, mode, value) -> inputs are stripped by dispatcher
50 # extra_args received by dispatcher: (pad, mode, value) (mode and value are optional/defaulted if not passed?)
51 # Based on OpDispatcher logic, extra_args contains arguments without layout.
52 if not extra_args or not isinstance(extra_args[0], (tuple, list)):
53 raise ValueError(f"For '{self.op_name}', expected pad tuple as the first element in extra_args, "
54 f"but got {extra_args}")
56 pad = extra_args[0]
57 pad_len = len(pad)
59 if pad_len % 2 != 0:
60 raise ValueError(f"Pad tuple length must be even, but got {pad_len}")
62 # Pytorch pad tuple format: (last_dim_left, last_dim_right, 2nd_last_left, 2nd_last_right, ...)
63 # We need to check if any dimension being padded is currently sharded.
64 num_padded_dims = pad_len // 2
65 if num_padded_dims > ndim:
66 raise ValueError(f"Padding {num_padded_dims} dimensions but tensor only has {ndim} dimensions.")
68 for i in range(num_padded_dims):
69 # Calculate the dimension index in the tensor (from 0 to ndim-1)
70 # pad index 0,1 -> last dimension (ndim - 1)
71 # pad index 2,3 -> second to last dimension (ndim - 2)
72 dim_index = ndim - 1 - i
74 pad_left = pad[2 * i]
75 pad_right = pad[2 * i + 1]
77 # If padding is applied on this dimension
78 if pad_left != 0 or pad_right != 0:
79 axis_alias = tensor_map[dim_index]
80 is_sharded = False
82 # Check if the axis alias indicates sharding (i.e., not "None")
83 if isinstance(axis_alias, (tuple, list)):
84 for sub_alias in axis_alias:
85 if sub_alias != "None":
86 is_sharded = True
87 break
88 elif axis_alias != "None":
89 is_sharded = True
91 if is_sharded:
92 raise ValueError(
93 f"Distributed Pad operator does not support padding on a sharded dimension. "
94 f"Dimension {dim_index} (alias: {axis_alias}) is sharded. "
95 f"Please redistribute the tensor to Replicate status on this dimension before padding."
96 )
98 # If no sharded dimension is padded, the output layout is identical to the input layout.
99 # The local tensor shape changes, but the mapping from device mesh to tensor dimensions remains valid.
100 return input_layout
102 # Note: get_expand_impl is not overridden because we default to returning None.
103 # OpDispatcher will use the original function (e.g., torch.nn.functional.pad) on the local tensor.
104 # Since we ensured the padded dimensions are Replicated, local padding is mathematically correct.