Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / param_group.py: 30%

447 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/distributed/fsdp/_fully_shard/_fsdp_param.py 

16# enhanced with fully_shard parameter management 

17# ============================================================================ 

18"""HSDP parameter group. 

19 

20This module implements fused communication for HSDP (Hybrid Shard Data Parallel) parameters. 

21Instead of issuing one all-gather / reduce-scatter per parameter, ``HSDPParamGroup`` packs all 

22parameters within a module into a single contiguous buffer and performs one collective operation, 

23which reduces kernel launch overhead and improves bandwidth utilization. 

24 

25Key components: 

26- ``HSDPParamGroup``: Groups all HSDP parameters in a module for fused all-gather (forward) 

27 and fused reduce-scatter + all-reduce (backward). 

28- ``AllGatherMetadata`` / ``AllGatherMetadataCache``: Caches per-group metadata (dtypes, numels, 

29 split sizes) to avoid recomputation across iterations. 

30- ``CommContext``: Global context that tracks the in-flight async communication handle and the 

31 param group that owns it, enabling pipelined overlap between communication and computation. 

32""" 

33from typing import List, Optional, NamedTuple, Any 

34from dataclasses import dataclass, field 

35from contextlib import ExitStack 

36import torch 

37import torch.distributed as dist 

38from torch.distributed import Work 

39from hyper_parallel.core.fully_shard.utils import ( 

40 MixedPrecisionPolicy, 

41 FSDPMeshInfo, 

42 DDPMeshInfo, 

43 HSDPMeshInfo, 

44) 

45from hyper_parallel.platform.torch.fully_shard.pack_utils import ( 

46 build_rs_plan, 

47 pack_for_reduce_scatter, 

48) 

49from hyper_parallel.platform.torch.fully_shard.param import TorchHSDPParamV2 

50 

51 

52def get_all_gather_metadata(hsdp_params): 

53 """Collect metadata required for fused all-gather from all HSDP parameters. 

54 

55 Iterates over each parameter's local shard inputs and records their dtypes and 

56 element counts. All parameters must share the same dtype (heterogeneous dtypes 

57 are not yet supported). 

58 

59 Args: 

60 hsdp_params: List of ``TorchHSDPParamV2`` whose ``all_gather_inputs`` will 

61 be inspected. 

62 

63 Returns: 

64 AllGatherMetadata: Aggregated metadata used by ``foreach_all_gather`` to 

65 allocate the fused output buffer and perform copy-in/copy-out. 

66 

67 Raises: 

68 ValueError: If parameters have different dtypes. 

69 """ 

70 param_input_dtypes = [] 

71 param_input_numels = [] 

72 inp_split_sizes = [] 

73 total_input_numel = 0 

74 first_dtype = None 

75 

76 for hsdp_param in hsdp_params: 

77 inputs = hsdp_param.all_gather_inputs 

78 if first_dtype is None: 

79 first_dtype = inputs[0].dtype 

80 elif first_dtype != inputs[0].dtype: 

81 raise ValueError("All parameters in the group must have a uniform dtype.") 

82 param_dtypes = [t.dtype for t in inputs] 

83 param_numels = [t.numel() for t in inputs] 

84 param_input_dtypes.append(param_dtypes) 

85 param_input_numels.append(param_numels) 

86 inp_split_sizes.extend(param_numels) 

87 total_input_numel += sum(param_numels) 

88 

89 return AllGatherMetadata( 

90 param_input_dtypes, 

91 param_input_numels, 

92 first_dtype, 

93 inp_split_sizes, 

94 total_input_numel 

95 ) 

96 

97 

98@dataclass 

99class AllGatherMetadata: 

100 """Metadata describing the fused all-gather buffer layout. 

101 

102 Attributes: 

103 param_input_dtypes: Per-parameter list of input tensor dtypes. 

104 param_input_numels: Per-parameter list of input tensor element counts. 

105 dtype: Uniform dtype of all inputs (used to allocate the fused buffer). 

106 inp_split_sizes: Flat list of element counts for each input tensor across 

107 all parameters, used by ``torch.split`` / ``split_with_sizes_copy`` to 

108 slice the fused buffer back into per-parameter chunks. 

109 total_input_numel: Total number of elements from all local shards (one rank's 

110 contribution); the full all-gather output has ``total_input_numel * world_size`` 

111 elements. 

112 hash_key: Computed in ``__post_init__`` for use as a cache key. 

113 """ 

114 param_input_dtypes: list[list[torch.dtype]] 

115 param_input_numels: list[list[int]] 

116 dtype: torch.dtype 

117 inp_split_sizes: list[int] 

118 total_input_numel: int 

119 hash_key: int = field(init=False) 

120 

121 def __post_init__(self): 

122 self.hash_key = hash(( 

123 tuple(tuple(d) for d in self.param_input_dtypes), 

124 tuple(tuple(n) for n in self.param_input_numels), 

125 self.dtype, 

126 tuple(self.inp_split_sizes), 

127 self.total_input_numel 

128 )) 

129 

130 

131class AllGatherResult(NamedTuple): 

132 """Result of a fused all-gather operation. 

133 

134 Attributes: 

135 all_gather_output: The contiguous output buffer holding gathered data from all ranks. 

136 metadata: The ``AllGatherMetadata`` used to interpret the buffer layout. 

137 handle: Async work handle from ``dist.all_gather_into_tensor``; ``None`` when 

138 the operation was synchronous or when ``shard_world_size == 1``. 

139 """ 

140 all_gather_output: torch.Tensor 

141 metadata: AllGatherMetadata 

