Coverage for hyper_parallel / core / checkpoint / filesystem_storage.py: 87%

172 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""File system storage implementations for checkpoint save and load.""" 

16import os 

17import pickle 

18from pathlib import Path 

19from typing import Any, Optional, Union 

20 

21from hyper_parallel.core.checkpoint.metadata import Metadata, MetadataIndex 

22from hyper_parallel.core.checkpoint.planner import ( 

23 LoadItemType, 

24 LoadPlan, 

25 LoadPlanner, 

26 ReadItem, 

27 SavePlan, 

28 SavePlanner, 

29 WriteItem, 

30) 

31from hyper_parallel.core.checkpoint.storage import ( 

32 StorageInfo, 

33 StorageReader, 

34 StorageWriter, 

35 WriteResult, 

36 _metadata_file_name, 

37) 

38from hyper_parallel.core.checkpoint.util import narrow_tensor_by_index 

39from hyper_parallel.platform import get_platform 

40 

41 

42class FileSystemWriter(StorageWriter): 

43 """ 

44 File system storage writer implementation. 

45 

46 Saves checkpoint data to the local file system, organizing tensors 

47 into safetensors files and bytes into separate files. 

48 """ 

49 

50 def __init__(self, checkpoint_dir: Union[Path, str]): 

51 self.checkpoint_dir = Path(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir 

52 self.checkpoint_dir.mkdir(parents=True, exist_ok=True) 

53 self.rank: int = 0 

54 self.is_coordinator: bool = False 

55 self.use_collectives: bool = True 

56 

57 def initialize_writer(self, checkpoint_id: Optional[Union[Path, str]] = None) -> None: 

58 """ 

59 Initialize storage writer with new checkpoint directory. 

60 

61 Args: 

62 checkpoint_id (Optional[Union[Path, str]]): New checkpoint directory path. Default None. 

63 """ 

64 if checkpoint_id: 

65 self.checkpoint_dir = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id 

66 self.checkpoint_dir.mkdir(parents=True, exist_ok=True) 

67 

68 def configure_writer(self, is_coordinator: bool, **kwargs) -> None: 

69 """ 

70 Configure storage writer. 

71 

72 Args: 

73 is_coordinator (bool): Whether this rank is the coordinator. 

74 **kwargs: Additional keyword arguments (e.g., rank, use_collectives). 

75 """ 

76 self.is_coordinator = is_coordinator 

77 self.rank = kwargs.get("rank", get_platform().get_rank()) 

78 self.use_collectives = kwargs.get("use_collectives", True) 

79 

80 def optimize_local_plan(self, plan: SavePlan) -> SavePlan: 

81 """ 

82 Optimize local plan. 

83 

84 Args: 

85 plan (SavePlan): Local save plan. 

86 

87 Returns: 

88 SavePlan: Optimized local plan. 

89 """ 

90 return plan 

91 

92 def optimize_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: 

93 """ 

94 Optimize global plan. 

95 

96 Args: 

97 plans (list[SavePlan]): List of local plans from all ranks. 

98 

99 Returns: 

100 list[SavePlan]: Optimized global plans. 

101 """ 

102 return plans 

103 

104 def _write_bytes_item(self, item: WriteItem) -> WriteResult: 

105 """ 

106 Write a single bytes item to storage. 

107 

108 Args: 

109 item (WriteItem): WriteItem containing bytes data. 

110 

111 Returns: 

112 WriteResult: Write result with storage metadata. 

113 """ 

114 fqn = item.index.fqn 

115 file_name = f"{fqn}_rank{self.rank}.bytes" 

116 file_path = self.checkpoint_dir / file_name 

117 with open(file_path, "wb") as f: 

118 if isinstance(item.bytes_io_data, bytes): 

119 f.write(item.bytes_io_data) 

120 else: 

121 pickle.dump(item.bytes_io_data, f) 

122 try: 

123 length = f.tell() 

124 except (OSError, IOError): 

125 length = 0 

126 storage_info = StorageInfo( 

127 relative_path=file_name, 

128 offset=0, 

129 length=length, 

130 ) 

131 return WriteResult( 

132 index=item.index, 

133 storage_data=storage_info, 

134 ) 

135 

136 def _collect_tensors(self, plan: SavePlan, planner: SavePlanner) -> dict[str, Any]: 

137 """ 

138 Collect tensor data from planner cache. 

139 

140 Args: 

141 plan (SavePlan): Save plan containing WriteItems. 

142 planner (SavePlanner): Save planner. 

143 

144 Returns: 

145 dict[str, Any]: Dictionary mapping FQN to tensor data. 

