Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / common / moe.py: 85%

172 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""PyTorch MoE building blocks: router, experts and orchestrator. 

16 

17Provides :class:`FeedForward`, :class:`GroupedExperts`, 

18:class:`TokenChoiceTopKRouter`, and the top-level :class:`MoE` orchestrator. 

19Load-balancing utilities (expert bias update and auxiliary loss) are also 

20included. 

21 

22All modules compose naturally with EP / TP parallel strategies via DTensor. 

23Distributed collectives are handled by 

24:mod:`hyper_parallel.core.expert_parallel.expert_parallel`; this module 

25contains only single-device computation. 

26""" 

27import math 

28from typing import Optional 

29 

30import torch 

31from torch import nn 

32import torch.nn.functional as F 

33 

34from hyper_parallel.core.dtensor.dtensor import DTensor 

35 

36__all__ = [ 

37 "FeedForward", 

38 "GroupedExperts", 

39 "TokenChoiceTopKRouter", 

40 "MoE", 

41 "update_expert_bias", 

42] 

43 

44 

45# --------------------------------------------------------------------------- 

46# Grouped expert computation kernels 

47# --------------------------------------------------------------------------- 

48 

49def _run_experts_for_loop( 

50 w1: torch.Tensor, 

51 w2: torch.Tensor, 

52 w3: torch.Tensor, 

53 x: torch.Tensor, 

54 num_tokens_per_expert: torch.Tensor, 

55) -> torch.Tensor: 

56 """Run per-expert SwiGLU via a sequential loop (reference path). 

57 

58 Args: 

59 w1: Shape ``[num_experts, hidden_dim, dim]``. 

60 w2: Shape ``[num_experts, dim, hidden_dim]``. 

61 w3: Shape ``[num_experts, hidden_dim, dim]``. 

62 x: Routed tokens in expert-major order, shape 

63 ``[total_routed_tokens, dim]``. 

64 num_tokens_per_expert: 1-D integer tensor of length ``num_experts``. 

65 

66 Returns: 

67 Expert output of shape ``[total_routed_tokens, dim]``. 

68 """ 

69 # Use ``torch.cat`` instead of ``zeros_like + in-place slice assign``. 

70 # On some backends (notably torch_npu) in-place slice assignment onto a 

71 # ``requires_grad=False`` leaf tensor does not reliably upgrade it to a 

72 # non-leaf with a ``grad_fn`` — the forward result may end up with 

73 # ``grad_fn=None`` and downstream ``backward()`` fails with 

74 # "element 0 of tensors does not require grad and does not have a grad_fn". 

75 # 

76 # Drain ``num_tokens_per_expert`` to host **once** via ``.tolist()`` 

77 # rather than calling ``int(n)`` per loop iteration — a single D2H 

78 # copy instead of ``num_local_experts`` separate ones. Per-iter 

79 # ``.item()`` would stall the host between expert kernels and shrink 

80 # the dual-pipe overlap window. 

81 counts_list = num_tokens_per_expert.tolist() 

82 parts = [] 

83 offset = 0 

84 for e, n in enumerate(counts_list): 

85 if n == 0: 

86 continue 

87 x_e = x[offset:offset + n] 

88 h = F.silu(x_e @ w1[e].T) * (x_e @ w3[e].T) 

89 parts.append(h @ w2[e].T) 

90 offset += n 

91 if not parts: 

92 # No routed tokens: return a grad-connected zero (not ``zeros_like``). 

93 return x * 0.0 

94 return torch.cat(parts, dim=0) 

95 

96 

97def _run_experts_grouped_mm_gpu( 

98 w1: torch.Tensor, 

99 w2: torch.Tensor, 

100 w3: torch.Tensor, 

101 x: torch.Tensor, 

102 num_tokens_per_expert: torch.Tensor, 

103) -> torch.Tensor: 

104 """Fused grouped matmul path for NVIDIA GPU using ``torch._grouped_mm``. 

105 

106 Args: 

107 w1: Shape ``[num_experts, hidden_dim, dim]``. 

108 w2: Shape ``[num_experts, dim, hidden_dim]``. 

109 w3: Shape ``[num_experts, hidden_dim, dim]``. 

