Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / context_parallel / async_context_parallel.py: 20%
110 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""AsyncContextParallel: overlap projection GEMM with all-to-all communication.
17Supports Pure Ulysses, Hybrid CP modes. Falls back to sync ContextParallel
18when q/k/v_proj not provided or in Pure Colossal AI mode.
20Forward: proj hooks launch async A2A → attn pre-hook waits Q/K/V → attn hook gathers output
21Backward: autograd backward launches async A2A → proj pre-hooks wait before GEMMs
22"""
23from functools import partial
24from typing import Optional, cast
26from hyper_parallel.core.context_parallel.context_parallel import (
27 ContextParallel,
28 _build_2d_mesh,
29 _ensure_1d,
30 _gather_seq,
31 _gather_head_to_seq,
32)
33from hyper_parallel.core.dtensor.device_mesh import DeviceMesh
34from hyper_parallel.core.dtensor.dtensor import DTensor
35from hyper_parallel.core.dtensor.placement_types import Shard, Replicate
36from hyper_parallel.platform import get_platform
38platform = get_platform()
39Module = platform.Module
40Tensor = platform.Tensor
43# ---------------------------------------------------------------------------
44# All-to-all helpers
45# ---------------------------------------------------------------------------
47def _launch_async_a2a_seq_to_head(
48 tensor: Tensor,
49 group,
50 world_size: int,
51 head_dim: int,
52) -> tuple:
53 """Launch async seq→head A2A (forward)."""
54 x = tensor.contiguous()
55 shape = list(x.shape)
56 num_heads = shape[head_dim]
57 if num_heads % world_size != 0:
58 raise ValueError(f"num_heads ({num_heads}) must be divisible by world_size ({world_size}).")
59 ndim = len(shape) + 1
60 x_perm = x.reshape(
61 shape[:head_dim] + [world_size, num_heads // world_size] + shape[head_dim + 1:]
62 ).permute(
63 [head_dim] + list(range(head_dim)) + list(range(head_dim + 1, ndim))
64 ).contiguous()
65 out_perm, work = platform.all_to_all_single(x_perm, list(x_perm.shape), group, async_op=True)
66 return work, out_perm
69def _a2a_reconstruct(out_perm: Tensor, concat_dim: int) -> Tensor:
70 """Reconstruct A2A result from raw out_perm."""
71 new_ndim = out_perm.dim()
72 chunk_in_perm = concat_dim + 1
73 recon_perm = list(range(1, chunk_in_perm)) + [0] + list(range(chunk_in_perm, new_ndim))
74 x_recon = out_perm.permute(recon_perm).contiguous()
75 shape = list(x_recon.shape)
76 merged = shape[concat_dim] * shape[concat_dim + 1]
77 return x_recon.reshape(shape[:concat_dim] + [merged] + shape[concat_dim + 2:])
80# ---------------------------------------------------------------------------
81# AsyncContextParallel
82# ---------------------------------------------------------------------------
84class AsyncContextParallel(ContextParallel):
85 """Context Parallel with projection–A2A compute overlap.
87 Requires ``q_proj``, ``k_proj``, ``v_proj`` in :meth:`apply`; otherwise
88 falls back to synchronous :class:`ContextParallel`.
90 Pure Colossal AI (``ulysses_degree=1``) automatically falls back to sync
91 because K/V AllGather is a barrier collective.
93 Args:
94 seq_dim: Sequence dimension (1=BSHD, 2=BNSD).
95 head_dim: Head dimension (2=BSHD, 1=BNSD).
96 ulysses_degree: Ulysses sub-mesh size (see :class:`ContextParallel`).
97 qkv_indices: Positional indices of (Q, K, V) in attention forward.
98 qkv_kwarg_names: Keyword names for (Q, K, V).
99 load_balance: Load-balance flag forwarded to base class.
100 """
102 def __init__(
103 self,
104 seq_dim: int = 1,
105 head_dim: int = 2,
106 ulysses_degree: Optional[int] = None,
107 qkv_indices: tuple = (0, 1, 2),
108 qkv_kwarg_names: tuple = (),
109 load_balance: bool = False,
110 ):
111 super().__init__(
112 seq_dim=seq_dim,
113 head_dim=head_dim,
114 ulysses_degree=ulysses_degree,
115 qkv_indices=qkv_indices,
116 qkv_kwarg_names=qkv_kwarg_names,
117 load_balance=load_balance,
118 )
120 # ------------------------------------------------------------------
121 # Public entry point
122 # ------------------------------------------------------------------
124 def apply( # pylint: disable=arguments-differ
125 self,
126 module: Module,
127 device_mesh: DeviceMesh,
128 q_proj: Optional[Module] = None,
129 k_proj: Optional[Module] = None,
130 v_proj: Optional[Module] = None,
131 ) -> Module:
132 """Register async-overlap hooks and return *module*.
134 Falls back to synchronous :class:`ContextParallel` if any of
135 ``q/k/v_proj`` is ``None`` or in Pure Colossal AI mode.
137 Args:
138 module: Core-attention submodule.
139 device_mesh: CP device mesh (1-D or 2-D).
140 q_proj: The last module in the Q path whose output is passed
141 directly to the attention module as Q. Its forward
142 post-hook launches the async Q all-to-all. There
143 must be **no** intermediate ops (view, transpose, …)
144 between this module and attention; such ops would be
145 bypassed by the pre-hook substitution and could cause
146 shape mismatches. For models with QK normalization
147 applied right before attention, pass ``qk_norm_q``
148 here instead of the raw projection.
149 k_proj: Same semantics as ``q_proj``, for the K path. Pass
150 ``qk_norm_k`` when the model applies QK-Norm before
151 attention.
152 v_proj: Value projection module (no norm variant needed).
153 """
154 if q_proj is None or k_proj is None or v_proj is None:
155 return super().apply(module, device_mesh)
157 cp_size = device_mesh.mesh.numel()
158 ds = self.ulysses_degree if self.ulysses_degree is not None else cp_size
159 if cp_size % ds != 0:
160 raise ValueError(
161 f"cp_size ({cp_size}) must be divisible by ulysses_degree ({ds})."
162 )
163 co = cp_size // ds
165 if ds == 1:
166 # Pure Colossal AI: K/V AllGather cannot be made async. Fall back.
167 return super().apply(module, device_mesh)
169 # Per-layer handle slots — local to this apply() call, bound to hooks via partial.
170 #
171 # fwd_slots is a plain dict. _proj_post_hook and _wait_a2a both receive the
172 # same dict reference via partial, so a simple assignment fwd_slots[key] = ...
173 # in _proj_post_hook is immediately visible to _wait_a2a — no list wrapper needed.
174 #
175 # bwd_slots[key] is a list held by both _wait_a2a and the autograd wait function
176 # The autograd function receives the list object itself (as handle_box) and appends
177 # to it; _proj_bwd_pre_hook pops from the same list. We cannot use a plain dict
178 # value here because the autograd function would hold a stale reference if we later
179 # reassigned bwd_slots[key].
180 fwd_slots = {"q": None, "k": None, "v": None}
181 bwd_slots = {"q": [], "k": [], "v": []}
183 if co == 1:
184 # Pure Ulysses
185 ds_submesh = _ensure_1d(device_mesh)
186 group = ds_submesh.get_group()
187 self._register_proj_hooks(q_proj, k_proj, v_proj, group=group, world_size=ds,
188 fwd_slots=fwd_slots, bwd_slots=bwd_slots)
189 module.register_forward_pre_hook(
190 partial(self._attn_pre_hook_ulysses, group=group, world_size=ds,
191 fwd_slots=fwd_slots, bwd_slots=bwd_slots)
192 )
193 else:
194 # Hybrid: async Ulysses A2A + sync Colossal AllGather
195 two_d_mesh = _build_2d_mesh(device_mesh, ds, co)
196 dim_names = two_d_mesh.mesh_dim_names
197 assert dim_names is not None, "2-D mesh must have mesh_dim_names (guaranteed by _build_2d_mesh)"
198 ds_submesh = two_d_mesh[dim_names[1]]
199 group = ds_submesh.get_group()
200 self._register_proj_hooks(q_proj, k_proj, v_proj, group=group, world_size=ds,
201 fwd_slots=fwd_slots, bwd_slots=bwd_slots)
202 module.register_forward_pre_hook(
203 partial(self._attn_pre_hook_hybrid, group=group, world_size=ds,
204 two_d_mesh=two_d_mesh, fwd_slots=fwd_slots, bwd_slots=bwd_slots)
205 )
207 module.register_forward_hook(
208 partial(self._attn_post_hook_ata, ds_submesh=ds_submesh)
209 )
210 return module
212 # ------------------------------------------------------------------
213 # Shared: projection hooks registration
214 # ------------------------------------------------------------------
216 def _register_proj_hooks(self, q_proj, k_proj, v_proj, group, world_size, fwd_slots, bwd_slots):
217 """Register forward and backward hooks on all three projection modules."""
218 for key, proj in [("q", q_proj), ("k", k_proj), ("v", v_proj)]:
219 proj.register_forward_hook(
220 partial(self._proj_post_hook, key=key, group=group, world_size=world_size,
221 fwd_slots=fwd_slots)
222 )
223 platform.register_full_backward_pre_hook(
224 proj,
225 partial(self._proj_bwd_pre_hook, bwd_slot=bwd_slots[key])
226 )
228 def _proj_post_hook(self, module, inputs, output, key, group, world_size, fwd_slots): # pylint: disable=unused-argument,too-many-arguments
229 """Launch async seq→head A2A after projection; return original output unchanged."""
230 tensor = output.to_local() if isinstance(output, DTensor) else output
231 fwd_slots[key] = _launch_async_a2a_seq_to_head(
232 tensor, group, world_size, self.head_dim
233 )
234 return output
236 # ------------------------------------------------------------------
237 # Internal: wait for a single pre-launched A2A handle
238 # ------------------------------------------------------------------
240 def _wait_a2a(self, tensor, group, world_size, fwd_slots, key, bwd_slot):
241 """Wait for pre-launched A2A; returns head-scattered tensor (differentiable)."""
242 work, out_perm = fwd_slots[key]
243 fwd_slots[key] = None
244 return platform.differentiable_async_a2a_wait(
245 tensor, work, out_perm, group, world_size,
246 self.seq_dim, self.head_dim, # concat_dim=seq_dim, split_dim=head_dim
247 bwd_slot,
248 )
250 # ------------------------------------------------------------------
251 # Attention pre-hooks
252 # ------------------------------------------------------------------
254 def _attn_pre_hook_ulysses(self, module, args, group, world_size, # pylint: disable=unused-argument,too-many-arguments
255 fwd_slots, bwd_slots):
256 """Wait Q/K/V A2A; return head-scattered args."""
257 q_idx, k_idx, v_idx = self.qkv_indices
258 new_args = list(args)
260 def _local(t):
261 return t.to_local() if isinstance(t, DTensor) else t
263 new_args[q_idx] = self._wait_a2a(_local(new_args[q_idx]), group, world_size,
264 fwd_slots, "q", bwd_slots["q"])
265 new_args[k_idx] = self._wait_a2a(_local(new_args[k_idx]), group, world_size,
266 fwd_slots, "k", bwd_slots["k"])
267 new_args[v_idx] = self._wait_a2a(_local(new_args[v_idx]), group, world_size,
268 fwd_slots, "v", bwd_slots["v"])
269 return tuple(new_args)
271 def _attn_pre_hook_hybrid( # pylint: disable=too-many-locals,too-many-arguments
272 self, module, args, group, world_size, two_d_mesh, # pylint: disable=unused-argument
273 fwd_slots, bwd_slots
274 ):
275 """Wait Ulysses A2A for Q/K/V, AllGather K/V on co-submesh, wrap as 2-D DTensors."""
276 q_idx, k_idx, v_idx = self.qkv_indices
277 new_args = list(args)
279 def _local(t):
280 return t.to_local() if isinstance(t, DTensor) else t
282 # Wait Ulysses A2A for Q and K
283 q_ul = cast(Tensor, self._wait_a2a(_local(new_args[q_idx]), group, world_size,
284 fwd_slots, "q", bwd_slots["q"]))
285 k_ul = cast(Tensor, self._wait_a2a(_local(new_args[k_idx]), group, world_size,
286 fwd_slots, "k", bwd_slots["k"]))
288 # AllGather K on co-submesh (while V A2A is still in flight)
289 co_submesh = two_d_mesh[two_d_mesh.mesh_dim_names[0]]
290 k_full = _gather_seq(k_ul, co_submesh, self.seq_dim)
292 # Wait V A2A, then AllGather V
293 v_ul = cast(Tensor, self._wait_a2a(_local(new_args[v_idx]), group, world_size,
294 fwd_slots, "v", bwd_slots["v"]))
295 v_full = _gather_seq(v_ul, co_submesh, self.seq_dim)
297 def _local_dt(dt):
298 return dt.to_local() if isinstance(dt, DTensor) else dt
300 new_args[q_idx] = DTensor.from_local(
301 q_ul, two_d_mesh, (Shard(self.seq_dim), Shard(self.head_dim))
302 )
303 new_args[k_idx] = DTensor.from_local(
304 _local_dt(k_full), two_d_mesh, (Replicate(), Shard(self.head_dim))
305 )
306 new_args[v_idx] = DTensor.from_local(
307 _local_dt(v_full), two_d_mesh, (Replicate(), Shard(self.head_dim))
308 )
309 return tuple(new_args)
311 # ------------------------------------------------------------------
312 # Attention post-hook (Ulysses and Hybrid share the same reverse ATA)
313 # ------------------------------------------------------------------
315 def _attn_post_hook_ata(self, module, args, output, ds_submesh): # pylint: disable=unused-argument
316 """Reverse head→seq gather on ds_submesh; returns local tensor."""
317 def _process(o):
318 if isinstance(o, (Tensor, DTensor)):
319 if isinstance(o, DTensor):
320 o = o.to_local()
321 return _gather_head_to_seq(
322 o, ds_submesh, self.seq_dim, self.head_dim
323 ).to_local()
324 return o
326 if isinstance(output, (tuple, list)):
327 return type(output)(_process(o) for o in output)
328 return _process(output)
330 # ------------------------------------------------------------------
331 # Backward: wait A2A handle (launched by autograd) before proj GEMM
332 # ------------------------------------------------------------------
334 def _proj_bwd_pre_hook(self, module, grad_output, bwd_slot): # pylint: disable=unused-argument
335 """Wait backward A2A just before proj GEMM; replace grad with seq-form.
337 The async head→seq A2A is launched inside _TorchAsyncA2AFunction.backward
338 and appended to ``bwd_slot``. Waiting here lets the A2A overlap with the
339 preceding proj GEMM.
340 """
341 work, out_perm = bwd_slot.pop()
342 work.wait()
343 d_seq = _a2a_reconstruct(out_perm, self.head_dim)
344 return (d_seq,) + grad_output[1:] if isinstance(grad_output, tuple) else (d_seq,)