Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / param.py: 69%

452 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/distributed/fsdp/_fully_shard/_fsdp_param.py 

16# enhanced with fully_shard parameter management 

17# ============================================================================ 

18"""HSDP parameter""" 

19# pylint: disable=W0212 

20import itertools 

21from typing import Callable, List, Optional, Tuple, Union, cast 

22 

23import torch 

24import torch.distributed as dist 

25from torch import nn 

26from torch._prims_common import make_contiguous_strides_for 

27 

28from hyper_parallel.core.dtensor.device_mesh import DeviceMesh 

29from hyper_parallel.core.dtensor.dtensor import DTensor, SkipDTensorDispatch 

30from hyper_parallel.core.dtensor.layout import Layout 

31from hyper_parallel.core.dtensor.placement_types import Replicate, Shard, StridedShard 

32from hyper_parallel.core.fully_shard.hsdp_param import HSDPParamV2 

33from hyper_parallel.core.fully_shard.hsdp_utils import ( 

34 FullyShardParamMode, 

35 GroupInfo, 

36 ParamModuleInfo, 

37 ShardedState, 

38 get_rank_list_for_axes, 

39 get_split_rank_lists_for_axes, 

40) 

41from hyper_parallel.core.fully_shard.utils import ( 

42 CPUOffloadPolicy, 

43 DDPMeshInfo, 

44 FSDPMeshInfo, 

45 MixedPrecisionPolicy, 

46 OffloadPolicy, 

47) 

48from hyper_parallel.platform import get_platform 

49from hyper_parallel.platform.torch.fully_shard.pack_utils import ( 

50 build_rs_plan, 

51 pack_for_reduce_scatter, 

52 unpack_from_all_gather, 

53) 

54 

55_GROUP_INFO_CACHE = {} 

56platform = get_platform() 

57 

58 

59def _copy_without_bumping_version(dst: torch.Tensor, src: torch.Tensor) -> None: 

60 """Copy into ``dst`` while preserving its autograd version counter.""" 

61 # pylint: disable=W0212 

62 with torch.autograd._unsafe_preserve_version_counter(dst): 

63 dst.copy_(src) 

64 

65 

66def _build_group_info_from_rank_list( 

67 group_name: str, 

68 rank_list, 

69) -> GroupInfo: 

70 """Create group metadata from an explicit rank list.""" 

71 normalized_rank_list = tuple(sorted(int(rank) for rank in rank_list)) 

72 if len(normalized_rank_list) <= 1: 

73 return GroupInfo(f"{group_name}_invalid", None, 1) 

74 if normalized_rank_list in _GROUP_INFO_CACHE: 

75 cached_group = _GROUP_INFO_CACHE[normalized_rank_list] 

76 return GroupInfo(str(normalized_rank_list), cached_group, len(normalized_rank_list)) 

77 try: 

78 group = platform.create_group(list(normalized_rank_list)) 

79 except (RuntimeError, ValueError): # pragma: no cover - UT may run without dist init 

80 group = None 

81 _GROUP_INFO_CACHE[normalized_rank_list] = group 

82 return GroupInfo(str(normalized_rank_list), group, len(normalized_rank_list)) 

83 

84 

85def _build_group_info_from_process_group( 

86 group_name: str, 

87 process_group, 

88 rank_size: int, 

89) -> GroupInfo: 

90 """Create group metadata from an existing process group.""" 

91 if process_group is None or rank_size <= 1: 

92 return GroupInfo(f"{group_name}_invalid", None, 1) 

93 try: 

94 rank_list = dist.get_process_group_ranks(process_group) 

95 resolved_group_name = str(tuple(sorted(rank_list))) 

96 except (AssertionError, AttributeError, KeyError, RuntimeError, TypeError, ValueError): 

97 # pragma: no cover - best-effort naming / mocked process groups in UT 

98 resolved_group_name = group_name 

99 return GroupInfo(resolved_group_name, process_group, rank_size) 

100 

101 

102class TorchHSDPParamV2(HSDPParamV2): 

103 """ 

104 Torch HSDP parameter. 

105 """ 

106 

107 def __init__( 

108 self, 

109 param: nn.Parameter, 

110 module_info: ParamModuleInfo, 

111 mesh_info: FSDPMeshInfo, 

112 shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None, 

113 mp_policy: Optional[MixedPrecisionPolicy] = None, 

114 offload_policy: Optional[OffloadPolicy] = None, 

115 device: Optional[torch.device] = None, 

116 param_mode: Optional[FullyShardParamMode] = None, 

117 enable_fsdp_shard: bool = True, 

118 ): 

119 """ 

120 Initialize TorchHSDPParamV2 and shard the parameter. 

121 

122 Args: 

123 param (nn.Parameter): The original full parameter to shard. 

124 module_info (ParamModuleInfo): Ownership and shared-weight metadata. 

125 mesh_info (FSDPMeshInfo): Mesh topology for shard/replicate dimensions. 

126 shard_placement_fn (Callable, optional): Returns a Shard placement for the parameter, 

127 or None to use default (Shard(0)). 

128 mp_policy (MixedPrecisionPolicy, optional): Mixed precision dtype policy. 

129 offload_policy (OffloadPolicy, optional): CPU offload policy. 

130 device (torch.device, optional): Target device for the sharded parameter. 

131 """ 