146 

147 Raises: 

148 RuntimeError: If tensor data not found in planner cache. 

149 """ 

150 tensor_dict: dict[str, Any] = {} 

151 for item in plan.items: 

152 if item.type.value == "tensor" and item.tensor_data: 

153 # Get tensor from planner cache instead of tensor_data 

154 tensor = planner.get_tensor(item.index) 

155 if tensor is None: 

156 raise RuntimeError( 

157 f"Tensor data not found in planner cache for index {item.index}. " 

158 f"FQN: {item.index.fqn}" 

159 ) 

160 fqn = item.index.fqn 

161 tensor_dict[fqn] = tensor 

162 return tensor_dict 

163 

164 def _write_tensors(self, plan: SavePlan, tensor_dict: dict[str, Any]) -> list[WriteResult]: 

165 """ 

166 Write all tensors to safetensors file and create WriteResults. 

167 

168 Args: 

169 plan (SavePlan): Save plan containing WriteItems. 

170 tensor_dict (dict[str, Any]): Dictionary mapping FQN to tensor data. 

171 

172 Returns: 

173 list[WriteResult]: List of write results for tensor items. 

174 """ 

175 if not tensor_dict: 

176 return [] 

177 

178 platform = get_platform() 

179 file_name = f"_rank{self.rank}_.safetensors" 

180 file_path = self.checkpoint_dir / file_name 

181 platform.save_checkpoint(tensor_dict, str(file_path)) 

182 

183 # Record StorageInfo for each tensor 

184 # Note: we don't know per-tensor byte offsets, so offset=0, length=-1 

185 results: list[WriteResult] = [] 

186 for item in plan.items: 

187 if item.type.value == "tensor" and item.tensor_data: 

188 storage_info = StorageInfo( 

189 relative_path=file_name, 

190 offset=0, 

191 length=-1, 

192 ) 

193 results.append( 

194 WriteResult( 

195 index=item.index, 

196 storage_data=storage_info, 

197 ) 

198 ) 

199 return results 

200 

201 def execute_write(self, plan: SavePlan, planner: SavePlanner) -> list[WriteResult]: 

202 """ 

203 Write data to storage and return per-item storage metadata. 

204 

205 Group tensors into safetensors files and bytes into separate files, recording StorageInfo for each item. 

206 

207 Args: 

208 plan (SavePlan): Save plan containing WriteItems. 

209 planner (SavePlanner): Save planner. 

210 

211 Returns: 

212 list[WriteResult]: List of write results with storage metadata. 

213 """ 

214 results: list[WriteResult] = [] 

215 

216 # Collect tensors and write bytes objects 

217 for item in plan.items: 

218 if item.type.value == "byte_io": 

219 results.append(self._write_bytes_item(item)) 

220 

221 # Collect and write tensors 

222 tensor_dict = self._collect_tensors(plan, planner) 

223 results.extend(self._write_tensors(plan, tensor_dict)) 

224 

225 return results 

226 

227 def finalize_checkpoint(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: 

228 """ 

229 Finish writing checkpoint and populate metadata.storage_data. 

230 

231 When use_collectives=True: only coordinator saves global metadata to .metadata. 

232 When use_collectives=False: each rank saves its own metadata to .rank{rank}_metadata, 

233 no cross-rank interaction. 

234 

235 Args: 

236 metadata (Metadata): Checkpoint metadata to update. 

237 results (list[list[WriteResult]]): Write results from all ranks (or single rank when use_collectives=False). 

238 """ 

239 should_save = self.use_collectives and self.is_coordinator or not self.use_collectives 

240 

241 if should_save: 

242 # Build storage_data: map MetadataIndex -> StorageInfo 

243 storage_md: dict[MetadataIndex, StorageInfo] = {} 

244 for wr_list in results: 

245 for wr in wr_list: 

246 storage_md[wr.index] = wr.storage_data 

247 metadata.storage_data = storage_md 

248 

249 # Save metadata file 

250 if self.use_collectives: 

251 metadata_file = self.checkpoint_dir / _metadata_file_name 

252 else: 

253 metadata_file = self.checkpoint_dir / f".rank{self.rank}_metadata" 

254 with open(metadata_file, "wb") as f: 

255 pickle.dump(metadata, f) 

256 

257 

258def _copy_tensor_to_target( 

259 req: ReadItem, tensor: Any, target_tensor: Any, planner: LoadPlanner 

260) -> None: 

261 """ 

262 Copy tensor data to target tensor and commit. 

263 

264 Args: 

265 req (ReadItem): ReadItem request. 

