Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / api.py: 23%

131 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Hyper Parallel Checkpoint API""" 

16import multiprocessing as mp 

17import queue 

18import threading 

19import traceback 

20from concurrent.futures import Future 

21from dataclasses import dataclass 

22from enum import Enum, auto 

23from pathlib import Path 

24from typing import Any, Optional, Union 

25 

26from hyper_parallel.core.distributed_checkpoint.async_staging import build_staged_state_dict 

27from hyper_parallel.core.distributed_checkpoint.standard_planner import StandardSavePlanner, StandardLoadPlanner 

28from hyper_parallel.core.distributed_checkpoint.filesystem_storage import FileSystemReader, FileSystemWriter 

29from hyper_parallel.core.distributed_checkpoint.metadata import Metadata 

30from hyper_parallel.core.distributed_checkpoint.planner import SavePlanner, LoadPlanner 

31from hyper_parallel.core.distributed_checkpoint.storage import StorageReader, StorageWriter 

32from hyper_parallel.platform import get_platform 

33 

34platform = get_platform() 

35 

36 

37class _AsyncPersistStatus(Enum): 

38 """Queue payload status from :func:`_async_persist_worker` to the parent join thread.""" 

39 

40 SUCCESS = auto() 

41 FAILURE = auto() 

42 

43 

44@dataclass 

45class AsyncSaveResponse: 

46 """Result of :func:`async_save`. 

47 

48 Host staging runs synchronously before :func:`async_save` returns; only checkpoint 

49 **persistence** is asynchronous. ``persist_completion`` completes when the child 

50 process finishes :func:`_save_impl` (plan, collectives, disk I/O) and supplies 

51 :class:`Metadata`. 

52 """ 

53 

54 persist_completion: Future[Metadata] 

55 

56 

57def _gather_from_all_ranks( 

58 local_object: Any, 

59 world_size: int, 

60 use_collectives: bool, 

61) -> list[Any]: 

62 """ 

63 Gather objects from all ranks. 

64 

65 Args: 

66 local_object (Any): Local object for current rank. 

67 world_size (int): Total number of ranks. 

68 use_collectives (bool): Whether to use collective communication. 

69 

70 Returns: 

71 list[Any]: List of all objects from all ranks. 

