Coverage for hyper_parallel / core / checkpoint / filesystem_storage.py: 87%
172 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""File system storage implementations for checkpoint save and load."""
16import os
17import pickle
18from pathlib import Path
19from typing import Any, Optional, Union
21from hyper_parallel.core.checkpoint.metadata import Metadata, MetadataIndex
22from hyper_parallel.core.checkpoint.planner import (
23 LoadItemType,
24 LoadPlan,
25 LoadPlanner,
26 ReadItem,
27 SavePlan,
28 SavePlanner,
29 WriteItem,
30)
31from hyper_parallel.core.checkpoint.storage import (
32 StorageInfo,
33 StorageReader,
34 StorageWriter,
35 WriteResult,
36 _metadata_file_name,
37)
38from hyper_parallel.core.checkpoint.util import narrow_tensor_by_index
39from hyper_parallel.platform import get_platform
42class FileSystemWriter(StorageWriter):
43 """
44 File system storage writer implementation.
46 Saves checkpoint data to the local file system, organizing tensors
47 into safetensors files and bytes into separate files.
48 """
50 def __init__(self, checkpoint_dir: Union[Path, str]):
51 self.checkpoint_dir = Path(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir
52 self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
53 self.rank: int = 0
54 self.is_coordinator: bool = False
55 self.use_collectives: bool = True
57 def initialize_writer(self, checkpoint_id: Optional[Union[Path, str]] = None) -> None:
58 """
59 Initialize storage writer with new checkpoint directory.
61 Args:
62 checkpoint_id (Optional[Union[Path, str]]): New checkpoint directory path. Default None.
63 """
64 if checkpoint_id:
65 self.checkpoint_dir = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id
66 self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
68 def configure_writer(self, is_coordinator: bool, **kwargs) -> None:
69 """
70 Configure storage writer.
72 Args:
73 is_coordinator (bool): Whether this rank is the coordinator.
74 **kwargs: Additional keyword arguments (e.g., rank, use_collectives).
75 """
76 self.is_coordinator = is_coordinator
77 self.rank = kwargs.get("rank", get_platform().get_rank())
78 self.use_collectives = kwargs.get("use_collectives", True)
80 def optimize_local_plan(self, plan: SavePlan) -> SavePlan:
81 """
82 Optimize local plan.
84 Args:
85 plan (SavePlan): Local save plan.
87 Returns:
88 SavePlan: Optimized local plan.
89 """
90 return plan
92 def optimize_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
93 """
94 Optimize global plan.
96 Args:
97 plans (list[SavePlan]): List of local plans from all ranks.
99 Returns:
100 list[SavePlan]: Optimized global plans.
101 """
102 return plans
104 def _write_bytes_item(self, item: WriteItem) -> WriteResult:
105 """
106 Write a single bytes item to storage.
108 Args:
109 item (WriteItem): WriteItem containing bytes data.
111 Returns:
112 WriteResult: Write result with storage metadata.
113 """
114 fqn = item.index.fqn
115 file_name = f"{fqn}_rank{self.rank}.bytes"
116 file_path = self.checkpoint_dir / file_name
117 with open(file_path, "wb") as f:
118 if isinstance(item.bytes_io_data, bytes):
119 f.write(item.bytes_io_data)
120 else:
121 pickle.dump(item.bytes_io_data, f)
122 try:
123 length = f.tell()
124 except (OSError, IOError):
125 length = 0
126 storage_info = StorageInfo(
127 relative_path=file_name,
128 offset=0,
129 length=length,
130 )
131 return WriteResult(
132 index=item.index,
133 storage_data=storage_info,
134 )
136 def _collect_tensors(self, plan: SavePlan, planner: SavePlanner) -> dict[str, Any]:
137 """
138 Collect tensor data from planner cache.
140 Args:
141 plan (SavePlan): Save plan containing WriteItems.
142 planner (SavePlanner): Save planner.
144 Returns:
145 dict[str, Any]: Dictionary mapping FQN to tensor data.
147 Raises:
148 RuntimeError: If tensor data not found in planner cache.
149 """
150 tensor_dict: dict[str, Any] = {}
151 for item in plan.items:
152 if item.type.value == "tensor" and item.tensor_data:
153 # Get tensor from planner cache instead of tensor_data
154 tensor = planner.get_tensor(item.index)
155 if tensor is None:
156 raise RuntimeError(
157 f"Tensor data not found in planner cache for index {item.index}. "
158 f"FQN: {item.index.fqn}"
159 )
160 fqn = item.index.fqn
161 tensor_dict[fqn] = tensor
162 return tensor_dict
164 def _write_tensors(self, plan: SavePlan, tensor_dict: dict[str, Any]) -> list[WriteResult]:
165 """
166 Write all tensors to safetensors file and create WriteResults.
168 Args:
169 plan (SavePlan): Save plan containing WriteItems.
170 tensor_dict (dict[str, Any]): Dictionary mapping FQN to tensor data.
172 Returns:
173 list[WriteResult]: List of write results for tensor items.
174 """
175 if not tensor_dict:
176 return []
178 platform = get_platform()
179 file_name = f"_rank{self.rank}_.safetensors"
180 file_path = self.checkpoint_dir / file_name
181 platform.save_checkpoint(tensor_dict, str(file_path))
183 # Record StorageInfo for each tensor
184 # Note: we don't know per-tensor byte offsets, so offset=0, length=-1
185 results: list[WriteResult] = []
186 for item in plan.items:
187 if item.type.value == "tensor" and item.tensor_data:
188 storage_info = StorageInfo(
189 relative_path=file_name,
190 offset=0,
191 length=-1,
192 )
193 results.append(
194 WriteResult(
195 index=item.index,
196 storage_data=storage_info,
197 )
198 )
199 return results
201 def execute_write(self, plan: SavePlan, planner: SavePlanner) -> list[WriteResult]:
202 """
203 Write data to storage and return per-item storage metadata.
205 Group tensors into safetensors files and bytes into separate files, recording StorageInfo for each item.
207 Args:
208 plan (SavePlan): Save plan containing WriteItems.
209 planner (SavePlanner): Save planner.
211 Returns:
212 list[WriteResult]: List of write results with storage metadata.
213 """
214 results: list[WriteResult] = []
216 # Collect tensors and write bytes objects
217 for item in plan.items:
218 if item.type.value == "byte_io":
219 results.append(self._write_bytes_item(item))
221 # Collect and write tensors
222 tensor_dict = self._collect_tensors(plan, planner)
223 results.extend(self._write_tensors(plan, tensor_dict))
225 return results
227 def finalize_checkpoint(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
228 """
229 Finish writing checkpoint and populate metadata.storage_data.
231 When use_collectives=True: only coordinator saves global metadata to .metadata.
232 When use_collectives=False: each rank saves its own metadata to .rank{rank}_metadata,
233 no cross-rank interaction.
235 Args:
236 metadata (Metadata): Checkpoint metadata to update.
237 results (list[list[WriteResult]]): Write results from all ranks (or single rank when use_collectives=False).
238 """
239 should_save = self.use_collectives and self.is_coordinator or not self.use_collectives
241 if should_save:
242 # Build storage_data: map MetadataIndex -> StorageInfo
243 storage_md: dict[MetadataIndex, StorageInfo] = {}
244 for wr_list in results:
245 for wr in wr_list:
246 storage_md[wr.index] = wr.storage_data
247 metadata.storage_data = storage_md
249 # Save metadata file
250 if self.use_collectives:
251 metadata_file = self.checkpoint_dir / _metadata_file_name
252 else:
253 metadata_file = self.checkpoint_dir / f".rank{self.rank}_metadata"
254 with open(metadata_file, "wb") as f:
255 pickle.dump(metadata, f)
258def _copy_tensor_to_target(
259 req: ReadItem, tensor: Any, target_tensor: Any, planner: LoadPlanner
260) -> None:
261 """
262 Copy tensor data to target tensor and commit.
264 Args:
265 req (ReadItem): ReadItem request.
266 tensor (Any): Source tensor (tensor-like object).
267 target_tensor (Any): Target tensor (tensor-like object).
268 planner (LoadPlanner): Load planner for committing.
269 """
270 if hasattr(target_tensor, "copy_"):
271 target_tensor.copy_(tensor)
272 planner.apply_tensor(req, target_tensor)
273 else:
274 # mindspore or non-tensor: copy via commit path
275 planner.apply_tensor(req, tensor)
278def _load_bytes_file(path: str, reqs: list[ReadItem], planner: LoadPlanner) -> None:
279 """
280 Load bytes from a file.
282 Args:
283 path (str): Path to the bytes file.
284 reqs (list[ReadItem]): List of ReadItems for this file.
285 planner (LoadPlanner): Load planner for loading bytes.
286 """
287 for req in reqs:
288 with open(path, "rb") as f:
289 value = f.read()
290 planner.apply_bytes(req, value)
293def _get_tensor_size(tensor: Any) -> Optional[tuple]:
294 """
295 Get size/shape of a tensor.
297 Args:
298 tensor (Any): Tensor object (tensor-like with shape/size attribute).
300 Returns:
301 Optional[tuple]: Tuple of tensor size or None if not available.
302 """
303 if hasattr(tensor, "size") and callable(tensor.size):
304 return tuple(tensor.size())
305 return getattr(tensor, "shape", None)
308def _load_tensor_file(
309 path: str, reqs: list[ReadItem], planner: LoadPlanner
310) -> None:
311 """
312 Load and process tensors from a safetensors file.
314 Args:
315 path (str): Path to the safetensors file.
316 reqs (list[ReadItem]): List of ReadItems for this file.
317 planner (LoadPlanner): Load planner for resolving and committing tensors.
318 """
319 platform = get_platform()
320 param_dict = platform.load_checkpoint(path)
322 for req in reqs:
323 fqn = req.storage_index.fqn
324 if fqn not in param_dict:
325 raise KeyError(f"Key {fqn} not found in checkpoint file {path}")
327 full_tensor = param_dict[fqn]
328 # Narrow by storage_offsets/lengths (resharding)
329 tensor = narrow_tensor_by_index(
330 full_tensor,
331 req.storage_offsets,
332 req.lengths,
333 )
334 target_tensor = planner.acquire_tensor(req)
335 if hasattr(target_tensor, "detach"):
336 target_tensor = target_tensor.detach()
338 # Size check (torch-aligned AssertionError)
339 target_size = _get_tensor_size(target_tensor)
340 tensor_size = _get_tensor_size(tensor)
341 if target_size is not None and tensor_size is not None:
342 if target_size != tensor_size:
343 raise AssertionError(
344 f"req {req.storage_index} mismatch sizes "
345 f"{target_size} vs {tensor_size}"
346 )
348 # Copy data to target
349 _copy_tensor_to_target(req, tensor, target_tensor, planner)
352class FileSystemReader(StorageReader):
353 """
354 File system storage reader implementation.
356 Reads checkpoint data from the local file system, loading tensors
357 from safetensors files and bytes from separate files.
358 """
360 def __init__(self, checkpoint_dir: Union[Path, str]):
361 self.checkpoint_dir = Path(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir
362 # Cached storage layout: MetadataIndex -> StorageInfo (torch-aligned)
363 self.storage_data: Optional[dict[MetadataIndex, StorageInfo]] = None
364 self.rank: int = 0
365 self.is_coordinator: bool = False
367 def initialize_reader(self, checkpoint_id: Optional[Union[Path, str]] = None) -> None:
368 """
369 Initialize storage reader with new checkpoint directory.
371 Args:
372 checkpoint_id (Optional[Union[Path, str]]): New checkpoint directory path. Default None.
373 """
374 if checkpoint_id:
375 self.checkpoint_dir = Path(checkpoint_id) if isinstance(checkpoint_id, str) else checkpoint_id
377 def load_metadata(self, **kwargs) -> Metadata:
378 """
379 Load checkpoint metadata from file.
381 When rank is provided in kwargs: load rank-local metadata from .rank{rank}_metadata
382 (for checkpoints saved with use_collectives=False).
383 Otherwise: load global metadata from .metadata.
385 Args:
386 **kwargs: Optional arguments (e.g., rank for rank-local metadata).
388 Returns:
389 Metadata: Metadata object loaded from file.
390 """
391 rank = kwargs.get("rank")
392 if rank is not None:
393 metadata_file = self.checkpoint_dir / f".rank{rank}_metadata"
394 else:
395 metadata_file = self.checkpoint_dir / _metadata_file_name
397 if not metadata_file.exists():
398 raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
399 with open(metadata_file, "rb") as f:
400 metadata = pickle.load(f)
401 return metadata
403 def configure_reader(self, metadata: Metadata, is_coordinator: bool, **kwargs) -> None:
404 """Configure storage reader."""
405 # Cache storage_data separately for quick lookup in execute_read.
406 # This mirrors torch.filesystem, where reader keeps a storage_data dict.
407 self.storage_data = getattr(metadata, "storage_data", None)
408 self.is_coordinator = is_coordinator
409 self.rank = kwargs.get("rank", get_platform().get_rank())
411 def optimize_local_plan(self, plan: LoadPlan) -> LoadPlan:
412 """
413 Optimize local plan.
415 Args:
416 plan (LoadPlan): Local load plan.
418 Returns:
419 LoadPlan: Optimized local plan.
420 """
421 return plan
423 def optimize_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]:
424 """
425 Optimize global plan.
427 Args:
428 plans (list[LoadPlan]): List of local plans from all ranks.
430 Returns:
431 list[LoadPlan]: Optimized global plans.
432 """
433 return plans
435 def _get_storage_path(self, read_item: ReadItem) -> str:
436 """
437 Get storage file path for a read item.
439 Args:
440 read_item (ReadItem): ReadItem to get path for.
442 Returns:
443 str: Absolute path to the storage file.
444 """
445 storage_data = self.storage_data
447 if storage_data is not None:
448 storage_info = storage_data.get(read_item.storage_index)
449 if storage_info is None:
450 raise KeyError(f"StorageInfo not found for index {read_item.storage_index}")
451 return str(self.checkpoint_dir / storage_info.relative_path)
452 # Fallback: derive path from rank & fqn (legacy format without storage_data)
453 if read_item.type == LoadItemType.TENSOR:
454 rank = read_item.storage_index.index or self.rank
455 return str(self.checkpoint_dir / f"_rank{rank}_.safetensors")
456 fqn = read_item.storage_index.fqn
457 rank = read_item.storage_index.index or self.rank
458 return str(self.checkpoint_dir / f"{fqn}_rank{rank}.bytes")
460 def _group_items_by_file(self, plan: LoadPlan) -> dict[str, list]:
461 """
462 Group ReadItems by storage file path.
464 Args:
465 plan (LoadPlan): Load plan containing ReadItems.
467 Returns:
468 dict[str, list[ReadItem]]: Dictionary mapping file paths to lists of ReadItems.
469 """
470 per_file: dict[str, list] = {}
471 for read_item in plan.items:
472 path = self._get_storage_path(read_item)
473 per_file.setdefault(path, []).append(read_item)
474 return per_file
476 def execute_read(self, plan: LoadPlan, planner: LoadPlanner) -> None:
477 """
478 Read data from storage.
480 Aligned with torch filesystem read_data: groups ReadItems by file,
481 loads each file once, narrows tensors by storage_offsets/lengths for
482 resharding, then resolves/copies/commits data.
484 Args:
485 plan (LoadPlan): Load plan containing ReadItems.
486 planner (LoadPlanner): Load planner for resolving and committing tensors.
487 """
488 # Group ReadItems by storage file path (like torch per_file)
489 per_file = self._group_items_by_file(plan)
491 # Process each file
492 for path, reqs in per_file.items():
493 if not os.path.exists(path):
494 raise FileNotFoundError(f"Checkpoint file not found: {path}")
496 if path.endswith(".bytes"):
497 # BYTE_IO: one file per (fqn, rank)
498 _load_bytes_file(path, reqs, planner)
499 else:
500 # TENSOR: one safetensors file per rank
501 _load_tensor_file(path, reqs, planner)