266 tensor (Any): Source tensor (tensor-like object). 

267 target_tensor (Any): Target tensor (tensor-like object). 

268 planner (LoadPlanner): Load planner for committing. 

269 """ 

270 if hasattr(target_tensor, "copy_"): 

271 target_tensor.copy_(tensor) 

272 planner.apply_tensor(req, target_tensor) 

273 else: 

274 # mindspore or non-tensor: copy via commit path 

275 planner.apply_tensor(req, tensor) 

276 

277 

278def _load_bytes_file(path: str, reqs: list[ReadItem], planner: LoadPlanner) -> None: 

279 """ 

280 Load bytes from a file. 

281 

282 Args: 

283 path (str): Path to the bytes file. 

284 reqs (list[ReadItem]): List of ReadItems for this file. 

285 planner (LoadPlanner): Load planner for loading bytes. 

286 """ 

287 for req in reqs: 

288 with open(path, "rb") as f: 

289 value = f.read() 

290 planner.apply_bytes(req, value) 

291 

292 

293def _get_tensor_size(tensor: Any) -> Optional[tuple]: 

294 """ 

295 Get size/shape of a tensor. 

296 

297 Args: 

298 tensor (Any): Tensor object (tensor-like with shape/size attribute). 

299 

300 Returns: 

301 Optional[tuple]: Tuple of tensor size or None if not available. 

302 """ 

303 if hasattr(tensor, "size") and callable(tensor.size): 

304 return tuple(tensor.size()) 

305 return getattr(tensor, "shape", None) 

306 

307 

308def _load_tensor_file( 

309 path: str, reqs: list[ReadItem], planner: LoadPlanner 

310) -> None: 

311 """ 

312 Load and process tensors from a safetensors file. 

313 

314 Args: 

315 path (str): Path to the safetensors file. 

316 reqs (list[ReadItem]): List of ReadItems for this file. 

317 planner (LoadPlanner): Load planner for resolving and committing tensors. 

318 """ 

319 platform = get_platform() 

320 param_dict = platform.load_checkpoint(path) 

321 

322 for req in reqs: 

323 fqn = req.storage_index.fqn 

324 if fqn not in param_dict: 

325 raise KeyError(f"Key {fqn} not found in checkpoint file {path}") 

326 

327 full_tensor = param_dict[fqn] 

328 # Narrow by storage_offsets/lengths (resharding) 

329 tensor = narrow_tensor_by_index( 

330 full_tensor, 

331 req.storage_offsets, 

332 req.lengths, 

333 ) 

334 target_tensor = planner.acquire_tensor(req) 

335 if hasattr(target_tensor, "detach"): 

336 target_tensor = target_tensor.detach() 

337 

338 # Size check (torch-aligned AssertionError) 

339 target_size = _get_tensor_size(target_tensor) 

340 tensor_size = _get_tensor_size(tensor) 

341 if target_size is not None and tensor_size is not None: 

342 if target_size != tensor_size: 

343 raise AssertionError( 

344 f"req {req.storage_index} mismatch sizes " 

345 f"{target_size} vs {tensor_size}" 

346 ) 

347 

348 # Copy data to target 

349 _copy_tensor_to_target(req, tensor, target_tensor, planner) 

350 

351 

352class FileSystemReader(StorageReader): 

353 """ 

354 File system storage reader implementation. 

355 

356 Reads checkpoint data from the local file system, loading tensors 

357 from safetensors files and bytes from separate files. 

358 """ 

359 

360 def __init__(self, checkpoint_dir: Union[Path, str]): 

361 self.checkpoint_dir = Path(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir 

362 # Cached storage layout: MetadataIndex -> StorageInfo (torch-aligned) 

363 self.storage_data: Optional[dict[MetadataIndex, StorageInfo]] = None 

364 self.rank: int = 0 

365 self.is_coordinator: bool = False 

366 

367 def initialize_reader(self, checkpoint_id: Optional[Union[Path, str]] = None) -> None: 

368 """ 

369 Initialize storage reader with new checkpoint directory. 

370 

371 Args: 

372 checkpoint_id (Optional[Union[Path, str]]): New checkpoint directory path. Default None. 

373 """ 

374 if checkpoint_id: 

375 self.checkpoint_dir = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id 

376 

377 def load_metadata(self, **kwargs) -> Metadata: 

378 """ 

379 Load checkpoint metadata from file. 

380 

381 When rank is provided in kwargs: load rank-local metadata from .rank{rank}_metadata 

382 (for checkpoints saved with use_collectives=False). 

383 Otherwise: load global metadata from .metadata. 

384 

385 Args: 