142 handle: Optional[Work] 

143 

144 

145@dataclass 

146class CommContext: 

147 """Global communication context for pipelining fused gradient reduction. 

148 

149 For FSDP (shard-only), the reduce-scatter handle is stored in ``comm_handle`` 

150 and the next module's backward hook waits on it before issuing its own reduction. 

151 

152 For HSDP (shard + replicate), a two-phase pipeline is used: 

153 Phase 1 (``wait_reduce_scatter_and_issue_all_reduce``): wait for 

154 reduce-scatter, then issue one or more async all-reduces stored on 

155 the owning ``HSDPParamGroup``. 

156 Phase 2 (``wait_all_reduce_and_apply_grad``): wait for all-reduce and 

157 write reduced gradients back. 

158 

159 This allows three-way overlap: 

160 Layer N reduce_scatter ↔ Layer N-1 backward compute 

161 Layer N all_reduce ↔ Layer N-1 reduce_scatter 

162 """ 

163 comm_handle: Optional[Work] = None 

164 all_reduce_handle: Optional[Work] = None 

165 pre_param_group = None 

166 # Param group whose all_reduce has been issued but grad not yet applied 

167 all_reduce_param_group = None 

168 

169 

170comm_ctx = CommContext() 

171 

172 

173def get_comm_ctx(): 

174 """Return the global ``CommContext`` singleton.""" 

175 return comm_ctx 

176 

177 

178@dataclass 

179class ReplicateBucket: 

180 """One fused all-reduce bucket sharing the same replicate process group.""" 

181 

182 key: int 

183 group: Any 

184 group_size: int 

185 param_indices: list[int] 

186 flat_numel: int 

187 buffer: Optional[torch.Tensor] = None 

188 

189 

190@dataclass 

191class PendingBucketAllReduce: 

192 """One in-flight async all-reduce launched for a replicate bucket.""" 

193 

194 bucket_key: int 

195 handle: Any 

196 

197 

198class AllGatherMetadataCache: 

199 """Cache for ``AllGatherMetadata`` to avoid recomputation across iterations. 

200 

201 The cache key is derived from ``(id(param), param.version)`` tuples so that 

202 it invalidates automatically when parameters are re-sharded or replaced. 

203 """ 

204 _cache: dict[int, AllGatherMetadata] = {} 

205 

206 @classmethod 

207 def get_metadata(cls, hsdp_params, fn): 

208 """Return cached metadata or compute via *fn* and cache the result.""" 

209 param_key = tuple((id(p), getattr(p, 'version', 0)) for p in hsdp_params) 

210 key = hash(param_key) 

211 

212 if key in cls._cache: 

213 return cls._cache[key] 

214 metadata = fn(hsdp_params) 

215 cls._cache[key] = metadata 

216 return metadata 

217 

218 

219def all_gather_copy_in(all_gather_inputs, all_gather_output, inp_split_sizes, all_gather_input_numel, rank): 

220 """Copy per-parameter local shards into the fused all-gather input buffer. 

221 

222 The fused output buffer has shape ``(total_input_numel * world_size,)``. Each rank 

223 writes its local shards into the slice ``[input_numel * rank : input_numel * (rank+1)]`` 

224 using ``torch._foreach_copy_`` for efficient batched copy. 

225 

226 Args: 

227 all_gather_inputs: Flat list of local shard tensors from all parameters. 

228 all_gather_output: The pre-allocated fused output buffer. 

229 inp_split_sizes: Element counts for splitting the rank-local slice. 

230 all_gather_input_numel: Total elements for one rank's local shards. 

231 rank: This rank's index within the shard process group. 

232 

233 Returns: 

234 Tuple of (rank-local input slice, full output buffer). 

235 """ 

236 all_gather_input = all_gather_output.narrow(0, all_gather_input_numel * rank, all_gather_input_numel) 

237 foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) 

238 with torch.no_grad(): 

239 # pylint: disable=W0212 

240 torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs) 

241 return all_gather_input, all_gather_output 

242 

243 

244def reduce_scatter_copy_in( 

245 hsdp_params: List[TorchHSDPParamV2], 

246 unsharded_grads: List[torch.Tensor], 

247 reduce_scatter_input: torch.Tensor, 

248 world_size: int, 

249) -> None: 

250 """Pack unsharded gradients into the fused reduce-scatter input buffer. 

251 

252 Uses ``torch._chunk_cat`` to interleave chunks from each gradient tensor so that 

253 the buffer layout matches what ``dist.reduce_scatter_tensor`` expects: the buffer 

254 is viewed as ``(world_size, total_numel // world_size)`` where row *i* contains 

255 the slice destined for rank *i* after reduction. 

256 

257 Args: 

258 hsdp_params: Parameters whose layout determines the pack plan per gradient. 

259 unsharded_grads: Full (unsharded) gradients from all parameters. 

260 reduce_scatter_input: Pre-allocated flat buffer of size ``sum(g.numel() for g in unsharded_grads)``. 

261 world_size: Number of ranks in the shard process group. 

262 """ 

263 if len(hsdp_params) != len(unsharded_grads): 

264 raise AssertionError( 

265 "reduce_scatter_copy_in expects one hsdp_param per unsharded_grad, but got " 

266 f"{len(hsdp_params)} params and {len(unsharded_grads)} grads" 

267 ) 

268 packed_rows = reduce_scatter_input.view(world_size, -1) 

269 col_offset = 0 

270 with torch.no_grad(): 

271 for hsdp_param, grad in zip(hsdp_params, unsharded_grads): 

