Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / filesystem_storage.py: 62%

197 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""File system storage implementations for checkpoint save and load.""" 

16import os 

17import pickle 

18from pathlib import Path 

19from typing import Any, Optional, Union 

20 

21from safetensors import safe_open 

22 

23from hyper_parallel.core.distributed_checkpoint.metadata import Metadata, MetadataIndex 

24from hyper_parallel.core.distributed_checkpoint.planner import ( 

25 LoadPlan, 

26 LoadPlanner, 

27 ReadItem, 

28 SavePlan, 

29 SavePlanner, 

30 WriteItem, 

31) 

32from hyper_parallel.core.distributed_checkpoint.storage import ( 

33 StorageInfo, 

34 StorageReader, 

35 StorageWriter, 

36 WriteResult, 

37 METADATA_FILE_NAME, 

38) 

39from hyper_parallel.core.distributed_checkpoint.util import narrow_tensor_by_index 

40from hyper_parallel.platform import get_platform 

41from hyper_parallel.platform.platform import PlatformType 

42 

43 

44class FileSystemWriter(StorageWriter): 

45 """ 

46 File system storage writer implementation. 

47 

48 Saves checkpoint data to the local file system, organizing tensors 

49 into safetensors files and bytes into separate files. 

50 """ 

51 

52 def __init__(self, checkpoint_dir: Union[Path, str]): 

53 self.checkpoint_dir = Path(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir 

54 self.checkpoint_dir.mkdir(parents=True, exist_ok=True) 

55 self.rank: int = 0 

56 self.is_coordinator: bool = False 

57 self.use_collectives: bool = True 

58 

59 def initialize_writer(self, checkpoint_id: Optional[Union[Path, str]] = None) -> None: 

60 """ 

61 Initialize storage writer with new checkpoint directory. 

62 

63 Args: 

64 checkpoint_id (Optional[Union[Path, str]]): New checkpoint directory path. Default None. 

65 """ 

66 if checkpoint_id: 

67 self.checkpoint_dir = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id 

68 self.checkpoint_dir.mkdir(parents=True, exist_ok=True) 

69 

70 def configure_writer(self, is_coordinator: bool, **kwargs) -> None: 

71 """ 

72 Configure storage writer. 

73 

74 Args: 

75 is_coordinator (bool): Whether this rank is the coordinator. 

76 **kwargs: Additional keyword arguments (e.g., rank, use_collectives). 

77 """ 

78 self.is_coordinator = is_coordinator 

79 self.rank = kwargs.get("rank") if "rank" in kwargs else get_platform().get_rank() 

80 self.use_collectives = kwargs.get("use_collectives", True) 

81 

82 def optimize_local_plan(self, plan: SavePlan) -> SavePlan: 

83 """ 

84 Optimize local plan. 

85 

86 Args: 

87 plan (SavePlan): Local save plan. 

88 

89 Returns: 

90 SavePlan: Optimized local plan. 

91 """ 

92 return plan 

93 

94 def optimize_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: 

95 """ 

96 Optimize global plan. 

97 

98 Args: 

99 plans (list[SavePlan]): List of local plans from all ranks. 

100 

101 Returns: 

102 list[SavePlan]: Optimized global plans. 

103 """ 

104 return plans 

105 

106 

107 def _serialize_bytes_item(self, item: WriteItem, planner: SavePlanner) -> bytes: 

108 """Serialize a BYTE_IO item payload while preserving current behavior.""" 

109 data = planner.get_data(item) 

110 if isinstance(data, bytes): 

111 return data 

112 return pickle.dumps(data) 

113 

114 

115 def _write_bytes_items(self, plan: SavePlan, planner: SavePlanner) -> list[WriteResult]: 

116 """ 

117 Write all BYTE_IO items into one per-rank bytes file. 

118 

119 Args: 

120 plan (SavePlan): Save plan containing WriteItems. 

121 planner (SavePlanner): Save planner used to resolve runtime data. 

122 

123 Returns: 

124 list[WriteResult]: Write results for BYTE_IO items. 

