Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / pipeline_parallel / comm_compute_overlap.py: 21%
58 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Two-thread comm/compute overlap orchestrator.
17This module provides :class:`CommComputeOverlap`, a helper that wraps
18MoE-style dispatch / combine phases with four synchronization hooks
19(``A``, ``B``, ``C``, ``D``) and drives a forward + backward pass on two
20threads with deterministic comm-first dispatch ordering via
21:class:`HookCoordinator`.
23The mechanism is independent of any specific pipeline schedule. It is
24typically driven by the ``OVERLAP_B_F`` callback registered on a
25schedule (e.g. ``ScheduleInterleaved1F1B(overlap_b_f=True)``), but the
26same orchestrator could be reused by other concurrent-dispatch overlap
27scenarios (TP+CP, FSDP prefetch, etc.) without modification.
29Every rendezvous is a strict COMM + COMPUTE pair — including layer
30boundaries — so the NCCL kernel is always enqueued before the paired
31compute kernel::
33 [A] ─► dispatch ─► [B] ─► module ─► [C] ─► combine ─► [D] ─► (Attention) ─► [A_next]
35At layer boundaries the D / A hooks coordinate combine (COMM) with the
36other thread's Attention (COMPUTE), preserving overlap across layers.
38Typical integration::
40 overlap = CommComputeOverlap()
42 # Wrap the expert-parallel dispatch / combine callables:
43 wrapped_dispatch = overlap.wrap_dispatch(original_dispatch)
44 wrapped_combine = overlap.wrap_combine(original_combine)
46 # At schedule time, run forward and backward in parallel:
47 overlap.run(
48 fwd_fn=lambda: fwd_stage.forward_one_chunk(mb, *args),
49 bwd_fn=lambda: bwd_stage.backward_one_chunk(mb, loss=loss),
50 )
51"""
52import threading
53from typing import Callable
55from hyper_parallel.platform import get_platform
56from hyper_parallel.core.pipeline_parallel.hook_coordinator import HookCoordinator
58platform = get_platform()
61class CommComputeOverlap:
62 """Orchestrator for two-thread comm/compute overlap.
64 Manages a :class:`HookCoordinator` and provides helpers to insert the
65 four synchronization hooks (``A``, ``B``, ``C``, ``D``) around MoE
66 dispatch / combine phases and to run forward + backward concurrently
67 with deterministic comm-first kernel launch ordering.
69 Example:
70 >>> overlap = CommComputeOverlap()
71 >>> wrapped_dispatch = overlap.wrap_dispatch(ep_dispatch_fn)
72 >>> wrapped_combine = overlap.wrap_combine(ep_combine_fn, is_last_layer=is_last)
73 >>> overlap.run(fwd_fn, bwd_fn) # doctest: +SKIP
74 """
76 def __init__(self) -> None:
77 self._coordinator = HookCoordinator()
79 @property
80 def coordinator(self) -> HookCoordinator:
81 """The underlying :class:`HookCoordinator` instance."""
82 return self._coordinator
84 # ------------------------------------------------------------------
85 # Wrapping helpers
86 # ------------------------------------------------------------------
88 def wrap_dispatch(self, dispatch_fn: Callable) -> Callable:
89 """Return a wrapped version of ``dispatch_fn`` bracketed by hooks A/B.
91 The returned callable inserts synchronization hooks on the **first
92 positional tensor argument** before and after the call::
94 A ─► dispatch_fn ─► B
96 Args:
97 dispatch_fn: The original dispatch callable.
99 Returns:
100 A new callable with the same signature.
101 """
102 coordinator = self._coordinator
104 def _wrapped(*args, **kwargs):
105 first, rest = args[0], args[1:]
106 first = platform.differentiable_sync_hook(first, "A", coordinator)
107 result = dispatch_fn(first, *rest, **kwargs)
108 if isinstance(result, tuple):
109 hooked = platform.differentiable_sync_hook(result[0], "B", coordinator)
110 return (hooked,) + result[1:]
111 return platform.differentiable_sync_hook(result, "B", coordinator)
113 return _wrapped
115 def wrap_combine(self, combine_fn: Callable, is_last_layer: bool = False) -> Callable:
116 """Return a wrapped version of ``combine_fn`` bracketed by hooks C/D.
118 The returned callable inserts synchronization hooks on the **first
119 positional tensor argument** before and after the call::
121 C ─► combine_fn ─► D
123 Args:
124 combine_fn: The original combine callable.
125 is_last_layer: If ``True``, the closing D hook is tagged
126 ``"D_LAST"`` so the rendezvous is skipped both in
127 forward (no Attention follows the last layer) and in
128 backward (this is the first BWD hook to fire and
129 combine.bwd has already dispatched freely). Tagging
130 this hook statically replaces the old runtime cycle
131 counter and BWD-D-skip mechanisms.
133 Returns:
134 A new callable with the same signature.
135 """
136 coordinator = self._coordinator
137 d_hook = "D_LAST" if is_last_layer else "D"
139 def _wrapped(*args, **kwargs):
140 first, rest = args[0], args[1:]
141 first = platform.differentiable_sync_hook(first, "C", coordinator)
142 result = combine_fn(first, *rest, **kwargs)
143 if isinstance(result, tuple):
144 hooked = platform.differentiable_sync_hook(result[0], d_hook, coordinator)
145 return (hooked,) + result[1:]
146 return platform.differentiable_sync_hook(result, d_hook, coordinator)
148 return _wrapped
150 # ------------------------------------------------------------------
151 # Execution
152 # ------------------------------------------------------------------
154 def run(
155 self,
156 fwd_fn: Callable[[], None],
157 bwd_fn: Callable[[], None],
158 ) -> None:
159 """Run ``fwd_fn`` and ``bwd_fn`` in parallel with comm/compute overlap.
161 Enables the coordinator, spawns the backward pass on a daemon thread,
162 and waits for both to complete. Layer-boundary handling is encoded
163 statically by the ``is_last_layer`` flag passed to
164 :meth:`wrap_combine` at wrap time, so no per-call layer count is
165 needed here.
167 Args:
168 fwd_fn: Callable that executes the forward pass.
169 bwd_fn: Callable that executes the backward pass. If it needs a
170 specific device stream, wrap that logic inside ``bwd_fn``.
172 Raises:
173 RuntimeError: If the backward thread raises an exception, it is
174 re-raised on the main thread after ``join``.
175 """
176 self._coordinator.enable()
178 exc_box: list = []
179 coordinator = self._coordinator
181 def _bwd_target():
182 try:
183 bwd_fn()
184 except Exception as exc: # pylint: disable=W0718
185 exc_box.append(exc)
186 # BWD died — disable the coordinator so any FWD rendezvous
187 # waiting on a barrier/event unblocks immediately. Without
188 # this the FWD thread hangs forever at the very first hook
189 # it reaches and the outer ``finally`` never runs.
190 coordinator.disable()
192 thread = threading.Thread(target=_bwd_target, daemon=True)
193 thread.start()
195 fwd_exc: list = []
196 try:
197 fwd_fn()
198 except Exception as exc: # pylint: disable=W0718
199 fwd_exc.append(exc)
200 # Symmetric: if FWD dies, unblock BWD so it can exit.
201 coordinator.disable()
202 finally:
203 # Idempotent in case either side already disabled.
204 coordinator.disable()
205 thread.join()
207 if exc_box:
208 raise RuntimeError(
209 "Exception in backward thread during dual-pipe overlap"
210 ) from exc_box[0]
211 if fwd_exc:
212 raise fwd_exc[0]