Coverage for hyper_parallel / core / hsdp / hsdp_state.py: 79%

111 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP cell state""" 

16from hyper_parallel.core.hsdp.hsdp_param_buffer import HSDPParamBuffer 

17from hyper_parallel.core.hsdp.hsdp_grad_buffer import HSDPGradBuffer 

18 

19 

20class HSDPState: 

21 """HSDP state for cell""" 

22 def __init__(self, cell, config, platform): 

23 self.cell = cell 

24 self.config = config 

25 self.platform = platform 

26 self.hsdp_params = [] 

27 self.sharded_hsdp_params = [] 

28 self.param_buffers = [] 

29 self.grad_buffers = [] 

30 self._init_hsdp_params() 

31 self._init_param_buffers() 

32 self._init_grad_buffers() 

33 self.is_shard = True 

34 

35 def _init_hsdp_params(self): 

36 """init hsdp parameters for cell""" 

37 raise NotImplementedError("HSDPState subclasses must implement _init_hsdp_params") 

38 

39 def _init_param_buffers(self): 

40 """init param buffers""" 

41 if not self.config.comm_fusion: 

42 return 

43 

44 group_to_buffer = {} 

45 for hsdp_param in self.sharded_hsdp_params: 

46 param_buffer_key = hsdp_param.sharded_group_info.group_name + str(hsdp_param.param.dtype) 

47 if param_buffer_key not in group_to_buffer: 

48 buffer = HSDPParamBuffer(self.config, hsdp_param, self.platform) 

49 buffer.add_param(hsdp_param) 

50 group_to_buffer[param_buffer_key] = buffer 

51 else: 

52 buffer = group_to_buffer[param_buffer_key] 

53 buffer.add_param(hsdp_param) 

54 self.param_buffers = list(group_to_buffer.values()) 

55 for buffer in self.param_buffers: 

56 buffer.init() 

57 

58 def _init_grad_buffers(self): 

59 """init grad buffers""" 

60 if not self.config.grad_fusion: 

61 return 

62 

63 bucket_infos = {} 

64 def get_bucket_key(buffer_key, hsdp_param): 

65 if self.config.bucket_size < 0: 

66 return buffer_key 

67 param_size = hsdp_param.param.numel() * self.platform.get_param_type_size(hsdp_param.param) 

68 bucket_info = bucket_infos.get(buffer_key, None) 

69 if bucket_info is None: 

70 bucket_info = [0, param_size] 

71 bucket_infos[buffer_key] = bucket_info 

72 else: 

73 bucket_size = bucket_info[1] + param_size 

74 if bucket_size > self.config.bucket_size: 

75 bucket_info[0] = bucket_info[0] + 1 

76 bucket_info[1] = param_size 

77 return buffer_key + '_' + str(bucket_info[0]) 

78 

79 self.param_to_buffer = {} 

80 group_to_buffer = {} 

81 for hsdp_param in self.hsdp_params: 

82 if not hsdp_param.param.requires_grad: 

83 continue 

84 buffer_key = hsdp_param.sharded_group_info.group_name + hsdp_param.unsharded_group_info.group_name \ 

85 + str(hsdp_param.param.dtype) 

86 bucket_key = get_bucket_key(buffer_key, hsdp_param) 

87 if bucket_key not in group_to_buffer: 

88 buffer = HSDPGradBuffer(self.config, hsdp_param, self.platform) 

89 group_to_buffer[bucket_key] = buffer 

90 else: 

91 buffer = group_to_buffer[bucket_key] 

92 buffer.add_param(hsdp_param) 

93 self.param_to_buffer[hsdp_param] = buffer 

94 self.grad_buffers = list(group_to_buffer.values()) 

95 for buffer in self.grad_buffers: 

96 buffer.init() 

97 

98 def shard(self): 

99 """change parameters to sharded state""" 

100 if self.is_shard: 

101 return 

102 

103 if self.config.comm_fusion: 

104 for buffer in self.param_buffers: 

105 buffer.to_sharded() 

106 else: 

107 for param in self.sharded_hsdp_params: 

108 param.to_sharded() 

109 self.is_shard = True 

110 

111 def unshard(self): 

112 """change parameters to unsharded state""" 

113 if not self.is_shard: 

114 return 

115 

116 if self.config.comm_fusion: 

117 for buffer in self.param_buffers: 

118 buffer.to_unsharded() 

119 else: 

120 for param in self.sharded_hsdp_params: 

121 param.to_unsharded() 

122 self.is_shard = False 

123 

124 def prefetch(self): 

125 """prefetch unsharded parameters""" 

126 if not self.is_shard: 

127 return 

128 if self.config.comm_fusion: 

129 for buffer in self.param_buffers: 

130 buffer.prefetch_unsharded() 

131 else: 

132 for param in self.sharded_hsdp_params: 

133 param.prefetch_unsharded() 

134 

135 def zero_grads(self): 

136 """zero grad or grad buffer""" 

137 if not self.config.grad_fusion: 

138 for hsdp_param in self.hsdp_params: 

139 if not hsdp_param.param.requires_grad: 

140 continue 

141 hsdp_param.zero_acc_grad() 

142 else: 

143 for buffer in self.grad_buffers: 

144 buffer.zero_grads() 

145 

146 def set_grad_ready(self, hsdp_param): 

147 """set grad ready""" 

148 if not self.config.grad_fusion: 

149 return 

150 buffer = self.param_to_buffer.get(hsdp_param, None) 

151 if buffer is not None: 

152 buffer.set_grad_ready() 

153 else: 

154 raise ValueError(f"param {hsdp_param.param} is not register to buffer.") 

155 

156 def set_requires_grad_sync(self, requires_grad_sync): 

157 """set requires grad sync flag to control gradient sync.""" 

158 if not self.config.grad_fusion: 

159 return 

160 for buffer in self.grad_buffers: 

161 buffer.set_requires_grad_sync(requires_grad_sync)