Coverage for hyper_parallel / platform / torch / fully_shard / utils.py: 93%
46 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1import torch
2from dataclasses import dataclass
3from hyper_parallel.core.device_mesh import DeviceMesh
4from typing import Optional
7@dataclass
8class MixedPrecisionPolicy:
9 """
10 Configures mixed precision training for HSDP.
12 This policy controls data type casting during forward/backward computation
13 and gradient reduction, enabling memory savings and potential speedups.
15 Attributes:
16 param_dtype: Data type for parameter computation. If None, uses original dtype.
17 reduce_dtype: Data type for gradient reduction. If None, uses param_dtype.
18 output_dtype: Data type for module outputs. If None, no casting applied.
19 """
20 param_dtype: Optional[torch.dtype] = None
21 reduce_dtype: Optional[torch.dtype] = None
22 output_dtype: Optional[torch.dtype] = None
23 cast_forward_inputs: bool = True
26@dataclass
27class OffloadPolicy:
28 """
29 Base class for offload policies.
31 This represents no offloading and serves as the default policy.
32 Subclass this to implement custom offload strategies.
33 """
34 pass
37@dataclass
38class CPUOffloadPolicy(OffloadPolicy):
39 """
40 Offloads sharded parameters and gradients to CPU memory.
42 When enabled, sharded parameters are kept on CPU and copied to device
43 before all-gather. Gradients are copied back to CPU after backward.
44 This reduces NPU memory usage at the cost of additional data transfers.
46 Attributes:
47 pin_memory: If True, pins CPU memory for faster H2D/D2H transfers
48 and enables overlap with computation. Disable if CPU memory
49 is constrained. (Default: True)
50 """
51 pin_memory: bool = True
54@dataclass
55class DataParallelMeshInfo:
56 mesh: DeviceMesh
57 shard_mesh_dim: Optional[int] = None
58 replicate_mesh_dim: Optional[int] = None
60 def __post_init__(self):
61 if self.shard_mesh_dim is None and self.replicate_mesh_dim is None:
62 raise AssertionError(
63 "At least one of shard_mesh_dim and replicate_mesh_dim must not be None"
64 )
67@dataclass
68class FSDPMeshInfo(DataParallelMeshInfo):
69 def __post_init__(self):
70 super().__post_init__()
71 if self.shard_mesh_dim is None:
72 raise AssertionError("Expects non-None shard_mesh_dim")
73 self.shard_mesh_size: int = self.mesh.mesh_shape[self.shard_mesh_dim]
74 self.shard_process_group = self.mesh.get_group(self.shard_mesh_dim)
75 self.shard_mesh_rank: int = self.shard_process_group.rank()
78@dataclass
79class DDPMeshInfo(DataParallelMeshInfo):
80 def __post_init__(self):
81 super().__post_init__()
82 if self.replicate_mesh_dim is None:
83 raise AssertionError("Expects non-None replicate_mesh_dim")
84 self.replicate_mesh_size: int = self.mesh.mesh_shape[self.replicate_mesh_dim]
85 self.replicate_process_group = self.mesh.get_group(self.replicate_mesh_dim)
86 self.replicate_mesh_rank: int = self.replicate_process_group.rank()
89@dataclass
90class HSDPMeshInfo(FSDPMeshInfo, DDPMeshInfo):
91 def __post_init__(self):
92 # Calls `FSDPMeshInfo` -> `DDPMeshInfo` -> `DataParallelMeshInfo`
93 super().__post_init__()