110 x: Shape ``[total_routed_tokens, dim]``. 

111 num_tokens_per_expert: 1-D integer tensor of length ``num_experts``. 

112 

113 Returns: 

114 Expert output of shape ``[total_routed_tokens, dim]``. 

115 """ 

116 # offs: cumulative split offsets (int32) for torch._grouped_mm. 

117 offs = torch.cumsum(num_tokens_per_expert[:-1], dim=0).to(torch.int32) 

118 # w1/w3 stored as [num_experts, hidden_dim, dim]; grouped_mm expects 

119 # [num_experts, dim, hidden_dim], so transpose the inner two dims. 

120 w1_t = w1.transpose(1, 2).contiguous() # [num_experts, dim, hidden_dim] 

121 w3_t = w3.transpose(1, 2).contiguous() 

122 w2_t = w2.transpose(1, 2).contiguous() # [num_experts, hidden_dim, dim] 

123 h1 = torch._grouped_mm(x, w1_t, offs=offs) # pylint: disable=protected-access 

124 h3 = torch._grouped_mm(x, w3_t, offs=offs) # pylint: disable=protected-access 

125 h = F.silu(h1) * h3 

126 return torch._grouped_mm(h, w2_t, offs=offs) # pylint: disable=protected-access 

127 

128 

129def _run_experts_grouped_mm_npu( 

130 w1: torch.Tensor, 

131 w2: torch.Tensor, 

132 w3: torch.Tensor, 

133 x: torch.Tensor, 

134 num_tokens_per_expert: torch.Tensor, 

135) -> torch.Tensor: 

136 """Fused grouped matmul path for Ascend NPU using ``torch_npu.npu_grouped_matmul``. 

137 

138 Requires ``torch_npu`` to be installed. 

139 

140 Args: 

141 w1: Shape ``[num_experts, hidden_dim, dim]``. 

142 w2: Shape ``[num_experts, dim, hidden_dim]``. 

143 w3: Shape ``[num_experts, hidden_dim, dim]``. 

144 x: Shape ``[total_routed_tokens, dim]``. 

145 num_tokens_per_expert: 1-D integer tensor of length ``num_experts``. 

146 

147 Returns: 

148 Expert output of shape ``[total_routed_tokens, dim]``. 

149 """ 

150 import torch_npu # pylint: disable=C0415 

151 

152 # npu_grouped_matmul computes y = x @ weight (no implicit transpose). 

153 # Our weight storage is [num_experts, out_dim, in_dim], matching F.linear's 

154 # convention (weight.T for y = x @ weight.T). Transpose each expert shard 

155 # so the shapes satisfy: [tokens, in_dim] @ [in_dim, out_dim] = [tokens, out_dim]. 

156 num_experts = w1.shape[0] 

157 counts = num_tokens_per_expert.tolist() 

158 x_list = list(torch.split(x, counts, dim=0)) 

159 # w1, w3: [E, hidden_dim, dim] → transposed per-expert: [dim, hidden_dim] 

160 # w2: [E, dim, hidden_dim] → transposed per-expert: [hidden_dim, dim] 

161 w1_list = [w1[e].T.contiguous() for e in range(num_experts)] 

162 w2_list = [w2[e].T.contiguous() for e in range(num_experts)] 

163 w3_list = [w3[e].T.contiguous() for e in range(num_experts)] 

164 

165 # npu_grouped_matmul: multi-multi-multi mode (x[i] @ weight[i]). 

166 # group_type=-1 selects independent per-expert matmul (no shared axis). 

167 h1_list = torch_npu.npu_grouped_matmul(x_list, w1_list, group_type=-1) 

168 h3_list = torch_npu.npu_grouped_matmul(x_list, w3_list, group_type=-1) 

169 h_list = [F.silu(h1) * h3 for h1, h3 in zip(h1_list, h3_list)] 

170 out_list = torch_npu.npu_grouped_matmul(h_list, w2_list, group_type=-1) 

171 return torch.cat(out_list, dim=0) 

172 

173 

174# --------------------------------------------------------------------------- 

175# FeedForward — shared expert / standard SwiGLU FFN 

176# --------------------------------------------------------------------------- 

177 

178class FeedForward(nn.Module): 

179 """SwiGLU feed-forward network, used as a shared (always-active) expert. 