125 """ 

126 byte_items = [item for item in plan.items if item.type.value == "byte_io"] 

127 if not byte_items: 

128 return [] 

129 

130 file_name = f"_rank{self.rank}_.bytes" 

131 file_path = self.checkpoint_dir / file_name 

132 

133 results: list[WriteResult] = [] 

134 

135 with open(file_path, "wb") as f: 

136 for item in byte_items: 

137 payload = self._serialize_bytes_item(item, planner) 

138 offset = f.tell() 

139 f.write(payload) 

140 length = len(payload) 

141 storage_info = StorageInfo( 

142 relative_path=file_name, 

143 offset=offset, 

144 length=length, 

145 ) 

146 results.append( 

147 WriteResult( 

148 index=item.index, 

149 storage_data=storage_info, 

150 ) 

151 ) 

152 

153 return results 

154 

155 def _collect_tensors(self, plan: SavePlan, planner: SavePlanner) -> dict[str, Any]: 

156 """ 

157 Collect tensor data from planner runtime lookup. 

158 

159 Args: 

160 plan (SavePlan): Save plan containing WriteItems. 

161 planner (SavePlanner): Save planner. 

162 

163 Returns: 

164 dict[str, Any]: Dictionary mapping FQN to tensor data. 

165 

166 Raises: 

167 RuntimeError: If tensor data cannot be resolved for an item. 

168 """ 

169 tensor_dict: dict[str, Any] = {} 

170 for item in plan.items: 

171 if item.type.value == "tensor" and item.tensor_data: 

172 tensor = planner.get_data(item) 

173 if tensor is None: 

174 raise RuntimeError( 

175 f"Tensor data could not be resolved for index {item.index}. " 

176 f"FQN: {item.index.fqn}" 

177 ) 

178 fqn = item.index.fqn 

179 tensor_dict[fqn] = tensor 

180 return tensor_dict 

181 

182 def _write_tensors(self, plan: SavePlan, tensor_dict: dict[str, Any]) -> list[WriteResult]: 

183 """ 

184 Write all tensors to safetensors file and create WriteResults. 

185 

186 Args: 

187 plan (SavePlan): Save plan containing WriteItems. 

188 tensor_dict (dict[str, Any]): Dictionary mapping FQN to tensor data. 

189 

190 Returns: 

191 list[WriteResult]: List of write results for tensor items. 

192 """ 

193 if not tensor_dict: 

194 return [] 

195 

196 platform = get_platform() 

197 file_name = f"_rank{self.rank}_.safetensors" 

198 file_path = self.checkpoint_dir / file_name 

199 platform.save_checkpoint(tensor_dict, str(file_path)) 

200 

201 # Record StorageInfo for each tensor 

202 # Note: we don't know per-tensor byte offsets, so offset=0, length=-1 

203 results: list[WriteResult] = [] 

204 for item in plan.items: 

205 if item.type.value == "tensor" and item.tensor_data: 

206 storage_info = StorageInfo( 

207 relative_path=file_name, 

208 offset=0, 

209 length=-1, 

210 ) 

211 results.append( 

212 WriteResult( 

213 index=item.index, 

214 storage_data=storage_info, 

215 ) 

216 ) 

217 return results 

218 

219 def execute_write(self, plan: SavePlan, planner: SavePlanner) -> list[WriteResult]: 

220 """ 

221 Write data to storage and return per-item storage metadata. 

222 

223 Group tensors into safetensors files and bytes into separate files, recording StorageInfo for each item. 

224 

225 Args: 

226 plan (SavePlan): Save plan containing WriteItems. 

227 planner (SavePlanner): Save planner. 

228 

229 Returns: 

230 list[WriteResult]: List of write results with storage metadata. 

231 """ 

232 results: list[WriteResult] = [] 

233 

234 # Write all BYTE_IO items into one file per rank 

235 results.extend(self._write_bytes_items(plan, planner)) 

236 

237 # Collect and write tensors 

238 tensor_dict = self._collect_tensors(plan, planner) 

239 results.extend(self._write_tensors(plan, tensor_dict)) 

240 

241 return results 

242 

243 def finalize_checkpoint(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: 

244 """ 

