Coverage for hyper_parallel / platform / torch / fully_shard / param.py: 65%

340 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/distributed/fsdp/_fully_shard/_fsdp_param.py 

16# enhanced with fully_shard parameter management 

17# ============================================================================ 

18"""HSDP parameter""" 

19from typing import List, Callable, Optional, cast, Sequence, Tuple, Any 

20from dataclasses import dataclass, field 

21import itertools 

22import torch 

23import torch.nn as nn 

24import torch.distributed as dist 

25from torch._prims_common import make_contiguous_strides_for 

26from hyper_parallel.platform.torch.fully_shard.utils import ( 

27 MixedPrecisionPolicy, 

28 CPUOffloadPolicy, 

29 OffloadPolicy, 

30 FSDPMeshInfo, 

31 DDPMeshInfo, 

32 HSDPMeshInfo, 

33) 

34from hyper_parallel.core.dtensor import DTensor 

35from hyper_parallel.core.layout import Layout 

36from hyper_parallel.core.fully_shard.hsdp_param import HSDPParamV2 

37from hyper_parallel.core.fully_shard.hsdp_utils import ShardedState 

38from hyper_parallel.core.placement_types import Shard, Replicate 

39from hyper_parallel.core.fully_shard.hsdp_utils import ParamModuleInfo, ExtensionsData 

40 

41 

42class TorchHSDPParamV2(HSDPParamV2): 

43 """ 

44 Torch HSDP parameter. 

45 """ 

46 

47 def __init__( 

48 self, 

49 param: nn.Parameter, 

50 module_info: ParamModuleInfo, 

51 mesh_info: FSDPMeshInfo, 

52 post_forward_mesh_info: Optional[FSDPMeshInfo] = None, 

53 shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None, 

54 mp_policy: Optional[MixedPrecisionPolicy] = None, 

55 offload_policy: Optional[OffloadPolicy] = None, 

56 threshold: int = 0, 

57 device: Optional[torch.device] = None, 

58 ): 

59 self._module_info: ParamModuleInfo = module_info 

60 self.mesh_info = mesh_info 

61 self.post_forward_mesh_info = post_forward_mesh_info 

62 self.mp_policy = mp_policy 

63 self.threshold = threshold 

64 self.device = device 

65 self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy) 

66 self.pin_memory = ( 

67 self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory 

68 ) 

69 self.grad_offload_event: Optional[torch.Event] = None 

70 self._init_sharded_param(param, shard_placement_fn) 

71 if self.post_forward_mesh_info: 

72 self._init_sharded_post_forward_param_metadata(param) 

73 self._init_extensions() 

74 self.all_gather_outputs: List[torch.Tensor] = [] 

75 self.unsharded_accumulated_grad = None 

76 self._param_fqn: Optional[str] = None 

77 # Communication attributes for prefetch pattern 

78 self.prefetch_handle: Optional[dist.Work] = None 

79 self._post_load_hook_handle = ( 

80 module_info.module.register_load_state_dict_post_hook( 

81 lambda *args, **kwargs: self.reset_sharded_param() 

82 ) 

83 ) 

84 

85 @torch.no_grad() 

86 def _init_sharded_param( 

87 self, 

88 param: nn.Parameter, 

89 shard_placement_fn: Optional[Callable], 

90 ) -> None: 

91 if param.device != self.device and param.device.type != "meta": 

92 raise AssertionError( 

93 f"Expects the parameter to already be moved to device {self.device} but got {param.device}" 

94 ) 

95 

96 hsdp_placement = shard_placement_fn(param) if shard_placement_fn else None 

97 if hsdp_placement is None: 

98 hsdp_placement = Shard(0) 

99 elif hsdp_placement.dim < 0: 

100 # if dim is negative, add the number of dimensions of the parameter 

101 hsdp_placement = Shard(hsdp_placement.dim + param.ndim) 

102 

103 if not isinstance(hsdp_placement, Shard): 

104 raise AssertionError( 

105 f"Expected Shard, got {type(hsdp_placement)}: {hsdp_placement}" 

106 ) 

107 

108 self.hsdp_placement = hsdp_placement 

109 shard_dim = hsdp_placement.dim 

110 

111 # Non-DTensor parameters have no pre-defined SPMD semantics. 

112 # FSDP/DDP solely determines the mesh and placements. 

113 self._spmd_mesh = self.mesh_info.mesh 

114 if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP 

115 self._spmd_placements = (Replicate(), hsdp_placement) 

116 elif isinstance(self.mesh_info, FSDPMeshInfo): # FSDP 

117 self._spmd_placements = (hsdp_placement,) 

118 elif isinstance(self.mesh_info, DDPMeshInfo): # DDP 

119 self._spmd_placements = (Replicate(),) 

120 param_data = param 

121 

122 shard_dim = hsdp_placement.dim 

123 self._orig_size = param_data.size() 

124 self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size) 

125 

126 if isinstance(self.mesh_info, FSDPMeshInfo): # FSDP or HSDP 

127 shard_rank = self.mesh_info.shard_mesh_rank 

128 shard_world_size = self.mesh_info.shard_mesh_size 

129 else: # DDP 

130 shard_rank = 0 

131 shard_world_size = 1 

132 

133 # Check if parameter size is below threshold, if so skip sharding 

134 param_size = param_data.numel() * param_data.element_size() 

135 if self.threshold > 0 and param_size < self.threshold: 

136 # Parameter too small, do not shard 

137 self.is_sharded = False 

138 self.sharded_size = param_data.size() 

139 self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size) 

140 self._sharded_param_data = param_data.view(-1) 

141 

142 self._sharding_spec = Layout.from_device_mesh(self._spmd_mesh) 

143 # For unsharded params, use Replicate placement 

144 if isinstance(self.mesh_info, HSDPMeshInfo): 

145 self._spmd_placements = (Replicate(), Replicate()) 

146 else: 

147 self._spmd_placements = (Replicate(),) 

148 self._sharding_spec.set_placements(self._spmd_placements) 

149 self._sharding_spec.placement_to_tensor_map(param.ndim) 

150 

151 self.sharded_param = nn.Parameter(DTensor.from_local(param_data, self._spmd_mesh, self._spmd_placements)) 

152 self.sharded_param.requires_grad_(param.requires_grad) 

153 self._setattr_on_modules(self.sharded_param) 

154 self.sharded_state = ShardedState.SHARDED 

155 return 

156 

157 self.is_sharded = True 

158 

159 if param_data.size(shard_dim) % shard_world_size != 0: 

160 raise NotImplementedError( 

161 f"Uneven sharding on dim {shard_dim} not supported: " 

162 f"shape={param_data.shape}, world_size={shard_world_size}" 

163 ) 

164 chunks = torch.chunk(param_data, shard_world_size, dim=shard_dim) 

165 sharded_param = chunks[shard_rank].clone().contiguous() 

166 self.sharded_size = sharded_param.size() 

167 self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size) 

168 if self.offload_to_cpu and not sharded_param.is_meta: 

169 sharded_param = sharded_param.cpu() 

170 if self.pin_memory: 

171 sharded_param = sharded_param.pin_memory() 

172 self._sharded_param_data = sharded_param.view(-1) 

173 

174 self._sharding_spec = Layout.from_device_mesh(self._spmd_mesh) 

175 self._sharding_spec.set_placements(self._spmd_placements) 

176 self._sharding_spec.placement_to_tensor_map(param.ndim) 

177 

178 self.sharded_param = nn.Parameter(DTensor.from_local(sharded_param, self._spmd_mesh, self._spmd_placements)) 

179 self.sharded_param.requires_grad_(param.requires_grad) 

180 self._setattr_on_modules(self.sharded_param) 

181 # 初始化后,self.sharded_param替换掉原先的param,后续梯度也需要注意要累加到这个Parameter的grad上 

182 self.sharded_param._hsdp_param_initialized = True 

183 self.sharded_state = ShardedState.SHARDED 

184 self.param_dtype = None 

185 

186 def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None: 

187 mesh_info = self.post_forward_mesh_info 

188 param_data = param._local_tensor if isinstance(param, DTensor) else param 

189 if isinstance(mesh_info, FSDPMeshInfo): 

190 chunks = torch.chunk(param_data, mesh_info.shard_mesh_size, dim=0) 

191 self.sharded_post_forward_size = chunks[mesh_info.shard_mesh_rank].size() 

192 else: # DDP 

193 chunks = torch.chunk(param_data, 1, dim=0) 

194 self.sharded_post_forward_size = chunks[0].size() 

195 

196 self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for( 

197 self.sharded_post_forward_size 

198 ) 

199 

200 def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy): 

201 param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype) 

202 self.orig_dtype = self.sharded_param.dtype 

203 if reduce_dtype == param_dtype: 

204 reduce_dtype = None 

205 if param_dtype == self.orig_dtype: 

206 param_dtype = None 

207 self.param_dtype = param_dtype 

208 self.reduce_dtype = reduce_dtype 

209 

210 def _init_extensions(self) -> None: 

211 inner_tensor = self._sharded_local_tensor 

212 has_fsdp_pre_all_gather = hasattr(inner_tensor, "fsdp_pre_all_gather") 

213 has_fsdp_post_all_gather = hasattr(inner_tensor, "fsdp_post_all_gather") 

214 if has_fsdp_pre_all_gather != has_fsdp_post_all_gather: 

215 raise AssertionError( 

216 "Both fsdp_pre_all_gather and fsdp_post_all_gather should be defined " 

217 f"if using all-gather extensions: {inner_tensor}" 

218 ) 

219 if has_fsdp_pre_all_gather: 

220 self._extensions_data = ExtensionsData() 

221 self._unsharded_inner_tensors: list[torch.Tensor] = [] 

222 

223 def init_all_gather_outputs( 

224 self, 

225 all_gather_input_numels: list[int], 

226 all_gather_input_dtypes: list[torch.dtype], 

227 world_size: int, 

228 device: torch.device, 

229 force_recreate: bool = False, 

230 ): 

231 if not force_recreate and len(self.all_gather_outputs) > 0: 

232 return # already initialized 

233 self.all_gather_outputs = [ 

234 torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device) 

235 for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes) 

236 ] 

237 

238 def init_unsharded_param(self): 

239 """ 

240 Initialize unsharded parameter from all-gather outputs. 

241 

242 This reconstructs the full parameter after all-gather by using 

243 the gathered data and reshaping it to the original size. 

244 """ 

245 if hasattr(self, "_unsharded_param"): 

246 return 

247 

248 # Get unsharded data from all-gather outputs 

249 if len(self.all_gather_outputs) != 1: 

250 raise AssertionError( 

251 f"Expected 1 all_gather_output, got {len(self.all_gather_outputs)}" 

252 ) 

253 unsharded_tensor = self.all_gather_outputs[0] 

254 # Use reshape to safely handle both contiguous and non-contiguous memory layouts. 

255 # It acts as a zero-copy view if possible, otherwise it performs a copy. 

256 # unsharded_param = unsharded_tensor.reshape(self._orig_size) 

257 unsharded_param = torch.as_strided( 

258 unsharded_tensor, 

259 self._orig_size, 

260 self._contiguous_orig_stride, 

261 storage_offset=0, 

262 ) 

263 

264 self._unsharded_param = nn.Parameter( 

265 unsharded_param, requires_grad=self.sharded_param.requires_grad 

266 ) 

267 

268 def to_sharded(self) -> None: 

269 self._setattr_on_modules(self.sharded_param) 

270 self.free_unsharded_param() 

271 self.sharded_state = ShardedState.SHARDED 

272 

273 def to_sharded_post_forward(self) -> None: 

274 if self.sharded_state != ShardedState.UNSHARDED: 

275 raise AssertionError(f"Expected sharded_state to be UNSHARDED, got {self.sharded_state}") 

276 shard_world_size = self.post_forward_mesh_info.shard_mesh_size 

277 numel = self.all_gather_outputs[0].numel() 

278 if numel % shard_world_size != 0: 

279 raise AssertionError( 

280 f"All-gather output size ({numel}) must be divisible by the shard " 

281 f"world size ({shard_world_size}). Check padding/mesh alignment." 

282 ) 

283 shard_rank = self.post_forward_mesh_info.shard_mesh_rank 

284 sharded_numel = numel // shard_world_size 

285 # clone to be able to free all-gather output 

286 self._sharded_post_forward_param_data = ( 

287 self.all_gather_outputs[0].narrow( 

288 0, sharded_numel * shard_rank, sharded_numel 

289 ) 

290 ).clone() 

291 # sharded_post_forward_tensor = self._sharded_post_forward_param_data.view( 

292 # self.sharded_post_forward_size 

293 # ) 

294 sharded_post_forward_tensor = torch.as_strided( 

295 self._sharded_post_forward_param_data, 

296 size=self.sharded_post_forward_size, 

297 stride=self.contiguous_sharded_post_forward_stride, 

298 storage_offset=0, 

299 ) 

300 self._sharded_post_forward_param = nn.Parameter( 

301 self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor) 

302 ) 

303 self._setattr_on_modules(self._sharded_post_forward_param) 

304 self.free_unsharded_param() 

305 self.sharded_state = ShardedState.SHARDED_POST_FORWARD 

306 

307 def to_unsharded(self) -> None: 

308 set_requires_grad_if_needed(self.sharded_param, self._unsharded_param) 

309 self._setattr_on_modules(self._unsharded_param) 

310 if self.sharded_state == ShardedState.SHARDED_POST_FORWARD: 

311 self._sharded_post_forward_param = None 

312 self._sharded_post_forward_param_data = None 

313 self.sharded_state = ShardedState.UNSHARDED 

314 

315 def _setattr_on_modules(self, param: nn.Parameter) -> None: 

316 if getattr(self._module_info.module.__setattr__, "__func__", None) is nn.Module.__setattr__: 

317 # fast path 

318 self._module_info.module._parameters[self._module_info.param_name] = param 

319 else: 

320 # slow path 

321 setattr(self._module_info.module, self._module_info.param_name, param) 

322 

323 # Iterate through all modules that share this parameter to prevent pointer desync. 

324 for shared_module, shared_param_name in zip( 

325 self._module_info.shared_modules, self._module_info.shared_param_names 

326 ): 

327 if getattr(shared_module.__setattr__, "__func__", None) is nn.Module.__setattr__: 

328 shared_module._parameters[shared_param_name] = param 

329 else: 

330 setattr(shared_module, shared_param_name, param) 

331 

332 def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor: 

333 """ 

334 Converts a local tensor representing either the sharded parameter or 

335 sharded gradient to DTensor. 

336 """ 

337 return DTensor.from_local( 

338 tensor, 

339 self._sharding_spec.mesh, 

340 self._sharding_spec.placements 

341 ) 

342 

343 def to_sharded_post_forward_dtensor(self, tensor: torch.Tensor) -> DTensor: 

344 """ 

345 Converts a local tensor to DTensor with post-forward sharding layout. 

346 """ 

347 post_forward_layout = Layout.from_device_mesh(self.post_forward_mesh_info.mesh) 

348 post_forward_layout.set_placements((Replicate(), Shard(0))) 

349 post_forward_layout.placement_to_tensor_map(tensor.ndim) 

350 return DTensor.from_local(tensor, post_forward_layout.mesh, post_forward_layout.placements) 

351 

352 def to_accumulated_grad_if_needed(self) -> None: 

353 if ( 

354 self._unsharded_param.grad is not None 

355 and self.reduce_dtype is not None 

356 and self._unsharded_param.grad.dtype != self.reduce_dtype 

357 ): 

358 # need to handle the gradient even after the parameter is resharded 

359 unsharded_grad = self._unsharded_param.grad 

360 self._unsharded_param.grad = None 

361 self.unsharded_accumulated_grad = unsharded_grad.to(self.reduce_dtype) 

362 

363 def accumulate_unsharded_grad_if_needed(self) -> None: 

364 if ( 

365 self.unsharded_accumulated_grad is not None 

366 and self.unsharded_param.grad is not None 

367 ): 

368 # need to handle the gradient 

369 self.unsharded_accumulated_grad += self.unsharded_param.grad 

370 self.unsharded_param.grad = None 

371 

372 def alloc_all_gather_outputs(self) -> None: 

373 for tensor in self.all_gather_outputs: 

374 expected_size = tensor.numel() * tensor.itemsize 

375 storage = tensor.untyped_storage() 

376 if storage.size() != expected_size: 

377 storage.resize_(expected_size) 

378 

379 def free_unsharded_param(self) -> None: 

380 for tensor in itertools.chain( 

381 self.all_gather_outputs, self._unsharded_inner_tensors 

382 ): 

383 storage = tensor.untyped_storage() 

384 if storage.size() != 0: 

385 storage.resize_(0) 

386 

387 @property 

388 def all_gather_inputs(self) -> list[torch.Tensor]: 

389 self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD) 

390 if self.sharded_state == ShardedState.SHARDED: 

391 sharded_param_data = self._sharded_param_data 

392 if self.offload_to_cpu: 

393 sharded_param_data = sharded_param_data.to( 

394 self.device, non_blocking=True 

395 ) 

396 if self.param_dtype is not None and self.param_dtype != sharded_param_data.dtype: 

397 return [sharded_param_data.to(self.param_dtype)] 

398 else: 

399 return [sharded_param_data] 

400 elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD: 

401 if self.param_dtype is not None and self.param_dtype != self._sharded_post_forward_param_data.dtype: 

402 return [self._sharded_post_forward_param_data.to(self.param_dtype)] 

403 else: 

404 return [self._sharded_post_forward_param_data] 

405 return [torch.empty(0)] 

406 

407 @property 

408 def unsharded_param(self) -> nn.Parameter: # ND 

409 return self._unsharded_param 

410 

411 @property 

412 def unsharded_grad_data(self) -> torch.Tensor: 

413 """ 

414 Get the unsharded gradient data as a local tensor. 

415 """ 

416 grad = self.unsharded_param.grad 

417 if grad is None: 

418 raise AssertionError("Expects unsharded_param.grad to not be None") 

419 if isinstance(grad, DTensor): 

420 raise AssertionError("Expected torch.Tensor, got DTensor") 

421 return grad 

422 

423 @property 

424 def unsharded_accumulated_grad_data(self) -> torch.Tensor: 

425 """ 

426 Get the unsharded accumulated gradient data as a local tensor. 

427 """ 

428 grad = self.unsharded_accumulated_grad 

429 # if grad is None: 

430 # raise AssertionError("Expects unsharded_accumulated_grad to not be None") 

431 # if isinstance(grad, DTensor): 

432 # raise AssertionError("Expected torch.Tensor, got DTensor") 

433 return grad 

434 

435 @property 

436 def _sharded_local_tensor(self) -> torch.Tensor: 

437 return cast(DTensor, self.sharded_param)._local_tensor 

438 

439 @property 

440 def shard_world_size(self) -> int: 

441 """Get the world size for shard dimension.""" 

442 if isinstance(self.mesh_info, FSDPMeshInfo): 

443 return self.mesh_info.shard_mesh_size 

444 return 1 

445 

446 @property 

447 def replicate_world_size(self) -> int: 

448 """Get the world size for replicate dimension (HSDP only).""" 

449 if isinstance(self.mesh_info, HSDPMeshInfo): 

450 return self.mesh_info.replicate_mesh_size 

451 return 1 

452 

453 def _assert_in_states(self, *states: ShardedState) -> None: 

454 """Assert current state is one of expected states.""" 

455 if self.sharded_state not in states: 

456 raise AssertionError( 

457 f"Expected sharded_state in {states}, got {self.sharded_state}" 

458 ) 

459 

460 def reset_sharded_param(self) -> None: 

461 """Reset sharded param after load_state_dict.""" 

462 module_info = self._module_info 

463 new_param = getattr(module_info.module, module_info.param_name) 

464 if new_param is not self.sharded_param: 

465 # Ensure object identity is preserved after parameter conversion. 

466 if torch.__future__.get_swap_module_params_on_conversion(): 

467 raise AssertionError( 

468 f"Expects swap_tensors to preserve object but got {new_param} " 

469 f"instead of {self.sharded_param}" 

470 ) 

471 self.sharded_param = new_param 

472 

473 local_tensor = new_param._local_tensor 

474 if local_tensor.is_meta: 

475 return 

476 updated_local_tensor = False 

477 # local_tensor can be padded twice 

478 # 1st time in fully_shard(model) 

479 # 2nd time in model(input) lazy_init 

480 # 2nd time should be no-op if parameters remain unchanged 

481 # 2nd time shouldn't be no-op if people call model.load_state_dict(...) before lazy_init 

482 # this makes it possible for trainer to call `sd = model.state_dict()` before the training loop 

483 # and use `sd` without calling .state_dict() per iteration 

484 same_local_tensor = False 

485 # TODO: need to support tensor subclass 

486 if type(self._sharded_param_data) is torch.Tensor: 

487 same_local_tensor = ( 

488 # when sharding param with shape (1, ...) over 2 ranks 

489 # local_tensor on rank 1 can be size 0, data_ptr() can be 0 

490 self._sharded_param_data.untyped_storage().data_ptr() > 0 

491 and self._sharded_param_data.untyped_storage().data_ptr() 

492 == local_tensor.untyped_storage().data_ptr() 

493 ) 

494 sharded_size = self.sharded_size 

495 shard_dim = self.hsdp_placement.dim 

496 length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0 

497 if local_tensor.size() != sharded_size and not same_local_tensor: 

498 raise AssertionError( 

499 f"Expected sharded_size to be {sharded_size}, got {local_tensor.size()}" 

500 ) 

501 if self.pin_memory and not local_tensor.is_pinned(): 

502 local_tensor = local_tensor.cpu().pin_memory() 

503 updated_local_tensor = True 

504 if not same_local_tensor: 

505 self._sharded_param_data = local_tensor.view(-1) 

506 if not isinstance(self.sharded_param, DTensor): 

507 raise AssertionError(f"Expected DTensor, got {type(self.sharded_param)}") 

508 if updated_local_tensor: 

509 # Only change the local tensor object if needed 

510 self.sharded_param._local_tensor = local_tensor.narrow( 

511 dim=shard_dim, start=0, length=length 

512 ) 

513 if not self.sharded_param._local_tensor.is_contiguous(): 

514 raise AssertionError( 

515 "Expected sharded_param._local_tensor to be contiguous" 

516 ) 

517 self._sharding_spec = cast(DTensor, self.sharded_param).layout 

518 

519 def _get_unsharded_param_data(self, async_op: bool = False) -> Tuple[torch.Tensor, Optional[dist.Work]]: 

520 """ 

521 Perform all-gather to get unsharded parameter data. 

522 

523 Args: 

524 async_op: Whether to execute asynchronously. 

525 

526 Returns: 

527 (unsharded_param, handle): Unsharded parameter data and communication handle. 

528 """ 

529 # If parameter is not sharded (below threshold), no communication needed 

530 if not self.is_sharded: 

531 self.init_all_gather_outputs( 

532 all_gather_input_numels=[self._sharded_param_data.numel()], 

533 all_gather_input_dtypes=[self._sharded_param_data.dtype], 

534 world_size=1, 

535 device=self.device, 

536 ) 

537 self.alloc_all_gather_outputs() 

538 self.all_gather_outputs[0].copy_(self._sharded_param_data) 

539 return self.all_gather_outputs[0], None 

540 

541 # Get input data 

542 all_gather_input = self.all_gather_inputs[0] 

543 

544 # Initialize output buffer 

545 self.init_all_gather_outputs( 

546 all_gather_input_numels=[all_gather_input.numel()], 

547 all_gather_input_dtypes=[all_gather_input.dtype], 

548 world_size=self.shard_world_size, 

549 device=self.device, 

550 ) 

551 self.alloc_all_gather_outputs() 

552 

553 # Get communication group 

554 shard_group = self.mesh_info.shard_process_group if isinstance(self.mesh_info, FSDPMeshInfo) else None 

555 

556 if shard_group is None or self.shard_world_size <= 1: 

557 # No communication needed, just copy 

558 self.all_gather_outputs[0].copy_(all_gather_input) 

559 return self.all_gather_outputs[0], None 

560 

561 # Execute all_gather_into_tensor 

562 handle = dist.all_gather_into_tensor( 

563 self.all_gather_outputs[0], 

564 all_gather_input, 

565 group=shard_group, 

566 async_op=async_op, 

567 ) 

568 

569 return self.all_gather_outputs[0], handle 

570 

571 def unshard(self, async_op: bool = False) -> None: 

572 if self.prefetch_handle is not None: 

573 # 已经被prefetch 触发过了,直接return 

574 return # no-op 

575 

576 _, handle = self._get_unsharded_param_data(async_op=async_op) 

577 self.prefetch_handle = handle 

578 

579 def wait_for_unshard(self) -> None: 

580 self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD) 

581 

582 if self.prefetch_handle is not None: 

583 self.prefetch_handle.wait() 

584 self.prefetch_handle = None 

585 

586 self.init_unsharded_param() 

587 self.to_unsharded() 

588 

589 def shard(self) -> None: 

590 """ 

591 Transition parameter from unsharded back to sharded state. 

592 """ 

593 self._assert_in_states(ShardedState.UNSHARDED) 

594 self.to_sharded() 

595 

596 def reduce_scatter_grad( 

597 self, 

598 async_op: bool = False, 

599 dtype: Optional[torch.dtype] = None, 

600 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG 

601 ) -> Tuple[torch.Tensor, Optional[dist.Work]]: 

602 """ 

603 Perform reduce-scatter on gradient to reduce and shard the full gradient. 

604 

605 Args: 

606 async_op: Whether to execute asynchronously. 

607 dtype: reduce dtype. 

608 reduce_op: do reduce-scatter avg or sum. 

609 

610 Returns: 

611 (sharded_grad, handle): Sharded gradient and communication handle. 

612 """ 

613 self._assert_in_states(ShardedState.UNSHARDED) 

614 

615 # Choose gradient source based on use_accumulated_grad flag 

616 if self.unsharded_accumulated_grad is not None: 

617 grad = self.unsharded_accumulated_grad_data 

618 else: 

619 grad = self.unsharded_grad_data 

620 reduce_dtype = dtype or grad.dtype 

621 grad = grad.to(reduce_dtype) 

622 grad_flat = grad.view(-1) 

623 

624 # If parameter is not sharded (below threshold), no reduce-scatter needed 

625 if not self.is_sharded: 

626 return grad_flat, None 

627 

628 # Get communication group 

629 shard_group = self.mesh_info.shard_process_group if isinstance(self.mesh_info, FSDPMeshInfo) else None 

630 

631 if shard_group is None or self.shard_world_size <= 1: 

632 # No communication needed 

633 return grad_flat, None 

634 

635 # Calculate output size 

636 output_numel = grad_flat.numel() // self.shard_world_size 

637 output = torch.empty(output_numel, dtype=reduce_dtype, device=grad.device) 

638 

639 # Execute reduce_scatter_tensor 

640 handle = dist.reduce_scatter_tensor( 

641 output, 

642 grad_flat, 

643 op=reduce_op, 

644 group=shard_group, 

645 async_op=async_op, 

646 ) 

647 

648 return output, handle 

649 

650 def all_reduce_grad( 

651 self, 

652 grad: Optional[torch.Tensor] = None, 

653 async_op: bool = False, 

654 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG 

655 ) -> Tuple[torch.Tensor, Optional[dist.Work]]: 

656 """ 

657 Perform all-reduce on gradient (across replicate dimension in HSDP mode). 

658 

659 Args: 

660 grad: Gradient tensor to reduce. If None, will use unsharded_param.grad 

661 or unsharded_accumulated_grad based on use_accumulated_grad flag. 

662 async_op: Whether to execute asynchronously. 

663 reduce_op: Optional[dist.ReduceOp] = dist.ReduceOp.AVG. 

664 

665 Returns: 

666 (reduced_grad, handle): Reduced gradient and communication handle. 

667 """ 

668 # If grad is not provided, get from parameter 

669 if grad is None: 

670 if self.unsharded_accumulated_grad is not None: 

671 grad = self.unsharded_accumulated_grad_data 

672 else: 

673 grad = self.unsharded_grad_data 

674 

675 if not isinstance(self.mesh_info, HSDPMeshInfo): 

676 # Not HSDP mode, no all-reduce needed 

677 return grad, None 

678 

679 replicate_group = self.mesh_info.replicate_process_group 

680 if replicate_group is None or self.replicate_world_size <= 1: 

681 return grad, None 

682 

683 handle = dist.all_reduce( 

684 grad, 

685 op=reduce_op, 

686 group=replicate_group, 

687 async_op=async_op 

688 ) 

689 return grad, handle 

690 

691 

692def set_requires_grad_if_needed( 

693 src_tensor: torch.Tensor, dst_tensor: torch.Tensor 

694) -> None: 

695 if src_tensor.requires_grad != dst_tensor.requires_grad: 

696 dst_tensor.requires_grad_(src_tensor.requires_grad)