Coverage for hyper_parallel / core / checkpoint / api.py: 87%
83 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Hyper Parallel Checkpoint API"""
16from pathlib import Path
17from typing import Any, Optional, Union
19from hyper_parallel.core.checkpoint.standard_planner import StandardSavePlanner, StandardLoadPlanner
20from hyper_parallel.core.checkpoint.filesystem_storage import FileSystemReader, FileSystemWriter
21from hyper_parallel.core.checkpoint.metadata import Metadata
22from hyper_parallel.core.checkpoint.planner import SavePlanner, LoadPlanner
23from hyper_parallel.core.checkpoint.storage import StorageReader, StorageWriter
24from hyper_parallel.platform import get_platform
25from hyper_parallel.platform.platform import Platform
28def _gather_from_all_ranks(
29 platform: Platform,
30 local_object: Any,
31 world_size: int,
32 use_collectives: bool,
33) -> list[Any]:
34 """
35 Gather objects from all ranks.
37 Args:
38 platform (Platform): Platform instance for communication.
39 local_object (Any): Local object for current rank.
40 world_size (int): Total number of ranks.
41 use_collectives (bool): Whether to use collective communication.
43 Returns:
44 list[Any]: List of all objects from all ranks.
45 """
46 if use_collectives and world_size > 1:
47 all_objects = [None] * world_size
48 platform.all_gather_object(all_objects, local_object)
49 return all_objects
50 return [local_object]
53def save(
54 state_dict: dict[str, Any],
55 *,
56 checkpoint_id: Optional[Union[Path, str]] = None,
57 storage_writer: Optional[StorageWriter] = None,
58 planner: Optional[SavePlanner] = None,
59 no_dist: bool = False,
60 use_collectives: bool = True,
61 remove_redundancy: bool = True,
62 save_to_minimum_rank: bool = False,
63) -> Metadata:
64 """
65 Save a distributed checkpoint in SPMD style.
67 This function saves a state_dict containing DTensors, where each rank
68 only saves their local shards.
70 Args:
71 state_dict (dict[str, Any]): The state_dict to save.
72 checkpoint_id (Optional[Union[Path, str]]): The ID/path of this checkpoint instance (can be Path or str).
73 Default None.
74 storage_writer (Optional[StorageWriter]): Instance of StorageWriter. If None, FileSystemWriter
75 will be created based on checkpoint_id. Default None.
76 planner (Optional[SavePlanner]): Instance of SavePlanner. If None, StandardSavePlanner will be used.
77 Default None.
78 no_dist (bool): If True, save in single process mode. Default False.
79 use_collectives (bool): If True, use collective communication for coordination.
80 If False, each rank saves its own shard data and rank-local metadata (.metadata_rank{rank}),
81 with no cross-rank interaction. Default True.
82 remove_redundancy (bool): If True, deduplicate tensors across ranks. Default True.
83 save_to_minimum_rank (bool): If True, deduplicated items are saved on the minimum rank. Default False.
85 Returns:
86 Metadata: Metadata object for the saved checkpoint.
87 """
88 platform = get_platform()
90 # Convert checkpoint_id to Path if it's a string
91 checkpoint_id = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id
93 # Determine if we're in distributed mode
94 use_collectives = False if no_dist else use_collectives
96 # When use_collectives=False: each rank saves its own shard, no cross-rank interaction
97 if not use_collectives:
98 remove_redundancy = False
100 # Set up storage writer
101 if storage_writer is None:
102 if checkpoint_id is None:
103 raise ValueError("Either storage_writer or checkpoint_id must be provided")
104 storage_writer = FileSystemWriter(checkpoint_id)
105 else:
106 if checkpoint_id:
107 storage_writer.initialize_writer(checkpoint_id)
109 # Set up planner
110 planner = StandardSavePlanner() if planner is None else planner
112 # Get rank and coordinator info
113 rank = platform.get_rank()
114 world_size = platform.get_world_size()
115 is_coordinator = rank == 0
117 # Configure planner (remove_redundancy=False when use_collectives=False)
118 planner.configure_planner(
119 state_dict=state_dict,
120 is_coordinator=is_coordinator,
121 rank=rank,
122 remove_redundancy=remove_redundancy,
123 save_to_minimum_rank=save_to_minimum_rank
124 )
126 # Configure storage writer (use_collectives for rank-local metadata when False)
127 storage_writer.configure_writer(
128 is_coordinator=is_coordinator,
129 rank=rank,
130 use_collectives=use_collectives
131 )
133 # Build local plan
134 local_plan = planner.build_local_plan()
135 local_plan = storage_writer.optimize_local_plan(local_plan)
137 # Gather all local plans and build global plan
138 all_local_plans = _gather_from_all_ranks(platform, local_plan, world_size, use_collectives)
139 global_plans, metadata = planner.build_global_plan(all_local_plans)
140 global_plans = storage_writer.optimize_global_plan(global_plans)
142 # Select central plan for current rank
143 if use_collectives and world_size > 1 and global_plans:
144 central_plan = global_plans[rank]
145 elif global_plans:
146 central_plan = global_plans[0]
147 else:
148 central_plan = local_plan
150 # Finalize plan
151 final_plan = planner.finalize_plan(central_plan)
153 # Write data
154 write_results = storage_writer.execute_write(final_plan, planner)
156 # Finalize checkpoint
157 all_write_results = _gather_from_all_ranks(platform, write_results, world_size, use_collectives)
158 storage_writer.finalize_checkpoint(metadata, all_write_results)
160 return metadata
163def load(
164 state_dict: dict[str, Any],
165 *,
166 checkpoint_id: Optional[Union[Path, str]] = None,
167 storage_reader: Optional[StorageReader] = None,
168 planner: Optional[LoadPlanner] = None,
169 no_dist: bool = False,
170 use_collectives: bool = True,
171) -> None:
172 """
173 Load a distributed checkpoint into state_dict in SPMD style.
175 Each rank will try to read the least amount of data necessary
176 to fulfill the requested state_dict. When loading DTensor instances,
177 each rank only reads data for their local shards.
179 Args:
180 state_dict (dict[str, Any]): The state_dict to load the checkpoint into (modified in-place).
181 checkpoint_id (Optional[Union[Path, str]]): The ID/path of this checkpoint instance (can be Path or str).
182 Default None.
183 storage_reader (Optional[StorageReader]): Instance of StorageReader. If None, FileSystemReader
184 will be created based on checkpoint_id. Default None.
185 planner (Optional[LoadPlanner]): Instance of LoadPlanner. If None, StandardLoadPlanner will be used.
186 Default None.
187 no_dist (bool): If True, load without cross-rank synchronization. Default False.
188 use_collectives (bool): If False, load from rank-local metadata (.metadata_rank{rank}),
189 for checkpoints saved with save(use_collectives=False). No cross-rank interaction. Default True.
191 Returns:
192 None. The state_dict is modified in-place.
193 """
194 platform = get_platform()
196 # Convert checkpoint_id to Path if it's a string
197 checkpoint_id = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id
199 # Determine if we're in distributed mode
200 use_collectives = False if no_dist else use_collectives
202 # Set up storage reader
203 if storage_reader is None:
204 if checkpoint_id is None:
205 raise ValueError("Either storage_reader or checkpoint_id must be provided")
206 storage_reader = FileSystemReader(checkpoint_id)
207 else:
208 if checkpoint_id:
209 storage_reader.initialize_reader(checkpoint_id)
211 # Set up planner
212 planner = StandardLoadPlanner() if planner is None else planner
214 # Get rank and coordinator info
215 rank = platform.get_rank()
216 world_size = platform.get_world_size()
217 is_coordinator = rank == 0
219 # Load metadata
220 if use_collectives:
221 try:
222 metadata = storage_reader.load_metadata()
223 except FileNotFoundError:
224 # Fallback to rank-local metadata (e.g. checkpoint saved with use_collectives=False)
225 metadata = storage_reader.load_metadata(rank=rank)
226 use_collectives = False
227 else:
228 # Load rank-local metadata directly (no cross-rank interaction)
229 metadata = storage_reader.load_metadata(rank=rank)
231 # Configure planner
232 planner.configure_planner(
233 state_dict=state_dict,
234 metadata=metadata,
235 is_coordinator=is_coordinator,
236 rank=rank
237 )
239 # Configure storage reader
240 storage_reader.configure_reader(
241 metadata=metadata,
242 is_coordinator=is_coordinator,
243 rank=rank,
244 use_collectives=use_collectives
245 )
247 # Build local plan
248 local_plan = planner.build_local_plan()
249 local_plan = storage_reader.optimize_local_plan(local_plan)
251 # Gather all local plans and build global plan
252 all_local_plans = _gather_from_all_ranks(platform, local_plan, world_size, use_collectives)
253 global_plans = planner.build_global_plan(all_local_plans)
254 global_plans = storage_reader.optimize_global_plan(global_plans)
256 # Select central plan for current rank
257 if use_collectives and world_size > 1 and global_plans:
258 central_plan = global_plans[rank]
259 elif global_plans:
260 central_plan = global_plans[0]
261 else:
262 central_plan = local_plan
264 # Finalize plan
265 final_plan = planner.finalize_plan(central_plan)
267 # Execute read
268 storage_reader.execute_read(final_plan, planner)