Coverage for hyper_parallel / platform / torch / fully_shard / utils.py: 93%

46 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1import torch 

2from dataclasses import dataclass 

3from hyper_parallel.core.device_mesh import DeviceMesh 

4from typing import Optional 

5 

6 

7@dataclass 

8class MixedPrecisionPolicy: 

9 """ 

10 Configures mixed precision training for HSDP. 

11 

12 This policy controls data type casting during forward/backward computation 

13 and gradient reduction, enabling memory savings and potential speedups. 

14 

15 Attributes: 

16 param_dtype: Data type for parameter computation. If None, uses original dtype. 

17 reduce_dtype: Data type for gradient reduction. If None, uses param_dtype. 

18 output_dtype: Data type for module outputs. If None, no casting applied. 

19 """ 

20 param_dtype: Optional[torch.dtype] = None 

21 reduce_dtype: Optional[torch.dtype] = None 

22 output_dtype: Optional[torch.dtype] = None 

23 cast_forward_inputs: bool = True 

24 

25 

26@dataclass 

27class OffloadPolicy: 

28 """ 

29 Base class for offload policies. 

30 

31 This represents no offloading and serves as the default policy. 

32 Subclass this to implement custom offload strategies. 

33 """ 

34 pass 

35 

36 

37@dataclass 

38class CPUOffloadPolicy(OffloadPolicy): 

39 """ 

40 Offloads sharded parameters and gradients to CPU memory. 

41 

42 When enabled, sharded parameters are kept on CPU and copied to device 

43 before all-gather. Gradients are copied back to CPU after backward. 

44 This reduces NPU memory usage at the cost of additional data transfers. 

45 

46 Attributes: 

47 pin_memory: If True, pins CPU memory for faster H2D/D2H transfers 

48 and enables overlap with computation. Disable if CPU memory 

49 is constrained. (Default: True) 

50 """ 

51 pin_memory: bool = True 

52 

53 

54@dataclass 

55class DataParallelMeshInfo: 

56 mesh: DeviceMesh 

57 shard_mesh_dim: Optional[int] = None 

58 replicate_mesh_dim: Optional[int] = None 

59 

60 def __post_init__(self): 

61 if self.shard_mesh_dim is None and self.replicate_mesh_dim is None: 

62 raise AssertionError( 

63 "At least one of shard_mesh_dim and replicate_mesh_dim must not be None" 

64 ) 

65 

66 

67@dataclass 

68class FSDPMeshInfo(DataParallelMeshInfo): 

69 def __post_init__(self): 

70 super().__post_init__() 

71 if self.shard_mesh_dim is None: 

72 raise AssertionError("Expects non-None shard_mesh_dim") 

73 self.shard_mesh_size: int = self.mesh.mesh_shape[self.shard_mesh_dim] 

74 self.shard_process_group = self.mesh.get_group(self.shard_mesh_dim) 

75 self.shard_mesh_rank: int = self.shard_process_group.rank() 

76 

77 

78@dataclass 

79class DDPMeshInfo(DataParallelMeshInfo): 

80 def __post_init__(self): 

81 super().__post_init__() 

82 if self.replicate_mesh_dim is None: 

83 raise AssertionError("Expects non-None replicate_mesh_dim") 

84 self.replicate_mesh_size: int = self.mesh.mesh_shape[self.replicate_mesh_dim] 

85 self.replicate_process_group = self.mesh.get_group(self.replicate_mesh_dim) 

86 self.replicate_mesh_rank: int = self.replicate_process_group.rank() 

87 

88 

89@dataclass 

90class HSDPMeshInfo(FSDPMeshInfo, DDPMeshInfo): 

91 def __post_init__(self): 

92 # Calls `FSDPMeshInfo` -> `DDPMeshInfo` -> `DataParallelMeshInfo` 

93 super().__post_init__()