272 grad = grad.contiguous() 

273 plan = build_rs_plan(hsdp_param, grad, world_size) 

274 packed_grad = pack_for_reduce_scatter(grad, plan) 

275 next_col_offset = col_offset + packed_grad.size(1) 

276 packed_rows[:, col_offset:next_col_offset].copy_(packed_grad) 

277 col_offset = next_col_offset 

278 if col_offset != packed_rows.size(1): 

279 raise AssertionError( 

280 "reduce_scatter_copy_in packed an unexpected number of elements: " 

281 f"{col_offset} != {packed_rows.size(1)}" 

282 ) 

283 

284 

285class HSDPParamGroup: 

286 """Groups all HSDP parameters within a module for fused collective communication. 

287 

288 Instead of issuing per-parameter all-gather (forward) and reduce-scatter (backward), 

289 this class packs all parameter shards into a single contiguous buffer and performs one 

290 fused collective, reducing NCCL/HCCL kernel launch overhead. 

291 

292 Lifecycle within one training iteration: 

293 1. **Forward** — ``unshard()`` → ``foreach_all_gather()`` packs local shards into 

294 ``ag_output`` and issues a single ``all_gather_into_tensor``. 

295 2. **Forward (wait)** — ``wait_for_unshard()`` → ``foreach_all_gather_copy_out()`` 

296 waits on the handle and scatters gathered data back to per-parameter buffers. 

297 3. **Backward** — ``foreach_reduce()`` packs unsharded gradients, issues fused 

298 ``reduce_scatter_tensor`` (+ optional ``all_reduce`` for HSDP replicate dim), 

299 and stores the handle in ``CommContext`` for pipelined overlap. 

300 4. **Backward (apply)** — ``apply_fusion_reduced_grad()`` waits on the handle and 

301 writes reduced gradient slices back to each parameter's ``.grad`` or ``.main_grad``. 

302 

303 Args: 

304 hsdp_params: List of ``TorchHSDPParamV2`` belonging to this module. 

305 mesh_info: Mesh info providing shard/replicate process groups. 

306 device: Target device for buffer allocation. 

307 mp_policy: Mixed-precision policy controlling reduce dtype and grad dtype. 

308 """ 

309 

310 def __init__( 

311 self, 

312 hsdp_params, 

313 mesh_info: FSDPMeshInfo, 

314 device: Optional[torch.device] = None, 

315 mp_policy: Optional[MixedPrecisionPolicy] = None, 

316 enable_zero_copy: bool = True, 

317 ): 

318 self.mesh_info = mesh_info 

319 self.device = device 

320 self.hsdp_params = hsdp_params 

321 if isinstance(self.mesh_info, (FSDPMeshInfo, HSDPMeshInfo)): 

322 self.shard_rank = self.mesh_info.shard_mesh_rank 

323 self.shard_world_size = self.mesh_info.shard_mesh_size 

324 else: 

325 self.shard_rank = 0 

326 self.shard_world_size = 1 

327 self.shard_group = self.mesh_info.shard_process_group 

328 self.replicate_group = None 

329 if isinstance(self.mesh_info, (HSDPMeshInfo, DDPMeshInfo)): 

330 self.replicate_group = self.mesh_info.replicate_process_group 

331 elif isinstance(self.mesh_info, FSDPMeshInfo): 

332 self.replicate_group = self._infer_layout_replicate_group() 

333 self.device = device 

334 self._all_gather_output = torch.empty(0, device=self.device) 

335 self.ag_output = None # Fused all-gather output buffer, lazily allocated 

336 self.metadata_cache = None 

337 self.mp_policy = mp_policy 

338 self.enable_zero_copy = enable_zero_copy 

339 self._result = None # Pending AllGatherResult from async all-gather 

340 self._reduce_output = None # Fused reduce-scatter output, consumed by apply_fusion_reduced_grad 

341 self._reduce_op = None # Reduce op saved from foreach_reduce for use in apply_fusion_reduced_grad 

342 self._needs_avg_div = False # Whether AVG was split into SUM + deferred div 

343 self._reduce_hsdp_params = None 

344 self._active_replicate_buckets: dict[int, ReplicateBucket] = {} 

345 self._active_param_flat_offsets: list[int] = [] 

346 self._pending_all_reduce_handles: list[PendingBucketAllReduce] = [] 

347 self._init_mp_dtypes() 

348 self._flat_param_buffer = None # Contiguous buffer holding all params' sharded data 

349 self._flat_cast_buffer = None # Cast buffer for mixed precision (param_dtype) 

350 if self.enable_zero_copy: 

351 self._init_flat_param_buffer() 

352 

353 def _infer_layout_replicate_group(self): 

354 """Infer a compatibility all-reduce group from params' final DTensor layout when mesh_info has none. 

355 

356 DTENSOR_UNIFIED parameters may still carry replicate axes from the original 

357 DTensor layout, for example a ``(tp, ep)`` mesh where ``ep`` is replicate-only. 

358 The non-fused path derives this group from each param's layout-driven 

359 ``unsharded_group_info``. ``comm_fusion`` now buckets by those groups, so 

360 this helper only preserves the historical ``self.replicate_group`` field 

361 for compatibility with simpler single-group paths. 

362 """ 

363 replicate_groups = [] 

364 for hsdp_param in self.hsdp_params: 

365 group_info = getattr(hsdp_param, "unsharded_group_info", None) 

366 group = getattr(group_info, "group", None) 

367 if group is None or getattr(hsdp_param, "replicate_world_size", 1) <= 1: 

368 continue 

