Coverage for hyper_parallel / core / fully_shard / hsdp_grad_buffer.py: 12%

164 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP gradient buffer""" 

16from hyper_parallel.core.fully_shard.hsdp_utils import OptimizerLevel 

17 

18 

19class HSDPGradBuffer: 

20 """ 

21 HSDP gradient buffer. 

22 """ 

23 def __init__(self, config, init_hsdp_param, platform): 

24 self.config = config 

25 self.platform = platform 

26 self.shard_size = init_hsdp_param.shard_size 

27 self.dp_size = init_hsdp_param.dp_size 

28 self.local_rank = init_hsdp_param.hsdp_rank % init_hsdp_param.shard_size 

29 self.dtype = init_hsdp_param.param.dtype 

30 self.sharded_group_info = init_hsdp_param.sharded_group_info 

31 self.unsharded_group_info = init_hsdp_param.unsharded_group_info 

32 self.sharded = init_hsdp_param.sharded 

33 self.fully_sharded = init_hsdp_param.fully_sharded 

34 self.device = init_hsdp_param.param.device 

35 self.hsdp_params = [] 

36 self.numel = 0 

37 self.requires_grad_sync = False 

38 self.sharded_grad_buffer = None 

39 self.unsharded_grad_buffer = None 

40 self.prefetch_handle = None 

41 self.prefetch_data = None 

42 self.reduce_type = self.dtype 

43 if self.config.reduce_dtype is not None: 

44 self.reduce_type = self.config.reduce_dtype 

45 

46 def add_param(self, hsdp_param): 

47 """add param to buffer""" 

48 self.hsdp_params.append(hsdp_param) 

49 

50 def init(self): 

51 """init buffer""" 

52 self._init_grad_buffer_index() 

53 self._init_acc_grad_buffer() 

54 self.num_grad = len(self.hsdp_params) 

55 self.num_ready_grad = 0 

56 

57 def set_requires_grad_sync(self, requires_grad_sync): 

58 """set requires grad sync flag to control gradient sync.""" 

59 self.requires_grad_sync = requires_grad_sync 

60 

61 def _set_handle(self, handle, process=None): 

62 """set handle for async comm and wait handle for sync comm""" 

63 if self.config.comm_async: 

64 if process is not None: 

65 self.platform.set_grad_reduce_handle(handle, process) 

66 else: 

67 self.platform.set_grad_reduce_handle(handle) 

68 else: 

69 handle.wait() 

70 if process is not None: 

71 process() 

72 

73 def _init_grad_buffer_index(self): 

74 """init grad buffer index""" 

75 self.numel = 0 

76 for hsdp_param in self.hsdp_params: 

77 start_index = self.numel 

78 end_index = start_index + hsdp_param.param.numel() 

79 hsdp_param.grad_buffer_start_index = start_index 

80 hsdp_param.grad_buffer_end_index = end_index 

81 self.numel = end_index 

82 

83 def _init_acc_grad_buffer(self): 

84 """init acc grad buffer""" 

85 if not self.config.requires_acc_grad: 

86 return 

87 

88 acc_grad_shape = (self.shard_size, self.numel) 

89 self.acc_grad_buffer_dim0 = self.shard_size 

90 if self.config.shard_level != OptimizerLevel.SHARD_OPT: 

91 acc_grad_shape = (1, self.numel) 

92 self.acc_grad_buffer_dim0 = 1 

93 self.acc_grad_buffer = self.platform.new_zero_parameter(acc_grad_shape, self.reduce_type, False, self.device) 

94 

95 def _copy_grad_to_unsharded_buffer(self): 

96 """copy grad to buffer""" 

97 self.unsharded_grad_buffer = self.platform.new_tensor((self.shard_size, self.numel), self.reduce_type, 

98 self.device) 

99 for hsdp_param in self.hsdp_params: 

100 if hsdp_param.grad is None: 

101 raise ValueError("HSDP param grad can't be None with comm fusion.") 

102 start_index = hsdp_param.grad_buffer_start_index 

103 end_index = hsdp_param.grad_buffer_end_index 

104 buffer_view = self.unsharded_grad_buffer[:, start_index:end_index] 

105 if self.dtype != self.reduce_type: 

106 buffer_view[:] = hsdp_param.grad.view(self.shard_size, -1).to(self.reduce_type)[:] 

107 else: 

108 buffer_view[:] = hsdp_param.grad.view(self.shard_size, -1)[:] 

109 

110 def _add_grad_to_acc_buffer(self): 

111 """add grad to acc buffer""" 

112 self._copy_grad_to_unsharded_buffer() 

113 self.acc_grad_buffer.add_(self.unsharded_grad_buffer) 

114 self.unsharded_grad_buffer = None 

115 

116 def _set_grad_data(self, buffer=None): 

117 """set grad with buffer view""" 

118 if buffer is None: 

119 buffer = self.sharded_grad_buffer 

120 

121 if self.dtype != self.reduce_type: 

122 buffer = buffer.to(self.dtype) 

123 

124 if self.config.grad_scale != 1.0: 

125 buffer = buffer * self.config.grad_scale 

126 

127 for hsdp_param in self.hsdp_params: 

128 if hsdp_param.grad is None: 

129 raise ValueError("HSDP param grad can't be None with comm fusion.") 

130 start_index = hsdp_param.grad_buffer_start_index 

131 end_index = hsdp_param.grad_buffer_end_index 

132 view_shape = list(hsdp_param.param_shape) 

133 view_shape[0] = view_shape[0] // self.shard_size 

134 buffer_view = buffer[:, start_index:end_index].view(view_shape) 

135 hsdp_param.grad.data = buffer_view + 0 

136 

137 def _handle_single_node_grad(self): 

138 """single node don't need to reduce grad, only need to acc grad""" 