132 self._module_info: ParamModuleInfo = module_info 

133 self.mesh_info = mesh_info 

134 self.mp_policy = mp_policy 

135 self.device = device 

136 if param_mode is None: 

137 raise AssertionError("param_mode must be resolved before TorchHSDPParamV2 initialization.") 

138 self.param_mode = param_mode 

139 self.enable_fsdp_shard = enable_fsdp_shard 

140 self.orig_dtype = None 

141 self.param_dtype = None 

142 self.reduce_dtype = None 

143 self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy) 

144 self.pin_memory = ( 

145 self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory 

146 ) 

147 self.grad_offload_event: Optional[torch.Event] = None 

148 self._orig_param_is_dtensor = isinstance(param, DTensor) 

149 self._orig_dtensor_mesh = param.device_mesh if self._orig_param_is_dtensor else None 

150 self._orig_dtensor_placements = tuple(param.placements) if self._orig_param_is_dtensor else None 

151 self._spmd_shard_mesh_dim = self.mesh_info.shard_mesh_dim 

152 self._spmd_replicate_mesh_dim = self.mesh_info.replicate_mesh_dim 

153 self._init_sharded_param(param, shard_placement_fn) 

154 self._init_group_infos() 

155 self.all_gather_outputs: List[torch.Tensor] = [] 

156 self.unsharded_accumulated_grad = None 

157 self._param_fqn: Optional[str] = None 

158 # Communication attributes for prefetch pattern 

159 self.prefetch_handle: Optional[dist.Work] = None 

160 self._post_load_hook_handle = ( 

161 module_info.module.register_load_state_dict_post_hook( 

162 lambda *args, **kwargs: self.reset_sharded_param() 

163 ) 

164 ) 

165 self._reduce_scatter_output = None 

166 self.reduce_scatter_handle = None 

167 self._all_reduce_output = None 

168 self.all_reduce_handle = None 

169 

170 @property 

171 def uses_param_shard(self) -> bool: 

172 """Whether fully_shard should physically shard parameter storage for this param.""" 

173 return self.enable_fsdp_shard 

174 

175 @property 

176 def is_dtensor_compat_mode(self) -> bool: 

177 """Whether the parameter is managed through the DTensor compatibility path only.""" 

178 return self.param_mode == FullyShardParamMode.DTENSOR_COMPAT 

179 

180 def _get_base_spmd_placements(self) -> tuple: 

181 if self.param_mode == FullyShardParamMode.DTENSOR_UNIFIED and self._orig_param_is_dtensor: 

182 # DTENSOR_UNIFIED keeps the original distributed layout and prefixes 

183 # explicit DP/FSDP mesh dimensions ahead of it on the unified mesh. 

184 self._spmd_mesh = DeviceMesh.concatenate([self.mesh_info.mesh, self._orig_dtensor_mesh]) 

185 dp_prefix_placements = tuple(Replicate() for _ in range(self.mesh_info.mesh.ndim)) 

186 return dp_prefix_placements + tuple(self._orig_dtensor_placements) 

187 

188 if self.is_dtensor_compat_mode and self._orig_param_is_dtensor: 

189 self._spmd_mesh = self._orig_dtensor_mesh 

190 return tuple(self._orig_dtensor_placements) 

191 

192 self._spmd_mesh = self.mesh_info.mesh 

193 return tuple(Replicate() for _ in range(self._spmd_mesh.ndim)) 

194 

195 def _apply_data_parallel_placements(self, placements: list, shard_placement: Shard) -> tuple: 

196 if len(placements) != self._spmd_mesh.ndim: 

197 raise AssertionError( 

198 f"Expected {self._spmd_mesh.ndim} unified placements, got {len(placements)}: {placements}" 

199 ) 

200 if ( 

201 isinstance(self.mesh_info, DDPMeshInfo) 

202 and self._spmd_replicate_mesh_dim is not None 

203 and not self._orig_param_is_dtensor 

204 ): 

205 placements[self._spmd_replicate_mesh_dim] = Replicate() 

206 if ( 

207 self.uses_param_shard 

208 and isinstance(self.mesh_info, FSDPMeshInfo) 

209 and self._spmd_shard_mesh_dim is not None 

210 ): 

211 # If TP/EP already shards the same tensor dimension, fully_shard must 

212 # use StridedShard so the unified placement preserves the intended 

213 # shard order on the concatenated mesh. 

214 split_factor = 1 

215 for mesh_idx, placement in enumerate(placements): 

216 if mesh_idx == self._spmd_shard_mesh_dim: 

217 continue 

218 if placement.is_shard(shard_placement.dim): 

219 split_factor *= self._spmd_mesh.mesh_shape[mesh_idx] 

220 placements[self._spmd_shard_mesh_dim] = ( 

221 StridedShard(shard_placement.dim, split_factor=split_factor) 

222 if split_factor > 1 

223 else shard_placement 

224 ) 

225 return tuple(placements) 

226 

227 def _init_group_infos(self) -> None: 

228 if self.uses_param_shard and self.is_sharded and isinstance(self.mesh_info, FSDPMeshInfo): 

229 self.sharded_group_info = _build_group_info_from_process_group( 

230 "fully_shard_sharded_group", 

231 self.mesh_info.shard_process_group, 

232 self.mesh_info.shard_mesh_size, 

233 ) 