369 replicate_groups.append((group, getattr(hsdp_param, "_param_fqn", "<unknown>"))) 

370 

371 if not replicate_groups: 

372 return None 

373 

374 ref_group, _ = replicate_groups[0] 

375 return ref_group 

376 

377 def _build_active_replicate_buckets(self, hsdp_params): 

378 """Group active params by their layout-driven replicate all-reduce group.""" 

379 buckets: dict[int, ReplicateBucket] = {} 

380 for idx, hsdp_param in enumerate(hsdp_params): 

381 group_info = getattr(hsdp_param, "unsharded_group_info", None) 

382 group = getattr(group_info, "group", None) 

383 group_size = getattr( 

384 group_info, 

385 "rank_size", 

386 getattr(hsdp_param, "replicate_world_size", 1), 

387 ) 

388 if not isinstance(group_size, int): 

389 fallback_group_size = getattr(hsdp_param, "replicate_world_size", 1) 

390 group_size = fallback_group_size if isinstance(fallback_group_size, int) else 1 

391 if group is None or group_size <= 1: 

392 continue 

393 

394 key = id(group) 

395 if key not in buckets: 

396 buckets[key] = ReplicateBucket( 

397 key=key, 

398 group=group, 

399 group_size=group_size, 

400 param_indices=[], 

401 flat_numel=0, 

402 ) 

403 buckets[key].param_indices.append(idx) 

404 buckets[key].flat_numel += hsdp_param.sharded_size.numel() 

405 return buckets 

406 

407 def _allocate_bucket_buffers_if_needed(self, device, dtype): 

408 """Allocate or resize per-bucket temporary all-reduce buffers.""" 

409 for bucket in self._active_replicate_buckets.values(): 

410 if bucket.flat_numel == 0: 

411 continue 

412 needs_new_buffer = ( 

413 bucket.buffer is None 

414 or bucket.buffer.numel() != bucket.flat_numel 

415 or bucket.buffer.device != device 

416 or bucket.buffer.dtype != dtype 

417 ) 

418 if needs_new_buffer: 

419 bucket.buffer = torch.empty(bucket.flat_numel, device=device, dtype=dtype) 

420 

421 def _pack_bucket_from_reduce_output(self, bucket: ReplicateBucket) -> torch.Tensor: 

422 """Pack one replicate bucket's scattered shards into a contiguous all-reduce buffer.""" 

423 if bucket.buffer is None: 

424 raise AssertionError("Bucket buffer must be allocated before packing from reduce output") 

425 if self._reduce_output is None or self._reduce_hsdp_params is None: 

426 raise AssertionError("Bucket packing requires an active fused reduce output") 

427 dst_offset = 0 

428 for idx in bucket.param_indices: 

429 hsdp_param = self._reduce_hsdp_params[idx] 

430 src_offset = self._active_param_flat_offsets[idx] 

431 numel = hsdp_param.sharded_size.numel() 

432 bucket.buffer.narrow(0, dst_offset, numel).copy_( 

433 self._reduce_output.narrow(0, src_offset, numel) 

434 ) 

435 dst_offset += numel 

436 return bucket.buffer 

437 

438 def _unpack_bucket_to_reduce_output(self, bucket: ReplicateBucket) -> None: 

439 """Write one bucket's post-all-reduce data back into the fused reduce output.""" 

440 if bucket.buffer is None: 

441 raise AssertionError("Bucket buffer must exist before unpacking to reduce output") 

442 if self._reduce_output is None or self._reduce_hsdp_params is None: 

443 raise AssertionError("Bucket unpack requires an active fused reduce output") 

444 src_offset = 0 

445 for idx in bucket.param_indices: 

446 hsdp_param = self._reduce_hsdp_params[idx] 

447 dst_offset = self._active_param_flat_offsets[idx] 

448 numel = hsdp_param.sharded_size.numel() 

449 self._reduce_output.narrow(0, dst_offset, numel).copy_( 

450 bucket.buffer.narrow(0, src_offset, numel) 

451 ) 

452 src_offset += numel 

453 

454 def _init_flat_param_buffer(self): 

455 """Initialize a contiguous flat buffer and rebase all params' sharded data into it. 

456 

457 This enables zero-copy all-gather by making all local shards contiguous in memory, 

458 so they can be passed directly to ``all_gather_into_tensor`` without ``foreach_copy_``. 

459 When mixed-precision casting is needed, a separate cast buffer is also allocated. 

460 """ 

461 if self.shard_world_size <= 1: 

462 return 

463 if len(self.hsdp_params) == 0: 

464 return 

465 if any(p.offload_to_cpu or p.sharded_param.device.type == "meta" for p in self.hsdp_params): 

466 return 

467 

468 total_numel = sum(p._sharded_param_data.numel() for p in self.hsdp_params) 

469 orig_dtype = self.hsdp_params[0]._sharded_param_data.dtype 

470 flat_buffer = torch.empty(total_numel, dtype=orig_dtype, device=self.device) 

471 

472 offset = 0 

473 for hsdp_param in self.hsdp_params: 

474 numel = hsdp_param._sharded_param_data.numel() 

475 flat_slice = flat_buffer.narrow(0, offset, numel) 

476 flat_slice.copy_(hsdp_param._sharded_param_data) 

477 # Rebase _sharded_param_data to be a view into the flat buffer 

478 hsdp_param._sharded_param_data = flat_slice 

479 # Rebase DTensor's local tensor so optimizer in-place updates write to flat buffer 

480 new_local = flat_slice.view(hsdp_param.sharded_size) 

