Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / offline_transform.py: 31%

217 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Offline checkpoint conversion: Hugging Face safetensors layout and HyperParallel DCP. 

16 

17Covers **full checkpoint** (unsharded weights) and **shard checkpoint** (e.g. HF multi-file safetensors), 

18``state_dict`` validation, HF-style save/load, and DCP read/write (see ``.claude/rules/offline_transform.md``). 

19 

20Public symbols are listed in ``__all__``; sharding helpers and split types are internal. 

21""" 

22from __future__ import annotations 

23 

24import json 

25import logging 

26import os 

27from dataclasses import dataclass 

28from typing import Any, Literal 

29 

30from hyper_parallel.core.distributed_checkpoint.filesystem_storage import FileSystemReader, FileSystemWriter 

31from hyper_parallel.core.distributed_checkpoint.standard_planner import StandardSavePlanner, _DcpMergeLoadPlanner 

32from hyper_parallel.platform import get_platform 

33from hyper_parallel.platform.platform import PlatformType 

34 

35logger = logging.getLogger(__name__) 

36 

37__all__ = [ 

38 "save_state_dict_as_huggingface_format", 

39 "parse_checkpoint_from_huggingface", 

40 "full_state_dict_to_dcp_format", 

41 "dcp_to_full_state_dict", 

42 "convert_full_checkpoint_to_dcp", 

43] 

44 

45 

46def _validate_state_dict_for_active_platform(state_dict: dict[str, Any]) -> None: 

47 """Ensure values are allowed for the active platform before DCP / HF export. 

48 

49 Applies to a full-weights ``state_dict`` (from a full checkpoint file or merged shard checkpoint content). 

50 

51 PyTorch: ``torch.Tensor`` (including ``nn.Parameter``), ``bytes`` (e.g. optimizer blobs), or ``str`` placeholders 

52 (e.g. bitsandbytes serialization per HF ecosystem). 

53 MindSpore: tensor-like values only (``Parameter`` / ``Tensor``). 

54 

55 Args: 

56 state_dict: Flat or nested mapping rejected here if values are not allowed types. 

57 

58 Raises: 

59 TypeError: If a value has an unsupported type for the current platform. 

60 """ 

61 if not state_dict: 

62 return 

63 platform = get_platform() 

64 if platform.platform_type == PlatformType.PYTORCH: 

65 import torch # pylint: disable=import-outside-toplevel 

66 

67 for key, value in state_dict.items(): 

68 if isinstance(value, (torch.Tensor, bytes, str)): 

69 continue 

70 logger.warning( 

71 "Unsupported PyTorch offline checkpoint value type for key %r: %s " 

72 "(expected torch.Tensor, bytes, or str per project rules).", 

73 key, 

74 type(value).__name__, 

75 ) 

76 raise TypeError( 

77 f"PyTorch offline checkpoint expects torch.Tensor, bytes, or str per key; " 

78 f"got {type(value).__name__} for key {key!r}." 

79 ) 

80 elif platform.platform_type == PlatformType.MINDSPORE: 

81 for key, value in state_dict.items(): 

82 if platform.is_tensor(value): 

83 continue 

84 logger.warning( 

85 "Unsupported MindSpore offline checkpoint value type for key %r: %s " 

86 "(expected tensor-like values per project rules).", 

87 key, 

88 type(value).__name__, 

89 ) 

90 raise TypeError( 

91 f"MindSpore offline checkpoint expects tensor-like values per key; " 

92 f"got {type(value).__name__} for key {key!r}." 

93 ) 

94 

95 

96_SIZE_UNITS = { 

97 "TB": 10**12, 

98 "GB": 10**9, 

99 "MB": 10**6, 

100 "KB": 10**3, 

101} 

102 

103_SAFE_WEIGHTS_NAME = "model.safetensors" 

104_SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" 

105# Hugging Face-style shard filenames: ``model.safetensors`` or ``model-00001-of-00002.safetensors``. 

106_HF_SAFE_WEIGHTS_FILENAME_PATTERN = _SAFE_WEIGHTS_NAME.replace(".safetensors", "{suffix}.safetensors") 

107 

108 

109@dataclass 

110class _StateDictSplit: 

111 """Result of splitting a state dict into named shard files and an optional weight map.""" 

112 

113 metadata: dict[str, Any] 

114 filename_to_tensors: dict[str, list[str]] 

115 tensor_to_filename: dict[str, str] 

116 is_sharded: bool = False 

117 

118 def __post_init__(self) -> None: 

119 self.is_sharded = len(self.filename_to_tensors) > 1 

120 

121 

122def _parse_size_to_int(size_as_str: str) -> int: 

123 """ 

