Coverage for hyper_parallel / core / checkpoint / metadata.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Checkpoint metadata structures for distributed checkpoint save and load.""" 

16from dataclasses import dataclass, field 

17from typing import Any, Optional, Union 

18 

19 

20@dataclass(frozen=True) 

21class MetadataIndex: 

22 """ 

23 Index to identify a specific piece of data in the checkpoint. 

24 

25 Attributes: 

26 fqn: Fully qualified name of the tensor/object. 

27 offset: Offset in the tensor (for sharded tensors). Default (). 

28 index: Index for sharded tensors (None for non-sharded). Default None. 

29 """ 

30 fqn: str 

31 offset: tuple = field(default_factory=tuple) 

32 index: Optional[int] = None 

33 

34 

35@dataclass(frozen=True) 

36class ChunkStorageMetadata: 

37 """ 

38 Metadata for a chunk of storage. 

39 

40 Represents a portion of a distributed tensor stored in the checkpoint. 

41 

42 Attributes: 

43 offsets: Offsets in the global tensor for each dimension. 

44 sizes: Sizes of the chunk for each dimension. 

45 """ 

46 offsets: tuple 

47 sizes: tuple 

48 

49 

50@dataclass(frozen=True) 

51class TensorProperties: 

52 """ 

53 Properties of a tensor. 

54 

55 Attributes: 

56 dtype: Data type of the tensor (as string). 

57 requires_grad: Whether the tensor requires gradients. Default False. 

58 memory_format: Memory format (optional). Default None. 

59 """ 

60 dtype: str 

61 requires_grad: bool = False 

62 memory_format: Optional[str] = None 

63 

64 

65@dataclass 

66class BytesStorageMetadata: 

67 """Metadata for bytes data stored in checkpoint.""" 

68 

69@dataclass(frozen=True) 

70class TensorStorageMetadata: 

71 """ 

72 Metadata for a distributed tensor. 

73 

74 Contains properties, global size, and list of chunks stored across ranks. 

75 

76 Attributes: 

77 properties: Tensor properties (dtype, etc.). 

78 size: Global size of the tensor. 

79 chunks: List of chunks stored in the checkpoint. Default []. 

80 """ 

81 properties: TensorProperties 

82 size: tuple 

83 chunks: list[ChunkStorageMetadata] = field(default_factory=list) 

84 

85 

86@dataclass 

87class Metadata: 

88 """ 

89 Global metadata for a checkpoint. 

90 

91 Contains metadata for all items in the state_dict, along with planner and storage-specific data. 

92 

93 Attributes: 

94 state_dict_metadata: Mapping from FQN to storage metadata. 

95 planner_data: Planner-specific data (optional). Default None. 

96 storage_data: Storage-specific data (optional). Default None. 

97 version: Checkpoint format version. Default "1.0". 

98 """ 

99 state_dict_metadata: dict[str, Union[TensorStorageMetadata, BytesStorageMetadata]] 

100 planner_data: Any = None # Planner-specific data (can be any type) 

101 storage_data: Optional[dict[MetadataIndex, Any]] = None # Storage mapping: MetadataIndex -> StorageInfo 

102 version: str = "1.0"