481 req_grad = hsdp_param.sharded_param.requires_grad 

482 hsdp_param.sharded_param._local_tensor = new_local 

483 hsdp_param.sharded_param.data = new_local 

484 if req_grad: 

485 new_local.requires_grad_(True) 

486 hsdp_param.sharded_param.requires_grad_(True) 

487 offset += numel 

488 

489 self._flat_param_buffer = flat_buffer 

490 

491 # Allocate cast buffer for mixed precision if needed 

492 has_param_dtype = any(p.param_dtype is not None for p in self.hsdp_params) 

493 if has_param_dtype: 

494 cast_dtype = next(p.param_dtype for p in self.hsdp_params if p.param_dtype is not None) 

495 self._flat_cast_buffer = torch.empty(total_numel, dtype=cast_dtype, device=self.device) 

496 

497 def _is_flat_buffer_valid(self): 

498 """Check if the flat buffer is still backing the params' sharded data. 

499 

500 The flat buffer becomes invalid after ``load_state_dict`` triggers 

501 ``reset_sharded_param``, which re-assigns ``_sharded_param_data``. 

502 """ 

503 if self._flat_param_buffer is None or len(self.hsdp_params) == 0: 

504 return False 

505 return self.hsdp_params[0]._sharded_param_data.data_ptr() == self._flat_param_buffer.data_ptr() 

506 

507 def unshard(self, async_op: bool = False): 

508 """Trigger fused all-gather to reconstruct full parameters from shards. 

509 

510 If a prefetch has already been issued (``_result is not None``), this is a no-op. 

511 For ``shard_world_size == 1`` (no sharding), skips the collective entirely. 

512 

513 Args: 

514 async_op: If True, the all-gather runs asynchronously and must be 

515 completed later via ``wait_for_unshard()``. 

516 """ 

517 # Already prefetched — skip 

518 if self._result is not None: 

519 return 

520 if self.shard_world_size == 1: 

521 self._result = AllGatherResult(self._all_gather_output, None, None) 

522 return 

523 self.foreach_all_gather(async_op=async_op) 

524 

525 def _init_mp_dtypes(self): 

526 """Initialize and validate mixed-precision dtypes across all trainable parameters. 

527 

528 All trainable parameters in the group must have a uniform ``orig_dtype`` and 

529 ``reduce_dtype``; heterogeneous dtypes would cause incorrect buffer slicing. 

530 """ 

531 for hsdp_param in self.hsdp_params: 

532 hsdp_param.init_dtype_attrs(self.mp_policy) 

533 trainable_params: list[TorchHSDPParamV2] = [ 

534 p for p in self.hsdp_params if p.sharded_param.requires_grad 

535 ] 

536 orig_dtypes = {p.orig_dtype for p in trainable_params} 

537 reduce_dtypes = {p.reduce_dtype for p in trainable_params} 

538 if len(trainable_params) > 0 and len(orig_dtypes) != 1: 

539 raise AssertionError( 

540 f"hsdp expects uniform original parameter dtype but got {orig_dtypes}" 

541 ) 

542 self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None 

543 if len(trainable_params) > 0 and len(reduce_dtypes) != 1: 

544 raise AssertionError( 

545 f"hsdp expects uniform reduce dtype but got {reduce_dtypes}" 

546 ) 

547 self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None 

548 

549 def wait_for_unshard(self): 

550 """Wait for the async all-gather to complete and scatter data to per-parameter buffers. 

551 

552 For ``shard_world_size == 1``, simply copies the local shard as the full parameter. 

553 Otherwise, calls ``foreach_all_gather_copy_out`` to split the fused buffer and 

554 write each parameter's all-gather output. Finally, initializes unsharded parameters. 

555 """ 

556 if self._result is None: 

557 return 

558 if self.shard_world_size == 1: 

559 for hsdp_param in self.hsdp_params: 

560 all_gather_input = hsdp_param.all_gather_inputs[0] 

561 hsdp_param.init_all_gather_outputs( 

562 [all_gather_input.numel()], 

563 [all_gather_input.dtype], 

564 self.shard_world_size, 

565 self.device 

566 ) 

567 hsdp_param.alloc_all_gather_outputs() 

568 # pylint: disable=W0212 

569 with torch.autograd._unsafe_preserve_version_counter(hsdp_param.all_gather_outputs[0]): 

570 # pylint: disable=W0212 

571 hsdp_param.all_gather_outputs[0].copy_(all_gather_input) 

572 else: 

573 self.foreach_all_gather_copy_out() 

574 for hsdp_param in self.hsdp_params: 

575 hsdp_param.init_unsharded_param() 

576 hsdp_param.to_unsharded() 

577 

578 def alloc_all_gather_output(self, total_output_numel): 

579 """Resize the fused all-gather buffer storage to fit ``total_output_numel`` elements. 

580 

581 Uses ``untyped_storage().resize_()`` to avoid reallocating the tensor object, 

582 enabling storage reuse across iterations. 

583 """ 

584 storage = self.ag_output.untyped_storage() 

585 expected_size = total_output_numel * self.ag_output.itemsize 

586 if storage.size() != expected_size: 

587 storage.resize_(expected_size) 

588 

589 def free_all_gather_output(self): 

590 """Release device memory of the fused all-gather buffer by resizing storage to 0.""" 

591 storage = self.ag_output.untyped_storage() 

592 if storage.size() != 0: 

593 storage.resize_(0) 

594 

595 @torch.no_grad() 

596 def foreach_all_gather(self, async_op=False): 

