Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_conv3d.py: 69%

78 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Conv3d operator. 

17""" 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22class Conv3dDistributedOp(DistributedOp): 

23 """ 

24 Distributed implementation for torch.nn.functional.conv3d. 

25 Supports Data Parallel, Tensor Parallel (Column/Row), and Spatial Parallel. 

26 """ 

27 

28 def __init__(self, op_name): 

29 super().__init__(op_name) 

30 self._allow_partial_inputs = False 

31 

32 def _validate_row_parallelism(self, in_map, w_map, groups): 

33 """ 

34 Validate constraints for Row Parallelism. 

35 """ 

36 # 1. Handle Groups Constraint for Row Parallelism 

37 if groups > 1: 

38 if in_map[1] != -1 or w_map[1] != -1: 

39 # Row Parallelism with groups > 1 requires advanced group-wise communication 

40 raise ValueError(f"{self.op_name}: Sharding on C_in with groups > 1 is not supported.") 

41 

42 # 2. Check Row Parallelism (Sharding on Channel In) 

43 # Input: (N, C_in, D, H, W), Weight: (C_out, C_in/groups, kD, kH, kW) 

44 if in_map[1] != -1: 

45 if in_map[1] != w_map[1]: 

46 raise ValueError(f"{self.op_name}: Input C_in and Weight C_in must be sharded on the same axis.") 

47 

48 def _validate_column_parallelism(self, w_layout, b_layout, groups): 

49 """ 

50 Validate constraints for Column Parallelism. 

51 """ 

52 w_map = w_layout.tensor_map 

53 w_map_0 = w_map[0][0] if isinstance(w_map[0], tuple) else w_map[0] 

54 

55 if w_map_0 != -1: 

56 # Check bias alignment 

57 if b_layout is not None: 

58 b_map = b_layout.tensor_map 

59 b_map_0 = b_map[0][0] if isinstance(b_map[0], tuple) else b_map[0] 

60 if w_map_0 != b_map_0: 

61 raise ValueError(f"{self.op_name}: Weight C_out and Bias C_out must be sharded on the same axis.") 

62 

63 # Check groups divisibility for Column Parallelism 

64 if groups > 1: 

65 axis_name = w_layout.alias_name[len(w_layout.alias_name) - 1 - w_map_0] 

66 dev_num = w_layout.mesh.get_device_num_along_axis(axis_name) 

67 

68 if groups % dev_num != 0: 

69 raise ValueError( 

70 f"{self.op_name}: For Column Parallelism, groups ({groups}) " 

71 f"must be divisible by tp_size ({dev_num})." 

72 ) 

73 

74 def infer_layout(self, layouts, extra_args=None): 

75 """ 

76 Infer output layout for Conv3d based on PyTorch functional.conv3d signature: 

77 (input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) 

78 

79 Args: 

80 layouts (tuple): (input_layout, weight_layout, bias_layout) 

81 extra_args (tuple): (stride, padding, dilation, groups) 

82 """ 

83 self._check_partial_inputs(layouts) 

84 

85 if not layouts or len(layouts) < 2: 

86 raise ValueError(f"{self.op_name}: Requires at least input and weight layouts.") 

87 

88 in_layout, w_layout = layouts[0], layouts[1] 

89 b_layout = layouts[2] if len(layouts) > 2 else None 

90 

91 # Extract groups from extra_args (index 3 based on functional.conv3d signature) 

92 # stride=0, padding=1, dilation=2, groups=3 

93 groups = extra_args[3] if extra_args and len(extra_args) > 3 else 1 

94 

95 in_map = in_layout.tensor_map 

96 w_map = w_layout.tensor_map 

97 

98 # Validate dimensions 

99 if len(in_map) != 5 or len(w_map) != 5: 

100 raise ValueError(f"{self.op_name}: Input and weight must be 5D.") 

101 

102 # Delegate validation to helper methods to reduce cyclomatic complexity 

103 self._validate_row_parallelism(in_map, w_map, groups) 

104 self._validate_column_parallelism(w_layout, b_layout, groups) 

105 

106 # Construct Output Map (N, C_out, D_out, H_out, W_out) 

107 out_map = [ 

108 in_map[0], # N 

109 w_map[0], # C_out 

110 in_map[2], # D 

111 in_map[3], # H 

112 in_map[4] # W 

113 ] 

114 

115 # Build Layout 

116 mesh_shape = in_layout.mesh_shape 

117 alias_name = in_layout.alias_name 

118 rank_list = in_layout.rank_list 

119 

120 def idx_to_alias(idx): 

121 if idx == -1: return "None" 

122 return alias_name[len(alias_name) - idx - 1] 

123 

124 output_alias_map = tuple(idx_to_alias(idx) for idx in out_map) 

125 output_layout = Layout(mesh_shape, alias_name, rank_list) 

126 output_layout = output_layout(*output_alias_map) 

127 

128 # Set Partial status for Row Parallelism 

129 if in_map[1] != -1: 

130 partial_axis = idx_to_alias(in_map[1]) 

131 output_layout.set_partial_by_dev_axis(partial_axis, "sum") 

132 

133 return output_layout 

134 

135 def get_expand_impl(self, func, infer_result, layouts, extra_args=None): 

136 """ 

137 Get expand implementation for the operator. 

138 Intercepts the execution to handle Grouped Convolution with Column Parallelism. 

139 """ 

140 _, w_layout = layouts[0], layouts[1] 

141 w_map = w_layout.tensor_map 

142 

143 # Extract the exact mesh mapping for C_out 

144 w_map_0 = w_map[0][0] if isinstance(w_map[0], tuple) else w_map[0] 

145 

146 # If Weight is NOT sharded on C_out (dim=0), native conv3d works fine. 

147 if w_map_0 == -1: 

148 return None 

149 

150 parsed_groups = extra_args[3] if extra_args and len(extra_args) > 3 else 1 

151 

152 mesh = w_layout.mesh 

153 # Find the mesh axis name where C_out is sharded 

154 axis_name = w_layout.alias_name[len(w_layout.alias_name) - 1 - w_map_0] 

155 dev_num = mesh.get_device_num_along_axis(axis_name) 

156 local_rank = mesh.get_local_rank(axis_name) 

157 

158 # Pre-calculate local groups and group boundaries for the current device ahead of time. 

159 # This hoisting optimization avoids redundant calculations during every forward pass. 

160 local_groups = parsed_groups // dev_num if parsed_groups > 1 else 1 

161 start_group = local_rank * local_groups 

162 end_group = start_group + local_groups 

163 

164 

165 def distributed_conv3d_impl(input_tensor, weight_tensor, bias=None, stride=1, padding=0, dilation=1, groups=1): 

166 # If standard convolution, fallback to native PyTorch function 

167 if groups == 1: 

168 return func(input_tensor, weight_tensor, bias, stride, padding, dilation, groups) 

169 

170 # --- Handling Groups > 1 with Column Parallelism --- 

171 # Calculate the input channel chunk size 

172 c_in = input_tensor.shape[1] 

173 c_in_per_group = c_in // groups 

174 

175 # Map the pre-calculated groups to the actual input channels 

176 # Uses start_group and end_group captured from the outer scope 

177 start_channel = start_group * c_in_per_group 

178 end_channel = end_group * c_in_per_group 

179 

180 # Slice the replicated input to match the local groups 

181 sliced_input = input_tensor[:, start_channel:end_channel, ...] 

182 

183 # Execute native conv3d with the sliced input and adjusted local groups 

184 return func(sliced_input, weight_tensor, bias, stride, padding, dilation, local_groups) 

185 

186 return distributed_conv3d_impl