Coverage for hyper_parallel / core / fully_shard / api.py: 58%

249 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""hybrid shard data parallel interface""" 

16from typing import Any, Mapping, cast, Optional, Union 

17 

18import torch 

19import torch.distributed as dist 

20from torch import nn 

21from torch.distributed.checkpoint.state_dict import StateDictOptions 

22 

23from hyper_parallel.platform.platform import PlatformType 

24from hyper_parallel import DeviceMesh, init_device_mesh 

25from hyper_parallel.platform import get_platform 

26from hyper_parallel.core.dtensor import DTensor, distribute_tensor 

27 

28platform = get_platform() 

29 

30origin_class_to_extend_class = {} 

31 

32 

33class _UnshardHandle: 

34 def __init__(self, hsdp_state=None): 

35 self._hsdp_state = hsdp_state 

36 

37 def wait(self): 

38 if self._hsdp_state is not None: 

39 self._hsdp_state.wait_for_unshard() 

40 self._hsdp_state = None 

41 

42 

43class HSDPModule: 

44 """ 

45 The hsdp block of neural networks with hsdp interface. 

46 

47 Supported Platforms: 

48 ``MindSpore`` ``torch`` 

49 """ 

50 

51 # pylint: disable=C0415 

52 def hsdp_init(self, platform_type, module, mesh, reshard_after_forward, 

53 shard_placement_fn, mp_policy, offload_policy, ignored_params, device): 

54 """init hsdp2 scheduler.""" 

55 scheduler_class = None 

56 if platform_type == PlatformType.MINDSPORE: 

57 from hyper_parallel.platform.mindspore.hsdp.scheduler import MindSporeHSDPScheduler 

58 scheduler_class = MindSporeHSDPScheduler 

59 else: 

60 from hyper_parallel.platform.torch.fully_shard.scheduler import TorchHSDPSchedulerV2 

61 scheduler_class = TorchHSDPSchedulerV2 

62 

63 self.hsdp_scheduler = scheduler_class(module, 

64 mesh, 

65 reshard_after_forward, 

66 shard_placement_fn, 

67 mp_policy, 

68 offload_policy, 

69 ignored_params, 

70 device, 

71 ) 

72 

73 def set_requires_gradient_sync(self, requires_grad_sync): 

74 r""" 

75 set requires grad sync flag. 

76 Args: 

77 requires_grad_sync(bool): requires_grad_sync is used to control gradient sync process. 

78 Raises: 

79 ValueError: If `requires_grad_sync` is not bool. 

