Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_conv3d.py: 69%
78 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Conv3d operator.
17"""
19from hyper_parallel.core.dtensor.layout import Layout
20from .parallel_ops import DistributedOp
22class Conv3dDistributedOp(DistributedOp):
23 """
24 Distributed implementation for torch.nn.functional.conv3d.
25 Supports Data Parallel, Tensor Parallel (Column/Row), and Spatial Parallel.
26 """
28 def __init__(self, op_name):
29 super().__init__(op_name)
30 self._allow_partial_inputs = False
32 def _validate_row_parallelism(self, in_map, w_map, groups):
33 """
34 Validate constraints for Row Parallelism.
35 """
36 # 1. Handle Groups Constraint for Row Parallelism
37 if groups > 1:
38 if in_map[1] != -1 or w_map[1] != -1:
39 # Row Parallelism with groups > 1 requires advanced group-wise communication
40 raise ValueError(f"{self.op_name}: Sharding on C_in with groups > 1 is not supported.")
42 # 2. Check Row Parallelism (Sharding on Channel In)
43 # Input: (N, C_in, D, H, W), Weight: (C_out, C_in/groups, kD, kH, kW)
44 if in_map[1] != -1:
45 if in_map[1] != w_map[1]:
46 raise ValueError(f"{self.op_name}: Input C_in and Weight C_in must be sharded on the same axis.")
48 def _validate_column_parallelism(self, w_layout, b_layout, groups):
49 """
50 Validate constraints for Column Parallelism.
51 """
52 w_map = w_layout.tensor_map
53 w_map_0 = w_map[0][0] if isinstance(w_map[0], tuple) else w_map[0]
55 if w_map_0 != -1:
56 # Check bias alignment
57 if b_layout is not None:
58 b_map = b_layout.tensor_map
59 b_map_0 = b_map[0][0] if isinstance(b_map[0], tuple) else b_map[0]
60 if w_map_0 != b_map_0:
61 raise ValueError(f"{self.op_name}: Weight C_out and Bias C_out must be sharded on the same axis.")
63 # Check groups divisibility for Column Parallelism
64 if groups > 1:
65 axis_name = w_layout.alias_name[len(w_layout.alias_name) - 1 - w_map_0]
66 dev_num = w_layout.mesh.get_device_num_along_axis(axis_name)
68 if groups % dev_num != 0:
69 raise ValueError(
70 f"{self.op_name}: For Column Parallelism, groups ({groups}) "
71 f"must be divisible by tp_size ({dev_num})."
72 )
74 def infer_layout(self, layouts, extra_args=None):
75 """
76 Infer output layout for Conv3d based on PyTorch functional.conv3d signature:
77 (input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)
79 Args:
80 layouts (tuple): (input_layout, weight_layout, bias_layout)
81 extra_args (tuple): (stride, padding, dilation, groups)
82 """
83 self._check_partial_inputs(layouts)
85 if not layouts or len(layouts) < 2:
86 raise ValueError(f"{self.op_name}: Requires at least input and weight layouts.")
88 in_layout, w_layout = layouts[0], layouts[1]
89 b_layout = layouts[2] if len(layouts) > 2 else None
91 # Extract groups from extra_args (index 3 based on functional.conv3d signature)
92 # stride=0, padding=1, dilation=2, groups=3
93 groups = extra_args[3] if extra_args and len(extra_args) > 3 else 1
95 in_map = in_layout.tensor_map
96 w_map = w_layout.tensor_map
98 # Validate dimensions
99 if len(in_map) != 5 or len(w_map) != 5:
100 raise ValueError(f"{self.op_name}: Input and weight must be 5D.")
102 # Delegate validation to helper methods to reduce cyclomatic complexity
103 self._validate_row_parallelism(in_map, w_map, groups)
104 self._validate_column_parallelism(w_layout, b_layout, groups)
106 # Construct Output Map (N, C_out, D_out, H_out, W_out)
107 out_map = [
108 in_map[0], # N
109 w_map[0], # C_out
110 in_map[2], # D
111 in_map[3], # H
112 in_map[4] # W
113 ]
115 # Build Layout
116 mesh_shape = in_layout.mesh_shape
117 alias_name = in_layout.alias_name
118 rank_list = in_layout.rank_list
120 def idx_to_alias(idx):
121 if idx == -1: return "None"
122 return alias_name[len(alias_name) - idx - 1]
124 output_alias_map = tuple(idx_to_alias(idx) for idx in out_map)
125 output_layout = Layout(mesh_shape, alias_name, rank_list)
126 output_layout = output_layout(*output_alias_map)
128 # Set Partial status for Row Parallelism
129 if in_map[1] != -1:
130 partial_axis = idx_to_alias(in_map[1])
131 output_layout.set_partial_by_dev_axis(partial_axis, "sum")
133 return output_layout
135 def get_expand_impl(self, func, infer_result, layouts, extra_args=None):
136 """
137 Get expand implementation for the operator.
138 Intercepts the execution to handle Grouped Convolution with Column Parallelism.
139 """
140 _, w_layout = layouts[0], layouts[1]
141 w_map = w_layout.tensor_map
143 # Extract the exact mesh mapping for C_out
144 w_map_0 = w_map[0][0] if isinstance(w_map[0], tuple) else w_map[0]
146 # If Weight is NOT sharded on C_out (dim=0), native conv3d works fine.
147 if w_map_0 == -1:
148 return None
150 parsed_groups = extra_args[3] if extra_args and len(extra_args) > 3 else 1
152 mesh = w_layout.mesh
153 # Find the mesh axis name where C_out is sharded
154 axis_name = w_layout.alias_name[len(w_layout.alias_name) - 1 - w_map_0]
155 dev_num = mesh.get_device_num_along_axis(axis_name)
156 local_rank = mesh.get_local_rank(axis_name)
158 # Pre-calculate local groups and group boundaries for the current device ahead of time.
159 # This hoisting optimization avoids redundant calculations during every forward pass.
160 local_groups = parsed_groups // dev_num if parsed_groups > 1 else 1
161 start_group = local_rank * local_groups
162 end_group = start_group + local_groups
165 def distributed_conv3d_impl(input_tensor, weight_tensor, bias=None, stride=1, padding=0, dilation=1, groups=1):
166 # If standard convolution, fallback to native PyTorch function
167 if groups == 1:
168 return func(input_tensor, weight_tensor, bias, stride, padding, dilation, groups)
170 # --- Handling Groups > 1 with Column Parallelism ---
171 # Calculate the input channel chunk size
172 c_in = input_tensor.shape[1]
173 c_in_per_group = c_in // groups
175 # Map the pre-calculated groups to the actual input channels
176 # Uses start_group and end_group captured from the outer scope
177 start_channel = start_group * c_in_per_group
178 end_channel = end_group * c_in_per_group
180 # Slice the replicated input to match the local groups
181 sliced_input = input_tensor[:, start_channel:end_channel, ...]
183 # Execute native conv3d with the sliced input and adjusted local groups
184 return func(sliced_input, weight_tensor, bias, stride, padding, dilation, local_groups)
186 return distributed_conv3d_impl