Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / pipeline_parallel / comm_compute_overlap.py: 21%

58 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Two-thread comm/compute overlap orchestrator. 

16 

17This module provides :class:`CommComputeOverlap`, a helper that wraps 

18MoE-style dispatch / combine phases with four synchronization hooks 

19(``A``, ``B``, ``C``, ``D``) and drives a forward + backward pass on two 

20threads with deterministic comm-first dispatch ordering via 

21:class:`HookCoordinator`. 

22 

23The mechanism is independent of any specific pipeline schedule. It is 

24typically driven by the ``OVERLAP_B_F`` callback registered on a 

25schedule (e.g. ``ScheduleInterleaved1F1B(overlap_b_f=True)``), but the 

26same orchestrator could be reused by other concurrent-dispatch overlap 

27scenarios (TP+CP, FSDP prefetch, etc.) without modification. 

28 

29Every rendezvous is a strict COMM + COMPUTE pair — including layer 

30boundaries — so the NCCL kernel is always enqueued before the paired 

31compute kernel:: 

32 

33 [A] ─► dispatch ─► [B] ─► module ─► [C] ─► combine ─► [D] ─► (Attention) ─► [A_next] 

34 

35At layer boundaries the D / A hooks coordinate combine (COMM) with the 

36other thread's Attention (COMPUTE), preserving overlap across layers. 

37 

38Typical integration:: 

39 

40 overlap = CommComputeOverlap() 

41 

42 # Wrap the expert-parallel dispatch / combine callables: 

43 wrapped_dispatch = overlap.wrap_dispatch(original_dispatch) 

44 wrapped_combine = overlap.wrap_combine(original_combine) 

45 

46 # At schedule time, run forward and backward in parallel: 

47 overlap.run( 

48 fwd_fn=lambda: fwd_stage.forward_one_chunk(mb, *args), 

49 bwd_fn=lambda: bwd_stage.backward_one_chunk(mb, loss=loss), 

50 ) 

51""" 

52import threading 

53from typing import Callable 

54 

55from hyper_parallel.platform import get_platform 

56from hyper_parallel.core.pipeline_parallel.hook_coordinator import HookCoordinator 

57 

58platform = get_platform() 

59 

60 

61class CommComputeOverlap: 

62 """Orchestrator for two-thread comm/compute overlap. 

63 

64 Manages a :class:`HookCoordinator` and provides helpers to insert the 

65 four synchronization hooks (``A``, ``B``, ``C``, ``D``) around MoE 

66 dispatch / combine phases and to run forward + backward concurrently 

67 with deterministic comm-first kernel launch ordering. 

68 

69 Example: 

70 >>> overlap = CommComputeOverlap() 

71 >>> wrapped_dispatch = overlap.wrap_dispatch(ep_dispatch_fn) 

72 >>> wrapped_combine = overlap.wrap_combine(ep_combine_fn, is_last_layer=is_last) 

73 >>> overlap.run(fwd_fn, bwd_fn) # doctest: +SKIP 

74 """ 

75 

76 def __init__(self) -> None: 

77 self._coordinator = HookCoordinator() 

78 

79 @property 

80 def coordinator(self) -> HookCoordinator: 

81 """The underlying :class:`HookCoordinator` instance.""" 

82 return self._coordinator 

83 

84 # ------------------------------------------------------------------ 

85 # Wrapping helpers 

86 # ------------------------------------------------------------------ 

87 

88 def wrap_dispatch(self, dispatch_fn: Callable) -> Callable: 

89 """Return a wrapped version of ``dispatch_fn`` bracketed by hooks A/B. 

90 

91 The returned callable inserts synchronization hooks on the **first 

92 positional tensor argument** before and after the call:: 

93 

94 A ─► dispatch_fn ─► B 

95 

96 Args: 

97 dispatch_fn: The original dispatch callable. 

98 

99 Returns: 

100 A new callable with the same signature. 