180 

181 Implements: ``output = w2(silu(w1(x)) * w3(x))`` 

182 

183 Args: 

184 dim: Input embedding dimension. 

185 hidden_dim: Intermediate hidden dimension. 

186 bias: Whether to add a learnable bias. Defaults to ``False``. 

187 

188 Example:: 

189 >>> ff = FeedForward(dim=256, hidden_dim=512) 

190 >>> out = ff(torch.randn(4, 16, 256)) 

191 >>> out.shape 

192 torch.Size([4, 16, 256]) 

193 """ 

194 

195 def __init__(self, dim: int, hidden_dim: int, bias: bool = False) -> None: 

196 super().__init__() 

197 self.w1 = nn.Linear(dim, hidden_dim, bias=bias) 

198 self.w2 = nn.Linear(hidden_dim, dim, bias=bias) 

199 self.w3 = nn.Linear(dim, hidden_dim, bias=bias) 

200 

201 def forward(self, x: torch.Tensor) -> torch.Tensor: 

202 """Compute SwiGLU feed-forward output. 

203 

204 Args: 

205 x: Input tensor of shape ``(..., dim)``. 

206 

207 Returns: 

208 Output tensor with the same leading shape and last dimension ``dim``. 

209 """ 

210 return self.w2(F.silu(self.w1(x)) * self.w3(x)) 

211 

212 

213# --------------------------------------------------------------------------- 

214# GroupedExperts 

215# --------------------------------------------------------------------------- 

216 

217class GroupedExperts(nn.Module): 

218 """Batch expert computation with optional grouped matrix-multiply. 

219 

220 All expert weights are stored in a single 3-D parameter so that EP / TP 

221 sharding strategies can distribute the expert dimension via DTensor. 

222 

223 Args: 

224 dim: Token embedding dimension. 

225 hidden_dim: Expert hidden dimension (SwiGLU intermediate size). 

226 num_experts: Total number of experts. 

227 use_grouped_mm: If ``True``, uses a hardware-accelerated grouped 

228 matmul kernel (``torch._grouped_mm`` on GPU, 

229 ``torch_npu.npu_grouped_matmul`` on NPU). Falls back to the 

230 for-loop path when neither is available. Defaults to ``False``. 

231 

232 Note: 

233 When weights are DTensors (e.g. after TP sharding via 

234 :class:`ExpertTensorParallel`), ``forward`` calls ``.to_local()`` 

235 before computation. 

236 

237 Example:: 

238 >>> experts = GroupedExperts(dim=8, hidden_dim=16, num_experts=4) 

239 >>> x = torch.randn(10, 8) 

240 >>> counts = torch.tensor([3, 2, 4, 1]) 

241 >>> out = experts(x, counts) 

242 >>> out.shape 

243 torch.Size([10, 8]) 

244 """ 

245 

246 def __init__( 

247 self, 

248 dim: int, 

249 hidden_dim: int, 

250 num_experts: int, 

251 use_grouped_mm: bool = False, 

252 ) -> None: 

253 super().__init__() 

254 # Weight layout: [num_experts, out_dim, in_dim] so that the standard 

255 # linear operation is x @ w[e].T. 

256 self.w1 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) 

257 self.w2 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) 

258 self.w3 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) 

259 self.num_experts = num_experts 

260 self.use_grouped_mm = use_grouped_mm 

261 self._reset_parameters() 

262 

263 def _reset_parameters(self) -> None: 

264 """Kaiming-uniform initialisation for all expert weight tensors.""" 

265 for weight in (self.w1, self.w2, self.w3): 

266 nn.init.kaiming_uniform_(weight.view(weight.shape[0], -1), a=math.sqrt(5)) 

267 

268 def forward( 

269 self, 

270 x: torch.Tensor, 

271 num_tokens_per_expert: torch.Tensor, 

272 ) -> torch.Tensor: 

273 """Run all experts on their assigned tokens. 

274 

275 Args: 

276 x: Routed tokens in expert-major order, 

277 shape ``[total_routed_tokens, dim]``. 