124 Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes). 

125 

126 Supported units are "TB", "GB", "MB", "KB". 

127 

128 Args: 

129 size_as_str (`str`): The size to convert. Will be directly returned if an `int`. 

130 

131 Example: 

132 

133 ```py 

134 >>> _parse_size_to_int("5MB") 

135 5000000 

136 ``` 

137 """ 

138 size_as_str = size_as_str.strip() 

139 

140 # Parse unit 

141 unit = size_as_str[-2:].upper() 

142 if unit not in _SIZE_UNITS: 

143 raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.") 

144 multiplier = _SIZE_UNITS[unit] 

145 

146 # Parse value 

147 try: 

148 value = float(size_as_str[:-2].strip()) 

149 except ValueError as e: 

150 raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e 

151 

152 return int(value * multiplier) 

153 

154 

155def _build_shard_list_for_safetensors( 

156 state_dict: dict[str, Any], 

157 max_cap: int, 

158 get_storage_size: Any, 

159) -> list[dict[str, Any]]: 

160 """Assign tensors to shards in key order (greedy by ``max_cap`` bytes).""" 

161 shard_list: list[dict[str, Any]] = [] 

162 current_shard: dict[str, Any] = {} 

163 current_shard_size = 0 

164 

165 for key, tensor in state_dict.items(): 

166 tensor_size = get_storage_size(tensor) 

167 

168 if tensor_size > max_cap: 

169 shard_list.append({key: tensor}) 

170 continue 

171 

172 if current_shard_size + tensor_size > max_cap: 

173 shard_list.append(current_shard) 

174 current_shard = {} 

175 current_shard_size = 0 

176 

177 current_shard[key] = tensor 

178 current_shard_size += tensor_size 

179 

180 if len(current_shard) > 0: 

181 shard_list.append(current_shard) 

182 return shard_list 

183 

184 

185def _total_bytes_unique_keys_in_shards( 

186 shard_list: list[dict[str, Any]], 

187 get_storage_size: Any, 

188) -> int: 

189 """Sum storage bytes once per parameter key (excludes str placeholders).""" 

190 total_size = 0 

191 seen_keys: set[str] = set() 

192 for shard in shard_list: 

193 for k, tensor in shard.items(): 

194 if k in seen_keys: 

195 continue 

196 seen_keys.add(k) 

197 if isinstance(tensor, str): 

198 continue 

199 total_size += get_storage_size(tensor) 

200 return total_size 

201 

202 

203def _split_state_dict_into_shards( 

204 state_dict: dict[str, Any], 

205 max_shard_size: int | str, 

206) -> _StateDictSplit: 

207 """Shard ``state_dict`` for safetensors using the active platform's storage size metric. 

208 

209 Shards are built by iterating ``state_dict`` in key order (no bin-packing optimization). If a single tensor exceeds 

210 ``max_shard_size``, it occupies its own shard (possibly larger than the cap). 

211 

212 Shard filenames follow Hugging Face conventions (``model.safetensors`` or 

213 ``model-00001-of-00002.safetensors``). 

