Coverage for hyper_parallel / core / checkpoint / api.py: 87%

83 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Hyper Parallel Checkpoint API""" 

16from pathlib import Path 

17from typing import Any, Optional, Union 

18 

19from hyper_parallel.core.checkpoint.standard_planner import StandardSavePlanner, StandardLoadPlanner 

20from hyper_parallel.core.checkpoint.filesystem_storage import FileSystemReader, FileSystemWriter 

21from hyper_parallel.core.checkpoint.metadata import Metadata 

22from hyper_parallel.core.checkpoint.planner import SavePlanner, LoadPlanner 

23from hyper_parallel.core.checkpoint.storage import StorageReader, StorageWriter 

24from hyper_parallel.platform import get_platform 

25from hyper_parallel.platform.platform import Platform 

26 

27 

28def _gather_from_all_ranks( 

29 platform: Platform, 

30 local_object: Any, 

31 world_size: int, 

32 use_collectives: bool, 

33) -> list[Any]: 

34 """ 

35 Gather objects from all ranks. 

36 

37 Args: 

38 platform (Platform): Platform instance for communication. 

39 local_object (Any): Local object for current rank. 

40 world_size (int): Total number of ranks. 

41 use_collectives (bool): Whether to use collective communication. 

42 

43 Returns: 

44 list[Any]: List of all objects from all ranks. 

45 """ 

46 if use_collectives and world_size > 1: 

47 all_objects = [None] * world_size 

48 platform.all_gather_object(all_objects, local_object) 

49 return all_objects 

50 return [local_object] 

51 

52 

53def save( 

54 state_dict: dict[str, Any], 

55 *, 

56 checkpoint_id: Optional[Union[Path, str]] = None, 

57 storage_writer: Optional[StorageWriter] = None, 

58 planner: Optional[SavePlanner] = None, 

59 no_dist: bool = False, 

60 use_collectives: bool = True, 

61 remove_redundancy: bool = True, 

62 save_to_minimum_rank: bool = False, 

63) -> Metadata: 

64 """ 

65 Save a distributed checkpoint in SPMD style. 

66 

67 This function saves a state_dict containing DTensors, where each rank 

68 only saves their local shards. 

69 

70 Args: 

71 state_dict (dict[str, Any]): The state_dict to save. 

72 checkpoint_id (Optional[Union[Path, str]]): The ID/path of this checkpoint instance (can be Path or str). 

73 Default None. 

74 storage_writer (Optional[StorageWriter]): Instance of StorageWriter. If None, FileSystemWriter 

75 will be created based on checkpoint_id. Default None. 

76 planner (Optional[SavePlanner]): Instance of SavePlanner. If None, StandardSavePlanner will be used. 

77 Default None. 

78 no_dist (bool): If True, save in single process mode. Default False. 

79 use_collectives (bool): If True, use collective communication for coordination. 

80 If False, each rank saves its own shard data and rank-local metadata (.metadata_rank{rank}), 

81 with no cross-rank interaction. Default True. 

82 remove_redundancy (bool): If True, deduplicate tensors across ranks. Default True. 

83 save_to_minimum_rank (bool): If True, deduplicated items are saved on the minimum rank. Default False. 

84 

85 Returns: 

86 Metadata: Metadata object for the saved checkpoint. 

87 """ 

88 platform = get_platform() 

89 

90 # Convert checkpoint_id to Path if it's a string 

91 checkpoint_id = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id 

92 

93 # Determine if we're in distributed mode 

94 use_collectives = False if no_dist else use_collectives 

95 

96 # When use_collectives=False: each rank saves its own shard, no cross-rank interaction 

97 if not use_collectives: 

98 remove_redundancy = False 

99 

100 # Set up storage writer 

101 if storage_writer is None: 

102 if checkpoint_id is None: 

103 raise ValueError("Either storage_writer or checkpoint_id must be provided") 

104 storage_writer = FileSystemWriter(checkpoint_id) 

105 else: 

106 if checkpoint_id: 

107 storage_writer.initialize_writer(checkpoint_id) 

108 

109 # Set up planner 

110 planner = StandardSavePlanner() if planner is None else planner 

111 

112 # Get rank and coordinator info 

113 rank = platform.get_rank() 

114 world_size = platform.get_world_size() 

115 is_coordinator = rank == 0 

116 

117 # Configure planner (remove_redundancy=False when use_collectives=False) 

118 planner.configure_planner( 

119 state_dict=state_dict, 

120 is_coordinator=is_coordinator, 

121 rank=rank, 

122 remove_redundancy=remove_redundancy, 

123 save_to_minimum_rank=save_to_minimum_rank 

124 ) 

125 

126 # Configure storage writer (use_collectives for rank-local metadata when False) 

127 storage_writer.configure_writer( 

128 is_coordinator=is_coordinator, 

129 rank=rank, 

130 use_collectives=use_collectives 

131 ) 

132 

133 # Build local plan 

134 local_plan = planner.build_local_plan() 

135 local_plan = storage_writer.optimize_local_plan(local_plan) 

136 

137 # Gather all local plans and build global plan 

138 all_local_plans = _gather_from_all_ranks(platform, local_plan, world_size, use_collectives) 

139 global_plans, metadata = planner.build_global_plan(all_local_plans) 

140 global_plans = storage_writer.optimize_global_plan(global_plans) 

141 

142 # Select central plan for current rank 

143 if use_collectives and world_size > 1 and global_plans: 

144 central_plan = global_plans[rank] 

145 elif global_plans: 

146 central_plan = global_plans[0] 

147 else: 

148 central_plan = local_plan 

149 

150 # Finalize plan 

151 final_plan = planner.finalize_plan(central_plan) 

152 

153 # Write data 

154 write_results = storage_writer.execute_write(final_plan, planner) 

155 

156 # Finalize checkpoint 

157 all_write_results = _gather_from_all_ranks(platform, write_results, world_size, use_collectives) 

158 storage_writer.finalize_checkpoint(metadata, all_write_results) 

159 

160 return metadata 

161 

162 

163def load( 

164 state_dict: dict[str, Any], 

165 *, 

166 checkpoint_id: Optional[Union[Path, str]] = None, 

167 storage_reader: Optional[StorageReader] = None, 

168 planner: Optional[LoadPlanner] = None, 

169 no_dist: bool = False, 

170 use_collectives: bool = True, 

171) -> None: 

172 """ 

