Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / api.py: 23%
131 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Hyper Parallel Checkpoint API"""
16import multiprocessing as mp
17import queue
18import threading
19import traceback
20from concurrent.futures import Future
21from dataclasses import dataclass
22from enum import Enum, auto
23from pathlib import Path
24from typing import Any, Optional, Union
26from hyper_parallel.core.distributed_checkpoint.async_staging import build_staged_state_dict
27from hyper_parallel.core.distributed_checkpoint.standard_planner import StandardSavePlanner, StandardLoadPlanner
28from hyper_parallel.core.distributed_checkpoint.filesystem_storage import FileSystemReader, FileSystemWriter
29from hyper_parallel.core.distributed_checkpoint.metadata import Metadata
30from hyper_parallel.core.distributed_checkpoint.planner import SavePlanner, LoadPlanner
31from hyper_parallel.core.distributed_checkpoint.storage import StorageReader, StorageWriter
32from hyper_parallel.platform import get_platform
34platform = get_platform()
37class _AsyncPersistStatus(Enum):
38 """Queue payload status from :func:`_async_persist_worker` to the parent join thread."""
40 SUCCESS = auto()
41 FAILURE = auto()
44@dataclass
45class AsyncSaveResponse:
46 """Result of :func:`async_save`.
48 Host staging runs synchronously before :func:`async_save` returns; only checkpoint
49 **persistence** is asynchronous. ``persist_completion`` completes when the child
50 process finishes :func:`_save_impl` (plan, collectives, disk I/O) and supplies
51 :class:`Metadata`.
52 """
54 persist_completion: Future[Metadata]
57def _gather_from_all_ranks(
58 local_object: Any,
59 world_size: int,
60 use_collectives: bool,
61) -> list[Any]:
62 """
63 Gather objects from all ranks.
65 Args:
66 local_object (Any): Local object for current rank.
67 world_size (int): Total number of ranks.
68 use_collectives (bool): Whether to use collective communication.
70 Returns:
71 list[Any]: List of all objects from all ranks.
72 """
73 if use_collectives and world_size > 1:
74 all_objects = [None] * world_size
75 platform.all_gather_object(all_objects, local_object)
76 return all_objects
77 return [local_object]
80def _save_impl(
81 state_dict: dict[str, Any],
82 *,
83 checkpoint_id: Optional[Union[Path, str]] = None,
84 storage_writer: Optional[StorageWriter] = None,
85 planner: Optional[SavePlanner] = None,
86 no_dist: bool = False,
87 use_collectives: bool = True,
88) -> Metadata:
89 """Synchronous distributed checkpoint save (shared by :func:`save` and :func:`async_save`)."""
90 # Convert checkpoint_id to Path if it's a string
91 checkpoint_id = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id
93 # Determine if we're in distributed mode
94 use_collectives = False if no_dist else use_collectives
96 # Set up storage writer
97 if storage_writer is None:
98 if checkpoint_id is None:
99 raise ValueError("Either storage_writer or checkpoint_id must be provided")
100 storage_writer = FileSystemWriter(checkpoint_id)
101 else:
102 if checkpoint_id:
103 storage_writer.initialize_writer(checkpoint_id)
105 # Set up planner
106 planner = StandardSavePlanner() if planner is None else planner
108 # Get rank and coordinator info
109 rank = platform.get_rank()
110 world_size = platform.get_world_size()
111 is_coordinator = rank == 0
113 # Configure planner
114 planner.configure_planner(
115 state_dict=state_dict,
116 is_coordinator=is_coordinator,
117 rank=rank,
118 use_collectives=use_collectives
119 )
121 # Configure storage writer (use_collectives for rank-local metadata when False)
122 storage_writer.configure_writer(
123 is_coordinator=is_coordinator,
124 rank=rank,
125 use_collectives=use_collectives
126 )
128 cached = planner.get_cached_result() if isinstance(planner, StandardSavePlanner) else None
129 if cached is not None:
130 final_plan, metadata = cached
131 else:
132 # Build local plan
133 local_plan = planner.build_local_plan()
134 local_plan = storage_writer.optimize_local_plan(local_plan)
136 # Gather all local plans and build global plan
137 all_local_plans = _gather_from_all_ranks(local_plan, world_size, use_collectives)
138 global_plans, metadata = planner.build_global_plan(all_local_plans)
139 global_plans = storage_writer.optimize_global_plan(global_plans)
141 # Select central plan for current rank
142 if use_collectives and world_size > 1 and global_plans:
143 central_plan = global_plans[rank]
144 elif global_plans:
145 central_plan = global_plans[0]
146 else:
147 central_plan = local_plan
149 # Finalize and cache plan
150 final_plan = planner.finalize_plan(central_plan)
151 if isinstance(planner, StandardSavePlanner):
152 planner.cache_result(final_plan, metadata)
154 # Write data
155 write_results = storage_writer.execute_write(final_plan, planner)
157 # Finalize checkpoint
158 all_write_results = _gather_from_all_ranks(write_results, world_size, use_collectives)
159 storage_writer.finalize_checkpoint(metadata, all_write_results)
161 return metadata
164def _async_persist_worker(
165 result_queue: mp.Queue,
166 staged: dict[str, Any],
167 checkpoint_id: Optional[Union[Path, str]],
168 storage_writer: Optional[StorageWriter],
169 planner: Optional[SavePlanner],
170 no_dist: bool,
171 use_collectives: bool,
172) -> None:
173 """Child-process entry: run :func:`_save_impl` and report ``Metadata`` or an error string on ``result_queue``."""
174 try:
175 meta = _save_impl(
176 staged,
177 checkpoint_id=checkpoint_id,
178 storage_writer=storage_writer,
179 planner=planner,
180 no_dist=no_dist,
181 use_collectives=use_collectives,
182 )
183 result_queue.put((_AsyncPersistStatus.SUCCESS, meta))
184 except Exception: # pylint: disable=broad-except
185 result_queue.put((_AsyncPersistStatus.FAILURE, traceback.format_exc()))
188def _async_persist_wait_process(
189 proc: mp.Process,
190 result_queue: mp.Queue,
191 persist_future: Future[Metadata],
192) -> None:
193 """Join persist ``proc`` and complete ``persist_future`` (runs on a daemon thread)."""
194 proc.join()
195 if persist_future.done():
196 return
197 try:
198 status, payload = result_queue.get_nowait()
199 except queue.Empty:
200 persist_future.set_exception(
201 RuntimeError(
202 f"async_persist process exited with code {proc.exitcode} and no result on queue"
203 )
204 )
205 return
206 if status == _AsyncPersistStatus.SUCCESS:
207 persist_future.set_result(payload)
208 elif status == _AsyncPersistStatus.FAILURE:
209 persist_future.set_exception(RuntimeError(payload))
210 else:
211 persist_future.set_exception(
212 RuntimeError(f"async_persist queue returned unexpected status: {status!r}")
213 )
216def save(
217 state_dict: dict[str, Any],
218 *,
219 checkpoint_id: Optional[Union[Path, str]] = None,
220 storage_writer: Optional[StorageWriter] = None,
221 planner: Optional[SavePlanner] = None,
222 no_dist: bool = False,
223 use_collectives: bool = True,
224) -> Metadata:
225 """
226 Save a distributed checkpoint in SPMD style.
228 This function saves a state_dict containing DTensors, where each rank
229 only saves their local shards.
231 Args:
232 state_dict (dict[str, Any]): The state_dict to save.
233 checkpoint_id (Optional[Union[Path, str]]): The ID/path of this checkpoint instance (can be Path or str).
234 Default None.
235 storage_writer (Optional[StorageWriter]): Instance of StorageWriter. If None, FileSystemWriter
236 will be created based on checkpoint_id. Default None.
237 planner (Optional[SavePlanner]): Instance of SavePlanner. If None, StandardSavePlanner will be used.
238 Default None.
239 no_dist (bool): If True, save in single process mode. Default False.
240 use_collectives (bool): If True, use collective communication for coordination.
241 If False, each rank saves its own shard data and rank-local metadata (.metadata_rank{rank}),
242 with no cross-rank interaction. Default True.
244 Returns:
245 Metadata: Metadata object for the saved checkpoint.
246 """
247 metadata = _save_impl(
248 state_dict,
249 checkpoint_id=checkpoint_id,
250 storage_writer=storage_writer,
251 planner=planner,
252 no_dist=no_dist,
253 use_collectives=use_collectives,
254 )
255 platform.barrier()
256 return metadata
259def async_save(
260 state_dict: dict[str, Any],
261 *,
262 checkpoint_id: Optional[Union[Path, str]] = None,
263 storage_writer: Optional[StorageWriter] = None,
264 planner: Optional[SavePlanner] = None,
265 no_dist: bool = False,
266 use_collectives: bool = True,
267) -> AsyncSaveResponse:
268 """
269 Asynchronous version of :func:`save` using a **background child process** for persistence.
271 **Staging** (tensor / DTensor → host copy) runs **synchronously in the caller
272 process** via :func:`build_staged_state_dict`, so no process pool is used for
273 staging and the training stack sees a normal Python call path. When this
274 function returns successfully, host staging is done and the original
275 ``state_dict`` may be mutated.
277 **Persistence** (plan, collectives, disk I/O) runs in **one** background
278 :class:`multiprocessing.Process` that executes :func:`_save_impl` on the staged
279 dict. A small daemon **thread** only joins that process and fills
280 ``persist_completion``; it does not perform tensor work.
282 The staged dict and ``storage_writer`` / ``planner`` must be picklable for the
283 persist child process (same constraints as before for the worker path).
285 .. warning::
286 Experimental API. Always wait on ``persist_completion`` for a fully persisted checkpoint.
288 Args:
289 state_dict (dict[str, Any]): The state_dict to save.
290 checkpoint_id (Optional[Union[Path, str]]): Same as :func:`save`.
291 storage_writer (Optional[StorageWriter]): Same as :func:`save`.
292 planner (Optional[SavePlanner]): Same as :func:`save`.
293 no_dist (bool): Same as :func:`save`.
294 use_collectives (bool): Same as :func:`save`.
296 Returns:
297 AsyncSaveResponse: Contains ``persist_completion`` only; staging is synchronous.
298 """
299 persist_completion: Future[Metadata] = Future()
301 staged = build_staged_state_dict(state_dict)
303 result_queue: mp.Queue = mp.Queue(maxsize=1)
304 proc = mp.Process(
305 target=_async_persist_worker,
306 args=(
307 result_queue,
308 staged,
309 checkpoint_id,
310 storage_writer,
311 planner,
312 no_dist,
313 use_collectives,
314 ),
315 name="HPAsyncCheckpointPersist",
316 )
317 proc.start()
318 join_thread = threading.Thread(
319 target=_async_persist_wait_process,
320 args=(proc, result_queue, persist_completion),
321 daemon=True,
322 name="HPAsyncCheckpointPersistJoin",
323 )
324 join_thread.start()
325 return AsyncSaveResponse(persist_completion=persist_completion)
328def load(
329 state_dict: dict[str, Any],
330 *,
331 checkpoint_id: Optional[Union[Path, str]] = None,
332 storage_reader: Optional[StorageReader] = None,
333 planner: Optional[LoadPlanner] = None,
334 no_dist: bool = False,
335 use_collectives: bool = True,
336) -> None:
337 """
338 Load a distributed checkpoint into state_dict in SPMD style.
340 Each rank will try to read the least amount of data necessary
341 to fulfill the requested state_dict. When loading DTensor instances,
342 each rank only reads data for their local shards.
344 Args:
345 state_dict (dict[str, Any]): The state_dict to load the checkpoint into (modified in-place).
346 checkpoint_id (Optional[Union[Path, str]]): The ID/path of this checkpoint instance (can be Path or str).
347 Default None.
348 storage_reader (Optional[StorageReader]): Instance of StorageReader. If None, FileSystemReader
349 will be created based on checkpoint_id. Default None.
350 planner (Optional[LoadPlanner]): Instance of LoadPlanner. If None, StandardLoadPlanner will be used.
351 Default None.
352 no_dist (bool): If True, load without cross-rank synchronization. Default False.
353 use_collectives (bool): If False, load from rank-local metadata (.metadata_rank{rank}),
354 for checkpoints saved with save(use_collectives=False). No cross-rank interaction. Default True.
356 Returns:
357 None. The state_dict is modified in-place.
358 """
359 # Convert checkpoint_id to Path if it's a string
360 checkpoint_id = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id
362 # Determine if we're in distributed mode
363 use_collectives = False if no_dist else use_collectives
365 # Set up storage reader
366 if storage_reader is None:
367 if checkpoint_id is None:
368 raise ValueError("Either storage_reader or checkpoint_id must be provided")
369 storage_reader = FileSystemReader(checkpoint_id)
370 else:
371 if checkpoint_id:
372 storage_reader.initialize_reader(checkpoint_id)
374 # Set up planner
375 planner = StandardLoadPlanner() if planner is None else planner
377 # Get rank and coordinator info
378 rank = platform.get_rank()
379 world_size = platform.get_world_size()
380 is_coordinator = rank == 0
382 # Load metadata
383 if use_collectives:
384 try:
385 metadata = storage_reader.load_metadata()
386 except FileNotFoundError:
387 # Fallback to rank-local metadata (e.g. checkpoint saved with use_collectives=False)
388 metadata = storage_reader.load_metadata(rank=rank)
389 use_collectives = False
390 else:
391 # Load rank-local metadata directly (no cross-rank interaction)
392 metadata = storage_reader.load_metadata(rank=rank)
394 # Configure planner
395 planner.configure_planner(
396 state_dict=state_dict,
397 metadata=metadata,
398 is_coordinator=is_coordinator,
399 rank=rank
400 )
402 # Configure storage reader
403 storage_reader.configure_reader(
404 metadata=metadata,
405 is_coordinator=is_coordinator,
406 rank=rank,
407 use_collectives=use_collectives
408 )
410 # Build local plan
411 local_plan = planner.build_local_plan()
412 local_plan = storage_reader.optimize_local_plan(local_plan)
414 # Gather all local plans and build global plan
415 all_local_plans = _gather_from_all_ranks(local_plan, world_size, use_collectives)
416 global_plans = planner.build_global_plan(all_local_plans)
417 global_plans = storage_reader.optimize_global_plan(global_plans)
419 # Select central plan for current rank
420 if use_collectives and world_size > 1 and global_plans:
421 central_plan = global_plans[rank]
422 elif global_plans:
423 central_plan = global_plans[0]
424 else:
425 central_plan = local_plan
427 # Finalize plan
428 final_plan = planner.finalize_plan(central_plan)
430 # Execute read
431 storage_reader.execute_read(final_plan, planner)