Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/platform/torch/common/moe.py 29.4% 138-139,189-195,323,325,327
hyper_parallel/platform/torch/common/moe.py
134
135
136
137
138
139
140
141
142
143
    w2_t = w2.transpose(1, 2).contiguous()  # [num_experts, hidden_dim, dim]
    h1 = torch._grouped_mm(x, w1_t, offs=offs)  # pylint: disable=protected-access
    h3 = torch._grouped_mm(x, w3_t, offs=offs)  # pylint: disable=protected-access
    h = F.silu(h1) * h3
    if scores is not None:
        h = h * scores.unsqueeze(-1)
    return torch._grouped_mm(h, w2_t, offs=offs)  # pylint: disable=protected-access


def _run_experts_grouped_mm_npu(
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    # group_type=-1 selects independent per-expert matmul (no shared axis).
    h1_list = torch_npu.npu_grouped_matmul(x_list, w1_list, group_type=-1)
    h3_list = torch_npu.npu_grouped_matmul(x_list, w3_list, group_type=-1)
    h_list = [F.silu(h1) * h3 for h1, h3 in zip(h1_list, h3_list)]
    if scores is not None:
        offset = 0
        for i, h in enumerate(h_list):
            n = counts[i]
            if n > 0:
                h_list[i] = h * scores[offset:offset + n].unsqueeze(-1)
            offset += n
    out_list = torch_npu.npu_grouped_matmul(h_list, w2_list, group_type=-1)
    return torch.cat(out_list, dim=0)

319
320
321
322
323
324
325
326
327
328
329
330
331
        if not self.use_grouped_mm:
            return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert, scores)

        if hasattr(torch, 'npu') and torch.npu.is_available():
            return _run_experts_grouped_mm_npu(w1, w2, w3, x, num_tokens_per_expert, scores)
        if torch.cuda.is_available():
            return _run_experts_grouped_mm_gpu(w1, w2, w3, x, num_tokens_per_expert, scores)

        return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert, scores)


# ---------------------------------------------------------------------------
# TokenChoiceTopKRouter