Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / metadata.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Checkpoint metadata structures for distributed checkpoint save and load.""" 

16from dataclasses import dataclass, field 

17from typing import Any, Optional, Union 

18 

19 

20@dataclass(frozen=True) 

21class MetadataIndex: 

22 """ 

23 Index to identify a specific piece of data in the checkpoint. 

24 

25 Attributes: 

26 fqn: Fully qualified name of the tensor/object. 

27 offset: Offset in the tensor (for sharded tensors). Default (). 

28 index: Index for sharded tensors (None for non-sharded). Default None. 

29 """ 

30 fqn: str 

31 offset: tuple = field(default_factory=tuple) 

32 index: Optional[int] = None 

33 

34 

35@dataclass(frozen=True) 

36class ChunkStorageMetadata: 

37 """ 

38 Metadata for a chunk of storage. 

39 

40 Represents a portion of a distributed tensor stored in the checkpoint. 

41 

42 Attributes: 

43 offsets: Offsets in the global tensor for each dimension. 

44 sizes: Sizes of the chunk for each dimension. 

45 """ 

46 offsets: tuple 

47 sizes: tuple 

48 

49 

50@dataclass(frozen=True) 

51class TensorProperties: 

52 """ 

53 Properties of a tensor. 

54 

55 Attributes: 

56 dtype: Data type of the tensor (as string). 

57 requires_grad: Whether the tensor requires gradients. Default False. 

58 memory_format: Memory format (optional). Default None. 

59 """ 

60 dtype: str 

61 requires_grad: bool = False 

62 memory_format: Optional[str] = None 

63 

64 

65@dataclass 

66class BytesStorageMetadata: 

67 """Metadata for bytes data stored in checkpoint.""" 

68 

69 

70@dataclass(frozen=True) 

71class TensorStorageMetadata: 

72 """ 

73 Metadata for a distributed tensor. 

74 

75 Contains properties, global size, and list of chunks stored across ranks. 

76 

77 Attributes: 

78 properties: Tensor properties (dtype, etc.). 

79 size: Global size of the tensor. 

80 chunks: List of chunks stored in the checkpoint. Default []. 

81 """ 

82 properties: TensorProperties 

83 size: tuple 

84 chunks: list[ChunkStorageMetadata] = field(default_factory=list) 

85 

86 

87@dataclass 

88class Metadata: 

89 """ 

90 Global metadata for a checkpoint. 

91 

92 Contains metadata for all items in the state_dict, along with planner and storage-specific data. 

93 

94 Attributes: 

95 state_dict_metadata: Mapping from FQN to storage metadata. 

96 planner_data: Planner-specific data (optional). Default None. 

97 storage_data: Storage-specific data (optional). Default None. 

98 version: Checkpoint format version. Default "1.0". 

99 """ 

100 state_dict_metadata: dict[str, Union[TensorStorageMetadata, BytesStorageMetadata]] 

101 planner_data: Any = None # Planner-specific data (can be any type) 

102 storage_data: Optional[dict[MetadataIndex, Any]] = None # Storage mapping: MetadataIndex -> StorageInfo 

103 version: str = "1.0"