597 """Perform a fused all-gather for all parameters in the group. 

598 

599 When a flat parameter buffer is available (see ``_init_flat_param_buffer``), 

600 the local shards are already contiguous and can be passed directly to 

601 ``all_gather_into_tensor`` without any copy-in. Otherwise falls back to 

602 the ``all_gather_copy_in`` path. 

603 

604 Args: 

605 async_op: If True, the collective runs asynchronously. 

606 """ 

607 if self.metadata_cache is None: 

608 self.metadata_cache = AllGatherMetadataCache() 

609 # pylint: disable=W0108 

610 metadata = self.metadata_cache.get_metadata(self.hsdp_params, lambda p: get_all_gather_metadata(p)) 

611 if metadata.total_input_numel == 0: 

612 return 

613 world_size, rank = self.shard_group.size(), self.shard_group.rank() 

614 total_output_numel = metadata.total_input_numel * world_size 

615 if self.ag_output is None: 

616 self.ag_output = torch.empty(size=(total_output_numel,), 

617 dtype=metadata.dtype, device=self.device) 

618 else: 

619 self.alloc_all_gather_output(total_output_numel) 

620 

621 if self.enable_zero_copy and not self._is_flat_buffer_valid(): 

622 self._init_flat_param_buffer() 

623 use_flat_buffer = self.enable_zero_copy and self._flat_param_buffer is not None 

624 if use_flat_buffer: 

625 # Zero-copy path: flat buffer already holds contiguous shard data 

626 if self._flat_cast_buffer is not None: 

627 # Mixed precision: single contiguous cast instead of N small copies 

628 self._flat_cast_buffer.copy_(self._flat_param_buffer) 

629 all_gather_input = self._flat_cast_buffer 

630 else: 

631 all_gather_input = self._flat_param_buffer 

632 else: 

633 # Fallback: collect inputs and copy into the rank-local slice of ag_output 

634 all_gather_inputs = [] 

635 for hsdp_param in self.hsdp_params: 

636 all_gather_inputs.extend(hsdp_param.all_gather_inputs) 

637 if len(all_gather_inputs) == 0: 

638 return 

639 all_gather_input, _ = all_gather_copy_in( 

640 all_gather_inputs, 

641 self.ag_output, 

642 metadata.inp_split_sizes, 

643 metadata.total_input_numel, 

644 rank 

645 ) 

646 del all_gather_inputs # Free references to individual shard tensors 

647 

648 handle = dist.all_gather_into_tensor(self.ag_output, all_gather_input, self.shard_group, async_op) 

649 self._result = AllGatherResult(self.ag_output, metadata, handle) 

650 

651 @torch.no_grad() 

652 def foreach_all_gather_copy_out(self): 

653 """Wait for the fused all-gather and scatter results back to per-parameter buffers. 

654 

655 After the collective completes, the fused output is viewed as ``(world_size, -1)`` 

656 and split along dim=1 according to ``inp_split_sizes``. Each slice is copied into 

657 the corresponding parameter's ``all_gather_outputs`` buffer using 

658 ``split_with_sizes_copy`` for zero-extra-allocation copy-out. 

659 

660 Version counters are preserved via ``_unsafe_preserve_version_counter`` to avoid 

661 triggering autograd version checks on parameter tensors that alias these buffers. 

662 """ 

663 (ag_output, metadata, _) = self._result 

664 if self._result.handle is not None: 

665 self._result.handle.wait() 

666 device = ag_output.device 

667 world_size = self.shard_group.size() 

668 split_with_sizes_out = [] 

669 for input_numels, input_dtypes, hsdp_param in zip( 

670 metadata.param_input_numels, metadata.param_input_dtypes, self.hsdp_params 

671 ): 

672 hsdp_param.init_all_gather_outputs(input_numels, input_dtypes, world_size, device) 

673 hsdp_param.alloc_all_gather_outputs() 

674 split_with_sizes_out.extend(hsdp_param.all_gather_outputs) 

675 ag_output = ag_output.view(world_size, -1) 

676 out = [t.view(world_size, -1) for t in split_with_sizes_out] 

677 non_inference_outs = [o for o in out if not o.is_inference()] 

678 if len(non_inference_outs) > 0: 

679 # Older torch variants only accept one tensor per context manager. 

680 # Preserve all version counters explicitly for cross-version compatibility. 

681 # pylint: disable=W0212 

682 with ExitStack() as stack: 

683 for tensor in non_inference_outs: 

684 stack.enter_context(torch.autograd._unsafe_preserve_version_counter(tensor)) 

685 torch.split_with_sizes_copy(ag_output, metadata.inp_split_sizes, dim=1, out=out) 

686 else: 

687 torch.split_with_sizes_copy(ag_output, metadata.inp_split_sizes, dim=1, out=out) 

688 self._result = None 

689 self.free_all_gather_output() # Immediately release fused buffer memory 

690 

691 @torch.no_grad() 

692 def foreach_reduce( 

693 self, 

694 reduce_scatter_reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG, 

695 async_op: bool = True 

696 ) -> Optional[torch.Tensor]: 

697 """Perform fused gradient reduction (reduce-scatter + optional all-reduce). 

698 

699 Collects unsharded gradients from all parameters, packs them into a single 

700 contiguous buffer, and issues one ``reduce_scatter_tensor``. For HSDP (2D mesh), 

701 a follow-up ``all_reduce`` across the replicate dimension is also performed. 

702 

703 When ``async_op=True``, the communication handle is stored in the global 

704 ``CommContext`` so that the next module's backward hook can overlap computation 

705 with this reduction. The actual gradient write-back is deferred to 

706 ``apply_fusion_reduced_grad()``. 

707 

708 Args: 

709 reduce_scatter_reduce_op: Reduction operator (default: AVG). 

710 async_op: If True, run collectives asynchronously for compute-comm overlap. 

711 """ 

