Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / fully_shard / param_group.py: 18%
451 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MindSpore HSDP parameter group with fused communication."""
17from __future__ import annotations
19import math
20from dataclasses import dataclass, field
21from typing import Any, List, NamedTuple, Optional
23import mindspore as ms
24from mindspore import ops
25from mindspore.common.api import _no_grad
26import mindspore.mint.distributed as dist
27from mindspore.ops.function.comm_func import CommHandle
29from hyper_parallel.core.fully_shard.utils import DDPMeshInfo, FSDPMeshInfo, HSDPMeshInfo, MixedPrecisionPolicy
30from hyper_parallel.platform.mindspore.fully_shard._version_utils import copy_without_bumping_version
31from hyper_parallel.platform.mindspore.fully_shard.pack_utils import build_rs_plan, pack_for_reduce_scatter
32from hyper_parallel.platform.mindspore.fully_shard.param import MindSporeHSDPParamV2
35def _normalize_device(device: Any) -> str:
36 if isinstance(device, str):
37 return device.split(":", 1)[0]
38 return str(device).split(":", 1)[0]
41def _shape_numel(shape) -> int:
42 return math.prod(int(dim) for dim in shape)
45def get_all_gather_metadata(hsdp_params):
46 """Collect metadata required for fused all-gather."""
47 param_input_dtypes = []
48 param_input_numels = []
49 inp_split_sizes = []
50 total_input_numel = 0
51 first_dtype = None
53 for hsdp_param in hsdp_params:
54 inputs = hsdp_param.all_gather_inputs
55 if first_dtype is None:
56 first_dtype = inputs[0].dtype
57 elif first_dtype != inputs[0].dtype:
58 raise ValueError("All parameters in the group must have a uniform dtype.")
59 param_dtypes = [t.dtype for t in inputs]
60 param_numels = [t.numel() for t in inputs]
61 param_input_dtypes.append(param_dtypes)
62 param_input_numels.append(param_numels)
63 inp_split_sizes.extend(param_numels)
64 total_input_numel += sum(param_numels)
66 return AllGatherMetadata(
67 param_input_dtypes,
68 param_input_numels,
69 first_dtype,
70 inp_split_sizes,
71 total_input_numel,
72 )
75@dataclass
76class AllGatherMetadata:
77 """Metadata describing the fused all-gather buffer layout."""
79 param_input_dtypes: list[list[Any]]
80 param_input_numels: list[list[int]]
81 dtype: Any
82 inp_split_sizes: list[int]
83 total_input_numel: int
84 hash_key: int = field(init=False)
86 def __post_init__(self):
87 self.hash_key = hash(
88 (
89 tuple(tuple(d) for d in self.param_input_dtypes),
90 tuple(tuple(n) for n in self.param_input_numels),
91 self.dtype,
92 tuple(self.inp_split_sizes),
93 self.total_input_numel,
94 )
95 )
98class AllGatherResult(NamedTuple):
99 """Result of a fused all-gather operation."""
101 all_gather_output: Optional[ms.Tensor]
102 metadata: Optional[AllGatherMetadata]
103 handle: Optional[CommHandle]
106@dataclass
107class CommContext:
108 """Global communication context for pipelined fused reductions."""
110 comm_handle: Optional[CommHandle] = None
111 all_reduce_handle: Optional[CommHandle] = None
112 pre_param_group = None
113 all_reduce_param_group = None
116comm_ctx = CommContext()
119def get_comm_ctx():
120 """Return the global communication context singleton."""
121 return comm_ctx
124@dataclass
125class ReplicateBucket:
126 """One fused all-reduce bucket sharing the same replicate process group."""
128 key: int
129 group: Any
130 group_size: int
131 param_indices: list[int]
132 flat_numel: int
133 buffer: Optional[ms.Tensor] = None
136@dataclass
137class PendingBucketAllReduce:
138 """One in-flight async all-reduce launched for a replicate bucket."""
140 bucket_key: int
141 handle: Any
144class AllGatherMetadataCache:
145 """Cache for all-gather metadata across iterations."""
147 _cache: dict[int, AllGatherMetadata] = {}
149 @classmethod
150 def get_metadata(cls, hsdp_params, fn):
151 param_key = tuple((id(p), getattr(p, "version", 0)) for p in hsdp_params)
152 key = hash(param_key)
153 if key in cls._cache:
154 return cls._cache[key]
155 metadata = fn(hsdp_params)
156 cls._cache[key] = metadata
157 return metadata
160@_no_grad()
161def all_gather_copy_in(all_gather_inputs, all_gather_output, inp_split_sizes, all_gather_input_numel, rank):
162 """Copy per-parameter local shards into one fused rank-local all-gather slice."""
163 all_gather_input = all_gather_output.narrow(0, all_gather_input_numel * rank, all_gather_input_numel)
164 offset = 0
165 for src, size in zip(all_gather_inputs, inp_split_sizes):
166 src_flat = src.view(-1)
167 all_gather_input.narrow(0, offset, size).copy_(src_flat)
168 offset += size
169 return all_gather_input, all_gather_output
172@_no_grad()
173def split_with_sizes_copy(all_gather_output, split_sizes, dim, out):
174 """Copy split views from a fused all-gather output into pre-allocated outputs."""
175 if dim != 1:
176 raise NotImplementedError("split_with_sizes_copy currently only supports dim=1")
177 offset = 0
178 for dst, size in zip(out, split_sizes):
179 src = all_gather_output.narrow(dim, offset, size)
180 copy_without_bumping_version(dst, src)
181 offset += size
184@_no_grad()
185def reduce_scatter_copy_in(
186 hsdp_params: List[MindSporeHSDPParamV2],
187 unsharded_grads: List[ms.Tensor],
188 reduce_scatter_input: ms.Tensor,
189 world_size: int,
190) -> None:
191 """Pack all unsharded gradients into one fused reduce-scatter input buffer."""
192 if len(hsdp_params) != len(unsharded_grads):
193 raise AssertionError(
194 "reduce_scatter_copy_in expects one hsdp_param per unsharded_grad, but got "
195 f"{len(hsdp_params)} params and {len(unsharded_grads)} grads"
196 )
197 packed_rows = reduce_scatter_input.view(world_size, -1)
198 col_offset = 0
199 for hsdp_param, grad in zip(hsdp_params, unsharded_grads):
200 grad = grad.contiguous()
201 plan = build_rs_plan(hsdp_param, grad, world_size)
202 packed_grad = pack_for_reduce_scatter(grad, plan)
203 next_col_offset = col_offset + packed_grad.shape[1]
204 for row_idx in range(world_size):
205 packed_rows[row_idx].narrow(0, col_offset, packed_grad.shape[1]).copy_(
206 packed_grad[row_idx].view(-1)
207 )
208 col_offset = next_col_offset
209 if col_offset != packed_rows.shape[1]:
210 raise AssertionError(
211 "reduce_scatter_copy_in packed an unexpected number of elements: "
212 f"{col_offset} != {packed_rows.shape[1]}"
213 )
216class HSDPParamGroup:
217 """Group HSDP parameters within a module for fused collectives."""
219 def __init__(
220 self,
221 hsdp_params,
222 mesh_info: FSDPMeshInfo,
223 device: Optional[str] = None,
224 mp_policy: Optional[MixedPrecisionPolicy] = None,
225 enable_zero_copy_param_buffer: bool = False,
226 ):
227 self.mesh_info = mesh_info
228 self.device = device
229 self.hsdp_params = hsdp_params
230 self.enable_zero_copy_param_buffer = enable_zero_copy_param_buffer
231 if isinstance(self.mesh_info, (FSDPMeshInfo, HSDPMeshInfo)):
232 self.shard_rank = self.mesh_info.shard_mesh_rank
233 self.shard_world_size = self.mesh_info.shard_mesh_size
234 else:
235 self.shard_rank = 0
236 self.shard_world_size = 1
237 self.shard_group = self.mesh_info.shard_process_group
238 self.replicate_group = None
239 if isinstance(self.mesh_info, (HSDPMeshInfo, DDPMeshInfo)):
240 self.replicate_group = self.mesh_info.replicate_process_group
241 elif isinstance(self.mesh_info, FSDPMeshInfo):
242 self.replicate_group = self._infer_layout_replicate_group()
243 self.ag_output: Optional[ms.Tensor] = None
244 self.metadata_cache = None
245 self.mp_policy = mp_policy
246 self._result = None
247 self._reduce_output = None
248 self._reduce_op = None
249 self._needs_avg_div = False
250 self._reduce_hsdp_params = None
251 self._active_replicate_buckets: dict[int, ReplicateBucket] = {}
252 self._active_param_flat_offsets: list[int] = []
253 self._pending_all_reduce_handles: list[PendingBucketAllReduce] = []
254 self._flat_param_buffer: Optional[ms.Tensor] = None
255 self._flat_cast_buffer: Optional[ms.Tensor] = None
256 self._init_mp_dtypes()
257 if self.enable_zero_copy_param_buffer:
258 self._init_flat_param_buffer()
260 def _infer_layout_replicate_group(self):
261 replicate_groups = []
262 for hsdp_param in self.hsdp_params:
263 group_info = getattr(hsdp_param, "unsharded_group_info", None)
264 group = getattr(group_info, "group", None)
265 if group is None or getattr(hsdp_param, "replicate_world_size", 1) <= 1:
266 continue
267 replicate_groups.append(group)
268 if not replicate_groups:
269 return None
270 return replicate_groups[0]
272 def _build_active_replicate_buckets(self, hsdp_params):
273 buckets: dict[int, ReplicateBucket] = {}
274 for idx, hsdp_param in enumerate(hsdp_params):
275 group_info = getattr(hsdp_param, "unsharded_group_info", None)
276 group = getattr(group_info, "group", None)
277 group_size = getattr(group_info, "rank_size", getattr(hsdp_param, "replicate_world_size", 1))
278 if group is None or group_size <= 1:
279 continue
280 key = id(group)
281 if key not in buckets:
282 buckets[key] = ReplicateBucket(
283 key=key,
284 group=group,
285 group_size=group_size,
286 param_indices=[],
287 flat_numel=0,
288 )
289 buckets[key].param_indices.append(idx)
290 buckets[key].flat_numel += _shape_numel(hsdp_param.sharded_size)
291 return buckets
293 def _init_flat_param_buffer(self):
294 """Rebase local shards into one flat buffer when storage semantics allow it."""
295 if not self.enable_zero_copy_param_buffer:
296 return
297 if self.shard_world_size <= 1 or len(self.hsdp_params) == 0:
298 return
299 if any(p.offload_to_cpu or str(p.sharded_param.device) == "meta" for p in self.hsdp_params):
300 return
302 total_numel = sum(hsdp_param._sharded_param_data.numel() for hsdp_param in self.hsdp_params)
303 orig_dtype = self.hsdp_params[0]._sharded_param_data.dtype
304 flat_buffer = ms.mint.empty((total_numel,), dtype=orig_dtype, device=_normalize_device(self.device))
306 offset = 0
307 original_locals = []
308 try:
309 for hsdp_param in self.hsdp_params:
310 original_locals.append((hsdp_param, hsdp_param._sharded_param_data, hsdp_param._sharded_local_tensor))
311 numel = hsdp_param._sharded_param_data.numel()
312 flat_slice = flat_buffer.narrow(0, offset, numel)
313 flat_slice.copy_(hsdp_param._sharded_param_data)
314 hsdp_param._sharded_param_data = flat_slice
315 new_local = flat_slice.view(hsdp_param.sharded_size)
316 req_grad = hsdp_param.sharded_param.requires_grad
317 hsdp_param.sharded_param.set_data(new_local)
318 hsdp_param.sharded_param._local_tensor = new_local
319 if req_grad:
320 new_local.requires_grad_(True)
321 hsdp_param.sharded_param.requires_grad_(True)
322 offset += numel
323 except Exception: # pylint: disable=W0718
324 for hsdp_param, orig_flat, orig_local in original_locals:
325 hsdp_param._sharded_param_data = orig_flat
326 hsdp_param.sharded_param.set_data(orig_local)
327 hsdp_param.sharded_param._local_tensor = orig_local
328 self._flat_param_buffer = None
329 self._flat_cast_buffer = None
330 return
332 self._flat_param_buffer = flat_buffer
333 has_param_dtype = any(p.param_dtype is not None for p in self.hsdp_params)
334 if has_param_dtype:
335 cast_dtype = next(p.param_dtype for p in self.hsdp_params if p.param_dtype is not None)
336 self._flat_cast_buffer = ms.mint.empty(
337 (total_numel,), dtype=cast_dtype, device=_normalize_device(self.device)
338 )
340 def _is_flat_buffer_valid(self):
341 """Check if flat buffer still backs the params' sharded data."""
342 if self._flat_param_buffer is None or len(self.hsdp_params) == 0:
343 return False
344 first_param = self.hsdp_params[0]
345 return (
346 first_param._sharded_param_data.untyped_storage().data_ptr()
347 == self._flat_param_buffer.untyped_storage().data_ptr()
348 )
350 def _allocate_bucket_buffers_if_needed(self, device, dtype):
351 normalized_device = _normalize_device(device)
352 for bucket in self._active_replicate_buckets.values():
353 if bucket.flat_numel == 0:
354 continue
355 needs_new_buffer = (
356 bucket.buffer is None
357 or bucket.buffer.numel() != bucket.flat_numel
358 or bucket.buffer.dtype != dtype
359 )
360 if needs_new_buffer:
361 bucket.buffer = ms.mint.empty((bucket.flat_numel,), dtype=dtype, device=normalized_device)
363 def _pack_bucket_from_reduce_output(self, bucket: ReplicateBucket) -> ms.Tensor:
364 if bucket.buffer is None:
365 raise AssertionError("Bucket buffer must be allocated before packing from reduce output")
366 if self._reduce_output is None or self._reduce_hsdp_params is None:
367 raise AssertionError("Bucket packing requires an active fused reduce output")
368 dst_offset = 0
369 for idx in bucket.param_indices:
370 hsdp_param = self._reduce_hsdp_params[idx]
371 src_offset = self._active_param_flat_offsets[idx]
372 numel = _shape_numel(hsdp_param.sharded_size)
373 bucket.buffer.narrow(0, dst_offset, numel).copy_(
374 self._reduce_output.narrow(0, src_offset, numel)
375 )
376 dst_offset += numel
377 return bucket.buffer
379 def _unpack_bucket_to_reduce_output(self, bucket: ReplicateBucket) -> None:
380 if bucket.buffer is None:
381 raise AssertionError("Bucket buffer must exist before unpacking to reduce output")
382 if self._reduce_output is None or self._reduce_hsdp_params is None:
383 raise AssertionError("Bucket unpack requires an active fused reduce output")
384 src_offset = 0
385 for idx in bucket.param_indices:
386 hsdp_param = self._reduce_hsdp_params[idx]
387 dst_offset = self._active_param_flat_offsets[idx]
388 numel = _shape_numel(hsdp_param.sharded_size)
389 self._reduce_output.narrow(0, dst_offset, numel).copy_(
390 bucket.buffer.narrow(0, src_offset, numel)
391 )
392 src_offset += numel
394 def unshard(self, async_op: bool = False):
395 """Trigger fused all-gather for all parameters in this group."""
396 if self._result is not None:
397 return
398 if self.shard_world_size == 1:
399 self._result = AllGatherResult(None, None, None)
400 return
401 self.foreach_all_gather(async_op=async_op)
403 def _init_mp_dtypes(self):
404 for hsdp_param in self.hsdp_params:
405 hsdp_param.init_dtype_attrs(self.mp_policy)
406 trainable_params: list[MindSporeHSDPParamV2] = [
407 p for p in self.hsdp_params if p.sharded_param.requires_grad
408 ]
409 orig_dtypes = {p.orig_dtype for p in trainable_params}
410 reduce_dtypes = {p.reduce_dtype for p in trainable_params}
411 if len(trainable_params) > 0 and len(orig_dtypes) != 1:
412 raise AssertionError(
413 f"hsdp expects uniform original parameter dtype but got {orig_dtypes}"
414 )
415 self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None
416 if len(trainable_params) > 0 and len(reduce_dtypes) != 1:
417 raise AssertionError(
418 f"hsdp expects uniform reduce dtype but got {reduce_dtypes}"
419 )
420 self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None
422 def wait_for_unshard(self):
423 """Wait for fused all-gather and materialize per-parameter unsharded views."""
424 if self._result is None:
425 return
426 if self.shard_world_size == 1:
427 for hsdp_param in self.hsdp_params:
428 all_gather_input = hsdp_param.all_gather_inputs[0]
429 hsdp_param.init_all_gather_outputs(
430 [all_gather_input.numel()],
431 [all_gather_input.dtype],
432 self.shard_world_size,
433 _normalize_device(self.device),
434 )
435 hsdp_param.alloc_all_gather_outputs()
436 copy_without_bumping_version(hsdp_param.all_gather_outputs[0], all_gather_input)
437 self._result = None
438 else:
439 self.foreach_all_gather_copy_out()
440 for hsdp_param in self.hsdp_params:
441 hsdp_param.init_unsharded_param()
442 hsdp_param.to_unsharded()
444 def alloc_all_gather_output(self, total_output_numel, dtype):
445 normalized_device = _normalize_device(self.device)
446 if self.ag_output is None or self.ag_output.dtype != dtype:
447 self.ag_output = ms.mint.empty((total_output_numel,), dtype=dtype, device=normalized_device)
448 return
449 storage = self.ag_output.untyped_storage()
450 expected_size = total_output_numel * self.ag_output.itemsize
451 if storage.size() != expected_size:
452 storage.resize_(expected_size)
454 def free_all_gather_output(self):
455 if self.ag_output is None:
456 return
457 storage = self.ag_output.untyped_storage()
458 if storage.size() != 0:
459 storage.resize_(0)
461 @_no_grad()
462 def foreach_all_gather(self, async_op=False):
463 """Perform one fused all-gather across all parameters in the group."""
464 if self.metadata_cache is None:
465 self.metadata_cache = AllGatherMetadataCache()
466 metadata = self.metadata_cache.get_metadata(self.hsdp_params, get_all_gather_metadata)
467 if metadata.total_input_numel == 0:
468 return
469 world_size = self.shard_world_size
470 rank = self.shard_rank
471 total_output_numel = metadata.total_input_numel * world_size
472 self.alloc_all_gather_output(total_output_numel, metadata.dtype)
473 for hsdp_param in self.hsdp_params:
474 hsdp_param.reset_sharded_param()
475 if self.enable_zero_copy_param_buffer and not self._is_flat_buffer_valid():
476 self._init_flat_param_buffer()
478 use_flat_buffer = (
479 self.enable_zero_copy_param_buffer
480 and self._flat_param_buffer is not None
481 and self._is_flat_buffer_valid()
482 )
483 if use_flat_buffer:
484 if self._flat_cast_buffer is not None:
485 self._flat_cast_buffer.copy_(self._flat_param_buffer)
486 all_gather_input = self._flat_cast_buffer
487 else:
488 all_gather_input = self._flat_param_buffer
489 else:
490 all_gather_inputs = []
491 for hsdp_param in self.hsdp_params:
492 all_gather_inputs.extend(hsdp_param.all_gather_inputs)
493 if len(all_gather_inputs) == 0:
494 return
495 all_gather_input, _ = all_gather_copy_in(
496 all_gather_inputs,
497 self.ag_output,
498 metadata.inp_split_sizes,
499 metadata.total_input_numel,
500 rank,
501 )
502 handle = dist.all_gather_into_tensor(self.ag_output, all_gather_input, self.shard_group, async_op)
503 self._result = AllGatherResult(self.ag_output, metadata, handle)
505 @_no_grad()
506 def foreach_all_gather_copy_out(self):
507 """Scatter one fused all-gather result back into per-parameter buffers."""
508 ag_output, metadata, handle = self._result
509 if handle is not None:
510 handle.wait()
511 world_size = self.shard_world_size
512 split_with_sizes_out = []
513 for input_numels, input_dtypes, hsdp_param in zip(
514 metadata.param_input_numels, metadata.param_input_dtypes, self.hsdp_params
515 ):
516 hsdp_param.init_all_gather_outputs(
517 input_numels,
518 input_dtypes,
519 world_size,
520 _normalize_device(ag_output.device),
521 )
522 hsdp_param.alloc_all_gather_outputs()
523 split_with_sizes_out.extend(hsdp_param.all_gather_outputs)
524 ag_output = ag_output.view(world_size, -1)
525 out = [t.view(world_size, -1) for t in split_with_sizes_out]
526 split_with_sizes_copy(ag_output, metadata.inp_split_sizes, dim=1, out=out)
527 self._result = None
528 self.free_all_gather_output()
530 @_no_grad()
531 def foreach_reduce(
532 self,
533 reduce_scatter_reduce_op: Optional[ops.ReduceOp] = ops.ReduceOp.SUM,
534 async_op: bool = True,
535 needs_avg_div: bool = False,
536 ) -> Optional[ms.Tensor]:
537 """Perform fused reduce-scatter and optional bucketed all-reduce."""
538 hsdp_params: List[MindSporeHSDPParamV2] = []
539 unsharded_grads: List[ms.Tensor] = []
540 for hsdp_param in self.hsdp_params:
541 if not hasattr(hsdp_param, "_unsharded_param"):
542 continue
543 if hsdp_param.unsharded_accumulated_grad is not None:
544 hsdp_params.append(hsdp_param)
545 unsharded_grads.append(hsdp_param.unsharded_accumulated_grad_data)
546 elif hsdp_param._unsharded_param.grad is not None:
547 hsdp_params.append(hsdp_param)
548 unsharded_grads.append(hsdp_param.unsharded_grad_data)
549 if not hsdp_params:
550 return None
551 grad_dtypes = {g.dtype for g in unsharded_grads}
552 if len(grad_dtypes) != 1:
553 raise ValueError(
554 f"FSDP reduce-scatter expects uniform grad dtype but got {grad_dtypes}"
555 )
556 grad_dtype = unsharded_grads[0].dtype
557 reduce_dtype = self._reduce_dtype or grad_dtype
558 world_size = self.shard_world_size
559 reduce_scatter_input_numel = sum(s.numel() for s in unsharded_grads)
560 reduce_scatter_output_numel = reduce_scatter_input_numel // world_size
561 device = _normalize_device(unsharded_grads[0].device)
562 reduce_scatter_input = ms.mint.empty((reduce_scatter_input_numel,), dtype=reduce_dtype, device=device)
563 reduce_scatter_copy_in(hsdp_params, unsharded_grads, reduce_scatter_input, world_size)
564 reduce_output = ms.mint.empty((reduce_scatter_output_numel,), dtype=reduce_dtype, device=device)
565 self._needs_avg_div = needs_avg_div
566 self._reduce_op = reduce_scatter_reduce_op
567 self._reduce_hsdp_params = hsdp_params
568 self._active_param_flat_offsets = []
569 flat_offset = 0
570 for hsdp_param in hsdp_params:
571 self._active_param_flat_offsets.append(flat_offset)
572 flat_offset += _shape_numel(hsdp_param.sharded_size)
573 self._active_replicate_buckets = self._build_active_replicate_buckets(hsdp_params)
574 self._allocate_bucket_buffers_if_needed(reduce_output.device, reduce_output.dtype)
575 self._pending_all_reduce_handles = []
576 if self.shard_group is None or world_size <= 1:
577 comm_ctx.comm_handle = None
578 self._reduce_output = reduce_scatter_input
579 if async_op:
580 comm_ctx.pre_param_group = self
581 else:
582 self.apply_fusion_reduced_grad()
583 return self._reduce_output
584 rs_handle = dist.reduce_scatter_tensor(
585 output=reduce_output,
586 input=reduce_scatter_input,
587 group=self.shard_group,
588 op=reduce_scatter_reduce_op,
589 async_op=async_op,
590 )
591 comm_ctx.comm_handle = rs_handle
592 self._reduce_output = reduce_output
593 if async_op:
594 comm_ctx.pre_param_group = self
595 else:
596 self.apply_fusion_reduced_grad()
597 return reduce_output
599 def wait_reduce_scatter_and_issue_all_reduce(self):
600 """Wait for reduce-scatter and issue async all-reduces for active buckets."""
601 if comm_ctx.comm_handle is not None:
602 comm_ctx.comm_handle.wait()
603 comm_ctx.comm_handle = None
604 if self._needs_avg_div and self._reduce_output is not None and self.shard_world_size > 1:
605 self._reduce_output.div_(self.shard_world_size)
606 if not self._active_replicate_buckets:
607 self._apply_reduced_grad()
608 return
609 self._pending_all_reduce_handles = []
610 for bucket in self._active_replicate_buckets.values():
611 packed = self._pack_bucket_from_reduce_output(bucket)
612 ar_handle = dist.all_reduce(
613 packed,
614 group=bucket.group,
615 op=self._reduce_op,
616 async_op=True,
617 )
618 self._pending_all_reduce_handles.append(
619 PendingBucketAllReduce(bucket_key=bucket.key, handle=ar_handle)
620 )
621 comm_ctx.all_reduce_param_group = self
623 def wait_all_reduce_and_apply_grad(self):
624 """Wait for pending bucket all-reduces and apply reduced grads."""
625 for pending in self._pending_all_reduce_handles:
626 bucket = self._active_replicate_buckets[pending.bucket_key]
627 pending.handle.wait()
628 if self._needs_avg_div and bucket.group_size > 1:
629 bucket.buffer.div_(bucket.group_size)
630 self._unpack_bucket_to_reduce_output(bucket)
631 self._pending_all_reduce_handles = []
632 comm_ctx.all_reduce_handle = None
633 self._apply_reduced_grad()
635 def apply_fusion_reduced_grad(self):
636 """Synchronous fallback: wait, all-reduce buckets, then apply grads."""
637 if comm_ctx.comm_handle is not None:
638 comm_ctx.comm_handle.wait()
639 comm_ctx.comm_handle = None
640 if self._needs_avg_div and self._reduce_output is not None and self.shard_world_size > 1:
641 self._reduce_output.div_(self.shard_world_size)
642 for bucket in self._active_replicate_buckets.values():
643 packed = self._pack_bucket_from_reduce_output(bucket)
644 dist.all_reduce(
645 packed,
646 group=bucket.group,
647 op=self._reduce_op,
648 )
649 if self._needs_avg_div and bucket.group_size > 1:
650 packed.div_(bucket.group_size)
651 self._unpack_bucket_to_reduce_output(bucket)
652 self._apply_reduced_grad()
654 def _apply_reduced_grad(self):
655 """Write reduced gradients from the fused output buffer back to params."""
656 flat_grad_offset = 0
657 if self._reduce_hsdp_params is None or self._reduce_output is None:
658 return
659 for hsdp_param in self._reduce_hsdp_params:
660 shard_numel = _shape_numel(hsdp_param.sharded_size)
661 new_sharded_grad = self._reduce_output.narrow(0, flat_grad_offset, shard_numel)
662 hsdp_param.apply_reduced_grad(new_sharded_grad, self._orig_dtype)
663 flat_grad_offset += shard_numel
664 self._reduce_output = None
665 self._reduce_hsdp_params = None
666 self._active_param_flat_offsets = []
667 self._active_replicate_buckets = {}
668 self._pending_all_reduce_handles = []