234 else: 

235 self.sharded_group_info = GroupInfo("fully_shard_sharded_group_invalid", None, 1) 

236 

237 # The all-reduce group is always derived from the final materialized layout. 

238 # This keeps replicate_params, DTensor compat, and unified multi-dim layouts 

239 # on a single source of truth. 

240 self.unsharded_group_info = self._build_layout_driven_group_info() 

241 

242 self.shard_size = self.sharded_group_info.rank_size 

243 self.dp_size = self.unsharded_group_info.rank_size 

244 self.rank_size = max(1, self.shard_size * self.dp_size) 

245 

246 def _build_layout_driven_group_info(self): 

247 group_axes = [ 

248 axis 

249 for axis, placement in enumerate(self._spmd_placements) 

250 if placement.is_replicate() 

251 ] 

252 if self.uses_param_shard and self._spmd_shard_mesh_dim is not None: 

253 group_axes = [axis for axis in group_axes if axis != self._spmd_shard_mesh_dim] 

254 if not group_axes: 

255 return GroupInfo("fully_shard_unsharded_group_invalid", None, 1) 

256 group_dim_names = getattr(self._spmd_mesh, "mesh_dim_names", None) 

257 if group_dim_names: 

258 try: 

259 mesh_axis_names = tuple(group_dim_names[axis] for axis in group_axes) 

260 if len(mesh_axis_names) == 1: 

261 axis_name = mesh_axis_names[0] 

262 process_group = self._spmd_mesh.get_group(axis_name) 

263 if process_group is not None: 

264 rank_size = self._spmd_mesh.mesh_shape[group_dim_names.index(axis_name)] 

265 return _build_group_info_from_process_group( 

266 "fully_shard_unsharded_group", 

267 process_group, 

268 rank_size, 

269 ) 

270 

271 split_rank_lists = get_split_rank_lists_for_axes(self._spmd_mesh, group_axes) 

272 process_group = platform.split_group(split_ranks=split_rank_lists) 

273 if process_group is not None: 

274 rank_size = 1 

275 for axis in group_axes: 

276 rank_size *= self._spmd_mesh.mesh_shape[axis] 

277 return _build_group_info_from_process_group( 

278 "fully_shard_unsharded_group", 

279 process_group, 

280 rank_size, 

281 ) 

282 except ( 

283 AssertionError, 

284 AttributeError, 

285 KeyError, 

286 RuntimeError, 

287 TypeError, 

288 ValueError, 

289 ): 

290 # Fall back to the explicit rank-list path for mocked meshes in UT 

291 # or when a mesh implementation cannot materialize a reusable group. 

292 pass 

293 

294 rank_list = get_rank_list_for_axes(self._spmd_mesh, group_axes) 

295 return _build_group_info_from_rank_list("fully_shard_unsharded_group", rank_list) 

296 

297 def _to_local_unsharded_grad(self, grad): 

298 """Normalize a pending gradient to a local tensor expected by fully_shard collectives.""" 

299 if not isinstance(grad, DTensor): 

300 return grad 

301 

302 if any(placement.is_partial() for placement in grad.placements): 

303 grad = grad.reduce_partial() 

304 

305 if ( 

306 self._orig_dtensor_mesh is not None 

307 and grad.device_mesh.to_hash() != self._orig_dtensor_mesh.to_hash() 

308 ) or ( 

309 self._orig_dtensor_placements is not None 

310 and tuple(grad.placements) != tuple(self._orig_dtensor_placements) 

311 ): 

312 grad = grad.redistribute(self._orig_dtensor_mesh, self._orig_dtensor_placements) 

313 return grad.to_local() 

314 

315 def reduce_scatter_output(self): 

316 """ 

317 Get the reduce-scatter output tensor and wait for asynchronous operation to complete. 

318 

319 Returns: 

320 torch.Tensor: The sharded gradient tensor after reduce-scatter operation. 

321 """ 

322 if self.reduce_scatter_handle is not None: 

323 self.reduce_scatter_handle.wait() 

324 self.reduce_scatter_handle = None 

325 return self._reduce_scatter_output 

326 

327 def clear_reduce_scatter_output(self): 

328 """Clear the reduce-scatter output tensor to free memory.""" 

329 self._reduce_scatter_output = None 

330 

331 def all_reduce_output(self): 

332 """ 

333 Get the all-reduce output tensor and wait for asynchronous operation to complete. 

334 

335 Returns: 

336 torch.Tensor: The reduced gradient tensor after all-reduce operation. 

337 """ 

338 if self.all_reduce_handle is not None: 

339 self.all_reduce_handle.wait() 

340 self.all_reduce_handle = None 

341 return self._all_reduce_output 

342 

343 def clear_all_reduce_output(self): 

344 """Clear the all-reduce output tensor to free memory.""" 

345 self._all_reduce_output = None 

346 

347 def apply_reduced_grad(self, reduced_grad, param_type): 

348 """ 

349 Apply reduced gradient to the sharded parameter. 

350 

351 Reshapes ``reduced_grad`` to match the local shard, optionally 

352 offloads to CPU, then accumulates or assigns onto 

353 ``hsdp_param.sharded_param.grad``. 

354 

355 Args: 

356 reduced_grad (torch.Tensor): Gradient after reduce-scatter 

357 and/or all-reduce. 

358 param_type (Optional[torch.dtype]): Target dtype for the gradient (if conversion is needed). 

359 """ 

