Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / offline_transform.py: 31%
217 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Offline checkpoint conversion: Hugging Face safetensors layout and HyperParallel DCP.
17Covers **full checkpoint** (unsharded weights) and **shard checkpoint** (e.g. HF multi-file safetensors),
18``state_dict`` validation, HF-style save/load, and DCP read/write (see ``.claude/rules/offline_transform.md``).
20Public symbols are listed in ``__all__``; sharding helpers and split types are internal.
21"""
22from __future__ import annotations
24import json
25import logging
26import os
27from dataclasses import dataclass
28from typing import Any, Literal
30from hyper_parallel.core.distributed_checkpoint.filesystem_storage import FileSystemReader, FileSystemWriter
31from hyper_parallel.core.distributed_checkpoint.standard_planner import StandardSavePlanner, _DcpMergeLoadPlanner
32from hyper_parallel.platform import get_platform
33from hyper_parallel.platform.platform import PlatformType
35logger = logging.getLogger(__name__)
37__all__ = [
38 "save_state_dict_as_huggingface_format",
39 "parse_checkpoint_from_huggingface",
40 "full_state_dict_to_dcp_format",
41 "dcp_to_full_state_dict",
42 "convert_full_checkpoint_to_dcp",
43]
46def _validate_state_dict_for_active_platform(state_dict: dict[str, Any]) -> None:
47 """Ensure values are allowed for the active platform before DCP / HF export.
49 Applies to a full-weights ``state_dict`` (from a full checkpoint file or merged shard checkpoint content).
51 PyTorch: ``torch.Tensor`` (including ``nn.Parameter``), ``bytes`` (e.g. optimizer blobs), or ``str`` placeholders
52 (e.g. bitsandbytes serialization per HF ecosystem).
53 MindSpore: tensor-like values only (``Parameter`` / ``Tensor``).
55 Args:
56 state_dict: Flat or nested mapping rejected here if values are not allowed types.
58 Raises:
59 TypeError: If a value has an unsupported type for the current platform.
60 """
61 if not state_dict:
62 return
63 platform = get_platform()
64 if platform.platform_type == PlatformType.PYTORCH:
65 import torch # pylint: disable=import-outside-toplevel
67 for key, value in state_dict.items():
68 if isinstance(value, (torch.Tensor, bytes, str)):
69 continue
70 logger.warning(
71 "Unsupported PyTorch offline checkpoint value type for key %r: %s "
72 "(expected torch.Tensor, bytes, or str per project rules).",
73 key,
74 type(value).__name__,
75 )
76 raise TypeError(
77 f"PyTorch offline checkpoint expects torch.Tensor, bytes, or str per key; "
78 f"got {type(value).__name__} for key {key!r}."
79 )
80 elif platform.platform_type == PlatformType.MINDSPORE:
81 for key, value in state_dict.items():
82 if platform.is_tensor(value):
83 continue
84 logger.warning(
85 "Unsupported MindSpore offline checkpoint value type for key %r: %s "
86 "(expected tensor-like values per project rules).",
87 key,
88 type(value).__name__,
89 )
90 raise TypeError(
91 f"MindSpore offline checkpoint expects tensor-like values per key; "
92 f"got {type(value).__name__} for key {key!r}."
93 )
96_SIZE_UNITS = {
97 "TB": 10**12,
98 "GB": 10**9,
99 "MB": 10**6,
100 "KB": 10**3,
101}
103_SAFE_WEIGHTS_NAME = "model.safetensors"
104_SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
105# Hugging Face-style shard filenames: ``model.safetensors`` or ``model-00001-of-00002.safetensors``.
106_HF_SAFE_WEIGHTS_FILENAME_PATTERN = _SAFE_WEIGHTS_NAME.replace(".safetensors", "{suffix}.safetensors")
109@dataclass
110class _StateDictSplit:
111 """Result of splitting a state dict into named shard files and an optional weight map."""
113 metadata: dict[str, Any]
114 filename_to_tensors: dict[str, list[str]]
115 tensor_to_filename: dict[str, str]
116 is_sharded: bool = False
118 def __post_init__(self) -> None:
119 self.is_sharded = len(self.filename_to_tensors) > 1
122def _parse_size_to_int(size_as_str: str) -> int:
123 """
124 Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes).
126 Supported units are "TB", "GB", "MB", "KB".
128 Args:
129 size_as_str (`str`): The size to convert. Will be directly returned if an `int`.
131 Example:
133 ```py
134 >>> _parse_size_to_int("5MB")
135 5000000
136 ```
137 """
138 size_as_str = size_as_str.strip()
140 # Parse unit
141 unit = size_as_str[-2:].upper()
142 if unit not in _SIZE_UNITS:
143 raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.")
144 multiplier = _SIZE_UNITS[unit]
146 # Parse value
147 try:
148 value = float(size_as_str[:-2].strip())
149 except ValueError as e:
150 raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e
152 return int(value * multiplier)
155def _build_shard_list_for_safetensors(
156 state_dict: dict[str, Any],
157 max_cap: int,
158 get_storage_size: Any,
159) -> list[dict[str, Any]]:
160 """Assign tensors to shards in key order (greedy by ``max_cap`` bytes)."""
161 shard_list: list[dict[str, Any]] = []
162 current_shard: dict[str, Any] = {}
163 current_shard_size = 0
165 for key, tensor in state_dict.items():
166 tensor_size = get_storage_size(tensor)
168 if tensor_size > max_cap:
169 shard_list.append({key: tensor})
170 continue
172 if current_shard_size + tensor_size > max_cap:
173 shard_list.append(current_shard)
174 current_shard = {}
175 current_shard_size = 0
177 current_shard[key] = tensor
178 current_shard_size += tensor_size
180 if len(current_shard) > 0:
181 shard_list.append(current_shard)
182 return shard_list
185def _total_bytes_unique_keys_in_shards(
186 shard_list: list[dict[str, Any]],
187 get_storage_size: Any,
188) -> int:
189 """Sum storage bytes once per parameter key (excludes str placeholders)."""
190 total_size = 0
191 seen_keys: set[str] = set()
192 for shard in shard_list:
193 for k, tensor in shard.items():
194 if k in seen_keys:
195 continue
196 seen_keys.add(k)
197 if isinstance(tensor, str):
198 continue
199 total_size += get_storage_size(tensor)
200 return total_size
203def _split_state_dict_into_shards(
204 state_dict: dict[str, Any],
205 max_shard_size: int | str,
206) -> _StateDictSplit:
207 """Shard ``state_dict`` for safetensors using the active platform's storage size metric.
209 Shards are built by iterating ``state_dict`` in key order (no bin-packing optimization). If a single tensor exceeds
210 ``max_shard_size``, it occupies its own shard (possibly larger than the cap).
212 Shard filenames follow Hugging Face conventions (``model.safetensors`` or
213 ``model-00001-of-00002.safetensors``).
214 """
215 platform = get_platform()
216 get_storage_size = platform.get_tensor_storage_size
218 filename_pattern = _HF_SAFE_WEIGHTS_FILENAME_PATTERN
220 max_cap: int | str = max_shard_size
221 if isinstance(max_cap, str):
222 max_cap = _parse_size_to_int(max_cap)
224 shard_list = _build_shard_list_for_safetensors(state_dict, max_cap, get_storage_size)
225 nb_shards = len(shard_list)
227 total_size = _total_bytes_unique_keys_in_shards(shard_list, get_storage_size)
229 # If we only have one shard, we return it => no need to build the index
230 if nb_shards == 1:
231 filename = filename_pattern.format(suffix="")
232 keys_in_shards = [k for k, t in state_dict.items() if not isinstance(t, str)]
233 return _StateDictSplit(
234 metadata={"total_size": total_size},
235 filename_to_tensors={filename: keys_in_shards},
236 tensor_to_filename={key: filename for key in keys_in_shards},
237 )
239 # Now that each tensor is assigned to a shard, let's assign a filename to each shard
240 tensor_name_to_filename = {}
241 filename_to_tensors = {}
242 for idx, shard in enumerate(shard_list):
243 filename = filename_pattern.format(suffix=f"-{idx + 1:05d}-of-{nb_shards:05d}")
244 for key in shard:
245 tensor_name_to_filename[key] = filename
246 filename_to_tensors[filename] = list(shard.keys())
248 # Build the index and return
249 return _StateDictSplit(
250 metadata={"total_size": total_size},
251 filename_to_tensors=filename_to_tensors,
252 tensor_to_filename=tensor_name_to_filename,
253 )
256def save_state_dict_as_huggingface_format(
257 save_directory: str | os.PathLike[str],
258 state_dict: dict[str, Any],
259 max_shard_size: str | int = "5GB",
260) -> None:
261 """Write ``state_dict`` to ``save_directory`` as Hugging Face-style safetensors (optionally sharded).
263 Args:
264 save_directory: Output directory path.
265 state_dict: Parameter tensors for the active platform, keyed by name.
266 max_shard_size: Per-shard byte cap (int) or string such as ``\"5GB\"``.
268 Returns:
269 None
270 """
271 if state_dict is None:
272 raise ValueError("state_dict is None.")
274 _validate_state_dict_for_active_platform(state_dict)
276 if os.path.isfile(save_directory):
277 raise ValueError(f"The save_directory {save_directory} should be a directory, but a file.")
279 os.makedirs(save_directory, exist_ok=True)
281 platform = get_platform()
282 weights_name = _SAFE_WEIGHTS_NAME
283 state_dict_split = _split_state_dict_into_shards(state_dict, max_shard_size=max_shard_size)
285 index = None
286 if state_dict_split.is_sharded:
287 index = {
288 "metadata": state_dict_split.metadata,
289 "weight_map": state_dict_split.tensor_to_filename,
290 }
292 filename_to_tensors = state_dict_split.filename_to_tensors.items()
294 for shard_file, tensors in filename_to_tensors:
295 shard = {tensor: state_dict[tensor] for tensor in tensors}
296 platform.save_checkpoint(shard, os.path.join(save_directory, shard_file), ckpt_format="safetensors")
297 if index is None:
298 path_to_weights = os.path.join(save_directory, weights_name)
299 logger.info("Model weights saved in %s", path_to_weights)
300 else:
301 save_index_file = os.path.join(save_directory, _SAFE_WEIGHTS_INDEX_NAME)
302 with open(save_index_file, "w", encoding="utf-8") as f:
303 content = json.dumps(index, indent=2, sort_keys=True) + "\n"
304 f.write(content)
305 logger.info(
306 "Model exceeds max shard size %s; split into %s shards. Weight map index: %s",
307 max_shard_size,
308 len(state_dict_split.filename_to_tensors),
309 save_index_file,
310 )
313def parse_checkpoint_from_huggingface(resume_from_checkpoint: str | os.PathLike[str]) -> dict[str, Any]:
314 """Load a Hugging Face ``model.safetensors`` or sharded safetensors + index from a directory.
316 Uses the active platform's :meth:`Platform.load_checkpoint` for safetensors.
318 Args:
319 resume_from_checkpoint: Directory containing ``model.safetensors`` or index + shard files.
321 Returns:
322 Flattened str-keyed state dict (framework tensors on CPU when applicable).
323 """
324 platform = get_platform()
325 safe_weights_file = os.path.join(resume_from_checkpoint, _SAFE_WEIGHTS_NAME)
327 state_dict: dict[str, Any] = {}
329 if os.path.isfile(safe_weights_file):
330 state_dict = platform.load_checkpoint(safe_weights_file, ckpt_format="safetensors")
331 logger.info("Loaded safetensors checkpoint from %s", safe_weights_file)
333 else:
334 safe_index_file = os.path.join(resume_from_checkpoint, _SAFE_WEIGHTS_INDEX_NAME)
335 if not os.path.isfile(safe_index_file):
336 raise ValueError(f"Can't find a checkpoint index in {os.path.abspath(resume_from_checkpoint)}.")
338 with open(safe_index_file, "r", encoding="utf-8") as f:
339 index = json.load(f)
340 shard_files = list(set(index["weight_map"].values()))
342 total_size = 0
343 for shard_file in shard_files:
344 shard_path = os.path.join(resume_from_checkpoint, shard_file)
345 shard_state_dict = platform.load_checkpoint(shard_path, ckpt_format="safetensors")
346 for key, value in shard_state_dict.items():
347 if key in state_dict:
348 logger.warning(
349 "Duplicate key %r when merging Hugging Face shards; keeping first occurrence.",
350 key,
351 )
352 continue
353 state_dict[key] = value
354 if platform.is_tensor(value):
355 size = platform.get_tensor_storage_size(value)
356 total_size += size
357 logger.debug("Loaded tensor %r, size %s bytes", key, size)
358 else:
359 logger.debug("Loaded non-tensor entry %r", key)
360 logger.info("Merged Hugging Face shards, total tensor bytes (sum per key): %s", total_size)
361 _validate_state_dict_for_active_platform(state_dict)
362 return state_dict
365def _mindspore_full_checkpoint_format_for_path(path: str) -> str:
366 """Infer MindSpore ``load_checkpoint`` ``format`` from a full-checkpoint file suffix."""
367 lower = path.lower()
368 if lower.endswith(".safetensors"):
369 return "safetensors"
370 return "ckpt"
373def _torch_full_checkpoint_format_for_path(path: str) -> str:
374 """Infer PyTorch :meth:`Platform.load_checkpoint` ``ckpt_format`` from a full-checkpoint file suffix."""
375 lower = path.lower()
376 if lower.endswith(".safetensors"):
377 return "safetensors"
378 return "pickle"
381def full_state_dict_to_dcp_format(
382 state_dict: dict[str, Any],
383 dst_dir: str | os.PathLike[str],
384) -> None:
385 """Convert a full-weights ``state_dict`` into HyperParallel DCP layout under ``dst_dir``.
387 Args:
388 state_dict: Merged weights to store; types must match the active platform (see project conventions).
389 dst_dir: Output directory for DCP files (``.metadata`` and shard files).
391 Returns:
392 None
393 """
394 _validate_state_dict_for_active_platform(state_dict)
395 storage_writer = FileSystemWriter(dst_dir)
396 planner = StandardSavePlanner()
398 planner.configure_planner(state_dict=state_dict, is_coordinator=True)
399 storage_writer.configure_writer(is_coordinator=True, rank=0, use_collectives=True)
400 local_plan = planner.build_local_plan()
401 local_data = storage_writer.optimize_local_plan(local_plan)
403 all_local_plans, global_metadata = planner.build_global_plan([local_data])
404 central_plan = storage_writer.optimize_global_plan(all_local_plans)[0]
406 final_local_plan = planner.finalize_plan(central_plan)
407 all_writes = storage_writer.execute_write(final_local_plan, planner)
409 storage_writer.finalize_checkpoint(metadata=global_metadata, results=[all_writes])
412def dcp_to_full_state_dict(src_dir: str | os.PathLike[str]) -> dict[str, Any]:
413 """Load a **shard** DCP layout from ``src_dir`` into one merged full-weights ``state_dict``.
415 Coordinator / single-process merge of the DCP shards on disk.
417 Args:
418 src_dir: Root directory of a saved DCP checkpoint.
420 Returns:
421 Merged state dictionary populated by the load planner and storage reader.
422 """
423 state_dict: dict[str, Any] = {}
424 planner = _DcpMergeLoadPlanner()
425 storage_reader = FileSystemReader(src_dir)
427 metadata = storage_reader.load_metadata()
428 planner.configure_planner(state_dict, metadata, is_coordinator=True)
429 storage_reader.configure_reader(metadata, is_coordinator=True, rank=0)
431 local_plan = planner.build_local_plan()
432 local_data = storage_reader.optimize_local_plan(local_plan)
433 all_data = [local_data]
435 all_local_plans = planner.build_global_plan(all_data)
436 all_results = storage_reader.optimize_global_plan(all_local_plans)
438 final_local_plan = planner.finalize_plan(all_results[0])
439 storage_reader.execute_read(final_local_plan, planner)
441 return state_dict
444def convert_full_checkpoint_to_dcp(
445 src_ckpt: str | os.PathLike[str],
446 dst_dir: str | os.PathLike[str],
447 src_platform: Literal["torch", "huggingface", "mindspore"] = "torch",
448) -> None:
449 """Load checkpoint weights and write HyperParallel DCP under ``dst_dir``.
451 Args:
452 src_ckpt: Source path: Hugging Face safetensors directory (full or shard layout), a PyTorch **full checkpoint**
453 file, or a MindSpore **full checkpoint** file (``.ckpt`` / ``.safetensors``).
454 dst_dir: Output directory for DCP files.
455 src_platform: ``huggingface`` (HF directory), ``torch`` or ``mindspore`` (full-checkpoint file via
456 :meth:`Platform.load_checkpoint`). For ``torch`` / ``mindspore``, the active runtime must match
457 (``HYPER_PARALLEL_PLATFORM``). ``huggingface`` uses either platform for safetensors.
459 Returns:
460 None
461 """
462 platform = get_platform()
463 if src_platform == "huggingface":
464 state_dict = parse_checkpoint_from_huggingface(src_ckpt)
465 elif src_platform == "torch":
466 if platform.platform_type != PlatformType.PYTORCH:
467 raise ValueError(
468 "src_platform='torch' requires the PyTorch platform; set HYPER_PARALLEL_PLATFORM=torch."
469 )
470 if not os.path.isfile(src_ckpt):
471 raise ValueError(
472 "src_platform='torch' requires a single PyTorch checkpoint file path; "
473 f"src_ckpt must be an existing file, got {src_ckpt!r}."
474 )
475 fmt = _torch_full_checkpoint_format_for_path(str(src_ckpt))
476 state_dict = platform.load_checkpoint(str(src_ckpt), ckpt_format=fmt)
477 elif src_platform == "mindspore":
478 if platform.platform_type != PlatformType.MINDSPORE:
479 raise ValueError(
480 "src_platform='mindspore' requires the MindSpore platform; set HYPER_PARALLEL_PLATFORM=mindspore."
481 )
482 if not os.path.isfile(src_ckpt):
483 raise ValueError(
484 "src_platform='mindspore' requires a single checkpoint file path; "
485 f"src_ckpt must be an existing file, got {src_ckpt!r}."
486 )
487 fmt = _mindspore_full_checkpoint_format_for_path(str(src_ckpt))
488 state_dict = platform.load_checkpoint(str(src_ckpt), ckpt_format=fmt)
489 else:
490 raise ValueError(
491 f"Unsupported src_platform={src_platform!r}; expected 'huggingface', 'torch', or 'mindspore'."
492 )
493 full_state_dict_to_dcp_format(state_dict, dst_dir)