Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / common / moe.py: 85%
172 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""PyTorch MoE building blocks: router, experts and orchestrator.
17Provides :class:`FeedForward`, :class:`GroupedExperts`,
18:class:`TokenChoiceTopKRouter`, and the top-level :class:`MoE` orchestrator.
19Load-balancing utilities (expert bias update and auxiliary loss) are also
20included.
22All modules compose naturally with EP / TP parallel strategies via DTensor.
23Distributed collectives are handled by
24:mod:`hyper_parallel.core.expert_parallel.expert_parallel`; this module
25contains only single-device computation.
26"""
27import math
28from typing import Optional
30import torch
31from torch import nn
32import torch.nn.functional as F
34from hyper_parallel.core.dtensor.dtensor import DTensor
36__all__ = [
37 "FeedForward",
38 "GroupedExperts",
39 "TokenChoiceTopKRouter",
40 "MoE",
41 "update_expert_bias",
42]
45# ---------------------------------------------------------------------------
46# Grouped expert computation kernels
47# ---------------------------------------------------------------------------
49def _run_experts_for_loop(
50 w1: torch.Tensor,
51 w2: torch.Tensor,
52 w3: torch.Tensor,
53 x: torch.Tensor,
54 num_tokens_per_expert: torch.Tensor,
55) -> torch.Tensor:
56 """Run per-expert SwiGLU via a sequential loop (reference path).
58 Args:
59 w1: Shape ``[num_experts, hidden_dim, dim]``.
60 w2: Shape ``[num_experts, dim, hidden_dim]``.
61 w3: Shape ``[num_experts, hidden_dim, dim]``.
62 x: Routed tokens in expert-major order, shape
63 ``[total_routed_tokens, dim]``.
64 num_tokens_per_expert: 1-D integer tensor of length ``num_experts``.
66 Returns:
67 Expert output of shape ``[total_routed_tokens, dim]``.
68 """
69 # Use ``torch.cat`` instead of ``zeros_like + in-place slice assign``.
70 # On some backends (notably torch_npu) in-place slice assignment onto a
71 # ``requires_grad=False`` leaf tensor does not reliably upgrade it to a
72 # non-leaf with a ``grad_fn`` — the forward result may end up with
73 # ``grad_fn=None`` and downstream ``backward()`` fails with
74 # "element 0 of tensors does not require grad and does not have a grad_fn".
75 #
76 # Drain ``num_tokens_per_expert`` to host **once** via ``.tolist()``
77 # rather than calling ``int(n)`` per loop iteration — a single D2H
78 # copy instead of ``num_local_experts`` separate ones. Per-iter
79 # ``.item()`` would stall the host between expert kernels and shrink
80 # the dual-pipe overlap window.
81 counts_list = num_tokens_per_expert.tolist()
82 parts = []
83 offset = 0
84 for e, n in enumerate(counts_list):
85 if n == 0:
86 continue
87 x_e = x[offset:offset + n]
88 h = F.silu(x_e @ w1[e].T) * (x_e @ w3[e].T)
89 parts.append(h @ w2[e].T)
90 offset += n
91 if not parts:
92 # No routed tokens: return a grad-connected zero (not ``zeros_like``).
93 return x * 0.0
94 return torch.cat(parts, dim=0)
97def _run_experts_grouped_mm_gpu(
98 w1: torch.Tensor,
99 w2: torch.Tensor,
100 w3: torch.Tensor,
101 x: torch.Tensor,
102 num_tokens_per_expert: torch.Tensor,
103) -> torch.Tensor:
104 """Fused grouped matmul path for NVIDIA GPU using ``torch._grouped_mm``.
106 Args:
107 w1: Shape ``[num_experts, hidden_dim, dim]``.
108 w2: Shape ``[num_experts, dim, hidden_dim]``.
109 w3: Shape ``[num_experts, hidden_dim, dim]``.
110 x: Shape ``[total_routed_tokens, dim]``.
111 num_tokens_per_expert: 1-D integer tensor of length ``num_experts``.
113 Returns:
114 Expert output of shape ``[total_routed_tokens, dim]``.
115 """
116 # offs: cumulative split offsets (int32) for torch._grouped_mm.
117 offs = torch.cumsum(num_tokens_per_expert[:-1], dim=0).to(torch.int32)
118 # w1/w3 stored as [num_experts, hidden_dim, dim]; grouped_mm expects
119 # [num_experts, dim, hidden_dim], so transpose the inner two dims.
120 w1_t = w1.transpose(1, 2).contiguous() # [num_experts, dim, hidden_dim]
121 w3_t = w3.transpose(1, 2).contiguous()
122 w2_t = w2.transpose(1, 2).contiguous() # [num_experts, hidden_dim, dim]
123 h1 = torch._grouped_mm(x, w1_t, offs=offs) # pylint: disable=protected-access
124 h3 = torch._grouped_mm(x, w3_t, offs=offs) # pylint: disable=protected-access
125 h = F.silu(h1) * h3
126 return torch._grouped_mm(h, w2_t, offs=offs) # pylint: disable=protected-access
129def _run_experts_grouped_mm_npu(
130 w1: torch.Tensor,
131 w2: torch.Tensor,
132 w3: torch.Tensor,
133 x: torch.Tensor,
134 num_tokens_per_expert: torch.Tensor,
135) -> torch.Tensor:
136 """Fused grouped matmul path for Ascend NPU using ``torch_npu.npu_grouped_matmul``.
138 Requires ``torch_npu`` to be installed.
140 Args:
141 w1: Shape ``[num_experts, hidden_dim, dim]``.
142 w2: Shape ``[num_experts, dim, hidden_dim]``.
143 w3: Shape ``[num_experts, hidden_dim, dim]``.
144 x: Shape ``[total_routed_tokens, dim]``.
145 num_tokens_per_expert: 1-D integer tensor of length ``num_experts``.
147 Returns:
148 Expert output of shape ``[total_routed_tokens, dim]``.
149 """
150 import torch_npu # pylint: disable=C0415
152 # npu_grouped_matmul computes y = x @ weight (no implicit transpose).
153 # Our weight storage is [num_experts, out_dim, in_dim], matching F.linear's
154 # convention (weight.T for y = x @ weight.T). Transpose each expert shard
155 # so the shapes satisfy: [tokens, in_dim] @ [in_dim, out_dim] = [tokens, out_dim].
156 num_experts = w1.shape[0]
157 counts = num_tokens_per_expert.tolist()
158 x_list = list(torch.split(x, counts, dim=0))
159 # w1, w3: [E, hidden_dim, dim] → transposed per-expert: [dim, hidden_dim]
160 # w2: [E, dim, hidden_dim] → transposed per-expert: [hidden_dim, dim]
161 w1_list = [w1[e].T.contiguous() for e in range(num_experts)]
162 w2_list = [w2[e].T.contiguous() for e in range(num_experts)]
163 w3_list = [w3[e].T.contiguous() for e in range(num_experts)]
165 # npu_grouped_matmul: multi-multi-multi mode (x[i] @ weight[i]).
166 # group_type=-1 selects independent per-expert matmul (no shared axis).
167 h1_list = torch_npu.npu_grouped_matmul(x_list, w1_list, group_type=-1)
168 h3_list = torch_npu.npu_grouped_matmul(x_list, w3_list, group_type=-1)
169 h_list = [F.silu(h1) * h3 for h1, h3 in zip(h1_list, h3_list)]
170 out_list = torch_npu.npu_grouped_matmul(h_list, w2_list, group_type=-1)
171 return torch.cat(out_list, dim=0)
174# ---------------------------------------------------------------------------
175# FeedForward — shared expert / standard SwiGLU FFN
176# ---------------------------------------------------------------------------
178class FeedForward(nn.Module):
179 """SwiGLU feed-forward network, used as a shared (always-active) expert.
181 Implements: ``output = w2(silu(w1(x)) * w3(x))``
183 Args:
184 dim: Input embedding dimension.
185 hidden_dim: Intermediate hidden dimension.
186 bias: Whether to add a learnable bias. Defaults to ``False``.
188 Example::
189 >>> ff = FeedForward(dim=256, hidden_dim=512)
190 >>> out = ff(torch.randn(4, 16, 256))
191 >>> out.shape
192 torch.Size([4, 16, 256])
193 """
195 def __init__(self, dim: int, hidden_dim: int, bias: bool = False) -> None:
196 super().__init__()
197 self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
198 self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
199 self.w3 = nn.Linear(dim, hidden_dim, bias=bias)
201 def forward(self, x: torch.Tensor) -> torch.Tensor:
202 """Compute SwiGLU feed-forward output.
204 Args:
205 x: Input tensor of shape ``(..., dim)``.
207 Returns:
208 Output tensor with the same leading shape and last dimension ``dim``.
209 """
210 return self.w2(F.silu(self.w1(x)) * self.w3(x))
213# ---------------------------------------------------------------------------
214# GroupedExperts
215# ---------------------------------------------------------------------------
217class GroupedExperts(nn.Module):
218 """Batch expert computation with optional grouped matrix-multiply.
220 All expert weights are stored in a single 3-D parameter so that EP / TP
221 sharding strategies can distribute the expert dimension via DTensor.
223 Args:
224 dim: Token embedding dimension.
225 hidden_dim: Expert hidden dimension (SwiGLU intermediate size).
226 num_experts: Total number of experts.
227 use_grouped_mm: If ``True``, uses a hardware-accelerated grouped
228 matmul kernel (``torch._grouped_mm`` on GPU,
229 ``torch_npu.npu_grouped_matmul`` on NPU). Falls back to the
230 for-loop path when neither is available. Defaults to ``False``.
232 Note:
233 When weights are DTensors (e.g. after TP sharding via
234 :class:`ExpertTensorParallel`), ``forward`` calls ``.to_local()``
235 before computation.
237 Example::
238 >>> experts = GroupedExperts(dim=8, hidden_dim=16, num_experts=4)
239 >>> x = torch.randn(10, 8)
240 >>> counts = torch.tensor([3, 2, 4, 1])
241 >>> out = experts(x, counts)
242 >>> out.shape
243 torch.Size([10, 8])
244 """
246 def __init__(
247 self,
248 dim: int,
249 hidden_dim: int,
250 num_experts: int,
251 use_grouped_mm: bool = False,
252 ) -> None:
253 super().__init__()
254 # Weight layout: [num_experts, out_dim, in_dim] so that the standard
255 # linear operation is x @ w[e].T.
256 self.w1 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
257 self.w2 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
258 self.w3 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
259 self.num_experts = num_experts
260 self.use_grouped_mm = use_grouped_mm
261 self._reset_parameters()
263 def _reset_parameters(self) -> None:
264 """Kaiming-uniform initialisation for all expert weight tensors."""
265 for weight in (self.w1, self.w2, self.w3):
266 nn.init.kaiming_uniform_(weight.view(weight.shape[0], -1), a=math.sqrt(5))
268 def forward(
269 self,
270 x: torch.Tensor,
271 num_tokens_per_expert: torch.Tensor,
272 ) -> torch.Tensor:
273 """Run all experts on their assigned tokens.
275 Args:
276 x: Routed tokens in expert-major order,
277 shape ``[total_routed_tokens, dim]``.
278 num_tokens_per_expert: 1-D integer tensor of length
279 ``num_local_experts`` with the token count per expert.
281 Returns:
282 Expert output of shape ``[total_routed_tokens, dim]``.
283 """
284 # Extract local shard when parameters are DTensors (TP path).
285 w1 = self.w1.to_local() if isinstance(self.w1, DTensor) else self.w1
286 w2 = self.w2.to_local() if isinstance(self.w2, DTensor) else self.w2
287 w3 = self.w3.to_local() if isinstance(self.w3, DTensor) else self.w3
289 if not self.use_grouped_mm:
290 return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert)
292 if hasattr(torch, 'npu') and torch.npu.is_available():
293 return _run_experts_grouped_mm_npu(w1, w2, w3, x, num_tokens_per_expert)
294 if torch.cuda.is_available():
295 return _run_experts_grouped_mm_gpu(w1, w2, w3, x, num_tokens_per_expert)
297 return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert)
300# ---------------------------------------------------------------------------
301# TokenChoiceTopKRouter
302# ---------------------------------------------------------------------------
304class TokenChoiceTopKRouter(nn.Module):
305 """Top-K router: each token independently selects its top-K experts.
307 Args:
308 dim: Token embedding dimension (input to the gate).
309 num_experts: Total number of experts.
310 top_k: Experts selected per token. Defaults to ``1``.
311 score_func: Activation on gate logits before topk.
312 One of ``"sigmoid"`` (default) or ``"softmax"``.
313 num_expert_groups: For node-limited routing, number of expert
314 groups. ``None`` disables node-limited routing.
315 num_limited_groups: Groups to keep in node-limited routing. Required
316 when ``num_expert_groups`` is not ``None``.
317 route_scale: Scalar multiplier applied to routing scores.
318 Defaults to ``1.0``.
320 Example::
321 >>> router = TokenChoiceTopKRouter(dim=64, num_experts=8, top_k=2)
322 >>> scores, indices, counts = router(torch.randn(32, 64))
323 >>> scores.shape, indices.shape, counts.shape
324 (torch.Size([32, 2]), torch.Size([32, 2]), torch.Size([8]))
325 """
327 def __init__(
328 self,
329 dim: int,
330 num_experts: int,
331 top_k: int = 1,
332 score_func: str = "sigmoid",
333 num_expert_groups: Optional[int] = None,
334 num_limited_groups: Optional[int] = None,
335 route_scale: float = 1.0,
336 ) -> None:
337 super().__init__()
338 if score_func not in ("sigmoid", "softmax"):
339 raise ValueError(
340 f"score_func must be 'sigmoid' or 'softmax', got '{score_func}'."
341 )
342 if num_expert_groups is not None and num_limited_groups is None:
343 raise ValueError(
344 "num_limited_groups must be set when num_expert_groups is not None."
345 )
346 self.gate = nn.Linear(dim, num_experts, bias=False)
347 self.num_experts = num_experts
348 self.top_k = top_k
349 self.score_func = score_func
350 self.num_expert_groups = num_expert_groups
351 self.num_limited_groups = num_limited_groups
352 self.route_scale = route_scale
354 def forward(
355 self,
356 x: torch.Tensor,
357 expert_bias: Optional[torch.Tensor] = None,
358 ):
359 """Compute routing scores and top-K expert assignments.
361 Args:
362 x: Token tensor of shape ``[num_tokens, dim]``.
363 expert_bias: Optional 1-D tensor of shape ``[num_experts]`` added
364 to gate logits for topk selection only (auxiliary-loss-free
365 load balancing). Does not affect the returned ``top_scores``.
367 Returns:
368 Tuple of:
370 - ``top_scores``: shape ``[num_tokens, top_k]`` — routing weights.
371 - ``selected_experts``: shape ``[num_tokens, top_k]`` — expert IDs.
372 - ``num_tokens_per_expert``: shape ``[num_experts]`` — load counts.
373 """
374 # Gate in float32 for numerical stability.
375 scores = self.gate(x).float()
377 if self.score_func == "sigmoid":
378 scores = torch.sigmoid(scores)
379 else:
380 scores = F.softmax(scores, dim=-1)
382 if self.route_scale != 1.0:
383 scores = scores * self.route_scale
385 # Node-limited routing — mask out low-scoring expert groups.
386 scores_for_topk = scores
387 if self.num_expert_groups is not None:
388 scores_for_topk = self._get_node_limited_routing_scores(scores)
390 # Add expert bias only for selection; returned scores remain unbiased.
391 scores_with_bias = scores_for_topk
392 if expert_bias is not None:
393 scores_with_bias = scores_for_topk + expert_bias.float()
395 top_scores, selected_experts = scores_with_bias.topk(self.top_k, dim=-1)
396 # Gather unbiased scores as the actual routing weights.
397 top_scores = scores.gather(1, selected_experts)
399 num_tokens_per_expert = torch.bincount(
400 selected_experts.flatten(), minlength=self.num_experts
401 )
402 return top_scores, selected_experts, num_tokens_per_expert
404 def _get_node_limited_routing_scores(self, scores: torch.Tensor) -> torch.Tensor:
405 """Mask out low-scoring expert groups (node-limited routing).
407 Args:
408 scores: Routing scores of shape ``[num_tokens, num_experts]``.
410 Returns:
411 Scores with non-selected groups masked to ``-inf``.
412 """
413 num_tokens, num_experts = scores.shape
414 experts_per_group = num_experts // self.num_expert_groups
415 group_scores = scores.view(
416 num_tokens, self.num_expert_groups, experts_per_group
417 ).max(dim=-1).values # [num_tokens, num_groups]
419 _, selected_groups = group_scores.topk(self.num_limited_groups, dim=-1)
421 mask = scores.new_zeros(num_tokens, self.num_expert_groups)
422 mask.scatter_(1, selected_groups, 1.0)
423 mask = (
424 mask.unsqueeze(-1)
425 .expand(num_tokens, self.num_expert_groups, experts_per_group)
426 .reshape(num_tokens, num_experts)
427 )
428 return scores.masked_fill(mask == 0, float("-inf"))
431# ---------------------------------------------------------------------------
432# Load-balance auxiliary loss
433# ---------------------------------------------------------------------------
435def _compute_load_balance_loss(
436 top_scores: torch.Tensor,
437 selected_experts: torch.Tensor,
438 num_experts: int,
439) -> torch.Tensor:
440 """Compute load-balance auxiliary loss.
442 Standard formulation: ``loss = num_experts * Σ fraction_i * mean_score_i``
444 Args:
445 top_scores: Routing weights, shape ``[num_tokens, top_k]``.
446 selected_experts: Expert IDs, shape ``[num_tokens, top_k]``.
447 num_experts: Total number of experts.
449 Returns:
450 Scalar loss tensor.
451 """
452 num_tokens, top_k = top_scores.shape
453 flat_experts = selected_experts.flatten() # [num_tokens * top_k]
455 # Fraction of tokens sent to each expert (soft, uses routing probabilities).
456 one_hot = torch.zeros(
457 num_tokens * top_k, num_experts,
458 dtype=top_scores.dtype, device=top_scores.device,
459 ).scatter_(1, flat_experts.unsqueeze(1), 1.0)
460 # [num_tokens, top_k, num_experts] → mean per expert
461 expert_fraction = one_hot.view(num_tokens, top_k, num_experts).float().mean(dim=(0, 1))
463 # Mean routing probability per expert.
464 prob = F.softmax(top_scores.float(), dim=-1) # [num_tokens, top_k]
465 mean_score = prob.mean(dim=0) # [top_k]
467 # Scalar: num_experts * dot(mean_score, expert_fraction.T)
468 loss = num_experts * (mean_score.unsqueeze(-1) * expert_fraction.unsqueeze(0)).sum()
469 return loss
472# ---------------------------------------------------------------------------
473# MoE orchestrator
474# ---------------------------------------------------------------------------
476class MoE(nn.Module):
477 """Mixture-of-Experts layer.
479 Orchestrates routing, token permutation, expert computation, and output
480 scatter-add. Supports shared experts, auxiliary-loss-free load balancing
481 via expert bias, node-limited routing, and auxiliary load-balance loss.
483 Args:
484 dim: Token embedding dimension.
485 hidden_dim: Expert hidden dimension.
486 num_experts: Total number of experts.
487 top_k: Experts selected per token. Defaults to ``1``.
488 score_before_experts: If ``True``, multiply routed tokens by routing
489 weights *before* expert computation; otherwise multiply expert
490 outputs *after*. Defaults to ``True``.
491 load_balance_coeff: When not ``None``, attaches an auxiliary load-
492 balance loss as ``output._load_balance_loss``.
493 shared_expert: Optional :class:`FeedForward` running on every token
494 in parallel; output added to routed-expert output.
495 router_kwargs: Extra keyword arguments forwarded to
496 :class:`TokenChoiceTopKRouter`.
497 use_grouped_mm: If ``True``, uses a hardware-accelerated grouped
498 matmul kernel (e.g. ``npu_grouped_matmul``) inside
499 :class:`GroupedExperts`. Defaults to ``False``.
501 Note:
502 *Auxiliary-loss-free load balancing*: call :func:`update_expert_bias`
503 once per optimiser step to adjust ``expert_bias`` from the accumulated
504 ``tokens_per_expert`` histogram.
506 Example::
507 >>> moe = MoE(dim=64, hidden_dim=128, num_experts=8, top_k=2)
508 >>> out = moe(torch.randn(2, 16, 64))
509 >>> out.shape
510 torch.Size([2, 16, 64])
511 """
513 def __init__(
514 self,
515 dim: int,
516 hidden_dim: int,
517 num_experts: int,
518 top_k: int = 1,
519 score_before_experts: bool = True,
520 load_balance_coeff: Optional[float] = None,
521 shared_expert: Optional[FeedForward] = None,
522 router_kwargs: Optional[dict] = None,
523 use_grouped_mm: bool = False,
524 ) -> None:
525 super().__init__()
526 router_kw = router_kwargs or {}
527 self.experts = GroupedExperts(
528 dim=dim, hidden_dim=hidden_dim, num_experts=num_experts,
529 use_grouped_mm=use_grouped_mm,
530 )
531 self.router = TokenChoiceTopKRouter(
532 dim=dim, num_experts=num_experts, top_k=top_k, **router_kw,
533 )
534 self.shared_expert = shared_expert
535 self.num_experts = num_experts
536 self.top_k = top_k
537 self.score_before_experts = score_before_experts
538 self.load_balance_coeff = load_balance_coeff
540 # Auxiliary-loss-free load-balance buffers (no gradient).
541 self.register_buffer("expert_bias", torch.zeros(num_experts))
542 self.register_buffer("tokens_per_expert", torch.zeros(num_experts))
544 def forward(self, x: torch.Tensor) -> torch.Tensor:
545 """Run the MoE layer.
547 Args:
548 x: Input tensor of shape ``[batch, seq_len, dim]``.
550 Returns:
551 Output tensor of shape ``[batch, seq_len, dim]``. When
552 ``load_balance_coeff`` is set, carries a ``_load_balance_loss``
553 attribute with the auxiliary loss scalar.
554 """
555 # Extract local tensor if input arrives as DTensor (TP path).
556 if isinstance(x, DTensor):
557 x = x.to_local()
559 bs, seq_len, dim = x.shape
560 num_tokens = bs * seq_len
561 x_flat = x.view(num_tokens, dim) # [num_tokens, dim]
563 # --- Routing ---
564 top_scores, selected_experts, token_counts = self.router(
565 x_flat, self.expert_bias
566 )
568 # Accumulate token histogram without creating gradient nodes.
569 with torch.no_grad():
570 self.tokens_per_expert.add_(token_counts.float())
572 # --- Token permutation: token-major → expert-major (inline argsort) ---
573 # flat_experts[i] is the expert ID for the i-th (token, top_k) slot.
574 flat_experts = selected_experts.flatten() # [num_tokens * top_k]
575 flat_indices = flat_experts.argsort(stable=True) # expert-major permutation
576 top_scores_sorted = top_scores.flatten()[flat_indices] # [num_tokens * top_k]
577 # Each entry in flat_indices maps to a position in [0, num_tokens * top_k);
578 # divide by top_k to recover the original token row index.
579 token_indices = flat_indices // self.top_k # [num_tokens * top_k]
580 num_tokens_per_expert = torch.bincount(
581 flat_experts, minlength=self.num_experts
582 )
584 # Gather routed tokens in expert-major order.
585 routed_x = x_flat[token_indices] # [num_tokens * top_k, dim]
587 if self.score_before_experts:
588 routed_x = routed_x * top_scores_sorted.unsqueeze(1)
590 # --- Shared expert (parallel with routed experts) ---
591 shared_out = None
592 if self.shared_expert is not None:
593 shared_out = self.shared_expert(x_flat)
595 # --- Expert computation ---
596 expert_out = self.experts(routed_x, num_tokens_per_expert)
598 if not self.score_before_experts:
599 expert_out = expert_out * top_scores_sorted.unsqueeze(1)
601 # --- Scatter expert outputs back to token order ---
602 # Use out-of-place ``scatter_add`` so autograd correctly records
603 # ``ScatterAddBackward``; ``new_zeros + scatter_add_`` on some
604 # backends (torch_npu) leaves the leaf un-upgraded and the result
605 # without a ``grad_fn``.
606 out = torch.zeros(
607 num_tokens, dim, dtype=x_flat.dtype, device=x_flat.device,
608 ).scatter_add(
609 0,
610 token_indices.unsqueeze(1).expand(-1, dim),
611 expert_out,
612 )
614 if shared_out is not None:
615 out = out + shared_out
617 result = out.view(bs, seq_len, dim)
619 # Auxiliary load-balance loss attached to the returned tensor.
620 if self.load_balance_coeff is not None:
621 lb_loss = self.load_balance_coeff * _compute_load_balance_loss(
622 top_scores, selected_experts, self.num_experts
623 )
624 result._load_balance_loss = lb_loss # pylint: disable=protected-access
626 return result
629# ---------------------------------------------------------------------------
630# Expert bias update for auxiliary-loss-free load balancing
631# ---------------------------------------------------------------------------
633def update_expert_bias(moe: MoE, lr: float = 1e-3) -> None:
634 """Update expert bias for auxiliary-loss-free load balancing.
636 Should be called once per training step after the optimiser step.
637 Adjusts ``moe.expert_bias`` to push token load towards the mean, then
638 resets the ``tokens_per_expert`` accumulator.
640 Args:
641 moe: The :class:`MoE` module whose bias should be updated.
642 lr: Step size for the bias update. Defaults to ``1e-3``.
644 Example::
645 >>> # After optimizer.step():
646 >>> update_expert_bias(moe_layer, lr=1e-3)
647 """
648 with torch.no_grad():
649 avg = moe.tokens_per_expert.float().mean()
650 moe.expert_bias.data += lr * (avg - moe.tokens_per_expert.float()).sign()
651 moe.tokens_per_expert.zero_()