Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / context_parallel / context_parallel.py: 16%

167 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Unified Context Parallel: Pure Ulysses, Pure Colossal AI, and Hybrid CP.""" 

16from functools import partial 

17from typing import Optional 

18 

19from hyper_parallel.core.dtensor.device_mesh import DeviceMesh 

20from hyper_parallel.core.dtensor.dtensor import DTensor 

21from hyper_parallel.core.tensor_parallel.style import ParallelStyle 

22from hyper_parallel.core.dtensor.placement_types import Shard, Replicate 

23from hyper_parallel.platform import get_platform 

24 

25platform = get_platform() 

26Module = platform.Module 

27Tensor = platform.Tensor 

28 

29 

30# --------------------------------------------------------------------------- 

31# Low-level communication primitives 

32# --------------------------------------------------------------------------- 

33 

34def _ensure_1d(device_mesh: DeviceMesh) -> DeviceMesh: 

35 """Return a 1-D DeviceMesh (flatten if multi-dimensional).""" 

36 if device_mesh.ndim == 1: 

37 return device_mesh 

38 ranks = list(device_mesh.rank_list) 

39 return DeviceMesh(device_mesh.device_type, ranks, mesh_dim_names=("cp",)) 

40 

41 

42def _build_2d_mesh(device_mesh: DeviceMesh, ds: int, co: int) -> DeviceMesh: 

43 """Build or validate a 2-D ``(co × ds)`` DeviceMesh for Hybrid CP. 

44 

45 If *device_mesh* is already 2-D it is returned as-is (must have 

46 ``mesh_dim_names`` set). Otherwise the ranks of the 1-D mesh are tiled 

47 into *co* rows of *ds* adjacent ranks each. 

48 """ 

49 if device_mesh.ndim == 2: 

50 if not device_mesh.mesh_dim_names: 

51 raise ValueError( 

52 "2-D device_mesh for Hybrid CP must have mesh_dim_names=(\"co\", \"ds\")." 

53 ) 

54 return device_mesh 

55 ranks = list(device_mesh.rank_list) 

56 return DeviceMesh( 

57 device_mesh.device_type, 

58 [ranks[i * ds:(i + 1) * ds] for i in range(co)], 

59 mesh_dim_names=("co", "ds"), 

60 ) 

61 

62 

63def _scatter_seq_to_head( 

64 tensor: Tensor, 

65 submesh: DeviceMesh, 

66 seq_dim: int, 

67 head_dim: int, 

68 submesh_size: int, 

69) -> "DTensor": 

70 """All-to-all: ``Shard(seq_dim)`` → ``Shard(head_dim)``. Returns DTensor.""" 

71 if isinstance(tensor, DTensor): 

72 return tensor.redistribute(submesh, (Shard(head_dim),)) 

73 if tensor.shape[head_dim] % submesh_size != 0: 

74 raise ValueError( 

75 f"num_heads ({tensor.shape[head_dim]}) must be divisible by " 

76 f"ulysses_degree ({submesh_size})." 

77 ) 

78 return DTensor.from_local(tensor, submesh, (Shard(seq_dim),)).redistribute( 

79 submesh, (Shard(head_dim),) 

80 ) 

81 

82 

83def _gather_head_to_seq( 

84 tensor: Tensor, 

85 submesh: DeviceMesh, 

86 seq_dim: int, 

87 head_dim: int, 

88) -> "DTensor": 

89 """Reverse all-to-all: ``Shard(head_dim)`` → ``Shard(seq_dim)``. Returns DTensor.""" 

90 if isinstance(tensor, DTensor): 

91 return tensor.redistribute(submesh, (Shard(seq_dim),)) 

92 return DTensor.from_local(tensor, submesh, (Shard(head_dim),)).redistribute( 

93 submesh, (Shard(seq_dim),) 

94 ) 

95 

96 

97def _gather_seq( 

98 tensor: Tensor, 

99 submesh: DeviceMesh, 

100 seq_dim: int, 

101) -> "DTensor": 

102 """All-gather: ``Shard(seq_dim)`` → ``Replicate``. Returns DTensor.""" 

103 if isinstance(tensor, DTensor): 

104 return tensor.redistribute(submesh, (Replicate(),)) 

105 return DTensor.from_local(tensor, submesh, (Shard(seq_dim),)).redistribute( 

106 submesh, (Replicate(),) 

107 ) 

108 

109 

110 

111 

112# --------------------------------------------------------------------------- 

113# Unified ContextParallel 

114# --------------------------------------------------------------------------- 

115 

116class ContextParallel(ParallelStyle): 

117 """Unified Context Parallel for core-attention modules. 

