Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / expert_parallel / expert_parallel.py: 95%

112 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Expert Parallelism distributed strategies. 

16 

17Provides token permutation helpers and four parallel styles that compose with 

18:class:`~hyper_parallel.core.expert_parallel.moe.GroupedExperts`: 

19 

20- :class:`BaseExpertParallel` — abstract base for EP strategies with 

21 all-to-all token dispatch/combine. 

22- :class:`ExpertParallel` — standard EP: each rank owns a shard of experts; 

23 tokens are routed via differentiable all-to-all. 

24- :class:`TensorParallel` — TP-only weight sharding for experts with no token 

25 dispatch; for use when EP degree = 1. 

26- :class:`ExpertTensorParallel` — combined EP + TP on a 2-D mesh ``[ep, tp]``; 

27 weights are doubly sharded, dispatch uses the EP sub-mesh. 

28""" 

29from abc import ABC, abstractmethod 

30from typing import Optional 

31 

32from hyper_parallel.core.dtensor.device_mesh import DeviceMesh 

33from hyper_parallel.core.dtensor.dtensor import ( 

34 distribute_module, 

35 distribute_tensor, 

36 _distribute_module_iter_params, 

37 _distribute_module_new_parameter, 

38 _distribute_module_param_source, 

39 _distribute_module_set_param, 

40) 

41from hyper_parallel.core.dtensor.placement_types import Shard 

42from hyper_parallel.core.tensor_parallel.style import ParallelStyle 

43from hyper_parallel.platform import get_platform 

44 

45platform = get_platform() 

46Module = platform.Module 

47 

48__all__ = [ 

49 "BaseExpertParallel", 

50 "ExpertParallel", 

51 "TensorParallel", 

52 "ExpertTensorParallel", 

53] 

54 

55 

56# --------------------------------------------------------------------------- 

57# Token permutation helpers 

58# --------------------------------------------------------------------------- 

59 

60def _generate_permute_indices( 

61 tokens_per_expert_group, 

62 experts_per_rank: int, 

63 num_ranks: int, 

64): 

65 """Generate permutation indices for rank-major → expert-major reordering. 

66 

67 After all-to-all, received tokens are laid out in rank-major order:: 

68 

69 [rank0·expert0 tokens | rank0·expert1 tokens | ... | 

70 rank1·expert0 tokens | rank1·expert1 tokens | ...] 

71 

72 Expert computation requires expert-major order:: 

73 

74 [all tokens for local expert 0 | all tokens for local expert 1 | ...] 

75 

76 Args: 

77 tokens_per_expert_group: 1-D integer tensor of shape 

78 ``[num_ranks * experts_per_rank]``. Entry ``[r * E + e]`` is the 

79 number of tokens received from rank ``r`` for local expert ``e``. 

80 experts_per_rank: Number of experts owned by each rank. 

81 num_ranks: EP degree (total number of ranks in the EP group). 

82 

83 Returns: 

84 Tuple of: 

85 

86 - ``permuted_indices``: 1-D long tensor of length 

87 ``total_received_tokens``. ``permuted_indices[i]`` is the source 

88 position in the rank-major buffer for destination position ``i`` in 

89 the expert-major buffer. 

90 - ``num_tokens_per_expert``: 1-D integer tensor of length 

91 ``experts_per_rank`` with the token count per local expert. 