245 Finish writing checkpoint and populate metadata.storage_data. 

246 

247 When use_collectives=True: only coordinator saves global metadata to .metadata. 

248 When use_collectives=False: each rank saves its own metadata to .rank{rank}_metadata, 

249 no cross-rank interaction. 

250 

251 Args: 

252 metadata (Metadata): Checkpoint metadata to update. 

253 results (list[list[WriteResult]]): Write results from all ranks (or single rank when use_collectives=False). 

254 """ 

255 should_save = self.use_collectives and self.is_coordinator or not self.use_collectives 

256 

257 if should_save: 

258 # Build storage_data: map MetadataIndex -> StorageInfo 

259 storage_md: dict[MetadataIndex, StorageInfo] = {} 

260 for wr_list in results: 

261 for wr in wr_list: 

262 storage_md[wr.index] = wr.storage_data 

263 metadata.storage_data = storage_md 

264 

265 # Save metadata file 

266 if self.use_collectives: 

267 metadata_file = self.checkpoint_dir / METADATA_FILE_NAME 

268 else: 

269 metadata_file = self.checkpoint_dir / f".rank{self.rank}_metadata" 

270 with open(metadata_file, "wb") as f: 

271 pickle.dump(metadata, f) 

272 

273 

274def _copy_tensor_to_target( 

275 req: ReadItem, tensor: Any, target_tensor: Any, planner: LoadPlanner 

276) -> None: 

277 """ 

278 Copy tensor data to target tensor and commit. 

279 

280 Args: 

281 req (ReadItem): ReadItem request. 

282 tensor (Any): Source tensor (tensor-like object). 

283 target_tensor (Any): Target tensor (tensor-like object). 

284 planner (LoadPlanner): Load planner for committing. 

285 """ 

286 if hasattr(target_tensor, "copy_"): 

287 target_tensor.copy_(tensor) 

288 planner.apply_tensor(req, target_tensor) 

289 else: 

290 # mindspore or non-tensor: copy via commit path 

291 planner.apply_tensor(req, tensor) 

292 

293 

294def _load_bytes_file( 

295 path: str, 

296 reqs: list[ReadItem], 

297 planner: LoadPlanner, 

298 storage_data: dict[MetadataIndex, StorageInfo], 

299) -> None: 

300 """ 

301 Load bytes from a file. 

302 

303 Args: 

304 path (str): Path to the bytes file. 

305 reqs (list[ReadItem]): List of ReadItems for this file. 

306 planner (LoadPlanner): Load planner for loading bytes. 

307 """ 

308 with open(path, "rb") as f: 

309 for req in reqs: 

310 storage_info = storage_data.get(req.storage_index) 

311 if storage_info is None: 

312 raise KeyError( 

313 f"StorageInfo not found for index {req.storage_index}" 

314 ) 

315 f.seek(storage_info.offset) 

316 value = f.read(storage_info.length) 

317 planner.apply_bytes(req, value) 

318 

319 

320def _get_tensor_size(tensor: Any) -> Optional[tuple]: 

321 """ 

322 Get size/shape of a tensor. 

323 

324 Args: 

325 tensor (Any): Tensor object (tensor-like with shape/size attribute). 

326 

327 Returns: 

328 Optional[tuple]: Tuple of tensor size or None if not available. 

329 """ 

330 if hasattr(tensor, "size") and callable(tensor.size): 

331 return tuple(tensor.size()) 

332 return getattr(tensor, "shape", None) 

333 

334 

335def _load_tensor_file( 

336 path: str, reqs: list[ReadItem], planner: LoadPlanner 

337) -> None: 

338 """ 

339 Load and process tensors from a safetensors file. 

340 

341 Args: 

342 path (str): Path to the safetensors file. 

343 reqs (list[ReadItem]): List of ReadItems for this file. 

344 planner (LoadPlanner): Load planner for resolving and committing tensors. 

