Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / standard_planner.py: 54%
269 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Standard planner implementations for checkpoint save and load."""
16from dataclasses import dataclass
17import dataclasses
18import pickle
19from typing import Any, Optional, Union
21from hyper_parallel.core.distributed_checkpoint.metadata import (
22 Metadata, MetadataIndex, ChunkStorageMetadata,
23 TensorStorageMetadata, TensorProperties, BytesStorageMetadata
24)
25from hyper_parallel.core.distributed_checkpoint.planner import (
26 SavePlan, SavePlanner, LoadPlan, LoadPlanner,
27 WriteItem, WriteItemType, ReadItem, LoadItemType
28)
29from hyper_parallel.core.distributed_checkpoint.reshard import infer_slice_area_by_rank, infer_intersection
30from hyper_parallel.core.distributed_checkpoint.util import (
31 narrow_tensor_by_index,
32 chunk_to_area,
33 create_chunk_list_for_tensor,
34 remove_redundant_plans,
35 flatten_state_dict,
36 set_element,
37)
38from hyper_parallel.core.dtensor.dtensor import DTensor, Layout
39from hyper_parallel.platform import get_platform
41platform = get_platform()
42Tensor = platform.Tensor
45@dataclass(frozen=True)
46class CachedSaveResult:
47 """Cached finalized save result keyed by planner cache namespace."""
49 final_plan: SavePlan
50 metadata: Metadata
53class StandardSavePlanner(SavePlanner):
54 """Standard implementation of SavePlanner for distributed checkpoint saving."""
56 _cached_save_result: dict[str, CachedSaveResult] = {}
58 def __init__(
59 self,
60 enable_plan_caching: bool = True,
61 remove_redundancy: bool = True,
62 save_to_minimum_rank: bool = False,
63 ):
64 self.state_dict: Optional[dict[str, Any]] = None
65 self.is_coordinator: bool = False
66 self.rank: int = 0
67 self.remove_redundancy: bool = remove_redundancy
68 self.save_to_minimum_rank: bool = save_to_minimum_rank
69 self.flatten_state_dict: bool = True
70 self._enable_plan_caching: bool = enable_plan_caching
71 self._cached_plans_key: str = self.__class__.__name__
73 def configure_planner(self, state_dict: dict[str, Any], **kwargs) -> None:
74 """
75 Configure planner.
77 Args:
78 state_dict (dict[str, Any]): The state_dict to save.
79 **kwargs: Additional keyword arguments (e.g., is_coordinator, rank, remove_redundancy,
80 save_to_minimum_rank).
81 """
82 self.is_coordinator = kwargs.get("is_coordinator", False)
83 self.rank = kwargs.get("rank", 0)
84 self.remove_redundancy = kwargs.get("remove_redundancy", self.remove_redundancy)
85 self.save_to_minimum_rank = kwargs.get("save_to_minimum_rank", self.save_to_minimum_rank)
86 self.flatten_state_dict = kwargs.get("flatten_state_dict", True)
88 use_collectives = bool(kwargs.get("use_collectives", True))
89 if not use_collectives:
90 self.remove_redundancy = False
91 self._enable_plan_caching = False
92 elif "enable_plan_caching" in kwargs:
93 self._enable_plan_caching = bool(kwargs["enable_plan_caching"])
95 if self.flatten_state_dict:
96 state_dict, self.name_mapping = flatten_state_dict(state_dict)
97 self.state_dict = state_dict
98 self._cached_plans_key = self._build_cache_key(state_dict)
100 def _build_cache_key(self, state_dict: dict[str, Any]) -> str:
101 """Build a stable cache namespace from sorted state_dict keys."""
102 return f"{self.__class__.__name__}:{'||'.join(state_dict.keys())}"
104 def build_local_plan(self) -> SavePlan:
105 """
106 Create local save plan.
108 Returns:
109 SavePlan: Local save plan containing WriteItems for this rank.
110 """
111 if self.state_dict is None:
112 raise RuntimeError("Planner not set up")
114 def compute_global_offsets(global_shape: tuple[int, ...], dtensor_layout: Layout) -> tuple[int, ...]:
115 """
116 Compute the offsets of local tensor in global tensor based on layout.
118 Args:
119 global_shape (tuple[int, ...]): Global shape of the tensor.
120 dtensor_layout (Layout): Layout of the DTensor.
122 Returns:
123 tuple[int, ...]: Tuple of offsets for each dimension.
124 """
125 if dtensor_layout is None:
126 # If layout is None, return all zeros (no sharding)
127 return tuple(0 for _ in global_shape)
129 # Validate layout attributes
130 if not hasattr(dtensor_layout, 'mesh_shape') or dtensor_layout.mesh_shape is None:
131 raise ValueError("Layout must have mesh_shape attribute")
132 if not hasattr(dtensor_layout, 'tensor_map') or dtensor_layout.tensor_map is None:
133 raise ValueError("Layout must have tensor_map attribute")
134 if not hasattr(dtensor_layout, 'rank_list') or dtensor_layout.rank_list is None:
135 raise ValueError("Layout must have rank_list attribute")
137 current_rank = self.rank
138 if current_rank not in dtensor_layout.rank_list:
139 raise ValueError(
140 f"Current rank {current_rank} not found in layout's rank_list {dtensor_layout.rank_list}")
142 inner_rank_id = dtensor_layout.rank_list.index(current_rank)
143 # Calculate slice area using infer_slice_area_by_rank
144 slice_area = infer_slice_area_by_rank(
145 mesh_shape=dtensor_layout.mesh_shape,
146 tensor_map=dtensor_layout.tensor_map,
147 rank_id=inner_rank_id,
148 full_shape=global_shape
149 )
150 # Extract offsets (start values) from slice_area
151 return tuple(start for start, _ in slice_area)
153 items = []
154 for fqn, obj in self.state_dict.items():
155 # Check if it's a DTensor
156 if isinstance(obj, DTensor):
157 # Create write item for DTensor
158 local_tensor = obj.to_local()
159 layout = obj.layout
161 # Get chunk metadata with offsets
162 if layout:
163 offsets = compute_global_offsets(obj.shape, layout)
164 else:
165 offsets = (0,) * len(local_tensor.shape)
167 sizes = local_tensor.shape
168 chunk = ChunkStorageMetadata(offsets=offsets, sizes=sizes)
169 # Get tensor properties
170 dtype_str = str(local_tensor.dtype) if hasattr(local_tensor, 'dtype') else 'unknown'
171 properties = TensorProperties(dtype=dtype_str)
172 # Create write item for this tensor
173 index = MetadataIndex(fqn=fqn, offset=offsets, index=None)
174 write_item = WriteItem(
175 index=index,
176 type=WriteItemType.TENSOR,
177 tensor_data={
178 'chunk': chunk,
179 'properties': properties,
180 'size': obj.shape,
181 }
182 )
183 items.append(write_item)
184 elif isinstance(obj, Tensor):
185 # Create write item for platform.Tensor: build single chunk with tensor's own size
186 dtype_str = str(obj.dtype) if hasattr(obj, 'dtype') else 'unknown'
187 properties = TensorProperties(dtype=dtype_str)
188 # Single chunk covering the whole tensor (offsets=0, sizes=shape)
189 chunk = ChunkStorageMetadata(
190 offsets=(0,) * len(obj.shape),
191 sizes=obj.shape,
192 )
193 index = MetadataIndex(fqn=fqn, offset=(0,) * len(obj.shape), index=None)
194 write_item = WriteItem(
195 index=index,
196 type=WriteItemType.TENSOR,
197 tensor_data={
198 'chunk': chunk,
199 'properties': properties,
200 'size': obj.shape,
201 }
202 )
203 items.append(write_item)
204 else:
205 # Handle non-tensor types (bytes, etc.)
206 index = MetadataIndex(fqn=fqn)
207 write_item = WriteItem(
208 index=index,
209 type=WriteItemType.BYTE_IO,
210 bytes_io_data=None
211 )
212 items.append(write_item)
214 plan = SavePlan(items=items)
215 if self.flatten_state_dict:
216 plan.planner_data = self.name_mapping
217 return plan
219 def build_global_plan(self, all_plans: list[SavePlan]) -> tuple[list[SavePlan], Metadata]:
220 """
221 Build global plan from all local plans.
223 Collects chunks from all ranks, validates consistency, and creates metadata for the checkpoint.
225 Args:
226 all_plans (list[SavePlan]): List of local plans from all ranks.
228 Returns:
229 tuple[list[SavePlan], Metadata]: Updated plans and checkpoint metadata.
230 """
231 # Deduplicate plans if redundancy removal is enabled
232 if self.remove_redundancy and len(all_plans) > 1:
233 all_plans = remove_redundant_plans(all_plans, save_to_minimum_rank=self.save_to_minimum_rank)
235 # Collect all write items by FQN
236 fqn_to_chunks: dict[str, list[ChunkStorageMetadata]] = {}
237 fqn_to_properties: dict[str, TensorProperties] = {}
238 fqn_to_size: dict[str, tuple] = {}
239 state_dict_metadata: dict[str, Union[TensorStorageMetadata, BytesStorageMetadata]] = {}
241 final_global_plans: list[SavePlan] = []
242 for plan in all_plans:
243 with_index_items = []
244 for item in plan.items:
245 if item.type == WriteItemType.TENSOR and item.tensor_data:
246 fqn = item.index.fqn
247 chunk = item.tensor_data['chunk']
248 properties = item.tensor_data['properties']
249 size = item.tensor_data['size']
251 # Validate consistency across ranks
252 if fqn in fqn_to_chunks and (fqn_to_properties[fqn] != properties or fqn_to_size[fqn] != size):
253 raise ValueError(f"The {fqn} in different rank has different properties and size.")
255 # Initialize FQN entry if not exists
256 if fqn not in fqn_to_chunks:
257 fqn_to_properties[fqn] = properties
258 fqn_to_size[fqn] = size
259 fqn_to_chunks[fqn] = []
261 # Append chunk and set index (platform.Tensor has exactly one chunk)
262 new_index = dataclasses.replace(item.index, index=len(fqn_to_chunks[fqn]))
263 with_index_item = dataclasses.replace(item, index=new_index)
264 with_index_items.append(with_index_item)
265 fqn_to_chunks[fqn].append(chunk)
267 elif item.type == WriteItemType.BYTE_IO:
268 with_index_items.append(item)
269 state_dict_metadata[item.index.fqn] = BytesStorageMetadata()
270 else:
271 raise ValueError(f"Unsupported write item type: {item.type}")
273 final_global_plans.append(dataclasses.replace(plan, items=with_index_items))
275 # Create metadata for all tensors
276 for fqn, chunks in fqn_to_chunks.items():
277 state_dict_metadata[fqn] = TensorStorageMetadata(
278 properties=fqn_to_properties[fqn],
279 size=fqn_to_size[fqn],
280 chunks=chunks
281 )
283 metadata = Metadata(state_dict_metadata=state_dict_metadata)
284 if self.flatten_state_dict:
285 merged_mapping = {}
286 for p in all_plans:
287 merged_mapping.update(p.planner_data)
288 metadata.planner_data = merged_mapping
289 return final_global_plans, metadata
291 def finalize_plan(self, plan: SavePlan) -> SavePlan:
292 """
293 Finalize the plan.
295 Args:
296 plan (SavePlan): Plan to finalize.
298 Returns:
299 SavePlan: Finalized plan.
300 """
301 return plan
303 def get_cached_result(self) -> Optional[tuple[SavePlan, Metadata]]:
304 """Return cached finalized plan and metadata when plan caching is enabled."""
305 if not self._enable_plan_caching:
306 return None
307 cached_result = StandardSavePlanner._cached_save_result.get(self._cached_plans_key)
308 if cached_result is None:
309 return None
310 return cached_result.final_plan, cached_result.metadata
312 def cache_result(self, final_plan: SavePlan, metadata: Metadata) -> None:
313 """Store finalized plan and metadata in the class-level planner cache."""
314 if not self._enable_plan_caching:
315 return
316 StandardSavePlanner._cached_save_result[self._cached_plans_key] = CachedSaveResult(
317 final_plan=final_plan,
318 metadata=metadata,
319 )
321 def get_data(self, item: WriteItem) -> Any:
322 """
323 Get current runtime data from state_dict for a write item.
325 Args:
326 item (WriteItem): Write item describing what to write.
328 Returns:
329 Any: Runtime object to be written.
330 """
331 if self.state_dict is None:
332 raise RuntimeError("Planner not set up")
333 fqn = item.index.fqn
334 if fqn not in self.state_dict:
335 raise KeyError(f"Key {fqn} not found in state_dict")
336 obj = self.state_dict[fqn]
337 if item.type == WriteItemType.TENSOR:
338 if isinstance(obj, DTensor):
339 return obj.to_local().detach().cpu()
340 if isinstance(obj, Tensor):
341 return obj.detach().cpu()
342 raise TypeError(f"Write item {fqn} expected tensor-like object, got {type(obj)}")
343 if item.type == WriteItemType.BYTE_IO:
344 return obj
345 raise TypeError(f"Unsupported write item type: {item.type}")
347def create_read_items_for_chunk_list(
348 fqn: str,
349 checkpoint_md: TensorStorageMetadata,
350 local_chunks: list[ChunkStorageMetadata],
351) -> list[ReadItem]:
352 """
353 Create ReadItems by matching local chunks (what this rank needs) with
354 saved chunks (checkpoint_md.chunks), including resharding overlaps.
356 Mirrors torch create_read_items_for_chunk_list behavior.
358 Args:
359 fqn (str): Fully qualified name of the tensor.
360 checkpoint_md (TensorStorageMetadata): Tensor storage metadata from checkpoint.
361 local_chunks (list[ChunkStorageMetadata]): List of local chunks needed by this rank.
363 Returns:
364 list[ReadItem]: List of ReadItems for loading the required data.
365 """
366 read_items: list[ReadItem] = []
367 saved_chunks = checkpoint_md.chunks
368 if not local_chunks or not saved_chunks:
369 return read_items
371 for local_idx, local_chunk in enumerate(local_chunks):
372 local_area = chunk_to_area(local_chunk)
373 for storage_idx, storage_chunk in enumerate(saved_chunks):
374 saved_area = chunk_to_area(storage_chunk)
375 overlap = infer_intersection(local_area, saved_area)
376 if overlap is None:
377 continue
379 dest_offsets = tuple(overlap[i][0] - local_chunk.offsets[i] for i in range(len(overlap)))
380 storage_offsets = tuple(overlap[i][0] - storage_chunk.offsets[i] for i in range(len(overlap)))
381 lengths = tuple(overlap[i][1] - overlap[i][0] for i in range(len(overlap)))
383 read_items.append(
384 ReadItem(
385 type=LoadItemType.TENSOR,
386 dest_index=MetadataIndex(fqn=fqn, offset=local_chunk.offsets, index=local_idx),
387 dest_offsets=dest_offsets,
388 storage_index=MetadataIndex(fqn=fqn, offset=storage_chunk.offsets, index=storage_idx),
389 storage_offsets=storage_offsets,
390 lengths=lengths,
391 )
392 )
393 return read_items
396class StandardLoadPlanner(LoadPlanner):
397 """
398 Standard implementation of LoadPlanner.
400 Iterate state_dict and creates load plans via chunk list for resharding support.
401 """
403 def __init__(self, allow_partial_load: bool = False):
404 """
405 Args:
406 allow_partial_load (bool): If True, allow loading when checkpoint has fewer keys than state_dict.
407 Default False.
408 """
409 self.state_dict: Optional[dict[str, Any]] = None
410 self.metadata: Optional[Metadata] = None
411 self.is_coordinator: bool = False
412 self.rank: int = 0
413 self.allow_partial_load = allow_partial_load
414 self.flatten_state_dict: bool = True
416 def configure_planner(self, state_dict: dict[str, Any], metadata: Metadata, **kwargs) -> None:
417 """
418 Configure planner with state dict and metadata.
420 Args:
421 state_dict (dict[str, Any]): The state_dict to load into (modified in-place).
422 metadata (Metadata): Checkpoint metadata.
423 **kwargs: Additional keyword arguments (e.g., is_coordinator, rank).
424 """
425 self.state_dict = state_dict
426 self.metadata = metadata
427 self.is_coordinator = kwargs.get("is_coordinator", False)
428 self.rank = kwargs.get("rank", 0)
429 self.flatten_state_dict = kwargs.get("flatten_state_dict", True)
430 self.original_state_dict = state_dict
431 if self.flatten_state_dict:
432 state_dict, self.name_mapping = flatten_state_dict(state_dict)
433 self.state_dict = state_dict
435 def build_local_plan(self) -> LoadPlan:
436 """
437 Build local load plan.
439 Iterate state_dict and creates load plans via chunk list for resharding support.
441 Returns:
442 LoadPlan: Local load plan containing ReadItems for this rank.
443 """
444 if self.state_dict is None or self.metadata is None:
445 raise RuntimeError("Planner not configured")
447 requests: list[ReadItem] = []
448 strict = not self.allow_partial_load
449 for fqn, obj in self.state_dict.items():
450 if fqn not in self.metadata.state_dict_metadata:
451 if strict:
452 raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
453 continue
454 md = self.metadata.state_dict_metadata[fqn]
455 if isinstance(md, TensorStorageMetadata):
456 obj_size = getattr(obj, "shape", None)
457 if obj_size is None or md.size != tuple(obj_size):
458 raise ValueError(
459 f"Size mismatch between saved {md.size} and current: {obj_size} for {fqn}",
460 )
461 if isinstance(obj, DTensor):
462 layout = getattr(obj, "layout", None)
463 rank_list = getattr(layout, "rank_list", None) if layout else None
464 if rank_list is None and layout is not None:
465 rank_list = getattr(layout, "_rank_list", None)
466 if layout is not None and rank_list is not None:
467 if get_platform().get_rank() not in rank_list:
468 continue
469 # Both DTensor and platform.Tensor: create local chunks and read items
470 local_chunks = create_chunk_list_for_tensor(obj)
471 requests += create_read_items_for_chunk_list(fqn, md, local_chunks)
472 else:
473 requests.append(
474 ReadItem(
475 type=LoadItemType.BYTE_IO,
476 dest_index=MetadataIndex(fqn=fqn),
477 dest_offsets=(0,),
478 storage_index=MetadataIndex(fqn=fqn),
479 storage_offsets=(0,),
480 lengths=(0,),
481 )
482 )
483 return LoadPlan(items=requests)
485 def build_global_plan(self, all_plans: list[LoadPlan]) -> list[LoadPlan]:
486 """
487 Build global plan from all local plans.
489 For now, returns plans as-is. In a more sophisticated implementation, you might need to coordinate across ranks.
491 Args:
492 all_plans (list[LoadPlan]): List of local plans from all ranks.
494 Returns:
495 list[LoadPlan]: Global plans (currently returns plans as-is).
496 """
497 return all_plans
499 def finalize_plan(self, plan: LoadPlan) -> LoadPlan:
500 """
501 Finalize the plan (no-op for default implementation).
503 Args:
504 plan (LoadPlan): Plan to finalize.
506 Returns:
507 LoadPlan: Finalized plan.
508 """
509 return plan
511 def acquire_tensor(self, read_item: ReadItem) -> Any:
512 """
513 Acquire the destination slice (narrow view) for this read_item.
515 StorageReader uses this to copy loaded data into the correct region.
516 Torch-aligned behavior.
518 Args:
519 read_item (ReadItem): The read item specifying what to load.
521 Returns:
522 Any: The destination tensor slice where data should be written
523 (tensor-like object).
524 """
525 if self.state_dict is None:
526 raise RuntimeError("Planner not configured")
528 fqn = read_item.dest_index.fqn
529 if fqn not in self.state_dict:
530 raise KeyError(f"Key {fqn} not found in state_dict")
532 target = self.state_dict[fqn]
533 local_tensor = target.to_local().detach() if isinstance(target, DTensor) else target.detach()
534 return narrow_tensor_by_index(
535 local_tensor,
536 read_item.dest_offsets,
537 read_item.lengths,
538 )
540 def apply_tensor(self, read_item: ReadItem, tensor: Any) -> None:
541 """
542 Apply tensor after reading.
544 After read_data copies into the slice, this is no-op when tensor is the
545 same slice. When the backend has no copy_ (e.g. mindspore), read_data
546 passes the loaded slice here; we copy it into the destination slice.
548 Args:
549 read_item (ReadItem): The read item that was processed.
550 tensor (Any): The tensor data to apply (tensor-like object).
551 """
552 if tensor is None:
553 return
554 dest_slice = self.acquire_tensor(read_item)
555 if dest_slice is tensor:
556 return
557 if hasattr(dest_slice, "copy_"):
558 dest_slice.copy_(tensor)
559 else:
560 # Fallback: assign into state_dict if supported
561 dest_slice[...] = tensor
563 def apply_bytes(self, read_item: ReadItem, value: bytes) -> None:
564 """
565 Load bytes data into state_dict.
567 Args:
568 read_item (ReadItem): The read item specifying the destination.
569 value (bytes): The bytes data to deserialize and load.
570 """
571 if self.state_dict is None:
572 raise RuntimeError("Planner not set up")
574 fqn = read_item.dest_index.fqn
575 # Deserialize bytes
576 obj = pickle.loads(value)
577 self.state_dict[fqn] = obj
578 if self.flatten_state_dict:
579 set_element(self.original_state_dict, self.name_mapping[fqn], obj)
583class _DcpMergeLoadPlanner(StandardLoadPlanner):
584 """Load planner that builds distributed checkpoint from dcp into fully ``state_dict`` (in-place)."""
586 def __init__(self) -> None:
587 super().__init__()
589 def configure_planner(self, state_dict: dict[str, Any], metadata: Metadata, **kwargs) -> None:
590 if len(state_dict) > 0:
591 raise ValueError(
592 "state_dict must be empty for _DcpMergeLoadPlanner; "
593 "it is populated in-place from checkpoint metadata."
594 )
596 if metadata is None:
597 raise ValueError("metadata must not be None for _DcpMergeLoadPlanner.")
599 self.is_coordinator = kwargs.get("is_coordinator", False)
600 for k, v in metadata.state_dict_metadata.items():
601 if isinstance(v, TensorStorageMetadata):
602 v = platform.empty(
603 platform.list_to_size(v.size),
604 dtype=platform.str_to_dtype(v.properties.dtype),
605 )
607 state_dict[k] = v
608 if metadata.planner_data is not None and k in metadata.planner_data:
609 set_element(state_dict, metadata.planner_data[k], v)
611 super().configure_planner(
612 state_dict,
613 metadata,
614 is_coordinator=self.is_coordinator,
615 flatten_state_dict=True,
616 )