Coverage for hyper_parallel / platform / torch / fully_shard / param.py: 65%
340 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/distributed/fsdp/_fully_shard/_fsdp_param.py
16# enhanced with fully_shard parameter management
17# ============================================================================
18"""HSDP parameter"""
19from typing import List, Callable, Optional, cast, Sequence, Tuple, Any
20from dataclasses import dataclass, field
21import itertools
22import torch
23import torch.nn as nn
24import torch.distributed as dist
25from torch._prims_common import make_contiguous_strides_for
26from hyper_parallel.platform.torch.fully_shard.utils import (
27 MixedPrecisionPolicy,
28 CPUOffloadPolicy,
29 OffloadPolicy,
30 FSDPMeshInfo,
31 DDPMeshInfo,
32 HSDPMeshInfo,
33)
34from hyper_parallel.core.dtensor import DTensor
35from hyper_parallel.core.layout import Layout
36from hyper_parallel.core.fully_shard.hsdp_param import HSDPParamV2
37from hyper_parallel.core.fully_shard.hsdp_utils import ShardedState
38from hyper_parallel.core.placement_types import Shard, Replicate
39from hyper_parallel.core.fully_shard.hsdp_utils import ParamModuleInfo, ExtensionsData
42class TorchHSDPParamV2(HSDPParamV2):
43 """
44 Torch HSDP parameter.
45 """
47 def __init__(
48 self,
49 param: nn.Parameter,
50 module_info: ParamModuleInfo,
51 mesh_info: FSDPMeshInfo,
52 post_forward_mesh_info: Optional[FSDPMeshInfo] = None,
53 shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None,
54 mp_policy: Optional[MixedPrecisionPolicy] = None,
55 offload_policy: Optional[OffloadPolicy] = None,
56 threshold: int = 0,
57 device: Optional[torch.device] = None,
58 ):
59 self._module_info: ParamModuleInfo = module_info
60 self.mesh_info = mesh_info
61 self.post_forward_mesh_info = post_forward_mesh_info
62 self.mp_policy = mp_policy
63 self.threshold = threshold
64 self.device = device
65 self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy)
66 self.pin_memory = (
67 self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory
68 )
69 self.grad_offload_event: Optional[torch.Event] = None
70 self._init_sharded_param(param, shard_placement_fn)
71 if self.post_forward_mesh_info:
72 self._init_sharded_post_forward_param_metadata(param)
73 self._init_extensions()
74 self.all_gather_outputs: List[torch.Tensor] = []
75 self.unsharded_accumulated_grad = None
76 self._param_fqn: Optional[str] = None
77 # Communication attributes for prefetch pattern
78 self.prefetch_handle: Optional[dist.Work] = None
79 self._post_load_hook_handle = (
80 module_info.module.register_load_state_dict_post_hook(
81 lambda *args, **kwargs: self.reset_sharded_param()
82 )
83 )
85 @torch.no_grad()
86 def _init_sharded_param(
87 self,
88 param: nn.Parameter,
89 shard_placement_fn: Optional[Callable],
90 ) -> None:
91 if param.device != self.device and param.device.type != "meta":
92 raise AssertionError(
93 f"Expects the parameter to already be moved to device {self.device} but got {param.device}"
94 )
96 hsdp_placement = shard_placement_fn(param) if shard_placement_fn else None
97 if hsdp_placement is None:
98 hsdp_placement = Shard(0)
99 elif hsdp_placement.dim < 0:
100 # if dim is negative, add the number of dimensions of the parameter
101 hsdp_placement = Shard(hsdp_placement.dim + param.ndim)
103 if not isinstance(hsdp_placement, Shard):
104 raise AssertionError(
105 f"Expected Shard, got {type(hsdp_placement)}: {hsdp_placement}"
106 )
108 self.hsdp_placement = hsdp_placement
109 shard_dim = hsdp_placement.dim
111 # Non-DTensor parameters have no pre-defined SPMD semantics.
112 # FSDP/DDP solely determines the mesh and placements.
113 self._spmd_mesh = self.mesh_info.mesh
114 if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP
115 self._spmd_placements = (Replicate(), hsdp_placement)
116 elif isinstance(self.mesh_info, FSDPMeshInfo): # FSDP
117 self._spmd_placements = (hsdp_placement,)
118 elif isinstance(self.mesh_info, DDPMeshInfo): # DDP
119 self._spmd_placements = (Replicate(),)
120 param_data = param
122 shard_dim = hsdp_placement.dim
123 self._orig_size = param_data.size()
124 self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size)
126 if isinstance(self.mesh_info, FSDPMeshInfo): # FSDP or HSDP
127 shard_rank = self.mesh_info.shard_mesh_rank
128 shard_world_size = self.mesh_info.shard_mesh_size
129 else: # DDP
130 shard_rank = 0
131 shard_world_size = 1
133 # Check if parameter size is below threshold, if so skip sharding
134 param_size = param_data.numel() * param_data.element_size()
135 if self.threshold > 0 and param_size < self.threshold:
136 # Parameter too small, do not shard
137 self.is_sharded = False
138 self.sharded_size = param_data.size()
139 self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size)
140 self._sharded_param_data = param_data.view(-1)
142 self._sharding_spec = Layout.from_device_mesh(self._spmd_mesh)
143 # For unsharded params, use Replicate placement
144 if isinstance(self.mesh_info, HSDPMeshInfo):
145 self._spmd_placements = (Replicate(), Replicate())
146 else:
147 self._spmd_placements = (Replicate(),)
148 self._sharding_spec.set_placements(self._spmd_placements)
149 self._sharding_spec.placement_to_tensor_map(param.ndim)
151 self.sharded_param = nn.Parameter(DTensor.from_local(param_data, self._spmd_mesh, self._spmd_placements))
152 self.sharded_param.requires_grad_(param.requires_grad)
153 self._setattr_on_modules(self.sharded_param)
154 self.sharded_state = ShardedState.SHARDED
155 return
157 self.is_sharded = True
159 if param_data.size(shard_dim) % shard_world_size != 0:
160 raise NotImplementedError(
161 f"Uneven sharding on dim {shard_dim} not supported: "
162 f"shape={param_data.shape}, world_size={shard_world_size}"
163 )
164 chunks = torch.chunk(param_data, shard_world_size, dim=shard_dim)
165 sharded_param = chunks[shard_rank].clone().contiguous()
166 self.sharded_size = sharded_param.size()
167 self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size)
168 if self.offload_to_cpu and not sharded_param.is_meta:
169 sharded_param = sharded_param.cpu()
170 if self.pin_memory:
171 sharded_param = sharded_param.pin_memory()
172 self._sharded_param_data = sharded_param.view(-1)
174 self._sharding_spec = Layout.from_device_mesh(self._spmd_mesh)
175 self._sharding_spec.set_placements(self._spmd_placements)
176 self._sharding_spec.placement_to_tensor_map(param.ndim)
178 self.sharded_param = nn.Parameter(DTensor.from_local(sharded_param, self._spmd_mesh, self._spmd_placements))
179 self.sharded_param.requires_grad_(param.requires_grad)
180 self._setattr_on_modules(self.sharded_param)
181 # 初始化后,self.sharded_param替换掉原先的param,后续梯度也需要注意要累加到这个Parameter的grad上
182 self.sharded_param._hsdp_param_initialized = True
183 self.sharded_state = ShardedState.SHARDED
184 self.param_dtype = None
186 def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None:
187 mesh_info = self.post_forward_mesh_info
188 param_data = param._local_tensor if isinstance(param, DTensor) else param
189 if isinstance(mesh_info, FSDPMeshInfo):
190 chunks = torch.chunk(param_data, mesh_info.shard_mesh_size, dim=0)
191 self.sharded_post_forward_size = chunks[mesh_info.shard_mesh_rank].size()
192 else: # DDP
193 chunks = torch.chunk(param_data, 1, dim=0)
194 self.sharded_post_forward_size = chunks[0].size()
196 self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for(
197 self.sharded_post_forward_size
198 )
200 def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy):
201 param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype)
202 self.orig_dtype = self.sharded_param.dtype
203 if reduce_dtype == param_dtype:
204 reduce_dtype = None
205 if param_dtype == self.orig_dtype:
206 param_dtype = None
207 self.param_dtype = param_dtype
208 self.reduce_dtype = reduce_dtype
210 def _init_extensions(self) -> None:
211 inner_tensor = self._sharded_local_tensor
212 has_fsdp_pre_all_gather = hasattr(inner_tensor, "fsdp_pre_all_gather")
213 has_fsdp_post_all_gather = hasattr(inner_tensor, "fsdp_post_all_gather")
214 if has_fsdp_pre_all_gather != has_fsdp_post_all_gather:
215 raise AssertionError(
216 "Both fsdp_pre_all_gather and fsdp_post_all_gather should be defined "
217 f"if using all-gather extensions: {inner_tensor}"
218 )
219 if has_fsdp_pre_all_gather:
220 self._extensions_data = ExtensionsData()
221 self._unsharded_inner_tensors: list[torch.Tensor] = []
223 def init_all_gather_outputs(
224 self,
225 all_gather_input_numels: list[int],
226 all_gather_input_dtypes: list[torch.dtype],
227 world_size: int,
228 device: torch.device,
229 force_recreate: bool = False,
230 ):
231 if not force_recreate and len(self.all_gather_outputs) > 0:
232 return # already initialized
233 self.all_gather_outputs = [
234 torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
235 for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes)
236 ]
238 def init_unsharded_param(self):
239 """
240 Initialize unsharded parameter from all-gather outputs.
242 This reconstructs the full parameter after all-gather by using
243 the gathered data and reshaping it to the original size.
244 """
245 if hasattr(self, "_unsharded_param"):
246 return
248 # Get unsharded data from all-gather outputs
249 if len(self.all_gather_outputs) != 1:
250 raise AssertionError(
251 f"Expected 1 all_gather_output, got {len(self.all_gather_outputs)}"
252 )
253 unsharded_tensor = self.all_gather_outputs[0]
254 # Use reshape to safely handle both contiguous and non-contiguous memory layouts.
255 # It acts as a zero-copy view if possible, otherwise it performs a copy.
256 # unsharded_param = unsharded_tensor.reshape(self._orig_size)
257 unsharded_param = torch.as_strided(
258 unsharded_tensor,
259 self._orig_size,
260 self._contiguous_orig_stride,
261 storage_offset=0,
262 )
264 self._unsharded_param = nn.Parameter(
265 unsharded_param, requires_grad=self.sharded_param.requires_grad
266 )
268 def to_sharded(self) -> None:
269 self._setattr_on_modules(self.sharded_param)
270 self.free_unsharded_param()
271 self.sharded_state = ShardedState.SHARDED
273 def to_sharded_post_forward(self) -> None:
274 if self.sharded_state != ShardedState.UNSHARDED:
275 raise AssertionError(f"Expected sharded_state to be UNSHARDED, got {self.sharded_state}")
276 shard_world_size = self.post_forward_mesh_info.shard_mesh_size
277 numel = self.all_gather_outputs[0].numel()
278 if numel % shard_world_size != 0:
279 raise AssertionError(
280 f"All-gather output size ({numel}) must be divisible by the shard "
281 f"world size ({shard_world_size}). Check padding/mesh alignment."
282 )
283 shard_rank = self.post_forward_mesh_info.shard_mesh_rank
284 sharded_numel = numel // shard_world_size
285 # clone to be able to free all-gather output
286 self._sharded_post_forward_param_data = (
287 self.all_gather_outputs[0].narrow(
288 0, sharded_numel * shard_rank, sharded_numel
289 )
290 ).clone()
291 # sharded_post_forward_tensor = self._sharded_post_forward_param_data.view(
292 # self.sharded_post_forward_size
293 # )
294 sharded_post_forward_tensor = torch.as_strided(
295 self._sharded_post_forward_param_data,
296 size=self.sharded_post_forward_size,
297 stride=self.contiguous_sharded_post_forward_stride,
298 storage_offset=0,
299 )
300 self._sharded_post_forward_param = nn.Parameter(
301 self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor)
302 )
303 self._setattr_on_modules(self._sharded_post_forward_param)
304 self.free_unsharded_param()
305 self.sharded_state = ShardedState.SHARDED_POST_FORWARD
307 def to_unsharded(self) -> None:
308 set_requires_grad_if_needed(self.sharded_param, self._unsharded_param)
309 self._setattr_on_modules(self._unsharded_param)
310 if self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
311 self._sharded_post_forward_param = None
312 self._sharded_post_forward_param_data = None
313 self.sharded_state = ShardedState.UNSHARDED
315 def _setattr_on_modules(self, param: nn.Parameter) -> None:
316 if getattr(self._module_info.module.__setattr__, "__func__", None) is nn.Module.__setattr__:
317 # fast path
318 self._module_info.module._parameters[self._module_info.param_name] = param
319 else:
320 # slow path
321 setattr(self._module_info.module, self._module_info.param_name, param)
323 # Iterate through all modules that share this parameter to prevent pointer desync.
324 for shared_module, shared_param_name in zip(
325 self._module_info.shared_modules, self._module_info.shared_param_names
326 ):
327 if getattr(shared_module.__setattr__, "__func__", None) is nn.Module.__setattr__:
328 shared_module._parameters[shared_param_name] = param
329 else:
330 setattr(shared_module, shared_param_name, param)
332 def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor:
333 """
334 Converts a local tensor representing either the sharded parameter or
335 sharded gradient to DTensor.
336 """
337 return DTensor.from_local(
338 tensor,
339 self._sharding_spec.mesh,
340 self._sharding_spec.placements
341 )
343 def to_sharded_post_forward_dtensor(self, tensor: torch.Tensor) -> DTensor:
344 """
345 Converts a local tensor to DTensor with post-forward sharding layout.
346 """
347 post_forward_layout = Layout.from_device_mesh(self.post_forward_mesh_info.mesh)
348 post_forward_layout.set_placements((Replicate(), Shard(0)))
349 post_forward_layout.placement_to_tensor_map(tensor.ndim)
350 return DTensor.from_local(tensor, post_forward_layout.mesh, post_forward_layout.placements)
352 def to_accumulated_grad_if_needed(self) -> None:
353 if (
354 self._unsharded_param.grad is not None
355 and self.reduce_dtype is not None
356 and self._unsharded_param.grad.dtype != self.reduce_dtype
357 ):
358 # need to handle the gradient even after the parameter is resharded
359 unsharded_grad = self._unsharded_param.grad
360 self._unsharded_param.grad = None
361 self.unsharded_accumulated_grad = unsharded_grad.to(self.reduce_dtype)
363 def accumulate_unsharded_grad_if_needed(self) -> None:
364 if (
365 self.unsharded_accumulated_grad is not None
366 and self.unsharded_param.grad is not None
367 ):
368 # need to handle the gradient
369 self.unsharded_accumulated_grad += self.unsharded_param.grad
370 self.unsharded_param.grad = None
372 def alloc_all_gather_outputs(self) -> None:
373 for tensor in self.all_gather_outputs:
374 expected_size = tensor.numel() * tensor.itemsize
375 storage = tensor.untyped_storage()
376 if storage.size() != expected_size:
377 storage.resize_(expected_size)
379 def free_unsharded_param(self) -> None:
380 for tensor in itertools.chain(
381 self.all_gather_outputs, self._unsharded_inner_tensors
382 ):
383 storage = tensor.untyped_storage()
384 if storage.size() != 0:
385 storage.resize_(0)
387 @property
388 def all_gather_inputs(self) -> list[torch.Tensor]:
389 self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD)
390 if self.sharded_state == ShardedState.SHARDED:
391 sharded_param_data = self._sharded_param_data
392 if self.offload_to_cpu:
393 sharded_param_data = sharded_param_data.to(
394 self.device, non_blocking=True
395 )
396 if self.param_dtype is not None and self.param_dtype != sharded_param_data.dtype:
397 return [sharded_param_data.to(self.param_dtype)]
398 else:
399 return [sharded_param_data]
400 elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
401 if self.param_dtype is not None and self.param_dtype != self._sharded_post_forward_param_data.dtype:
402 return [self._sharded_post_forward_param_data.to(self.param_dtype)]
403 else:
404 return [self._sharded_post_forward_param_data]
405 return [torch.empty(0)]
407 @property
408 def unsharded_param(self) -> nn.Parameter: # ND
409 return self._unsharded_param
411 @property
412 def unsharded_grad_data(self) -> torch.Tensor:
413 """
414 Get the unsharded gradient data as a local tensor.
415 """
416 grad = self.unsharded_param.grad
417 if grad is None:
418 raise AssertionError("Expects unsharded_param.grad to not be None")
419 if isinstance(grad, DTensor):
420 raise AssertionError("Expected torch.Tensor, got DTensor")
421 return grad
423 @property
424 def unsharded_accumulated_grad_data(self) -> torch.Tensor:
425 """
426 Get the unsharded accumulated gradient data as a local tensor.
427 """
428 grad = self.unsharded_accumulated_grad
429 # if grad is None:
430 # raise AssertionError("Expects unsharded_accumulated_grad to not be None")
431 # if isinstance(grad, DTensor):
432 # raise AssertionError("Expected torch.Tensor, got DTensor")
433 return grad
435 @property
436 def _sharded_local_tensor(self) -> torch.Tensor:
437 return cast(DTensor, self.sharded_param)._local_tensor
439 @property
440 def shard_world_size(self) -> int:
441 """Get the world size for shard dimension."""
442 if isinstance(self.mesh_info, FSDPMeshInfo):
443 return self.mesh_info.shard_mesh_size
444 return 1
446 @property
447 def replicate_world_size(self) -> int:
448 """Get the world size for replicate dimension (HSDP only)."""
449 if isinstance(self.mesh_info, HSDPMeshInfo):
450 return self.mesh_info.replicate_mesh_size
451 return 1
453 def _assert_in_states(self, *states: ShardedState) -> None:
454 """Assert current state is one of expected states."""
455 if self.sharded_state not in states:
456 raise AssertionError(
457 f"Expected sharded_state in {states}, got {self.sharded_state}"
458 )
460 def reset_sharded_param(self) -> None:
461 """Reset sharded param after load_state_dict."""
462 module_info = self._module_info
463 new_param = getattr(module_info.module, module_info.param_name)
464 if new_param is not self.sharded_param:
465 # Ensure object identity is preserved after parameter conversion.
466 if torch.__future__.get_swap_module_params_on_conversion():
467 raise AssertionError(
468 f"Expects swap_tensors to preserve object but got {new_param} "
469 f"instead of {self.sharded_param}"
470 )
471 self.sharded_param = new_param
473 local_tensor = new_param._local_tensor
474 if local_tensor.is_meta:
475 return
476 updated_local_tensor = False
477 # local_tensor can be padded twice
478 # 1st time in fully_shard(model)
479 # 2nd time in model(input) lazy_init
480 # 2nd time should be no-op if parameters remain unchanged
481 # 2nd time shouldn't be no-op if people call model.load_state_dict(...) before lazy_init
482 # this makes it possible for trainer to call `sd = model.state_dict()` before the training loop
483 # and use `sd` without calling .state_dict() per iteration
484 same_local_tensor = False
485 # TODO: need to support tensor subclass
486 if type(self._sharded_param_data) is torch.Tensor:
487 same_local_tensor = (
488 # when sharding param with shape (1, ...) over 2 ranks
489 # local_tensor on rank 1 can be size 0, data_ptr() can be 0
490 self._sharded_param_data.untyped_storage().data_ptr() > 0
491 and self._sharded_param_data.untyped_storage().data_ptr()
492 == local_tensor.untyped_storage().data_ptr()
493 )
494 sharded_size = self.sharded_size
495 shard_dim = self.hsdp_placement.dim
496 length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0
497 if local_tensor.size() != sharded_size and not same_local_tensor:
498 raise AssertionError(
499 f"Expected sharded_size to be {sharded_size}, got {local_tensor.size()}"
500 )
501 if self.pin_memory and not local_tensor.is_pinned():
502 local_tensor = local_tensor.cpu().pin_memory()
503 updated_local_tensor = True
504 if not same_local_tensor:
505 self._sharded_param_data = local_tensor.view(-1)
506 if not isinstance(self.sharded_param, DTensor):
507 raise AssertionError(f"Expected DTensor, got {type(self.sharded_param)}")
508 if updated_local_tensor:
509 # Only change the local tensor object if needed
510 self.sharded_param._local_tensor = local_tensor.narrow(
511 dim=shard_dim, start=0, length=length
512 )
513 if not self.sharded_param._local_tensor.is_contiguous():
514 raise AssertionError(
515 "Expected sharded_param._local_tensor to be contiguous"
516 )
517 self._sharding_spec = cast(DTensor, self.sharded_param).layout
519 def _get_unsharded_param_data(self, async_op: bool = False) -> Tuple[torch.Tensor, Optional[dist.Work]]:
520 """
521 Perform all-gather to get unsharded parameter data.
523 Args:
524 async_op: Whether to execute asynchronously.
526 Returns:
527 (unsharded_param, handle): Unsharded parameter data and communication handle.
528 """
529 # If parameter is not sharded (below threshold), no communication needed
530 if not self.is_sharded:
531 self.init_all_gather_outputs(
532 all_gather_input_numels=[self._sharded_param_data.numel()],
533 all_gather_input_dtypes=[self._sharded_param_data.dtype],
534 world_size=1,
535 device=self.device,
536 )
537 self.alloc_all_gather_outputs()
538 self.all_gather_outputs[0].copy_(self._sharded_param_data)
539 return self.all_gather_outputs[0], None
541 # Get input data
542 all_gather_input = self.all_gather_inputs[0]
544 # Initialize output buffer
545 self.init_all_gather_outputs(
546 all_gather_input_numels=[all_gather_input.numel()],
547 all_gather_input_dtypes=[all_gather_input.dtype],
548 world_size=self.shard_world_size,
549 device=self.device,
550 )
551 self.alloc_all_gather_outputs()
553 # Get communication group
554 shard_group = self.mesh_info.shard_process_group if isinstance(self.mesh_info, FSDPMeshInfo) else None
556 if shard_group is None or self.shard_world_size <= 1:
557 # No communication needed, just copy
558 self.all_gather_outputs[0].copy_(all_gather_input)
559 return self.all_gather_outputs[0], None
561 # Execute all_gather_into_tensor
562 handle = dist.all_gather_into_tensor(
563 self.all_gather_outputs[0],
564 all_gather_input,
565 group=shard_group,
566 async_op=async_op,
567 )
569 return self.all_gather_outputs[0], handle
571 def unshard(self, async_op: bool = False) -> None:
572 if self.prefetch_handle is not None:
573 # 已经被prefetch 触发过了,直接return
574 return # no-op
576 _, handle = self._get_unsharded_param_data(async_op=async_op)
577 self.prefetch_handle = handle
579 def wait_for_unshard(self) -> None:
580 self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD)
582 if self.prefetch_handle is not None:
583 self.prefetch_handle.wait()
584 self.prefetch_handle = None
586 self.init_unsharded_param()
587 self.to_unsharded()
589 def shard(self) -> None:
590 """
591 Transition parameter from unsharded back to sharded state.
592 """
593 self._assert_in_states(ShardedState.UNSHARDED)
594 self.to_sharded()
596 def reduce_scatter_grad(
597 self,
598 async_op: bool = False,
599 dtype: Optional[torch.dtype] = None,
600 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG
601 ) -> Tuple[torch.Tensor, Optional[dist.Work]]:
602 """
603 Perform reduce-scatter on gradient to reduce and shard the full gradient.
605 Args:
606 async_op: Whether to execute asynchronously.
607 dtype: reduce dtype.
608 reduce_op: do reduce-scatter avg or sum.
610 Returns:
611 (sharded_grad, handle): Sharded gradient and communication handle.
612 """
613 self._assert_in_states(ShardedState.UNSHARDED)
615 # Choose gradient source based on use_accumulated_grad flag
616 if self.unsharded_accumulated_grad is not None:
617 grad = self.unsharded_accumulated_grad_data
618 else:
619 grad = self.unsharded_grad_data
620 reduce_dtype = dtype or grad.dtype
621 grad = grad.to(reduce_dtype)
622 grad_flat = grad.view(-1)
624 # If parameter is not sharded (below threshold), no reduce-scatter needed
625 if not self.is_sharded:
626 return grad_flat, None
628 # Get communication group
629 shard_group = self.mesh_info.shard_process_group if isinstance(self.mesh_info, FSDPMeshInfo) else None
631 if shard_group is None or self.shard_world_size <= 1:
632 # No communication needed
633 return grad_flat, None
635 # Calculate output size
636 output_numel = grad_flat.numel() // self.shard_world_size
637 output = torch.empty(output_numel, dtype=reduce_dtype, device=grad.device)
639 # Execute reduce_scatter_tensor
640 handle = dist.reduce_scatter_tensor(
641 output,
642 grad_flat,
643 op=reduce_op,
644 group=shard_group,
645 async_op=async_op,
646 )
648 return output, handle
650 def all_reduce_grad(
651 self,
652 grad: Optional[torch.Tensor] = None,
653 async_op: bool = False,
654 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG
655 ) -> Tuple[torch.Tensor, Optional[dist.Work]]:
656 """
657 Perform all-reduce on gradient (across replicate dimension in HSDP mode).
659 Args:
660 grad: Gradient tensor to reduce. If None, will use unsharded_param.grad
661 or unsharded_accumulated_grad based on use_accumulated_grad flag.
662 async_op: Whether to execute asynchronously.
663 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG.
665 Returns:
666 (reduced_grad, handle): Reduced gradient and communication handle.
667 """
668 # If grad is not provided, get from parameter
669 if grad is None:
670 if self.unsharded_accumulated_grad is not None:
671 grad = self.unsharded_accumulated_grad_data
672 else:
673 grad = self.unsharded_grad_data
675 if not isinstance(self.mesh_info, HSDPMeshInfo):
676 # Not HSDP mode, no all-reduce needed
677 return grad, None
679 replicate_group = self.mesh_info.replicate_process_group
680 if replicate_group is None or self.replicate_world_size <= 1:
681 return grad, None
683 handle = dist.all_reduce(
684 grad,
685 op=reduce_op,
686 group=replicate_group,
687 async_op=async_op
688 )
689 return grad, handle
692def set_requires_grad_if_needed(
693 src_tensor: torch.Tensor, dst_tensor: torch.Tensor
694) -> None:
695 if src_tensor.requires_grad != dst_tensor.requires_grad:
696 dst_tensor.requires_grad_(src_tensor.requires_grad)