278 num_tokens_per_expert: 1-D integer tensor of length 

279 ``num_local_experts`` with the token count per expert. 

280 

281 Returns: 

282 Expert output of shape ``[total_routed_tokens, dim]``. 

283 """ 

284 # Extract local shard when parameters are DTensors (TP path). 

285 w1 = self.w1.to_local() if isinstance(self.w1, DTensor) else self.w1 

286 w2 = self.w2.to_local() if isinstance(self.w2, DTensor) else self.w2 

287 w3 = self.w3.to_local() if isinstance(self.w3, DTensor) else self.w3 

288 

289 if not self.use_grouped_mm: 

290 return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert) 

291 

292 if hasattr(torch, 'npu') and torch.npu.is_available(): 

293 return _run_experts_grouped_mm_npu(w1, w2, w3, x, num_tokens_per_expert) 

294 if torch.cuda.is_available(): 

295 return _run_experts_grouped_mm_gpu(w1, w2, w3, x, num_tokens_per_expert) 

296 

297 return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert) 

298 

299 

300# --------------------------------------------------------------------------- 

301# TokenChoiceTopKRouter 

302# --------------------------------------------------------------------------- 

303 

304class TokenChoiceTopKRouter(nn.Module): 

305 """Top-K router: each token independently selects its top-K experts. 

306 

307 Args: 

308 dim: Token embedding dimension (input to the gate). 

309 num_experts: Total number of experts. 

310 top_k: Experts selected per token. Defaults to ``1``. 

311 score_func: Activation on gate logits before topk. 

312 One of ``"sigmoid"`` (default) or ``"softmax"``. 

313 num_expert_groups: For node-limited routing, number of expert 

314 groups. ``None`` disables node-limited routing. 

315 num_limited_groups: Groups to keep in node-limited routing. Required 

316 when ``num_expert_groups`` is not ``None``. 

317 route_scale: Scalar multiplier applied to routing scores. 

318 Defaults to ``1.0``. 

319 

320 Example:: 

321 >>> router = TokenChoiceTopKRouter(dim=64, num_experts=8, top_k=2) 

322 >>> scores, indices, counts = router(torch.randn(32, 64)) 

323 >>> scores.shape, indices.shape, counts.shape 

324 (torch.Size([32, 2]), torch.Size([32, 2]), torch.Size([8])) 

325 """ 

326 

327 def __init__( 

328 self, 

329 dim: int, 

330 num_experts: int, 

331 top_k: int = 1, 

332 score_func: str = "sigmoid", 

333 num_expert_groups: Optional[int] = None, 

334 num_limited_groups: Optional[int] = None, 

335 route_scale: float = 1.0, 

336 ) -> None: 

337 super().__init__() 

338 if score_func not in ("sigmoid", "softmax"): 

339 raise ValueError( 

340 f"score_func must be 'sigmoid' or 'softmax', got '{score_func}'." 

341 ) 

342 if num_expert_groups is not None and num_limited_groups is None: 

343 raise ValueError( 

344 "num_limited_groups must be set when num_expert_groups is not None." 

345 ) 

346 self.gate = nn.Linear(dim, num_experts, bias=False) 

347 self.num_experts = num_experts 

348 self.top_k = top_k 

349 self.score_func = score_func 

350 self.num_expert_groups = num_expert_groups 

351 self.num_limited_groups = num_limited_groups 

352 self.route_scale = route_scale 

353 

354 def forward( 

355 self, 

356 x: torch.Tensor, 

357 expert_bias: Optional[torch.Tensor] = None, 

358 ): 

359 """Compute routing scores and top-K expert assignments. 

360 

361 Args: 

362 x: Token tensor of shape ``[num_tokens, dim]``. 

363 expert_bias: Optional 1-D tensor of shape ``[num_experts]`` added 

364 to gate logits for topk selection only (auxiliary-loss-free 

365 load balancing). Does not affect the returned ``top_scores``. 

366 

367 Returns: 

368 Tuple of: 

369 

370 - ``top_scores``: shape ``[num_tokens, top_k]`` — routing weights. 

371 - ``selected_experts``: shape ``[num_tokens, top_k]`` — expert IDs. 

