Coverage for hyper_parallel / core / fully_shard / hsdp_state.py: 37%

131 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP cell state""" 

16from typing import List 

17from hyper_parallel.core.fully_shard.hsdp_param import HSDPParamV2 

18from hyper_parallel.core.fully_shard.hsdp_param_buffer import HSDPParamBuffer 

19from hyper_parallel.core.fully_shard.hsdp_grad_buffer import HSDPGradBuffer 

20from hyper_parallel.core.fully_shard.hsdp_utils import HSDPConfigV2 

21 

22class HSDPState: 

23 """HSDP state for cell""" 

24 def __init__(self, cell, mesh_info, config: HSDPConfigV2, platform, device=None): 

25 self.cell = cell 

26 self.mesh_info = mesh_info 

27 self.config = config 

28 self.mp_policy = config.mp_policy 

29 self.offload_policy = config.offload_policy 

30 self.platform = platform 

31 self.device = device 

32 self.hsdp_params: List[HSDPParamV2] = [] 

33 self.sharded_hsdp_params: List[HSDPParamV2] = [] 

34 self.param_buffers = [] 

35 self.grad_buffers = [] 

36 self._move_states_to_device() 

37 self._init_hsdp_params() 

38 self._init_param_buffers() 

39 self._init_grad_buffers() 

40 self.is_shard = True 

41 

42 def _init_hsdp_params(self): 

43 """init hsdp parameters for cell""" 

44 raise NotImplementedError("HSDPState subclasses must implement _init_hsdp_params") 

45 

46 def _move_states_to_device(self): 

47 """move states to device""" 

48 raise NotImplementedError("HSDPState subclasses must implement _move_states_to_device") 

49 

50 def _init_param_buffers(self): 

51 """init param buffers""" 

52 if not self.config.comm_fusion: 

53 return 

54 

55 group_to_buffer = {} 

56 for hsdp_param in self.sharded_hsdp_params: 

57 param_buffer_key = hsdp_param.sharded_group_info.group_name + str(hsdp_param.param.dtype) 

58 if param_buffer_key not in group_to_buffer: 

59 buffer = HSDPParamBuffer(self.config, hsdp_param, self.platform) 

60 buffer.add_param(hsdp_param) 

61 group_to_buffer[param_buffer_key] = buffer 

62 else: 

63 buffer = group_to_buffer[param_buffer_key] 

64 buffer.add_param(hsdp_param) 

65 self.param_buffers = list(group_to_buffer.values()) 

66 for buffer in self.param_buffers: 

67 buffer.init() 

68 

69 def _init_grad_buffers(self): 

70 """init grad buffers""" 

71 if not self.config.grad_fusion: 

72 return 

73 

74 bucket_infos = {} 

75 

76 def get_bucket_key(buffer_key, hsdp_param): 

77 if self.config.bucket_size < 0: 

78 return buffer_key 

79 param_size = hsdp_param.param.numel() * self.platform.get_param_type_size(hsdp_param.param) 

80 bucket_info = bucket_infos.get(buffer_key, None) 

81 if bucket_info is None: 

82 bucket_info = [0, param_size] 

83 bucket_infos[buffer_key] = bucket_info 

84 else: 

85 bucket_size = bucket_info[1] + param_size 

86 if bucket_size > self.config.bucket_size: 

87 bucket_info[0] = bucket_info[0] + 1 

88 bucket_info[1] = param_size 

89 return buffer_key + '_' + str(bucket_info[0]) 

90 

91 self.param_to_buffer = {} 

92 group_to_buffer = {} 

93 for hsdp_param in self.hsdp_params: 

94 if not hsdp_param.param.requires_grad: 

95 continue 

96 buffer_key = hsdp_param.sharded_group_info.group_name + hsdp_param.unsharded_group_info.group_name \ 

97 + str(hsdp_param.param.dtype) 

98 bucket_key = get_bucket_key(buffer_key, hsdp_param) 

99 if bucket_key not in group_to_buffer: 

100 buffer = HSDPGradBuffer(self.config, hsdp_param, self.platform) 

101 group_to_buffer[bucket_key] = buffer 

102 else: 

103 buffer = group_to_buffer[bucket_key] 

104 buffer.add_param(hsdp_param) 

105 self.param_to_buffer[hsdp_param] = buffer 

106 self.grad_buffers = list(group_to_buffer.values()) 

107 for buffer in self.grad_buffers: 

108 buffer.init() 

109 

110 def shard(self): 

111 """change parameters to sharded state""" 

112 if self.is_shard: 

113 return 

114 

115 if self.config.comm_fusion: 

116 for buffer in self.param_buffers: 

117 buffer.to_sharded() 

118 else: 

119 for param in self.sharded_hsdp_params: 

120 param.to_sharded() 

121 self.is_shard = True 

122 

123 def unshard(self, async_op=False): 

124 """change parameters to unsharded state""" 

125 if not self.is_shard: 

126 return 

127 

128 if self.config.comm_fusion: 

129 raise ValueError(f"comm_fusion is deprecated, check config.comm_fusion.") 

130 for buffer in self.param_buffers: 

131 buffer.to_unsharded(async_op=async_op) 

132 else: 

133 for param in self.sharded_hsdp_params: 

134 param.unshard() 

135 param.wait_for_unshard() 

136 self.is_shard = False 

137 

138 def prefetch(self): 

139 """prefetch unsharded parameters""" 

140 if not self.is_shard: 

141 return 

142 if self.config.comm_fusion: 

143 for buffer in self.param_buffers: 

144 buffer.prefetch_unsharded() 

145 else: 

146 for param in self.sharded_hsdp_params: 

147 param.unshard(async_op=True) 

148 

149 

150 def wait_for_unsharded(self): 

151 """wait for all unsharded parameters""" 

152 if not self.is_shard: 

153 return 

154 if self.config.comm_fusion: 

155 for buffer in self.param_buffers: 

156 if buffer.prefetch_handle is not None: 

157 buffer.wait_for_unsharded() 

158 else: 

159 for param in self.sharded_hsdp_params: 

160 if param.prefetch_handle is not None: 

161 param.wait_for_unsharded() 

162 

163 def zero_grads(self): 

164 """zero grad or grad buffer""" 

165 if not self.config.grad_fusion: 

166 for hsdp_param in self.hsdp_params: 

167 if not hsdp_param.param.requires_grad: 

168 continue 

169 hsdp_param.zero_acc_grad() 

170 else: 

171 for buffer in self.grad_buffers: 

172 buffer.zero_grads() 

173 

174 def set_grad_ready(self, hsdp_param): 

175 """set grad ready""" 

176 if not self.config.grad_fusion: 

177 return 

178 buffer = self.param_to_buffer.get(hsdp_param, None) 

179 if buffer is not None: 

180 buffer.set_grad_ready() 

181 else: 

182 raise ValueError(f"param {hsdp_param.param} is not register to buffer.") 

183 

184 def set_requires_grad_sync(self, requires_grad_sync): 

185 """set requires grad sync flag to control gradient sync.""" 

186 if not self.config.grad_fusion: 

187 return 

188 for buffer in self.grad_buffers: 

189 buffer.set_requires_grad_sync(requires_grad_sync)