Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / param_group.py: 30%
447 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/distributed/fsdp/_fully_shard/_fsdp_param.py
16# enhanced with fully_shard parameter management
17# ============================================================================
18"""HSDP parameter group.
20This module implements fused communication for HSDP (Hybrid Shard Data Parallel) parameters.
21Instead of issuing one all-gather / reduce-scatter per parameter, ``HSDPParamGroup`` packs all
22parameters within a module into a single contiguous buffer and performs one collective operation,
23which reduces kernel launch overhead and improves bandwidth utilization.
25Key components:
26- ``HSDPParamGroup``: Groups all HSDP parameters in a module for fused all-gather (forward)
27 and fused reduce-scatter + all-reduce (backward).
28- ``AllGatherMetadata`` / ``AllGatherMetadataCache``: Caches per-group metadata (dtypes, numels,
29 split sizes) to avoid recomputation across iterations.
30- ``CommContext``: Global context that tracks the in-flight async communication handle and the
31 param group that owns it, enabling pipelined overlap between communication and computation.
32"""
33from typing import List, Optional, NamedTuple, Any
34from dataclasses import dataclass, field
35from contextlib import ExitStack
36import torch
37import torch.distributed as dist
38from torch.distributed import Work
39from hyper_parallel.core.fully_shard.utils import (
40 MixedPrecisionPolicy,
41 FSDPMeshInfo,
42 DDPMeshInfo,
43 HSDPMeshInfo,
44)
45from hyper_parallel.platform.torch.fully_shard.pack_utils import (
46 build_rs_plan,
47 pack_for_reduce_scatter,
48)
49from hyper_parallel.platform.torch.fully_shard.param import TorchHSDPParamV2
52def get_all_gather_metadata(hsdp_params):
53 """Collect metadata required for fused all-gather from all HSDP parameters.
55 Iterates over each parameter's local shard inputs and records their dtypes and
56 element counts. All parameters must share the same dtype (heterogeneous dtypes
57 are not yet supported).
59 Args:
60 hsdp_params: List of ``TorchHSDPParamV2`` whose ``all_gather_inputs`` will
61 be inspected.
63 Returns:
64 AllGatherMetadata: Aggregated metadata used by ``foreach_all_gather`` to
65 allocate the fused output buffer and perform copy-in/copy-out.
67 Raises:
68 ValueError: If parameters have different dtypes.
69 """
70 param_input_dtypes = []
71 param_input_numels = []
72 inp_split_sizes = []
73 total_input_numel = 0
74 first_dtype = None
76 for hsdp_param in hsdp_params:
77 inputs = hsdp_param.all_gather_inputs
78 if first_dtype is None:
79 first_dtype = inputs[0].dtype
80 elif first_dtype != inputs[0].dtype:
81 raise ValueError("All parameters in the group must have a uniform dtype.")
82 param_dtypes = [t.dtype for t in inputs]
83 param_numels = [t.numel() for t in inputs]
84 param_input_dtypes.append(param_dtypes)
85 param_input_numels.append(param_numels)
86 inp_split_sizes.extend(param_numels)
87 total_input_numel += sum(param_numels)
89 return AllGatherMetadata(
90 param_input_dtypes,
91 param_input_numels,
92 first_dtype,
93 inp_split_sizes,
94 total_input_numel
95 )
98@dataclass
99class AllGatherMetadata:
100 """Metadata describing the fused all-gather buffer layout.
102 Attributes:
103 param_input_dtypes: Per-parameter list of input tensor dtypes.
104 param_input_numels: Per-parameter list of input tensor element counts.
105 dtype: Uniform dtype of all inputs (used to allocate the fused buffer).
106 inp_split_sizes: Flat list of element counts for each input tensor across
107 all parameters, used by ``torch.split`` / ``split_with_sizes_copy`` to
108 slice the fused buffer back into per-parameter chunks.
109 total_input_numel: Total number of elements from all local shards (one rank's
110 contribution); the full all-gather output has ``total_input_numel * world_size``
111 elements.
112 hash_key: Computed in ``__post_init__`` for use as a cache key.
113 """
114 param_input_dtypes: list[list[torch.dtype]]
115 param_input_numels: list[list[int]]
116 dtype: torch.dtype
117 inp_split_sizes: list[int]
118 total_input_numel: int
119 hash_key: int = field(init=False)
121 def __post_init__(self):
122 self.hash_key = hash((
123 tuple(tuple(d) for d in self.param_input_dtypes),
124 tuple(tuple(n) for n in self.param_input_numels),
125 self.dtype,
126 tuple(self.inp_split_sizes),
127 self.total_input_numel
128 ))
131class AllGatherResult(NamedTuple):
132 """Result of a fused all-gather operation.
134 Attributes:
135 all_gather_output: The contiguous output buffer holding gathered data from all ranks.
136 metadata: The ``AllGatherMetadata`` used to interpret the buffer layout.
137 handle: Async work handle from ``dist.all_gather_into_tensor``; ``None`` when
138 the operation was synchronous or when ``shard_world_size == 1``.
139 """
140 all_gather_output: torch.Tensor
141 metadata: AllGatherMetadata
142 handle: Optional[Work]
145@dataclass
146class CommContext:
147 """Global communication context for pipelining fused gradient reduction.
149 For FSDP (shard-only), the reduce-scatter handle is stored in ``comm_handle``
150 and the next module's backward hook waits on it before issuing its own reduction.
152 For HSDP (shard + replicate), a two-phase pipeline is used:
153 Phase 1 (``wait_reduce_scatter_and_issue_all_reduce``): wait for
154 reduce-scatter, then issue one or more async all-reduces stored on
155 the owning ``HSDPParamGroup``.
156 Phase 2 (``wait_all_reduce_and_apply_grad``): wait for all-reduce and
157 write reduced gradients back.
159 This allows three-way overlap:
160 Layer N reduce_scatter ↔ Layer N-1 backward compute
161 Layer N all_reduce ↔ Layer N-1 reduce_scatter
162 """
163 comm_handle: Optional[Work] = None
164 all_reduce_handle: Optional[Work] = None
165 pre_param_group = None
166 # Param group whose all_reduce has been issued but grad not yet applied
167 all_reduce_param_group = None
170comm_ctx = CommContext()
173def get_comm_ctx():
174 """Return the global ``CommContext`` singleton."""
175 return comm_ctx
178@dataclass
179class ReplicateBucket:
180 """One fused all-reduce bucket sharing the same replicate process group."""
182 key: int
183 group: Any
184 group_size: int
185 param_indices: list[int]
186 flat_numel: int
187 buffer: Optional[torch.Tensor] = None
190@dataclass
191class PendingBucketAllReduce:
192 """One in-flight async all-reduce launched for a replicate bucket."""
194 bucket_key: int
195 handle: Any
198class AllGatherMetadataCache:
199 """Cache for ``AllGatherMetadata`` to avoid recomputation across iterations.
201 The cache key is derived from ``(id(param), param.version)`` tuples so that
202 it invalidates automatically when parameters are re-sharded or replaced.
203 """
204 _cache: dict[int, AllGatherMetadata] = {}
206 @classmethod
207 def get_metadata(cls, hsdp_params, fn):
208 """Return cached metadata or compute via *fn* and cache the result."""
209 param_key = tuple((id(p), getattr(p, 'version', 0)) for p in hsdp_params)
210 key = hash(param_key)
212 if key in cls._cache:
213 return cls._cache[key]
214 metadata = fn(hsdp_params)
215 cls._cache[key] = metadata
216 return metadata
219def all_gather_copy_in(all_gather_inputs, all_gather_output, inp_split_sizes, all_gather_input_numel, rank):
220 """Copy per-parameter local shards into the fused all-gather input buffer.
222 The fused output buffer has shape ``(total_input_numel * world_size,)``. Each rank
223 writes its local shards into the slice ``[input_numel * rank : input_numel * (rank+1)]``
224 using ``torch._foreach_copy_`` for efficient batched copy.
226 Args:
227 all_gather_inputs: Flat list of local shard tensors from all parameters.
228 all_gather_output: The pre-allocated fused output buffer.
229 inp_split_sizes: Element counts for splitting the rank-local slice.
230 all_gather_input_numel: Total elements for one rank's local shards.
231 rank: This rank's index within the shard process group.
233 Returns:
234 Tuple of (rank-local input slice, full output buffer).
235 """
236 all_gather_input = all_gather_output.narrow(0, all_gather_input_numel * rank, all_gather_input_numel)
237 foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
238 with torch.no_grad():
239 # pylint: disable=W0212
240 torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)
241 return all_gather_input, all_gather_output
244def reduce_scatter_copy_in(
245 hsdp_params: List[TorchHSDPParamV2],
246 unsharded_grads: List[torch.Tensor],
247 reduce_scatter_input: torch.Tensor,
248 world_size: int,
249) -> None:
250 """Pack unsharded gradients into the fused reduce-scatter input buffer.
252 Uses ``torch._chunk_cat`` to interleave chunks from each gradient tensor so that
253 the buffer layout matches what ``dist.reduce_scatter_tensor`` expects: the buffer
254 is viewed as ``(world_size, total_numel // world_size)`` where row *i* contains
255 the slice destined for rank *i* after reduction.
257 Args:
258 hsdp_params: Parameters whose layout determines the pack plan per gradient.
259 unsharded_grads: Full (unsharded) gradients from all parameters.
260 reduce_scatter_input: Pre-allocated flat buffer of size ``sum(g.numel() for g in unsharded_grads)``.
261 world_size: Number of ranks in the shard process group.
262 """
263 if len(hsdp_params) != len(unsharded_grads):
264 raise AssertionError(
265 "reduce_scatter_copy_in expects one hsdp_param per unsharded_grad, but got "
266 f"{len(hsdp_params)} params and {len(unsharded_grads)} grads"
267 )
268 packed_rows = reduce_scatter_input.view(world_size, -1)
269 col_offset = 0
270 with torch.no_grad():
271 for hsdp_param, grad in zip(hsdp_params, unsharded_grads):
272 grad = grad.contiguous()
273 plan = build_rs_plan(hsdp_param, grad, world_size)
274 packed_grad = pack_for_reduce_scatter(grad, plan)
275 next_col_offset = col_offset + packed_grad.size(1)
276 packed_rows[:, col_offset:next_col_offset].copy_(packed_grad)
277 col_offset = next_col_offset
278 if col_offset != packed_rows.size(1):
279 raise AssertionError(
280 "reduce_scatter_copy_in packed an unexpected number of elements: "
281 f"{col_offset} != {packed_rows.size(1)}"
282 )
285class HSDPParamGroup:
286 """Groups all HSDP parameters within a module for fused collective communication.
288 Instead of issuing per-parameter all-gather (forward) and reduce-scatter (backward),
289 this class packs all parameter shards into a single contiguous buffer and performs one
290 fused collective, reducing NCCL/HCCL kernel launch overhead.
292 Lifecycle within one training iteration:
293 1. **Forward** — ``unshard()`` → ``foreach_all_gather()`` packs local shards into
294 ``ag_output`` and issues a single ``all_gather_into_tensor``.
295 2. **Forward (wait)** — ``wait_for_unshard()`` → ``foreach_all_gather_copy_out()``
296 waits on the handle and scatters gathered data back to per-parameter buffers.
297 3. **Backward** — ``foreach_reduce()`` packs unsharded gradients, issues fused
298 ``reduce_scatter_tensor`` (+ optional ``all_reduce`` for HSDP replicate dim),
299 and stores the handle in ``CommContext`` for pipelined overlap.
300 4. **Backward (apply)** — ``apply_fusion_reduced_grad()`` waits on the handle and
301 writes reduced gradient slices back to each parameter's ``.grad`` or ``.main_grad``.
303 Args:
304 hsdp_params: List of ``TorchHSDPParamV2`` belonging to this module.
305 mesh_info: Mesh info providing shard/replicate process groups.
306 device: Target device for buffer allocation.
307 mp_policy: Mixed-precision policy controlling reduce dtype and grad dtype.
308 """
310 def __init__(
311 self,
312 hsdp_params,
313 mesh_info: FSDPMeshInfo,
314 device: Optional[torch.device] = None,
315 mp_policy: Optional[MixedPrecisionPolicy] = None,
316 enable_zero_copy: bool = True,
317 ):
318 self.mesh_info = mesh_info
319 self.device = device
320 self.hsdp_params = hsdp_params
321 if isinstance(self.mesh_info, (FSDPMeshInfo, HSDPMeshInfo)):
322 self.shard_rank = self.mesh_info.shard_mesh_rank
323 self.shard_world_size = self.mesh_info.shard_mesh_size
324 else:
325 self.shard_rank = 0
326 self.shard_world_size = 1
327 self.shard_group = self.mesh_info.shard_process_group
328 self.replicate_group = None
329 if isinstance(self.mesh_info, (HSDPMeshInfo, DDPMeshInfo)):
330 self.replicate_group = self.mesh_info.replicate_process_group
331 elif isinstance(self.mesh_info, FSDPMeshInfo):
332 self.replicate_group = self._infer_layout_replicate_group()
333 self.device = device
334 self._all_gather_output = torch.empty(0, device=self.device)
335 self.ag_output = None # Fused all-gather output buffer, lazily allocated
336 self.metadata_cache = None
337 self.mp_policy = mp_policy
338 self.enable_zero_copy = enable_zero_copy
339 self._result = None # Pending AllGatherResult from async all-gather
340 self._reduce_output = None # Fused reduce-scatter output, consumed by apply_fusion_reduced_grad
341 self._reduce_op = None # Reduce op saved from foreach_reduce for use in apply_fusion_reduced_grad
342 self._needs_avg_div = False # Whether AVG was split into SUM + deferred div
343 self._reduce_hsdp_params = None
344 self._active_replicate_buckets: dict[int, ReplicateBucket] = {}
345 self._active_param_flat_offsets: list[int] = []
346 self._pending_all_reduce_handles: list[PendingBucketAllReduce] = []
347 self._init_mp_dtypes()
348 self._flat_param_buffer = None # Contiguous buffer holding all params' sharded data
349 self._flat_cast_buffer = None # Cast buffer for mixed precision (param_dtype)
350 if self.enable_zero_copy:
351 self._init_flat_param_buffer()
353 def _infer_layout_replicate_group(self):
354 """Infer a compatibility all-reduce group from params' final DTensor layout when mesh_info has none.
356 DTENSOR_UNIFIED parameters may still carry replicate axes from the original
357 DTensor layout, for example a ``(tp, ep)`` mesh where ``ep`` is replicate-only.
358 The non-fused path derives this group from each param's layout-driven
359 ``unsharded_group_info``. ``comm_fusion`` now buckets by those groups, so
360 this helper only preserves the historical ``self.replicate_group`` field
361 for compatibility with simpler single-group paths.
362 """
363 replicate_groups = []
364 for hsdp_param in self.hsdp_params:
365 group_info = getattr(hsdp_param, "unsharded_group_info", None)
366 group = getattr(group_info, "group", None)
367 if group is None or getattr(hsdp_param, "replicate_world_size", 1) <= 1:
368 continue
369 replicate_groups.append((group, getattr(hsdp_param, "_param_fqn", "<unknown>")))
371 if not replicate_groups:
372 return None
374 ref_group, _ = replicate_groups[0]
375 return ref_group
377 def _build_active_replicate_buckets(self, hsdp_params):
378 """Group active params by their layout-driven replicate all-reduce group."""
379 buckets: dict[int, ReplicateBucket] = {}
380 for idx, hsdp_param in enumerate(hsdp_params):
381 group_info = getattr(hsdp_param, "unsharded_group_info", None)
382 group = getattr(group_info, "group", None)
383 group_size = getattr(
384 group_info,
385 "rank_size",
386 getattr(hsdp_param, "replicate_world_size", 1),
387 )
388 if not isinstance(group_size, int):
389 fallback_group_size = getattr(hsdp_param, "replicate_world_size", 1)
390 group_size = fallback_group_size if isinstance(fallback_group_size, int) else 1
391 if group is None or group_size <= 1:
392 continue
394 key = id(group)
395 if key not in buckets:
396 buckets[key] = ReplicateBucket(
397 key=key,
398 group=group,
399 group_size=group_size,
400 param_indices=[],
401 flat_numel=0,
402 )
403 buckets[key].param_indices.append(idx)
404 buckets[key].flat_numel += hsdp_param.sharded_size.numel()
405 return buckets
407 def _allocate_bucket_buffers_if_needed(self, device, dtype):
408 """Allocate or resize per-bucket temporary all-reduce buffers."""
409 for bucket in self._active_replicate_buckets.values():
410 if bucket.flat_numel == 0:
411 continue
412 needs_new_buffer = (
413 bucket.buffer is None
414 or bucket.buffer.numel() != bucket.flat_numel
415 or bucket.buffer.device != device
416 or bucket.buffer.dtype != dtype
417 )
418 if needs_new_buffer:
419 bucket.buffer = torch.empty(bucket.flat_numel, device=device, dtype=dtype)
421 def _pack_bucket_from_reduce_output(self, bucket: ReplicateBucket) -> torch.Tensor:
422 """Pack one replicate bucket's scattered shards into a contiguous all-reduce buffer."""
423 if bucket.buffer is None:
424 raise AssertionError("Bucket buffer must be allocated before packing from reduce output")
425 if self._reduce_output is None or self._reduce_hsdp_params is None:
426 raise AssertionError("Bucket packing requires an active fused reduce output")
427 dst_offset = 0
428 for idx in bucket.param_indices:
429 hsdp_param = self._reduce_hsdp_params[idx]
430 src_offset = self._active_param_flat_offsets[idx]
431 numel = hsdp_param.sharded_size.numel()
432 bucket.buffer.narrow(0, dst_offset, numel).copy_(
433 self._reduce_output.narrow(0, src_offset, numel)
434 )
435 dst_offset += numel
436 return bucket.buffer
438 def _unpack_bucket_to_reduce_output(self, bucket: ReplicateBucket) -> None:
439 """Write one bucket's post-all-reduce data back into the fused reduce output."""
440 if bucket.buffer is None:
441 raise AssertionError("Bucket buffer must exist before unpacking to reduce output")
442 if self._reduce_output is None or self._reduce_hsdp_params is None:
443 raise AssertionError("Bucket unpack requires an active fused reduce output")
444 src_offset = 0
445 for idx in bucket.param_indices:
446 hsdp_param = self._reduce_hsdp_params[idx]
447 dst_offset = self._active_param_flat_offsets[idx]
448 numel = hsdp_param.sharded_size.numel()
449 self._reduce_output.narrow(0, dst_offset, numel).copy_(
450 bucket.buffer.narrow(0, src_offset, numel)
451 )
452 src_offset += numel
454 def _init_flat_param_buffer(self):
455 """Initialize a contiguous flat buffer and rebase all params' sharded data into it.
457 This enables zero-copy all-gather by making all local shards contiguous in memory,
458 so they can be passed directly to ``all_gather_into_tensor`` without ``foreach_copy_``.
459 When mixed-precision casting is needed, a separate cast buffer is also allocated.
460 """
461 if self.shard_world_size <= 1:
462 return
463 if len(self.hsdp_params) == 0:
464 return
465 if any(p.offload_to_cpu or p.sharded_param.device.type == "meta" for p in self.hsdp_params):
466 return
468 total_numel = sum(p._sharded_param_data.numel() for p in self.hsdp_params)
469 orig_dtype = self.hsdp_params[0]._sharded_param_data.dtype
470 flat_buffer = torch.empty(total_numel, dtype=orig_dtype, device=self.device)
472 offset = 0
473 for hsdp_param in self.hsdp_params:
474 numel = hsdp_param._sharded_param_data.numel()
475 flat_slice = flat_buffer.narrow(0, offset, numel)
476 flat_slice.copy_(hsdp_param._sharded_param_data)
477 # Rebase _sharded_param_data to be a view into the flat buffer
478 hsdp_param._sharded_param_data = flat_slice
479 # Rebase DTensor's local tensor so optimizer in-place updates write to flat buffer
480 new_local = flat_slice.view(hsdp_param.sharded_size)
481 req_grad = hsdp_param.sharded_param.requires_grad
482 hsdp_param.sharded_param._local_tensor = new_local
483 hsdp_param.sharded_param.data = new_local
484 if req_grad:
485 new_local.requires_grad_(True)
486 hsdp_param.sharded_param.requires_grad_(True)
487 offset += numel
489 self._flat_param_buffer = flat_buffer
491 # Allocate cast buffer for mixed precision if needed
492 has_param_dtype = any(p.param_dtype is not None for p in self.hsdp_params)
493 if has_param_dtype:
494 cast_dtype = next(p.param_dtype for p in self.hsdp_params if p.param_dtype is not None)
495 self._flat_cast_buffer = torch.empty(total_numel, dtype=cast_dtype, device=self.device)
497 def _is_flat_buffer_valid(self):
498 """Check if the flat buffer is still backing the params' sharded data.
500 The flat buffer becomes invalid after ``load_state_dict`` triggers
501 ``reset_sharded_param``, which re-assigns ``_sharded_param_data``.
502 """
503 if self._flat_param_buffer is None or len(self.hsdp_params) == 0:
504 return False
505 return self.hsdp_params[0]._sharded_param_data.data_ptr() == self._flat_param_buffer.data_ptr()
507 def unshard(self, async_op: bool = False):
508 """Trigger fused all-gather to reconstruct full parameters from shards.
510 If a prefetch has already been issued (``_result is not None``), this is a no-op.
511 For ``shard_world_size == 1`` (no sharding), skips the collective entirely.
513 Args:
514 async_op: If True, the all-gather runs asynchronously and must be
515 completed later via ``wait_for_unshard()``.
516 """
517 # Already prefetched — skip
518 if self._result is not None:
519 return
520 if self.shard_world_size == 1:
521 self._result = AllGatherResult(self._all_gather_output, None, None)
522 return
523 self.foreach_all_gather(async_op=async_op)
525 def _init_mp_dtypes(self):
526 """Initialize and validate mixed-precision dtypes across all trainable parameters.
528 All trainable parameters in the group must have a uniform ``orig_dtype`` and
529 ``reduce_dtype``; heterogeneous dtypes would cause incorrect buffer slicing.
530 """
531 for hsdp_param in self.hsdp_params:
532 hsdp_param.init_dtype_attrs(self.mp_policy)
533 trainable_params: list[TorchHSDPParamV2] = [
534 p for p in self.hsdp_params if p.sharded_param.requires_grad
535 ]
536 orig_dtypes = {p.orig_dtype for p in trainable_params}
537 reduce_dtypes = {p.reduce_dtype for p in trainable_params}
538 if len(trainable_params) > 0 and len(orig_dtypes) != 1:
539 raise AssertionError(
540 f"hsdp expects uniform original parameter dtype but got {orig_dtypes}"
541 )
542 self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None
543 if len(trainable_params) > 0 and len(reduce_dtypes) != 1:
544 raise AssertionError(
545 f"hsdp expects uniform reduce dtype but got {reduce_dtypes}"
546 )
547 self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None
549 def wait_for_unshard(self):
550 """Wait for the async all-gather to complete and scatter data to per-parameter buffers.
552 For ``shard_world_size == 1``, simply copies the local shard as the full parameter.
553 Otherwise, calls ``foreach_all_gather_copy_out`` to split the fused buffer and
554 write each parameter's all-gather output. Finally, initializes unsharded parameters.
555 """
556 if self._result is None:
557 return
558 if self.shard_world_size == 1:
559 for hsdp_param in self.hsdp_params:
560 all_gather_input = hsdp_param.all_gather_inputs[0]
561 hsdp_param.init_all_gather_outputs(
562 [all_gather_input.numel()],
563 [all_gather_input.dtype],
564 self.shard_world_size,
565 self.device
566 )
567 hsdp_param.alloc_all_gather_outputs()
568 # pylint: disable=W0212
569 with torch.autograd._unsafe_preserve_version_counter(hsdp_param.all_gather_outputs[0]):
570 # pylint: disable=W0212
571 hsdp_param.all_gather_outputs[0].copy_(all_gather_input)
572 else:
573 self.foreach_all_gather_copy_out()
574 for hsdp_param in self.hsdp_params:
575 hsdp_param.init_unsharded_param()
576 hsdp_param.to_unsharded()
578 def alloc_all_gather_output(self, total_output_numel):
579 """Resize the fused all-gather buffer storage to fit ``total_output_numel`` elements.
581 Uses ``untyped_storage().resize_()`` to avoid reallocating the tensor object,
582 enabling storage reuse across iterations.
583 """
584 storage = self.ag_output.untyped_storage()
585 expected_size = total_output_numel * self.ag_output.itemsize
586 if storage.size() != expected_size:
587 storage.resize_(expected_size)
589 def free_all_gather_output(self):
590 """Release device memory of the fused all-gather buffer by resizing storage to 0."""
591 storage = self.ag_output.untyped_storage()
592 if storage.size() != 0:
593 storage.resize_(0)
595 @torch.no_grad()
596 def foreach_all_gather(self, async_op=False):
597 """Perform a fused all-gather for all parameters in the group.
599 When a flat parameter buffer is available (see ``_init_flat_param_buffer``),
600 the local shards are already contiguous and can be passed directly to
601 ``all_gather_into_tensor`` without any copy-in. Otherwise falls back to
602 the ``all_gather_copy_in`` path.
604 Args:
605 async_op: If True, the collective runs asynchronously.
606 """
607 if self.metadata_cache is None:
608 self.metadata_cache = AllGatherMetadataCache()
609 # pylint: disable=W0108
610 metadata = self.metadata_cache.get_metadata(self.hsdp_params, lambda p: get_all_gather_metadata(p))
611 if metadata.total_input_numel == 0:
612 return
613 world_size, rank = self.shard_group.size(), self.shard_group.rank()
614 total_output_numel = metadata.total_input_numel * world_size
615 if self.ag_output is None:
616 self.ag_output = torch.empty(size=(total_output_numel,),
617 dtype=metadata.dtype, device=self.device)
618 else:
619 self.alloc_all_gather_output(total_output_numel)
621 if self.enable_zero_copy and not self._is_flat_buffer_valid():
622 self._init_flat_param_buffer()
623 use_flat_buffer = self.enable_zero_copy and self._flat_param_buffer is not None
624 if use_flat_buffer:
625 # Zero-copy path: flat buffer already holds contiguous shard data
626 if self._flat_cast_buffer is not None:
627 # Mixed precision: single contiguous cast instead of N small copies
628 self._flat_cast_buffer.copy_(self._flat_param_buffer)
629 all_gather_input = self._flat_cast_buffer
630 else:
631 all_gather_input = self._flat_param_buffer
632 else:
633 # Fallback: collect inputs and copy into the rank-local slice of ag_output
634 all_gather_inputs = []
635 for hsdp_param in self.hsdp_params:
636 all_gather_inputs.extend(hsdp_param.all_gather_inputs)
637 if len(all_gather_inputs) == 0:
638 return
639 all_gather_input, _ = all_gather_copy_in(
640 all_gather_inputs,
641 self.ag_output,
642 metadata.inp_split_sizes,
643 metadata.total_input_numel,
644 rank
645 )
646 del all_gather_inputs # Free references to individual shard tensors
648 handle = dist.all_gather_into_tensor(self.ag_output, all_gather_input, self.shard_group, async_op)
649 self._result = AllGatherResult(self.ag_output, metadata, handle)
651 @torch.no_grad()
652 def foreach_all_gather_copy_out(self):
653 """Wait for the fused all-gather and scatter results back to per-parameter buffers.
655 After the collective completes, the fused output is viewed as ``(world_size, -1)``
656 and split along dim=1 according to ``inp_split_sizes``. Each slice is copied into
657 the corresponding parameter's ``all_gather_outputs`` buffer using
658 ``split_with_sizes_copy`` for zero-extra-allocation copy-out.
660 Version counters are preserved via ``_unsafe_preserve_version_counter`` to avoid
661 triggering autograd version checks on parameter tensors that alias these buffers.
662 """
663 (ag_output, metadata, _) = self._result
664 if self._result.handle is not None:
665 self._result.handle.wait()
666 device = ag_output.device
667 world_size = self.shard_group.size()
668 split_with_sizes_out = []
669 for input_numels, input_dtypes, hsdp_param in zip(
670 metadata.param_input_numels, metadata.param_input_dtypes, self.hsdp_params
671 ):
672 hsdp_param.init_all_gather_outputs(input_numels, input_dtypes, world_size, device)
673 hsdp_param.alloc_all_gather_outputs()
674 split_with_sizes_out.extend(hsdp_param.all_gather_outputs)
675 ag_output = ag_output.view(world_size, -1)
676 out = [t.view(world_size, -1) for t in split_with_sizes_out]
677 non_inference_outs = [o for o in out if not o.is_inference()]
678 if len(non_inference_outs) > 0:
679 # Older torch variants only accept one tensor per context manager.
680 # Preserve all version counters explicitly for cross-version compatibility.
681 # pylint: disable=W0212
682 with ExitStack() as stack:
683 for tensor in non_inference_outs:
684 stack.enter_context(torch.autograd._unsafe_preserve_version_counter(tensor))
685 torch.split_with_sizes_copy(ag_output, metadata.inp_split_sizes, dim=1, out=out)
686 else:
687 torch.split_with_sizes_copy(ag_output, metadata.inp_split_sizes, dim=1, out=out)
688 self._result = None
689 self.free_all_gather_output() # Immediately release fused buffer memory
691 @torch.no_grad()
692 def foreach_reduce(
693 self,
694 reduce_scatter_reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG,
695 async_op: bool = True
696 ) -> Optional[torch.Tensor]:
697 """Perform fused gradient reduction (reduce-scatter + optional all-reduce).
699 Collects unsharded gradients from all parameters, packs them into a single
700 contiguous buffer, and issues one ``reduce_scatter_tensor``. For HSDP (2D mesh),
701 a follow-up ``all_reduce`` across the replicate dimension is also performed.
703 When ``async_op=True``, the communication handle is stored in the global
704 ``CommContext`` so that the next module's backward hook can overlap computation
705 with this reduction. The actual gradient write-back is deferred to
706 ``apply_fusion_reduced_grad()``.
708 Args:
709 reduce_scatter_reduce_op: Reduction operator (default: AVG).
710 async_op: If True, run collectives asynchronously for compute-comm overlap.
711 """
712 # Collect unsharded gradients (from accumulated grad or .grad)
713 hsdp_params: List[TorchHSDPParamV2] = []
714 unsharded_grads: List[torch.Tensor] = []
715 for hsdp_param in self.hsdp_params:
716 if not hasattr(hsdp_param, '_unsharded_param'):
717 continue
718 if hsdp_param.unsharded_accumulated_grad is not None:
719 hsdp_params.append(hsdp_param)
720 unsharded_grads.append(hsdp_param.unsharded_accumulated_grad_data)
721 elif hsdp_param._unsharded_param.grad is not None: # pylint: disable=W0212
722 hsdp_params.append(hsdp_param)
723 unsharded_grads.append(hsdp_param.unsharded_grad_data)
724 if not hsdp_params:
725 return
726 grad_dtypes = {g.dtype for g in unsharded_grads}
727 if len(grad_dtypes) != 1:
728 raise ValueError(
729 f"FSDP reduce-scatter expects uniform grad dtype but got {grad_dtypes}"
730 )
731 grad_dtype = unsharded_grads[0].dtype
732 reduce_dtype = self._reduce_dtype or grad_dtype
733 world_size = self.shard_group.size()
734 reduce_scatter_input_numel = sum(s.numel() for s in unsharded_grads)
735 reduce_scatter_output_numel = reduce_scatter_input_numel // world_size
736 device = unsharded_grads[0].device
737 # Pack all gradients into a contiguous buffer for fused reduce-scatter
738 reduce_scatter_input = torch.empty((reduce_scatter_input_numel,), dtype=reduce_dtype, device=device)
739 reduce_scatter_copy_in(hsdp_params, unsharded_grads, reduce_scatter_input, world_size)
740 unsharded_grads.clear() # Release references to full gradients
741 reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,))
742 self._needs_avg_div = reduce_scatter_reduce_op == dist.ReduceOp.AVG
743 comm_op = dist.ReduceOp.SUM if self._needs_avg_div else reduce_scatter_reduce_op
744 self._reduce_op = comm_op
745 self._reduce_hsdp_params = hsdp_params
746 self._active_param_flat_offsets = []
747 flat_offset = 0
748 for hsdp_param in hsdp_params:
749 self._active_param_flat_offsets.append(flat_offset)
750 flat_offset += hsdp_param.sharded_size.numel()
751 self._active_replicate_buckets = self._build_active_replicate_buckets(hsdp_params)
752 self._allocate_bucket_buffers_if_needed(reduce_output.device, reduce_output.dtype)
753 self._pending_all_reduce_handles = []
754 rs_handle = dist.reduce_scatter_tensor(
755 output=reduce_output,
756 input=reduce_scatter_input,
757 group=self.shard_group,
758 op=comm_op,
759 async_op=async_op
760 )
761 comm_ctx.comm_handle = rs_handle
762 # Step 2 (HSDP only): All-reduce is deferred to apply_fusion_reduced_grad()
763 self._reduce_output = reduce_output
764 if async_op:
765 # Register this group for deferred grad application by the next backward hook
766 comm_ctx.pre_param_group = self
767 else:
768 self.apply_fusion_reduced_grad()
770 def wait_reduce_scatter_and_issue_all_reduce(self):
771 """Phase 1 of pipelined HSDP gradient reduction.
773 Waits for the async reduce-scatter to complete, then issues an async
774 all-reduce for each active replicate bucket. The bucket handles are
775 stored on this ``HSDPParamGroup`` so they can overlap with the next
776 layer's reduce-scatter (Phase 2 is deferred).
778 For FSDP (no replicate group), skips the all-reduce and directly
779 applies gradients since there is nothing further to pipeline.
780 """
781 if comm_ctx.comm_handle is not None:
782 comm_ctx.comm_handle.wait()
783 comm_ctx.comm_handle = None
784 # Deferred div for AVG: apply after RS completes, before AR
785 if self._needs_avg_div:
786 self._reduce_output.div_(self.shard_world_size)
787 if not self._active_replicate_buckets:
788 # No replicate group — no all-reduce needed, apply grads immediately
789 self._apply_reduced_grad()
790 return
792 self._pending_all_reduce_handles = []
793 for bucket in self._active_replicate_buckets.values():
794 packed = self._pack_bucket_from_reduce_output(bucket)
795 ar_handle = dist.all_reduce(
796 packed,
797 group=bucket.group,
798 op=self._reduce_op,
799 async_op=True,
800 )
801 self._pending_all_reduce_handles.append(
802 PendingBucketAllReduce(bucket_key=bucket.key, handle=ar_handle)
803 )
804 comm_ctx.all_reduce_param_group = self
806 def wait_all_reduce_and_apply_grad(self):
807 """Phase 2 of pipelined HSDP gradient reduction.
809 Waits for the async all-reduce issued in Phase 1 and writes reduced
810 gradients back to sharded parameters.
811 """
812 for pending in self._pending_all_reduce_handles:
813 bucket = self._active_replicate_buckets[pending.bucket_key]
814 pending.handle.wait()
815 if self._needs_avg_div:
816 bucket.buffer.div_(bucket.group_size)
817 self._unpack_bucket_to_reduce_output(bucket)
818 self._pending_all_reduce_handles = []
819 comm_ctx.all_reduce_handle = None
820 self._apply_reduced_grad()
822 def apply_fusion_reduced_grad(self):
823 """Full synchronous reduction path (used for final drain and sync mode).
825 Waits for reduce-scatter, performs synchronous all-reduce if needed,
826 and applies gradients — all in one call without pipelining.
827 """
828 if comm_ctx.comm_handle is not None:
829 comm_ctx.comm_handle.wait()
830 comm_ctx.comm_handle = None
831 # Deferred div for AVG after RS
832 if self._needs_avg_div:
833 self._reduce_output.div_(self.shard_world_size)
834 for bucket in self._active_replicate_buckets.values():
835 packed = self._pack_bucket_from_reduce_output(bucket)
836 dist.all_reduce(
837 packed,
838 group=bucket.group,
839 op=self._reduce_op,
840 )
841 # Deferred div for AVG after AR
842 if self._needs_avg_div:
843 packed.div_(bucket.group_size)
844 self._unpack_bucket_to_reduce_output(bucket)
845 self._apply_reduced_grad()
847 def _apply_reduced_grad(self):
848 """Write reduced gradients from ``_reduce_output`` back to sharded parameters.
850 Slices the fused ``_reduce_output`` buffer into per-parameter sharded gradients
851 using ``torch.as_strided`` (zero-copy view), then either accumulates into the
852 existing ``.grad`` / ``.main_grad`` or assigns a new DTensor gradient.
854 Handles:
855 - Mixed-precision: casts reduced gradient to ``_orig_dtype`` if needed.
856 - CPU offload: transfers gradient to CPU (``non_blocking`` when possible).
857 - Gradient accumulation: adds to existing grad when present.
858 - Memory cleanup: nulls out unsharded grad references to free memory.
859 """
860 flat_grad_offset = 0
861 if self._reduce_hsdp_params is None:
862 return
863 for hsdp_param in self._reduce_hsdp_params:
864 # Determine target gradient tensor (regular .grad or fp32 main_grad)
865 sharded_grad = None
866 if not self.mp_policy.apply_grad_on_fp32_main_grad:
867 sharded_grad = hsdp_param.sharded_param.grad
868 else:
869 if not hasattr(hsdp_param.sharded_param, "main_grad"):
870 hsdp_param.sharded_param.main_grad = None
871 sharded_grad = hsdp_param.sharded_param.main_grad
872 shard_size = hsdp_param.sharded_size
873 # Zero-copy view into the fused reduce output for this parameter's shard
874 new_sharded_grad = torch.as_strided(
875 self._reduce_output,
876 size=shard_size,
877 stride=hsdp_param.contiguous_sharded_stride,
878 storage_offset=flat_grad_offset,
879 )
880 # Cast to original dtype if reduce was done in a different precision
881 if not self.mp_policy.apply_grad_on_fp32_main_grad and new_sharded_grad.dtype != self._orig_dtype:
882 new_sharded_grad = new_sharded_grad.to(self._orig_dtype)
883 need_synchronize = False
884 if hsdp_param.offload_to_cpu:
885 non_blocking = hsdp_param.pin_memory and sharded_grad is None
886 new_sharded_grad = new_sharded_grad.to(
887 torch.device("cpu"), non_blocking=non_blocking
888 )
889 need_synchronize = True
890 # Accumulate or assign gradient
891 if sharded_grad is not None:
892 if not self.mp_policy.apply_grad_on_fp32_main_grad:
893 hsdp_param.sharded_param.grad._local_tensor += new_sharded_grad
894 else:
895 hsdp_param.sharded_param.main_grad._local_tensor += new_sharded_grad
896 hsdp_param.sharded_param.grad = None
897 else:
898 if not self.mp_policy.apply_grad_on_fp32_main_grad:
899 hsdp_param.sharded_param.grad = hsdp_param.to_sharded_dtensor(new_sharded_grad)
900 else:
901 hsdp_param.sharded_param.main_grad = hsdp_param.to_sharded_dtensor(new_sharded_grad)
902 hsdp_param.sharded_param.grad = None
903 flat_grad_offset += shard_size.numel()
904 # Release unsharded gradient references to free memory
905 if hsdp_param.unsharded_accumulated_grad is not None:
906 hsdp_param.unsharded_accumulated_grad = None
907 elif hsdp_param.unsharded_param.grad is not None:
908 hsdp_param.unsharded_param.grad = None
910 if need_synchronize:
911 if self.device.type == "npu":
912 torch.npu.current_stream().synchronize()
913 elif self.device.type == "cuda":
914 torch.cuda.current_stream().synchronize()
915 else:
916 raise NotImplementedError(f"Unsupported device type {self.device} for \
917 synchronization after CPU offload.")
918 self._reduce_output = None # Release fused reduce buffer
919 self._reduce_hsdp_params = None
920 self._active_param_flat_offsets = []
921 self._active_replicate_buckets = {}
922 self._pending_all_reduce_handles = []