360 sharded_grad = None 

361 if not self.mp_policy.apply_grad_on_fp32_main_grad: 

362 sharded_grad = self.sharded_param.grad 

363 else: 

364 if not hasattr(self.sharded_param, "main_grad"): 

365 self.sharded_param.main_grad = None 

366 sharded_grad = self.sharded_param.main_grad 

367 sharded_param_local_shape = ( 

368 self.sharded_param.local_shape 

369 if isinstance(self.sharded_param, DTensor) 

370 else self.sharded_param.shape 

371 ) 

372 reduced_grad = reduced_grad.view(sharded_param_local_shape) 

373 if (not self.mp_policy.apply_grad_on_fp32_main_grad and param_type is not None 

374 and reduced_grad.dtype != param_type): 

375 reduced_grad = reduced_grad.to(param_type) 

376 to_accumulate_grad = sharded_grad is not None 

377 need_synchronize = False 

378 if self.offload_to_cpu: 

379 non_blocking = self.pin_memory and not to_accumulate_grad 

380 reduced_grad = reduced_grad.to( 

381 torch.device("cpu"), non_blocking=non_blocking 

382 ) 

383 need_synchronize = True 

384 if sharded_grad is None: 

385 if not self.mp_policy.apply_grad_on_fp32_main_grad: 

386 self.sharded_param.grad = self.to_sharded_dtensor(reduced_grad) 

387 else: 

388 self.sharded_param.main_grad = self.to_sharded_dtensor(reduced_grad) 

389 self.sharded_param.grad = None 

390 else: 

391 with SkipDTensorDispatch(): 

392 if not self.mp_policy.apply_grad_on_fp32_main_grad: 

393 self.sharded_param.grad._local_tensor += reduced_grad 

394 else: 

395 self.sharded_param.main_grad._local_tensor += reduced_grad 

396 self.sharded_param.grad = None 

397 if self.unsharded_accumulated_grad_data is not None: 

398 self.unsharded_accumulated_grad = None 

399 elif self.unsharded_param.grad is not None: 

400 self.unsharded_param.grad = None 

401 return need_synchronize 

402 

403 @torch.no_grad() 

404 def _init_sharded_param( 

405 self, 

406 param: nn.Parameter, 

407 shard_placement_fn: Optional[Callable], 

408 ) -> None: 

409 if param.device != self.device and param.device.type != "meta": 

410 raise AssertionError( 

411 f"Expects the parameter to already be moved to device {self.device} but got {param.device}" 

412 ) 

413 

414 hsdp_placement = shard_placement_fn(param) if shard_placement_fn else None 

415 if hsdp_placement is None: 

416 hsdp_placement = Shard(0) 

417 elif hsdp_placement.dim < 0: 

418 # if dim is negative, add the number of dimensions of the parameter 

419 hsdp_placement = Shard(hsdp_placement.dim + param.ndim) 

420 

421 if not isinstance(hsdp_placement, Shard): 

422 raise AssertionError( 

423 f"Expected Shard, got {type(hsdp_placement)}: {hsdp_placement}" 

424 ) 

425 

426 self.hsdp_placement = hsdp_placement 

427 base_placements = list(self._get_base_spmd_placements()) 

428 self._spmd_placements = self._apply_data_parallel_placements(base_placements, hsdp_placement) 

429 param_data = param.to_local() if self._orig_param_is_dtensor else param 

430 

431 shard_dim = hsdp_placement.dim 

432 self._orig_size = param_data.size() 

433 self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size) 

434 

435 if self.uses_param_shard and isinstance(self.mesh_info, FSDPMeshInfo): 

436 shard_rank = self.mesh_info.shard_mesh_rank 

437 shard_world_size = self.mesh_info.shard_mesh_size 

438 else: 

439 shard_rank = 0 

440 shard_world_size = 1 

441 

442 if isinstance(param_data, DTensor) and isinstance(self.mesh_info, DDPMeshInfo): 

443 param_data.data = param_data.full_tensor() 

444 

445 self.is_sharded = bool(self.uses_param_shard and shard_world_size > 1) 

446 

447 if param_data.size(shard_dim) % shard_world_size != 0: 

448 raise NotImplementedError( 

449 f"Uneven sharding on dim {shard_dim} not supported: " 

450 f"shape={param_data.shape}, world_size={shard_world_size}" 

451 ) 

452 chunks = torch.chunk(param_data, shard_world_size, dim=shard_dim) 

453 sharded_param = chunks[shard_rank].clone().contiguous() 

454 self.sharded_size = sharded_param.size() 

455 self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size) 

456 if self.offload_to_cpu and not sharded_param.is_meta: 

457 sharded_param = sharded_param.cpu() 

458 if self.pin_memory: 

459 sharded_param = sharded_param.pin_memory() 

460 self._sharded_param_data = sharded_param.view(-1) 

461 

462 self._sharding_spec = Layout.from_device_mesh(self._spmd_mesh) 

463 self._sharding_spec.set_placements(self._spmd_placements) 

