Coverage for hyper_parallel / core / checkpoint / metadata.py: 100%
29 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Checkpoint metadata structures for distributed checkpoint save and load."""
16from dataclasses import dataclass, field
17from typing import Any, Optional, Union
20@dataclass(frozen=True)
21class MetadataIndex:
22 """
23 Index to identify a specific piece of data in the checkpoint.
25 Attributes:
26 fqn: Fully qualified name of the tensor/object.
27 offset: Offset in the tensor (for sharded tensors). Default ().
28 index: Index for sharded tensors (None for non-sharded). Default None.
29 """
30 fqn: str
31 offset: tuple = field(default_factory=tuple)
32 index: Optional[int] = None
35@dataclass(frozen=True)
36class ChunkStorageMetadata:
37 """
38 Metadata for a chunk of storage.
40 Represents a portion of a distributed tensor stored in the checkpoint.
42 Attributes:
43 offsets: Offsets in the global tensor for each dimension.
44 sizes: Sizes of the chunk for each dimension.
45 """
46 offsets: tuple
47 sizes: tuple
50@dataclass(frozen=True)
51class TensorProperties:
52 """
53 Properties of a tensor.
55 Attributes:
56 dtype: Data type of the tensor (as string).
57 requires_grad: Whether the tensor requires gradients. Default False.
58 memory_format: Memory format (optional). Default None.
59 """
60 dtype: str
61 requires_grad: bool = False
62 memory_format: Optional[str] = None
65@dataclass
66class BytesStorageMetadata:
67 """Metadata for bytes data stored in checkpoint."""
69@dataclass(frozen=True)
70class TensorStorageMetadata:
71 """
72 Metadata for a distributed tensor.
74 Contains properties, global size, and list of chunks stored across ranks.
76 Attributes:
77 properties: Tensor properties (dtype, etc.).
78 size: Global size of the tensor.
79 chunks: List of chunks stored in the checkpoint. Default [].
80 """
81 properties: TensorProperties
82 size: tuple
83 chunks: list[ChunkStorageMetadata] = field(default_factory=list)
86@dataclass
87class Metadata:
88 """
89 Global metadata for a checkpoint.
91 Contains metadata for all items in the state_dict, along with planner and storage-specific data.
93 Attributes:
94 state_dict_metadata: Mapping from FQN to storage metadata.
95 planner_data: Planner-specific data (optional). Default None.
96 storage_data: Storage-specific data (optional). Default None.
97 version: Checkpoint format version. Default "1.0".
98 """
99 state_dict_metadata: dict[str, Union[TensorStorageMetadata, BytesStorageMetadata]]
100 planner_data: Any = None # Planner-specific data (can be any type)
101 storage_data: Optional[dict[MetadataIndex, Any]] = None # Storage mapping: MetadataIndex -> StorageInfo
102 version: str = "1.0"