139 if not self.config.requires_acc_grad: 

140 return 

141 self._add_grad_to_acc_buffer() 

142 self._set_grad_data(self.acc_grad_buffer) 

143 

144 def _reduce_no_shard_grad(self): 

145 """reduce grad when param is not sharded""" 

146 if not self.config.requires_acc_grad: 

147 self._copy_grad_to_unsharded_buffer() 

148 output, handle = self.platform.all_reduce(self.unsharded_grad_buffer, self.unsharded_group_info, 

149 async_op=True) 

150 else: 

151 self._add_grad_to_acc_buffer() 

152 if not self.requires_grad_sync: 

153 return 

154 output, handle = self.platform.all_reduce(self.acc_grad_buffer, self.unsharded_group_info, async_op=True) 

155 self.sharded_grad_buffer = output 

156 self._set_handle(handle, self._set_grad_data) 

157 

158 def _reduce_fully_shard_grad(self): 

159 """reducescatter grad when param is fully sharded""" 

160 if not self.config.requires_acc_grad: 

161 self._copy_grad_to_unsharded_buffer() 

162 output, handle = self.platform.reduce_scatter_tensor(self.unsharded_grad_buffer, self.sharded_group_info, 

163 async_op=True) 

164 self.sharded_grad_buffer = output 

165 self._set_handle(handle, self._set_grad_data) 

166 elif self.config.shard_level != OptimizerLevel.SHARD_OPT: 

167 self._copy_grad_to_unsharded_buffer() 

168 output, handle = self.platform.reduce_scatter_tensor(self.unsharded_grad_buffer, self.sharded_group_info, 

169 async_op=True) 

170 self.sharded_grad_buffer = output 

171 def post_process(): 

172 self.acc_grad_buffer.add_(output) 

173 self._set_grad_data(self.acc_grad_buffer) 

174 self._set_handle(handle, post_process) 

175 else: 

176 self._add_grad_to_acc_buffer() 

177 if not self.requires_grad_sync: 

178 return 

179 output, handle = self.platform.reduce_scatter_tensor(self.acc_grad_buffer, self.sharded_group_info, 

180 async_op=True) 

181 self.sharded_grad_buffer = output 

182 self._set_handle(handle, self._set_grad_data) 

183 

184 def _reduce_partial_shard_grad(self): 

185 """reduce grad after reducescatter grad when param is partial sharded""" 

186 if not self.config.requires_acc_grad: 

187 self._copy_grad_to_unsharded_buffer() 

188 output, _ = self.platform.reduce_scatter_tensor(self.unsharded_grad_buffer, self.sharded_group_info, 

189 async_op=False) 

190 output, handle = self.platform.all_reduce(output, self.unsharded_group_info, async_op=True) 

191 self.sharded_grad_buffer = output 

192 self._set_handle(handle, self._set_grad_data) 

193 elif self.config.shard_level != OptimizerLevel.SHARD_OPT: 

194 self._copy_grad_to_unsharded_buffer() 

195 output, _ = self.platform.reduce_scatter_tensor(self.unsharded_grad_buffer, self.sharded_group_info, 

196 async_op=False) 

197 self.acc_grad_buffer.add_(output) 

198 if not self.requires_grad_sync: 

199 return 

200 output, handle = self.platform.all_reduce(self.acc_grad_buffer, self.unsharded_group_info, async_op=True) 

201 self.sharded_grad_buffer = output 

202 self._set_handle(handle, self._set_grad_data) 

203 else: 

204 self._add_grad_to_acc_buffer() 

205 if not self.requires_grad_sync: 

206 return 

207 output, _ = self.platform.reduce_scatter_tensor(self.acc_grad_buffer, self.sharded_group_info, 

208 async_op=False) 

209 output, handle = self.platform.all_reduce(output, self.unsharded_group_info, async_op=True) 

210 self.sharded_grad_buffer = output 

211 self._set_handle(handle, self._set_grad_data) 

212 

213 def _reduce_grads(self): 

214 """reduce or accumulate grad buffer""" 

215 if not self.sharded: 

216 if self.dp_size == 1: 

217 self._handle_single_node_grad() 

218 else: 

219 self._reduce_no_shard_grad() 

220 elif self.fully_sharded: 

221 self._reduce_fully_shard_grad() 

222 else: 

223 self._reduce_partial_shard_grad() 

224 

225 def zero_grads(self): 

226 """zero grad buffer""" 

227 if not self.config.requires_acc_grad: 

228 return 

229 self.acc_grad_buffer.zero_() 

230 

231 def set_grad_ready(self): 

232 """set grad ready""" 

233 self.num_ready_grad = self.num_ready_grad + 1 

234 if self.num_ready_grad == self.num_grad: 

235 self._reduce_grads() 

236 self.num_ready_grad = 0