386 **kwargs: Optional arguments (e.g., rank for rank-local metadata). 

387 

388 Returns: 

389 Metadata: Metadata object loaded from file. 

390 """ 

391 rank = kwargs.get("rank") 

392 if rank is not None: 

393 metadata_file = self.checkpoint_dir / f".rank{rank}_metadata" 

394 else: 

395 metadata_file = self.checkpoint_dir / _metadata_file_name 

396 

397 if not metadata_file.exists(): 

398 raise FileNotFoundError(f"Metadata file not found: {metadata_file}") 

399 with open(metadata_file, "rb") as f: 

400 metadata = pickle.load(f) 

401 return metadata 

402 

403 def configure_reader(self, metadata: Metadata, is_coordinator: bool, **kwargs) -> None: 

404 """Configure storage reader.""" 

405 # Cache storage_data separately for quick lookup in execute_read. 

406 # This mirrors torch.filesystem, where reader keeps a storage_data dict. 

407 self.storage_data = getattr(metadata, "storage_data", None) 

408 self.is_coordinator = is_coordinator 

409 self.rank = kwargs.get("rank", get_platform().get_rank()) 

410 

411 def optimize_local_plan(self, plan: LoadPlan) -> LoadPlan: 

412 """ 

413 Optimize local plan. 

414 

415 Args: 

416 plan (LoadPlan): Local load plan. 

417 

418 Returns: 

419 LoadPlan: Optimized local plan. 

420 """ 

421 return plan 

422 

423 def optimize_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]: 

424 """ 

425 Optimize global plan. 

426 

427 Args: 

428 plans (list[LoadPlan]): List of local plans from all ranks. 

429 

430 Returns: 

431 list[LoadPlan]: Optimized global plans. 

432 """ 

433 return plans 

434 

435 def _get_storage_path(self, read_item: ReadItem) -> str: 

436 """ 

437 Get storage file path for a read item. 

438 

439 Args: 

440 read_item (ReadItem): ReadItem to get path for. 

441 

442 Returns: 

443 str: Absolute path to the storage file. 

444 """ 

445 storage_data = self.storage_data 

446 

447 if storage_data is not None: 

448 storage_info = storage_data.get(read_item.storage_index) 

449 if storage_info is None: 

450 raise KeyError(f"StorageInfo not found for index {read_item.storage_index}") 

451 return str(self.checkpoint_dir / storage_info.relative_path) 

452 # Fallback: derive path from rank & fqn (legacy format without storage_data) 

453 if read_item.type == LoadItemType.TENSOR: 

454 rank = read_item.storage_index.index or self.rank 

455 return str(self.checkpoint_dir / f"_rank{rank}_.safetensors") 

456 fqn = read_item.storage_index.fqn 

457 rank = read_item.storage_index.index or self.rank 

458 return str(self.checkpoint_dir / f"{fqn}_rank{rank}.bytes") 

459 

460 def _group_items_by_file(self, plan: LoadPlan) -> dict[str, list]: 

461 """ 

462 Group ReadItems by storage file path. 

463 

464 Args: 

465 plan (LoadPlan): Load plan containing ReadItems. 

466 

467 Returns: 

468 dict[str, list[ReadItem]]: Dictionary mapping file paths to lists of ReadItems. 

469 """ 

470 per_file: dict[str, list] = {} 

471 for read_item in plan.items: 

472 path = self._get_storage_path(read_item) 

473 per_file.setdefault(path, []).append(read_item) 

474 return per_file 

475 

476 def execute_read(self, plan: LoadPlan, planner: LoadPlanner) -> None: 

477 """ 

478 Read data from storage. 

479 

480 Aligned with torch filesystem read_data: groups ReadItems by file, 

481 loads each file once, narrows tensors by storage_offsets/lengths for 

482 resharding, then resolves/copies/commits data. 

483 

484 Args: 

485 plan (LoadPlan): Load plan containing ReadItems. 

486 planner (LoadPlanner): Load planner for resolving and committing tensors. 

487 """ 

488 # Group ReadItems by storage file path (like torch per_file) 

489 per_file = self._group_items_by_file(plan) 

490 

491 # Process each file 

492 for path, reqs in per_file.items(): 

493 if not os.path.exists(path): 

494 raise FileNotFoundError(f"Checkpoint file not found: {path}") 

495 

496 if path.endswith(".bytes"): 

497 # BYTE_IO: one file per (fqn, rank) 

498 _load_bytes_file(path, reqs, planner) 

499 else: 

500 # TENSOR: one safetensors file per rank 

501 _load_tensor_file(path, reqs, planner)