"""
Get expand implementation for the operator.
Intercepts the execution to handle Grouped Convolution with Column Parallelism.
"""
w_layout = cache_values[1]
w_map = w_layout.alias_tensor_map
w_map_0 = w_map[0]
# If Weight is NOT sharded on C_out (dim=0), native conv3d works fine.
if w_map_0 == "None":
return None
parsed_groups = cache_values[6]
if parsed_groups == 1:
return None
mesh = w_layout.mesh
axes = w_map_0 if isinstance(w_map_0, tuple) else (w_map_0,)
dev_num = 1
local_rank = 0
for axis_name in axes:
axis_size = mesh.get_device_num_along_axis(axis_name)
dev_num *= axis_size
local_rank = local_rank * axis_size + mesh.get_local_rank(axis_name)
# Pre-calculate local groups and group boundaries for the current device ahead of time.
# This hoisting optimization avoids redundant calculations during every forward pass.
local_groups = parsed_groups // dev_num
start_group = local_rank * local_groups
end_group = start_group + local_groups
def distributed_conv3d_impl(input_tensor, weight_tensor, bias=None, stride=1, padding=0, dilation=1, groups=1):
# --- Handling Groups > 1 with Column Parallelism ---
# Calculate the input channel chunk size
c_in = input_tensor.shape[1]
c_in_per_group = c_in // groups
# Map the pre-calculated groups to the actual input channels
# Uses start_group and end_group captured from the outer scope
start_channel = start_group * c_in_per_group
end_channel = end_group * c_in_per_group
# Slice the replicated input to match the local groups
sliced_input = input_tensor[:, start_channel:end_channel, ...]
# Execute native conv3d with the sliced input and adjusted local groups
return func(sliced_input, weight_tensor, bias, stride, padding, dilation, local_groups)
return distributed_conv3d_impl