Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / context_parallel / context_parallel.py: 16%
167 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Unified Context Parallel: Pure Ulysses, Pure Colossal AI, and Hybrid CP."""
16from functools import partial
17from typing import Optional
19from hyper_parallel.core.dtensor.device_mesh import DeviceMesh
20from hyper_parallel.core.dtensor.dtensor import DTensor
21from hyper_parallel.core.tensor_parallel.style import ParallelStyle
22from hyper_parallel.core.dtensor.placement_types import Shard, Replicate
23from hyper_parallel.platform import get_platform
25platform = get_platform()
26Module = platform.Module
27Tensor = platform.Tensor
30# ---------------------------------------------------------------------------
31# Low-level communication primitives
32# ---------------------------------------------------------------------------
34def _ensure_1d(device_mesh: DeviceMesh) -> DeviceMesh:
35 """Return a 1-D DeviceMesh (flatten if multi-dimensional)."""
36 if device_mesh.ndim == 1:
37 return device_mesh
38 ranks = list(device_mesh.rank_list)
39 return DeviceMesh(device_mesh.device_type, ranks, mesh_dim_names=("cp",))
42def _build_2d_mesh(device_mesh: DeviceMesh, ds: int, co: int) -> DeviceMesh:
43 """Build or validate a 2-D ``(co × ds)`` DeviceMesh for Hybrid CP.
45 If *device_mesh* is already 2-D it is returned as-is (must have
46 ``mesh_dim_names`` set). Otherwise the ranks of the 1-D mesh are tiled
47 into *co* rows of *ds* adjacent ranks each.
48 """
49 if device_mesh.ndim == 2:
50 if not device_mesh.mesh_dim_names:
51 raise ValueError(
52 "2-D device_mesh for Hybrid CP must have mesh_dim_names=(\"co\", \"ds\")."
53 )
54 return device_mesh
55 ranks = list(device_mesh.rank_list)
56 return DeviceMesh(
57 device_mesh.device_type,
58 [ranks[i * ds:(i + 1) * ds] for i in range(co)],
59 mesh_dim_names=("co", "ds"),
60 )
63def _scatter_seq_to_head(
64 tensor: Tensor,
65 submesh: DeviceMesh,
66 seq_dim: int,
67 head_dim: int,
68 submesh_size: int,
69) -> "DTensor":
70 """All-to-all: ``Shard(seq_dim)`` → ``Shard(head_dim)``. Returns DTensor."""
71 if isinstance(tensor, DTensor):
72 return tensor.redistribute(submesh, (Shard(head_dim),))
73 if tensor.shape[head_dim] % submesh_size != 0:
74 raise ValueError(
75 f"num_heads ({tensor.shape[head_dim]}) must be divisible by "
76 f"ulysses_degree ({submesh_size})."
77 )
78 return DTensor.from_local(tensor, submesh, (Shard(seq_dim),)).redistribute(
79 submesh, (Shard(head_dim),)
80 )
83def _gather_head_to_seq(
84 tensor: Tensor,
85 submesh: DeviceMesh,
86 seq_dim: int,
87 head_dim: int,
88) -> "DTensor":
89 """Reverse all-to-all: ``Shard(head_dim)`` → ``Shard(seq_dim)``. Returns DTensor."""
90 if isinstance(tensor, DTensor):
91 return tensor.redistribute(submesh, (Shard(seq_dim),))
92 return DTensor.from_local(tensor, submesh, (Shard(head_dim),)).redistribute(
93 submesh, (Shard(seq_dim),)
94 )
97def _gather_seq(
98 tensor: Tensor,
99 submesh: DeviceMesh,
100 seq_dim: int,
101) -> "DTensor":
102 """All-gather: ``Shard(seq_dim)`` → ``Replicate``. Returns DTensor."""
103 if isinstance(tensor, DTensor):
104 return tensor.redistribute(submesh, (Replicate(),))
105 return DTensor.from_local(tensor, submesh, (Shard(seq_dim),)).redistribute(
106 submesh, (Replicate(),)
107 )
112# ---------------------------------------------------------------------------
113# Unified ContextParallel
114# ---------------------------------------------------------------------------
116class ContextParallel(ParallelStyle):
117 """Unified Context Parallel for core-attention modules.
119 Three modes controlled by ``ulysses_degree``:
121 +-----------------+--------------------+------------------------------------------+
122 | Mode | ``ulysses_degree`` | Mechanism |
123 +=================+====================+==========================================+
124 | Pure Ulysses | ``None`` (default) | seq→head A2A before attn; |
125 | | (= cp_size) | head→seq A2A after. |
126 | | | Requires ``num_heads % cp_size == 0``. |
127 +-----------------+--------------------+------------------------------------------+
128 | Pure Colossal AI| ``1`` | Q stays as local Shard(seq); |
129 | | | K/V all-gathered (Replicate). |
130 | | | No head-count constraint. |
131 +-----------------+--------------------+------------------------------------------+
132 | Hybrid | ``1 < k < cp_size``| Q/K/V seq→head A2A on Ulysses sub-mesh |
133 | | | (size ``k``); K/V then all-gathered on |
134 | | | Colossal sub-mesh (size ``cp_size // k``)|
135 | | | Requires ``num_heads % k == 0``. |
136 +-----------------+--------------------+------------------------------------------+
138 Args:
139 seq_dim: Sequence dimension index. 1 for BSHD, 2 for BNSD.
140 head_dim: Head dimension index. 2 for BSHD, 1 for BNSD.
141 ulysses_degree: Ulysses sub-mesh size (see table above).
142 qkv_indices: Positional-argument indices for (Q, K, V).
143 qkv_kwarg_names: Keyword-argument names for (Q, K, V).
144 load_balance: Enable Head-Tail Q-exchange load balancing.
145 Only valid with Pure Colossal AI (``ulysses_degree=1``).
147 **Important**: When ``load_balance=True``, ``q.shape[seq_dim]``
148 inside ``forward()`` returns ``S / 2`` (global shape / 2)
149 rather than the true global ``S``. This is because
150 ``DTensor.shape`` returns ``local_tensor_size * mesh_size``,
151 and each sub-FA call wraps a half-sized Q shard
152 (``S / (2 * cp_size)`` tokens) with a ``co_submesh`` of
153 size ``cp_size``, giving a DTensor global shape of
154 ``S / (2 * cp_size) * cp_size = S / 2``.
155 K/V are always Replicate so ``k.shape[seq_dim]`` always
156 returns the true ``S``. **When building the attention mask,
157 use ``k.shape[seq_dim]`` (not ``q.shape[seq_dim]``) to
158 obtain the correct global sequence length.**
159 """
161 def __init__(
162 self,
163 seq_dim: int = 1,
164 head_dim: int = 2,
165 ulysses_degree: Optional[int] = None,
166 qkv_indices: tuple = (0, 1, 2),
167 qkv_kwarg_names: tuple = (),
168 load_balance: bool = False,
169 ):
170 if load_balance and ulysses_degree != 1:
171 raise ValueError(
172 "load_balance=True requires ulysses_degree=1 (Pure Colossal AI mode)."
173 )
174 self.seq_dim = seq_dim
175 self.head_dim = head_dim
176 self.ulysses_degree = ulysses_degree
177 self.qkv_indices = qkv_indices
178 self.qkv_kwarg_names = qkv_kwarg_names
179 self.load_balance = load_balance
181 # ------------------------------------------------------------------
182 # ParallelStyle interface
183 # ------------------------------------------------------------------
185 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module:
186 """Register forward hooks on *module* and return it.
188 Args:
189 module: attention submodule to parallelise.
190 device_mesh: CP device mesh (1-D or 2-D).
191 """
192 cp_size = device_mesh.mesh.numel()
193 ds = self.ulysses_degree if self.ulysses_degree is not None else cp_size
194 if cp_size % ds != 0:
195 raise ValueError(
196 f"cp_size ({cp_size}) must be divisible by ulysses_degree ({ds})."
197 )
198 co = cp_size // ds
200 if ds == 1:
201 # Pure Colossal AI
202 co_submesh = _ensure_1d(device_mesh)
203 if self.load_balance:
204 self._apply_lb_colossal(module, co_submesh)
205 else:
206 module.register_forward_pre_hook(
207 partial(self._pre_hook_colossal, co_submesh=co_submesh),
208 with_kwargs=True,
209 )
210 module.register_forward_hook(
211 partial(self._post_hook_colossal, co_submesh=co_submesh)
212 )
213 elif co == 1:
214 # Pure Ulysses
215 ds_submesh = _ensure_1d(device_mesh)
216 module.register_forward_pre_hook(
217 partial(self._pre_hook_ulysses, ds_submesh=ds_submesh, ds_size=ds),
218 with_kwargs=True,
219 )
220 module.register_forward_hook(
221 partial(self._post_hook_ata, ds_submesh=ds_submesh)
222 )
223 else:
224 # Hybrid
225 two_d_mesh = _build_2d_mesh(device_mesh, ds, co)
226 dim_names = two_d_mesh.mesh_dim_names
227 assert dim_names is not None, "2-D mesh must have mesh_dim_names (guaranteed by _build_2d_mesh)"
228 ds_submesh = two_d_mesh[dim_names[1]]
229 module.register_forward_pre_hook(
230 partial(
231 self._pre_hook_hybrid,
232 two_d_mesh=two_d_mesh,
233 ds_submesh=ds_submesh,
234 ds_size=ds,
235 ),
236 with_kwargs=True,
237 )
238 module.register_forward_hook(
239 partial(self._post_hook_ata, ds_submesh=ds_submesh)
240 )
242 return module
244 # ------------------------------------------------------------------
245 # Pre-hooks
246 # ------------------------------------------------------------------
248 def _pre_hook_colossal(self, module, args, kwargs, co_submesh): # pylint: disable=unused-argument
249 """Wrap Q as ``DTensor(co_submesh, Shard(seq))``; all-gather K/V."""
250 new_args = list(args)
251 new_kwargs = dict(kwargs)
253 q_idx = self.qkv_indices[0]
254 if q_idx < len(new_args) and isinstance(new_args[q_idx], Tensor) \
255 and not isinstance(new_args[q_idx], DTensor):
256 new_args[q_idx] = DTensor.from_local(
257 new_args[q_idx], co_submesh, (Shard(self.seq_dim),)
258 )
259 for idx in self.qkv_indices[1:]:
260 if idx < len(new_args) and isinstance(new_args[idx], Tensor):
261 new_args[idx] = _gather_seq(new_args[idx], co_submesh, self.seq_dim)
263 if self.qkv_kwarg_names:
264 q_name = self.qkv_kwarg_names[0]
265 if q_name in new_kwargs and isinstance(new_kwargs[q_name], Tensor) \
266 and not isinstance(new_kwargs[q_name], DTensor):
267 new_kwargs[q_name] = DTensor.from_local(
268 new_kwargs[q_name], co_submesh, (Shard(self.seq_dim),)
269 )
270 for name in self.qkv_kwarg_names[1:]:
271 if name in new_kwargs and isinstance(new_kwargs[name], Tensor):
272 new_kwargs[name] = _gather_seq(new_kwargs[name], co_submesh, self.seq_dim)
274 return tuple(new_args), new_kwargs
276 def _pre_hook_ulysses(self, module, args, kwargs, ds_submesh, ds_size): # pylint: disable=unused-argument
277 """Seq→head all-to-all for Q, K, and V."""
278 new_args = list(args)
279 for idx in self.qkv_indices:
280 if idx < len(new_args) and isinstance(new_args[idx], Tensor):
281 new_args[idx] = _scatter_seq_to_head(
282 new_args[idx], ds_submesh, self.seq_dim, self.head_dim, ds_size
283 )
285 new_kwargs = dict(kwargs)
286 for name in self.qkv_kwarg_names:
287 if name in new_kwargs and isinstance(new_kwargs[name], Tensor):
288 new_kwargs[name] = _scatter_seq_to_head(
289 new_kwargs[name], ds_submesh, self.seq_dim, self.head_dim, ds_size
290 )
292 return tuple(new_args), new_kwargs
294 def _ata_scatter_to_2d(self, t, ds_submesh, two_d_mesh, ds_size):
295 """ATA scatter: Shard(seq)→Shard(head) on ds_submesh; wrap as 2-D DTensor.
297 Args:
298 t: Plain local tensor to scatter.
299 ds_submesh: 1-D Ulysses sub-mesh.
300 two_d_mesh: 2-D mesh (co × ds).
301 ds_size: Ulysses degree (world size on ds_submesh).
303 Returns:
304 DTensor with placements ``(Shard(seq_dim), Shard(head_dim))`` on two_d_mesh.
305 """
306 if t.shape[self.head_dim] % ds_size != 0:
307 raise ValueError(
308 f"num_heads ({t.shape[self.head_dim]}) must be divisible by "
309 f"ulysses_degree ({ds_size})."
310 )
311 local = (
312 DTensor.from_local(t, ds_submesh, (Shard(self.seq_dim),))
313 .redistribute(ds_submesh, (Shard(self.head_dim),))
314 .to_local()
315 )
316 return DTensor.from_local(local, two_d_mesh, (Shard(self.seq_dim), Shard(self.head_dim)))
318 def _pre_hook_hybrid(self, module, args, kwargs, two_d_mesh, ds_submesh, ds_size): # pylint: disable=unused-argument
319 """Hybrid: seq→head ATA on ds-submesh, then all-gather K/V on co-submesh.
321 After this hook, placements on ``two_d_mesh`` are:
322 Q → ``(Shard(seq_dim), Shard(head_dim))``
323 K/V → ``(Replicate(), Shard(head_dim))``
324 """
325 new_args = list(args)
327 # Step 1: ATA on ds_submesh for all of Q/K/V; wrap as 2-D DTensor
328 for idx in self.qkv_indices:
329 if idx < len(new_args) and isinstance(new_args[idx], Tensor) \
330 and not isinstance(new_args[idx], DTensor):
331 new_args[idx] = self._ata_scatter_to_2d(
332 new_args[idx], ds_submesh, two_d_mesh, ds_size
333 )
335 # Step 2: all-gather K/V on co-dim (Shard(seq)→Replicate)
336 for idx in self.qkv_indices[1:]:
337 if idx < len(new_args) and isinstance(new_args[idx], DTensor):
338 new_args[idx] = new_args[idx].redistribute(
339 two_d_mesh, (Replicate(), Shard(self.head_dim))
340 )
342 # Same for kwargs
343 new_kwargs = dict(kwargs)
344 for name in self.qkv_kwarg_names:
345 if name in new_kwargs and isinstance(new_kwargs[name], Tensor) \
346 and not isinstance(new_kwargs[name], DTensor):
347 t = new_kwargs[name]
348 local = (
349 DTensor.from_local(t, ds_submesh, (Shard(self.seq_dim),))
350 .redistribute(ds_submesh, (Shard(self.head_dim),))
351 .to_local()
352 )
353 new_kwargs[name] = DTensor.from_local(
354 local, two_d_mesh, (Shard(self.seq_dim), Shard(self.head_dim))
355 )
356 for name in self.qkv_kwarg_names[1:]:
357 if name in new_kwargs and isinstance(new_kwargs[name], DTensor):
358 new_kwargs[name] = new_kwargs[name].redistribute(
359 two_d_mesh, (Replicate(), Shard(self.head_dim))
360 )
362 return tuple(new_args), new_kwargs
364 # ------------------------------------------------------------------
365 # Post-hooks
366 # ------------------------------------------------------------------
368 def _post_hook_ata(self, module, inputs, outputs, ds_submesh): # pylint: disable=unused-argument
369 """Reverse all-to-all: head→seq on ds-submesh; returns local tensor.
371 Handles both Ulysses (1-D DTensor or plain tensor) and Hybrid
372 (2-D DTensor — ``to_local()`` first to project onto the 1-D ds-submesh).
373 """
374 def _process(out):
375 if isinstance(out, (Tensor, DTensor)):
376 if isinstance(out, DTensor):
377 out = out.to_local()
378 return _gather_head_to_seq(
379 out, ds_submesh, self.seq_dim, self.head_dim
380 ).to_local()
381 return out
383 if isinstance(outputs, (tuple, list)):
384 return type(outputs)(_process(o) for o in outputs)
385 return _process(outputs)
387 def _post_hook_colossal(self, module, inputs, outputs, co_submesh): # pylint: disable=unused-argument
388 """Colossal AI: convert any DTensor output to a local tensor."""
389 def _process(out):
390 return out.to_local() if isinstance(out, DTensor) else out
392 if isinstance(outputs, (tuple, list)):
393 return type(outputs)(_process(o) for o in outputs)
394 return _process(outputs)
396 # ------------------------------------------------------------------
397 # Load-balance Colossal AI (Head-Tail Q-exchange)
398 # ------------------------------------------------------------------
400 def _apply_lb_colossal(self, module: Module, co_submesh: DeviceMesh) -> None:
401 """Replace ``module.forward`` with the load-balanced two-sub-FA wrapper."""
402 ws = co_submesh.mesh.numel()
403 rank_list = list(co_submesh.rank_list)
404 local_idx = rank_list.index(platform.get_rank())
405 target_idx = ws - 1 - local_idx
406 module.forward = partial(
407 self._lb_colossal_forward,
408 original_forward=module.forward,
409 co_submesh=co_submesh,
410 local_idx=local_idx,
411 target_idx=target_idx,
412 ws=ws,
413 peer_rank=rank_list[target_idx],
414 )
416 def _lb_colossal_forward( # pylint: disable=too-many-arguments,too-many-locals
417 self,
418 *args,
419 original_forward,
420 co_submesh: DeviceMesh,
421 local_idx: int,
422 target_idx: int,
423 ws: int,
424 peer_rank: int,
425 **kwargs,
426 ):
427 """Head-Tail load-balanced forward for Pure Colossal AI CP.
429 Splits local Q (shape ``[B, S/ws, H, D]``) into head/tail halves.
430 The tail is P2P-exchanged with the paired rank ``(ws - 1 - local_idx)``.
431 Two sub-FA calls are issued with adjusted causal-mask offsets:
433 - FA1: ``q_keep`` at ``split_id = 2*local_idx``
434 - FA2: ``q_peer`` at ``split_id = 2*target_idx + 1``
436 FA2's output is exchanged back; final output = ``cat([FA1, FA2_recv])``.
437 """
438 from hyper_parallel.core.shard.ops.parallel_npu_flash_attention_score import ( # pylint: disable=import-outside-toplevel
439 _set_lb_override, _clear_lb_override,
440 )
442 seq_dim = self.seq_dim
443 q_idx, k_idx, v_idx = self.qkv_indices
444 new_args = list(args)
446 q = new_args[q_idx]
447 half = q.shape[seq_dim] // 2
448 q_keep = q.narrow(seq_dim, 0, half)
449 q_mine = q.narrow(seq_dim, half, half)
451 q_peer = platform.p2p_exchange(q_mine, peer_rank)
452 k_full = _gather_seq(new_args[k_idx], co_submesh, seq_dim).to_local()
453 v_full = _gather_seq(new_args[v_idx], co_submesh, seq_dim).to_local()
455 # K/V are Replicate; wrap once and reuse for both FA calls
456 k_full_dt = DTensor.from_local(k_full, co_submesh, (Replicate(),))
457 v_full_dt = DTensor.from_local(v_full, co_submesh, (Replicate(),))
459 def _fa(q_half, split_id):
460 new_args[q_idx] = DTensor.from_local(q_half, co_submesh, (Shard(seq_dim),))
461 new_args[k_idx] = k_full_dt
462 new_args[v_idx] = v_full_dt
463 _set_lb_override(split_id=split_id, split_num=2 * ws)
464 out = original_forward(*new_args, **kwargs)
465 _clear_lb_override()
466 return out.to_local() if isinstance(out, DTensor) else out
468 fa1_out = _fa(q_keep, split_id=2 * local_idx)
469 fa2_out = _fa(q_peer, split_id=2 * target_idx + 1)
470 fa2_our = platform.p2p_exchange(fa2_out, peer_rank)
471 return platform.cat([fa1_out, fa2_our], dim=seq_dim)