372 - ``num_tokens_per_expert``: shape ``[num_experts]`` — load counts. 

373 """ 

374 # Gate in float32 for numerical stability. 

375 scores = self.gate(x).float() 

376 

377 if self.score_func == "sigmoid": 

378 scores = torch.sigmoid(scores) 

379 else: 

380 scores = F.softmax(scores, dim=-1) 

381 

382 if self.route_scale != 1.0: 

383 scores = scores * self.route_scale 

384 

385 # Node-limited routing — mask out low-scoring expert groups. 

386 scores_for_topk = scores 

387 if self.num_expert_groups is not None: 

388 scores_for_topk = self._get_node_limited_routing_scores(scores) 

389 

390 # Add expert bias only for selection; returned scores remain unbiased. 

391 scores_with_bias = scores_for_topk 

392 if expert_bias is not None: 

393 scores_with_bias = scores_for_topk + expert_bias.float() 

394 

395 top_scores, selected_experts = scores_with_bias.topk(self.top_k, dim=-1) 

396 # Gather unbiased scores as the actual routing weights. 

397 top_scores = scores.gather(1, selected_experts) 

398 

399 num_tokens_per_expert = torch.bincount( 

400 selected_experts.flatten(), minlength=self.num_experts 

401 ) 

402 return top_scores, selected_experts, num_tokens_per_expert 

403 

404 def _get_node_limited_routing_scores(self, scores: torch.Tensor) -> torch.Tensor: 

405 """Mask out low-scoring expert groups (node-limited routing). 

406 

407 Args: 

408 scores: Routing scores of shape ``[num_tokens, num_experts]``. 

409 

410 Returns: 

411 Scores with non-selected groups masked to ``-inf``. 

412 """ 

413 num_tokens, num_experts = scores.shape 

414 experts_per_group = num_experts // self.num_expert_groups 

415 group_scores = scores.view( 

416 num_tokens, self.num_expert_groups, experts_per_group 

417 ).max(dim=-1).values # [num_tokens, num_groups] 

418 

419 _, selected_groups = group_scores.topk(self.num_limited_groups, dim=-1) 

420 

421 mask = scores.new_zeros(num_tokens, self.num_expert_groups) 

422 mask.scatter_(1, selected_groups, 1.0) 

423 mask = ( 

424 mask.unsqueeze(-1) 

425 .expand(num_tokens, self.num_expert_groups, experts_per_group) 

426 .reshape(num_tokens, num_experts) 

427 ) 

428 return scores.masked_fill(mask == 0, float("-inf")) 

429 

430 

431# --------------------------------------------------------------------------- 

432# Load-balance auxiliary loss 

433# --------------------------------------------------------------------------- 

434 

435def _compute_load_balance_loss( 

436 top_scores: torch.Tensor, 

437 selected_experts: torch.Tensor, 

438 num_experts: int, 

439) -> torch.Tensor: 

440 """Compute load-balance auxiliary loss. 

441 

442 Standard formulation: ``loss = num_experts * Σ fraction_i * mean_score_i`` 

443 

444 Args: 

445 top_scores: Routing weights, shape ``[num_tokens, top_k]``. 

446 selected_experts: Expert IDs, shape ``[num_tokens, top_k]``. 

447 num_experts: Total number of experts. 

448 

449 Returns: 

450 Scalar loss tensor. 

451 """ 

452 num_tokens, top_k = top_scores.shape 

453 flat_experts = selected_experts.flatten() # [num_tokens * top_k] 

454 

455 # Fraction of tokens sent to each expert (soft, uses routing probabilities). 

456 one_hot = torch.zeros( 

457 num_tokens * top_k, num_experts, 

458 dtype=top_scores.dtype, device=top_scores.device, 

459 ).scatter_(1, flat_experts.unsqueeze(1), 1.0) 

460 # [num_tokens, top_k, num_experts] → mean per expert 

461 expert_fraction = one_hot.view(num_tokens, top_k, num_experts).float().mean(dim=(0, 1)) 

462 

463 # Mean routing probability per expert. 

464 prob = F.softmax(top_scores.float(), dim=-1) # [num_tokens, top_k] 

465 mean_score = prob.mean(dim=0) # [top_k] 

466 

467 # Scalar: num_experts * dot(mean_score, expert_fraction.T) 

468 loss = num_experts * (mean_score.unsqueeze(-1) * expert_fraction.unsqueeze(0)).sum() 

469 return loss 

470 

471 

472# --------------------------------------------------------------------------- 

473# MoE orchestrator 

474# --------------------------------------------------------------------------- 

475 

476class MoE(nn.Module): 

477 """Mixture-of-Experts layer. 