72 """ 

73 if use_collectives and world_size > 1: 

74 all_objects = [None] * world_size 

75 platform.all_gather_object(all_objects, local_object) 

76 return all_objects 

77 return [local_object] 

78 

79 

80def _save_impl( 

81 state_dict: dict[str, Any], 

82 *, 

83 checkpoint_id: Optional[Union[Path, str]] = None, 

84 storage_writer: Optional[StorageWriter] = None, 

85 planner: Optional[SavePlanner] = None, 

86 no_dist: bool = False, 

87 use_collectives: bool = True, 

88) -> Metadata: 

89 """Synchronous distributed checkpoint save (shared by :func:`save` and :func:`async_save`).""" 

90 # Convert checkpoint_id to Path if it's a string 

91 checkpoint_id = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id 

92 

93 # Determine if we're in distributed mode 

94 use_collectives = False if no_dist else use_collectives 

95 

96 # Set up storage writer 

97 if storage_writer is None: 

98 if checkpoint_id is None: 

99 raise ValueError("Either storage_writer or checkpoint_id must be provided") 

100 storage_writer = FileSystemWriter(checkpoint_id) 

101 else: 

102 if checkpoint_id: 

103 storage_writer.initialize_writer(checkpoint_id) 

104 

105 # Set up planner 

106 planner = StandardSavePlanner() if planner is None else planner 

107 

108 # Get rank and coordinator info 

109 rank = platform.get_rank() 

110 world_size = platform.get_world_size() 

111 is_coordinator = rank == 0 

112 

113 # Configure planner 

114 planner.configure_planner( 

115 state_dict=state_dict, 

116 is_coordinator=is_coordinator, 

117 rank=rank, 

118 use_collectives=use_collectives 

119 ) 

120 

121 # Configure storage writer (use_collectives for rank-local metadata when False) 

122 storage_writer.configure_writer( 

123 is_coordinator=is_coordinator, 

124 rank=rank, 

125 use_collectives=use_collectives 

126 ) 

127 

128 cached = planner.get_cached_result() if isinstance(planner, StandardSavePlanner) else None 

129 if cached is not None: 

130 final_plan, metadata = cached 

131 else: 

132 # Build local plan 

133 local_plan = planner.build_local_plan() 

134 local_plan = storage_writer.optimize_local_plan(local_plan) 

135 

136 # Gather all local plans and build global plan 

137 all_local_plans = _gather_from_all_ranks(local_plan, world_size, use_collectives) 

138 global_plans, metadata = planner.build_global_plan(all_local_plans) 

139 global_plans = storage_writer.optimize_global_plan(global_plans) 

140 

141 # Select central plan for current rank 

142 if use_collectives and world_size > 1 and global_plans: 

143 central_plan = global_plans[rank] 

144 elif global_plans: 

145 central_plan = global_plans[0] 

146 else: 

147 central_plan = local_plan 

148 

149 # Finalize and cache plan 

150 final_plan = planner.finalize_plan(central_plan) 

151 if isinstance(planner, StandardSavePlanner): 

152 planner.cache_result(final_plan, metadata) 

153 

154 # Write data 

155 write_results = storage_writer.execute_write(final_plan, planner) 

156 

157 # Finalize checkpoint 

158 all_write_results = _gather_from_all_ranks(write_results, world_size, use_collectives) 

159 storage_writer.finalize_checkpoint(metadata, all_write_results) 

160 

161 return metadata 

162 

163 

164def _async_persist_worker( 

165 result_queue: mp.Queue, 

166 staged: dict[str, Any], 

167 checkpoint_id: Optional[Union[Path, str]], 

168 storage_writer: Optional[StorageWriter], 

169 planner: Optional[SavePlanner], 

170 no_dist: bool, 

171 use_collectives: bool, 

172) -> None: 

173 """Child-process entry: run :func:`_save_impl` and report ``Metadata`` or an error string on ``result_queue``.""" 

174 try: 

175 meta = _save_impl( 

176 staged, 

177 checkpoint_id=checkpoint_id, 

178 storage_writer=storage_writer, 

179 planner=planner, 

180 no_dist=no_dist, 

181 use_collectives=use_collectives, 

182 ) 

183 result_queue.put((_AsyncPersistStatus.SUCCESS, meta)) 

184 except Exception: # pylint: disable=broad-except 

185 result_queue.put((_AsyncPersistStatus.FAILURE, traceback.format_exc())) 

186 

187 

188def _async_persist_wait_process( 

189 proc: mp.Process, 

190 result_queue: mp.Queue, 

191 persist_future: Future[Metadata], 

192) -> None: 

193 """Join persist ``proc`` and complete ``persist_future`` (runs on a daemon thread).""" 

194 proc.join() 

195 if persist_future.done(): 

196 return 

197 try: 

198 status, payload = result_queue.get_nowait() 

199 except queue.Empty: 

200 persist_future.set_exception( 

201 RuntimeError( 

202 f"async_persist process exited with code {proc.exitcode} and no result on queue" 

203 ) 

204 ) 

205 return 

206 if status == _AsyncPersistStatus.SUCCESS: 

207 persist_future.set_result(payload) 

208 elif status == _AsyncPersistStatus.FAILURE: 

209 persist_future.set_exception(RuntimeError(payload)) 

210 else: 

211 persist_future.set_exception( 

212 RuntimeError(f"async_persist queue returned unexpected status: {status!r}") 

213 ) 

214 

215 

216def save( 

217 state_dict: dict[str, Any], 

218 *, 

219 checkpoint_id: Optional[Union[Path, str]] = None, 

220 storage_writer: Optional[StorageWriter] = None, 

221 planner: Optional[SavePlanner] = None, 

222 no_dist: bool = False, 

223 use_collectives: bool = True, 

224) -> Metadata: 

225 """ 

226 Save a distributed checkpoint in SPMD style. 

227 

228 This function saves a state_dict containing DTensors, where each rank 

229 only saves their local shards. 

230 

231 Args: 