92 """ 

93 counts = tokens_per_expert_group # [num_ranks * experts_per_rank] 

94 

95 # num_tokens_per_expert[e] = Σ_r counts[r * E + e] 

96 counts_2d = counts.view(num_ranks, experts_per_rank) # [R, E] 

97 num_tokens_per_expert = counts_2d.sum(dim=0) # [E] 

98 

99 # ``total`` must be a host int because ``arange`` needs a scalar size. 

100 # That single D2H drain is unavoidable. Everything else stays on 

101 # device — no per-block ``.item()`` in a loop. 

102 total = int(num_tokens_per_expert.sum()) 

103 if total == 0: 

104 return counts.new_zeros(0, dtype=counts.dtype), num_tokens_per_expert 

105 

106 # ---- Vectorized expert-major permutation, no host stalls ----------- 

107 # Source offsets in the rank-major receive buffer for each (r, e) block. 

108 src_offsets_rm = counts.cumsum(0) - counts # [R*E], starts of each block 

109 # Reorder src offsets to expert-major iteration order: block (e, r). 

110 src_offsets_em = ( 

111 src_offsets_rm.view(num_ranks, experts_per_rank).T.contiguous().view(-1) 

112 ) # [E*R] 

113 # Counts in expert-major iteration order. 

114 counts_em = counts_2d.T.contiguous().view(-1) # [E*R] 

115 

116 # ``repeat_interleave`` expands each block's src start to one entry per 

117 # token in that block — gives the source position of each output token. 

118 block_src_starts = src_offsets_em.repeat_interleave(counts_em) # [total] 

119 

120 # Destination block starts in expert-major order, then expanded. The 

121 # ``arange(total) - dst_block_starts_per_token`` produces 0..n-1 within 

122 # each block, i.e. the intra-block offset. 

123 dst_block_starts = counts_em.cumsum(0) - counts_em # [E*R] 

124 dst_block_starts_per_token = dst_block_starts.repeat_interleave(counts_em) 

125 intra = platform.arange(0, total, device=counts.device) - dst_block_starts_per_token 

126 

127 permuted_indices = (block_src_starts + intra).long() 

128 return permuted_indices, num_tokens_per_expert 

129 

130 

131def _permute(x, tokens_per_expert_group, ep_degree: int, num_local_experts: int): 

132 """Apply rank-major → expert-major permutation to routed tokens. 

133 

134 Args: 

135 x: Received token tensor of shape 

136 ``[sum(tokens_per_expert_group), *feature_dims]``. 

137 tokens_per_expert_group: 1-D integer tensor of shape 

138 ``[ep_degree * num_local_experts]`` (output of the first 

139 all-to-all that exchanges token counts). 

140 ep_degree: EP group size (number of ranks). 

141 num_local_experts: Number of experts owned by this rank. 

142 

143 Returns: 

144 Tuple of: 

145 

146 - ``original_shape``: shape of *x* before permutation. 

147 - ``permuted_x``: tokens reordered to expert-major layout. 

148 - ``permuted_indices``: permutation indices (needed for 

149 :func:`_unpermute`). 

150 - ``num_tokens_per_expert``: token count per local expert. 

151 """ 

152 original_shape = x.shape 

153 permuted_indices, num_tokens_per_expert = _generate_permute_indices( 

154 tokens_per_expert_group, num_local_experts, ep_degree 

155 ) 

156 # ``x[permuted_indices]`` works for empty indices too (returns a 

157 # shape-0 tensor with a real grad_fn). Avoid the early-return with 

158 # ``new_zeros`` which would produce a leaf tensor without grad_fn and 

159 # silently break autograd for ranks that happen to receive zero tokens. 

160 permuted_x = x[permuted_indices] 

161 return original_shape, permuted_x, permuted_indices, num_tokens_per_expert 

162 

163 

164def _unpermute(out, original_shape, permuted_indices): 

165 """Reverse the permutation applied by :func:`_permute`. 

166 

167 Args: 

168 out: Expert-major output tensor of shape 

169 ``[sum(num_tokens_per_expert), *feature_dims]``. 

170 original_shape: Shape before permutation (from :func:`_permute`). 

171 permuted_indices: Permutation indices from :func:`_permute`. 

172 

173 Returns: 

174 Token tensor restored to the rank-major layout received after 

175 all-to-all, with shape ``original_shape``. 

176 """ 

177 # ``result[permuted_indices] = out`` is a differentiable scatter that 

178 # also handles the empty-index case (no-op assignment, but autograd 

179 # still connects ``result`` back to ``out``). Do NOT short-circuit 

180 # with a bare ``new_zeros`` — that returns a leaf tensor without 

181 # grad_fn and the downstream combine a2a loses its backward path, 

182 # which manifests as "element 0 of tensors does not require grad". 

183 result = out.new_zeros(*original_shape) 

184 result[permuted_indices] = out 

185 return result 

186 

187 

188# --------------------------------------------------------------------------- 

189# BaseExpertParallel — abstract base for all-to-all EP strategies 

190# --------------------------------------------------------------------------- 

191 

192class BaseExpertParallel(ParallelStyle, ABC): 

193 """Abstract base class for Expert Parallel strategies with token dispatch. 

