Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / context_parallel / async_context_parallel.py: 20%

110 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""AsyncContextParallel: overlap projection GEMM with all-to-all communication. 

16 

17Supports Pure Ulysses, Hybrid CP modes. Falls back to sync ContextParallel 

18when q/k/v_proj not provided or in Pure Colossal AI mode. 

19 

20Forward: proj hooks launch async A2A → attn pre-hook waits Q/K/V → attn hook gathers output 

21Backward: autograd backward launches async A2A → proj pre-hooks wait before GEMMs 

22""" 

23from functools import partial 

24from typing import Optional, cast 

25 

26from hyper_parallel.core.context_parallel.context_parallel import ( 

27 ContextParallel, 

28 _build_2d_mesh, 

29 _ensure_1d, 

30 _gather_seq, 

31 _gather_head_to_seq, 

32) 

33from hyper_parallel.core.dtensor.device_mesh import DeviceMesh 

34from hyper_parallel.core.dtensor.dtensor import DTensor 

35from hyper_parallel.core.dtensor.placement_types import Shard, Replicate 

36from hyper_parallel.platform import get_platform 

37 

38platform = get_platform() 

39Module = platform.Module 

40Tensor = platform.Tensor 

41 

42 

43# --------------------------------------------------------------------------- 

44# All-to-all helpers 

45# --------------------------------------------------------------------------- 

46 

47def _launch_async_a2a_seq_to_head( 

48 tensor: Tensor, 

49 group, 

50 world_size: int, 

51 head_dim: int, 

52) -> tuple: 

53 """Launch async seq→head A2A (forward).""" 

54 x = tensor.contiguous() 

55 shape = list(x.shape) 

56 num_heads = shape[head_dim] 

57 if num_heads % world_size != 0: 

58 raise ValueError(f"num_heads ({num_heads}) must be divisible by world_size ({world_size}).") 

59 ndim = len(shape) + 1 

60 x_perm = x.reshape( 

61 shape[:head_dim] + [world_size, num_heads // world_size] + shape[head_dim + 1:] 

62 ).permute( 

63 [head_dim] + list(range(head_dim)) + list(range(head_dim + 1, ndim)) 

64 ).contiguous() 

65 out_perm, work = platform.all_to_all_single(x_perm, list(x_perm.shape), group, async_op=True) 

66 return work, out_perm 

67 

68 

69def _a2a_reconstruct(out_perm: Tensor, concat_dim: int) -> Tensor: 

70 """Reconstruct A2A result from raw out_perm.""" 

71 new_ndim = out_perm.dim() 

72 chunk_in_perm = concat_dim + 1 

73 recon_perm = list(range(1, chunk_in_perm)) + [0] + list(range(chunk_in_perm, new_ndim)) 

74 x_recon = out_perm.permute(recon_perm).contiguous() 

75 shape = list(x_recon.shape) 

76 merged = shape[concat_dim] * shape[concat_dim + 1] 

77 return x_recon.reshape(shape[:concat_dim] + [merged] + shape[concat_dim + 2:]) 

78 

79 

80# --------------------------------------------------------------------------- 

81# AsyncContextParallel 

82# --------------------------------------------------------------------------- 

83 

84class AsyncContextParallel(ContextParallel): 

85 """Context Parallel with projection–A2A compute overlap. 

86 

87 Requires ``q_proj``, ``k_proj``, ``v_proj`` in :meth:`apply`; otherwise 

88 falls back to synchronous :class:`ContextParallel`. 

89 

90 Pure Colossal AI (``ulysses_degree=1``) automatically falls back to sync 

91 because K/V AllGather is a barrier collective. 

92 

93 Args: 

94 seq_dim: Sequence dimension (1=BSHD, 2=BNSD). 

95 head_dim: Head dimension (2=BSHD, 1=BNSD). 

96 ulysses_degree: Ulysses sub-mesh size (see :class:`ContextParallel`). 

97 qkv_indices: Positional indices of (Q, K, V) in attention forward. 

98 qkv_kwarg_names: Keyword names for (Q, K, V). 

99 load_balance: Load-balance flag forwarded to base class. 