80 """ 

81 if not isinstance(requires_grad_sync, bool): 

82 raise ValueError(f"requires_grad_sync must be bool but got {requires_grad_sync}.") 

83 if not hasattr(self, "hsdp_scheduler"): 

84 raise ValueError("call hsdp interface first.") 

85 

86 for _, module in platform.get_cells_and_names(self): 

87 if isinstance(module, HSDPModule): 

88 module.hsdp_scheduler.set_requires_grad_sync(requires_grad_sync) 

89 

90 def zero_grads(self): 

91 """zero accumunication grads""" 

92 if not hasattr(self, "hsdp_scheduler"): 

93 raise ValueError("call hsdp interface first.") 

94 if platform == PlatformType.PYTORCH: 

95 raise RuntimeError("zero_grads shouldn't be called in torch platform, use optimizer.zero_grad() instead.") 

96 for _, module in platform.get_cells_and_names(self): 

97 if isinstance(module, HSDPModule): 

98 module.hsdp_scheduler.zero_grads() 

99 

100 def set_modules_to_forward_prefetch(self, modules): 

101 """set forward prefetch module list to prefetch all gather for unsharded parameters""" 

102 if not isinstance(modules, (tuple, list)): 

103 raise ValueError("modules must be HSDPModule list") 

104 for module in modules: 

105 if not isinstance(module, HSDPModule): 

106 raise ValueError(f"modules must be HSDPModule list but got {type(module)} in list.") 

107 if not hasattr(self, "hsdp_scheduler"): 

108 raise ValueError("call hsdp interface first.") 

109 self.hsdp_scheduler.set_forward_prefetch_cells(modules) 

110 

111 def set_modules_to_backward_prefetch(self, modules): 

112 """set backward prefetch module list to prefetch all gather for unsharded parameters""" 

113 if not isinstance(modules, (tuple, list)): 

114 raise ValueError("modules must be HSDPModule list") 

115 for module in modules: 

116 if not isinstance(module, HSDPModule): 

117 raise ValueError(f"modules must be HSDPModule list but got {type(module)} in list.") 

118 if not hasattr(self, "hsdp_scheduler"): 

119 raise ValueError("call fully_shard interface first.") 

120 self.hsdp_scheduler.set_backward_prefetch_cells(modules) 

121 

122 def reshard(self) -> None: 

123 """reshard all sharded parameters""" 

124 if not self.hsdp_scheduler: 

125 raise ValueError("hsdp_scheduler is None") 

126 scheduler_state = self.hsdp_scheduler.scheduler_state 

127 if scheduler_state: 

128 scheduler_state.shard() 

129 

130 def unshard(self, async_op: bool = False): 

131 """unshard all sharded parameters""" 

132 if not isinstance(async_op, bool): 

133 raise ValueError(f"async_op should be a bool, got {type(async_op)}") 

134 if not self.hsdp_scheduler: 

135 raise ValueError("hsdp_scheduler is None") 

136 scheduler_state = self.hsdp_scheduler.scheduler_state 

137 if scheduler_state: 

138 scheduler_state.unshard(async_op=async_op) 

139 if async_op: 

140 return _UnshardHandle(hsdp_state=scheduler_state) 

141 return None 

142 

143 def load_state_dict( 

144 self, 

145 state_dict: Mapping[str, Any], 

146 strict: bool = True, 

147 assign: bool = False, 

148 ): 

149 """ 

150 Load state dict by copying directly into local shards. 

151 

152 Bypasses ``super().load_state_dict()`` because the standard PyTorch 

153 implementation triggers ``copy_`` through the DTensor dispatcher, which 

154 is not registered in the hyper-parallel layout system. 

155 

156 Each value in ``state_dict`` is dispatched by type: 

157 - hyper DTensor: extract local shard and copy directly. 

158 - plain Tensor whose shape == local shard shape: copy as-is. 

159 - plain Tensor whose shape == global shape: distribute via 

160 ``distribute_tensor``, then copy the local shard. 

161 

162 Args: 

163 state_dict (Mapping[str, Any]): Fully-qualified parameter/buffer 

164 names mapped to tensors (DTensor or plain Tensor). 

165 strict (bool): If ``True`` (default), missing or unexpected keys 

166 raise ``RuntimeError``, matching ``nn.Module.load_state_dict`` 

167 semantics. 

168 assign (bool): Reserved for API compatibility with 

169 ``nn.Module.load_state_dict(assign=True)``. Currently unused. 

170 

171 Raises: 

172 RuntimeError: When ``strict`` is ``True`` and keys do not match. 

173 ValueError: When a plain tensor shape matches neither the local 

174 shard shape nor the global shape of the target DTensor. 