194 

195 Subclasses implement :meth:`_partition_fn`, :meth:`_token_dispatch`, and 

196 :meth:`_token_combine`; this class wires them into :func:`distribute_module`. 

197 """ 

198 

199 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module: 

200 """Apply EP sharding and dispatch/combine hooks to *module*. 

201 

202 Args: 

203 module: A :class:`~hyper_parallel.core.expert_parallel.moe.GroupedExperts` 

204 instance to shard. 

205 device_mesh: Device mesh for this EP strategy. 

206 

207 Returns: 

208 The module with distributed parameters and dispatch/combine hooks. 

209 """ 

210 return distribute_module( 

211 module, 

212 device_mesh, 

213 self._partition_fn, 

214 self._token_dispatch, 

215 self._token_combine, 

216 ) 

217 

218 @abstractmethod 

219 def _partition_fn( 

220 self, name: str, module: Module, device_mesh: DeviceMesh 

221 ) -> None: 

222 """Shard module parameters according to this strategy. 

223 

224 Args: 

225 name: Submodule name. 

226 module: The module whose parameters are being sharded. 

227 device_mesh: Device mesh for this EP strategy. 

228 """ 

229 

230 @abstractmethod 

231 def _token_dispatch(self, module: Module, inputs, device_mesh: DeviceMesh): 

232 """Pre-hook: route input tokens to their assigned ranks. 

233 

234 Args: 

235 module: The ``GroupedExperts`` module. 

236 inputs: Forward inputs tuple. 

237 device_mesh: Device mesh for this EP strategy. 

238 

239 Returns: 

240 Transformed inputs for local expert computation. 

241 """ 

242 

243 @abstractmethod 

244 def _token_combine(self, module: Module, routed_output, device_mesh: DeviceMesh): 

245 """Post-hook: gather expert outputs back to the originating ranks. 

246 

247 Args: 

248 module: The ``GroupedExperts`` module. 

249 routed_output: Expert output tensor in expert-major order. 

250 device_mesh: Device mesh for this EP strategy. 

251 

252 Returns: 

253 Token tensor in the original token-major layout. 

254 """ 

255 

256 

257# --------------------------------------------------------------------------- 

258# ExpertParallel — standard all-to-all EP 

259# --------------------------------------------------------------------------- 

260 

261class ExpertParallel(BaseExpertParallel): 

262 """Expert Parallel: shard experts across ranks via all-to-all token routing. 

263 

264 Applies :meth:`apply` to a :class:`GroupedExperts` module: 

265 

266 1. **Partition** — distributes expert weights on dim 0 (``Shard(0)``) so 

267 each rank holds ``num_experts // ep_degree`` local experts. 

268 2. **Token dispatch** (forward pre-hook) — two-step all-to-all: 

269 a. Exchange token counts (non-differentiable). 

270 b. Exchange actual tokens (differentiable, gradient flows back). 

271 Followed by rank-major → expert-major permutation. 

272 3. **Token combine** (forward post-hook) — expert-major → rank-major 

273 unpermute, then reverse all-to-all (differentiable). 

274 

275 All collectives use ``platform.differentiable_all_to_all_single`` / 

276 ``platform.all_to_all_single`` — no direct ``torch.distributed`` calls. 

277 

278 Args: 

279 None 

280 

281 Example:: 

282 >>> ep_style = ExpertParallel() 

283 >>> sharded_experts = ep_style.apply(experts_module, ep_device_mesh) 