101 """ 

102 coordinator = self._coordinator 

103 

104 def _wrapped(*args, **kwargs): 

105 first, rest = args[0], args[1:] 

106 first = platform.differentiable_sync_hook(first, "A", coordinator) 

107 result = dispatch_fn(first, *rest, **kwargs) 

108 if isinstance(result, tuple): 

109 hooked = platform.differentiable_sync_hook(result[0], "B", coordinator) 

110 return (hooked,) + result[1:] 

111 return platform.differentiable_sync_hook(result, "B", coordinator) 

112 

113 return _wrapped 

114 

115 def wrap_combine(self, combine_fn: Callable, is_last_layer: bool = False) -> Callable: 

116 """Return a wrapped version of ``combine_fn`` bracketed by hooks C/D. 

117 

118 The returned callable inserts synchronization hooks on the **first 

119 positional tensor argument** before and after the call:: 

120 

121 C ─► combine_fn ─► D 

122 

123 Args: 

124 combine_fn: The original combine callable. 

125 is_last_layer: If ``True``, the closing D hook is tagged 

126 ``"D_LAST"`` so the rendezvous is skipped both in 

127 forward (no Attention follows the last layer) and in 

128 backward (this is the first BWD hook to fire and 

129 combine.bwd has already dispatched freely). Tagging 

130 this hook statically replaces the old runtime cycle 

131 counter and BWD-D-skip mechanisms. 

132 

133 Returns: 

134 A new callable with the same signature. 

135 """ 

136 coordinator = self._coordinator 

137 d_hook = "D_LAST" if is_last_layer else "D" 

138 

139 def _wrapped(*args, **kwargs): 

140 first, rest = args[0], args[1:] 

141 first = platform.differentiable_sync_hook(first, "C", coordinator) 

142 result = combine_fn(first, *rest, **kwargs) 

143 if isinstance(result, tuple): 

144 hooked = platform.differentiable_sync_hook(result[0], d_hook, coordinator) 

145 return (hooked,) + result[1:] 

146 return platform.differentiable_sync_hook(result, d_hook, coordinator) 

147 

148 return _wrapped 

149 

150 # ------------------------------------------------------------------ 

151 # Execution 

152 # ------------------------------------------------------------------ 

153 

154 def run( 

155 self, 

156 fwd_fn: Callable[[], None], 

157 bwd_fn: Callable[[], None], 

158 ) -> None: 

159 """Run ``fwd_fn`` and ``bwd_fn`` in parallel with comm/compute overlap. 

160 

161 Enables the coordinator, spawns the backward pass on a daemon thread, 

162 and waits for both to complete. Layer-boundary handling is encoded 

163 statically by the ``is_last_layer`` flag passed to 

164 :meth:`wrap_combine` at wrap time, so no per-call layer count is 

165 needed here. 

166 

167 Args: 

168 fwd_fn: Callable that executes the forward pass. 

169 bwd_fn: Callable that executes the backward pass. If it needs a 

170 specific device stream, wrap that logic inside ``bwd_fn``. 

171 

172 Raises: 

173 RuntimeError: If the backward thread raises an exception, it is 

174 re-raised on the main thread after ``join``. 

175 """ 

176 self._coordinator.enable() 

177 

178 exc_box: list = [] 

179 coordinator = self._coordinator 

180 

181 def _bwd_target(): 

182 try: 

183 bwd_fn() 

184 except Exception as exc: # pylint: disable=W0718 

185 exc_box.append(exc) 

186 # BWD died — disable the coordinator so any FWD rendezvous 

187 # waiting on a barrier/event unblocks immediately. Without 

188 # this the FWD thread hangs forever at the very first hook 

189 # it reaches and the outer ``finally`` never runs. 

190 coordinator.disable() 

191 

192 thread = threading.Thread(target=_bwd_target, daemon=True) 

193 thread.start() 

194 

195 fwd_exc: list = [] 

196 try: 

197 fwd_fn() 

198 except Exception as exc: # pylint: disable=W0718 

199 fwd_exc.append(exc) 

200 # Symmetric: if FWD dies, unblock BWD so it can exit. 

201 coordinator.disable() 

202 finally: 

203 # Idempotent in case either side already disabled. 

204 coordinator.disable() 

205 thread.join() 

206 

207 if exc_box: 

208 raise RuntimeError( 

209 "Exception in backward thread during dual-pipe overlap" 

210 ) from exc_box[0] 

211 if fwd_exc: 

212 raise fwd_exc[0]