Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / util.py: 39%
139 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Common utility functions."""
16import dataclasses
17from collections import defaultdict
18from collections.abc import Collection, Mapping
19from pathlib import Path
20from typing import Any, Union
22from hyper_parallel.core.distributed_checkpoint.metadata import ChunkStorageMetadata, MetadataIndex
23from hyper_parallel.core.distributed_checkpoint.planner import SavePlan, WriteItem
24from hyper_parallel.core.distributed_checkpoint.reshard import infer_slice_area_by_rank
25from hyper_parallel.core.dtensor.dtensor import DTensor
26from hyper_parallel.platform import get_platform
29platform = get_platform()
30Tensor = platform.Tensor
33def check_path(path: Union[Path, str]) -> None:
34 """
35 Check whether path is existing or not.
37 Args:
38 path (Union[Path, str]): path to check. Can only a file name in current directory, a pure directory, or a file
39 name with directory. When path contains a directory, the function will check whether the directory exists, if
40 not, the directory will be created.
41 """
42 path_obj = Path(path) if isinstance(path, str) else path
44 if path_obj.exists():
45 return
47 if path_obj.suffix:
48 path_obj.parent.mkdir(parents=True, exist_ok=True)
49 else:
50 path_obj.mkdir(parents=True, exist_ok=True)
53def has_valid_filename(path: Path) -> bool:
54 """
55 Check whether path has valid filename. A filename should contain name and suffix, name and suffix must contain
56 letters, and then can have numbers and underscores.
58 Args:
59 path (Path): path to check.
61 Return:
62 bool: whether path has a valid filename.
63 """
64 conditions = (
65 path.name,
66 path.suffix,
67 len(path.suffix) > 1,
68 path.stem,
69 any(c.isalpha() for c in path.stem),
70 any(c.isalpha() for c in path.suffix[1:])
71 )
72 return all(conditions)
75def narrow_tensor_by_index(tensor: Any, offsets: tuple, lengths: tuple) -> Any:
76 """
77 Narrow the tensor by (offsets, lengths) per dimension.
79 Used for resharding operations to extract a slice from a tensor.
80 Compatible with both torch and mindspore (uses slice indexing).
82 Args:
83 tensor (Any): The tensor to narrow (tensor-like object supporting indexing).
84 offsets (tuple): Tuple of offsets per dimension.
85 lengths (tuple): Tuple of lengths per dimension.
87 Returns:
88 Any: The narrowed tensor slice (tensor-like object).
89 """
90 if not offsets or not lengths:
91 return tensor
92 slices = tuple(
93 slice(int(off), int(off) + int(ln))
94 for off, ln in zip(offsets, lengths)
95 )
96 return tensor[slices]
99def chunk_to_area(chunk: ChunkStorageMetadata) -> tuple[tuple[int, int], ...]:
100 """
101 Convert ChunkStorageMetadata to (start, end) area per dimension.
103 Args:
104 chunk (ChunkStorageMetadata): ChunkStorageMetadata instance with offsets and sizes.
106 Returns:
107 tuple[tuple[int, int], ...]: Tuple of (start, end) tuples for each dimension.
108 """
109 return tuple(
110 (chunk.offsets[i], chunk.offsets[i] + chunk.sizes[i])
111 for i in range(len(chunk.offsets))
112 )
115def create_chunk_list_for_tensor(obj: Union[Tensor, DTensor]) -> list[ChunkStorageMetadata]:
116 """
117 Create list of local chunks for the given object (DTensor or plain tensor).
119 Used to determine what this rank needs to load (resharding).
121 Args:
122 obj (Union[Tensor, DTensor]): hyper DTensor or platform Tensor.
124 Returns:
125 list[ChunkStorageMetadata]: List of ChunkStorageMetadata representing
126 local chunks needed by this rank.
127 """
128 if isinstance(obj, DTensor):
129 layout = obj.layout
130 if layout is None:
131 shape = obj.shape if hasattr(obj, "shape") else obj.to_local().shape
132 return [ChunkStorageMetadata(offsets=(0,) * len(shape), sizes=tuple(shape))]
134 mesh_shape = getattr(layout, "mesh_shape", None) or getattr(layout, "_mesh", None)
135 tensor_map = getattr(layout, "tensor_map", None) or getattr(layout, "_tensor_map", None)
136 rank_list = getattr(layout, "rank_list", None) or getattr(layout, "_rank_list", None)
138 if mesh_shape is None or tensor_map is None or rank_list is None:
139 shape = obj.shape if hasattr(obj, "shape") else obj.to_local().shape
140 return [ChunkStorageMetadata(offsets=(0,) * len(shape), sizes=tuple(shape))]
142 current_rank = platform.get_rank()
143 if current_rank not in rank_list:
144 return []
146 inner_rank_id = rank_list.index(current_rank)
147 full_shape = obj.shape
148 slice_area = infer_slice_area_by_rank(
149 mesh_shape=mesh_shape,
150 tensor_map=tensor_map,
151 rank_id=inner_rank_id,
152 full_shape=full_shape,
153 )
154 offsets = tuple(s for s, _ in slice_area)
155 sizes = tuple(e - s for s, e in slice_area)
156 return [ChunkStorageMetadata(offsets=offsets, sizes=sizes)]
158 if isinstance(obj, Tensor):
159 # platform.Tensor has exactly one chunk in metadata (full tensor)
160 shape = tuple(obj.shape)
161 return [ChunkStorageMetadata(offsets=(0,) * len(shape), sizes=shape)]
163 raise ValueError(f"Not support type {type(obj)} for creating chunk list ")
166def remove_redundant_plans(
167 all_plans: list[SavePlan],
168 save_to_minimum_rank: bool = False,
169) -> list[SavePlan]:
170 """
171 Remove duplicate entries across SavePlans. For each duplicate, only one plan
172 keeps the entry. The selection prefers the smallest planned storage size
173 (or the minimum rank when save_to_minimum_rank is True).
175 Args:
176 all_plans (list[SavePlan]): List of save plans to deduplicate.
177 save_to_minimum_rank (bool): If True, assign duplicates to the minimum rank; else to plan with minimal storage.
178 Default False.
179 """
180 # Build mapping from item index to set of plan indices containing it
181 duplicate_map: dict[MetadataIndex, set[int]] = defaultdict(set)
182 # Registry to retrieve WriteItem by its index
183 item_registry: dict[MetadataIndex, WriteItem] = {}
184 # Track which items remain in each plan after deduplication
185 remaining_items: list[set[MetadataIndex]] = [
186 {entry.index for entry in plan.items} for plan in all_plans
187 ]
189 # Collect all items and their plan associations
190 for idx, plan in enumerate(all_plans):
191 for entry in plan.items:
192 duplicate_map[entry.index].add(idx)
193 item_registry[entry.index] = entry
195 storage_sizes = [0] * len(all_plans)
197 # Separate unique items (appear in only one plan) from duplicates
198 # Process unique items first to prevent them from affecting load balancing
199 single_plan_items: list[tuple[MetadataIndex, int]] = []
200 multi_plan_items: list[tuple[MetadataIndex, set[int]]] = []
202 for item_key, containing_plans in duplicate_map.items():
203 if len(containing_plans) == 1:
204 single_plan_items.append((item_key, next(iter(containing_plans))))
205 else:
206 multi_plan_items.append((item_key, containing_plans))
208 # First pass: handle items that appear in only one plan
209 for item_key, target_idx in single_plan_items:
210 entry = item_registry[item_key]
211 storage_sizes[target_idx] += entry.tensor_storage_size() or 1
213 # Second pass: assign duplicate items to the plan with minimal storage size
214 for item_key, containing_plans in multi_plan_items:
215 if save_to_minimum_rank:
216 target_plan = min(containing_plans)
217 else:
218 target_plan = min(
219 containing_plans, key=lambda p_idx: storage_sizes[p_idx]
220 )
222 entry = item_registry[item_key]
223 storage_sizes[target_plan] += entry.tensor_storage_size() or 1
224 # Remove this item from all other plans
225 for p_idx in containing_plans - {target_plan}:
226 remaining_items[p_idx].discard(item_key)
228 if len(all_plans) != len(remaining_items):
229 raise AssertionError("len(all_plans) != len(remaining_items)")
231 # Generate deduplicated plans with only remaining items
232 return [
233 dataclasses.replace(
234 plan, items=[entry for entry in plan.items if entry.index in item_set]
235 )
236 for plan, item_set in zip(all_plans, remaining_items)
237 ]
240def traverse_state_dict(
241 state_dict: Any,
242 visitor: Any,
243) -> None:
244 """
245 Invoke ``visitor`` for each value recursively in ``state_dict``.
246 Mapping will be traversed and ``visitor`` will be applied to the leaf elements.
247 ``visitor`` will only be applied to elements in a list or a tuple, if the
248 container contains tensors or mappings.
249 """
251 def _is_terminal(value: Any) -> bool:
252 """Leaf-like container: no nested mappings/lists/tuples/tensors to recurse into."""
253 values: Collection
254 if isinstance(value, Mapping):
255 return False
256 if isinstance(value, (list, tuple)):
257 values = value
258 else:
259 return True
261 for entry in values:
262 if isinstance(entry, (Mapping, list, tuple)) and not _is_terminal(entry):
263 return False
264 if isinstance(entry, Tensor):
265 return False
266 return True
268 def _traverse_obj(path: tuple[Any, ...], value: Any) -> None:
269 if isinstance(value, Mapping):
270 for k, v in value.items():
271 _traverse_obj(path + (str(k),), v)
272 elif _is_terminal(value):
273 visitor(path, value)
274 elif isinstance(value, (list, tuple)):
275 for i, v in enumerate(value):
276 _traverse_obj(path + (i,), v)
278 for key, value in state_dict.items():
279 _traverse_obj((str(key),), value)
282def flatten_state_dict(state_dict: Any) -> tuple[dict[str, Any], dict[str, tuple[Any, ...]]]:
283 """Flatten a nested state dict to dotted FQN keys; returns ``(flat_dict, fqn -> path)``."""
284 fqn_names: dict[str, Any] = {}
285 mappings: dict[str, tuple[Any, ...]] = {}
287 def flat_copy(path: tuple[Any, ...], value: Any) -> None:
288 new_fqn = ".".join(map(str, path))
289 if new_fqn in fqn_names:
290 raise ValueError(
291 f"Duplicate flattened FQN {new_fqn!r} when converting nested state_dict; "
292 "two different values map to the same dotted name."
293 )
294 fqn_names[new_fqn] = value
295 mappings[new_fqn] = path
297 traverse_state_dict(state_dict, flat_copy)
298 return fqn_names, mappings
301def set_element(root_dict: Any, path: tuple[Any, ...], value: Any) -> None:
302 """Set ``value`` in ``root_dict`` along the ``path`` object path."""
303 if not path:
304 raise ValueError("path must be non-empty")
305 cur_container: Any = root_dict
307 def extend_list(lst: list[Any], idx: int) -> None:
308 while len(lst) <= idx:
309 lst.append(None)
311 for i in range(1, len(path)):
312 prev_key = path[i - 1]
313 next_key = path[i]
314 def_val: Any = {} if isinstance(next_key, str) else []
316 if isinstance(cur_container, Mapping):
317 cur_container = cur_container.setdefault(prev_key, def_val)
318 else:
319 extend_list(cur_container, prev_key)
320 if cur_container[prev_key] is None:
321 cur_container[prev_key] = def_val
322 cur_container = cur_container[prev_key]
324 last_key = path[-1]
325 if isinstance(last_key, int):
326 extend_list(cur_container, last_key)
328 cur_container[last_key] = value