464 self._sharding_spec.placement_to_tensor_map(param.ndim) 

465 

466 self.sharded_param = nn.Parameter(DTensor.from_local(sharded_param, self._spmd_mesh, self._spmd_placements)) 

467 self.sharded_param.requires_grad_(param.requires_grad) 

468 self._setattr_on_modules(self.sharded_param) 

469 # after init, self.sharded_param replaces original param, gradients must accumulate to this Parameter's grad 

470 self.sharded_param._hsdp_param_initialized = True 

471 self.sharded_state = ShardedState.SHARDED 

472 self.param_dtype = None 

473 

474 def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy): 

475 """Initialize param_dtype and reduce_dtype from the mixed precision policy.""" 

476 param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype) 

477 self.orig_dtype = self.sharded_param.dtype 

478 if reduce_dtype == param_dtype: 

479 reduce_dtype = None 

480 if param_dtype == self.orig_dtype: 

481 param_dtype = None 

482 self.param_dtype = param_dtype 

483 self.reduce_dtype = reduce_dtype 

484 

485 def init_all_gather_outputs( 

486 self, 

487 all_gather_input_numels: list[int], 

488 all_gather_input_dtypes: list[torch.dtype], 

489 world_size: int, 

490 device: torch.device, 

491 force_recreate: bool = False, 

492 ): 

493 """ 

494 Allocate output buffers for all-gather communication. 

495 

496 Args: 

497 all_gather_input_numels: Number of elements per input shard. 

498 all_gather_input_dtypes: Dtype of each input shard. 

499 world_size: Number of ranks in the shard process group. 

500 device: Device on which to allocate the output buffers. 

501 force_recreate: If True, always recreate buffers even if already initialized. 

502 """ 

503 if not force_recreate and len(self.all_gather_outputs) > 0: 

504 return # already initialized 

505 self.all_gather_outputs = [ 

506 torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device) 

507 for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes) 

508 ] 

509 

510 def init_unsharded_param(self): 

511 """ 

512 Initialize unsharded parameter from all-gather outputs. 

513 

514 This reconstructs the full parameter after all-gather by unpacking the 

515 gathered flat buffer back to the original tensor layout. 

516 """ 

517 unsharded_param = self._get_unsharded_param_from_all_gather_output() 

518 # Always refresh the unsharded Parameter from the latest all-gather output. 

519 # Non-dim0 unpack currently materializes a contiguous tensor copy, so 

520 # keeping stale .data would otherwise reuse old weights after optimizer.step() 

521 # mutates only the sharded local shard. Preserve the Parameter object identity 

522 # so autograd-facing module state stays stable across unshard cycles. 

523 if hasattr(self, "_unsharded_param"): 

524 # pylint: disable=access-member-before-definition 

525 self._unsharded_param.data = unsharded_param 

526 self._unsharded_param.requires_grad_(self.sharded_param.requires_grad) 

527 self._unsharded_param.grad = None 

528 return 

529 self._unsharded_param = nn.Parameter( 

530 unsharded_param, 

531 requires_grad=self.sharded_param.requires_grad, 

532 ) 

533 

534 def _get_unsharded_param_from_all_gather_output(self) -> torch.Tensor: 

535 """Reconstruct the full local parameter view from the packed all-gather output.""" 

536 if len(self.all_gather_outputs) != 1: 

537 raise AssertionError( 

538 f"Expected 1 all_gather_output, got {len(self.all_gather_outputs)}" 

539 ) 

540 unsharded_tensor = self.all_gather_outputs[0] 

541 plan = build_rs_plan( 

542 self, 

543 self._sharded_local_tensor, 

544 self.shard_world_size if self.is_sharded else 1, 

545 ) 

546 unsharded_param = unpack_from_all_gather(unsharded_tensor, plan) 

547 if self._orig_param_is_dtensor: 

548 # Rebuild the original DTensor view after all-gather so gradient 

549 # consumers keep seeing the source DTensor layout. 

550 unsharded_param = DTensor.from_local( 

551 unsharded_param, 

552 self._orig_dtensor_mesh, 

553 self._orig_dtensor_placements, 

554 ) 

555 return unsharded_param 

556 

557 def to_sharded(self) -> None: 

558 if not self.uses_param_shard and self._unsharded_param is not None: 

559 # Replicate params keep the same local shape across shard/unshard, 

560 # so persist forward-time state updates before switching objects. 

561 src = self._unsharded_param.to_local() if isinstance(self._unsharded_param, DTensor) \ 

562 else self._unsharded_param 

563 dst = self.sharded_param.to_local() if isinstance(self.sharded_param, DTensor) else self.sharded_param 

564 _copy_without_bumping_version(dst, src) 

565 self._setattr_on_modules(self.sharded_param) 

566 self.free_unsharded_param() 

567 self.sharded_state = ShardedState.SHARDED 

568 

569 def to_unsharded(self) -> None: 

570 set_requires_grad_if_needed(self.sharded_param, self._unsharded_param) 

571 self._setattr_on_modules(self._unsharded_param) 

572 self.sharded_state = ShardedState.UNSHARDED 

573 

574 def _setattr_on_modules(self, param: nn.Parameter) -> None: 

575 """Set parameter on module and shared modules, preserving pointer consistency.""" 

