Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / planner.py: 75%
83 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Planner interfaces and implementations"""
16import abc
17from dataclasses import dataclass, field
18from enum import Enum
19from typing import Any, Optional, Union
21from hyper_parallel.core.distributed_checkpoint.metadata import Metadata, MetadataIndex
24class WriteItemType(Enum):
25 """Type of write item."""
26 TENSOR = "tensor"
27 BYTE_IO = "byte_io"
30class LoadItemType(Enum):
31 """Type of load item."""
32 TENSOR = "tensor"
33 BYTE_IO = "byte_io"
36@dataclass(frozen=True)
37class WriteItem:
38 """
39 Item to be written to storage.
41 Represents a single logical item (tensor or bytes) to be saved.
43 Attributes:
44 index: Metadata index identifying this item.
45 type: Type of write item (TENSOR or BYTE_IO).
46 tensor_data: Dictionary containing tensor data (for TENSOR type). Default None.
47 bytes_io_data: Bytes data (for BYTE_IO type). Default None.
48 """
49 index: MetadataIndex
50 type: WriteItemType
51 # Keys: 'chunk' (ChunkStorageMetadata), 'properties' (TensorProperties), 'size' (tuple).
52 # Actual tensor data is in planner's tensor cache, not here, to avoid all_gather of tensors.
53 tensor_data: Optional[dict[str, Any]] = None
54 bytes_io_data: Optional[Union[bytes, Any]] = None # Bytes or pickle-serializable object
56 def tensor_storage_size(self) -> Optional[int]:
57 """
58 Best-effort storage size estimation in bytes for tensor items.
60 Returns:
61 Optional[int]: Estimated storage size in bytes for tensor items,
62 or None if estimation cannot be performed (e.g., for non-tensor items).
63 """
64 if self.type != WriteItemType.TENSOR or not self.tensor_data:
65 return None
67 # Try to estimate from metadata
68 chunk = self.tensor_data.get("chunk")
69 properties = self.tensor_data.get("properties")
70 if chunk is None or properties is None:
71 return None
73 # Get size from chunk (local chunk size, not global size)
74 size = chunk.sizes
75 num = 1
76 for dim in size:
77 num *= int(dim)
78 # Try to get dtype item size from properties
79 dtype_str = getattr(properties, "dtype", None)
80 if dtype_str is None:
81 return int(num)
82 # Simple estimation: assume common dtypes
83 dtype_to_size_map = {
84 "int32": 4, "int64": 8, "bfloat16": 2, "float16": 2, "float32": 4, "float64": 8
85 }
86 dtype_str_lower = str(dtype_str).lower()
87 elem_size = 4 # Default to 4 bytes
88 for dtype_name, size in dtype_to_size_map.items():
89 if dtype_name in dtype_str_lower:
90 elem_size = size
91 break
92 return int(num) * int(elem_size)
96@dataclass(frozen=True)
97class ReadItem:
98 """
99 Item to be read from storage.
101 Represents a single logical read operation, mapping from checkpoint storage
102 to destination state_dict location.
104 Attributes:
105 type: Type of load item (TENSOR or BYTE_IO).
106 dest_index: Metadata index identifying the destination in state_dict.
107 dest_offsets: Offsets into the destination tensor (for TENSOR type).
108 storage_index: Metadata index identifying the source in checkpoint.
109 storage_offsets: Offsets into the checkpoint storage data.
110 lengths: Size of the hypercube to copy (dimensions of the data region).
111 """
112 type: LoadItemType
113 dest_index: MetadataIndex # Index into the state_dict
114 dest_offsets: tuple # Offsets into destination tensor
115 storage_index: MetadataIndex # Index into the checkpoint
116 storage_offsets: tuple # Offset into the checkpoint data
117 lengths: tuple # Size of the hypercube to copy
120@dataclass
121class SavePlan:
122 """
123 Plan for saving checkpoint.
125 Contains write items and optional storage/planner-specific data.
127 Attributes:
128 items: List of WriteItems to be saved. Default [].
129 storage_data: Storage-specific data (optional). Default None.
130 planner_data: Planner-specific data (optional). Default None.
131 """
132 items: list[WriteItem] = field(default_factory=list)
133 storage_data: Optional[dict[MetadataIndex, Any]] = None # Storage-specific data mapping
134 planner_data: Any = None # Planner-specific data (can be any type)
137@dataclass
138class LoadPlan:
139 """
140 Plan for loading checkpoint.
142 Contains read items and optional storage/planner-specific data.
144 Attributes:
145 items: List of ReadItems to be loaded. Default [].
146 storage_data: Storage-specific data (optional). Default None.
147 planner_data: Planner-specific data (optional). Default None.
148 """
149 items: list[ReadItem] = field(default_factory=list)
150 storage_data: Optional[dict[MetadataIndex, Any]] = None # Storage-specific data mapping
151 planner_data: Any = None # Planner-specific data (can be any type)
153class SavePlanner(abc.ABC):
154 """Abstract base class for save planners."""
156 @abc.abstractmethod
157 def configure_planner(self, state_dict: dict[str, Any], **kwargs) -> None:
158 """
159 Configure the planner with state dict.
161 Args:
162 state_dict (dict[str, Any]): The state_dict to save.
163 **kwargs: Additional keyword arguments (e.g., is_coordinator, rank, remove_redundancy,
164 save_to_minimum_rank).
165 """
167 @abc.abstractmethod
168 def build_local_plan(self) -> SavePlan:
169 """
170 Build local save plan.
172 Creates a plan for saving checkpoint data from the current rank's perspective.
173 This plan contains WriteItems for all tensors and bytes that this rank needs to save.
175 Returns:
176 SavePlan: Local save plan containing WriteItems for this rank.
177 """
179 @abc.abstractmethod
180 def build_global_plan(self, all_plans: list[SavePlan]) -> tuple[list[SavePlan], Metadata]:
181 """
182 Build global plan from all local plans.
184 Combines local plans from all ranks into a global plan and creates checkpoint metadata.
185 This method may deduplicate redundant data across ranks and assign storage indices.
187 Args:
188 all_plans (list[SavePlan]): List of local save plans from all ranks.
190 Returns:
191 tuple[list[SavePlan], Metadata]: Updated global plans (one per rank) and
192 checkpoint metadata containing information about all saved items.
193 """
195 @abc.abstractmethod
196 def finalize_plan(self, plan: SavePlan) -> SavePlan:
197 """
198 Finalize the plan.
200 Performs any final adjustments to the plan before execution, such as updating
201 tensor cache keys or performing planner-specific optimizations.
203 Args:
204 plan (SavePlan): The plan to finalize.
206 Returns:
207 SavePlan: The finalized plan ready for execution.
208 """
210 @abc.abstractmethod
211 def get_data(self, item: WriteItem) -> Any:
212 """
213 Get runtime data for a write item from the current state_dict.
215 Args:
216 item (WriteItem): The write item to get data for.
218 Returns:
219 Any: Runtime object to be written for this item.
220 """
223class LoadPlanner(abc.ABC):
224 """Abstract base class for load planners."""
226 @abc.abstractmethod
227 def configure_planner(self, state_dict: dict[str, Any], metadata: Metadata, **kwargs) -> None:
228 """
229 Configure the planner with state dict and metadata.
231 Args:
232 state_dict (dict[str, Any]): The state_dict to load into (modified in-place).
233 metadata (Metadata): Checkpoint metadata.
234 **kwargs: Additional keyword arguments (e.g., is_coordinator, rank).
235 """
237 @abc.abstractmethod
238 def build_local_plan(self) -> LoadPlan:
239 """
240 Build local load plan.
242 Creates a plan for loading checkpoint data from the current rank's perspective.
243 This plan contains ReadItems for all tensors and bytes that this rank needs to load.
245 Returns:
246 LoadPlan: Local load plan containing ReadItems for this rank.
247 """
249 @abc.abstractmethod
250 def build_global_plan(self, all_plans: list[LoadPlan]) -> list[LoadPlan]:
251 """
252 Build global plan from all local plans.
254 Combines local plans from all ranks into a global plan. This method may
255 coordinate across ranks or perform optimizations.
257 Args:
258 all_plans (list[LoadPlan]): List of local load plans from all ranks.
260 Returns:
261 list[LoadPlan]: Updated global load plans (one per rank).
262 """
264 @abc.abstractmethod
265 def finalize_plan(self, plan: LoadPlan) -> LoadPlan:
266 """
267 Finalize the plan.
269 Performs any final adjustments to the plan before execution, such as
270 performing planner-specific optimizations or validations.
272 Args:
273 plan (LoadPlan): The plan to finalize.
275 Returns:
276 LoadPlan: The finalized plan ready for execution.
277 """
279 @abc.abstractmethod
280 def acquire_tensor(self, read_item: ReadItem) -> Any:
281 """
282 Acquire tensor for read item.
284 Returns a tensor slice/view where data should be written.
286 Args:
287 read_item (ReadItem): Read item to acquire tensor for.
289 Returns:
290 Any: Acquired tensor slice/view (tensor-like object).
291 """
293 @abc.abstractmethod
294 def apply_tensor(self, read_item: ReadItem, tensor: Any) -> None:
295 """
296 Apply tensor after reading.
298 Args:
299 read_item (ReadItem): Read item.
300 tensor (Any): Tensor data to apply (tensor-like object).
301 """
303 @abc.abstractmethod
304 def apply_bytes(self, read_item: ReadItem, value: bytes) -> None:
305 """
306 Apply bytes data.
308 Args:
309 read_item (ReadItem): The read item specifying the destination.
310 value (bytes): The bytes data to deserialize and apply.
311 """