284 """ 

285 

286 def __init__(self) -> None: 

287 # State saved between _token_dispatch and _token_combine within one 

288 # forward pass. Safe for standard (non-pipeline) training. 

289 self._input_splits: list = [] 

290 self._output_splits: list = [] 

291 self._input_shape: Optional[tuple] = None 

292 self._permuted_indices = None 

293 

294 def _partition_fn( 

295 self, name: str, module: Module, device_mesh: DeviceMesh 

296 ) -> None: 

297 """Shard all expert parameters along dim 0 (expert dimension). 

298 

299 Args: 

300 name: Submodule name (unused). 

301 module: The module whose parameters are being sharded. 

302 device_mesh: EP device mesh. 

303 """ 

304 del name 

305 for key, param in _distribute_module_iter_params(module): 

306 if param is None: 

307 continue 

308 src = _distribute_module_param_source(param) 

309 requires_grad = bool(getattr(param, "requires_grad", True)) 

310 dt = distribute_tensor(src, device_mesh, [Shard(0)]) 

311 new_param = _distribute_module_new_parameter(key, dt, requires_grad) 

312 _distribute_module_set_param(module, key, new_param) 

313 

314 def _token_dispatch(self, module: Module, inputs, device_mesh: DeviceMesh): 

315 """Dispatch tokens to their assigned ranks via all-to-all. 

316 

317 Called as an ``input_fn`` hook by ``distribute_module``. Receives the 

318 module's forward inputs and returns transformed inputs. 

319 

320 Args: 

321 module: The ``GroupedExperts`` module (unused here). 

322 inputs: Tuple ``(routed_input, num_tokens_per_expert)`` where 

323 ``routed_input`` has shape ``[total_tokens, dim]`` and 

324 ``num_tokens_per_expert`` has shape ``[num_experts]``. 

325 device_mesh: EP device mesh (1-D). 

326 

327 Returns: 

328 Tuple ``(permuted_local_input, local_token_counts)`` ready for 

329 local expert computation. 

330 """ 

331 del module 

332 routed_input, num_tokens_per_expert = inputs[0], inputs[1] 

333 ep_group = device_mesh.get_group() 

334 ep_size = device_mesh.size() 

335 num_local_experts = num_tokens_per_expert.shape[0] // ep_size 

336 

337 # --- Step 1: exchange token counts (no gradient needed) --- 

338 # Each rank needs to know how many tokens it will receive from every 

339 # other rank (for each local expert). Uses ``async_op=True`` + an 

340 # explicit ``handle.wait()`` rather than ``async_op=False`` because 

341 # the implicit cross-stream sync is NCCL-only; on HCCL the compute 

342 # stream may read ``counts_out`` before the collective write is 

343 # visible, producing garbage values that blow up the downstream 

344 # ``torch.empty(sum(output_splits), ...)`` allocation. 

345 counts_out, handle = platform.all_to_all_single( 

346 num_tokens_per_expert, 

347 output_shape=[num_tokens_per_expert.shape[0]], 

348 group=ep_group, 

349 async_op=True, 

350 ) 

351 if handle is not None: 

352 handle.wait() 

353 # counts_out shape: [ep_size * num_local_experts] 

354 # counts_out[r * num_local_experts + e] = tokens from rank r for expert e 

355 

356 # --- Step 2: compute input / output splits --- 

357 # input_splits[r] = tokens this rank sends to rank r 

358 # output_splits[r] = tokens this rank receives from rank r 

359 # Reshape to [ep_size, num_local_experts] and sum per rank on device; 

360 # a single ``tolist()`` drains the rank-sum vector to host, replacing 

361 # ``2 * ep_size`` scalar ``int()`` D2H syncs with 2. 

362 input_splits = num_tokens_per_expert.view(ep_size, num_local_experts).sum(dim=1).tolist() 

363 output_splits = counts_out.view(ep_size, num_local_experts).sum(dim=1).tolist() 

364 self._input_splits = input_splits 

365 self._output_splits = output_splits 

366 

367 # --- Step 3: exchange actual tokens (differentiable) --- 

368 dispatched = platform.differentiable_all_to_all_single( 

369 routed_input, input_splits, output_splits, group=ep_group, 

370 ) 

371 

372 # --- Step 4: rank-major → expert-major permutation --- 

373 self._input_shape, permuted, self._permuted_indices, local_counts = _permute( 

374 dispatched, counts_out, ep_size, num_local_experts 

375 ) 

376 return permuted, local_counts 

377 

378 def _token_combine(self, module: Module, routed_output, device_mesh: DeviceMesh): 

379 """Gather expert outputs back to the originating ranks via all-to-all. 

