Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / filesystem_storage.py: 62%
197 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""File system storage implementations for checkpoint save and load."""
16import os
17import pickle
18from pathlib import Path
19from typing import Any, Optional, Union
21from safetensors import safe_open
23from hyper_parallel.core.distributed_checkpoint.metadata import Metadata, MetadataIndex
24from hyper_parallel.core.distributed_checkpoint.planner import (
25 LoadPlan,
26 LoadPlanner,
27 ReadItem,
28 SavePlan,
29 SavePlanner,
30 WriteItem,
31)
32from hyper_parallel.core.distributed_checkpoint.storage import (
33 StorageInfo,
34 StorageReader,
35 StorageWriter,
36 WriteResult,
37 METADATA_FILE_NAME,
38)
39from hyper_parallel.core.distributed_checkpoint.util import narrow_tensor_by_index
40from hyper_parallel.platform import get_platform
41from hyper_parallel.platform.platform import PlatformType
44class FileSystemWriter(StorageWriter):
45 """
46 File system storage writer implementation.
48 Saves checkpoint data to the local file system, organizing tensors
49 into safetensors files and bytes into separate files.
50 """
52 def __init__(self, checkpoint_dir: Union[Path, str]):
53 self.checkpoint_dir = Path(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir
54 self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
55 self.rank: int = 0
56 self.is_coordinator: bool = False
57 self.use_collectives: bool = True
59 def initialize_writer(self, checkpoint_id: Optional[Union[Path, str]] = None) -> None:
60 """
61 Initialize storage writer with new checkpoint directory.
63 Args:
64 checkpoint_id (Optional[Union[Path, str]]): New checkpoint directory path. Default None.
65 """
66 if checkpoint_id:
67 self.checkpoint_dir = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id
68 self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
70 def configure_writer(self, is_coordinator: bool, **kwargs) -> None:
71 """
72 Configure storage writer.
74 Args:
75 is_coordinator (bool): Whether this rank is the coordinator.
76 **kwargs: Additional keyword arguments (e.g., rank, use_collectives).
77 """
78 self.is_coordinator = is_coordinator
79 self.rank = kwargs.get("rank") if "rank" in kwargs else get_platform().get_rank()
80 self.use_collectives = kwargs.get("use_collectives", True)
82 def optimize_local_plan(self, plan: SavePlan) -> SavePlan:
83 """
84 Optimize local plan.
86 Args:
87 plan (SavePlan): Local save plan.
89 Returns:
90 SavePlan: Optimized local plan.
91 """
92 return plan
94 def optimize_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
95 """
96 Optimize global plan.
98 Args:
99 plans (list[SavePlan]): List of local plans from all ranks.
101 Returns:
102 list[SavePlan]: Optimized global plans.
103 """
104 return plans
107 def _serialize_bytes_item(self, item: WriteItem, planner: SavePlanner) -> bytes:
108 """Serialize a BYTE_IO item payload while preserving current behavior."""
109 data = planner.get_data(item)
110 if isinstance(data, bytes):
111 return data
112 return pickle.dumps(data)
115 def _write_bytes_items(self, plan: SavePlan, planner: SavePlanner) -> list[WriteResult]:
116 """
117 Write all BYTE_IO items into one per-rank bytes file.
119 Args:
120 plan (SavePlan): Save plan containing WriteItems.
121 planner (SavePlanner): Save planner used to resolve runtime data.
123 Returns:
124 list[WriteResult]: Write results for BYTE_IO items.
125 """
126 byte_items = [item for item in plan.items if item.type.value == "byte_io"]
127 if not byte_items:
128 return []
130 file_name = f"_rank{self.rank}_.bytes"
131 file_path = self.checkpoint_dir / file_name
133 results: list[WriteResult] = []
135 with open(file_path, "wb") as f:
136 for item in byte_items:
137 payload = self._serialize_bytes_item(item, planner)
138 offset = f.tell()
139 f.write(payload)
140 length = len(payload)
141 storage_info = StorageInfo(
142 relative_path=file_name,
143 offset=offset,
144 length=length,
145 )
146 results.append(
147 WriteResult(
148 index=item.index,
149 storage_data=storage_info,
150 )
151 )
153 return results
155 def _collect_tensors(self, plan: SavePlan, planner: SavePlanner) -> dict[str, Any]:
156 """
157 Collect tensor data from planner runtime lookup.
159 Args:
160 plan (SavePlan): Save plan containing WriteItems.
161 planner (SavePlanner): Save planner.
163 Returns:
164 dict[str, Any]: Dictionary mapping FQN to tensor data.
166 Raises:
167 RuntimeError: If tensor data cannot be resolved for an item.
168 """
169 tensor_dict: dict[str, Any] = {}
170 for item in plan.items:
171 if item.type.value == "tensor" and item.tensor_data:
172 tensor = planner.get_data(item)
173 if tensor is None:
174 raise RuntimeError(
175 f"Tensor data could not be resolved for index {item.index}. "
176 f"FQN: {item.index.fqn}"
177 )
178 fqn = item.index.fqn
179 tensor_dict[fqn] = tensor
180 return tensor_dict
182 def _write_tensors(self, plan: SavePlan, tensor_dict: dict[str, Any]) -> list[WriteResult]:
183 """
184 Write all tensors to safetensors file and create WriteResults.
186 Args:
187 plan (SavePlan): Save plan containing WriteItems.
188 tensor_dict (dict[str, Any]): Dictionary mapping FQN to tensor data.
190 Returns:
191 list[WriteResult]: List of write results for tensor items.
192 """
193 if not tensor_dict:
194 return []
196 platform = get_platform()
197 file_name = f"_rank{self.rank}_.safetensors"
198 file_path = self.checkpoint_dir / file_name
199 platform.save_checkpoint(tensor_dict, str(file_path))
201 # Record StorageInfo for each tensor
202 # Note: we don't know per-tensor byte offsets, so offset=0, length=-1
203 results: list[WriteResult] = []
204 for item in plan.items:
205 if item.type.value == "tensor" and item.tensor_data:
206 storage_info = StorageInfo(
207 relative_path=file_name,
208 offset=0,
209 length=-1,
210 )
211 results.append(
212 WriteResult(
213 index=item.index,
214 storage_data=storage_info,
215 )
216 )
217 return results
219 def execute_write(self, plan: SavePlan, planner: SavePlanner) -> list[WriteResult]:
220 """
221 Write data to storage and return per-item storage metadata.
223 Group tensors into safetensors files and bytes into separate files, recording StorageInfo for each item.
225 Args:
226 plan (SavePlan): Save plan containing WriteItems.
227 planner (SavePlanner): Save planner.
229 Returns:
230 list[WriteResult]: List of write results with storage metadata.
231 """
232 results: list[WriteResult] = []
234 # Write all BYTE_IO items into one file per rank
235 results.extend(self._write_bytes_items(plan, planner))
237 # Collect and write tensors
238 tensor_dict = self._collect_tensors(plan, planner)
239 results.extend(self._write_tensors(plan, tensor_dict))
241 return results
243 def finalize_checkpoint(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
244 """
245 Finish writing checkpoint and populate metadata.storage_data.
247 When use_collectives=True: only coordinator saves global metadata to .metadata.
248 When use_collectives=False: each rank saves its own metadata to .rank{rank}_metadata,
249 no cross-rank interaction.
251 Args:
252 metadata (Metadata): Checkpoint metadata to update.
253 results (list[list[WriteResult]]): Write results from all ranks (or single rank when use_collectives=False).
254 """
255 should_save = self.use_collectives and self.is_coordinator or not self.use_collectives
257 if should_save:
258 # Build storage_data: map MetadataIndex -> StorageInfo
259 storage_md: dict[MetadataIndex, StorageInfo] = {}
260 for wr_list in results:
261 for wr in wr_list:
262 storage_md[wr.index] = wr.storage_data
263 metadata.storage_data = storage_md
265 # Save metadata file
266 if self.use_collectives:
267 metadata_file = self.checkpoint_dir / METADATA_FILE_NAME
268 else:
269 metadata_file = self.checkpoint_dir / f".rank{self.rank}_metadata"
270 with open(metadata_file, "wb") as f:
271 pickle.dump(metadata, f)
274def _copy_tensor_to_target(
275 req: ReadItem, tensor: Any, target_tensor: Any, planner: LoadPlanner
276) -> None:
277 """
278 Copy tensor data to target tensor and commit.
280 Args:
281 req (ReadItem): ReadItem request.
282 tensor (Any): Source tensor (tensor-like object).
283 target_tensor (Any): Target tensor (tensor-like object).
284 planner (LoadPlanner): Load planner for committing.
285 """
286 if hasattr(target_tensor, "copy_"):
287 target_tensor.copy_(tensor)
288 planner.apply_tensor(req, target_tensor)
289 else:
290 # mindspore or non-tensor: copy via commit path
291 planner.apply_tensor(req, tensor)
294def _load_bytes_file(
295 path: str,
296 reqs: list[ReadItem],
297 planner: LoadPlanner,
298 storage_data: dict[MetadataIndex, StorageInfo],
299) -> None:
300 """
301 Load bytes from a file.
303 Args:
304 path (str): Path to the bytes file.
305 reqs (list[ReadItem]): List of ReadItems for this file.
306 planner (LoadPlanner): Load planner for loading bytes.
307 """
308 with open(path, "rb") as f:
309 for req in reqs:
310 storage_info = storage_data.get(req.storage_index)
311 if storage_info is None:
312 raise KeyError(
313 f"StorageInfo not found for index {req.storage_index}"
314 )
315 f.seek(storage_info.offset)
316 value = f.read(storage_info.length)
317 planner.apply_bytes(req, value)
320def _get_tensor_size(tensor: Any) -> Optional[tuple]:
321 """
322 Get size/shape of a tensor.
324 Args:
325 tensor (Any): Tensor object (tensor-like with shape/size attribute).
327 Returns:
328 Optional[tuple]: Tuple of tensor size or None if not available.
329 """
330 if hasattr(tensor, "size") and callable(tensor.size):
331 return tuple(tensor.size())
332 return getattr(tensor, "shape", None)
335def _load_tensor_file(
336 path: str, reqs: list[ReadItem], planner: LoadPlanner
337) -> None:
338 """
339 Load and process tensors from a safetensors file.
341 Args:
342 path (str): Path to the safetensors file.
343 reqs (list[ReadItem]): List of ReadItems for this file.
344 planner (LoadPlanner): Load planner for resolving and committing tensors.
345 """
346 platform = get_platform()
348 if platform.platform_type == PlatformType.PYTORCH:
349 with safe_open(path, framework="pt", device="cpu") as tensor_file:
350 for req in reqs:
351 fqn = req.storage_index.fqn
352 if fqn not in tensor_file.keys():
353 raise KeyError(f"Key {fqn} not found in checkpoint file {path}")
354 tensor_slices = tuple(
355 slice(int(off), int(off) + int(length))
356 for off, length in zip(req.storage_offsets, req.lengths)
357 )
358 if tensor_slices:
359 tensor = tensor_file.get_slice(fqn)[tensor_slices]
360 else:
361 tensor = narrow_tensor_by_index(
362 tensor_file.get_tensor(fqn),
363 req.storage_offsets,
364 req.lengths,
365 )
367 target_tensor = planner.acquire_tensor(req)
368 if hasattr(target_tensor, "detach"):
369 target_tensor = target_tensor.detach()
371 # Size check (torch-aligned AssertionError)
372 target_size = _get_tensor_size(target_tensor)
373 tensor_size = _get_tensor_size(tensor)
374 if target_size is not None and tensor_size is not None:
375 if target_size != tensor_size:
376 raise AssertionError(
377 f"req {req.storage_index} mismatch sizes "
378 f"{target_size} vs {tensor_size}"
379 )
381 # Copy data to target
382 _copy_tensor_to_target(req, tensor, target_tensor, planner)
383 return
385 param_dict = platform.load_checkpoint(path)
386 for req in reqs:
387 fqn = req.storage_index.fqn
388 if fqn not in param_dict:
389 raise KeyError(f"Key {fqn} not found in checkpoint file {path}")
390 full_tensor = param_dict[fqn]
391 tensor = narrow_tensor_by_index(
392 full_tensor,
393 req.storage_offsets,
394 req.lengths,
395 )
397 target_tensor = planner.acquire_tensor(req)
398 if hasattr(target_tensor, "detach"):
399 target_tensor = target_tensor.detach()
401 # Size check (torch-aligned AssertionError)
402 target_size = _get_tensor_size(target_tensor)
403 tensor_size = _get_tensor_size(tensor)
404 if target_size is not None and tensor_size is not None:
405 if target_size != tensor_size:
406 raise AssertionError(
407 f"req {req.storage_index} mismatch sizes "
408 f"{target_size} vs {tensor_size}"
409 )
411 # Copy data to target
412 _copy_tensor_to_target(req, tensor, target_tensor, planner)
415class FileSystemReader(StorageReader):
416 """
417 File system storage reader implementation.
419 Reads checkpoint data from the local file system, loading tensors
420 from safetensors files and bytes from separate files.
421 """
423 def __init__(self, checkpoint_dir: Union[Path, str]):
424 self.checkpoint_dir = Path(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir
425 # Cached storage layout: MetadataIndex -> StorageInfo (torch-aligned)
426 self.storage_data: Optional[dict[MetadataIndex, StorageInfo]] = None
427 self.rank: int = 0
428 self.is_coordinator: bool = False
430 def initialize_reader(self, checkpoint_id: Optional[Union[Path, str]] = None) -> None:
431 """
432 Initialize storage reader with new checkpoint directory.
434 Args:
435 checkpoint_id (Optional[Union[Path, str]]): New checkpoint directory path. Default None.
436 """
437 if checkpoint_id:
438 self.checkpoint_dir = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id
440 def load_metadata(self, **kwargs) -> Metadata:
441 """
442 Load checkpoint metadata from file.
444 When rank is provided in kwargs: load rank-local metadata from .rank{rank}_metadata
445 (for checkpoints saved with use_collectives=False).
446 Otherwise: load global metadata from .metadata.
448 Args:
449 **kwargs: Optional arguments (e.g., rank for rank-local metadata).
451 Returns:
452 Metadata: Metadata object loaded from file.
453 """
454 rank = kwargs.get("rank")
455 if rank is not None:
456 metadata_file = self.checkpoint_dir / f".rank{rank}_metadata"
457 else:
458 metadata_file = self.checkpoint_dir / METADATA_FILE_NAME
460 if not metadata_file.exists():
461 raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
462 with open(metadata_file, "rb") as f:
463 metadata = pickle.load(f)
464 return metadata
466 def configure_reader(self, metadata: Metadata, is_coordinator: bool, **kwargs) -> None:
467 """Configure storage reader."""
468 # Cache storage_data separately for quick lookup in execute_read.
469 # This mirrors torch.filesystem, where reader keeps a storage_data dict.
470 self.storage_data = getattr(metadata, "storage_data", None)
471 self.is_coordinator = is_coordinator
472 self.rank = kwargs.get("rank") if "rank" in kwargs else get_platform().get_rank()
474 def optimize_local_plan(self, plan: LoadPlan) -> LoadPlan:
475 """
476 Optimize local plan.
478 Args:
479 plan (LoadPlan): Local load plan.
481 Returns:
482 LoadPlan: Optimized local plan.
483 """
484 return plan
486 def optimize_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]:
487 """
488 Optimize global plan.
490 Args:
491 plans (list[LoadPlan]): List of local plans from all ranks.
493 Returns:
494 list[LoadPlan]: Optimized global plans.
495 """
496 return plans
498 def _get_storage_path(self, read_item: ReadItem) -> str:
499 """
500 Get storage file path for a read item.
502 Args:
503 read_item (ReadItem): ReadItem to get path for.
505 Returns:
506 str: Absolute path to the storage file.
507 """
508 if self.storage_data is None:
509 raise KeyError("Checkpoint metadata.storage_data is required for filesystem read")
510 storage_info = self.storage_data.get(read_item.storage_index)
511 if storage_info is None:
512 raise KeyError(f"StorageInfo not found for index {read_item.storage_index}")
513 return str(self.checkpoint_dir / storage_info.relative_path)
515 def _group_items_by_file(self, plan: LoadPlan) -> dict[str, list]:
516 """
517 Group ReadItems by storage file path.
519 Args:
520 plan (LoadPlan): Load plan containing ReadItems.
522 Returns:
523 dict[str, list[ReadItem]]: Dictionary mapping file paths to lists of ReadItems.
524 """
525 per_file: dict[str, list] = {}
526 for read_item in plan.items:
527 path = self._get_storage_path(read_item)
528 per_file.setdefault(path, []).append(read_item)
529 return per_file
531 def execute_read(self, plan: LoadPlan, planner: LoadPlanner) -> None:
532 """
533 Read data from storage.
535 Aligned with torch filesystem read_data: groups ReadItems by file,
536 loads each file once, narrows tensors by storage_offsets/lengths for
537 resharding, then resolves/copies/commits data.
539 Args:
540 plan (LoadPlan): Load plan containing ReadItems.
541 planner (LoadPlanner): Load planner for resolving and committing tensors.
542 """
543 # Group ReadItems by storage file path (like torch per_file)
544 per_file = self._group_items_by_file(plan)
546 # Process each file
547 for path, reqs in per_file.items():
548 if not os.path.exists(path):
549 raise FileNotFoundError(f"Checkpoint file not found: {path}")
551 if path.endswith(".bytes"):
552 # BYTE_IO: one bytes file per rank with per-item offsets.
553 _load_bytes_file(path, reqs, planner, self.storage_data)
554 else:
555 # TENSOR: one safetensors file per rank
556 _load_tensor_file(path, reqs, planner)