Coverage for hyper_parallel / core / shard / ops / parallel_pad.py: 78%

32 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Pad operator. 

17""" 

18 

19from .parallel_ops import DistributedOp 

20 

21 

22class PadDistributedOp(DistributedOp): 

23 """Distributed implementation for Pad operator.""" 

24 

25 def infer_layout(self, layouts, extra_args): 

26 """ 

27 Infer output layout for Pad operator. 

28 

29 The Pad operator expands the tensor size. In a distributed setting, 

30 padding on a sharded dimension is generally not supported without 

31 explicit redistribution because it disrupts the uniform slicing logic. 

32 

33 Args: 

34 layouts (tuple): Layouts of input tensor. 

35 extra_args (tuple): Arguments for the operator. 

36 extra_args[0] should be the 'pad' tuple. 

37 

38 Returns: 

39 Layout: Layout for output tensor (same as input layout). 

40 

41 Raises: 

42 ValueError: If padding is attempted on a sharded dimension. 

43 """ 

44 input_layout = layouts[0] 

45 tensor_map = input_layout.alias_tensor_map 

46 ndim = len(tensor_map) 

47 

48 # Parse extra_args 

49 # Pytorch style pad: (input, pad, mode, value) -> inputs are stripped by dispatcher 

50 # extra_args received by dispatcher: (pad, mode, value) (mode and value are optional/defaulted if not passed?) 

51 # Based on OpDispatcher logic, extra_args contains arguments without layout. 

52 if not extra_args or not isinstance(extra_args[0], (tuple, list)): 

53 raise ValueError(f"For '{self.op_name}', expected pad tuple as the first element in extra_args, " 

54 f"but got {extra_args}") 

55 

56 pad = extra_args[0] 

57 pad_len = len(pad) 

58 

59 if pad_len % 2 != 0: 

60 raise ValueError(f"Pad tuple length must be even, but got {pad_len}") 

61 

62 # Pytorch pad tuple format: (last_dim_left, last_dim_right, 2nd_last_left, 2nd_last_right, ...) 

63 # We need to check if any dimension being padded is currently sharded. 

64 num_padded_dims = pad_len // 2 

65 if num_padded_dims > ndim: 

66 raise ValueError(f"Padding {num_padded_dims} dimensions but tensor only has {ndim} dimensions.") 

67 

68 for i in range(num_padded_dims): 

69 # Calculate the dimension index in the tensor (from 0 to ndim-1) 

70 # pad index 0,1 -> last dimension (ndim - 1) 

71 # pad index 2,3 -> second to last dimension (ndim - 2) 

72 dim_index = ndim - 1 - i 

73 

74 pad_left = pad[2 * i] 

75 pad_right = pad[2 * i + 1] 

76 

77 # If padding is applied on this dimension 

78 if pad_left != 0 or pad_right != 0: 

79 axis_alias = tensor_map[dim_index] 

80 is_sharded = False 

81 

82 # Check if the axis alias indicates sharding (i.e., not "None") 

83 if isinstance(axis_alias, (tuple, list)): 

84 for sub_alias in axis_alias: 

85 if sub_alias != "None": 

86 is_sharded = True 

87 break 

88 elif axis_alias != "None": 

89 is_sharded = True 

90 

91 if is_sharded: 

92 raise ValueError( 

93 f"Distributed Pad operator does not support padding on a sharded dimension. " 

94 f"Dimension {dim_index} (alias: {axis_alias}) is sharded. " 

95 f"Please redistribute the tensor to Replicate status on this dimension before padding." 

96 ) 

97 

98 # If no sharded dimension is padded, the output layout is identical to the input layout. 

99 # The local tensor shape changes, but the mapping from device mesh to tensor dimensions remains valid. 

100 return input_layout 

101 

102 # Note: get_expand_impl is not overridden because we default to returning None. 

103 # OpDispatcher will use the original function (e.g., torch.nn.functional.pad) on the local tensor. 

104 # Since we ensured the padded dimensions are Replicated, local padding is mathematically correct.