576 if getattr(self._module_info.module.__setattr__, "__func__", None) is nn.Module.__setattr__: 

577 # fast path 

578 self._module_info.module._parameters[self._module_info.param_name] = param 

579 else: 

580 # slow path 

581 setattr(self._module_info.module, self._module_info.param_name, param) 

582 

583 # Iterate through all modules that share this parameter to prevent pointer desync. 

584 for shared_module, shared_param_name in zip( 

585 self._module_info.shared_modules, self._module_info.shared_param_names 

586 ): 

587 if getattr(shared_module.__setattr__, "__func__", None) is nn.Module.__setattr__: 

588 shared_module._parameters[shared_param_name] = param 

589 else: 

590 setattr(shared_module, shared_param_name, param) 

591 

592 def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor: 

593 """ 

594 Converts a local tensor representing either the sharded parameter or 

595 sharded gradient to DTensor. 

596 """ 

597 return DTensor.from_local( 

598 tensor, 

599 self._sharding_spec.mesh, 

600 self._sharding_spec.placements 

601 ) 

602 

603 def to_accumulated_grad_if_needed(self) -> None: 

604 if self._unsharded_param.grad is None: 

605 return 

606 # Keep local gradients alive across no-sync / delayed-sync steps even 

607 # after the parameter transitions back to the sharded view. 

608 unsharded_grad = self._unsharded_param.grad 

609 self._unsharded_param.grad = None 

610 if self.reduce_dtype is not None and unsharded_grad.dtype != self.reduce_dtype: 

611 unsharded_grad = unsharded_grad.to(self.reduce_dtype) 

612 if self.unsharded_accumulated_grad is None: 

613 self.unsharded_accumulated_grad = unsharded_grad 

614 else: 

615 self.unsharded_accumulated_grad += unsharded_grad 

616 

617 def accumulate_unsharded_grad_if_needed(self) -> None: 

618 if ( 

619 self.unsharded_accumulated_grad is not None 

620 and self.unsharded_param.grad is not None 

621 ): 

622 grad = self.unsharded_param.grad 

623 if self.reduce_dtype is not None and grad.dtype != self.reduce_dtype: 

624 grad = grad.to(self.reduce_dtype) 

625 self.unsharded_accumulated_grad += grad 

626 self.unsharded_param.grad = None 

627 

628 def alloc_all_gather_outputs(self) -> None: 

629 """Resize all-gather output buffers to their full capacity for communication.""" 

630 for tensor in self.all_gather_outputs: 

631 expected_size = tensor.numel() * tensor.itemsize 

632 storage = tensor.untyped_storage() 

633 if storage.size() != expected_size: 

634 storage.resize_(expected_size) 

635 

636 def free_unsharded_param(self) -> None: 

637 """Release storage of all-gather outputs to free device memory.""" 

638 for tensor in self.all_gather_outputs: 

639 storage = tensor.untyped_storage() 

640 if storage.size() != 0: 

641 storage.resize_(0) 

642 

643 @property 

644 def all_gather_inputs(self) -> list[torch.Tensor]: 

645 """Return the local sharded tensor to use as input for all-gather, applying dtype cast if needed.""" 

646 self._assert_in_states(ShardedState.SHARDED) 

647 sharded_param_data = self._sharded_param_data 

648 if self.offload_to_cpu: 

649 sharded_param_data = sharded_param_data.to( 

650 self.device, non_blocking=True 

651 ) 

652 if self.param_dtype is not None and self.param_dtype != sharded_param_data.dtype: 

653 return [sharded_param_data.to(self.param_dtype)] 

654 return [sharded_param_data] 

655 

656 @property 

657 def unsharded_param(self) -> nn.Parameter: 

658 """Return the full unsharded parameter after all-gather.""" 

659 return self._unsharded_param 

660 

661 @property 

662 def unsharded_grad_data(self) -> torch.Tensor: 

663 """ 

664 Get the unsharded gradient data as a local tensor. 

665 """ 

666 grad = self.unsharded_param.grad 

667 if grad is None: 

668 raise AssertionError("Expects unsharded_param.grad to not be None") 

669 return self._to_local_unsharded_grad(grad) 

670 

671 @property 

672 def unsharded_accumulated_grad_data(self) -> torch.Tensor: 

673 """ 

674 Get the unsharded accumulated gradient data as a local tensor. 

675 """ 

676 grad = self.unsharded_accumulated_grad 

677 return self._to_local_unsharded_grad(grad) 

678 

679 @property 

680 def _sharded_local_tensor(self) -> torch.Tensor: 

681 """Return the underlying local tensor of the sharded DTensor parameter.""" 

682 return cast(DTensor, self.sharded_param)._local_tensor 

683 

684 @property 

685 def shard_world_size(self) -> int: 

686 """Get the world size for shard dimension.""" 

687 return self.shard_size 

688 

689 @property 

690 def replicate_world_size(self) -> int: 

691 """Get the world size for replicate dimension (HSDP only).""" 

692 return self.dp_size 

693 

694 def _assert_in_states(self, *states: ShardedState) -> None: 

695 """Assert current state is one of expected states.""" 

696 if self.sharded_state not in states: 

697 raise AssertionError( 

698 f"Expected sharded_state in {states}, got {self.sharded_state}" 

699 ) 

700 