712 # Collect unsharded gradients (from accumulated grad or .grad) 

713 hsdp_params: List[TorchHSDPParamV2] = [] 

714 unsharded_grads: List[torch.Tensor] = [] 

715 for hsdp_param in self.hsdp_params: 

716 if not hasattr(hsdp_param, '_unsharded_param'): 

717 continue 

718 if hsdp_param.unsharded_accumulated_grad is not None: 

719 hsdp_params.append(hsdp_param) 

720 unsharded_grads.append(hsdp_param.unsharded_accumulated_grad_data) 

721 elif hsdp_param._unsharded_param.grad is not None: # pylint: disable=W0212 

722 hsdp_params.append(hsdp_param) 

723 unsharded_grads.append(hsdp_param.unsharded_grad_data) 

724 if not hsdp_params: 

725 return 

726 grad_dtypes = {g.dtype for g in unsharded_grads} 

727 if len(grad_dtypes) != 1: 

728 raise ValueError( 

729 f"FSDP reduce-scatter expects uniform grad dtype but got {grad_dtypes}" 

730 ) 

731 grad_dtype = unsharded_grads[0].dtype 

732 reduce_dtype = self._reduce_dtype or grad_dtype 

733 world_size = self.shard_group.size() 

734 reduce_scatter_input_numel = sum(s.numel() for s in unsharded_grads) 

735 reduce_scatter_output_numel = reduce_scatter_input_numel // world_size 

736 device = unsharded_grads[0].device 

737 # Pack all gradients into a contiguous buffer for fused reduce-scatter 

738 reduce_scatter_input = torch.empty((reduce_scatter_input_numel,), dtype=reduce_dtype, device=device) 

739 reduce_scatter_copy_in(hsdp_params, unsharded_grads, reduce_scatter_input, world_size) 

740 unsharded_grads.clear() # Release references to full gradients 

741 reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,)) 

742 self._needs_avg_div = reduce_scatter_reduce_op == dist.ReduceOp.AVG 

743 comm_op = dist.ReduceOp.SUM if self._needs_avg_div else reduce_scatter_reduce_op 

744 self._reduce_op = comm_op 

745 self._reduce_hsdp_params = hsdp_params 

746 self._active_param_flat_offsets = [] 

747 flat_offset = 0 

748 for hsdp_param in hsdp_params: 

749 self._active_param_flat_offsets.append(flat_offset) 

750 flat_offset += hsdp_param.sharded_size.numel() 

751 self._active_replicate_buckets = self._build_active_replicate_buckets(hsdp_params) 

752 self._allocate_bucket_buffers_if_needed(reduce_output.device, reduce_output.dtype) 

753 self._pending_all_reduce_handles = [] 

754 rs_handle = dist.reduce_scatter_tensor( 

755 output=reduce_output, 

756 input=reduce_scatter_input, 

757 group=self.shard_group, 

758 op=comm_op, 

759 async_op=async_op 

760 ) 

761 comm_ctx.comm_handle = rs_handle 

762 # Step 2 (HSDP only): All-reduce is deferred to apply_fusion_reduced_grad() 

763 self._reduce_output = reduce_output 

764 if async_op: 

765 # Register this group for deferred grad application by the next backward hook 

766 comm_ctx.pre_param_group = self 

767 else: 

768 self.apply_fusion_reduced_grad() 

769 

770 def wait_reduce_scatter_and_issue_all_reduce(self): 

771 """Phase 1 of pipelined HSDP gradient reduction. 

772 

773 Waits for the async reduce-scatter to complete, then issues an async 

774 all-reduce for each active replicate bucket. The bucket handles are 

775 stored on this ``HSDPParamGroup`` so they can overlap with the next 

776 layer's reduce-scatter (Phase 2 is deferred). 

777 

778 For FSDP (no replicate group), skips the all-reduce and directly 

779 applies gradients since there is nothing further to pipeline. 

780 """ 

781 if comm_ctx.comm_handle is not None: 

782 comm_ctx.comm_handle.wait() 

783 comm_ctx.comm_handle = None 

784 # Deferred div for AVG: apply after RS completes, before AR 

785 if self._needs_avg_div: 

786 self._reduce_output.div_(self.shard_world_size) 

787 if not self._active_replicate_buckets: 

788 # No replicate group — no all-reduce needed, apply grads immediately 

789 self._apply_reduced_grad() 

790 return 

791 

792 self._pending_all_reduce_handles = [] 

793 for bucket in self._active_replicate_buckets.values(): 

794 packed = self._pack_bucket_from_reduce_output(bucket) 

795 ar_handle = dist.all_reduce( 

796 packed, 

797 group=bucket.group, 

798 op=self._reduce_op, 

799 async_op=True, 

800 ) 

801 self._pending_all_reduce_handles.append( 

802 PendingBucketAllReduce(bucket_key=bucket.key, handle=ar_handle) 

803 ) 

804 comm_ctx.all_reduce_param_group = self 

805 

806 def wait_all_reduce_and_apply_grad(self): 

807 """Phase 2 of pipelined HSDP gradient reduction. 

808 

809 Waits for the async all-reduce issued in Phase 1 and writes reduced 

810 gradients back to sharded parameters. 

811 """ 

812 for pending in self._pending_all_reduce_handles: 