214 """ 

215 platform = get_platform() 

216 get_storage_size = platform.get_tensor_storage_size 

217 

218 filename_pattern = _HF_SAFE_WEIGHTS_FILENAME_PATTERN 

219 

220 max_cap: int | str = max_shard_size 

221 if isinstance(max_cap, str): 

222 max_cap = _parse_size_to_int(max_cap) 

223 

224 shard_list = _build_shard_list_for_safetensors(state_dict, max_cap, get_storage_size) 

225 nb_shards = len(shard_list) 

226 

227 total_size = _total_bytes_unique_keys_in_shards(shard_list, get_storage_size) 

228 

229 # If we only have one shard, we return it => no need to build the index 

230 if nb_shards == 1: 

231 filename = filename_pattern.format(suffix="") 

232 keys_in_shards = [k for k, t in state_dict.items() if not isinstance(t, str)] 

233 return _StateDictSplit( 

234 metadata={"total_size": total_size}, 

235 filename_to_tensors={filename: keys_in_shards}, 

236 tensor_to_filename={key: filename for key in keys_in_shards}, 

237 ) 

238 

239 # Now that each tensor is assigned to a shard, let's assign a filename to each shard 

240 tensor_name_to_filename = {} 

241 filename_to_tensors = {} 

242 for idx, shard in enumerate(shard_list): 

243 filename = filename_pattern.format(suffix=f"-{idx + 1:05d}-of-{nb_shards:05d}") 

244 for key in shard: 

245 tensor_name_to_filename[key] = filename 

246 filename_to_tensors[filename] = list(shard.keys()) 

247 

248 # Build the index and return 

249 return _StateDictSplit( 

250 metadata={"total_size": total_size}, 

251 filename_to_tensors=filename_to_tensors, 

252 tensor_to_filename=tensor_name_to_filename, 

253 ) 

254 

255 

256def save_state_dict_as_huggingface_format( 

257 save_directory: str | os.PathLike[str], 

258 state_dict: dict[str, Any], 

259 max_shard_size: str | int = "5GB", 

260) -> None: 

261 """Write ``state_dict`` to ``save_directory`` as Hugging Face-style safetensors (optionally sharded). 

262 

263 Args: 

264 save_directory: Output directory path. 

265 state_dict: Parameter tensors for the active platform, keyed by name. 

266 max_shard_size: Per-shard byte cap (int) or string such as ``\"5GB\"``. 

267 

268 Returns: 

269 None 

270 """ 

271 if state_dict is None: 

272 raise ValueError("state_dict is None.") 

273 

274 _validate_state_dict_for_active_platform(state_dict) 

275 

276 if os.path.isfile(save_directory): 

277 raise ValueError(f"The save_directory {save_directory} should be a directory, but a file.") 

278 

279 os.makedirs(save_directory, exist_ok=True) 

280 

281 platform = get_platform() 

282 weights_name = _SAFE_WEIGHTS_NAME 

283 state_dict_split = _split_state_dict_into_shards(state_dict, max_shard_size=max_shard_size) 

284 

285 index = None 

286 if state_dict_split.is_sharded: 

287 index = { 

288 "metadata": state_dict_split.metadata, 

289 "weight_map": state_dict_split.tensor_to_filename, 

290 } 

291 

292 filename_to_tensors = state_dict_split.filename_to_tensors.items() 

293 

294 for shard_file, tensors in filename_to_tensors: 

295 shard = {tensor: state_dict[tensor] for tensor in tensors} 

296 platform.save_checkpoint(shard, os.path.join(save_directory, shard_file), ckpt_format="safetensors") 

297 if index is None: 

298 path_to_weights = os.path.join(save_directory, weights_name) 

299 logger.info("Model weights saved in %s", path_to_weights) 

300 else: 

301 save_index_file = os.path.join(save_directory, _SAFE_WEIGHTS_INDEX_NAME) 

302 with open(save_index_file, "w", encoding="utf-8") as f: 

303 content = json.dumps(index, indent=2, sort_keys=True) + "\n" 

304 f.write(content) 

305 logger.info( 

306 "Model exceeds max shard size %s; split into %s shards. Weight map index: %s", 

307 max_shard_size, 

308 len(state_dict_split.filename_to_tensors), 

309 save_index_file, 

310 ) 

311 

312 

313def parse_checkpoint_from_huggingface(resume_from_checkpoint: str | os.PathLike[str]) -> dict[str, Any]: 

314 """Load a Hugging Face ``model.safetensors`` or sharded safetensors + index from a directory. 

315 

316 Uses the active platform's :meth:`Platform.load_checkpoint` for safetensors. 

317 

318 Args: 

319 resume_from_checkpoint: Directory containing ``model.safetensors`` or index + shard files. 

320 

321 Returns: 

322 Flattened str-keyed state dict (framework tensors on CPU when applicable). 