175 """ 

176 self_module = cast(nn.Module, self) 

177 

178 target_map: dict[str, torch.Tensor] = {} 

179 for name, p in self_module.named_parameters(): 

180 target_map[name] = p 

181 for name, b in self_module.named_buffers(): 

182 target_map[name] = b 

183 

184 if strict: 

185 expected_keys = set(self_module.state_dict().keys()) 

186 missing = expected_keys - set(state_dict.keys()) 

187 unexpected = set(state_dict.keys()) - expected_keys 

188 error_msgs: list[str] = [] 

189 if missing: 

190 error_msgs.append( 

191 "Missing key(s): " + ", ".join(repr(k) for k in sorted(missing)) 

192 ) 

193 if unexpected: 

194 error_msgs.append( 

195 "Unexpected key(s): " + ", ".join(repr(k) for k in sorted(unexpected)) 

196 ) 

197 if error_msgs: 

198 raise RuntimeError( 

199 f"Error(s) in loading state_dict for " 

200 f"{self_module.__class__.__name__}:\n\t" 

201 + "\n\t".join(error_msgs) 

202 ) 

203 

204 with torch.no_grad(): 

205 for key, val in state_dict.items(): 

206 target = target_map.get(key) 

207 if target is None: 

208 continue 

209 

210 if isinstance(target, DTensor): 

211 if isinstance(val, DTensor): 

212 local_val = val.to_local() 

213 else: 

214 local_shape = tuple(target.local_shape) 

215 global_shape = tuple(target.shape) 

216 val_shape = tuple(val.shape) 

217 if val_shape == local_shape: 

218 local_val = val 

219 elif val_shape == global_shape: 

220 wrapped = distribute_tensor( 

221 val.detach(), target.device_mesh, target.placements, 

222 ) 

223 local_val = wrapped.to_local() 

224 else: 

225 raise ValueError( 

226 f"load '{key}': plain tensor shape {val_shape} " 

227 f"matches neither local shard {local_shape} " 

228 f"nor global {global_shape}." 

229 ) 

230 if target.to_local().is_meta: 

231 # Meta tensor materialisation: replace the placeholder 

232 target._local_tensor = local_val # pylint: disable=protected-access 

233 else: 

234 target.to_local().copy_(local_val) 

235 else: 

236 target.copy_(val) 

237 

238 # Trigger load_state_dict post-hooks so that HSDP internal 

239 # bookkeeping (e.g. _sharded_param_data) stays in sync. 

240 for _, module in self_module.named_modules(): 

241 hooks = module._load_state_dict_post_hooks # pylint: disable=protected-access 

242 for hook in hooks.values(): 

243 hook(module, None) 

244 

245 def set_is_last_backward(self, is_last_backward: bool): 

246 """set is_last_backward flag""" 

247 self.hsdp_scheduler.scheduler_ctx.is_last_backward = is_last_backward 

248 

249 def set_requires_all_reduce(self, requires_all_reduce: bool, *, recurse: bool = True) -> None: 

250 """set requires_all_reduce flag""" 

251 if not isinstance(requires_all_reduce, bool): 

252 raise ValueError( 

253 f"requires_all_reduce should be a bool, got {type(requires_all_reduce)}" 

254 ) 

255 if not recurse: 

256 raise NotImplementedError(f"Currently impl is equal to recurse=True,\ 

257 need support module_param mapping.") 

258 self_module = cast(nn.Module, self) 

259 modules = list(self_module.modules()) if recurse else [self_module] 

260 for module in modules: 

261 if isinstance(module, HSDPModule): 

262 module.hsdp_scheduler.set_requires_all_reduce(requires_all_reduce) 

263 

264 def set_reshard_after_forward(self, reshard_after_forward: bool, recurse: bool = True) -> None: 

265 """set reshard_after_forward flag""" 

266 if not isinstance(reshard_after_forward, bool): 

267 raise ValueError( 

268 f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}" 

269 ) 

270 if not recurse: 

271 raise NotImplementedError(f"Currently impl is equal to recurse=True,\ 

272 need support module_param mapping.") 

273 self_module = cast(nn.Module, self) 

274 modules = list(self_module.modules()) if recurse else [self_module] 

275 for module in modules: 

276 if isinstance(module, HSDPModule): 

277 module.hsdp_scheduler.set_reshard_after_forward(reshard_after_forward) 

278 

279 def set_reshard_after_backward(self, reshard_after_backward: bool, recurse: bool = True) -> None: 

280 """set reshard_after_backward flag""" 

281 if not isinstance(reshard_after_backward, bool): 

282 raise ValueError( 

283 f"reshard_after_backward should be a bool, got {type(reshard_after_backward)}" 

284 ) 

285 if not recurse: 

286 raise NotImplementedError(f"Currently impl is equal to recurse=True,\ 

287 need support module_param mapping.") 

288 self_module = cast(nn.Module, self) 

289 modules = list(self_module.modules()) if recurse else [self_module] 

290 for module in modules: 

291 if isinstance(module, HSDPModule): 

292 module.hsdp_scheduler.set_reshard_after_backward(reshard_after_backward) 

293 

294 def set_reduce_op_type(self, reduce_op_type) -> None: 

295 """ 

296 set reduce_op_type for all reduce operations in HSDP 

297 support reduce_op_type "avg" and "sum", default is "avg" 

