output_shape=[mesh_info.oep_size * num_tokens_per_expert.shape[0]],
group=mesh_info.oep_group,
async_op=True,
)
if handle is not None:
handle.wait()
gathered_counts = gathered_counts.view(mesh_info.oep_size, num_tokens_per_expert.shape[0])
source_token_totals = gathered_counts.sum(dim=1).tolist()
if any(total != routed_input.shape[0] for total in source_token_totals):
raise ValueError(
"DeredundencyTokenDispatcher requires equal routed token "
"counts within each OEP group because the shared token view "
f"uses all-gather, got totals {source_token_totals}."
)
gathered_routed = platform.differentiable_all_gather_concat(
routed_input, mesh_info.oep_group, mesh_info.oep_size, 0,
)
if router_coeff is None:
gathered_router_coeff = None
else:
gathered_router_coeff = platform.differentiable_all_gather_concat(
router_coeff, mesh_info.oep_group, mesh_info.oep_size, 0,
)
return gathered_counts, gathered_routed, gathered_router_coeff
@staticmethod
def dispatch(module: Module, inputs: tuple, device_mesh: DeviceMesh) -> tuple:
"""Dispatch tokens using OEP all-gather and IEP all-to-all.