478 

479 Orchestrates routing, token permutation, expert computation, and output 

480 scatter-add. Supports shared experts, auxiliary-loss-free load balancing 

481 via expert bias, node-limited routing, and auxiliary load-balance loss. 

482 

483 Args: 

484 dim: Token embedding dimension. 

485 hidden_dim: Expert hidden dimension. 

486 num_experts: Total number of experts. 

487 top_k: Experts selected per token. Defaults to ``1``. 

488 score_before_experts: If ``True``, multiply routed tokens by routing 

489 weights *before* expert computation; otherwise multiply expert 

490 outputs *after*. Defaults to ``True``. 

491 load_balance_coeff: When not ``None``, attaches an auxiliary load- 

492 balance loss as ``output._load_balance_loss``. 

493 shared_expert: Optional :class:`FeedForward` running on every token 

494 in parallel; output added to routed-expert output. 

495 router_kwargs: Extra keyword arguments forwarded to 

496 :class:`TokenChoiceTopKRouter`. 

497 use_grouped_mm: If ``True``, uses a hardware-accelerated grouped 

498 matmul kernel (e.g. ``npu_grouped_matmul``) inside 

499 :class:`GroupedExperts`. Defaults to ``False``. 

500 

501 Note: 

502 *Auxiliary-loss-free load balancing*: call :func:`update_expert_bias` 

503 once per optimiser step to adjust ``expert_bias`` from the accumulated 

504 ``tokens_per_expert`` histogram. 

505 

506 Example:: 

507 >>> moe = MoE(dim=64, hidden_dim=128, num_experts=8, top_k=2) 

508 >>> out = moe(torch.randn(2, 16, 64)) 

509 >>> out.shape 

510 torch.Size([2, 16, 64]) 

511 """ 

512 

513 def __init__( 

514 self, 

515 dim: int, 

516 hidden_dim: int, 

517 num_experts: int, 

518 top_k: int = 1, 

519 score_before_experts: bool = True, 

520 load_balance_coeff: Optional[float] = None, 

521 shared_expert: Optional[FeedForward] = None, 

522 router_kwargs: Optional[dict] = None, 

523 use_grouped_mm: bool = False, 

524 ) -> None: 

525 super().__init__() 

526 router_kw = router_kwargs or {} 

527 self.experts = GroupedExperts( 

528 dim=dim, hidden_dim=hidden_dim, num_experts=num_experts, 

529 use_grouped_mm=use_grouped_mm, 

530 ) 

531 self.router = TokenChoiceTopKRouter( 

532 dim=dim, num_experts=num_experts, top_k=top_k, **router_kw, 

533 ) 

534 self.shared_expert = shared_expert 

535 self.num_experts = num_experts 

536 self.top_k = top_k 

537 self.score_before_experts = score_before_experts 

538 self.load_balance_coeff = load_balance_coeff 

539 

540 # Auxiliary-loss-free load-balance buffers (no gradient). 

541 self.register_buffer("expert_bias", torch.zeros(num_experts)) 

542 self.register_buffer("tokens_per_expert", torch.zeros(num_experts)) 

543 

544 def forward(self, x: torch.Tensor) -> torch.Tensor: 

545 """Run the MoE layer. 

546 

547 Args: 

548 x: Input tensor of shape ``[batch, seq_len, dim]``. 

549 

550 Returns: 

551 Output tensor of shape ``[batch, seq_len, dim]``. When 

552 ``load_balance_coeff`` is set, carries a ``_load_balance_loss`` 

553 attribute with the auxiliary loss scalar. 