813 bucket = self._active_replicate_buckets[pending.bucket_key] 

814 pending.handle.wait() 

815 if self._needs_avg_div: 

816 bucket.buffer.div_(bucket.group_size) 

817 self._unpack_bucket_to_reduce_output(bucket) 

818 self._pending_all_reduce_handles = [] 

819 comm_ctx.all_reduce_handle = None 

820 self._apply_reduced_grad() 

821 

822 def apply_fusion_reduced_grad(self): 

823 """Full synchronous reduction path (used for final drain and sync mode). 

824 

825 Waits for reduce-scatter, performs synchronous all-reduce if needed, 

826 and applies gradients — all in one call without pipelining. 

827 """ 

828 if comm_ctx.comm_handle is not None: 

829 comm_ctx.comm_handle.wait() 

830 comm_ctx.comm_handle = None 

831 # Deferred div for AVG after RS 

832 if self._needs_avg_div: 

833 self._reduce_output.div_(self.shard_world_size) 

834 for bucket in self._active_replicate_buckets.values(): 

835 packed = self._pack_bucket_from_reduce_output(bucket) 

836 dist.all_reduce( 

837 packed, 

838 group=bucket.group, 

839 op=self._reduce_op, 

840 ) 

841 # Deferred div for AVG after AR 

842 if self._needs_avg_div: 

843 packed.div_(bucket.group_size) 

844 self._unpack_bucket_to_reduce_output(bucket) 

845 self._apply_reduced_grad() 

846 

847 def _apply_reduced_grad(self): 

848 """Write reduced gradients from ``_reduce_output`` back to sharded parameters. 

849 

850 Slices the fused ``_reduce_output`` buffer into per-parameter sharded gradients 

851 using ``torch.as_strided`` (zero-copy view), then either accumulates into the 

852 existing ``.grad`` / ``.main_grad`` or assigns a new DTensor gradient. 

853 

854 Handles: 

855 - Mixed-precision: casts reduced gradient to ``_orig_dtype`` if needed. 

856 - CPU offload: transfers gradient to CPU (``non_blocking`` when possible). 

857 - Gradient accumulation: adds to existing grad when present. 

858 - Memory cleanup: nulls out unsharded grad references to free memory. 

859 """ 

860 flat_grad_offset = 0 

861 if self._reduce_hsdp_params is None: 

862 return 

863 for hsdp_param in self._reduce_hsdp_params: 

864 # Determine target gradient tensor (regular .grad or fp32 main_grad) 

865 sharded_grad = None 

866 if not self.mp_policy.apply_grad_on_fp32_main_grad: 

867 sharded_grad = hsdp_param.sharded_param.grad 

868 else: 

869 if not hasattr(hsdp_param.sharded_param, "main_grad"): 

870 hsdp_param.sharded_param.main_grad = None 

871 sharded_grad = hsdp_param.sharded_param.main_grad 

872 shard_size = hsdp_param.sharded_size 

873 # Zero-copy view into the fused reduce output for this parameter's shard 

874 new_sharded_grad = torch.as_strided( 

875 self._reduce_output, 

876 size=shard_size, 

877 stride=hsdp_param.contiguous_sharded_stride, 

878 storage_offset=flat_grad_offset, 

879 ) 

880 # Cast to original dtype if reduce was done in a different precision 

881 if not self.mp_policy.apply_grad_on_fp32_main_grad and new_sharded_grad.dtype != self._orig_dtype: 

882 new_sharded_grad = new_sharded_grad.to(self._orig_dtype) 

883 need_synchronize = False 

884 if hsdp_param.offload_to_cpu: 

885 non_blocking = hsdp_param.pin_memory and sharded_grad is None 

886 new_sharded_grad = new_sharded_grad.to( 

887 torch.device("cpu"), non_blocking=non_blocking 

888 ) 

889 need_synchronize = True 

890 # Accumulate or assign gradient 

891 if sharded_grad is not None: 

892 if not self.mp_policy.apply_grad_on_fp32_main_grad: 

893 hsdp_param.sharded_param.grad._local_tensor += new_sharded_grad 

894 else: 

895 hsdp_param.sharded_param.main_grad._local_tensor += new_sharded_grad 

896 hsdp_param.sharded_param.grad = None 

897 else: 

898 if not self.mp_policy.apply_grad_on_fp32_main_grad: 

899 hsdp_param.sharded_param.grad = hsdp_param.to_sharded_dtensor(new_sharded_grad) 

900 else: 

901 hsdp_param.sharded_param.main_grad = hsdp_param.to_sharded_dtensor(new_sharded_grad) 

902 hsdp_param.sharded_param.grad = None 

903 flat_grad_offset += shard_size.numel() 

904 # Release unsharded gradient references to free memory 

905 if hsdp_param.unsharded_accumulated_grad is not None: 

906 hsdp_param.unsharded_accumulated_grad = None 

907 elif hsdp_param.unsharded_param.grad is not None: 

908 hsdp_param.unsharded_param.grad = None 

909 

910 if need_synchronize: 

911 if self.device.type == "npu": 

912 torch.npu.current_stream().synchronize() 

913 elif self.device.type == "cuda": 

914 torch.cuda.current_stream().synchronize() 

915 else: 

916 raise NotImplementedError(f"Unsupported device type {self.device} for \ 

917 synchronization after CPU offload.") 

918 self._reduce_output = None # Release fused reduce buffer 

919 self._reduce_hsdp_params = None 

920 self._active_param_flat_offsets = [] 

921 self._active_replicate_buckets = {} 

922 self._pending_all_reduce_handles = []