323 """ 

324 platform = get_platform() 

325 safe_weights_file = os.path.join(resume_from_checkpoint, _SAFE_WEIGHTS_NAME) 

326 

327 state_dict: dict[str, Any] = {} 

328 

329 if os.path.isfile(safe_weights_file): 

330 state_dict = platform.load_checkpoint(safe_weights_file, ckpt_format="safetensors") 

331 logger.info("Loaded safetensors checkpoint from %s", safe_weights_file) 

332 

333 else: 

334 safe_index_file = os.path.join(resume_from_checkpoint, _SAFE_WEIGHTS_INDEX_NAME) 

335 if not os.path.isfile(safe_index_file): 

336 raise ValueError(f"Can't find a checkpoint index in {os.path.abspath(resume_from_checkpoint)}.") 

337 

338 with open(safe_index_file, "r", encoding="utf-8") as f: 

339 index = json.load(f) 

340 shard_files = list(set(index["weight_map"].values())) 

341 

342 total_size = 0 

343 for shard_file in shard_files: 

344 shard_path = os.path.join(resume_from_checkpoint, shard_file) 

345 shard_state_dict = platform.load_checkpoint(shard_path, ckpt_format="safetensors") 

346 for key, value in shard_state_dict.items(): 

347 if key in state_dict: 

348 logger.warning( 

349 "Duplicate key %r when merging Hugging Face shards; keeping first occurrence.", 

350 key, 

351 ) 

352 continue 

353 state_dict[key] = value 

354 if platform.is_tensor(value): 

355 size = platform.get_tensor_storage_size(value) 

356 total_size += size 

357 logger.debug("Loaded tensor %r, size %s bytes", key, size) 

358 else: 

359 logger.debug("Loaded non-tensor entry %r", key) 

360 logger.info("Merged Hugging Face shards, total tensor bytes (sum per key): %s", total_size) 

361 _validate_state_dict_for_active_platform(state_dict) 

362 return state_dict 

363 

364 

365def _mindspore_full_checkpoint_format_for_path(path: str) -> str: 

366 """Infer MindSpore ``load_checkpoint`` ``format`` from a full-checkpoint file suffix.""" 

367 lower = path.lower() 

368 if lower.endswith(".safetensors"): 

369 return "safetensors" 

370 return "ckpt" 

371 

372 

373def _torch_full_checkpoint_format_for_path(path: str) -> str: 

374 """Infer PyTorch :meth:`Platform.load_checkpoint` ``ckpt_format`` from a full-checkpoint file suffix.""" 

375 lower = path.lower() 

376 if lower.endswith(".safetensors"): 

377 return "safetensors" 

378 return "pickle" 

379 

380 

381def full_state_dict_to_dcp_format( 

382 state_dict: dict[str, Any], 

383 dst_dir: str | os.PathLike[str], 

384) -> None: 

385 """Convert a full-weights ``state_dict`` into HyperParallel DCP layout under ``dst_dir``. 

386 

387 Args: 

388 state_dict: Merged weights to store; types must match the active platform (see project conventions). 

389 dst_dir: Output directory for DCP files (``.metadata`` and shard files). 

390 

391 Returns: 

392 None 

393 """ 

394 _validate_state_dict_for_active_platform(state_dict) 

395 storage_writer = FileSystemWriter(dst_dir) 

396 planner = StandardSavePlanner() 

397 

398 planner.configure_planner(state_dict=state_dict, is_coordinator=True) 

399 storage_writer.configure_writer(is_coordinator=True, rank=0, use_collectives=True) 

400 local_plan = planner.build_local_plan() 

401 local_data = storage_writer.optimize_local_plan(local_plan) 

402 

403 all_local_plans, global_metadata = planner.build_global_plan([local_data]) 

404 central_plan = storage_writer.optimize_global_plan(all_local_plans)[0] 

405 

406 final_local_plan = planner.finalize_plan(central_plan) 

407 all_writes = storage_writer.execute_write(final_local_plan, planner) 

408 

409 storage_writer.finalize_checkpoint(metadata=global_metadata, results=[all_writes]) 

410 

411 

412def dcp_to_full_state_dict(src_dir: str | os.PathLike[str]) -> dict[str, Any]: 

413 """Load a **shard** DCP layout from ``src_dir`` into one merged full-weights ``state_dict``. 