701 def reset_sharded_param(self) -> None: 

702 """Reset sharded param after load_state_dict.""" 

703 module_info = self._module_info 

704 new_param = getattr(module_info.module, module_info.param_name) 

705 if new_param is not self.sharded_param: 

706 # Ensure object identity is preserved after parameter conversion. 

707 if torch.__future__.get_swap_module_params_on_conversion(): 

708 raise AssertionError( 

709 f"Expects swap_tensors to preserve object but got {new_param} " 

710 f"instead of {self.sharded_param}" 

711 ) 

712 if isinstance(new_param, DTensor): 

713 self.sharded_param = new_param 

714 if not getattr(self.sharded_param, "_hsdp_param_initialized", None): 

715 # reset _hsdp_param_initialized flag. 

716 self.sharded_param._hsdp_param_initialized = True 

717 elif isinstance(new_param, torch.Tensor): 

718 # if new_param is Tensor, don't change 'self.sharded_param' ref 

719 # just update self.sharded_param._local_tensor and self.sharded_param_data. 

720 pass 

721 

722 local_tensor = new_param._local_tensor if isinstance(new_param, DTensor) else new_param 

723 if local_tensor.is_meta: 

724 return 

725 updated_local_tensor = False 

726 # local_tensor can be padded twice 

727 # 1st time in fully_shard(model) 

728 # 2nd time in model(input) lazy_init 

729 # 2nd time should be no-op if parameters remain unchanged 

730 # 2nd time shouldn't be no-op if people call model.load_state_dict(...) before lazy_init 

731 # this makes it possible for trainer to call `sd = model.state_dict()` before the training loop 

732 # and use `sd` without calling .state_dict() per iteration 

733 same_local_tensor = False 

734 if isinstance(self._sharded_param_data, torch.Tensor): 

735 same_local_tensor = ( 

736 # when sharding param with shape (1, ...) over 2 ranks 

737 # local_tensor on rank 1 can be size 0, data_ptr() can be 0 

738 self._sharded_param_data.untyped_storage().data_ptr() > 0 

739 and self._sharded_param_data.untyped_storage().data_ptr() 

740 == local_tensor.untyped_storage().data_ptr() 

741 ) 

742 sharded_size = self.sharded_size 

743 shard_dim = self.hsdp_placement.dim 

744 length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0 

745 if not same_local_tensor: 

746 if local_tensor.size() != sharded_size: 

747 raise AssertionError( 

748 f"Expected sharded_size to be {sharded_size}, got {local_tensor.size()}" 

749 ) 

750 updated_local_tensor = True 

751 if self.pin_memory and not local_tensor.is_pinned(): 

752 local_tensor = local_tensor.cpu().pin_memory() 

753 updated_local_tensor = True 

754 if not same_local_tensor: 

755 self._sharded_param_data = local_tensor.view(-1) 

756 if not isinstance(self.sharded_param, DTensor): 

757 raise AssertionError(f"Expected DTensor, got {type(self.sharded_param)}") 

758 if updated_local_tensor: 

759 # Only change the local tensor object if needed 

760 self.sharded_param._local_tensor = local_tensor.narrow( 

761 dim=shard_dim, start=0, length=length 

762 ) 

763 if not self.sharded_param._local_tensor.is_contiguous(): 

764 raise AssertionError( 

765 "Expected sharded_param._local_tensor to be contiguous" 

766 ) 

767 self._sharding_spec = cast(DTensor, self.sharded_param).layout 

768 

769 def _get_unsharded_param_data(self, async_op: bool = False) -> Tuple[torch.Tensor, Optional[dist.Work]]: 

770 """ 

771 Perform all-gather to get unsharded parameter data. 

772 

773 Args: 

774 async_op: Whether to execute asynchronously. 

775 

776 Returns: 

777 (unsharded_param, handle): Unsharded parameter data and communication handle. 

778 """ 

779 # If parameter is not sharded (below threshold), no communication needed 

780 if not self.is_sharded: 

781 all_gather_input = self.all_gather_inputs[0] 

782 self.init_all_gather_outputs( 

783 all_gather_input_numels=[all_gather_input.numel()], 

784 all_gather_input_dtypes=[all_gather_input.dtype], 

785 world_size=1, 

786 device=self.device, 

787 ) 

788 self.alloc_all_gather_outputs() 

789 _copy_without_bumping_version(self.all_gather_outputs[0], all_gather_input) 

790 return self.all_gather_outputs[0], None 

791 

792 # Get input data 

793 all_gather_input = self.all_gather_inputs[0] 

794 

795 # Initialize output buffer 

796 self.init_all_gather_outputs( 

797 all_gather_input_numels=[all_gather_input.numel()], 

798 all_gather_input_dtypes=[all_gather_input.dtype], 

799 world_size=self.shard_world_size, 

800 device=self.device, 

801 ) 

802 self.alloc_all_gather_outputs() 

803 

804 if self.sharded_group_info.group is None or self.shard_world_size <= 1: 

805 # No communication needed, just copy 

806 _copy_without_bumping_version(self.all_gather_outputs[0], all_gather_input) 

807 return self.all_gather_outputs[0], None 

808 

809 # Execute all_gather_into_tensor 

