if not self.use_grouped_mm:
return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert, scores)
if hasattr(torch, 'npu') and torch.npu.is_available():
return _run_experts_grouped_mm_npu(w1, w2, w3, x, num_tokens_per_expert, scores)
if torch.cuda.is_available():
return _run_experts_grouped_mm_gpu(w1, w2, w3, x, num_tokens_per_expert, scores)
return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert, scores)
# ---------------------------------------------------------------------------
# TokenChoiceTopKRouter