345 """ 

346 platform = get_platform() 

347 

348 if platform.platform_type == PlatformType.PYTORCH: 

349 with safe_open(path, framework="pt", device="cpu") as tensor_file: 

350 for req in reqs: 

351 fqn = req.storage_index.fqn 

352 if fqn not in tensor_file.keys(): 

353 raise KeyError(f"Key {fqn} not found in checkpoint file {path}") 

354 tensor_slices = tuple( 

355 slice(int(off), int(off) + int(length)) 

356 for off, length in zip(req.storage_offsets, req.lengths) 

357 ) 

358 if tensor_slices: 

359 tensor = tensor_file.get_slice(fqn)[tensor_slices] 

360 else: 

361 tensor = narrow_tensor_by_index( 

362 tensor_file.get_tensor(fqn), 

363 req.storage_offsets, 

364 req.lengths, 

365 ) 

366 

367 target_tensor = planner.acquire_tensor(req) 

368 if hasattr(target_tensor, "detach"): 

369 target_tensor = target_tensor.detach() 

370 

371 # Size check (torch-aligned AssertionError) 

372 target_size = _get_tensor_size(target_tensor) 

373 tensor_size = _get_tensor_size(tensor) 

374 if target_size is not None and tensor_size is not None: 

375 if target_size != tensor_size: 

376 raise AssertionError( 

377 f"req {req.storage_index} mismatch sizes " 

378 f"{target_size} vs {tensor_size}" 

379 ) 

380 

381 # Copy data to target 

382 _copy_tensor_to_target(req, tensor, target_tensor, planner) 

383 return 

384 

385 param_dict = platform.load_checkpoint(path) 

386 for req in reqs: 

387 fqn = req.storage_index.fqn 

388 if fqn not in param_dict: 

389 raise KeyError(f"Key {fqn} not found in checkpoint file {path}") 

390 full_tensor = param_dict[fqn] 

391 tensor = narrow_tensor_by_index( 

392 full_tensor, 

393 req.storage_offsets, 

394 req.lengths, 

395 ) 

396 

397 target_tensor = planner.acquire_tensor(req) 

398 if hasattr(target_tensor, "detach"): 

399 target_tensor = target_tensor.detach() 

400 

401 # Size check (torch-aligned AssertionError) 

402 target_size = _get_tensor_size(target_tensor) 

403 tensor_size = _get_tensor_size(tensor) 

404 if target_size is not None and tensor_size is not None: 

405 if target_size != tensor_size: 

406 raise AssertionError( 

407 f"req {req.storage_index} mismatch sizes " 

408 f"{target_size} vs {tensor_size}" 

409 ) 

410 

411 # Copy data to target 

412 _copy_tensor_to_target(req, tensor, target_tensor, planner) 

413 

414 

415class FileSystemReader(StorageReader): 

416 """ 

417 File system storage reader implementation. 

418 

419 Reads checkpoint data from the local file system, loading tensors 

420 from safetensors files and bytes from separate files. 

421 """ 

422 

423 def __init__(self, checkpoint_dir: Union[Path, str]): 

424 self.checkpoint_dir = Path(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir 

425 # Cached storage layout: MetadataIndex -> StorageInfo (torch-aligned) 

426 self.storage_data: Optional[dict[MetadataIndex, StorageInfo]] = None 

427 self.rank: int = 0 

428 self.is_coordinator: bool = False 

429 

430 def initialize_reader(self, checkpoint_id: Optional[Union[Path, str]] = None) -> None: 

431 """ 

432 Initialize storage reader with new checkpoint directory. 

433 

434 Args: 

435 checkpoint_id (Optional[Union[Path, str]]): New checkpoint directory path. Default None. 

436 """ 

437 if checkpoint_id: 

438 self.checkpoint_dir = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id 

439 

440 def load_metadata(self, **kwargs) -> Metadata: 

441 """ 

442 Load checkpoint metadata from file. 

443 

444 When rank is provided in kwargs: load rank-local metadata from .rank{rank}_metadata 

445 (for checkpoints saved with use_collectives=False). 

446 Otherwise: load global metadata from .metadata. 

447 

448 Args: 

449 **kwargs: Optional arguments (e.g., rank for rank-local metadata). 

450 

451 Returns: 

452 Metadata: Metadata object loaded from file. 