232 state_dict (dict[str, Any]): The state_dict to save. 

233 checkpoint_id (Optional[Union[Path, str]]): The ID/path of this checkpoint instance (can be Path or str). 

234 Default None. 

235 storage_writer (Optional[StorageWriter]): Instance of StorageWriter. If None, FileSystemWriter 

236 will be created based on checkpoint_id. Default None. 

237 planner (Optional[SavePlanner]): Instance of SavePlanner. If None, StandardSavePlanner will be used. 

238 Default None. 

239 no_dist (bool): If True, save in single process mode. Default False. 

240 use_collectives (bool): If True, use collective communication for coordination. 

241 If False, each rank saves its own shard data and rank-local metadata (.metadata_rank{rank}), 

242 with no cross-rank interaction. Default True. 

243 

244 Returns: 

245 Metadata: Metadata object for the saved checkpoint. 

246 """ 

247 metadata = _save_impl( 

248 state_dict, 

249 checkpoint_id=checkpoint_id, 

250 storage_writer=storage_writer, 

251 planner=planner, 

252 no_dist=no_dist, 

253 use_collectives=use_collectives, 

254 ) 

255 platform.barrier() 

256 return metadata 

257 

258 

259def async_save( 

260 state_dict: dict[str, Any], 

261 *, 

262 checkpoint_id: Optional[Union[Path, str]] = None, 

263 storage_writer: Optional[StorageWriter] = None, 

264 planner: Optional[SavePlanner] = None, 

265 no_dist: bool = False, 

266 use_collectives: bool = True, 

267) -> AsyncSaveResponse: 

268 """ 

269 Asynchronous version of :func:`save` using a **background child process** for persistence. 

270 

271 **Staging** (tensor / DTensor → host copy) runs **synchronously in the caller 

272 process** via :func:`build_staged_state_dict`, so no process pool is used for 

273 staging and the training stack sees a normal Python call path. When this 

274 function returns successfully, host staging is done and the original 

275 ``state_dict`` may be mutated. 

276 

277 **Persistence** (plan, collectives, disk I/O) runs in **one** background 

278 :class:`multiprocessing.Process` that executes :func:`_save_impl` on the staged 

279 dict. A small daemon **thread** only joins that process and fills 

280 ``persist_completion``; it does not perform tensor work. 

281 

282 The staged dict and ``storage_writer`` / ``planner`` must be picklable for the 

283 persist child process (same constraints as before for the worker path). 

284 

285 .. warning:: 

286 Experimental API. Always wait on ``persist_completion`` for a fully persisted checkpoint. 

287 

288 Args: 

289 state_dict (dict[str, Any]): The state_dict to save. 

290 checkpoint_id (Optional[Union[Path, str]]): Same as :func:`save`. 

291 storage_writer (Optional[StorageWriter]): Same as :func:`save`. 

292 planner (Optional[SavePlanner]): Same as :func:`save`. 

293 no_dist (bool): Same as :func:`save`. 

294 use_collectives (bool): Same as :func:`save`. 

295 

296 Returns: 

297 AsyncSaveResponse: Contains ``persist_completion`` only; staging is synchronous. 

298 """ 

299 persist_completion: Future[Metadata] = Future() 

300 

301 staged = build_staged_state_dict(state_dict) 

302 

303 result_queue: mp.Queue = mp.Queue(maxsize=1) 

304 proc = mp.Process( 

305 target=_async_persist_worker, 

306 args=( 

307 result_queue, 

308 staged, 

309 checkpoint_id, 

310 storage_writer, 

311 planner, 

312 no_dist, 

313 use_collectives, 

314 ), 

315 name="HPAsyncCheckpointPersist", 

316 ) 

317 proc.start() 

318 join_thread = threading.Thread( 

319 target=_async_persist_wait_process, 

320 args=(proc, result_queue, persist_completion), 

321 daemon=True, 

322 name="HPAsyncCheckpointPersistJoin", 

323 ) 

324 join_thread.start() 

325 return AsyncSaveResponse(persist_completion=persist_completion) 

326 

327 

328def load( 

329 state_dict: dict[str, Any], 

330 *, 

331 checkpoint_id: Optional[Union[Path, str]] = None, 

332 storage_reader: Optional[StorageReader] = None, 

333 planner: Optional[LoadPlanner] = None, 

334 no_dist: bool = False, 

335 use_collectives: bool = True, 

336) -> None: 

337 """ 

