Coverage for hyper_parallel / core / checkpoint / standard_planner.py: 89%
212 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Standard planner implementations for checkpoint save and load."""
16import dataclasses
17import pickle
18from typing import Any, Optional, Union
20from hyper_parallel.core.checkpoint.metadata import (
21 Metadata, MetadataIndex, ChunkStorageMetadata,
22 TensorStorageMetadata, TensorProperties, BytesStorageMetadata
23)
24from hyper_parallel.core.checkpoint.planner import (
25 SavePlan, SavePlanner, LoadPlan, LoadPlanner,
26 WriteItem, WriteItemType, ReadItem, LoadItemType
27)
28from hyper_parallel.core.checkpoint.reshard import infer_slice_area_by_rank, infer_intersection
29from hyper_parallel.core.checkpoint.util import (
30 narrow_tensor_by_index,
31 chunk_to_area,
32 create_chunk_list_for_tensor,
33 remove_redundant_plans,
34)
35from hyper_parallel.core.dtensor import DTensor, Layout
36from hyper_parallel.platform import get_platform
38platform = get_platform()
39Tensor = platform.Tensor
42class StandardSavePlanner(SavePlanner):
43 """Standard implementation of SavePlanner for distributed checkpoint saving."""
45 def __init__(self):
46 self.state_dict: Optional[dict[str, Any]] = None
47 self.is_coordinator: bool = False
48 self.rank: int = 0
49 self.remove_redundancy: bool = True
50 self.save_to_minimum_rank: bool = True
51 self._tensor_cache: dict[MetadataIndex, Any] = {} # Cache for tensor data
53 def configure_planner(self, state_dict: dict[str, Any], **kwargs) -> None:
54 """
55 Configure planner.
57 Args:
58 state_dict (dict[str, Any]): The state_dict to save.
59 **kwargs: Additional keyword arguments (e.g., is_coordinator, rank, remove_redundancy,
60 save_to_minimum_rank).
61 """
62 self.state_dict = state_dict
63 self.is_coordinator = kwargs.get("is_coordinator", False)
64 self.rank = kwargs.get("rank", 0)
65 self.remove_redundancy = kwargs.get("remove_redundancy", True)
66 self.save_to_minimum_rank = kwargs.get("save_to_minimum_rank", True)
68 def build_local_plan(self) -> SavePlan:
69 """
70 Create local save plan.
72 Returns:
73 SavePlan: Local save plan containing WriteItems for this rank.
74 """
75 if self.state_dict is None:
76 raise RuntimeError("Planner not set up")
78 def compute_global_offsets(global_shape: tuple[int, ...], dtensor_layout: Layout) -> tuple[int, ...]:
79 """
80 Compute the offsets of local tensor in global tensor based on layout.
82 Args:
83 global_shape (tuple[int, ...]): Global shape of the tensor.
84 dtensor_layout (Layout): Layout of the DTensor.
86 Returns:
87 tuple[int, ...]: Tuple of offsets for each dimension.
88 """
89 if dtensor_layout is None:
90 # If layout is None, return all zeros (no sharding)
91 return tuple(0 for _ in global_shape)
93 # Validate layout attributes
94 if not hasattr(dtensor_layout, 'mesh_shape') or dtensor_layout.mesh_shape is None:
95 raise ValueError("Layout must have mesh_shape attribute")
96 if not hasattr(dtensor_layout, 'tensor_map') or dtensor_layout.tensor_map is None:
97 raise ValueError("Layout must have tensor_map attribute")
98 if not hasattr(dtensor_layout, 'rank_list') or dtensor_layout.rank_list is None:
99 raise ValueError("Layout must have rank_list attribute")
101 current_rank = self.rank
102 if current_rank not in dtensor_layout.rank_list:
103 raise ValueError(
104 f"Current rank {current_rank} not found in layout's rank_list {dtensor_layout.rank_list}")
106 inner_rank_id = dtensor_layout.rank_list.index(current_rank)
108 # Calculate slice area using infer_slice_area_by_rank
109 slice_area = infer_slice_area_by_rank(
110 mesh_shape=dtensor_layout.mesh_shape,
111 tensor_map=dtensor_layout.tensor_map,
112 rank_id=inner_rank_id,
113 full_shape=global_shape
114 )
116 # Extract offsets (start values) from slice_area
117 return tuple(start for start, _ in slice_area)
119 items = []
120 for fqn, obj in self.state_dict.items():
121 # Check if it's a DTensor
122 if isinstance(obj, DTensor):
123 # Create write item for DTensor
124 local_tensor = obj.to_local()
125 layout = obj.layout
127 # Get chunk metadata with offsets
128 if layout:
129 offsets = compute_global_offsets(obj.shape, layout)
130 else:
131 offsets = (0,) * len(local_tensor.shape)
133 sizes = local_tensor.shape
134 chunk = ChunkStorageMetadata(offsets=offsets, sizes=sizes)
136 # Get tensor properties
137 dtype_str = str(local_tensor.dtype) if hasattr(local_tensor, 'dtype') else 'unknown'
138 properties = TensorProperties(dtype=dtype_str)
140 # Create write item for this tensor
141 index = MetadataIndex(fqn=fqn, offset=offsets, index=None)
142 # Store tensor in cache instead of tensor_data
143 self._tensor_cache[index] = local_tensor
144 write_item = WriteItem(
145 index=index,
146 type=WriteItemType.TENSOR,
147 tensor_data={
148 'chunk': chunk,
149 'properties': properties,
150 'size': obj.shape,
151 }
152 )
153 items.append(write_item)
154 elif isinstance(obj, Tensor):
155 # Create write item for platform.Tensor: build single chunk with tensor's own size
156 dtype_str = str(obj.dtype) if hasattr(obj, 'dtype') else 'unknown'
157 properties = TensorProperties(dtype=dtype_str)
159 # Single chunk covering the whole tensor (offsets=0, sizes=shape)
160 chunk = ChunkStorageMetadata(
161 offsets=(0,) * len(obj.shape),
162 sizes=obj.shape,
163 )
165 index = MetadataIndex(fqn=fqn, offset=(0,) * len(obj.shape), index=None)
166 self._tensor_cache[index] = obj
167 write_item = WriteItem(
168 index=index,
169 type=WriteItemType.TENSOR,
170 tensor_data={
171 'chunk': chunk,
172 'properties': properties,
173 'size': obj.shape,
174 }
175 )
176 items.append(write_item)
177 else:
178 # Handle non-tensor types (bytes, etc.)
179 index = MetadataIndex(fqn=fqn)
180 write_item = WriteItem(
181 index=index,
182 type=WriteItemType.BYTE_IO,
183 bytes_io_data=obj
184 )
185 items.append(write_item)
187 return SavePlan(items=items)
189 def build_global_plan(self, all_plans: list[SavePlan]) -> tuple[list[SavePlan], Metadata]:
190 """
191 Build global plan from all local plans.
193 Collects chunks from all ranks, validates consistency, and creates metadata for the checkpoint.
195 Args:
196 all_plans (list[SavePlan]): List of local plans from all ranks.
198 Returns:
199 tuple[list[SavePlan], Metadata]: Updated plans and checkpoint metadata.
200 """
201 # Deduplicate plans if redundancy removal is enabled
202 if self.remove_redundancy and len(all_plans) > 1:
203 all_plans = remove_redundant_plans(all_plans, save_to_minimum_rank=self.save_to_minimum_rank)
205 # Collect all write items by FQN
206 fqn_to_chunks: dict[str, list[ChunkStorageMetadata]] = {}
207 fqn_to_properties: dict[str, TensorProperties] = {}
208 fqn_to_size: dict[str, tuple] = {}
209 state_dict_metadata: dict[str, Union[TensorStorageMetadata, BytesStorageMetadata]] = {}
211 final_global_plans: list[SavePlan] = []
212 for plan in all_plans:
213 with_index_items = []
214 for item in plan.items:
215 if item.type == WriteItemType.TENSOR and item.tensor_data:
216 fqn = item.index.fqn
217 chunk = item.tensor_data['chunk']
218 properties = item.tensor_data['properties']
219 size = item.tensor_data['size']
221 # Validate consistency across ranks
222 if fqn in fqn_to_chunks and (fqn_to_properties[fqn] != properties or fqn_to_size[fqn] != size):
223 raise ValueError(f"The {fqn} in different rank has different properties and size.")
225 # Initialize FQN entry if not exists
226 if fqn not in fqn_to_chunks:
227 fqn_to_properties[fqn] = properties
228 fqn_to_size[fqn] = size
229 fqn_to_chunks[fqn] = []
231 # Append chunk and set index (platform.Tensor has exactly one chunk)
232 new_index = dataclasses.replace(item.index, index=len(fqn_to_chunks[fqn]))
233 with_index_item = dataclasses.replace(item, index=new_index)
234 with_index_items.append(with_index_item)
235 fqn_to_chunks[fqn].append(chunk)
237 elif item.type == WriteItemType.BYTE_IO:
238 with_index_items.append(item)
239 state_dict_metadata[item.index.fqn] = BytesStorageMetadata()
240 else:
241 raise ValueError(f"Unsupported write item type: {item.type}")
243 final_global_plans.append(dataclasses.replace(plan, items=with_index_items))
245 # Create metadata for all tensors
246 for fqn, chunks in fqn_to_chunks.items():
247 state_dict_metadata[fqn] = TensorStorageMetadata(
248 properties=fqn_to_properties[fqn],
249 size=fqn_to_size[fqn],
250 chunks=chunks
251 )
253 metadata = Metadata(state_dict_metadata=state_dict_metadata)
254 return final_global_plans, metadata
256 def _update_tensor_cache(self, plan: SavePlan) -> None:
257 """
258 Update tensor cache keys to match the finalized plan's MetadataIndex values.
260 Updates cache keys for tensors that are in the plan, and removes tensors
261 that are not in the plan. The plan's items have been modified by build_global_plan
262 (index field added), so we need to update the cache keys accordingly.
264 Args:
265 plan (SavePlan): Plan with updated MetadataIndex values.
266 """
267 # Build mapping from (fqn, offset) to updated MetadataIndex from plan
268 plan_tensor_map: dict[tuple[str, tuple], MetadataIndex] = {}
269 for item in plan.items:
270 if item.type == WriteItemType.TENSOR and item.tensor_data:
271 key_pair = (item.index.fqn, item.index.offset)
272 plan_tensor_map[key_pair] = item.index
274 # Update tensor cache keys and remove tensors not in plan
275 keys_to_remove = []
276 for cached_key in list(self._tensor_cache.keys()):
277 key_pair = (cached_key.fqn, cached_key.offset)
279 if key_pair in plan_tensor_map:
280 # Update cache key if index changed
281 new_index = plan_tensor_map[key_pair]
282 if cached_key != new_index:
283 tensor = self._tensor_cache.pop(cached_key)
284 self._tensor_cache[new_index] = tensor
285 else:
286 # Mark for removal if not in plan
287 keys_to_remove.append(cached_key)
289 # Remove tensors not in plan
290 for key in keys_to_remove:
291 self._tensor_cache.pop(key, None)
293 def finalize_plan(self, plan: SavePlan) -> SavePlan:
294 """
295 Finalize the plan and update tensor cache keys.
297 Updates tensor cache keys to match the finalized plan's MetadataIndex values.
298 The plan's items have been modified by build_global_plan (index field added),
299 so we need to update the cache keys accordingly.
300 Also removes tensors from cache that are not in the finalized plan.
302 Args:
303 plan (SavePlan): Plan to finalize (with updated MetadataIndex values).
305 Returns:
306 SavePlan: Finalized plan.
307 """
308 self._update_tensor_cache(plan)
309 return plan
311 def get_tensor(self, index: MetadataIndex) -> Any:
312 """
313 Get tensor data for a given MetadataIndex from the cache.
315 Args:
316 index (MetadataIndex): Metadata index identifying the tensor.
318 Returns:
319 Any: Tensor data (tensor-like object) or None if not found.
320 """
321 return self._tensor_cache.get(index)
324def create_read_items_for_chunk_list(
325 fqn: str,
326 checkpoint_md: TensorStorageMetadata,
327 local_chunks: list[ChunkStorageMetadata],
328) -> list[ReadItem]:
329 """
330 Create ReadItems by matching local chunks (what this rank needs) with
331 saved chunks (checkpoint_md.chunks), including resharding overlaps.
333 Mirrors torch create_read_items_for_chunk_list behavior.
335 Args:
336 fqn (str): Fully qualified name of the tensor.
337 checkpoint_md (TensorStorageMetadata): Tensor storage metadata from checkpoint.
338 local_chunks (list[ChunkStorageMetadata]): List of local chunks needed by this rank.
340 Returns:
341 list[ReadItem]: List of ReadItems for loading the required data.
342 """
343 read_items: list[ReadItem] = []
344 saved_chunks = checkpoint_md.chunks
345 if not local_chunks or not saved_chunks:
346 return read_items
348 for local_idx, local_chunk in enumerate(local_chunks):
349 local_area = chunk_to_area(local_chunk)
350 for storage_idx, storage_chunk in enumerate(saved_chunks):
351 saved_area = chunk_to_area(storage_chunk)
352 overlap = infer_intersection(local_area, saved_area)
353 if overlap is None:
354 continue
356 dest_offsets = tuple(overlap[i][0] - local_chunk.offsets[i] for i in range(len(overlap)))
357 storage_offsets = tuple(overlap[i][0] - storage_chunk.offsets[i] for i in range(len(overlap)))
358 lengths = tuple(overlap[i][1] - overlap[i][0] for i in range(len(overlap)))
360 read_items.append(
361 ReadItem(
362 type=LoadItemType.TENSOR,
363 dest_index=MetadataIndex(fqn=fqn, offset=local_chunk.offsets, index=local_idx),
364 dest_offsets=dest_offsets,
365 storage_index=MetadataIndex(fqn=fqn, offset=storage_chunk.offsets, index=storage_idx),
366 storage_offsets=storage_offsets,
367 lengths=lengths,
368 )
369 )
370 return read_items
373class StandardLoadPlanner(LoadPlanner):
374 """
375 Standard implementation of LoadPlanner.
377 Iterate state_dict and creates load plans via chunk list for resharding support.
378 """
380 def __init__(self, allow_partial_load: bool = False):
381 """
382 Args:
383 allow_partial_load (bool): If True, allow loading when checkpoint has fewer keys than state_dict.
384 Default False.
385 """
386 self.state_dict: Optional[dict[str, Any]] = None
387 self.metadata: Optional[Metadata] = None
388 self.is_coordinator: bool = False
389 self.rank: int = 0
390 self.allow_partial_load = allow_partial_load
392 def configure_planner(self, state_dict: dict[str, Any], metadata: Metadata, **kwargs) -> None:
393 """
394 Configure planner with state dict and metadata.
396 Args:
397 state_dict (dict[str, Any]): The state_dict to load into (modified in-place).
398 metadata (Metadata): Checkpoint metadata.
399 **kwargs: Additional keyword arguments (e.g., is_coordinator, rank).
400 """
401 self.state_dict = state_dict
402 self.metadata = metadata
403 self.is_coordinator = kwargs.get("is_coordinator", False)
404 self.rank = kwargs.get("rank", 0)
406 def build_local_plan(self) -> LoadPlan:
407 """
408 Build local load plan.
410 Iterate state_dict and creates load plans via chunk list for resharding support.
412 Returns:
413 LoadPlan: Local load plan containing ReadItems for this rank.
414 """
415 if self.state_dict is None or self.metadata is None:
416 raise RuntimeError("Planner not configured")
418 requests: list[ReadItem] = []
419 strict = not self.allow_partial_load
420 for fqn, obj in self.state_dict.items():
421 if fqn not in self.metadata.state_dict_metadata:
422 if strict:
423 raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
424 continue
425 md = self.metadata.state_dict_metadata[fqn]
426 if isinstance(md, TensorStorageMetadata):
427 obj_size = getattr(obj, "shape", None)
428 if obj_size is None or md.size != tuple(obj_size):
429 raise ValueError(
430 f"Size mismatch between saved {md.size} and current: {obj_size} for {fqn}",
431 )
432 if isinstance(obj, DTensor):
433 layout = getattr(obj, "layout", None)
434 rank_list = getattr(layout, "rank_list", None) if layout else None
435 if rank_list is None and layout is not None:
436 rank_list = getattr(layout, "_rank_list", None)
437 if layout is not None and rank_list is not None:
438 if get_platform().get_rank() not in rank_list:
439 continue
440 # Both DTensor and platform.Tensor: create local chunks and read items
441 local_chunks = create_chunk_list_for_tensor(obj)
442 requests += create_read_items_for_chunk_list(fqn, md, local_chunks)
443 else:
444 requests.append(
445 ReadItem(
446 type=LoadItemType.BYTE_IO,
447 dest_index=MetadataIndex(fqn=fqn),
448 dest_offsets=(0,),
449 storage_index=MetadataIndex(fqn=fqn),
450 storage_offsets=(0,),
451 lengths=(0,),
452 )
453 )
454 return LoadPlan(items=requests)
456 def build_global_plan(self, all_plans: list[LoadPlan]) -> list[LoadPlan]:
457 """
458 Build global plan from all local plans.
460 For now, returns plans as-is. In a more sophisticated implementation, you might need to coordinate across ranks.
462 Args:
463 all_plans (list[LoadPlan]): List of local plans from all ranks.
465 Returns:
466 list[LoadPlan]: Global plans (currently returns plans as-is).
467 """
468 return all_plans
470 def finalize_plan(self, plan: LoadPlan) -> LoadPlan:
471 """
472 Finalize the plan (no-op for default implementation).
474 Args:
475 plan (LoadPlan): Plan to finalize.
477 Returns:
478 LoadPlan: Finalized plan.
479 """
480 return plan
482 def acquire_tensor(self, read_item: ReadItem) -> Any:
483 """
484 Acquire the destination slice (narrow view) for this read_item.
486 StorageReader uses this to copy loaded data into the correct region.
487 Torch-aligned behavior.
489 Args:
490 read_item (ReadItem): The read item specifying what to load.
492 Returns:
493 Any: The destination tensor slice where data should be written
494 (tensor-like object).
495 """
496 if self.state_dict is None:
497 raise RuntimeError("Planner not configured")
499 fqn = read_item.dest_index.fqn
500 if fqn not in self.state_dict:
501 raise KeyError(f"Key {fqn} not found in state_dict")
503 target = self.state_dict[fqn]
504 local_tensor = target.to_local() if isinstance(target, DTensor) else target
505 return narrow_tensor_by_index(
506 local_tensor,
507 read_item.dest_offsets,
508 read_item.lengths,
509 )
511 def apply_tensor(self, read_item: ReadItem, tensor: Any) -> None:
512 """
513 Apply tensor after reading.
515 After read_data copies into the slice, this is no-op when tensor is the
516 same slice. When the backend has no copy_ (e.g. mindspore), read_data
517 passes the loaded slice here; we copy it into the destination slice.
519 Args:
520 read_item (ReadItem): The read item that was processed.
521 tensor (Any): The tensor data to apply (tensor-like object).
522 """
523 if tensor is None:
524 return
525 dest_slice = self.acquire_tensor(read_item)
526 if dest_slice is tensor:
527 return
528 if hasattr(dest_slice, "copy_"):
529 dest_slice.copy_(tensor)
530 else:
531 # Fallback: assign into state_dict if supported
532 dest_slice[...] = tensor
534 def apply_bytes(self, read_item: ReadItem, value: bytes) -> None:
535 """
536 Load bytes data into state_dict.
538 Args:
539 read_item (ReadItem): The read item specifying the destination.
540 value (bytes): The bytes data to deserialize and load.
541 """
542 if self.state_dict is None:
543 raise RuntimeError("Planner not set up")
545 fqn = read_item.dest_index.fqn
546 # Deserialize bytes
547 obj = pickle.loads(value)
548 self.state_dict[fqn] = obj