810 handle = dist.all_gather_into_tensor( 

811 self.all_gather_outputs[0], 

812 all_gather_input, 

813 group=self.sharded_group_info.group, 

814 async_op=async_op, 

815 ) 

816 

817 return self.all_gather_outputs[0], handle 

818 

819 def unshard(self, async_op: bool = False) -> None: 

820 if self.prefetch_handle is not None: 

821 # Already triggered by HSDPState.prefetch(), so return directly. 

822 return # no-op 

823 

824 _, handle = self._get_unsharded_param_data(async_op=async_op) 

825 self.prefetch_handle = handle 

826 

827 def wait_for_unshard(self) -> None: 

828 self._assert_in_states(ShardedState.SHARDED) 

829 

830 if self.prefetch_handle is not None: 

831 self.prefetch_handle.wait() 

832 self.prefetch_handle = None 

833 

834 self.init_unsharded_param() 

835 self.to_unsharded() 

836 

837 def shard(self) -> None: 

838 """ 

839 Transition parameter from unsharded back to sharded state. 

840 """ 

841 self._assert_in_states(ShardedState.UNSHARDED) 

842 self.to_sharded() 

843 

844 def reduce_scatter_grad( 

845 self, 

846 async_op: bool = True, 

847 dtype: Optional[torch.dtype] = None, 

848 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG 

849 ) -> Union[None, Tuple[torch.Tensor, Optional[dist.Work]]]: 

850 """ 

851 Perform reduce-scatter on gradient to reduce and shard the full gradient. 

852 

853 Args: 

854 async_op: Whether to execute asynchronously. 

855 dtype: reduce dtype. 

856 reduce_op: do reduce-scatter avg or sum. 

857 

858 Returns: 

859 (sharded_grad, handle): Sharded gradient and communication handle. 

860 """ 

861 self._assert_in_states(ShardedState.UNSHARDED) 

862 

863 # Choose gradient source based on use_accumulated_grad flag 

864 if self.unsharded_accumulated_grad is not None: 

865 grad = self.unsharded_accumulated_grad_data 

866 else: 

867 grad = self.unsharded_grad_data 

868 reduce_dtype = dtype or grad.dtype 

869 grad = grad.to(reduce_dtype) 

870 plan_world_size = ( 

871 self.shard_world_size 

872 if self.is_sharded 

873 and self.sharded_group_info.group is not None 

874 and self.shard_world_size > 1 

875 else 1 

876 ) 

877 plan = build_rs_plan(self, grad, plan_world_size) 

878 grad_flat = pack_for_reduce_scatter(grad, plan).reshape(-1) 

879 

880 # If parameter is not sharded (below threshold), no reduce-scatter needed 

881 if not self.is_sharded: 

882 return grad_flat, None 

883 

884 if self.sharded_group_info.group is None or self.shard_world_size <= 1: 

885 # No communication needed 

886 return grad_flat, None 

887 

888 # Calculate output size 

889 output_numel = grad_flat.numel() // self.shard_world_size 

890 self._reduce_scatter_output = torch.empty(output_numel, dtype=reduce_dtype, device=grad.device) 

891 

892 # Execute reduce_scatter_tensor 

893 self.reduce_scatter_handle = dist.reduce_scatter_tensor( 

894 self._reduce_scatter_output, 

895 grad_flat, 

896 op=reduce_op, 

897 group=self.sharded_group_info.group, 

898 async_op=async_op, 

899 ) 

900 return self._reduce_scatter_output, self.reduce_scatter_handle 

901 

902 def all_reduce_grad( 

903 self, 

904 grad: Optional[torch.Tensor] = None, 

905 dtype: Optional[torch.dtype] = None, 

906 async_op: bool = True, 

907 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG 

908 ) -> Union[None, Tuple[torch.Tensor, Optional[dist.Work]]]: 

909 """ 

910 Perform all-reduce on gradient (across replicate dimension in HSDP mode). 

911 

912 Args: 

913 grad: Gradient tensor to reduce. If None, will use unsharded_param.grad 

914 or unsharded_accumulated_grad based on use_accumulated_grad flag. 

915 async_op: Whether to execute asynchronously. 

916 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG. 

917 

918 Returns: 

919 (reduced_grad, handle): Reduced gradient and communication handle. 

920 """ 

921 # If grad is not provided, get from parameter 

922 if grad is None: 

923 if self.unsharded_accumulated_grad is not None: 

924 grad = self.unsharded_accumulated_grad_data 

925 else: 

926 grad = self.unsharded_grad_data 

927 

928 if dtype is not None and dtype != grad.dtype: 

929 grad = grad.to(dtype) 

930 

931 if self.unsharded_group_info.group is None or self.replicate_world_size <= 1: 

932 return grad, None 

933 

934 self.all_reduce_handle = dist.all_reduce(grad, op=reduce_op, 

935 group=self.unsharded_group_info.group, async_op=async_op) 

936 self._all_reduce_output = grad 

937 return grad, self.all_reduce_handle 

938 

939 

940def set_requires_grad_if_needed( 

941 src_tensor: torch.Tensor, dst_tensor: torch.Tensor 

942) -> None: 

943 """set dst_tensor requires_grads from src_tensor if needed.""" 

944 if src_tensor.requires_grad != dst_tensor.requires_grad: 

945 dst_tensor.requires_grad_(src_tensor.requires_grad)