118 

119 Three modes controlled by ``ulysses_degree``: 

120 

121 +-----------------+--------------------+------------------------------------------+ 

122 | Mode | ``ulysses_degree`` | Mechanism | 

123 +=================+====================+==========================================+ 

124 | Pure Ulysses | ``None`` (default) | seq→head A2A before attn; | 

125 | | (= cp_size) | head→seq A2A after. | 

126 | | | Requires ``num_heads % cp_size == 0``. | 

127 +-----------------+--------------------+------------------------------------------+ 

128 | Pure Colossal AI| ``1`` | Q stays as local Shard(seq); | 

129 | | | K/V all-gathered (Replicate). | 

130 | | | No head-count constraint. | 

131 +-----------------+--------------------+------------------------------------------+ 

132 | Hybrid | ``1 < k < cp_size``| Q/K/V seq→head A2A on Ulysses sub-mesh | 

133 | | | (size ``k``); K/V then all-gathered on | 

134 | | | Colossal sub-mesh (size ``cp_size // k``)| 

135 | | | Requires ``num_heads % k == 0``. | 

136 +-----------------+--------------------+------------------------------------------+ 

137 

138 Args: 

139 seq_dim: Sequence dimension index. 1 for BSHD, 2 for BNSD. 

140 head_dim: Head dimension index. 2 for BSHD, 1 for BNSD. 

141 ulysses_degree: Ulysses sub-mesh size (see table above). 

142 qkv_indices: Positional-argument indices for (Q, K, V). 

143 qkv_kwarg_names: Keyword-argument names for (Q, K, V). 

144 load_balance: Enable Head-Tail Q-exchange load balancing. 

145 Only valid with Pure Colossal AI (``ulysses_degree=1``). 

146 

147 **Important**: When ``load_balance=True``, ``q.shape[seq_dim]`` 

148 inside ``forward()`` returns ``S / 2`` (global shape / 2) 

149 rather than the true global ``S``. This is because 

150 ``DTensor.shape`` returns ``local_tensor_size * mesh_size``, 

151 and each sub-FA call wraps a half-sized Q shard 

152 (``S / (2 * cp_size)`` tokens) with a ``co_submesh`` of 

153 size ``cp_size``, giving a DTensor global shape of 

154 ``S / (2 * cp_size) * cp_size = S / 2``. 

155 K/V are always Replicate so ``k.shape[seq_dim]`` always 

156 returns the true ``S``. **When building the attention mask, 

157 use ``k.shape[seq_dim]`` (not ``q.shape[seq_dim]``) to 

158 obtain the correct global sequence length.** 

159 """ 

160 

161 def __init__( 

162 self, 

163 seq_dim: int = 1, 

164 head_dim: int = 2, 

165 ulysses_degree: Optional[int] = None, 

166 qkv_indices: tuple = (0, 1, 2), 

167 qkv_kwarg_names: tuple = (), 

168 load_balance: bool = False, 

169 ): 

170 if load_balance and ulysses_degree != 1: 

171 raise ValueError( 

172 "load_balance=True requires ulysses_degree=1 (Pure Colossal AI mode)." 

173 ) 

174 self.seq_dim = seq_dim 

175 self.head_dim = head_dim 

176 self.ulysses_degree = ulysses_degree 

177 self.qkv_indices = qkv_indices 

178 self.qkv_kwarg_names = qkv_kwarg_names 

179 self.load_balance = load_balance 

180 

181 # ------------------------------------------------------------------ 

182 # ParallelStyle interface 

183 # ------------------------------------------------------------------ 

184 

185 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module: 

186 """Register forward hooks on *module* and return it. 

187 

188 Args: 

189 module: attention submodule to parallelise. 

190 device_mesh: CP device mesh (1-D or 2-D). 

191 """ 

192 cp_size = device_mesh.mesh.numel() 

193 ds = self.ulysses_degree if self.ulysses_degree is not None else cp_size 

194 if cp_size % ds != 0: 

195 raise ValueError( 

196 f"cp_size ({cp_size}) must be divisible by ulysses_degree ({ds})." 

197 ) 

198 co = cp_size // ds 

199 

200 if ds == 1: 

201 # Pure Colossal AI 

202 co_submesh = _ensure_1d(device_mesh) 

203 if self.load_balance: 

204 self._apply_lb_colossal(module, co_submesh) 

205 else: 

206 module.register_forward_pre_hook( 

207 partial(self._pre_hook_colossal, co_submesh=co_submesh), 

208 with_kwargs=True, 

209 ) 

210 module.register_forward_hook( 

211 partial(self._post_hook_colossal, co_submesh=co_submesh) 

212 ) 

213 elif co == 1: 

214 # Pure Ulysses 

215 ds_submesh = _ensure_1d(device_mesh) 

216 module.register_forward_pre_hook( 

217 partial(self._pre_hook_ulysses, ds_submesh=ds_submesh, ds_size=ds), 

218 with_kwargs=True, 

219 ) 

220 module.register_forward_hook( 

221 partial(self._post_hook_ata, ds_submesh=ds_submesh) 

222 ) 

223 else: 

224 # Hybrid 

225 two_d_mesh = _build_2d_mesh(device_mesh, ds, co) 

226 dim_names = two_d_mesh.mesh_dim_names 

227 assert dim_names is not None, "2-D mesh must have mesh_dim_names (guaranteed by _build_2d_mesh)" 

228 ds_submesh = two_d_mesh[dim_names[1]] 

229 module.register_forward_pre_hook( 

230 partial( 

231 self._pre_hook_hybrid, 

232 two_d_mesh=two_d_mesh, 

233 ds_submesh=ds_submesh, 

234 ds_size=ds, 

235 ), 

236 with_kwargs=True, 

237 ) 

238 module.register_forward_hook( 

239 partial(self._post_hook_ata, ds_submesh=ds_submesh) 

240 ) 

241 

242 return module 

243 

244 # ------------------------------------------------------------------ 

245 # Pre-hooks 

246 # ------------------------------------------------------------------ 

247 

248 def _pre_hook_colossal(self, module, args, kwargs, co_submesh): # pylint: disable=unused-argument 

249 """Wrap Q as ``DTensor(co_submesh, Shard(seq))``; all-gather K/V.""" 

250 new_args = list(args) 

251 new_kwargs = dict(kwargs) 

252 

253 q_idx = self.qkv_indices[0] 

254 if q_idx < len(new_args) and isinstance(new_args[q_idx], Tensor) \ 

255 and not isinstance(new_args[q_idx], DTensor): 

256 new_args[q_idx] = DTensor.from_local( 

257 new_args[q_idx], co_submesh, (Shard(self.seq_dim),) 

258 ) 

259 for idx in self.qkv_indices[1:]: 

260 if idx < len(new_args) and isinstance(new_args[idx], Tensor): 

261 new_args[idx] = _gather_seq(new_args[idx], co_submesh, self.seq_dim) 

262 

263 if self.qkv_kwarg_names: 

264 q_name = self.qkv_kwarg_names[0] 

265 if q_name in new_kwargs and isinstance(new_kwargs[q_name], Tensor) \ 

266 and not isinstance(new_kwargs[q_name], DTensor): 

267 new_kwargs[q_name] = DTensor.from_local( 

268 new_kwargs[q_name], co_submesh, (Shard(self.seq_dim),) 

269 ) 

270 for name in self.qkv_kwarg_names[1:]: 

271 if name in new_kwargs and isinstance(new_kwargs[name], Tensor): 

272 new_kwargs[name] = _gather_seq(new_kwargs[name], co_submesh, self.seq_dim) 

273 

274 return tuple(new_args), new_kwargs 

275 

276 def _pre_hook_ulysses(self, module, args, kwargs, ds_submesh, ds_size): # pylint: disable=unused-argument 

277 """Seq→head all-to-all for Q, K, and V.""" 

278 new_args = list(args) 

279 for idx in self.qkv_indices: 

280 if idx < len(new_args) and isinstance(new_args[idx], Tensor): 

281 new_args[idx] = _scatter_seq_to_head( 

282 new_args[idx], ds_submesh, self.seq_dim, self.head_dim, ds_size 

283 ) 

284 

285 new_kwargs = dict(kwargs) 

286 for name in self.qkv_kwarg_names: 

287 if name in new_kwargs and isinstance(new_kwargs[name], Tensor): 

288 new_kwargs[name] = _scatter_seq_to_head( 

289 new_kwargs[name], ds_submesh, self.seq_dim, self.head_dim, ds_size 

290 ) 

291 

292 return tuple(new_args), new_kwargs 

293 

294 def _ata_scatter_to_2d(self, t, ds_submesh, two_d_mesh, ds_size): 

295 """ATA scatter: Shard(seq)→Shard(head) on ds_submesh; wrap as 2-D DTensor. 

296 

297 Args: 

298 t: Plain local tensor to scatter. 

299 ds_submesh: 1-D Ulysses sub-mesh. 

300 two_d_mesh: 2-D mesh (co × ds). 

301 ds_size: Ulysses degree (world size on ds_submesh). 

302 

303 Returns: 

304 DTensor with placements ``(Shard(seq_dim), Shard(head_dim))`` on two_d_mesh. 

305 """ 

306 if t.shape[self.head_dim] % ds_size != 0: 

307 raise ValueError( 

308 f"num_heads ({t.shape[self.head_dim]}) must be divisible by " 

309 f"ulysses_degree ({ds_size})." 

310 ) 

311 local = ( 

312 DTensor.from_local(t, ds_submesh, (Shard(self.seq_dim),)) 

313 .redistribute(ds_submesh, (Shard(self.head_dim),)) 

314 .to_local() 

315 ) 

316 return DTensor.from_local(local, two_d_mesh, (Shard(self.seq_dim), Shard(self.head_dim))) 

317 

318 def _pre_hook_hybrid(self, module, args, kwargs, two_d_mesh, ds_submesh, ds_size): # pylint: disable=unused-argument 

319 """Hybrid: seq→head ATA on ds-submesh, then all-gather K/V on co-submesh. 

320 

321 After this hook, placements on ``two_d_mesh`` are: 

322 Q → ``(Shard(seq_dim), Shard(head_dim))`` 

323 K/V → ``(Replicate(), Shard(head_dim))`` 

324 """ 

325 new_args = list(args) 

326 

327 # Step 1: ATA on ds_submesh for all of Q/K/V; wrap as 2-D DTensor 

328 for idx in self.qkv_indices: 

329 if idx < len(new_args) and isinstance(new_args[idx], Tensor) \ 

330 and not isinstance(new_args[idx], DTensor): 

331 new_args[idx] = self._ata_scatter_to_2d( 

332 new_args[idx], ds_submesh, two_d_mesh, ds_size 

333 ) 

334 

335 # Step 2: all-gather K/V on co-dim (Shard(seq)→Replicate) 

336 for idx in self.qkv_indices[1:]: 

337 if idx < len(new_args) and isinstance(new_args[idx], DTensor): 

338 new_args[idx] = new_args[idx].redistribute( 

339 two_d_mesh, (Replicate(), Shard(self.head_dim)) 

340 ) 

341 

342 # Same for kwargs 

343 new_kwargs = dict(kwargs) 

344 for name in self.qkv_kwarg_names: 

345 if name in new_kwargs and isinstance(new_kwargs[name], Tensor) \ 

346 and not isinstance(new_kwargs[name], DTensor): 

347 t = new_kwargs[name] 

348 local = ( 

349 DTensor.from_local(t, ds_submesh, (Shard(self.seq_dim),)) 

350 .redistribute(ds_submesh, (Shard(self.head_dim),)) 

351 .to_local() 

352 ) 

353 new_kwargs[name] = DTensor.from_local( 

354 local, two_d_mesh, (Shard(self.seq_dim), Shard(self.head_dim)) 

355 ) 

356 for name in self.qkv_kwarg_names[1:]: 

357 if name in new_kwargs and isinstance(new_kwargs[name], DTensor): 

358 new_kwargs[name] = new_kwargs[name].redistribute( 

359 two_d_mesh, (Replicate(), Shard(self.head_dim)) 

360 ) 

361 

362 return tuple(new_args), new_kwargs 

363 

364 # ------------------------------------------------------------------ 

365 # Post-hooks 

366 # ------------------------------------------------------------------ 

367 

368 def _post_hook_ata(self, module, inputs, outputs, ds_submesh): # pylint: disable=unused-argument 

369 """Reverse all-to-all: head→seq on ds-submesh; returns local tensor. 

370 

371 Handles both Ulysses (1-D DTensor or plain tensor) and Hybrid 

372 (2-D DTensor — ``to_local()`` first to project onto the 1-D ds-submesh). 

373 """ 

374 def _process(out): 

375 if isinstance(out, (Tensor, DTensor)): 

376 if isinstance(out, DTensor): 

377 out = out.to_local() 

378 return _gather_head_to_seq( 

379 out, ds_submesh, self.seq_dim, self.head_dim 

380 ).to_local() 

381 return out 

382 

383 if isinstance(outputs, (tuple, list)): 

384 return type(outputs)(_process(o) for o in outputs) 

385 return _process(outputs) 

386 

387 def _post_hook_colossal(self, module, inputs, outputs, co_submesh): # pylint: disable=unused-argument 

388 """Colossal AI: convert any DTensor output to a local tensor.""" 

389 def _process(out): 

390 return out.to_local() if isinstance(out, DTensor) else out 

391 

392 if isinstance(outputs, (tuple, list)): 

393 return type(outputs)(_process(o) for o in outputs) 

394 return _process(outputs) 

395 

396 # ------------------------------------------------------------------ 

397 # Load-balance Colossal AI (Head-Tail Q-exchange) 

398 # ------------------------------------------------------------------ 

399 

400 def _apply_lb_colossal(self, module: Module, co_submesh: DeviceMesh) -> None: 

401 """Replace ``module.forward`` with the load-balanced two-sub-FA wrapper.""" 

402 ws = co_submesh.mesh.numel() 

403 rank_list = list(co_submesh.rank_list) 

404 local_idx = rank_list.index(platform.get_rank()) 

405 target_idx = ws - 1 - local_idx 

406 module.forward = partial( 

407 self._lb_colossal_forward, 

408 original_forward=module.forward, 

409 co_submesh=co_submesh, 

410 local_idx=local_idx, 

411 target_idx=target_idx, 

412 ws=ws, 

413 peer_rank=rank_list[target_idx], 

414 ) 

415 

416 def _lb_colossal_forward( # pylint: disable=too-many-arguments,too-many-locals 

417 self, 

418 *args, 

419 original_forward, 

420 co_submesh: DeviceMesh, 

421 local_idx: int, 

422 target_idx: int, 

423 ws: int, 

424 peer_rank: int, 

425 **kwargs, 

426 ): 

427 """Head-Tail load-balanced forward for Pure Colossal AI CP. 

428 

429 Splits local Q (shape ``[B, S/ws, H, D]``) into head/tail halves. 

430 The tail is P2P-exchanged with the paired rank ``(ws - 1 - local_idx)``. 

431 Two sub-FA calls are issued with adjusted causal-mask offsets: 

432 

433 - FA1: ``q_keep`` at ``split_id = 2*local_idx`` 

434 - FA2: ``q_peer`` at ``split_id = 2*target_idx + 1`` 

435 

436 FA2's output is exchanged back; final output = ``cat([FA1, FA2_recv])``. 

437 """ 

438 from hyper_parallel.core.shard.ops.parallel_npu_flash_attention_score import ( # pylint: disable=import-outside-toplevel 

439 _set_lb_override, _clear_lb_override, 

440 ) 

441 

442 seq_dim = self.seq_dim 

443 q_idx, k_idx, v_idx = self.qkv_indices 

444 new_args = list(args) 

445 

446 q = new_args[q_idx] 

447 half = q.shape[seq_dim] // 2 

448 q_keep = q.narrow(seq_dim, 0, half) 

449 q_mine = q.narrow(seq_dim, half, half) 

450 

451 q_peer = platform.p2p_exchange(q_mine, peer_rank) 

452 k_full = _gather_seq(new_args[k_idx], co_submesh, seq_dim).to_local() 

453 v_full = _gather_seq(new_args[v_idx], co_submesh, seq_dim).to_local() 

454 

455 # K/V are Replicate; wrap once and reuse for both FA calls 

456 k_full_dt = DTensor.from_local(k_full, co_submesh, (Replicate(),)) 

457 v_full_dt = DTensor.from_local(v_full, co_submesh, (Replicate(),)) 

458 

459 def _fa(q_half, split_id): 

460 new_args[q_idx] = DTensor.from_local(q_half, co_submesh, (Shard(seq_dim),)) 

461 new_args[k_idx] = k_full_dt 

462 new_args[v_idx] = v_full_dt 

463 _set_lb_override(split_id=split_id, split_num=2 * ws) 

464 out = original_forward(*new_args, **kwargs) 

465 _clear_lb_override() 

466 return out.to_local() if isinstance(out, DTensor) else out 

467 

468 fa1_out = _fa(q_keep, split_id=2 * local_idx) 

469 fa2_out = _fa(q_peer, split_id=2 * target_idx + 1) 

470 fa2_our = platform.p2p_exchange(fa2_out, peer_rank) 

471 return platform.cat([fa1_out, fa2_our], dim=seq_dim)