Coverage for hyper_parallel / core / checkpoint / planner.py: 98%
83 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Planner interfaces and implementations"""
16import abc
17from dataclasses import dataclass, field
18from enum import Enum
19from typing import Any, Optional, Union
21from hyper_parallel.core.checkpoint.metadata import (
22 Metadata, MetadataIndex
23)
26class WriteItemType(Enum):
27 """Type of write item."""
28 TENSOR = "tensor"
29 BYTE_IO = "byte_io"
32class LoadItemType(Enum):
33 """Type of load item."""
34 TENSOR = "tensor"
35 BYTE_IO = "byte_io"
38@dataclass(frozen=True)
39class WriteItem:
40 """
41 Item to be written to storage.
43 Represents a single logical item (tensor or bytes) to be saved.
45 Attributes:
46 index: Metadata index identifying this item.
47 type: Type of write item (TENSOR or BYTE_IO).
48 tensor_data: Dictionary containing tensor data (for TENSOR type). Default None.
49 bytes_io_data: Bytes data (for BYTE_IO type). Default None.
50 """
51 index: MetadataIndex
52 type: WriteItemType
53 # Keys: 'chunk' (ChunkStorageMetadata), 'properties' (TensorProperties), 'size' (tuple).
54 # Actual tensor data is in planner's tensor cache, not here, to avoid all_gather of tensors.
55 tensor_data: Optional[dict[str, Any]] = None
56 bytes_io_data: Optional[Union[bytes, Any]] = None # Bytes or pickle-serializable object
58 def tensor_storage_size(self) -> Optional[int]:
59 """
60 Best-effort storage size estimation in bytes for tensor items.
62 Returns:
63 Optional[int]: Estimated storage size in bytes for tensor items,
64 or None if estimation cannot be performed (e.g., for non-tensor items).
65 """
66 if self.type != WriteItemType.TENSOR or not self.tensor_data:
67 return None
69 # Try to estimate from metadata
70 chunk = self.tensor_data.get("chunk")
71 properties = self.tensor_data.get("properties")
72 if chunk is None or properties is None:
73 return None
75 # Get size from chunk (local chunk size, not global size)
76 size = chunk.sizes
77 num = 1
78 for dim in size:
79 num *= int(dim)
80 # Try to get dtype item size from properties
81 dtype_str = getattr(properties, "dtype", None)
82 if dtype_str is None:
83 return int(num)
84 # Simple estimation: assume common dtypes
85 dtype_to_size_map = {
86 "int32": 4, "int64": 8, "bfloat16": 2, "float16": 2, "float32": 4, "float64": 8
87 }
88 dtype_str_lower = str(dtype_str).lower()
89 elem_size = 4 # Default to 4 bytes
90 for dtype_name, size in dtype_to_size_map.items():
91 if dtype_name in dtype_str_lower:
92 elem_size = size
93 break
94 return int(num) * int(elem_size)
98@dataclass(frozen=True)
99class ReadItem:
100 """
101 Item to be read from storage.
103 Represents a single logical read operation, mapping from checkpoint storage
104 to destination state_dict location.
106 Attributes:
107 type: Type of load item (TENSOR or BYTE_IO).
108 dest_index: Metadata index identifying the destination in state_dict.
109 dest_offsets: Offsets into the destination tensor (for TENSOR type).
110 storage_index: Metadata index identifying the source in checkpoint.
111 storage_offsets: Offsets into the checkpoint storage data.
112 lengths: Size of the hypercube to copy (dimensions of the data region).
113 """
114 type: LoadItemType
115 dest_index: MetadataIndex # Index into the state_dict
116 dest_offsets: tuple # Offsets into destination tensor
117 storage_index: MetadataIndex # Index into the checkpoint
118 storage_offsets: tuple # Offset into the checkpoint data
119 lengths: tuple # Size of the hypercube to copy
122@dataclass
123class SavePlan:
124 """
125 Plan for saving checkpoint.
127 Contains write items and optional storage/planner-specific data.
129 Attributes:
130 items: List of WriteItems to be saved. Default [].
131 storage_data: Storage-specific data (optional). Default None.
132 planner_data: Planner-specific data (optional). Default None.
133 """
134 items: list[WriteItem] = field(default_factory=list)
135 storage_data: Optional[dict[MetadataIndex, Any]] = None # Storage-specific data mapping
136 planner_data: Any = None # Planner-specific data (can be any type)
139@dataclass
140class LoadPlan:
141 """
142 Plan for loading checkpoint.
144 Contains read items and optional storage/planner-specific data.
146 Attributes:
147 items: List of ReadItems to be loaded. Default [].
148 storage_data: Storage-specific data (optional). Default None.
149 planner_data: Planner-specific data (optional). Default None.
150 """
151 items: list[ReadItem] = field(default_factory=list)
152 storage_data: Optional[dict[MetadataIndex, Any]] = None # Storage-specific data mapping
153 planner_data: Any = None # Planner-specific data (can be any type)
156class SavePlanner(abc.ABC):
157 """Abstract base class for save planners."""
159 @abc.abstractmethod
160 def configure_planner(self, state_dict: dict[str, Any], **kwargs) -> None:
161 """
162 Configure the planner with state dict.
164 Args:
165 state_dict (dict[str, Any]): The state_dict to save.
166 **kwargs: Additional keyword arguments (e.g., is_coordinator, rank, remove_redundancy,
167 save_to_minimum_rank).
168 """
170 @abc.abstractmethod
171 def build_local_plan(self) -> SavePlan:
172 """
173 Build local save plan.
175 Creates a plan for saving checkpoint data from the current rank's perspective.
176 This plan contains WriteItems for all tensors and bytes that this rank needs to save.
178 Returns:
179 SavePlan: Local save plan containing WriteItems for this rank.
180 """
182 @abc.abstractmethod
183 def build_global_plan(self, all_plans: list[SavePlan]) -> tuple[list[SavePlan], Metadata]:
184 """
185 Build global plan from all local plans.
187 Combines local plans from all ranks into a global plan and creates checkpoint metadata.
188 This method may deduplicate redundant data across ranks and assign storage indices.
190 Args:
191 all_plans (list[SavePlan]): List of local save plans from all ranks.
193 Returns:
194 tuple[list[SavePlan], Metadata]: Updated global plans (one per rank) and
195 checkpoint metadata containing information about all saved items.
196 """
198 @abc.abstractmethod
199 def finalize_plan(self, plan: SavePlan) -> SavePlan:
200 """
201 Finalize the plan.
203 Performs any final adjustments to the plan before execution, such as updating
204 tensor cache keys or performing planner-specific optimizations.
206 Args:
207 plan (SavePlan): The plan to finalize.
209 Returns:
210 SavePlan: The finalized plan ready for execution.
211 """
213 @abc.abstractmethod
214 def get_tensor(self, index: MetadataIndex) -> Any:
215 """
216 Get tensor data for a given MetadataIndex.
218 This method allows storage writers to retrieve tensor data when needed,
219 avoiding the need to store tensors in WriteItem.tensor_data (which would
220 be transmitted during all_gather operations).
222 Args:
223 index (MetadataIndex): Metadata index identifying the tensor.
225 Returns:
226 Any: Tensor data (tensor-like object) or None if not found.
227 """
230class LoadPlanner(abc.ABC):
231 """Abstract base class for load planners."""
233 @abc.abstractmethod
234 def configure_planner(self, state_dict: dict[str, Any], metadata: Metadata, **kwargs) -> None:
235 """
236 Configure the planner with state dict and metadata.
238 Args:
239 state_dict (dict[str, Any]): The state_dict to load into (modified in-place).
240 metadata (Metadata): Checkpoint metadata.
241 **kwargs: Additional keyword arguments (e.g., is_coordinator, rank).
242 """
244 @abc.abstractmethod
245 def build_local_plan(self) -> LoadPlan:
246 """
247 Build local load plan.
249 Creates a plan for loading checkpoint data from the current rank's perspective.
250 This plan contains ReadItems for all tensors and bytes that this rank needs to load.
252 Returns:
253 LoadPlan: Local load plan containing ReadItems for this rank.
254 """
256 @abc.abstractmethod
257 def build_global_plan(self, all_plans: list[LoadPlan]) -> list[LoadPlan]:
258 """
259 Build global plan from all local plans.
261 Combines local plans from all ranks into a global plan. This method may
262 coordinate across ranks or perform optimizations.
264 Args:
265 all_plans (list[LoadPlan]): List of local load plans from all ranks.
267 Returns:
268 list[LoadPlan]: Updated global load plans (one per rank).
269 """
271 @abc.abstractmethod
272 def finalize_plan(self, plan: LoadPlan) -> LoadPlan:
273 """
274 Finalize the plan.
276 Performs any final adjustments to the plan before execution, such as
277 performing planner-specific optimizations or validations.
279 Args:
280 plan (LoadPlan): The plan to finalize.
282 Returns:
283 LoadPlan: The finalized plan ready for execution.
284 """
286 @abc.abstractmethod
287 def acquire_tensor(self, read_item: ReadItem) -> Any:
288 """
289 Acquire tensor for read item.
291 Returns a tensor slice/view where data should be written.
293 Args:
294 read_item (ReadItem): Read item to acquire tensor for.
296 Returns:
297 Any: Acquired tensor slice/view (tensor-like object).
298 """
300 @abc.abstractmethod
301 def apply_tensor(self, read_item: ReadItem, tensor: Any) -> None:
302 """
303 Apply tensor after reading.
305 Args:
306 read_item (ReadItem): Read item.
307 tensor (Any): Tensor data to apply (tensor-like object).
308 """
310 @abc.abstractmethod
311 def apply_bytes(self, read_item: ReadItem, value: bytes) -> None:
312 """
313 Apply bytes data.
315 Args:
316 read_item (ReadItem): The read item specifying the destination.
317 value (bytes): The bytes data to deserialize and apply.
318 """