338 Load a distributed checkpoint into state_dict in SPMD style. 

339 

340 Each rank will try to read the least amount of data necessary 

341 to fulfill the requested state_dict. When loading DTensor instances, 

342 each rank only reads data for their local shards. 

343 

344 Args: 

345 state_dict (dict[str, Any]): The state_dict to load the checkpoint into (modified in-place). 

346 checkpoint_id (Optional[Union[Path, str]]): The ID/path of this checkpoint instance (can be Path or str). 

347 Default None. 

348 storage_reader (Optional[StorageReader]): Instance of StorageReader. If None, FileSystemReader 

349 will be created based on checkpoint_id. Default None. 

350 planner (Optional[LoadPlanner]): Instance of LoadPlanner. If None, StandardLoadPlanner will be used. 

351 Default None. 

352 no_dist (bool): If True, load without cross-rank synchronization. Default False. 

353 use_collectives (bool): If False, load from rank-local metadata (.metadata_rank{rank}), 

354 for checkpoints saved with save(use_collectives=False). No cross-rank interaction. Default True. 

355 

356 Returns: 

357 None. The state_dict is modified in-place. 

358 """ 

359 # Convert checkpoint_id to Path if it's a string 

360 checkpoint_id = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id 

361 

362 # Determine if we're in distributed mode 

363 use_collectives = False if no_dist else use_collectives 

364 

365 # Set up storage reader 

366 if storage_reader is None: 

367 if checkpoint_id is None: 

368 raise ValueError("Either storage_reader or checkpoint_id must be provided") 

369 storage_reader = FileSystemReader(checkpoint_id) 

370 else: 

371 if checkpoint_id: 

372 storage_reader.initialize_reader(checkpoint_id) 

373 

374 # Set up planner 

375 planner = StandardLoadPlanner() if planner is None else planner 

376 

377 # Get rank and coordinator info 

378 rank = platform.get_rank() 

379 world_size = platform.get_world_size() 

380 is_coordinator = rank == 0 

381 

382 # Load metadata 

383 if use_collectives: 

384 try: 

385 metadata = storage_reader.load_metadata() 

386 except FileNotFoundError: 

387 # Fallback to rank-local metadata (e.g. checkpoint saved with use_collectives=False) 

388 metadata = storage_reader.load_metadata(rank=rank) 

389 use_collectives = False 

390 else: 

391 # Load rank-local metadata directly (no cross-rank interaction) 

392 metadata = storage_reader.load_metadata(rank=rank) 

393 

394 # Configure planner 

395 planner.configure_planner( 

396 state_dict=state_dict, 

397 metadata=metadata, 

398 is_coordinator=is_coordinator, 

399 rank=rank 

400 ) 

401 

402 # Configure storage reader 

403 storage_reader.configure_reader( 

404 metadata=metadata, 

405 is_coordinator=is_coordinator, 

406 rank=rank, 

407 use_collectives=use_collectives 

408 ) 

409 

410 # Build local plan 

411 local_plan = planner.build_local_plan() 

412 local_plan = storage_reader.optimize_local_plan(local_plan) 

413 

414 # Gather all local plans and build global plan 

415 all_local_plans = _gather_from_all_ranks(local_plan, world_size, use_collectives) 

416 global_plans = planner.build_global_plan(all_local_plans) 

417 global_plans = storage_reader.optimize_global_plan(global_plans) 

418 

419 # Select central plan for current rank 

420 if use_collectives and world_size > 1 and global_plans: 

421 central_plan = global_plans[rank] 

422 elif global_plans: 

423 central_plan = global_plans[0] 

424 else: 

425 central_plan = local_plan 

426 

427 # Finalize plan 

428 final_plan = planner.finalize_plan(central_plan) 

429 

430 # Execute read 

431 storage_reader.execute_read(final_plan, planner)