380 

381 Called as an ``output_fn`` hook by ``distribute_module``. 

382 

383 Args: 

384 module: The ``GroupedExperts`` module (unused). 

385 routed_output: Expert output tensor in expert-major order, 

386 shape ``[sum(local_counts), dim]``. 

387 device_mesh: EP device mesh (1-D). 

388 

389 Returns: 

390 Token tensor in the original token-major layout, 

391 shape ``[sum(input_splits), dim]``. 

392 """ 

393 del module 

394 ep_group = device_mesh.get_group() 

395 

396 # expert-major → rank-major 

397 unpermuted = _unpermute(routed_output, self._input_shape, self._permuted_indices) 

398 

399 # reverse all-to-all (output/input splits are swapped) 

400 combined = platform.differentiable_all_to_all_single( 

401 unpermuted, 

402 self._output_splits, # was output, now becomes input 

403 self._input_splits, # was input, now becomes output 

404 group=ep_group, 

405 ) 

406 return combined 

407 

408 

409# --------------------------------------------------------------------------- 

410# TensorParallel — TP-only weight sharding for experts (no token dispatch) 

411# --------------------------------------------------------------------------- 

412 

413class TensorParallel(ParallelStyle): 

414 """Tensor Parallel for expert weights (no token dispatch). 

415 

416 Shards the ``GroupedExperts`` weight tensors in the column/row-wise 

417 pattern used by standard TP: 

418 

419 - ``w1`` / ``w3``: ``Shard(1)`` — column-wise (hidden_dim dimension). 

420 - ``w2``: ``Shard(2)`` — row-wise (output dim dimension). 

421 

422 Use this when EP degree is 1 and you want TP across experts without 

423 any all-to-all token dispatch. Typically combined with the standard 

424 :class:`~hyper_parallel.core.tensor_parallel.style.ColwiseParallel` / 

425 :class:`~hyper_parallel.core.tensor_parallel.style.RowwiseParallel` 

426 pattern for attention layers. 

427 

428 Example:: 

429 >>> tp_style = TensorParallel() 

430 >>> sharded_experts = tp_style.apply(experts_module, tp_device_mesh) 

431 """ 

432 

433 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module: 

434 """Apply TP weight sharding to *module*. 

435 

436 Args: 

437 module: A :class:`GroupedExperts` instance. 

438 device_mesh: 1-D TP device mesh (``mesh_dim_names=("tp",)``). 

439 

440 Returns: 

441 The module with TP-sharded expert parameters. 

442 """ 

443 return distribute_module( 

444 module, 

445 device_mesh, 

446 self._partition_fn, 

447 ) 

448 

449 def _partition_fn( 

450 self, name: str, module: Module, device_mesh: DeviceMesh 

451 ) -> None: 

452 """Shard expert weights column-wise (w1/w3) or row-wise (w2). 

453 

454 ``GroupedExperts`` weight layout is ``[num_experts, out_dim, in_dim]`` 

455 so: 

456 

457 - ``w1``/``w3``: shard ``Shard(1)`` → split ``hidden_dim`` 

458 (column-wise analogue). 

459 - ``w2``: shard ``Shard(2)`` → split ``in_dim = hidden_dim`` 

460 (row-wise analogue). 

461 

462 Args: 

463 name: Submodule name (unused). 

464 module: The module whose parameters are being sharded. 

465 device_mesh: TP device mesh. 