100 """ 

101 

102 def __init__( 

103 self, 

104 seq_dim: int = 1, 

105 head_dim: int = 2, 

106 ulysses_degree: Optional[int] = None, 

107 qkv_indices: tuple = (0, 1, 2), 

108 qkv_kwarg_names: tuple = (), 

109 load_balance: bool = False, 

110 ): 

111 super().__init__( 

112 seq_dim=seq_dim, 

113 head_dim=head_dim, 

114 ulysses_degree=ulysses_degree, 

115 qkv_indices=qkv_indices, 

116 qkv_kwarg_names=qkv_kwarg_names, 

117 load_balance=load_balance, 

118 ) 

119 

120 # ------------------------------------------------------------------ 

121 # Public entry point 

122 # ------------------------------------------------------------------ 

123 

124 def apply( # pylint: disable=arguments-differ 

125 self, 

126 module: Module, 

127 device_mesh: DeviceMesh, 

128 q_proj: Optional[Module] = None, 

129 k_proj: Optional[Module] = None, 

130 v_proj: Optional[Module] = None, 

131 ) -> Module: 

132 """Register async-overlap hooks and return *module*. 

133 

134 Falls back to synchronous :class:`ContextParallel` if any of 

135 ``q/k/v_proj`` is ``None`` or in Pure Colossal AI mode. 

136 

137 Args: 

138 module: Core-attention submodule. 

139 device_mesh: CP device mesh (1-D or 2-D). 

140 q_proj: The last module in the Q path whose output is passed 

141 directly to the attention module as Q. Its forward 

142 post-hook launches the async Q all-to-all. There 

143 must be **no** intermediate ops (view, transpose, …) 

144 between this module and attention; such ops would be 

145 bypassed by the pre-hook substitution and could cause 

146 shape mismatches. For models with QK normalization 

147 applied right before attention, pass ``qk_norm_q`` 

148 here instead of the raw projection. 

149 k_proj: Same semantics as ``q_proj``, for the K path. Pass 

150 ``qk_norm_k`` when the model applies QK-Norm before 

151 attention. 

152 v_proj: Value projection module (no norm variant needed). 

