ValueError: If both ``x_layout`` and ``w_layout`` have Partial on the same device
axis with different reduce operations (e.g. one is 'sum' and the other 'avg').
"""
if x_layout is None or w_layout is None:
return
# Propagate x's partial status to output
for dev_idx, op in enumerate(x_layout.partial):
if op is not None: