Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / state.py: 61%

270 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Torch HSDP cell state""" 

16# pylint: disable=protected-access 

17 

18from typing import Optional 

19 

20import torch 

21 

22from hyper_parallel.core.fully_shard.hsdp_state import HSDPState 

23from hyper_parallel.core.fully_shard.hsdp_utils import ( 

24 FullyShardParamMode, 

25 _get_param_module_infos, 

26 infer_fully_shard_param_mode, 

27) 

28from hyper_parallel.core.fully_shard.utils import CPUOffloadPolicy 

29from hyper_parallel.platform.torch.fully_shard.param import TorchHSDPParamV2 

30from hyper_parallel.platform.torch.fully_shard.pack_utils import build_rs_plan 

31from hyper_parallel.platform.torch.fully_shard.param_group import get_comm_ctx, HSDPParamGroup 

32 

33 

34def _to_dtype_if_needed( 

35 tensor: torch.Tensor, dtype: Optional[torch.dtype] 

36) -> torch.Tensor: 

37 """Cast tensor to the given dtype if it differs from current dtype. 

38 

39 Args: 

40 tensor: The input tensor to potentially cast. 

41 dtype: Target dtype. If None or same as tensor dtype, no-op. 

42 """ 

43 if dtype is not None and tensor.dtype != dtype: 

44 return tensor.to(dtype) 

45 return tensor 

46 

47 

48class TorchHSDPStateV2(HSDPState): 

49 """Torch HSDP cell state""" 

50 # DTensor compat parameters in pure-TP mode can accumulate gradients 

51 # directly on ``sharded_param.grad`` without ever materializing an 

52 # ``_unsharded_param``. Track their async all-reduce work separately from 

53 # the standard unsharded-grad queues. 

54 pre_direct_all_reduce_grads = [] 

55 

56 @staticmethod 

57 def _get_pending_unsharded_grad(hsdp_param): 

58 """Return the pending unsharded gradient tensor for all-reduce-based paths.""" 

59 if hsdp_param.unsharded_accumulated_grad is not None: 

60 return hsdp_param.unsharded_accumulated_grad_data 

61 return hsdp_param.unsharded_grad_data 

62 

63 @staticmethod 

64 def _has_pending_unsharded_grad(hsdp_param): 

65 """Whether the parameter currently has a gradient waiting for reduction.""" 

66 if hsdp_param.unsharded_accumulated_grad is not None: 

67 return True 

68 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None: 

69 return False 

70 return hsdp_param.unsharded_param.grad is not None 

71 

72 @staticmethod 

73 def _get_local_sharded_grad(hsdp_param): 

74 """Return the local gradient tensor currently stored on ``sharded_param``.""" 

75 grad = hsdp_param.sharded_param.grad 

76 if grad is None: 

77 return None 

78 to_local = getattr(grad, "to_local", None) 

79 if callable(to_local): 

80 return to_local() 

81 return grad 

82 

83 def __init__(self, cell, mesh_info, config, platform, device): 

84 """ 

85 Initialize TorchHSDPStateV2. 

86 

87 Args: 

88 cell (nn.Module): The module whose parameters are managed by this state. 

89 mesh_info: Mesh topology for shard/replicate dimensions. 

90 config (HSDPConfigV2): HSDP configuration. 

91 platform (TorchPlatform): Torch platform abstraction. 

92 device (torch.device): Target device. 

93 """ 

94 super().__init__(cell, mesh_info, config, platform, device) 

95 self.comm_fusion = config.comm_fusion 

96 # Do ReduceScatter/AllReduce for grad 

97 self.device = device 

98 self.mp_policy = config.mp_policy 

99 self.offload_policy = config.offload_policy 

100 self.reduce_grads = True 

101 # Reshard parameter after backward 

102 self.reshard_after_backward = True 

103 # Requires AllReduce for grad When HSDP 

104 self.requires_all_reduce = True 

105 # Default reduce op is decided at the fully_shard-state level: 

106 # if any managed parameter is DTensor-backed, use SUM; otherwise AVG. 

107 self._user_reduce_op_type = None 

108 self.reduce_op_type = self._resolve_default_reduce_op() 

109 self._reset_sharded_params = False 

110 self._init_param_group() 

111 

112 @staticmethod 

113 def _comm_fusion_unsupported_reason(hsdp_param) -> Optional[str]: 

114 """Return the reason why ``hsdp_param`` cannot participate in comm_fusion.""" 

115 if not hsdp_param.enable_fsdp_shard: 

116 return "non-sharded parameters such as replicate_params are not supported" 

117 if hsdp_param.param_mode not in ( 

118 FullyShardParamMode.LOCAL_PARAM, 

119 FullyShardParamMode.DTENSOR_UNIFIED, 

120 ): 

121 return ( 

122 "param_mode " 

123 f"{hsdp_param.param_mode} is not supported" 

124 ) 

125 local_shard = getattr(hsdp_param, "_sharded_local_tensor", None) 

126 if local_shard is None: 

127 return "missing local shard tensor for comm_fusion plan validation" 

128 plan_world_size = getattr(hsdp_param, "shard_world_size", None) 

129 if plan_world_size is None: 

130 plan_world_size = getattr(hsdp_param, "shard_size", 1) 

131 try: 

132 build_rs_plan(hsdp_param, local_shard, plan_world_size) 

133 except NotImplementedError as exc: 

134 return str(exc) 

135 except (AssertionError, ValueError) as exc: 

136 return f"cannot build comm_fusion pack plan: {exc}" 

137 return None 

138 

139 def _init_param_group(self): 

140 """Initialize fused parameter group for communication fusion. 

141 

142 When ``comm_fusion`` is enabled, creates an ``HSDPParamGroup`` that packs all 

143 parameters into a single buffer for fused all-gather and reduce-scatter, 

144 replacing the per-parameter communication pattern. 

145 """ 

146 if self.config.comm_fusion: 

147 unsupported_param = next( 

148 ( 

149 hsdp_param 

150 for hsdp_param in self.hsdp_params 

151 if self._comm_fusion_unsupported_reason(hsdp_param) is not None 

152 ), 

153 None, 

154 ) 

155 if unsupported_param is not None: 

156 param_fqn = getattr(unsupported_param, "_param_fqn", "<unknown>") 

157 reason = self._comm_fusion_unsupported_reason(unsupported_param) 

158 raise NotImplementedError( 

159 f"comm_fusion does not support parameter {param_fqn}: {reason}." 

160 ) 

161 self.param_group = None 

162 if self.hsdp_params: 

163 # pylint: disable=E1128 

164 self.param_group = HSDPParamGroup( 

165 self.hsdp_params, 

166 self.mesh_info, 

167 self.device, 

168 self.mp_policy, 

169 self.config.comm_fusion_zero_copy, 

170 ) 

171 

172 def _move_states_to_device(self): 

173 """move states to device""" 

174 for mod in self.modules: 

175 for param in mod.parameters(): 

176 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized: 

177 continue 

178 if param.device == self.device or param.device.type == "meta": 

179 continue 

180 param.data = param.to(self.device) 

181 for buffer in mod.buffers(): 

182 if buffer.device == self.device or buffer.device.type == "meta": 

183 continue 

184 buffer.data = buffer.to(self.device) 

185 

186 def _init_hsdp_params(self): 

187 """init hsdp parameters and replicate parameters for cell.""" 

188 replicate_params = set(self.config.replicate_params or ()) 

189 # all parameters in the module tree(s), deduplicated 

190 ignored_params = set(self.config.ignored_params or ()) 

191 visited_params = set() 

192 filtered_params = [] 

193 for mod in self.modules: 

194 for _, param in mod.named_parameters(): 

195 if param in ignored_params: 

196 continue 

197 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized: 

198 continue 

199 if param in visited_params: 

200 continue 

201 visited_params.add(param) 

202 filtered_params.append(param) 

203 

204 module_infos = _get_param_module_infos(filtered_params, tuple(self.modules)) 

205 for param, module_info in zip(filtered_params, module_infos): 

206 param_mode = infer_fully_shard_param_mode(self.config.mesh, [param]) 

207 enable_fsdp_shard = param not in replicate_params 

208 hsdp_param = TorchHSDPParamV2(param, 

209 module_info, 

210 self.mesh_info, 

211 shard_placement_fn=self.config.shard_placement_fn, 

212 mp_policy=self.mp_policy, 

213 offload_policy=self.offload_policy, 

214 device=self.device, 

215 param_mode=param_mode, 

216 enable_fsdp_shard=enable_fsdp_shard, 

217 ) 

218 if param in replicate_params: 

219 self.replicate_params.append(hsdp_param) 

220 else: 

221 self.hsdp_params.append(hsdp_param) 

222 if hsdp_param.is_sharded: 

223 self.sharded_hsdp_params.append(hsdp_param) 

224 

225 def _init_mp_dtypes(self): 

226 """init mp dtypes for hsdp parameters and replicate parameters""" 

227 for hsdp_param in self.hsdp_params: 

228 hsdp_param.init_dtype_attrs(self.mp_policy) 

229 for replicate_param in self.replicate_params: 

230 replicate_param.init_dtype_attrs(self.mp_policy) 

231 trainable_params: list[TorchHSDPParamV2] = [ 

232 p for p in self._iter_managed_params() if p.sharded_param.requires_grad 

233 ] 

234 orig_dtypes = {p.orig_dtype for p in trainable_params} 

235 reduce_dtypes = {p.reduce_dtype for p in trainable_params} 

236 if len(trainable_params) > 0 and len(orig_dtypes) != 1: 

237 raise AssertionError( 

238 f"hsdp expects uniform original parameter dtype but got {orig_dtypes}" 

239 ) 

240 self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None 

241 if len(trainable_params) > 0 and len(reduce_dtypes) != 1: 

242 raise AssertionError( 

243 f"hsdp expects uniform reduce dtype but got {reduce_dtypes}" 

244 ) 

245 self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None 

246 

247 def _validate_cpu_offload_params(self): 

248 """Validate that all parameters are on CPU when CPU offload policy is enabled.""" 

249 if not isinstance(self.offload_policy, CPUOffloadPolicy): 

250 return 

251 hsdp_params_not_on_cpu = [ 

252 hsdp_param 

253 for hsdp_param in self._iter_managed_params() 

254 if hsdp_param.sharded_param.device.type != "cpu" 

255 ] 

256 if hsdp_params_not_on_cpu: 

257 raise RuntimeError( 

258 "HSDP parameters should be materialized on CPU when enabling CPU offloading. " 

259 'For example, load a CPU state dict or call module.to_empty(device="cpu"). ' 

260 "Found following parameters on non-CPU device: " 

261 f"{[(p._param_fqn, p.sharded_param.device) for p in hsdp_params_not_on_cpu]}\n" 

262 ) 

263 

264 def lazy_init(self): 

265 if not self._reset_sharded_params: 

266 for hsdp_param in self.hsdp_params: 

267 if hsdp_param.is_sharded: 

268 hsdp_param.reset_sharded_param() 

269 self._reset_sharded_params = True 

270 self._validate_no_meta_params() 

271 self._validate_cpu_offload_params() 

272 self._init_mp_dtypes() 

273 

274 def _validate_no_meta_params(self): 

275 param_names_on_meta = [ 

276 hsdp_param._param_fqn 

277 for hsdp_param in self._iter_managed_params() 

278 if hsdp_param.sharded_param.device.type == "meta" 

279 ] 

280 if param_names_on_meta: 

281 raise RuntimeError( 

282 "HSDP parameters should be materialized from meta device before training, " 

283 f"but the following were still on meta device: {param_names_on_meta}\n" 

284 "For example, call module.to_empty(device) to materialize to device and " 

285 "call module.reset_parameters() on each module to initialize values." 

286 ) 

287 

288 def post_backward_for_comm_fusion(self): 

289 """post_backward_for_comm_fusion.""" 

290 # Replicate-only params still use the non-fused compat all-reduce path. 

291 # Drain any pending side-path reductions before advancing the fused 

292 # param-group pipeline for sharded params. 

293 self.reduce_params() 

294 # Fused gradient reduction path: first apply any pending async reduction 

295 # from the previous module's backward (pipelined overlap), then issue 

296 # this module's fused reduce-scatter (+ all-reduce for HSDP). 

297 comm_ctx = get_comm_ctx() 

298 # Phase 2: apply grads for the param group whose all_reduce is done 

299 if comm_ctx.all_reduce_param_group is not None: 

300 comm_ctx.all_reduce_param_group.wait_all_reduce_and_apply_grad() 

301 comm_ctx.all_reduce_param_group = None 

302 # Phase 1: wait reduce_scatter, issue async all_reduce for previous layer 

303 if comm_ctx.pre_param_group is not None: 

304 comm_ctx.pre_param_group.wait_reduce_scatter_and_issue_all_reduce() 

305 comm_ctx.pre_param_group = None 

306 if self.param_group is not None: 

307 self.param_group.foreach_reduce( 

308 reduce_scatter_reduce_op=self.reduce_op_type 

309 ) 

310 for hsdp_param in self.replicate_params: 

311 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None: 

312 continue 

313 if not hsdp_param.sharded_param.requires_grad: 

314 continue 

315 if not self._has_pending_unsharded_grad(hsdp_param): 

316 continue 

317 reduce_op = self._resolve_reduce_op(hsdp_param) 

318 self._queue_compat_all_reduce(hsdp_param, reduce_op) 

319 

320 def _resolve_default_reduce_op(self): 

321 """Resolve the default reduce op for the whole fully_shard state.""" 

322 for hsdp_param in self._iter_managed_params(): 

323 if hsdp_param.param_mode in ( 

324 FullyShardParamMode.DTENSOR_COMPAT, 

325 FullyShardParamMode.DTENSOR_UNIFIED, 

326 ): 

327 return torch.distributed.ReduceOp.SUM 

328 return torch.distributed.ReduceOp.AVG 

329 

330 def _resolve_reduce_op(self, hsdp_param=None): 

331 """Resolve the gradient reduction op for the current fully_shard state.""" 

332 if self._user_reduce_op_type is not None: 

333 return self._user_reduce_op_type 

334 return self.reduce_op_type 

335 

336 def _should_run_all_reduce(self, hsdp_param) -> bool: 

337 """Whether the current parameter should issue an all-reduce in this backward pass.""" 

338 return self.requires_all_reduce and hsdp_param.dp_size > 1 

339 

340 def _queue_reduce_scatter_then_all_reduce(self, hsdp_param, reduce_op): 

341 """Queue the standard FSDP/HSDP reduction path.""" 

342 hsdp_param.reduce_scatter_grad( 

343 dtype=self._reduce_dtype, 

344 reduce_op=reduce_op, 

345 ) 

346 HSDPState.pre_reduce_scatter_params.append((hsdp_param, self._orig_dtype)) 

347 if not self._should_run_all_reduce(hsdp_param): 

348 return 

349 reduced_grad = hsdp_param.reduce_scatter_output() 

350 if ( 

351 HSDPState.pre_reduce_scatter_params 

352 and HSDPState.pre_reduce_scatter_params[-1][0] == hsdp_param 

353 ): 

354 HSDPState.pre_reduce_scatter_params.pop() 

355 hsdp_param.all_reduce_grad( 

356 grad=reduced_grad, 

357 dtype=self._reduce_dtype, 

358 reduce_op=reduce_op, 

359 ) 

360 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype)) 

361 

362 def _queue_compat_all_reduce(self, hsdp_param, reduce_op): 

363 """Queue the compatibility all-reduce path without FSDP sharding.""" 

364 if not self._should_run_all_reduce(hsdp_param): 

365 return 

366 hsdp_param.all_reduce_grad( 

367 grad=self._get_pending_unsharded_grad(hsdp_param), 

368 dtype=self._reduce_dtype, 

369 reduce_op=reduce_op, 

370 ) 

371 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype)) 

372 

373 def _can_direct_all_reduce_compat_grad(self, hsdp_param) -> bool: 

374 """Whether ``hsdp_param`` should reduce its existing ``sharded_param.grad`` directly.""" 

375 return ( 

376 hsdp_param.param_mode == FullyShardParamMode.DTENSOR_COMPAT 

377 and hsdp_param.enable_fsdp_shard 

378 and not hsdp_param.is_sharded 

379 and hsdp_param.shard_size == 1 

380 and hsdp_param.sharded_param.requires_grad 

381 and self._should_run_all_reduce(hsdp_param) 

382 and self._get_local_sharded_grad(hsdp_param) is not None 

383 ) 

384 

385 def _queue_direct_compat_all_reduce(self, hsdp_param, reduce_op): 

386 """Queue all-reduce for DTENSOR_COMPAT params whose grad stays on ``sharded_param``.""" 

387 grad = self._get_local_sharded_grad(hsdp_param) 

388 if grad is None: 

389 return 

390 reduced_grad = grad 

391 if self._reduce_dtype is not None and reduced_grad.dtype != self._reduce_dtype: 

392 reduced_grad = reduced_grad.to(self._reduce_dtype) 

393 handle = None 

394 if hsdp_param.unsharded_group_info.group is not None and hsdp_param.dp_size > 1: 

395 handle = torch.distributed.all_reduce( 

396 reduced_grad, 

397 op=reduce_op, 

398 group=hsdp_param.unsharded_group_info.group, 

399 async_op=True, 

400 ) 

401 TorchHSDPStateV2.pre_direct_all_reduce_grads.append((handle, reduced_grad, grad)) 

402 

403 def post_backward(self, *unused): # pylint: disable=unused-argument 

404 """Reduce gradients and reshard parameters after backward.""" 

405 for hsdp_param in self._iter_managed_params(): 

406 hsdp_param.accumulate_unsharded_grad_if_needed() 

407 if not self.reduce_grads: 

408 if self.reshard_after_backward: 

409 self.shard() 

410 for hsdp_param in self._iter_managed_params(): 

411 hsdp_param.to_accumulated_grad_if_needed() 

412 return 

413 if not self.comm_fusion: 

414 self.reduce_params() 

415 for hsdp_param in self._iter_managed_params(): 

416 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None: 

417 if self._can_direct_all_reduce_compat_grad(hsdp_param): 

418 reduce_op = self._resolve_reduce_op(hsdp_param) 

419 self._queue_direct_compat_all_reduce(hsdp_param, reduce_op) 

420 continue 

421 # Frozen parameters produce no gradient, so there is nothing to reduce. 

422 if not hsdp_param.sharded_param.requires_grad: 

423 continue 

424 if not self._has_pending_unsharded_grad(hsdp_param): 

425 continue 

426 reduce_op = self._resolve_reduce_op(hsdp_param) 

427 if hsdp_param.shard_size > 1: 

428 self._queue_reduce_scatter_then_all_reduce(hsdp_param, reduce_op) 

429 elif self._should_run_all_reduce(hsdp_param): 

430 self._queue_compat_all_reduce(hsdp_param, reduce_op) 

431 else: 

432 self.post_backward_for_comm_fusion() 

433 if self.reshard_after_backward: 

434 self.shard() 

435 

436 def reduce_params(self): 

437 """Apply reduced gradients from pre-staged HSDP parameters to sharded parameters. 

438 

439 This function processes two lists of pre-queued HSDP parameters (`pre_reduce_scatter_params` 

440 and `pre_all_reduce_params`), retrieves the reduced gradients from asynchronous 

441 reduce-scatter/all-reduce operations, clears cached communication outputs, and applies 

442 the reduced gradients to the corresponding sharded parameters (including reshaping, 

443 dtype conversion, optional CPU offloading, and gradient accumulation/assignment). 

444 

445 Note: 

446 - Parameters are processed in **FIFO (First-In-First-Out)** order (via `pop(0)`), ensuring 

447 gradient application order matches the order of gradient reduction operations. 

448 - After retrieving the reduced gradient, the cached communication output (reduce_scatter_output 

449 or all_reduce_output) is cleared to free memory and avoid stale data. 

450 - Gradient application logic (in `apply_reduced_grad`) includes: 

451 1. Reshaping the flat reduced gradient to match the local shard shape 

452 2. Optional dtype conversion to `param_type` 

453 3. Optional CPU offloading (per the HSDP parameter's offload policy) 

454 4. Assigning or accumulating the gradient to `sharded_param.grad` 

455 """ 

456 need_synchronize = False 

457 while HSDPState.pre_reduce_scatter_params: 

458 pre_hsdp_param, pre_orig_dtype = HSDPState.pre_reduce_scatter_params.pop(0) 

459 reduced_grad = pre_hsdp_param.reduce_scatter_output() 

460 pre_hsdp_param.clear_reduce_scatter_output() 

461 need_synchronize = pre_hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype) or need_synchronize 

462 

463 while HSDPState.pre_all_reduce_params: 

464 pre_hsdp_param, pre_orig_dtype = HSDPState.pre_all_reduce_params.pop(0) 

465 reduced_grad = pre_hsdp_param.all_reduce_output() 

466 pre_hsdp_param.clear_all_reduce_output() 

467 need_synchronize = pre_hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype) or need_synchronize 

468 

469 while TorchHSDPStateV2.pre_direct_all_reduce_grads: 

470 handle, reduced_grad, target_grad = TorchHSDPStateV2.pre_direct_all_reduce_grads.pop(0) 

471 if handle is not None: 

472 handle.wait() 

473 if reduced_grad is not target_grad: 

474 if reduced_grad.dtype != target_grad.dtype: 

475 reduced_grad = reduced_grad.to(target_grad.dtype) 

476 target_grad.copy_(reduced_grad) 

477 if need_synchronize: 

478 if self.device.type == "npu": 

479 torch.npu.current_stream().synchronize() 

480 elif self.device.type == "cuda": 

481 torch.cuda.current_stream().synchronize() 

482 else: 

483 raise NotImplementedError( 

484 f"Unsupported device type {self.device.type} for synchronization after CPU offload." 

485 ) 

486 

487 def set_requires_grad_sync(self, requires_grad_sync): 

488 """set requires grad sync flag to control gradient sync.""" 

489 self.reduce_grads = requires_grad_sync 

490 

491 def set_reduce_op_type(self, reduce_op_type: str): 

492 """set reduce op type for gradient reduction.""" 

493 fsdp_support_reduce_op = { 

494 "sum": torch.distributed.ReduceOp.SUM, 

495 "avg": torch.distributed.ReduceOp.AVG, 

496 } 

497 if reduce_op_type not in fsdp_support_reduce_op: 

498 raise ValueError( 

499 f"Unsupported reduce op type {reduce_op_type}, " 

500 f"supported types are {list(fsdp_support_reduce_op.keys())}" 

501 ) 

502 reduce_op: str = reduce_op_type.lower().strip() 

503 self._user_reduce_op_type = fsdp_support_reduce_op[reduce_op] 

504 self.reduce_op_type = self._user_reduce_op_type