153 """ 

154 if q_proj is None or k_proj is None or v_proj is None: 

155 return super().apply(module, device_mesh) 

156 

157 cp_size = device_mesh.mesh.numel() 

158 ds = self.ulysses_degree if self.ulysses_degree is not None else cp_size 

159 if cp_size % ds != 0: 

160 raise ValueError( 

161 f"cp_size ({cp_size}) must be divisible by ulysses_degree ({ds})." 

162 ) 

163 co = cp_size // ds 

164 

165 if ds == 1: 

166 # Pure Colossal AI: K/V AllGather cannot be made async. Fall back. 

167 return super().apply(module, device_mesh) 

168 

169 # Per-layer handle slots — local to this apply() call, bound to hooks via partial. 

170 # 

171 # fwd_slots is a plain dict. _proj_post_hook and _wait_a2a both receive the 

172 # same dict reference via partial, so a simple assignment fwd_slots[key] = ... 

173 # in _proj_post_hook is immediately visible to _wait_a2a — no list wrapper needed. 

174 # 

175 # bwd_slots[key] is a list held by both _wait_a2a and the autograd wait function 

176 # The autograd function receives the list object itself (as handle_box) and appends 

177 # to it; _proj_bwd_pre_hook pops from the same list. We cannot use a plain dict 

178 # value here because the autograd function would hold a stale reference if we later 

179 # reassigned bwd_slots[key]. 

180 fwd_slots = {"q": None, "k": None, "v": None} 

181 bwd_slots = {"q": [], "k": [], "v": []} 

182 

183 if co == 1: 

184 # Pure Ulysses 

185 ds_submesh = _ensure_1d(device_mesh) 

186 group = ds_submesh.get_group() 

187 self._register_proj_hooks(q_proj, k_proj, v_proj, group=group, world_size=ds, 

188 fwd_slots=fwd_slots, bwd_slots=bwd_slots) 

189 module.register_forward_pre_hook( 

190 partial(self._attn_pre_hook_ulysses, group=group, world_size=ds, 

191 fwd_slots=fwd_slots, bwd_slots=bwd_slots) 

192 ) 

193 else: 

194 # Hybrid: async Ulysses A2A + sync Colossal AllGather 

195 two_d_mesh = _build_2d_mesh(device_mesh, ds, co) 

196 dim_names = two_d_mesh.mesh_dim_names 

197 assert dim_names is not None, "2-D mesh must have mesh_dim_names (guaranteed by _build_2d_mesh)" 

198 ds_submesh = two_d_mesh[dim_names[1]] 

199 group = ds_submesh.get_group() 

200 self._register_proj_hooks(q_proj, k_proj, v_proj, group=group, world_size=ds, 

201 fwd_slots=fwd_slots, bwd_slots=bwd_slots) 

202 module.register_forward_pre_hook( 

203 partial(self._attn_pre_hook_hybrid, group=group, world_size=ds, 

204 two_d_mesh=two_d_mesh, fwd_slots=fwd_slots, bwd_slots=bwd_slots) 

205 ) 

206 

207 module.register_forward_hook( 

208 partial(self._attn_post_hook_ata, ds_submesh=ds_submesh) 

209 ) 

210 return module 

211 

212 # ------------------------------------------------------------------ 

213 # Shared: projection hooks registration 

214 # ------------------------------------------------------------------ 

215 

216 def _register_proj_hooks(self, q_proj, k_proj, v_proj, group, world_size, fwd_slots, bwd_slots): 

217 """Register forward and backward hooks on all three projection modules.""" 

218 for key, proj in [("q", q_proj), ("k", k_proj), ("v", v_proj)]: 

219 proj.register_forward_hook( 

220 partial(self._proj_post_hook, key=key, group=group, world_size=world_size, 

221 fwd_slots=fwd_slots) 

222 ) 

223 platform.register_full_backward_pre_hook( 

224 proj, 

225 partial(self._proj_bwd_pre_hook, bwd_slot=bwd_slots[key]) 

226 ) 

227 

228 def _proj_post_hook(self, module, inputs, output, key, group, world_size, fwd_slots): # pylint: disable=unused-argument,too-many-arguments 

229 """Launch async seq→head A2A after projection; return original output unchanged.""" 

230 tensor = output.to_local() if isinstance(output, DTensor) else output 

231 fwd_slots[key] = _launch_async_a2a_seq_to_head( 

232 tensor, group, world_size, self.head_dim 

233 ) 

234 return output 

235 

236 # ------------------------------------------------------------------ 

237 # Internal: wait for a single pre-launched A2A handle 

238 # ------------------------------------------------------------------ 

239 

240 def _wait_a2a(self, tensor, group, world_size, fwd_slots, key, bwd_slot): 

241 """Wait for pre-launched A2A; returns head-scattered tensor (differentiable).""" 

242 work, out_perm = fwd_slots[key] 

243 fwd_slots[key] = None 

244 return platform.differentiable_async_a2a_wait( 

245 tensor, work, out_perm, group, world_size, 

246 self.seq_dim, self.head_dim, # concat_dim=seq_dim, split_dim=head_dim 

247 bwd_slot, 

248 ) 

249 

250 # ------------------------------------------------------------------ 

251 # Attention pre-hooks 

252 # ------------------------------------------------------------------ 

253 

254 def _attn_pre_hook_ulysses(self, module, args, group, world_size, # pylint: disable=unused-argument,too-many-arguments 

255 fwd_slots, bwd_slots): 

256 """Wait Q/K/V A2A; return head-scattered args.""" 

257 q_idx, k_idx, v_idx = self.qkv_indices 

258 new_args = list(args) 

259 

260 def _local(t): 

261 return t.to_local() if isinstance(t, DTensor) else t 

262 

263 new_args[q_idx] = self._wait_a2a(_local(new_args[q_idx]), group, world_size, 

264 fwd_slots, "q", bwd_slots["q"]) 

265 new_args[k_idx] = self._wait_a2a(_local(new_args[k_idx]), group, world_size, 

266 fwd_slots, "k", bwd_slots["k"]) 

267 new_args[v_idx] = self._wait_a2a(_local(new_args[v_idx]), group, world_size, 

268 fwd_slots, "v", bwd_slots["v"]) 

269 return tuple(new_args) 

270 

271 def _attn_pre_hook_hybrid( # pylint: disable=too-many-locals,too-many-arguments 

272 self, module, args, group, world_size, two_d_mesh, # pylint: disable=unused-argument 

273 fwd_slots, bwd_slots 

274 ): 

275 """Wait Ulysses A2A for Q/K/V, AllGather K/V on co-submesh, wrap as 2-D DTensors.""" 

276 q_idx, k_idx, v_idx = self.qkv_indices 

277 new_args = list(args) 

278 

279 def _local(t): 

280 return t.to_local() if isinstance(t, DTensor) else t 

281 

282 # Wait Ulysses A2A for Q and K 

283 q_ul = cast(Tensor, self._wait_a2a(_local(new_args[q_idx]), group, world_size, 

284 fwd_slots, "q", bwd_slots["q"])) 

285 k_ul = cast(Tensor, self._wait_a2a(_local(new_args[k_idx]), group, world_size, 

286 fwd_slots, "k", bwd_slots["k"])) 

287 

288 # AllGather K on co-submesh (while V A2A is still in flight) 

289 co_submesh = two_d_mesh[two_d_mesh.mesh_dim_names[0]] 

290 k_full = _gather_seq(k_ul, co_submesh, self.seq_dim) 

291 

292 # Wait V A2A, then AllGather V 

293 v_ul = cast(Tensor, self._wait_a2a(_local(new_args[v_idx]), group, world_size, 

294 fwd_slots, "v", bwd_slots["v"])) 

295 v_full = _gather_seq(v_ul, co_submesh, self.seq_dim) 

296 

297 def _local_dt(dt): 

298 return dt.to_local() if isinstance(dt, DTensor) else dt 

299 

300 new_args[q_idx] = DTensor.from_local( 

301 q_ul, two_d_mesh, (Shard(self.seq_dim), Shard(self.head_dim)) 

302 ) 

303 new_args[k_idx] = DTensor.from_local( 

304 _local_dt(k_full), two_d_mesh, (Replicate(), Shard(self.head_dim)) 

305 ) 

306 new_args[v_idx] = DTensor.from_local( 

307 _local_dt(v_full), two_d_mesh, (Replicate(), Shard(self.head_dim)) 

308 ) 

309 return tuple(new_args) 

310 

311 # ------------------------------------------------------------------ 

312 # Attention post-hook (Ulysses and Hybrid share the same reverse ATA) 

313 # ------------------------------------------------------------------ 

314 

315 def _attn_post_hook_ata(self, module, args, output, ds_submesh): # pylint: disable=unused-argument 

316 """Reverse head→seq gather on ds_submesh; returns local tensor.""" 

317 def _process(o): 

318 if isinstance(o, (Tensor, DTensor)): 

319 if isinstance(o, DTensor): 

320 o = o.to_local() 

321 return _gather_head_to_seq( 

322 o, ds_submesh, self.seq_dim, self.head_dim 

323 ).to_local() 

324 return o 

325 

326 if isinstance(output, (tuple, list)): 

327 return type(output)(_process(o) for o in output) 

328 return _process(output) 

329 

330 # ------------------------------------------------------------------ 

331 # Backward: wait A2A handle (launched by autograd) before proj GEMM 

332 # ------------------------------------------------------------------ 

333 

334 def _proj_bwd_pre_hook(self, module, grad_output, bwd_slot): # pylint: disable=unused-argument 

335 """Wait backward A2A just before proj GEMM; replace grad with seq-form. 

336 

337 The async head→seq A2A is launched inside _TorchAsyncA2AFunction.backward 

338 and appended to ``bwd_slot``. Waiting here lets the A2A overlap with the 

339 preceding proj GEMM. 

340 """ 

341 work, out_perm = bwd_slot.pop() 

342 work.wait() 

343 d_seq = _a2a_reconstruct(out_perm, self.head_dim) 

344 return (d_seq,) + grad_output[1:] if isinstance(grad_output, tuple) else (d_seq,)