Coverage for hyper_parallel / core / fully_shard / api.py: 58%
249 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""hybrid shard data parallel interface"""
16from typing import Any, Mapping, cast, Optional, Union
18import torch
19import torch.distributed as dist
20from torch import nn
21from torch.distributed.checkpoint.state_dict import StateDictOptions
23from hyper_parallel.platform.platform import PlatformType
24from hyper_parallel import DeviceMesh, init_device_mesh
25from hyper_parallel.platform import get_platform
26from hyper_parallel.core.dtensor import DTensor, distribute_tensor
28platform = get_platform()
30origin_class_to_extend_class = {}
33class _UnshardHandle:
34 def __init__(self, hsdp_state=None):
35 self._hsdp_state = hsdp_state
37 def wait(self):
38 if self._hsdp_state is not None:
39 self._hsdp_state.wait_for_unshard()
40 self._hsdp_state = None
43class HSDPModule:
44 """
45 The hsdp block of neural networks with hsdp interface.
47 Supported Platforms:
48 ``MindSpore`` ``torch``
49 """
51 # pylint: disable=C0415
52 def hsdp_init(self, platform_type, module, mesh, reshard_after_forward,
53 shard_placement_fn, mp_policy, offload_policy, ignored_params, device):
54 """init hsdp2 scheduler."""
55 scheduler_class = None
56 if platform_type == PlatformType.MINDSPORE:
57 from hyper_parallel.platform.mindspore.hsdp.scheduler import MindSporeHSDPScheduler
58 scheduler_class = MindSporeHSDPScheduler
59 else:
60 from hyper_parallel.platform.torch.fully_shard.scheduler import TorchHSDPSchedulerV2
61 scheduler_class = TorchHSDPSchedulerV2
63 self.hsdp_scheduler = scheduler_class(module,
64 mesh,
65 reshard_after_forward,
66 shard_placement_fn,
67 mp_policy,
68 offload_policy,
69 ignored_params,
70 device,
71 )
73 def set_requires_gradient_sync(self, requires_grad_sync):
74 r"""
75 set requires grad sync flag.
76 Args:
77 requires_grad_sync(bool): requires_grad_sync is used to control gradient sync process.
78 Raises:
79 ValueError: If `requires_grad_sync` is not bool.
80 """
81 if not isinstance(requires_grad_sync, bool):
82 raise ValueError(f"requires_grad_sync must be bool but got {requires_grad_sync}.")
83 if not hasattr(self, "hsdp_scheduler"):
84 raise ValueError("call hsdp interface first.")
86 for _, module in platform.get_cells_and_names(self):
87 if isinstance(module, HSDPModule):
88 module.hsdp_scheduler.set_requires_grad_sync(requires_grad_sync)
90 def zero_grads(self):
91 """zero accumunication grads"""
92 if not hasattr(self, "hsdp_scheduler"):
93 raise ValueError("call hsdp interface first.")
94 if platform == PlatformType.PYTORCH:
95 raise RuntimeError("zero_grads shouldn't be called in torch platform, use optimizer.zero_grad() instead.")
96 for _, module in platform.get_cells_and_names(self):
97 if isinstance(module, HSDPModule):
98 module.hsdp_scheduler.zero_grads()
100 def set_modules_to_forward_prefetch(self, modules):
101 """set forward prefetch module list to prefetch all gather for unsharded parameters"""
102 if not isinstance(modules, (tuple, list)):
103 raise ValueError("modules must be HSDPModule list")
104 for module in modules:
105 if not isinstance(module, HSDPModule):
106 raise ValueError(f"modules must be HSDPModule list but got {type(module)} in list.")
107 if not hasattr(self, "hsdp_scheduler"):
108 raise ValueError("call hsdp interface first.")
109 self.hsdp_scheduler.set_forward_prefetch_cells(modules)
111 def set_modules_to_backward_prefetch(self, modules):
112 """set backward prefetch module list to prefetch all gather for unsharded parameters"""
113 if not isinstance(modules, (tuple, list)):
114 raise ValueError("modules must be HSDPModule list")
115 for module in modules:
116 if not isinstance(module, HSDPModule):
117 raise ValueError(f"modules must be HSDPModule list but got {type(module)} in list.")
118 if not hasattr(self, "hsdp_scheduler"):
119 raise ValueError("call fully_shard interface first.")
120 self.hsdp_scheduler.set_backward_prefetch_cells(modules)
122 def reshard(self) -> None:
123 """reshard all sharded parameters"""
124 if not self.hsdp_scheduler:
125 raise ValueError("hsdp_scheduler is None")
126 scheduler_state = self.hsdp_scheduler.scheduler_state
127 if scheduler_state:
128 scheduler_state.shard()
130 def unshard(self, async_op: bool = False):
131 """unshard all sharded parameters"""
132 if not isinstance(async_op, bool):
133 raise ValueError(f"async_op should be a bool, got {type(async_op)}")
134 if not self.hsdp_scheduler:
135 raise ValueError("hsdp_scheduler is None")
136 scheduler_state = self.hsdp_scheduler.scheduler_state
137 if scheduler_state:
138 scheduler_state.unshard(async_op=async_op)
139 if async_op:
140 return _UnshardHandle(hsdp_state=scheduler_state)
141 return None
143 def load_state_dict(
144 self,
145 state_dict: Mapping[str, Any],
146 strict: bool = True,
147 assign: bool = False,
148 ):
149 """
150 Load state dict by copying directly into local shards.
152 Bypasses ``super().load_state_dict()`` because the standard PyTorch
153 implementation triggers ``copy_`` through the DTensor dispatcher, which
154 is not registered in the hyper-parallel layout system.
156 Each value in ``state_dict`` is dispatched by type:
157 - hyper DTensor: extract local shard and copy directly.
158 - plain Tensor whose shape == local shard shape: copy as-is.
159 - plain Tensor whose shape == global shape: distribute via
160 ``distribute_tensor``, then copy the local shard.
162 Args:
163 state_dict (Mapping[str, Any]): Fully-qualified parameter/buffer
164 names mapped to tensors (DTensor or plain Tensor).
165 strict (bool): If ``True`` (default), missing or unexpected keys
166 raise ``RuntimeError``, matching ``nn.Module.load_state_dict``
167 semantics.
168 assign (bool): Reserved for API compatibility with
169 ``nn.Module.load_state_dict(assign=True)``. Currently unused.
171 Raises:
172 RuntimeError: When ``strict`` is ``True`` and keys do not match.
173 ValueError: When a plain tensor shape matches neither the local
174 shard shape nor the global shape of the target DTensor.
175 """
176 self_module = cast(nn.Module, self)
178 target_map: dict[str, torch.Tensor] = {}
179 for name, p in self_module.named_parameters():
180 target_map[name] = p
181 for name, b in self_module.named_buffers():
182 target_map[name] = b
184 if strict:
185 expected_keys = set(self_module.state_dict().keys())
186 missing = expected_keys - set(state_dict.keys())
187 unexpected = set(state_dict.keys()) - expected_keys
188 error_msgs: list[str] = []
189 if missing:
190 error_msgs.append(
191 "Missing key(s): " + ", ".join(repr(k) for k in sorted(missing))
192 )
193 if unexpected:
194 error_msgs.append(
195 "Unexpected key(s): " + ", ".join(repr(k) for k in sorted(unexpected))
196 )
197 if error_msgs:
198 raise RuntimeError(
199 f"Error(s) in loading state_dict for "
200 f"{self_module.__class__.__name__}:\n\t"
201 + "\n\t".join(error_msgs)
202 )
204 with torch.no_grad():
205 for key, val in state_dict.items():
206 target = target_map.get(key)
207 if target is None:
208 continue
210 if isinstance(target, DTensor):
211 if isinstance(val, DTensor):
212 local_val = val.to_local()
213 else:
214 local_shape = tuple(target.local_shape)
215 global_shape = tuple(target.shape)
216 val_shape = tuple(val.shape)
217 if val_shape == local_shape:
218 local_val = val
219 elif val_shape == global_shape:
220 wrapped = distribute_tensor(
221 val.detach(), target.device_mesh, target.placements,
222 )
223 local_val = wrapped.to_local()
224 else:
225 raise ValueError(
226 f"load '{key}': plain tensor shape {val_shape} "
227 f"matches neither local shard {local_shape} "
228 f"nor global {global_shape}."
229 )
230 if target.to_local().is_meta:
231 # Meta tensor materialisation: replace the placeholder
232 target._local_tensor = local_val # pylint: disable=protected-access
233 else:
234 target.to_local().copy_(local_val)
235 else:
236 target.copy_(val)
238 # Trigger load_state_dict post-hooks so that HSDP internal
239 # bookkeeping (e.g. _sharded_param_data) stays in sync.
240 for _, module in self_module.named_modules():
241 hooks = module._load_state_dict_post_hooks # pylint: disable=protected-access
242 for hook in hooks.values():
243 hook(module, None)
245 def set_is_last_backward(self, is_last_backward: bool):
246 """set is_last_backward flag"""
247 self.hsdp_scheduler.scheduler_ctx.is_last_backward = is_last_backward
249 def set_requires_all_reduce(self, requires_all_reduce: bool, *, recurse: bool = True) -> None:
250 """set requires_all_reduce flag"""
251 if not isinstance(requires_all_reduce, bool):
252 raise ValueError(
253 f"requires_all_reduce should be a bool, got {type(requires_all_reduce)}"
254 )
255 if not recurse:
256 raise NotImplementedError(f"Currently impl is equal to recurse=True,\
257 need support module_param mapping.")
258 self_module = cast(nn.Module, self)
259 modules = list(self_module.modules()) if recurse else [self_module]
260 for module in modules:
261 if isinstance(module, HSDPModule):
262 module.hsdp_scheduler.set_requires_all_reduce(requires_all_reduce)
264 def set_reshard_after_forward(self, reshard_after_forward: bool, recurse: bool = True) -> None:
265 """set reshard_after_forward flag"""
266 if not isinstance(reshard_after_forward, bool):
267 raise ValueError(
268 f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}"
269 )
270 if not recurse:
271 raise NotImplementedError(f"Currently impl is equal to recurse=True,\
272 need support module_param mapping.")
273 self_module = cast(nn.Module, self)
274 modules = list(self_module.modules()) if recurse else [self_module]
275 for module in modules:
276 if isinstance(module, HSDPModule):
277 module.hsdp_scheduler.set_reshard_after_forward(reshard_after_forward)
279 def set_reshard_after_backward(self, reshard_after_backward: bool, recurse: bool = True) -> None:
280 """set reshard_after_backward flag"""
281 if not isinstance(reshard_after_backward, bool):
282 raise ValueError(
283 f"reshard_after_backward should be a bool, got {type(reshard_after_backward)}"
284 )
285 if not recurse:
286 raise NotImplementedError(f"Currently impl is equal to recurse=True,\
287 need support module_param mapping.")
288 self_module = cast(nn.Module, self)
289 modules = list(self_module.modules()) if recurse else [self_module]
290 for module in modules:
291 if isinstance(module, HSDPModule):
292 module.hsdp_scheduler.set_reshard_after_backward(reshard_after_backward)
294 def set_reduce_op_type(self, reduce_op_type) -> None:
295 """
296 set reduce_op_type for all reduce operations in HSDP
297 support reduce_op_type "avg" and "sum", default is "avg"
298 """
299 if hsdp_state := self.hsdp_scheduler.hsdp_state:
300 hsdp_state.set_reduce_op_type(reduce_op_type)
303def _extend_module_with_hsdp_interface(module):
304 """extend Module with HSDPModule interface"""
305 origin_class = module.__class__
306 extend_class = origin_class_to_extend_class.get(origin_class, None)
307 if extend_class is None:
308 extend_class = type(f"HSDP{origin_class.__name__}", (HSDPModule, origin_class), {})
309 origin_class_to_extend_class[origin_class] = extend_class
310 module.__class__ = extend_class
313# pylint: disable=C0415
314def _check_module_valid(platform_type, module):
315 """check module valid"""
316 if platform_type == PlatformType.MINDSPORE:
317 from mindspore.nn.cell import Cell
318 if not isinstance(module, Cell):
319 raise ValueError(f"module's type must be nn.cell but got {type(module)}.")
320 else:
321 from torch.nn import Module
322 if not isinstance(module, Module):
323 raise ValueError(f"module's type must be nn.Module but got {type(module)}.")
326# pylint: disable=C0415
327def _check_hsdp_input_valid(platform_type, module, shard_size, threshold, optimizer_level, enable_grad_accumulation,
328 grad_scale, reduce_dtype, comm_async, comm_fusion, bucket_size):
329 """check hsdp input valid"""
330 _check_module_valid(platform_type, module)
331 if not isinstance(shard_size, int) or (shard_size <= 0 and shard_size != -1):
332 raise ValueError(f"shard_size must be a positive integer, but got {shard_size}.")
333 if not isinstance(threshold, int) or threshold < 0:
334 raise ValueError(f"threshold must be a positive integer or 0, but got {threshold}.")
335 if optimizer_level not in ["level1", "level2", "level3"]:
336 raise ValueError(f"Optimizer level should in ['level1', 'level2', 'level3'], but got {optimizer_level}.")
337 if not isinstance(enable_grad_accumulation, bool):
338 raise ValueError(f"enable_grad_accumulation must be bool but got {enable_grad_accumulation}.")
339 if not isinstance(grad_scale, float):
340 raise ValueError(f"grad_scale must be float but got {grad_scale}.")
341 if platform_type == PlatformType.MINDSPORE:
342 from mindspore._c_expression.typing import Type
343 if reduce_dtype is not None and not isinstance(reduce_dtype, Type):
344 raise ValueError(f"reduce_dtype must be mindspore.dtype but got {reduce_dtype}.")
345 else:
346 if reduce_dtype is not None and not isinstance(reduce_dtype, torch.dtype):
347 raise ValueError(f"reduce_dtype must be torch.dtype but got {reduce_dtype}.")
348 if not isinstance(comm_async, bool):
349 raise ValueError(f"comm_async must be bool but got {comm_async}.")
350 if not isinstance(comm_fusion, bool):
351 raise ValueError(f"comm_fusion must be bool but got {comm_fusion}.")
352 if not isinstance(bucket_size, int) or (bucket_size < 0 and bucket_size != -1):
353 raise ValueError(f"bucket_size must be a positive integer or 0, but got {bucket_size}.")
356def fully_shard(
357 module: nn.Module,
358 *,
359 mesh: Optional[DeviceMesh] = None,
360 reshard_after_forward: Optional[Union[bool, int]] = None,
361 shard_placement_fn: None = None,
362 mp_policy: None = None,
363 offload_policy: None = None,
364 ignored_params: Optional[set[nn.Parameter]] = None,
365 device = None,
366):
367 platform_type = platform.platform_type
368 _extend_module_with_hsdp_interface(module)
369 # TODO: mindspore does not support get_device_handle
370 if device is None:
371 device_handle = platform.get_device_handle() # return torch.npu or torch.cuda
372 if device_handle.is_available():
373 device = torch.device(device_handle.current_device())
374 else:
375 device = torch.device("cpu")
377 mesh = mesh or init_device_mesh(device_type=device, mesh_shape=(platform.get_world_size(),))
379 module.hsdp_init(
380 platform_type,
381 module,
382 mesh,
383 reshard_after_forward,
384 shard_placement_fn,
385 mp_policy,
386 offload_policy,
387 ignored_params,
388 device,
389 )
390 return module
393def _gather_full_state_dict(
394 state_dict: dict[str, Any], cpu_offload: bool
395) -> dict[str, Any]:
396 """All-gather every DTensor shard into a full tensor.
398 Args:
399 state_dict: Model state dict with DTensor or plain tensor values.
400 cpu_offload: If True, only rank-0 keeps the result on CPU;
401 other ranks return an empty dict to save memory.
402 """
403 is_rank0 = (not dist.is_initialized()) or (dist.get_rank() == 0)
405 gathered: dict[str, Any] = {}
406 for key, val in state_dict.items():
407 if isinstance(val, DTensor):
408 val = val.full_tensor()
409 if cpu_offload:
410 if not is_rank0:
411 del val
412 continue
413 if isinstance(val, torch.Tensor):
414 val = val.cpu()
415 gathered[key] = val
417 if cpu_offload and not is_rank0:
418 return {}
419 return gathered
422def _offload_sharded_state_dict(
423 state_dict: dict[str, Any],
424) -> dict[str, Any]:
425 """Move each shard to CPU without all-gathering.
427 Args:
428 state_dict: Model state dict with DTensor or plain tensor values.
429 """
430 offloaded: dict[str, Any] = {}
431 for key, val in state_dict.items():
432 if isinstance(val, DTensor):
433 val = DTensor.from_local(
434 val.to_local().cpu(), val.device_mesh, val.placements,
435 )
436 elif isinstance(val, torch.Tensor):
437 val = val.cpu()
438 offloaded[key] = val
439 return offloaded
442def get_model_state_dict(
443 model: nn.Module,
444 *,
445 options: StateDictOptions | None = None,
446) -> dict[str, Any]:
447 """Return the model state dict with configurable gathering and offloading.
449 Behaviour matrix:
451 +-----------------+-------------+--------------------------------------+
452 | full_state_dict | cpu_offload | result |
453 +=================+=============+======================================+
454 | False | False | DTensor (sharded, as-is) |
455 +-----------------+-------------+--------------------------------------+
456 | False | True | DTensor local shard offloaded to CPU |
457 +-----------------+-------------+--------------------------------------+
458 | True | False | full Tensor on **every** rank |
459 +-----------------+-------------+--------------------------------------+
460 | True | True | full Tensor on CPU, **rank 0 only** |
461 +-----------------+-------------+--------------------------------------+
463 Args:
464 model: The model whose state dict to retrieve.
465 options: Controls full_state_dict, cpu_offload,
466 ignore_frozen_params, and broadcast_from_rank0 flags.
467 """
468 options = options or StateDictOptions()
470 if options.broadcast_from_rank0 and not options.full_state_dict:
471 raise ValueError(
472 "full_state_dict must be True when broadcast_from_rank0 is True."
473 )
475 state_dict: dict[str, Any] = model.state_dict()
477 if options.ignore_frozen_params:
478 frozen_keys = {
479 name for name, p in model.named_parameters()
480 if not p.requires_grad
481 }
482 for key in frozen_keys:
483 state_dict.pop(key, None)
485 if options.full_state_dict:
486 return _gather_full_state_dict(state_dict, options.cpu_offload)
488 if options.cpu_offload:
489 return _offload_sharded_state_dict(state_dict)
491 return state_dict
494def hsdp_sync_stream():
495 """wait for hsdp gradient handle to be completed"""
496 if platform is None:
497 return
498 platform.wait_grad_handle()