Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / param.py: 69%
452 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/distributed/fsdp/_fully_shard/_fsdp_param.py
16# enhanced with fully_shard parameter management
17# ============================================================================
18"""HSDP parameter"""
19# pylint: disable=W0212
20import itertools
21from typing import Callable, List, Optional, Tuple, Union, cast
23import torch
24import torch.distributed as dist
25from torch import nn
26from torch._prims_common import make_contiguous_strides_for
28from hyper_parallel.core.dtensor.device_mesh import DeviceMesh
29from hyper_parallel.core.dtensor.dtensor import DTensor, SkipDTensorDispatch
30from hyper_parallel.core.dtensor.layout import Layout
31from hyper_parallel.core.dtensor.placement_types import Replicate, Shard, StridedShard
32from hyper_parallel.core.fully_shard.hsdp_param import HSDPParamV2
33from hyper_parallel.core.fully_shard.hsdp_utils import (
34 FullyShardParamMode,
35 GroupInfo,
36 ParamModuleInfo,
37 ShardedState,
38 get_rank_list_for_axes,
39 get_split_rank_lists_for_axes,
40)
41from hyper_parallel.core.fully_shard.utils import (
42 CPUOffloadPolicy,
43 DDPMeshInfo,
44 FSDPMeshInfo,
45 MixedPrecisionPolicy,
46 OffloadPolicy,
47)
48from hyper_parallel.platform import get_platform
49from hyper_parallel.platform.torch.fully_shard.pack_utils import (
50 build_rs_plan,
51 pack_for_reduce_scatter,
52 unpack_from_all_gather,
53)
55_GROUP_INFO_CACHE = {}
56platform = get_platform()
59def _copy_without_bumping_version(dst: torch.Tensor, src: torch.Tensor) -> None:
60 """Copy into ``dst`` while preserving its autograd version counter."""
61 # pylint: disable=W0212
62 with torch.autograd._unsafe_preserve_version_counter(dst):
63 dst.copy_(src)
66def _build_group_info_from_rank_list(
67 group_name: str,
68 rank_list,
69) -> GroupInfo:
70 """Create group metadata from an explicit rank list."""
71 normalized_rank_list = tuple(sorted(int(rank) for rank in rank_list))
72 if len(normalized_rank_list) <= 1:
73 return GroupInfo(f"{group_name}_invalid", None, 1)
74 if normalized_rank_list in _GROUP_INFO_CACHE:
75 cached_group = _GROUP_INFO_CACHE[normalized_rank_list]
76 return GroupInfo(str(normalized_rank_list), cached_group, len(normalized_rank_list))
77 try:
78 group = platform.create_group(list(normalized_rank_list))
79 except (RuntimeError, ValueError): # pragma: no cover - UT may run without dist init
80 group = None
81 _GROUP_INFO_CACHE[normalized_rank_list] = group
82 return GroupInfo(str(normalized_rank_list), group, len(normalized_rank_list))
85def _build_group_info_from_process_group(
86 group_name: str,
87 process_group,
88 rank_size: int,
89) -> GroupInfo:
90 """Create group metadata from an existing process group."""
91 if process_group is None or rank_size <= 1:
92 return GroupInfo(f"{group_name}_invalid", None, 1)
93 try:
94 rank_list = dist.get_process_group_ranks(process_group)
95 resolved_group_name = str(tuple(sorted(rank_list)))
96 except (AssertionError, AttributeError, KeyError, RuntimeError, TypeError, ValueError):
97 # pragma: no cover - best-effort naming / mocked process groups in UT
98 resolved_group_name = group_name
99 return GroupInfo(resolved_group_name, process_group, rank_size)
102class TorchHSDPParamV2(HSDPParamV2):
103 """
104 Torch HSDP parameter.
105 """
107 def __init__(
108 self,
109 param: nn.Parameter,
110 module_info: ParamModuleInfo,
111 mesh_info: FSDPMeshInfo,
112 shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None,
113 mp_policy: Optional[MixedPrecisionPolicy] = None,
114 offload_policy: Optional[OffloadPolicy] = None,
115 device: Optional[torch.device] = None,
116 param_mode: Optional[FullyShardParamMode] = None,
117 enable_fsdp_shard: bool = True,
118 ):
119 """
120 Initialize TorchHSDPParamV2 and shard the parameter.
122 Args:
123 param (nn.Parameter): The original full parameter to shard.
124 module_info (ParamModuleInfo): Ownership and shared-weight metadata.
125 mesh_info (FSDPMeshInfo): Mesh topology for shard/replicate dimensions.
126 shard_placement_fn (Callable, optional): Returns a Shard placement for the parameter,
127 or None to use default (Shard(0)).
128 mp_policy (MixedPrecisionPolicy, optional): Mixed precision dtype policy.
129 offload_policy (OffloadPolicy, optional): CPU offload policy.
130 device (torch.device, optional): Target device for the sharded parameter.
131 """
132 self._module_info: ParamModuleInfo = module_info
133 self.mesh_info = mesh_info
134 self.mp_policy = mp_policy
135 self.device = device
136 if param_mode is None:
137 raise AssertionError("param_mode must be resolved before TorchHSDPParamV2 initialization.")
138 self.param_mode = param_mode
139 self.enable_fsdp_shard = enable_fsdp_shard
140 self.orig_dtype = None
141 self.param_dtype = None
142 self.reduce_dtype = None
143 self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy)
144 self.pin_memory = (
145 self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory
146 )
147 self.grad_offload_event: Optional[torch.Event] = None
148 self._orig_param_is_dtensor = isinstance(param, DTensor)
149 self._orig_dtensor_mesh = param.device_mesh if self._orig_param_is_dtensor else None
150 self._orig_dtensor_placements = tuple(param.placements) if self._orig_param_is_dtensor else None
151 self._spmd_shard_mesh_dim = self.mesh_info.shard_mesh_dim
152 self._spmd_replicate_mesh_dim = self.mesh_info.replicate_mesh_dim
153 self._init_sharded_param(param, shard_placement_fn)
154 self._init_group_infos()
155 self.all_gather_outputs: List[torch.Tensor] = []
156 self.unsharded_accumulated_grad = None
157 self._param_fqn: Optional[str] = None
158 # Communication attributes for prefetch pattern
159 self.prefetch_handle: Optional[dist.Work] = None
160 self._post_load_hook_handle = (
161 module_info.module.register_load_state_dict_post_hook(
162 lambda *args, **kwargs: self.reset_sharded_param()
163 )
164 )
165 self._reduce_scatter_output = None
166 self.reduce_scatter_handle = None
167 self._all_reduce_output = None
168 self.all_reduce_handle = None
170 @property
171 def uses_param_shard(self) -> bool:
172 """Whether fully_shard should physically shard parameter storage for this param."""
173 return self.enable_fsdp_shard
175 @property
176 def is_dtensor_compat_mode(self) -> bool:
177 """Whether the parameter is managed through the DTensor compatibility path only."""
178 return self.param_mode == FullyShardParamMode.DTENSOR_COMPAT
180 def _get_base_spmd_placements(self) -> tuple:
181 if self.param_mode == FullyShardParamMode.DTENSOR_UNIFIED and self._orig_param_is_dtensor:
182 # DTENSOR_UNIFIED keeps the original distributed layout and prefixes
183 # explicit DP/FSDP mesh dimensions ahead of it on the unified mesh.
184 self._spmd_mesh = DeviceMesh.concatenate([self.mesh_info.mesh, self._orig_dtensor_mesh])
185 dp_prefix_placements = tuple(Replicate() for _ in range(self.mesh_info.mesh.ndim))
186 return dp_prefix_placements + tuple(self._orig_dtensor_placements)
188 if self.is_dtensor_compat_mode and self._orig_param_is_dtensor:
189 self._spmd_mesh = self._orig_dtensor_mesh
190 return tuple(self._orig_dtensor_placements)
192 self._spmd_mesh = self.mesh_info.mesh
193 return tuple(Replicate() for _ in range(self._spmd_mesh.ndim))
195 def _apply_data_parallel_placements(self, placements: list, shard_placement: Shard) -> tuple:
196 if len(placements) != self._spmd_mesh.ndim:
197 raise AssertionError(
198 f"Expected {self._spmd_mesh.ndim} unified placements, got {len(placements)}: {placements}"
199 )
200 if (
201 isinstance(self.mesh_info, DDPMeshInfo)
202 and self._spmd_replicate_mesh_dim is not None
203 and not self._orig_param_is_dtensor
204 ):
205 placements[self._spmd_replicate_mesh_dim] = Replicate()
206 if (
207 self.uses_param_shard
208 and isinstance(self.mesh_info, FSDPMeshInfo)
209 and self._spmd_shard_mesh_dim is not None
210 ):
211 # If TP/EP already shards the same tensor dimension, fully_shard must
212 # use StridedShard so the unified placement preserves the intended
213 # shard order on the concatenated mesh.
214 split_factor = 1
215 for mesh_idx, placement in enumerate(placements):
216 if mesh_idx == self._spmd_shard_mesh_dim:
217 continue
218 if placement.is_shard(shard_placement.dim):
219 split_factor *= self._spmd_mesh.mesh_shape[mesh_idx]
220 placements[self._spmd_shard_mesh_dim] = (
221 StridedShard(shard_placement.dim, split_factor=split_factor)
222 if split_factor > 1
223 else shard_placement
224 )
225 return tuple(placements)
227 def _init_group_infos(self) -> None:
228 if self.uses_param_shard and self.is_sharded and isinstance(self.mesh_info, FSDPMeshInfo):
229 self.sharded_group_info = _build_group_info_from_process_group(
230 "fully_shard_sharded_group",
231 self.mesh_info.shard_process_group,
232 self.mesh_info.shard_mesh_size,
233 )
234 else:
235 self.sharded_group_info = GroupInfo("fully_shard_sharded_group_invalid", None, 1)
237 # The all-reduce group is always derived from the final materialized layout.
238 # This keeps replicate_params, DTensor compat, and unified multi-dim layouts
239 # on a single source of truth.
240 self.unsharded_group_info = self._build_layout_driven_group_info()
242 self.shard_size = self.sharded_group_info.rank_size
243 self.dp_size = self.unsharded_group_info.rank_size
244 self.rank_size = max(1, self.shard_size * self.dp_size)
246 def _build_layout_driven_group_info(self):
247 group_axes = [
248 axis
249 for axis, placement in enumerate(self._spmd_placements)
250 if placement.is_replicate()
251 ]
252 if self.uses_param_shard and self._spmd_shard_mesh_dim is not None:
253 group_axes = [axis for axis in group_axes if axis != self._spmd_shard_mesh_dim]
254 if not group_axes:
255 return GroupInfo("fully_shard_unsharded_group_invalid", None, 1)
256 group_dim_names = getattr(self._spmd_mesh, "mesh_dim_names", None)
257 if group_dim_names:
258 try:
259 mesh_axis_names = tuple(group_dim_names[axis] for axis in group_axes)
260 if len(mesh_axis_names) == 1:
261 axis_name = mesh_axis_names[0]
262 process_group = self._spmd_mesh.get_group(axis_name)
263 if process_group is not None:
264 rank_size = self._spmd_mesh.mesh_shape[group_dim_names.index(axis_name)]
265 return _build_group_info_from_process_group(
266 "fully_shard_unsharded_group",
267 process_group,
268 rank_size,
269 )
271 split_rank_lists = get_split_rank_lists_for_axes(self._spmd_mesh, group_axes)
272 process_group = platform.split_group(split_ranks=split_rank_lists)
273 if process_group is not None:
274 rank_size = 1
275 for axis in group_axes:
276 rank_size *= self._spmd_mesh.mesh_shape[axis]
277 return _build_group_info_from_process_group(
278 "fully_shard_unsharded_group",
279 process_group,
280 rank_size,
281 )
282 except (
283 AssertionError,
284 AttributeError,
285 KeyError,
286 RuntimeError,
287 TypeError,
288 ValueError,
289 ):
290 # Fall back to the explicit rank-list path for mocked meshes in UT
291 # or when a mesh implementation cannot materialize a reusable group.
292 pass
294 rank_list = get_rank_list_for_axes(self._spmd_mesh, group_axes)
295 return _build_group_info_from_rank_list("fully_shard_unsharded_group", rank_list)
297 def _to_local_unsharded_grad(self, grad):
298 """Normalize a pending gradient to a local tensor expected by fully_shard collectives."""
299 if not isinstance(grad, DTensor):
300 return grad
302 if any(placement.is_partial() for placement in grad.placements):
303 grad = grad.reduce_partial()
305 if (
306 self._orig_dtensor_mesh is not None
307 and grad.device_mesh.to_hash() != self._orig_dtensor_mesh.to_hash()
308 ) or (
309 self._orig_dtensor_placements is not None
310 and tuple(grad.placements) != tuple(self._orig_dtensor_placements)
311 ):
312 grad = grad.redistribute(self._orig_dtensor_mesh, self._orig_dtensor_placements)
313 return grad.to_local()
315 def reduce_scatter_output(self):
316 """
317 Get the reduce-scatter output tensor and wait for asynchronous operation to complete.
319 Returns:
320 torch.Tensor: The sharded gradient tensor after reduce-scatter operation.
321 """
322 if self.reduce_scatter_handle is not None:
323 self.reduce_scatter_handle.wait()
324 self.reduce_scatter_handle = None
325 return self._reduce_scatter_output
327 def clear_reduce_scatter_output(self):
328 """Clear the reduce-scatter output tensor to free memory."""
329 self._reduce_scatter_output = None
331 def all_reduce_output(self):
332 """
333 Get the all-reduce output tensor and wait for asynchronous operation to complete.
335 Returns:
336 torch.Tensor: The reduced gradient tensor after all-reduce operation.
337 """
338 if self.all_reduce_handle is not None:
339 self.all_reduce_handle.wait()
340 self.all_reduce_handle = None
341 return self._all_reduce_output
343 def clear_all_reduce_output(self):
344 """Clear the all-reduce output tensor to free memory."""
345 self._all_reduce_output = None
347 def apply_reduced_grad(self, reduced_grad, param_type):
348 """
349 Apply reduced gradient to the sharded parameter.
351 Reshapes ``reduced_grad`` to match the local shard, optionally
352 offloads to CPU, then accumulates or assigns onto
353 ``hsdp_param.sharded_param.grad``.
355 Args:
356 reduced_grad (torch.Tensor): Gradient after reduce-scatter
357 and/or all-reduce.
358 param_type (Optional[torch.dtype]): Target dtype for the gradient (if conversion is needed).
359 """
360 sharded_grad = None
361 if not self.mp_policy.apply_grad_on_fp32_main_grad:
362 sharded_grad = self.sharded_param.grad
363 else:
364 if not hasattr(self.sharded_param, "main_grad"):
365 self.sharded_param.main_grad = None
366 sharded_grad = self.sharded_param.main_grad
367 sharded_param_local_shape = (
368 self.sharded_param.local_shape
369 if isinstance(self.sharded_param, DTensor)
370 else self.sharded_param.shape
371 )
372 reduced_grad = reduced_grad.view(sharded_param_local_shape)
373 if (not self.mp_policy.apply_grad_on_fp32_main_grad and param_type is not None
374 and reduced_grad.dtype != param_type):
375 reduced_grad = reduced_grad.to(param_type)
376 to_accumulate_grad = sharded_grad is not None
377 need_synchronize = False
378 if self.offload_to_cpu:
379 non_blocking = self.pin_memory and not to_accumulate_grad
380 reduced_grad = reduced_grad.to(
381 torch.device("cpu"), non_blocking=non_blocking
382 )
383 need_synchronize = True
384 if sharded_grad is None:
385 if not self.mp_policy.apply_grad_on_fp32_main_grad:
386 self.sharded_param.grad = self.to_sharded_dtensor(reduced_grad)
387 else:
388 self.sharded_param.main_grad = self.to_sharded_dtensor(reduced_grad)
389 self.sharded_param.grad = None
390 else:
391 with SkipDTensorDispatch():
392 if not self.mp_policy.apply_grad_on_fp32_main_grad:
393 self.sharded_param.grad._local_tensor += reduced_grad
394 else:
395 self.sharded_param.main_grad._local_tensor += reduced_grad
396 self.sharded_param.grad = None
397 if self.unsharded_accumulated_grad_data is not None:
398 self.unsharded_accumulated_grad = None
399 elif self.unsharded_param.grad is not None:
400 self.unsharded_param.grad = None
401 return need_synchronize
403 @torch.no_grad()
404 def _init_sharded_param(
405 self,
406 param: nn.Parameter,
407 shard_placement_fn: Optional[Callable],
408 ) -> None:
409 if param.device != self.device and param.device.type != "meta":
410 raise AssertionError(
411 f"Expects the parameter to already be moved to device {self.device} but got {param.device}"
412 )
414 hsdp_placement = shard_placement_fn(param) if shard_placement_fn else None
415 if hsdp_placement is None:
416 hsdp_placement = Shard(0)
417 elif hsdp_placement.dim < 0:
418 # if dim is negative, add the number of dimensions of the parameter
419 hsdp_placement = Shard(hsdp_placement.dim + param.ndim)
421 if not isinstance(hsdp_placement, Shard):
422 raise AssertionError(
423 f"Expected Shard, got {type(hsdp_placement)}: {hsdp_placement}"
424 )
426 self.hsdp_placement = hsdp_placement
427 base_placements = list(self._get_base_spmd_placements())
428 self._spmd_placements = self._apply_data_parallel_placements(base_placements, hsdp_placement)
429 param_data = param.to_local() if self._orig_param_is_dtensor else param
431 shard_dim = hsdp_placement.dim
432 self._orig_size = param_data.size()
433 self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size)
435 if self.uses_param_shard and isinstance(self.mesh_info, FSDPMeshInfo):
436 shard_rank = self.mesh_info.shard_mesh_rank
437 shard_world_size = self.mesh_info.shard_mesh_size
438 else:
439 shard_rank = 0
440 shard_world_size = 1
442 if isinstance(param_data, DTensor) and isinstance(self.mesh_info, DDPMeshInfo):
443 param_data.data = param_data.full_tensor()
445 self.is_sharded = bool(self.uses_param_shard and shard_world_size > 1)
447 if param_data.size(shard_dim) % shard_world_size != 0:
448 raise NotImplementedError(
449 f"Uneven sharding on dim {shard_dim} not supported: "
450 f"shape={param_data.shape}, world_size={shard_world_size}"
451 )
452 chunks = torch.chunk(param_data, shard_world_size, dim=shard_dim)
453 sharded_param = chunks[shard_rank].clone().contiguous()
454 self.sharded_size = sharded_param.size()
455 self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size)
456 if self.offload_to_cpu and not sharded_param.is_meta:
457 sharded_param = sharded_param.cpu()
458 if self.pin_memory:
459 sharded_param = sharded_param.pin_memory()
460 self._sharded_param_data = sharded_param.view(-1)
462 self._sharding_spec = Layout.from_device_mesh(self._spmd_mesh)
463 self._sharding_spec.set_placements(self._spmd_placements)
464 self._sharding_spec.placement_to_tensor_map(param.ndim)
466 self.sharded_param = nn.Parameter(DTensor.from_local(sharded_param, self._spmd_mesh, self._spmd_placements))
467 self.sharded_param.requires_grad_(param.requires_grad)
468 self._setattr_on_modules(self.sharded_param)
469 # after init, self.sharded_param replaces original param, gradients must accumulate to this Parameter's grad
470 self.sharded_param._hsdp_param_initialized = True
471 self.sharded_state = ShardedState.SHARDED
472 self.param_dtype = None
474 def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy):
475 """Initialize param_dtype and reduce_dtype from the mixed precision policy."""
476 param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype)
477 self.orig_dtype = self.sharded_param.dtype
478 if reduce_dtype == param_dtype:
479 reduce_dtype = None
480 if param_dtype == self.orig_dtype:
481 param_dtype = None
482 self.param_dtype = param_dtype
483 self.reduce_dtype = reduce_dtype
485 def init_all_gather_outputs(
486 self,
487 all_gather_input_numels: list[int],
488 all_gather_input_dtypes: list[torch.dtype],
489 world_size: int,
490 device: torch.device,
491 force_recreate: bool = False,
492 ):
493 """
494 Allocate output buffers for all-gather communication.
496 Args:
497 all_gather_input_numels: Number of elements per input shard.
498 all_gather_input_dtypes: Dtype of each input shard.
499 world_size: Number of ranks in the shard process group.
500 device: Device on which to allocate the output buffers.
501 force_recreate: If True, always recreate buffers even if already initialized.
502 """
503 if not force_recreate and len(self.all_gather_outputs) > 0:
504 return # already initialized
505 self.all_gather_outputs = [
506 torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
507 for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes)
508 ]
510 def init_unsharded_param(self):
511 """
512 Initialize unsharded parameter from all-gather outputs.
514 This reconstructs the full parameter after all-gather by unpacking the
515 gathered flat buffer back to the original tensor layout.
516 """
517 unsharded_param = self._get_unsharded_param_from_all_gather_output()
518 # Always refresh the unsharded Parameter from the latest all-gather output.
519 # Non-dim0 unpack currently materializes a contiguous tensor copy, so
520 # keeping stale .data would otherwise reuse old weights after optimizer.step()
521 # mutates only the sharded local shard. Preserve the Parameter object identity
522 # so autograd-facing module state stays stable across unshard cycles.
523 if hasattr(self, "_unsharded_param"):
524 # pylint: disable=access-member-before-definition
525 self._unsharded_param.data = unsharded_param
526 self._unsharded_param.requires_grad_(self.sharded_param.requires_grad)
527 self._unsharded_param.grad = None
528 return
529 self._unsharded_param = nn.Parameter(
530 unsharded_param,
531 requires_grad=self.sharded_param.requires_grad,
532 )
534 def _get_unsharded_param_from_all_gather_output(self) -> torch.Tensor:
535 """Reconstruct the full local parameter view from the packed all-gather output."""
536 if len(self.all_gather_outputs) != 1:
537 raise AssertionError(
538 f"Expected 1 all_gather_output, got {len(self.all_gather_outputs)}"
539 )
540 unsharded_tensor = self.all_gather_outputs[0]
541 plan = build_rs_plan(
542 self,
543 self._sharded_local_tensor,
544 self.shard_world_size if self.is_sharded else 1,
545 )
546 unsharded_param = unpack_from_all_gather(unsharded_tensor, plan)
547 if self._orig_param_is_dtensor:
548 # Rebuild the original DTensor view after all-gather so gradient
549 # consumers keep seeing the source DTensor layout.
550 unsharded_param = DTensor.from_local(
551 unsharded_param,
552 self._orig_dtensor_mesh,
553 self._orig_dtensor_placements,
554 )
555 return unsharded_param
557 def to_sharded(self) -> None:
558 if not self.uses_param_shard and self._unsharded_param is not None:
559 # Replicate params keep the same local shape across shard/unshard,
560 # so persist forward-time state updates before switching objects.
561 src = self._unsharded_param.to_local() if isinstance(self._unsharded_param, DTensor) \
562 else self._unsharded_param
563 dst = self.sharded_param.to_local() if isinstance(self.sharded_param, DTensor) else self.sharded_param
564 _copy_without_bumping_version(dst, src)
565 self._setattr_on_modules(self.sharded_param)
566 self.free_unsharded_param()
567 self.sharded_state = ShardedState.SHARDED
569 def to_unsharded(self) -> None:
570 set_requires_grad_if_needed(self.sharded_param, self._unsharded_param)
571 self._setattr_on_modules(self._unsharded_param)
572 self.sharded_state = ShardedState.UNSHARDED
574 def _setattr_on_modules(self, param: nn.Parameter) -> None:
575 """Set parameter on module and shared modules, preserving pointer consistency."""
576 if getattr(self._module_info.module.__setattr__, "__func__", None) is nn.Module.__setattr__:
577 # fast path
578 self._module_info.module._parameters[self._module_info.param_name] = param
579 else:
580 # slow path
581 setattr(self._module_info.module, self._module_info.param_name, param)
583 # Iterate through all modules that share this parameter to prevent pointer desync.
584 for shared_module, shared_param_name in zip(
585 self._module_info.shared_modules, self._module_info.shared_param_names
586 ):
587 if getattr(shared_module.__setattr__, "__func__", None) is nn.Module.__setattr__:
588 shared_module._parameters[shared_param_name] = param
589 else:
590 setattr(shared_module, shared_param_name, param)
592 def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor:
593 """
594 Converts a local tensor representing either the sharded parameter or
595 sharded gradient to DTensor.
596 """
597 return DTensor.from_local(
598 tensor,
599 self._sharding_spec.mesh,
600 self._sharding_spec.placements
601 )
603 def to_accumulated_grad_if_needed(self) -> None:
604 if self._unsharded_param.grad is None:
605 return
606 # Keep local gradients alive across no-sync / delayed-sync steps even
607 # after the parameter transitions back to the sharded view.
608 unsharded_grad = self._unsharded_param.grad
609 self._unsharded_param.grad = None
610 if self.reduce_dtype is not None and unsharded_grad.dtype != self.reduce_dtype:
611 unsharded_grad = unsharded_grad.to(self.reduce_dtype)
612 if self.unsharded_accumulated_grad is None:
613 self.unsharded_accumulated_grad = unsharded_grad
614 else:
615 self.unsharded_accumulated_grad += unsharded_grad
617 def accumulate_unsharded_grad_if_needed(self) -> None:
618 if (
619 self.unsharded_accumulated_grad is not None
620 and self.unsharded_param.grad is not None
621 ):
622 grad = self.unsharded_param.grad
623 if self.reduce_dtype is not None and grad.dtype != self.reduce_dtype:
624 grad = grad.to(self.reduce_dtype)
625 self.unsharded_accumulated_grad += grad
626 self.unsharded_param.grad = None
628 def alloc_all_gather_outputs(self) -> None:
629 """Resize all-gather output buffers to their full capacity for communication."""
630 for tensor in self.all_gather_outputs:
631 expected_size = tensor.numel() * tensor.itemsize
632 storage = tensor.untyped_storage()
633 if storage.size() != expected_size:
634 storage.resize_(expected_size)
636 def free_unsharded_param(self) -> None:
637 """Release storage of all-gather outputs to free device memory."""
638 for tensor in self.all_gather_outputs:
639 storage = tensor.untyped_storage()
640 if storage.size() != 0:
641 storage.resize_(0)
643 @property
644 def all_gather_inputs(self) -> list[torch.Tensor]:
645 """Return the local sharded tensor to use as input for all-gather, applying dtype cast if needed."""
646 self._assert_in_states(ShardedState.SHARDED)
647 sharded_param_data = self._sharded_param_data
648 if self.offload_to_cpu:
649 sharded_param_data = sharded_param_data.to(
650 self.device, non_blocking=True
651 )
652 if self.param_dtype is not None and self.param_dtype != sharded_param_data.dtype:
653 return [sharded_param_data.to(self.param_dtype)]
654 return [sharded_param_data]
656 @property
657 def unsharded_param(self) -> nn.Parameter:
658 """Return the full unsharded parameter after all-gather."""
659 return self._unsharded_param
661 @property
662 def unsharded_grad_data(self) -> torch.Tensor:
663 """
664 Get the unsharded gradient data as a local tensor.
665 """
666 grad = self.unsharded_param.grad
667 if grad is None:
668 raise AssertionError("Expects unsharded_param.grad to not be None")
669 return self._to_local_unsharded_grad(grad)
671 @property
672 def unsharded_accumulated_grad_data(self) -> torch.Tensor:
673 """
674 Get the unsharded accumulated gradient data as a local tensor.
675 """
676 grad = self.unsharded_accumulated_grad
677 return self._to_local_unsharded_grad(grad)
679 @property
680 def _sharded_local_tensor(self) -> torch.Tensor:
681 """Return the underlying local tensor of the sharded DTensor parameter."""
682 return cast(DTensor, self.sharded_param)._local_tensor
684 @property
685 def shard_world_size(self) -> int:
686 """Get the world size for shard dimension."""
687 return self.shard_size
689 @property
690 def replicate_world_size(self) -> int:
691 """Get the world size for replicate dimension (HSDP only)."""
692 return self.dp_size
694 def _assert_in_states(self, *states: ShardedState) -> None:
695 """Assert current state is one of expected states."""
696 if self.sharded_state not in states:
697 raise AssertionError(
698 f"Expected sharded_state in {states}, got {self.sharded_state}"
699 )
701 def reset_sharded_param(self) -> None:
702 """Reset sharded param after load_state_dict."""
703 module_info = self._module_info
704 new_param = getattr(module_info.module, module_info.param_name)
705 if new_param is not self.sharded_param:
706 # Ensure object identity is preserved after parameter conversion.
707 if torch.__future__.get_swap_module_params_on_conversion():
708 raise AssertionError(
709 f"Expects swap_tensors to preserve object but got {new_param} "
710 f"instead of {self.sharded_param}"
711 )
712 if isinstance(new_param, DTensor):
713 self.sharded_param = new_param
714 if not getattr(self.sharded_param, "_hsdp_param_initialized", None):
715 # reset _hsdp_param_initialized flag.
716 self.sharded_param._hsdp_param_initialized = True
717 elif isinstance(new_param, torch.Tensor):
718 # if new_param is Tensor, don't change 'self.sharded_param' ref
719 # just update self.sharded_param._local_tensor and self.sharded_param_data.
720 pass
722 local_tensor = new_param._local_tensor if isinstance(new_param, DTensor) else new_param
723 if local_tensor.is_meta:
724 return
725 updated_local_tensor = False
726 # local_tensor can be padded twice
727 # 1st time in fully_shard(model)
728 # 2nd time in model(input) lazy_init
729 # 2nd time should be no-op if parameters remain unchanged
730 # 2nd time shouldn't be no-op if people call model.load_state_dict(...) before lazy_init
731 # this makes it possible for trainer to call `sd = model.state_dict()` before the training loop
732 # and use `sd` without calling .state_dict() per iteration
733 same_local_tensor = False
734 if isinstance(self._sharded_param_data, torch.Tensor):
735 same_local_tensor = (
736 # when sharding param with shape (1, ...) over 2 ranks
737 # local_tensor on rank 1 can be size 0, data_ptr() can be 0
738 self._sharded_param_data.untyped_storage().data_ptr() > 0
739 and self._sharded_param_data.untyped_storage().data_ptr()
740 == local_tensor.untyped_storage().data_ptr()
741 )
742 sharded_size = self.sharded_size
743 shard_dim = self.hsdp_placement.dim
744 length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0
745 if not same_local_tensor:
746 if local_tensor.size() != sharded_size:
747 raise AssertionError(
748 f"Expected sharded_size to be {sharded_size}, got {local_tensor.size()}"
749 )
750 updated_local_tensor = True
751 if self.pin_memory and not local_tensor.is_pinned():
752 local_tensor = local_tensor.cpu().pin_memory()
753 updated_local_tensor = True
754 if not same_local_tensor:
755 self._sharded_param_data = local_tensor.view(-1)
756 if not isinstance(self.sharded_param, DTensor):
757 raise AssertionError(f"Expected DTensor, got {type(self.sharded_param)}")
758 if updated_local_tensor:
759 # Only change the local tensor object if needed
760 self.sharded_param._local_tensor = local_tensor.narrow(
761 dim=shard_dim, start=0, length=length
762 )
763 if not self.sharded_param._local_tensor.is_contiguous():
764 raise AssertionError(
765 "Expected sharded_param._local_tensor to be contiguous"
766 )
767 self._sharding_spec = cast(DTensor, self.sharded_param).layout
769 def _get_unsharded_param_data(self, async_op: bool = False) -> Tuple[torch.Tensor, Optional[dist.Work]]:
770 """
771 Perform all-gather to get unsharded parameter data.
773 Args:
774 async_op: Whether to execute asynchronously.
776 Returns:
777 (unsharded_param, handle): Unsharded parameter data and communication handle.
778 """
779 # If parameter is not sharded (below threshold), no communication needed
780 if not self.is_sharded:
781 all_gather_input = self.all_gather_inputs[0]
782 self.init_all_gather_outputs(
783 all_gather_input_numels=[all_gather_input.numel()],
784 all_gather_input_dtypes=[all_gather_input.dtype],
785 world_size=1,
786 device=self.device,
787 )
788 self.alloc_all_gather_outputs()
789 _copy_without_bumping_version(self.all_gather_outputs[0], all_gather_input)
790 return self.all_gather_outputs[0], None
792 # Get input data
793 all_gather_input = self.all_gather_inputs[0]
795 # Initialize output buffer
796 self.init_all_gather_outputs(
797 all_gather_input_numels=[all_gather_input.numel()],
798 all_gather_input_dtypes=[all_gather_input.dtype],
799 world_size=self.shard_world_size,
800 device=self.device,
801 )
802 self.alloc_all_gather_outputs()
804 if self.sharded_group_info.group is None or self.shard_world_size <= 1:
805 # No communication needed, just copy
806 _copy_without_bumping_version(self.all_gather_outputs[0], all_gather_input)
807 return self.all_gather_outputs[0], None
809 # Execute all_gather_into_tensor
810 handle = dist.all_gather_into_tensor(
811 self.all_gather_outputs[0],
812 all_gather_input,
813 group=self.sharded_group_info.group,
814 async_op=async_op,
815 )
817 return self.all_gather_outputs[0], handle
819 def unshard(self, async_op: bool = False) -> None:
820 if self.prefetch_handle is not None:
821 # Already triggered by HSDPState.prefetch(), so return directly.
822 return # no-op
824 _, handle = self._get_unsharded_param_data(async_op=async_op)
825 self.prefetch_handle = handle
827 def wait_for_unshard(self) -> None:
828 self._assert_in_states(ShardedState.SHARDED)
830 if self.prefetch_handle is not None:
831 self.prefetch_handle.wait()
832 self.prefetch_handle = None
834 self.init_unsharded_param()
835 self.to_unsharded()
837 def shard(self) -> None:
838 """
839 Transition parameter from unsharded back to sharded state.
840 """
841 self._assert_in_states(ShardedState.UNSHARDED)
842 self.to_sharded()
844 def reduce_scatter_grad(
845 self,
846 async_op: bool = True,
847 dtype: Optional[torch.dtype] = None,
848 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG
849 ) -> Union[None, Tuple[torch.Tensor, Optional[dist.Work]]]:
850 """
851 Perform reduce-scatter on gradient to reduce and shard the full gradient.
853 Args:
854 async_op: Whether to execute asynchronously.
855 dtype: reduce dtype.
856 reduce_op: do reduce-scatter avg or sum.
858 Returns:
859 (sharded_grad, handle): Sharded gradient and communication handle.
860 """
861 self._assert_in_states(ShardedState.UNSHARDED)
863 # Choose gradient source based on use_accumulated_grad flag
864 if self.unsharded_accumulated_grad is not None:
865 grad = self.unsharded_accumulated_grad_data
866 else:
867 grad = self.unsharded_grad_data
868 reduce_dtype = dtype or grad.dtype
869 grad = grad.to(reduce_dtype)
870 plan_world_size = (
871 self.shard_world_size
872 if self.is_sharded
873 and self.sharded_group_info.group is not None
874 and self.shard_world_size > 1
875 else 1
876 )
877 plan = build_rs_plan(self, grad, plan_world_size)
878 grad_flat = pack_for_reduce_scatter(grad, plan).reshape(-1)
880 # If parameter is not sharded (below threshold), no reduce-scatter needed
881 if not self.is_sharded:
882 return grad_flat, None
884 if self.sharded_group_info.group is None or self.shard_world_size <= 1:
885 # No communication needed
886 return grad_flat, None
888 # Calculate output size
889 output_numel = grad_flat.numel() // self.shard_world_size
890 self._reduce_scatter_output = torch.empty(output_numel, dtype=reduce_dtype, device=grad.device)
892 # Execute reduce_scatter_tensor
893 self.reduce_scatter_handle = dist.reduce_scatter_tensor(
894 self._reduce_scatter_output,
895 grad_flat,
896 op=reduce_op,
897 group=self.sharded_group_info.group,
898 async_op=async_op,
899 )
900 return self._reduce_scatter_output, self.reduce_scatter_handle
902 def all_reduce_grad(
903 self,
904 grad: Optional[torch.Tensor] = None,
905 dtype: Optional[torch.dtype] = None,
906 async_op: bool = True,
907 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG
908 ) -> Union[None, Tuple[torch.Tensor, Optional[dist.Work]]]:
909 """
910 Perform all-reduce on gradient (across replicate dimension in HSDP mode).
912 Args:
913 grad: Gradient tensor to reduce. If None, will use unsharded_param.grad
914 or unsharded_accumulated_grad based on use_accumulated_grad flag.
915 async_op: Whether to execute asynchronously.
916 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG.
918 Returns:
919 (reduced_grad, handle): Reduced gradient and communication handle.
920 """
921 # If grad is not provided, get from parameter
922 if grad is None:
923 if self.unsharded_accumulated_grad is not None:
924 grad = self.unsharded_accumulated_grad_data
925 else:
926 grad = self.unsharded_grad_data
928 if dtype is not None and dtype != grad.dtype:
929 grad = grad.to(dtype)
931 if self.unsharded_group_info.group is None or self.replicate_world_size <= 1:
932 return grad, None
934 self.all_reduce_handle = dist.all_reduce(grad, op=reduce_op,
935 group=self.unsharded_group_info.group, async_op=async_op)
936 self._all_reduce_output = grad
937 return grad, self.all_reduce_handle
940def set_requires_grad_if_needed(
941 src_tensor: torch.Tensor, dst_tensor: torch.Tensor
942) -> None:
943 """set dst_tensor requires_grads from src_tensor if needed."""
944 if src_tensor.requires_grad != dst_tensor.requires_grad:
945 dst_tensor.requires_grad_(src_tensor.requires_grad)