453 """ 

454 rank = kwargs.get("rank") 

455 if rank is not None: 

456 metadata_file = self.checkpoint_dir / f".rank{rank}_metadata" 

457 else: 

458 metadata_file = self.checkpoint_dir / METADATA_FILE_NAME 

459 

460 if not metadata_file.exists(): 

461 raise FileNotFoundError(f"Metadata file not found: {metadata_file}") 

462 with open(metadata_file, "rb") as f: 

463 metadata = pickle.load(f) 

464 return metadata 

465 

466 def configure_reader(self, metadata: Metadata, is_coordinator: bool, **kwargs) -> None: 

467 """Configure storage reader.""" 

468 # Cache storage_data separately for quick lookup in execute_read. 

469 # This mirrors torch.filesystem, where reader keeps a storage_data dict. 

470 self.storage_data = getattr(metadata, "storage_data", None) 

471 self.is_coordinator = is_coordinator 

472 self.rank = kwargs.get("rank") if "rank" in kwargs else get_platform().get_rank() 

473 

474 def optimize_local_plan(self, plan: LoadPlan) -> LoadPlan: 

475 """ 

476 Optimize local plan. 

477 

478 Args: 

479 plan (LoadPlan): Local load plan. 

480 

481 Returns: 

482 LoadPlan: Optimized local plan. 

483 """ 

484 return plan 

485 

486 def optimize_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]: 

487 """ 

488 Optimize global plan. 

489 

490 Args: 

491 plans (list[LoadPlan]): List of local plans from all ranks. 

492 

493 Returns: 

494 list[LoadPlan]: Optimized global plans. 

495 """ 

496 return plans 

497 

498 def _get_storage_path(self, read_item: ReadItem) -> str: 

499 """ 

500 Get storage file path for a read item. 

501 

502 Args: 

503 read_item (ReadItem): ReadItem to get path for. 

504 

505 Returns: 

506 str: Absolute path to the storage file. 

507 """ 

508 if self.storage_data is None: 

509 raise KeyError("Checkpoint metadata.storage_data is required for filesystem read") 

510 storage_info = self.storage_data.get(read_item.storage_index) 

511 if storage_info is None: 

512 raise KeyError(f"StorageInfo not found for index {read_item.storage_index}") 

513 return str(self.checkpoint_dir / storage_info.relative_path) 

514 

515 def _group_items_by_file(self, plan: LoadPlan) -> dict[str, list]: 

516 """ 

517 Group ReadItems by storage file path. 

518 

519 Args: 

520 plan (LoadPlan): Load plan containing ReadItems. 

521 

522 Returns: 

523 dict[str, list[ReadItem]]: Dictionary mapping file paths to lists of ReadItems. 

524 """ 

525 per_file: dict[str, list] = {} 

526 for read_item in plan.items: 

527 path = self._get_storage_path(read_item) 

528 per_file.setdefault(path, []).append(read_item) 

529 return per_file 

530 

531 def execute_read(self, plan: LoadPlan, planner: LoadPlanner) -> None: 

532 """ 

533 Read data from storage. 

534 

535 Aligned with torch filesystem read_data: groups ReadItems by file, 

536 loads each file once, narrows tensors by storage_offsets/lengths for 

537 resharding, then resolves/copies/commits data. 

538 

539 Args: 

540 plan (LoadPlan): Load plan containing ReadItems. 

541 planner (LoadPlanner): Load planner for resolving and committing tensors. 

542 """ 

543 # Group ReadItems by storage file path (like torch per_file) 

544 per_file = self._group_items_by_file(plan) 

545 

546 # Process each file 

547 for path, reqs in per_file.items(): 

548 if not os.path.exists(path): 

549 raise FileNotFoundError(f"Checkpoint file not found: {path}") 

550 

551 if path.endswith(".bytes"): 

552 # BYTE_IO: one bytes file per rank with per-item offsets. 

553 _load_bytes_file(path, reqs, planner, self.storage_data) 

554 else: 

555 # TENSOR: one safetensors file per rank 

556 _load_tensor_file(path, reqs, planner)