554 """ 

555 # Extract local tensor if input arrives as DTensor (TP path). 

556 if isinstance(x, DTensor): 

557 x = x.to_local() 

558 

559 bs, seq_len, dim = x.shape 

560 num_tokens = bs * seq_len 

561 x_flat = x.view(num_tokens, dim) # [num_tokens, dim] 

562 

563 # --- Routing --- 

564 top_scores, selected_experts, token_counts = self.router( 

565 x_flat, self.expert_bias 

566 ) 

567 

568 # Accumulate token histogram without creating gradient nodes. 

569 with torch.no_grad(): 

570 self.tokens_per_expert.add_(token_counts.float()) 

571 

572 # --- Token permutation: token-major → expert-major (inline argsort) --- 

573 # flat_experts[i] is the expert ID for the i-th (token, top_k) slot. 

574 flat_experts = selected_experts.flatten() # [num_tokens * top_k] 

575 flat_indices = flat_experts.argsort(stable=True) # expert-major permutation 

576 top_scores_sorted = top_scores.flatten()[flat_indices] # [num_tokens * top_k] 

577 # Each entry in flat_indices maps to a position in [0, num_tokens * top_k); 

578 # divide by top_k to recover the original token row index. 

579 token_indices = flat_indices // self.top_k # [num_tokens * top_k] 

580 num_tokens_per_expert = torch.bincount( 

581 flat_experts, minlength=self.num_experts 

582 ) 

583 

584 # Gather routed tokens in expert-major order. 

585 routed_x = x_flat[token_indices] # [num_tokens * top_k, dim] 

586 

587 if self.score_before_experts: 

588 routed_x = routed_x * top_scores_sorted.unsqueeze(1) 

589 

590 # --- Shared expert (parallel with routed experts) --- 

591 shared_out = None 

592 if self.shared_expert is not None: 

593 shared_out = self.shared_expert(x_flat) 

594 

595 # --- Expert computation --- 

596 expert_out = self.experts(routed_x, num_tokens_per_expert) 

597 

598 if not self.score_before_experts: 

599 expert_out = expert_out * top_scores_sorted.unsqueeze(1) 

600 

601 # --- Scatter expert outputs back to token order --- 

602 # Use out-of-place ``scatter_add`` so autograd correctly records 

603 # ``ScatterAddBackward``; ``new_zeros + scatter_add_`` on some 

604 # backends (torch_npu) leaves the leaf un-upgraded and the result 

605 # without a ``grad_fn``. 

606 out = torch.zeros( 

607 num_tokens, dim, dtype=x_flat.dtype, device=x_flat.device, 

608 ).scatter_add( 

609 0, 

610 token_indices.unsqueeze(1).expand(-1, dim), 

611 expert_out, 

612 ) 

613 

614 if shared_out is not None: 

615 out = out + shared_out 

616 

617 result = out.view(bs, seq_len, dim) 

618 

619 # Auxiliary load-balance loss attached to the returned tensor. 

620 if self.load_balance_coeff is not None: 

621 lb_loss = self.load_balance_coeff * _compute_load_balance_loss( 

622 top_scores, selected_experts, self.num_experts 

623 ) 

624 result._load_balance_loss = lb_loss # pylint: disable=protected-access 

625 

626 return result 

627 

628 

629# --------------------------------------------------------------------------- 

630# Expert bias update for auxiliary-loss-free load balancing 

631# --------------------------------------------------------------------------- 

632 

633def update_expert_bias(moe: MoE, lr: float = 1e-3) -> None: 

634 """Update expert bias for auxiliary-loss-free load balancing. 

635 

636 Should be called once per training step after the optimiser step. 

637 Adjusts ``moe.expert_bias`` to push token load towards the mean, then 

638 resets the ``tokens_per_expert`` accumulator. 

639 

640 Args: 

641 moe: The :class:`MoE` module whose bias should be updated. 

642 lr: Step size for the bias update. Defaults to ``1e-3``. 

643 

644 Example:: 

645 >>> # After optimizer.step(): 

646 >>> update_expert_bias(moe_layer, lr=1e-3) 

647 """ 

648 with torch.no_grad(): 

649 avg = moe.tokens_per_expert.float().mean() 

650 moe.expert_bias.data += lr * (avg - moe.tokens_per_expert.float()).sign() 

651 moe.tokens_per_expert.zero_()