466 """ 

467 del name 

468 for key, param in _distribute_module_iter_params(module): 

469 if param is None: 

470 continue 

471 src = _distribute_module_param_source(param) 

472 requires_grad = bool(getattr(param, "requires_grad", True)) 

473 # w1, w3: column-wise → Shard(1); w2: row-wise → Shard(2). 

474 shard_dim = 2 if key == "w2" else 1 

475 dt = distribute_tensor(src, device_mesh, [Shard(shard_dim)]) 

476 new_param = _distribute_module_new_parameter(key, dt, requires_grad) 

477 _distribute_module_set_param(module, key, new_param) 

478 

479 

480# --------------------------------------------------------------------------- 

481# ExpertTensorParallel — combined EP + TP on a 2-D [ep, tp] mesh 

482# --------------------------------------------------------------------------- 

483 

484class ExpertTensorParallel(ExpertParallel): 

485 """Combined Expert + Tensor Parallel on a 2-D ``[ep, tp]`` device mesh. 

486 

487 Extends :class:`ExpertParallel` to operate on a 2-D mesh with named 

488 dimensions ``"ep"`` and ``"tp"``: 

489 

490 - **Partition**: each expert weight ``[num_experts, out, in]`` is doubly 

491 sharded — ``Shard(0)`` along the EP dim (expert ownership) and 

492 ``Shard(1)``/``Shard(2)`` along the TP dim (column-wise / row-wise). 

493 - **Dispatch / Combine**: use only the 1-D ``device_mesh["ep"]`` sub-mesh 

494 so that token routing uses EP-group collectives, not the full 2-D mesh. 

495 

496 Args: 

497 None 

498 

499 Example:: 

500 >>> etp_style = ExpertTensorParallel() 

501 >>> sharded = etp_style.apply(experts_module, ep_tp_2d_mesh) 

502 """ 

503 

504 def _partition_fn( 

505 self, name: str, module: Module, device_mesh: DeviceMesh 

506 ) -> None: 

507 """Shard expert weights along both EP (dim 0) and TP (dim 1 or 2). 

508 

509 Weight layout ``[num_experts, out_dim, in_dim]``: 

510 

511 - ``w1``/``w3``: ``[Shard(0), Shard(1)]`` — EP shards experts, 

512 TP splits hidden_dim (column-wise). 

513 - ``w2``: ``[Shard(0), Shard(2)]`` — EP shards experts, TP splits 

514 the input dimension (row-wise). 

515 

516 Args: 

517 name: Submodule name (unused). 

518 module: The module whose parameters are being sharded. 

519 device_mesh: 2-D device mesh with dims ``("ep", "tp")``. 

520 """ 

521 del name 

522 for key, param in _distribute_module_iter_params(module): 

523 if param is None: 

524 continue 

525 src = _distribute_module_param_source(param) 

526 requires_grad = bool(getattr(param, "requires_grad", True)) 

527 # EP shards expert ownership (dim 0); TP shards weight dim. 

528 tp_dim = 2 if key == "w2" else 1 

529 dt = distribute_tensor(src, device_mesh, [Shard(0), Shard(tp_dim)]) 

530 new_param = _distribute_module_new_parameter(key, dt, requires_grad) 

531 _distribute_module_set_param(module, key, new_param) 

532 

533 def _token_dispatch(self, module: Module, inputs, device_mesh: DeviceMesh): 

534 """Dispatch tokens using only the EP sub-mesh. 

535 

536 Args: 

537 module: The ``GroupedExperts`` module. 

538 inputs: Forward inputs tuple. 

539 device_mesh: 2-D device mesh with dims ``("ep", "tp")``. 

540 

541 Returns: 

542 Transformed inputs for local expert computation. 

543 """ 

544 return super()._token_dispatch(module, inputs, device_mesh["ep"]) 

545 

546 def _token_combine(self, module: Module, routed_output, device_mesh: DeviceMesh): 

547 """Combine tokens using only the EP sub-mesh. 

548 

549 Args: 

550 module: The ``GroupedExperts`` module. 

551 routed_output: Expert output tensor in expert-major order. 

552 device_mesh: 2-D device mesh with dims ``("ep", "tp")``. 

553 

554 Returns: 

555 Token tensor in the original token-major layout. 

556 """ 

557 return super()._token_combine(module, routed_output, device_mesh["ep"])