298 """ 

299 if hsdp_state := self.hsdp_scheduler.hsdp_state: 

300 hsdp_state.set_reduce_op_type(reduce_op_type) 

301 

302 

303def _extend_module_with_hsdp_interface(module): 

304 """extend Module with HSDPModule interface""" 

305 origin_class = module.__class__ 

306 extend_class = origin_class_to_extend_class.get(origin_class, None) 

307 if extend_class is None: 

308 extend_class = type(f"HSDP{origin_class.__name__}", (HSDPModule, origin_class), {}) 

309 origin_class_to_extend_class[origin_class] = extend_class 

310 module.__class__ = extend_class 

311 

312 

313# pylint: disable=C0415 

314def _check_module_valid(platform_type, module): 

315 """check module valid""" 

316 if platform_type == PlatformType.MINDSPORE: 

317 from mindspore.nn.cell import Cell 

318 if not isinstance(module, Cell): 

319 raise ValueError(f"module's type must be nn.cell but got {type(module)}.") 

320 else: 

321 from torch.nn import Module 

322 if not isinstance(module, Module): 

323 raise ValueError(f"module's type must be nn.Module but got {type(module)}.") 

324 

325 

326# pylint: disable=C0415 

327def _check_hsdp_input_valid(platform_type, module, shard_size, threshold, optimizer_level, enable_grad_accumulation, 

328 grad_scale, reduce_dtype, comm_async, comm_fusion, bucket_size): 

329 """check hsdp input valid""" 

330 _check_module_valid(platform_type, module) 

331 if not isinstance(shard_size, int) or (shard_size <= 0 and shard_size != -1): 

332 raise ValueError(f"shard_size must be a positive integer, but got {shard_size}.") 

333 if not isinstance(threshold, int) or threshold < 0: 

334 raise ValueError(f"threshold must be a positive integer or 0, but got {threshold}.") 

335 if optimizer_level not in ["level1", "level2", "level3"]: 

336 raise ValueError(f"Optimizer level should in ['level1', 'level2', 'level3'], but got {optimizer_level}.") 

337 if not isinstance(enable_grad_accumulation, bool): 

338 raise ValueError(f"enable_grad_accumulation must be bool but got {enable_grad_accumulation}.") 

339 if not isinstance(grad_scale, float): 

340 raise ValueError(f"grad_scale must be float but got {grad_scale}.") 

341 if platform_type == PlatformType.MINDSPORE: 

342 from mindspore._c_expression.typing import Type 

343 if reduce_dtype is not None and not isinstance(reduce_dtype, Type): 

344 raise ValueError(f"reduce_dtype must be mindspore.dtype but got {reduce_dtype}.") 

345 else: 

346 if reduce_dtype is not None and not isinstance(reduce_dtype, torch.dtype): 

347 raise ValueError(f"reduce_dtype must be torch.dtype but got {reduce_dtype}.") 

348 if not isinstance(comm_async, bool): 

349 raise ValueError(f"comm_async must be bool but got {comm_async}.") 

350 if not isinstance(comm_fusion, bool): 

351 raise ValueError(f"comm_fusion must be bool but got {comm_fusion}.") 

352 if not isinstance(bucket_size, int) or (bucket_size < 0 and bucket_size != -1): 

353 raise ValueError(f"bucket_size must be a positive integer or 0, but got {bucket_size}.") 

354 

355 

356def fully_shard( 

357 module: nn.Module, 

358 *, 

359 mesh: Optional[DeviceMesh] = None, 

360 reshard_after_forward: Optional[Union[bool, int]] = None, 

361 shard_placement_fn: None = None, 

362 mp_policy: None = None, 

363 offload_policy: None = None, 

364 ignored_params: Optional[set[nn.Parameter]] = None, 

365 device = None, 

366): 

367 platform_type = platform.platform_type 

368 _extend_module_with_hsdp_interface(module) 

369 # TODO: mindspore does not support get_device_handle 

370 if device is None: 

371 device_handle = platform.get_device_handle() # return torch.npu or torch.cuda 

372 if device_handle.is_available(): 

373 device = torch.device(device_handle.current_device()) 

374 else: 

375 device = torch.device("cpu") 

376 

377 mesh = mesh or init_device_mesh(device_type=device, mesh_shape=(platform.get_world_size(),)) 

378 

379 module.hsdp_init( 

380 platform_type, 

381 module, 

382 mesh, 

383 reshard_after_forward, 

384 shard_placement_fn, 

385 mp_policy, 

386 offload_policy, 

387 ignored_params, 

388 device, 

389 ) 

390 return module 

391 

392 

393def _gather_full_state_dict( 

394 state_dict: dict[str, Any], cpu_offload: bool 

395) -> dict[str, Any]: 

396 """All-gather every DTensor shard into a full tensor. 

