Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / metadata.py: 100%
29 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Checkpoint metadata structures for distributed checkpoint save and load."""
16from dataclasses import dataclass, field
17from typing import Any, Optional, Union
20@dataclass(frozen=True)
21class MetadataIndex:
22 """
23 Index to identify a specific piece of data in the checkpoint.
25 Attributes:
26 fqn: Fully qualified name of the tensor/object.
27 offset: Offset in the tensor (for sharded tensors). Default ().
28 index: Index for sharded tensors (None for non-sharded). Default None.
29 """
30 fqn: str
31 offset: tuple = field(default_factory=tuple)
32 index: Optional[int] = None
35@dataclass(frozen=True)
36class ChunkStorageMetadata:
37 """
38 Metadata for a chunk of storage.
40 Represents a portion of a distributed tensor stored in the checkpoint.
42 Attributes:
43 offsets: Offsets in the global tensor for each dimension.
44 sizes: Sizes of the chunk for each dimension.
45 """
46 offsets: tuple
47 sizes: tuple
50@dataclass(frozen=True)
51class TensorProperties:
52 """
53 Properties of a tensor.
55 Attributes:
56 dtype: Data type of the tensor (as string).
57 requires_grad: Whether the tensor requires gradients. Default False.
58 memory_format: Memory format (optional). Default None.
59 """
60 dtype: str
61 requires_grad: bool = False
62 memory_format: Optional[str] = None
65@dataclass
66class BytesStorageMetadata:
67 """Metadata for bytes data stored in checkpoint."""
70@dataclass(frozen=True)
71class TensorStorageMetadata:
72 """
73 Metadata for a distributed tensor.
75 Contains properties, global size, and list of chunks stored across ranks.
77 Attributes:
78 properties: Tensor properties (dtype, etc.).
79 size: Global size of the tensor.
80 chunks: List of chunks stored in the checkpoint. Default [].
81 """
82 properties: TensorProperties
83 size: tuple
84 chunks: list[ChunkStorageMetadata] = field(default_factory=list)
87@dataclass
88class Metadata:
89 """
90 Global metadata for a checkpoint.
92 Contains metadata for all items in the state_dict, along with planner and storage-specific data.
94 Attributes:
95 state_dict_metadata: Mapping from FQN to storage metadata.
96 planner_data: Planner-specific data (optional). Default None.
97 storage_data: Storage-specific data (optional). Default None.
98 version: Checkpoint format version. Default "1.0".
99 """
100 state_dict_metadata: dict[str, Union[TensorStorageMetadata, BytesStorageMetadata]]
101 planner_data: Any = None # Planner-specific data (can be any type)
102 storage_data: Optional[dict[MetadataIndex, Any]] = None # Storage mapping: MetadataIndex -> StorageInfo
103 version: str = "1.0"