Coverage for hyper_parallel / core / checkpoint / util.py: 85%
82 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Common utility functions."""
16import dataclasses
17from collections import defaultdict
18from pathlib import Path
19from typing import Any, Union
21from hyper_parallel.core.checkpoint.metadata import ChunkStorageMetadata, MetadataIndex
22from hyper_parallel.core.checkpoint.planner import SavePlan, WriteItem
23from hyper_parallel.core.checkpoint.reshard import infer_slice_area_by_rank
24from hyper_parallel.core.dtensor import DTensor
25from hyper_parallel.platform import get_platform
27platform = get_platform()
28Tensor = platform.Tensor
31def check_path(path: Union[Path, str]) -> None:
32 """
33 Check whether path is existing or not.
35 Args:
36 path (Union[Path, str]): path to check. Can only a file name in current directory, a pure directory, or a file
37 name with directory. When path contains a directory, the function will check whether the directory exists, if
38 not, the directory will be created.
39 """
40 path_obj = Path(path) if isinstance(path, str) else path
42 if path_obj.exists():
43 return
45 if path_obj.suffix:
46 path_obj.parent.mkdir(parents=True, exist_ok=True)
47 else:
48 path_obj.mkdir(parents=True, exist_ok=True)
51def has_valid_filename(path: Path) -> bool:
52 """
53 Check whether path has valid filename. A filename should contain name and suffix, name and suffix must contain
54 letters, and then can have numbers and underscores.
56 Args:
57 path (Path): path to check.
59 Return:
60 bool: whether path has a valid filename.
61 """
62 conditions = (
63 path.name,
64 path.suffix,
65 len(path.suffix) > 1,
66 path.stem,
67 any(c.isalpha() for c in path.stem),
68 any(c.isalpha() for c in path.suffix[1:])
69 )
70 return all(conditions)
73def narrow_tensor_by_index(tensor: Any, offsets: tuple, lengths: tuple) -> Any:
74 """
75 Narrow the tensor by (offsets, lengths) per dimension.
77 Used for resharding operations to extract a slice from a tensor.
78 Compatible with both torch and mindspore (uses slice indexing).
80 Args:
81 tensor (Any): The tensor to narrow (tensor-like object supporting indexing).
82 offsets (tuple): Tuple of offsets per dimension.
83 lengths (tuple): Tuple of lengths per dimension.
85 Returns:
86 Any: The narrowed tensor slice (tensor-like object).
87 """
88 if not offsets or not lengths:
89 return tensor
90 slices = tuple(
91 slice(int(off), int(off) + int(ln))
92 for off, ln in zip(offsets, lengths)
93 )
94 return tensor[slices]
97def chunk_to_area(chunk: ChunkStorageMetadata) -> tuple[tuple[int, int], ...]:
98 """
99 Convert ChunkStorageMetadata to (start, end) area per dimension.
101 Args:
102 chunk (ChunkStorageMetadata): ChunkStorageMetadata instance with offsets and sizes.
104 Returns:
105 tuple[tuple[int, int], ...]: Tuple of (start, end) tuples for each dimension.
106 """
107 return tuple(
108 (chunk.offsets[i], chunk.offsets[i] + chunk.sizes[i])
109 for i in range(len(chunk.offsets))
110 )
113def create_chunk_list_for_tensor(obj: Union[Tensor, DTensor]) -> list[ChunkStorageMetadata]:
114 """
115 Create list of local chunks for the given object (DTensor or plain tensor).
117 Used to determine what this rank needs to load (resharding).
119 Args:
120 obj (Union[Tensor, DTensor]): hyper DTensor or platform Tensor.
122 Returns:
123 list[ChunkStorageMetadata]: List of ChunkStorageMetadata representing
124 local chunks needed by this rank.
125 """
126 if isinstance(obj, DTensor):
127 layout = obj.layout
128 if layout is None:
129 shape = obj.shape if hasattr(obj, "shape") else obj.to_local().shape
130 return [ChunkStorageMetadata(offsets=(0,) * len(shape), sizes=tuple(shape))]
132 mesh_shape = getattr(layout, "mesh_shape", None) or getattr(layout, "_mesh", None)
133 tensor_map = getattr(layout, "tensor_map", None) or getattr(layout, "_tensor_map", None)
134 rank_list = getattr(layout, "rank_list", None) or getattr(layout, "_rank_list", None)
136 if mesh_shape is None or tensor_map is None or rank_list is None:
137 shape = obj.shape if hasattr(obj, "shape") else obj.to_local().shape
138 return [ChunkStorageMetadata(offsets=(0,) * len(shape), sizes=tuple(shape))]
140 current_rank = platform.get_rank()
141 if current_rank not in rank_list:
142 return []
144 inner_rank_id = rank_list.index(current_rank)
145 full_shape = obj.shape
146 slice_area = infer_slice_area_by_rank(
147 mesh_shape=mesh_shape,
148 tensor_map=tensor_map,
149 rank_id=inner_rank_id,
150 full_shape=full_shape,
151 )
152 offsets = tuple(s for s, _ in slice_area)
153 sizes = tuple(e - s for s, e in slice_area)
154 return [ChunkStorageMetadata(offsets=offsets, sizes=sizes)]
156 if isinstance(obj, Tensor):
157 # platform.Tensor has exactly one chunk in metadata (full tensor)
158 shape = tuple(obj.shape)
159 return [ChunkStorageMetadata(offsets=(0,) * len(shape), sizes=shape)]
161 raise ValueError(f"Not support type {type(obj)} for creating chunk list ")
164def remove_redundant_plans(
165 all_plans: list[SavePlan],
166 save_to_minimum_rank: bool = False,
167) -> list[SavePlan]:
168 """
169 Remove duplicate entries across SavePlans. For each duplicate, only one plan
170 keeps the entry. The selection prefers the smallest planned storage size
171 (or the minimum rank when save_to_minimum_rank is True).
173 Args:
174 all_plans (list[SavePlan]): List of save plans to deduplicate.
175 save_to_minimum_rank (bool): If True, assign duplicates to the minimum rank; else to plan with minimal storage.
176 Default False.
177 """
178 # Build mapping from item index to set of plan indices containing it
179 duplicate_map: dict[MetadataIndex, set[int]] = defaultdict(set)
180 # Registry to retrieve WriteItem by its index
181 item_registry: dict[MetadataIndex, WriteItem] = {}
182 # Track which items remain in each plan after deduplication
183 remaining_items: list[set[MetadataIndex]] = [
184 {entry.index for entry in plan.items} for plan in all_plans
185 ]
187 # Collect all items and their plan associations
188 for idx, plan in enumerate(all_plans):
189 for entry in plan.items:
190 duplicate_map[entry.index].add(idx)
191 item_registry[entry.index] = entry
193 storage_sizes = [0] * len(all_plans)
195 # Separate unique items (appear in only one plan) from duplicates
196 # Process unique items first to prevent them from affecting load balancing
197 single_plan_items: list[tuple[MetadataIndex, int]] = []
198 multi_plan_items: list[tuple[MetadataIndex, set[int]]] = []
200 for item_key, containing_plans in duplicate_map.items():
201 if len(containing_plans) == 1:
202 single_plan_items.append((item_key, next(iter(containing_plans))))
203 else:
204 multi_plan_items.append((item_key, containing_plans))
206 # First pass: handle items that appear in only one plan
207 for item_key, target_idx in single_plan_items:
208 entry = item_registry[item_key]
209 storage_sizes[target_idx] += entry.tensor_storage_size() or 1
211 # Second pass: assign duplicate items to the plan with minimal storage size
212 for item_key, containing_plans in multi_plan_items:
213 if save_to_minimum_rank:
214 target_plan = min(containing_plans)
215 else:
216 target_plan = min(
217 containing_plans, key=lambda p_idx: storage_sizes[p_idx]
218 )
220 entry = item_registry[item_key]
221 storage_sizes[target_plan] += entry.tensor_storage_size() or 1
222 # Remove this item from all other plans
223 for p_idx in containing_plans - {target_plan}:
224 remaining_items[p_idx].discard(item_key)
226 if len(all_plans) != len(remaining_items):
227 raise AssertionError("len(all_plans) != len(remaining_items)")
229 # Generate deduplicated plans with only remaining items
230 return [
231 dataclasses.replace(
232 plan, items=[entry for entry in plan.items if entry.index in item_set]
233 )
234 for plan, item_set in zip(all_plans, remaining_items)
235 ]