173 Load a distributed checkpoint into state_dict in SPMD style. 

174 

175 Each rank will try to read the least amount of data necessary 

176 to fulfill the requested state_dict. When loading DTensor instances, 

177 each rank only reads data for their local shards. 

178 

179 Args: 

180 state_dict (dict[str, Any]): The state_dict to load the checkpoint into (modified in-place). 

181 checkpoint_id (Optional[Union[Path, str]]): The ID/path of this checkpoint instance (can be Path or str). 

182 Default None. 

183 storage_reader (Optional[StorageReader]): Instance of StorageReader. If None, FileSystemReader 

184 will be created based on checkpoint_id. Default None. 

185 planner (Optional[LoadPlanner]): Instance of LoadPlanner. If None, StandardLoadPlanner will be used. 

186 Default None. 

187 no_dist (bool): If True, load without cross-rank synchronization. Default False. 

188 use_collectives (bool): If False, load from rank-local metadata (.metadata_rank{rank}), 

189 for checkpoints saved with save(use_collectives=False). No cross-rank interaction. Default True. 

190 

191 Returns: 

192 None. The state_dict is modified in-place. 

193 """ 

194 platform = get_platform() 

195 

196 # Convert checkpoint_id to Path if it's a string 

197 checkpoint_id = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id 

198 

199 # Determine if we're in distributed mode 

200 use_collectives = False if no_dist else use_collectives 

201 

202 # Set up storage reader 

203 if storage_reader is None: 

204 if checkpoint_id is None: 

205 raise ValueError("Either storage_reader or checkpoint_id must be provided") 

206 storage_reader = FileSystemReader(checkpoint_id) 

207 else: 

208 if checkpoint_id: 

209 storage_reader.initialize_reader(checkpoint_id) 

210 

211 # Set up planner 

212 planner = StandardLoadPlanner() if planner is None else planner 

213 

214 # Get rank and coordinator info 

215 rank = platform.get_rank() 

216 world_size = platform.get_world_size() 

217 is_coordinator = rank == 0 

218 

219 # Load metadata 

220 if use_collectives: 

221 try: 

222 metadata = storage_reader.load_metadata() 

223 except FileNotFoundError: 

224 # Fallback to rank-local metadata (e.g. checkpoint saved with use_collectives=False) 

225 metadata = storage_reader.load_metadata(rank=rank) 

226 use_collectives = False 

227 else: 

228 # Load rank-local metadata directly (no cross-rank interaction) 

229 metadata = storage_reader.load_metadata(rank=rank) 

230 

231 # Configure planner 

232 planner.configure_planner( 

233 state_dict=state_dict, 

234 metadata=metadata, 

235 is_coordinator=is_coordinator, 

236 rank=rank 

237 ) 

238 

239 # Configure storage reader 

240 storage_reader.configure_reader( 

241 metadata=metadata, 

242 is_coordinator=is_coordinator, 

243 rank=rank, 

244 use_collectives=use_collectives 

245 ) 

246 

247 # Build local plan 

248 local_plan = planner.build_local_plan() 

249 local_plan = storage_reader.optimize_local_plan(local_plan) 

250 

251 # Gather all local plans and build global plan 

252 all_local_plans = _gather_from_all_ranks(platform, local_plan, world_size, use_collectives) 

253 global_plans = planner.build_global_plan(all_local_plans) 

254 global_plans = storage_reader.optimize_global_plan(global_plans) 

255 

256 # Select central plan for current rank 

257 if use_collectives and world_size > 1 and global_plans: 

258 central_plan = global_plans[rank] 

259 elif global_plans: 

260 central_plan = global_plans[0] 

261 else: 

262 central_plan = local_plan 

263 

264 # Finalize plan 

265 final_plan = planner.finalize_plan(central_plan) 

266 

267 # Execute read 

268 storage_reader.execute_read(final_plan, planner)