Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / fully_shard / utils.py: 94%
48 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1"""Common policy and mesh metadata for fully_shard APIs."""
2from dataclasses import dataclass
3from typing import Optional
5from hyper_parallel.collectives.cc import get_group_local_rank
6from hyper_parallel.core.dtensor.device_mesh import DeviceMesh
7from hyper_parallel.platform import get_platform
9platform = get_platform()
12@dataclass
13class MixedPrecisionPolicy:
14 """
15 Configures mixed precision training for HSDP.
17 This policy controls data type casting during forward/backward computation
18 and gradient reduction, enabling memory savings and potential speedups.
20 Attributes:
21 param_dtype: Data type for parameter computation. If None, uses original dtype.
22 reduce_dtype: Data type for gradient reduction. If None, uses param_dtype.
23 output_dtype: Data type for module outputs. If None, no casting applied.
24 """
25 param_dtype: Optional[platform.dtype] = None
26 reduce_dtype: Optional[platform.dtype] = None
27 output_dtype: Optional[platform.dtype] = None
28 cast_forward_inputs: bool = True
29 apply_grad_on_fp32_main_grad: bool = False
32@dataclass
33class OffloadPolicy:
34 """
35 Base class for offload policies.
37 This represents no offloading and serves as the default policy.
38 Subclass this to implement custom offload strategies.
39 """
42@dataclass
43class CPUOffloadPolicy(OffloadPolicy):
44 """
45 Offloads sharded parameters and gradients to CPU memory.
47 When enabled, sharded parameters are kept on CPU and copied to device
48 before all-gather. Gradients are copied back to CPU after backward.
49 This reduces NPU memory usage at the cost of additional data transfers.
51 Attributes:
52 pin_memory: If True, pins CPU memory for faster H2D/D2H transfers
53 and enables overlap with computation. Disable if CPU memory
54 is constrained. (Default: True)
55 """
56 pin_memory: bool = True
59@dataclass
60class DataParallelMeshInfo:
61 mesh: DeviceMesh
62 shard_mesh_dim: Optional[int] = None
63 replicate_mesh_dim: Optional[int] = None
65 def __post_init__(self):
66 if self.shard_mesh_dim is None and self.replicate_mesh_dim is None:
67 raise AssertionError(
68 "At least one of shard_mesh_dim and replicate_mesh_dim must not be None"
69 )
72@dataclass
73class FSDPMeshInfo(DataParallelMeshInfo):
74 def __post_init__(self):
75 super().__post_init__()
76 if self.shard_mesh_dim is None:
77 raise AssertionError("Expects non-None shard_mesh_dim")
78 self.shard_mesh_size: int = self.mesh.mesh_shape[self.shard_mesh_dim]
79 self.shard_process_group = self.mesh.get_group(self.shard_mesh_dim)
80 self.shard_mesh_rank: int = get_group_local_rank(self.shard_process_group)
83@dataclass
84class DDPMeshInfo(DataParallelMeshInfo):
85 def __post_init__(self):
86 super().__post_init__()
87 if self.replicate_mesh_dim is None:
88 raise AssertionError("Expects non-None replicate_mesh_dim")
89 self.replicate_mesh_size: int = self.mesh.mesh_shape[self.replicate_mesh_dim]
90 self.replicate_process_group = self.mesh.get_group(self.replicate_mesh_dim)
91 self.replicate_mesh_rank: int = get_group_local_rank(self.replicate_process_group)
94@dataclass
95class HSDPMeshInfo(FSDPMeshInfo, DDPMeshInfo):
96 # pylint: disable=W0246
97 def __post_init__(self):
98 # Calls `FSDPMeshInfo` -> `DDPMeshInfo` -> `DataParallelMeshInfo`
99 super().__post_init__()