414 

415 Coordinator / single-process merge of the DCP shards on disk. 

416 

417 Args: 

418 src_dir: Root directory of a saved DCP checkpoint. 

419 

420 Returns: 

421 Merged state dictionary populated by the load planner and storage reader. 

422 """ 

423 state_dict: dict[str, Any] = {} 

424 planner = _DcpMergeLoadPlanner() 

425 storage_reader = FileSystemReader(src_dir) 

426 

427 metadata = storage_reader.load_metadata() 

428 planner.configure_planner(state_dict, metadata, is_coordinator=True) 

429 storage_reader.configure_reader(metadata, is_coordinator=True, rank=0) 

430 

431 local_plan = planner.build_local_plan() 

432 local_data = storage_reader.optimize_local_plan(local_plan) 

433 all_data = [local_data] 

434 

435 all_local_plans = planner.build_global_plan(all_data) 

436 all_results = storage_reader.optimize_global_plan(all_local_plans) 

437 

438 final_local_plan = planner.finalize_plan(all_results[0]) 

439 storage_reader.execute_read(final_local_plan, planner) 

440 

441 return state_dict 

442 

443 

444def convert_full_checkpoint_to_dcp( 

445 src_ckpt: str | os.PathLike[str], 

446 dst_dir: str | os.PathLike[str], 

447 src_platform: Literal["torch", "huggingface", "mindspore"] = "torch", 

448) -> None: 

449 """Load checkpoint weights and write HyperParallel DCP under ``dst_dir``. 

450 

451 Args: 

452 src_ckpt: Source path: Hugging Face safetensors directory (full or shard layout), a PyTorch **full checkpoint** 

453 file, or a MindSpore **full checkpoint** file (``.ckpt`` / ``.safetensors``). 

454 dst_dir: Output directory for DCP files. 

455 src_platform: ``huggingface`` (HF directory), ``torch`` or ``mindspore`` (full-checkpoint file via 

456 :meth:`Platform.load_checkpoint`). For ``torch`` / ``mindspore``, the active runtime must match 

457 (``HYPER_PARALLEL_PLATFORM``). ``huggingface`` uses either platform for safetensors. 

458 

459 Returns: 

460 None 

461 """ 

462 platform = get_platform() 

463 if src_platform == "huggingface": 

464 state_dict = parse_checkpoint_from_huggingface(src_ckpt) 

465 elif src_platform == "torch": 

466 if platform.platform_type != PlatformType.PYTORCH: 

467 raise ValueError( 

468 "src_platform='torch' requires the PyTorch platform; set HYPER_PARALLEL_PLATFORM=torch." 

469 ) 

470 if not os.path.isfile(src_ckpt): 

471 raise ValueError( 

472 "src_platform='torch' requires a single PyTorch checkpoint file path; " 

473 f"src_ckpt must be an existing file, got {src_ckpt!r}." 

474 ) 

475 fmt = _torch_full_checkpoint_format_for_path(str(src_ckpt)) 

476 state_dict = platform.load_checkpoint(str(src_ckpt), ckpt_format=fmt) 

477 elif src_platform == "mindspore": 

478 if platform.platform_type != PlatformType.MINDSPORE: 

479 raise ValueError( 

480 "src_platform='mindspore' requires the MindSpore platform; set HYPER_PARALLEL_PLATFORM=mindspore." 

481 ) 

482 if not os.path.isfile(src_ckpt): 

483 raise ValueError( 

484 "src_platform='mindspore' requires a single checkpoint file path; " 

485 f"src_ckpt must be an existing file, got {src_ckpt!r}." 

486 ) 

487 fmt = _mindspore_full_checkpoint_format_for_path(str(src_ckpt)) 

488 state_dict = platform.load_checkpoint(str(src_ckpt), ckpt_format=fmt) 

489 else: 

490 raise ValueError( 

491 f"Unsupported src_platform={src_platform!r}; expected 'huggingface', 'torch', or 'mindspore'." 

492 ) 

493 full_state_dict_to_dcp_format(state_dict, dst_dir)