397 

398 Args: 

399 state_dict: Model state dict with DTensor or plain tensor values. 

400 cpu_offload: If True, only rank-0 keeps the result on CPU; 

401 other ranks return an empty dict to save memory. 

402 """ 

403 is_rank0 = (not dist.is_initialized()) or (dist.get_rank() == 0) 

404 

405 gathered: dict[str, Any] = {} 

406 for key, val in state_dict.items(): 

407 if isinstance(val, DTensor): 

408 val = val.full_tensor() 

409 if cpu_offload: 

410 if not is_rank0: 

411 del val 

412 continue 

413 if isinstance(val, torch.Tensor): 

414 val = val.cpu() 

415 gathered[key] = val 

416 

417 if cpu_offload and not is_rank0: 

418 return {} 

419 return gathered 

420 

421 

422def _offload_sharded_state_dict( 

423 state_dict: dict[str, Any], 

424) -> dict[str, Any]: 

425 """Move each shard to CPU without all-gathering. 

426 

427 Args: 

428 state_dict: Model state dict with DTensor or plain tensor values. 

429 """ 

430 offloaded: dict[str, Any] = {} 

431 for key, val in state_dict.items(): 

432 if isinstance(val, DTensor): 

433 val = DTensor.from_local( 

434 val.to_local().cpu(), val.device_mesh, val.placements, 

435 ) 

436 elif isinstance(val, torch.Tensor): 

437 val = val.cpu() 

438 offloaded[key] = val 

439 return offloaded 

440 

441 

442def get_model_state_dict( 

443 model: nn.Module, 

444 *, 

445 options: StateDictOptions | None = None, 

446) -> dict[str, Any]: 

447 """Return the model state dict with configurable gathering and offloading. 

448 

449 Behaviour matrix: 

450 

451 +-----------------+-------------+--------------------------------------+ 

452 | full_state_dict | cpu_offload | result | 

453 +=================+=============+======================================+ 

454 | False | False | DTensor (sharded, as-is) | 

455 +-----------------+-------------+--------------------------------------+ 

456 | False | True | DTensor local shard offloaded to CPU | 

457 +-----------------+-------------+--------------------------------------+ 

458 | True | False | full Tensor on **every** rank | 

459 +-----------------+-------------+--------------------------------------+ 

460 | True | True | full Tensor on CPU, **rank 0 only** | 

461 +-----------------+-------------+--------------------------------------+ 

462 

463 Args: 

464 model: The model whose state dict to retrieve. 

465 options: Controls full_state_dict, cpu_offload, 

466 ignore_frozen_params, and broadcast_from_rank0 flags. 

467 """ 

468 options = options or StateDictOptions() 

469 

470 if options.broadcast_from_rank0 and not options.full_state_dict: 

471 raise ValueError( 

472 "full_state_dict must be True when broadcast_from_rank0 is True." 

473 ) 

474 

475 state_dict: dict[str, Any] = model.state_dict() 

476 

477 if options.ignore_frozen_params: 

478 frozen_keys = { 

479 name for name, p in model.named_parameters() 

480 if not p.requires_grad 

481 } 

482 for key in frozen_keys: 

483 state_dict.pop(key, None) 

484 

485 if options.full_state_dict: 

486 return _gather_full_state_dict(state_dict, options.cpu_offload) 

487 

488 if options.cpu_offload: 

489 return _offload_sharded_state_dict(state_dict) 

490 

491 return state_dict 

492 

493 

494def hsdp_sync_stream(): 

495 """wait for hsdp gradient handle to be completed""" 

496 if platform is None: 

497 return 

498 platform.wait_grad_handle()