Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / fully_shard / utils.py: 94%

48 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1"""Common policy and mesh metadata for fully_shard APIs.""" 

2from dataclasses import dataclass 

3from typing import Optional 

4 

5from hyper_parallel.collectives.cc import get_group_local_rank 

6from hyper_parallel.core.dtensor.device_mesh import DeviceMesh 

7from hyper_parallel.platform import get_platform 

8 

9platform = get_platform() 

10 

11 

12@dataclass 

13class MixedPrecisionPolicy: 

14 """ 

15 Configures mixed precision training for HSDP. 

16 

17 This policy controls data type casting during forward/backward computation 

18 and gradient reduction, enabling memory savings and potential speedups. 

19 

20 Attributes: 

21 param_dtype: Data type for parameter computation. If None, uses original dtype. 

22 reduce_dtype: Data type for gradient reduction. If None, uses param_dtype. 

23 output_dtype: Data type for module outputs. If None, no casting applied. 

24 """ 

25 param_dtype: Optional[platform.dtype] = None 

26 reduce_dtype: Optional[platform.dtype] = None 

27 output_dtype: Optional[platform.dtype] = None 

28 cast_forward_inputs: bool = True 

29 apply_grad_on_fp32_main_grad: bool = False 

30 

31 

32@dataclass 

33class OffloadPolicy: 

34 """ 

35 Base class for offload policies. 

36 

37 This represents no offloading and serves as the default policy. 

38 Subclass this to implement custom offload strategies. 

39 """ 

40 

41 

42@dataclass 

43class CPUOffloadPolicy(OffloadPolicy): 

44 """ 

45 Offloads sharded parameters and gradients to CPU memory. 

46 

47 When enabled, sharded parameters are kept on CPU and copied to device 

48 before all-gather. Gradients are copied back to CPU after backward. 

49 This reduces NPU memory usage at the cost of additional data transfers. 

50 

51 Attributes: 

52 pin_memory: If True, pins CPU memory for faster H2D/D2H transfers 

53 and enables overlap with computation. Disable if CPU memory 

54 is constrained. (Default: True) 

55 """ 

56 pin_memory: bool = True 

57 

58 

59@dataclass 

60class DataParallelMeshInfo: 

61 mesh: DeviceMesh 

62 shard_mesh_dim: Optional[int] = None 

63 replicate_mesh_dim: Optional[int] = None 

64 

65 def __post_init__(self): 

66 if self.shard_mesh_dim is None and self.replicate_mesh_dim is None: 

67 raise AssertionError( 

68 "At least one of shard_mesh_dim and replicate_mesh_dim must not be None" 

69 ) 

70 

71 

72@dataclass 

73class FSDPMeshInfo(DataParallelMeshInfo): 

74 def __post_init__(self): 

75 super().__post_init__() 

76 if self.shard_mesh_dim is None: 

77 raise AssertionError("Expects non-None shard_mesh_dim") 

78 self.shard_mesh_size: int = self.mesh.mesh_shape[self.shard_mesh_dim] 

79 self.shard_process_group = self.mesh.get_group(self.shard_mesh_dim) 

80 self.shard_mesh_rank: int = get_group_local_rank(self.shard_process_group) 

81 

82 

83@dataclass 

84class DDPMeshInfo(DataParallelMeshInfo): 

85 def __post_init__(self): 

86 super().__post_init__() 

87 if self.replicate_mesh_dim is None: 

88 raise AssertionError("Expects non-None replicate_mesh_dim") 

89 self.replicate_mesh_size: int = self.mesh.mesh_shape[self.replicate_mesh_dim] 

90 self.replicate_process_group = self.mesh.get_group(self.replicate_mesh_dim) 

91 self.replicate_mesh_rank: int = get_group_local_rank(self.replicate_process_group) 

92 

93 

94@dataclass 

95class HSDPMeshInfo(FSDPMeshInfo, DDPMeshInfo): 

96 # pylint: disable=W0246 

97 def __post_init__(self): 

98 # Calls `FSDPMeshInfo` -> `DDPMeshInfo` -> `DataParallelMeshInfo` 

99 super().__post_init__()