Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / expert_parallel / expert_parallel.py: 95%
112 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Expert Parallelism distributed strategies.
17Provides token permutation helpers and four parallel styles that compose with
18:class:`~hyper_parallel.core.expert_parallel.moe.GroupedExperts`:
20- :class:`BaseExpertParallel` — abstract base for EP strategies with
21 all-to-all token dispatch/combine.
22- :class:`ExpertParallel` — standard EP: each rank owns a shard of experts;
23 tokens are routed via differentiable all-to-all.
24- :class:`TensorParallel` — TP-only weight sharding for experts with no token
25 dispatch; for use when EP degree = 1.
26- :class:`ExpertTensorParallel` — combined EP + TP on a 2-D mesh ``[ep, tp]``;
27 weights are doubly sharded, dispatch uses the EP sub-mesh.
28"""
29from abc import ABC, abstractmethod
30from typing import Optional
32from hyper_parallel.core.dtensor.device_mesh import DeviceMesh
33from hyper_parallel.core.dtensor.dtensor import (
34 distribute_module,
35 distribute_tensor,
36 _distribute_module_iter_params,
37 _distribute_module_new_parameter,
38 _distribute_module_param_source,
39 _distribute_module_set_param,
40)
41from hyper_parallel.core.dtensor.placement_types import Shard
42from hyper_parallel.core.tensor_parallel.style import ParallelStyle
43from hyper_parallel.platform import get_platform
45platform = get_platform()
46Module = platform.Module
48__all__ = [
49 "BaseExpertParallel",
50 "ExpertParallel",
51 "TensorParallel",
52 "ExpertTensorParallel",
53]
56# ---------------------------------------------------------------------------
57# Token permutation helpers
58# ---------------------------------------------------------------------------
60def _generate_permute_indices(
61 tokens_per_expert_group,
62 experts_per_rank: int,
63 num_ranks: int,
64):
65 """Generate permutation indices for rank-major → expert-major reordering.
67 After all-to-all, received tokens are laid out in rank-major order::
69 [rank0·expert0 tokens | rank0·expert1 tokens | ... |
70 rank1·expert0 tokens | rank1·expert1 tokens | ...]
72 Expert computation requires expert-major order::
74 [all tokens for local expert 0 | all tokens for local expert 1 | ...]
76 Args:
77 tokens_per_expert_group: 1-D integer tensor of shape
78 ``[num_ranks * experts_per_rank]``. Entry ``[r * E + e]`` is the
79 number of tokens received from rank ``r`` for local expert ``e``.
80 experts_per_rank: Number of experts owned by each rank.
81 num_ranks: EP degree (total number of ranks in the EP group).
83 Returns:
84 Tuple of:
86 - ``permuted_indices``: 1-D long tensor of length
87 ``total_received_tokens``. ``permuted_indices[i]`` is the source
88 position in the rank-major buffer for destination position ``i`` in
89 the expert-major buffer.
90 - ``num_tokens_per_expert``: 1-D integer tensor of length
91 ``experts_per_rank`` with the token count per local expert.
92 """
93 counts = tokens_per_expert_group # [num_ranks * experts_per_rank]
95 # num_tokens_per_expert[e] = Σ_r counts[r * E + e]
96 counts_2d = counts.view(num_ranks, experts_per_rank) # [R, E]
97 num_tokens_per_expert = counts_2d.sum(dim=0) # [E]
99 # ``total`` must be a host int because ``arange`` needs a scalar size.
100 # That single D2H drain is unavoidable. Everything else stays on
101 # device — no per-block ``.item()`` in a loop.
102 total = int(num_tokens_per_expert.sum())
103 if total == 0:
104 return counts.new_zeros(0, dtype=counts.dtype), num_tokens_per_expert
106 # ---- Vectorized expert-major permutation, no host stalls -----------
107 # Source offsets in the rank-major receive buffer for each (r, e) block.
108 src_offsets_rm = counts.cumsum(0) - counts # [R*E], starts of each block
109 # Reorder src offsets to expert-major iteration order: block (e, r).
110 src_offsets_em = (
111 src_offsets_rm.view(num_ranks, experts_per_rank).T.contiguous().view(-1)
112 ) # [E*R]
113 # Counts in expert-major iteration order.
114 counts_em = counts_2d.T.contiguous().view(-1) # [E*R]
116 # ``repeat_interleave`` expands each block's src start to one entry per
117 # token in that block — gives the source position of each output token.
118 block_src_starts = src_offsets_em.repeat_interleave(counts_em) # [total]
120 # Destination block starts in expert-major order, then expanded. The
121 # ``arange(total) - dst_block_starts_per_token`` produces 0..n-1 within
122 # each block, i.e. the intra-block offset.
123 dst_block_starts = counts_em.cumsum(0) - counts_em # [E*R]
124 dst_block_starts_per_token = dst_block_starts.repeat_interleave(counts_em)
125 intra = platform.arange(0, total, device=counts.device) - dst_block_starts_per_token
127 permuted_indices = (block_src_starts + intra).long()
128 return permuted_indices, num_tokens_per_expert
131def _permute(x, tokens_per_expert_group, ep_degree: int, num_local_experts: int):
132 """Apply rank-major → expert-major permutation to routed tokens.
134 Args:
135 x: Received token tensor of shape
136 ``[sum(tokens_per_expert_group), *feature_dims]``.
137 tokens_per_expert_group: 1-D integer tensor of shape
138 ``[ep_degree * num_local_experts]`` (output of the first
139 all-to-all that exchanges token counts).
140 ep_degree: EP group size (number of ranks).
141 num_local_experts: Number of experts owned by this rank.
143 Returns:
144 Tuple of:
146 - ``original_shape``: shape of *x* before permutation.
147 - ``permuted_x``: tokens reordered to expert-major layout.
148 - ``permuted_indices``: permutation indices (needed for
149 :func:`_unpermute`).
150 - ``num_tokens_per_expert``: token count per local expert.
151 """
152 original_shape = x.shape
153 permuted_indices, num_tokens_per_expert = _generate_permute_indices(
154 tokens_per_expert_group, num_local_experts, ep_degree
155 )
156 # ``x[permuted_indices]`` works for empty indices too (returns a
157 # shape-0 tensor with a real grad_fn). Avoid the early-return with
158 # ``new_zeros`` which would produce a leaf tensor without grad_fn and
159 # silently break autograd for ranks that happen to receive zero tokens.
160 permuted_x = x[permuted_indices]
161 return original_shape, permuted_x, permuted_indices, num_tokens_per_expert
164def _unpermute(out, original_shape, permuted_indices):
165 """Reverse the permutation applied by :func:`_permute`.
167 Args:
168 out: Expert-major output tensor of shape
169 ``[sum(num_tokens_per_expert), *feature_dims]``.
170 original_shape: Shape before permutation (from :func:`_permute`).
171 permuted_indices: Permutation indices from :func:`_permute`.
173 Returns:
174 Token tensor restored to the rank-major layout received after
175 all-to-all, with shape ``original_shape``.
176 """
177 # ``result[permuted_indices] = out`` is a differentiable scatter that
178 # also handles the empty-index case (no-op assignment, but autograd
179 # still connects ``result`` back to ``out``). Do NOT short-circuit
180 # with a bare ``new_zeros`` — that returns a leaf tensor without
181 # grad_fn and the downstream combine a2a loses its backward path,
182 # which manifests as "element 0 of tensors does not require grad".
183 result = out.new_zeros(*original_shape)
184 result[permuted_indices] = out
185 return result
188# ---------------------------------------------------------------------------
189# BaseExpertParallel — abstract base for all-to-all EP strategies
190# ---------------------------------------------------------------------------
192class BaseExpertParallel(ParallelStyle, ABC):
193 """Abstract base class for Expert Parallel strategies with token dispatch.
195 Subclasses implement :meth:`_partition_fn`, :meth:`_token_dispatch`, and
196 :meth:`_token_combine`; this class wires them into :func:`distribute_module`.
197 """
199 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module:
200 """Apply EP sharding and dispatch/combine hooks to *module*.
202 Args:
203 module: A :class:`~hyper_parallel.core.expert_parallel.moe.GroupedExperts`
204 instance to shard.
205 device_mesh: Device mesh for this EP strategy.
207 Returns:
208 The module with distributed parameters and dispatch/combine hooks.
209 """
210 return distribute_module(
211 module,
212 device_mesh,
213 self._partition_fn,
214 self._token_dispatch,
215 self._token_combine,
216 )
218 @abstractmethod
219 def _partition_fn(
220 self, name: str, module: Module, device_mesh: DeviceMesh
221 ) -> None:
222 """Shard module parameters according to this strategy.
224 Args:
225 name: Submodule name.
226 module: The module whose parameters are being sharded.
227 device_mesh: Device mesh for this EP strategy.
228 """
230 @abstractmethod
231 def _token_dispatch(self, module: Module, inputs, device_mesh: DeviceMesh):
232 """Pre-hook: route input tokens to their assigned ranks.
234 Args:
235 module: The ``GroupedExperts`` module.
236 inputs: Forward inputs tuple.
237 device_mesh: Device mesh for this EP strategy.
239 Returns:
240 Transformed inputs for local expert computation.
241 """
243 @abstractmethod
244 def _token_combine(self, module: Module, routed_output, device_mesh: DeviceMesh):
245 """Post-hook: gather expert outputs back to the originating ranks.
247 Args:
248 module: The ``GroupedExperts`` module.
249 routed_output: Expert output tensor in expert-major order.
250 device_mesh: Device mesh for this EP strategy.
252 Returns:
253 Token tensor in the original token-major layout.
254 """
257# ---------------------------------------------------------------------------
258# ExpertParallel — standard all-to-all EP
259# ---------------------------------------------------------------------------
261class ExpertParallel(BaseExpertParallel):
262 """Expert Parallel: shard experts across ranks via all-to-all token routing.
264 Applies :meth:`apply` to a :class:`GroupedExperts` module:
266 1. **Partition** — distributes expert weights on dim 0 (``Shard(0)``) so
267 each rank holds ``num_experts // ep_degree`` local experts.
268 2. **Token dispatch** (forward pre-hook) — two-step all-to-all:
269 a. Exchange token counts (non-differentiable).
270 b. Exchange actual tokens (differentiable, gradient flows back).
271 Followed by rank-major → expert-major permutation.
272 3. **Token combine** (forward post-hook) — expert-major → rank-major
273 unpermute, then reverse all-to-all (differentiable).
275 All collectives use ``platform.differentiable_all_to_all_single`` /
276 ``platform.all_to_all_single`` — no direct ``torch.distributed`` calls.
278 Args:
279 None
281 Example::
282 >>> ep_style = ExpertParallel()
283 >>> sharded_experts = ep_style.apply(experts_module, ep_device_mesh)
284 """
286 def __init__(self) -> None:
287 # State saved between _token_dispatch and _token_combine within one
288 # forward pass. Safe for standard (non-pipeline) training.
289 self._input_splits: list = []
290 self._output_splits: list = []
291 self._input_shape: Optional[tuple] = None
292 self._permuted_indices = None
294 def _partition_fn(
295 self, name: str, module: Module, device_mesh: DeviceMesh
296 ) -> None:
297 """Shard all expert parameters along dim 0 (expert dimension).
299 Args:
300 name: Submodule name (unused).
301 module: The module whose parameters are being sharded.
302 device_mesh: EP device mesh.
303 """
304 del name
305 for key, param in _distribute_module_iter_params(module):
306 if param is None:
307 continue
308 src = _distribute_module_param_source(param)
309 requires_grad = bool(getattr(param, "requires_grad", True))
310 dt = distribute_tensor(src, device_mesh, [Shard(0)])
311 new_param = _distribute_module_new_parameter(key, dt, requires_grad)
312 _distribute_module_set_param(module, key, new_param)
314 def _token_dispatch(self, module: Module, inputs, device_mesh: DeviceMesh):
315 """Dispatch tokens to their assigned ranks via all-to-all.
317 Called as an ``input_fn`` hook by ``distribute_module``. Receives the
318 module's forward inputs and returns transformed inputs.
320 Args:
321 module: The ``GroupedExperts`` module (unused here).
322 inputs: Tuple ``(routed_input, num_tokens_per_expert)`` where
323 ``routed_input`` has shape ``[total_tokens, dim]`` and
324 ``num_tokens_per_expert`` has shape ``[num_experts]``.
325 device_mesh: EP device mesh (1-D).
327 Returns:
328 Tuple ``(permuted_local_input, local_token_counts)`` ready for
329 local expert computation.
330 """
331 del module
332 routed_input, num_tokens_per_expert = inputs[0], inputs[1]
333 ep_group = device_mesh.get_group()
334 ep_size = device_mesh.size()
335 num_local_experts = num_tokens_per_expert.shape[0] // ep_size
337 # --- Step 1: exchange token counts (no gradient needed) ---
338 # Each rank needs to know how many tokens it will receive from every
339 # other rank (for each local expert). Uses ``async_op=True`` + an
340 # explicit ``handle.wait()`` rather than ``async_op=False`` because
341 # the implicit cross-stream sync is NCCL-only; on HCCL the compute
342 # stream may read ``counts_out`` before the collective write is
343 # visible, producing garbage values that blow up the downstream
344 # ``torch.empty(sum(output_splits), ...)`` allocation.
345 counts_out, handle = platform.all_to_all_single(
346 num_tokens_per_expert,
347 output_shape=[num_tokens_per_expert.shape[0]],
348 group=ep_group,
349 async_op=True,
350 )
351 if handle is not None:
352 handle.wait()
353 # counts_out shape: [ep_size * num_local_experts]
354 # counts_out[r * num_local_experts + e] = tokens from rank r for expert e
356 # --- Step 2: compute input / output splits ---
357 # input_splits[r] = tokens this rank sends to rank r
358 # output_splits[r] = tokens this rank receives from rank r
359 # Reshape to [ep_size, num_local_experts] and sum per rank on device;
360 # a single ``tolist()`` drains the rank-sum vector to host, replacing
361 # ``2 * ep_size`` scalar ``int()`` D2H syncs with 2.
362 input_splits = num_tokens_per_expert.view(ep_size, num_local_experts).sum(dim=1).tolist()
363 output_splits = counts_out.view(ep_size, num_local_experts).sum(dim=1).tolist()
364 self._input_splits = input_splits
365 self._output_splits = output_splits
367 # --- Step 3: exchange actual tokens (differentiable) ---
368 dispatched = platform.differentiable_all_to_all_single(
369 routed_input, input_splits, output_splits, group=ep_group,
370 )
372 # --- Step 4: rank-major → expert-major permutation ---
373 self._input_shape, permuted, self._permuted_indices, local_counts = _permute(
374 dispatched, counts_out, ep_size, num_local_experts
375 )
376 return permuted, local_counts
378 def _token_combine(self, module: Module, routed_output, device_mesh: DeviceMesh):
379 """Gather expert outputs back to the originating ranks via all-to-all.
381 Called as an ``output_fn`` hook by ``distribute_module``.
383 Args:
384 module: The ``GroupedExperts`` module (unused).
385 routed_output: Expert output tensor in expert-major order,
386 shape ``[sum(local_counts), dim]``.
387 device_mesh: EP device mesh (1-D).
389 Returns:
390 Token tensor in the original token-major layout,
391 shape ``[sum(input_splits), dim]``.
392 """
393 del module
394 ep_group = device_mesh.get_group()
396 # expert-major → rank-major
397 unpermuted = _unpermute(routed_output, self._input_shape, self._permuted_indices)
399 # reverse all-to-all (output/input splits are swapped)
400 combined = platform.differentiable_all_to_all_single(
401 unpermuted,
402 self._output_splits, # was output, now becomes input
403 self._input_splits, # was input, now becomes output
404 group=ep_group,
405 )
406 return combined
409# ---------------------------------------------------------------------------
410# TensorParallel — TP-only weight sharding for experts (no token dispatch)
411# ---------------------------------------------------------------------------
413class TensorParallel(ParallelStyle):
414 """Tensor Parallel for expert weights (no token dispatch).
416 Shards the ``GroupedExperts`` weight tensors in the column/row-wise
417 pattern used by standard TP:
419 - ``w1`` / ``w3``: ``Shard(1)`` — column-wise (hidden_dim dimension).
420 - ``w2``: ``Shard(2)`` — row-wise (output dim dimension).
422 Use this when EP degree is 1 and you want TP across experts without
423 any all-to-all token dispatch. Typically combined with the standard
424 :class:`~hyper_parallel.core.tensor_parallel.style.ColwiseParallel` /
425 :class:`~hyper_parallel.core.tensor_parallel.style.RowwiseParallel`
426 pattern for attention layers.
428 Example::
429 >>> tp_style = TensorParallel()
430 >>> sharded_experts = tp_style.apply(experts_module, tp_device_mesh)
431 """
433 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module:
434 """Apply TP weight sharding to *module*.
436 Args:
437 module: A :class:`GroupedExperts` instance.
438 device_mesh: 1-D TP device mesh (``mesh_dim_names=("tp",)``).
440 Returns:
441 The module with TP-sharded expert parameters.
442 """
443 return distribute_module(
444 module,
445 device_mesh,
446 self._partition_fn,
447 )
449 def _partition_fn(
450 self, name: str, module: Module, device_mesh: DeviceMesh
451 ) -> None:
452 """Shard expert weights column-wise (w1/w3) or row-wise (w2).
454 ``GroupedExperts`` weight layout is ``[num_experts, out_dim, in_dim]``
455 so:
457 - ``w1``/``w3``: shard ``Shard(1)`` → split ``hidden_dim``
458 (column-wise analogue).
459 - ``w2``: shard ``Shard(2)`` → split ``in_dim = hidden_dim``
460 (row-wise analogue).
462 Args:
463 name: Submodule name (unused).
464 module: The module whose parameters are being sharded.
465 device_mesh: TP device mesh.
466 """
467 del name
468 for key, param in _distribute_module_iter_params(module):
469 if param is None:
470 continue
471 src = _distribute_module_param_source(param)
472 requires_grad = bool(getattr(param, "requires_grad", True))
473 # w1, w3: column-wise → Shard(1); w2: row-wise → Shard(2).
474 shard_dim = 2 if key == "w2" else 1
475 dt = distribute_tensor(src, device_mesh, [Shard(shard_dim)])
476 new_param = _distribute_module_new_parameter(key, dt, requires_grad)
477 _distribute_module_set_param(module, key, new_param)
480# ---------------------------------------------------------------------------
481# ExpertTensorParallel — combined EP + TP on a 2-D [ep, tp] mesh
482# ---------------------------------------------------------------------------
484class ExpertTensorParallel(ExpertParallel):
485 """Combined Expert + Tensor Parallel on a 2-D ``[ep, tp]`` device mesh.
487 Extends :class:`ExpertParallel` to operate on a 2-D mesh with named
488 dimensions ``"ep"`` and ``"tp"``:
490 - **Partition**: each expert weight ``[num_experts, out, in]`` is doubly
491 sharded — ``Shard(0)`` along the EP dim (expert ownership) and
492 ``Shard(1)``/``Shard(2)`` along the TP dim (column-wise / row-wise).
493 - **Dispatch / Combine**: use only the 1-D ``device_mesh["ep"]`` sub-mesh
494 so that token routing uses EP-group collectives, not the full 2-D mesh.
496 Args:
497 None
499 Example::
500 >>> etp_style = ExpertTensorParallel()
501 >>> sharded = etp_style.apply(experts_module, ep_tp_2d_mesh)
502 """
504 def _partition_fn(
505 self, name: str, module: Module, device_mesh: DeviceMesh
506 ) -> None:
507 """Shard expert weights along both EP (dim 0) and TP (dim 1 or 2).
509 Weight layout ``[num_experts, out_dim, in_dim]``:
511 - ``w1``/``w3``: ``[Shard(0), Shard(1)]`` — EP shards experts,
512 TP splits hidden_dim (column-wise).
513 - ``w2``: ``[Shard(0), Shard(2)]`` — EP shards experts, TP splits
514 the input dimension (row-wise).
516 Args:
517 name: Submodule name (unused).
518 module: The module whose parameters are being sharded.
519 device_mesh: 2-D device mesh with dims ``("ep", "tp")``.
520 """
521 del name
522 for key, param in _distribute_module_iter_params(module):
523 if param is None:
524 continue
525 src = _distribute_module_param_source(param)
526 requires_grad = bool(getattr(param, "requires_grad", True))
527 # EP shards expert ownership (dim 0); TP shards weight dim.
528 tp_dim = 2 if key == "w2" else 1
529 dt = distribute_tensor(src, device_mesh, [Shard(0), Shard(tp_dim)])
530 new_param = _distribute_module_new_parameter(key, dt, requires_grad)
531 _distribute_module_set_param(module, key, new_param)
533 def _token_dispatch(self, module: Module, inputs, device_mesh: DeviceMesh):
534 """Dispatch tokens using only the EP sub-mesh.
536 Args:
537 module: The ``GroupedExperts`` module.
538 inputs: Forward inputs tuple.
539 device_mesh: 2-D device mesh with dims ``("ep", "tp")``.
541 Returns:
542 Transformed inputs for local expert computation.
543 """
544 return super()._token_dispatch(module, inputs, device_mesh["ep"])
546 def _token_combine(self, module: Module, routed_output, device_mesh: DeviceMesh):
547 """Combine tokens using only the EP sub-mesh.
549 Args:
550 module: The ``GroupedExperts`` module.
551 routed_output: Expert output tensor in expert-major order.
552 device_mesh: 2-D device mesh with dims ``("ep", "tp")``.
554 Returns:
555 Token tensor in the original token-major layout.
556 """
557 return super()._token_combine(module, routed_output, device_mesh["ep"])