Coverage for hyper_parallel / core / hsdp / hsdp_param.py: 93%

239 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP parameter""" 

16import functools 

17from hyper_parallel.core.dtensor import DTensor 

18from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel, GroupInfo 

19 

20 

21class HSDPParam: 

22 """ 

23 HSDP parameter. 

24 """ 

25 def __init__(self, cell, param_name, param, config, platform): 

26 self.cell = cell 

27 self.param_name = param_name 

28 self.param = param 

29 self.config = config 

30 self.platform = platform 

31 self.shard_size = 1 

32 self.unsharded_param = None 

33 self.sharded_param = None 

34 self.acc_grad = None 

35 self.grad = None 

36 self.sharded = False 

37 self.fully_sharded = True 

38 self.prefetch_handle = None 

39 self.prefetch_data = None 

40 self.param_buffer_start_index = 0 

41 self.param_buffer_end_index = 0 

42 self.grad_buffer_start_index = 0 

43 self.grad_buffer_end_index = 0 

44 self._init_rank_info() 

45 self._init_param_shard_size() 

46 self._init_param() 

47 self.dp_size = self.rank_size // self.shard_size 

48 group_name, group = self._create_sharded_dp_group() 

49 self.sharded_group_info = GroupInfo(group_name, group, self.shard_size) 

50 group_name, group = self._create_unsharded_dp_group() 

51 self.unsharded_group_info = GroupInfo(group_name, group, self.dp_size) 

52 

53 def _init_rank_info(self): 

54 """init parameter rank info""" 

55 self.rank_id = self.platform.get_rank() 

56 self.hsdp_rank = self.rank_id 

57 self.local_rank = self.rank_id 

58 self.tp_rank = 0 

59 if not isinstance(self.param, DTensor) or self.param.layout is None: 

60 self.rank_size = self.platform.get_world_size() 

61 return 

62 

63 if len(self.param.layout.rank_list) == 1: 

64 self.rank_size = 1 

65 return 

66 

67 try: 

68 self.local_rank = self.param.layout.rank_list.index(self.rank_id) 

69 except ValueError as e: 

70 raise ValueError(f"HSDP invalid rank {self.rank_id} with rank list {self.param.layout.rank_list}.") from e 

71 

72 tensor_map = self.param.layout.tensor_map 

73 sharded_axis_set = set() 

74 for axis in tensor_map: 

75 if isinstance(axis, int) and axis != -1: 

76 sharded_axis_set.add(axis) 

77 continue 

78 if isinstance(axis, tuple): 

79 for item in axis: 

80 sharded_axis_set.add(item) 

81 self.sharded_axis_set = sharded_axis_set 

82 self.rank_size = 1 

83 self.unsharded_reverse_axis_list = [] 

84 self.global_rank_stride_list = [] 

85 self.hsdp_rank_stride_list = [] 

86 self.tp_rank_stride_list = [] 

87 device_dims = len(self.param.layout.mesh_shape) 

88 stride = 1 

89 hsdp_stride = 1 

90 tp_stride = 1 

91 for axis in range(device_dims): 

92 r_axis = device_dims - 1 - axis 

93 self.global_rank_stride_list.append(stride) 

94 self.hsdp_rank_stride_list.append(hsdp_stride) 

95 self.tp_rank_stride_list.append(tp_stride) 

96 stride = stride * self.param.layout.mesh_shape[r_axis] 

97 if axis in self.sharded_axis_set: 

98 tp_stride = tp_stride * self.param.layout.mesh_shape[r_axis] 

99 continue 

100 

101 hsdp_stride = hsdp_stride * self.param.layout.mesh_shape[r_axis] 

102 self.unsharded_reverse_axis_list.append(r_axis) 

103 self.rank_size = self.rank_size * self.param.layout.mesh_shape[r_axis] 

104 self.global_rank_stride_list.reverse() 

105 self.hsdp_rank_stride_list.reverse() 

106 self.tp_rank_stride_list.reverse() 

107 self.unsharded_reverse_axis_list.reverse() 

108 

109 rank_indices = [] 

110 index = self.local_rank 

111 for stride in self.global_rank_stride_list: 

112 rank_indices.append(index // stride) 

113 index = index % stride 

114 self.rank_indices = rank_indices 

115 hsdp_rank = 0 

116 for axis in self.unsharded_reverse_axis_list: 

117 hsdp_rank = hsdp_rank + rank_indices[axis] * self.hsdp_rank_stride_list[axis] 

118 self.hsdp_rank = hsdp_rank 

119 tp_rank = 0 

120 for axis in range(device_dims): 

121 if axis in self.sharded_axis_set: 

122 r_axis = device_dims - 1 - axis 

123 tp_rank = tp_rank + rank_indices[r_axis] * self.tp_rank_stride_list[r_axis] 

124 self.tp_rank = tp_rank 

125 

126 def _hsdp_rank_to_global_rank(self, hsdp_rank_list): 

127 """transform from hsdp rank to global rank""" 

128 rank_list = [] 

129 for hsdp_rank in hsdp_rank_list: 

130 local_index = hsdp_rank 

131 local_indices_dict = {} 

132 for axis in self.unsharded_reverse_axis_list: 

133 stride = self.hsdp_rank_stride_list[axis] 

134 local_indices_dict[axis] = local_index // stride 

135 local_index = local_index % stride 

136 global_rank = 0 

137 for axis, index in enumerate(self.rank_indices): 

138 index = local_indices_dict.get(axis, index) 

139 global_rank = global_rank + index * self.global_rank_stride_list[axis] 

140 if self.param.layout is not None: 

141 if global_rank >= len(self.param.layout.rank_list): 

142 raise ValueError(f"HSDP invalid index {global_rank} with" 

143 f"rank list len {len(self.param.layout.rank_list)}.") 

144 global_rank = self.param.layout.rank_list[global_rank] 

145 rank_list.append(global_rank) 

146 return rank_list 

147 

148 def _get_op_rank_list(self): 

149 """get data parallel rank list""" 

150 if isinstance(self.param, DTensor): 

151 rank_base = self.hsdp_rank // self.shard_size * self.shard_size 

152 hsdp_rank_list = [i + rank_base for i in range(self.shard_size)] 

153 return self._hsdp_rank_to_global_rank(hsdp_rank_list) 

154 rank_base = self.local_rank // self.shard_size * self.shard_size 

155 rank_list = [i + rank_base for i in range(self.shard_size)] 

156 return rank_list 

157 

158 def _get_dp_rank_list(self): 

159 """get optimizer parallel rank list""" 

160 if isinstance(self.param, DTensor): 

161 rank_stride = self.shard_size 

162 rank_base = self.hsdp_rank % rank_stride 

163 hsdp_rank_list = [i * rank_stride + rank_base for i in range(self.dp_size)] 

164 return self._hsdp_rank_to_global_rank(hsdp_rank_list) 

165 rank_stride = self.shard_size 

166 rank_base = self.local_rank % rank_stride 

167 rank_list = [i * rank_stride + rank_base for i in range(self.dp_size)] 

168 return rank_list 

169 

170 def _init_sharded_param(self): 

171 """add and init sharded param""" 

172 raise NotImplementedError("HSDP param subclasses must implement _init_sharded_param") 

173 

174 def _init_unsharded_param(self): 

175 """add and init unshared param""" 

176 raise NotImplementedError("HSDP param subclasses must implement _init_unsharded_param") 

177 

178 def _get_unsharded_param_data(self, async_op): 

179 """get unsharded param data with async comm""" 

180 local_data = self.platform.get_param_local_data(self.param) 

181 return self.platform.all_gather_into_tensor(local_data, self.sharded_group_info, async_op=async_op) 

182 

183 def _init_param_shard_size(self): 

184 """init parameter dp shard size""" 

185 if hasattr(self.param, "hsdp_shard_size"): 

186 if not isinstance(self.param.hsdp_shard_size, int) or \ 

187 (self.param.hsdp_shard_size <= 0 and self.param.hsdp_shard_size != -1): 

188 raise ValueError(f"param's hsdp_shard_size must be a positive integer, " 

189 f"but got {self.param.hsdp_shard_size}.") 

190 self.shard_size = self.param.hsdp_shard_size 

191 else: 

192 self.shard_size = self.config.shard_size 

193 local_shape = self.platform.get_param_local_shape(self.param) 

194 if len(local_shape) < 1: 

195 self.shard_size = 1 

196 return 

197 param_type_size = self.platform.get_param_type_size(self.param) 

198 param_size = functools.reduce(lambda x, y: x * y, local_shape, param_type_size) 

199 if param_size < self.config.threshold: 

200 self.shard_size = 1 

201 return 

202 if self.shard_size == -1 or local_shape[0] < self.shard_size: 

203 self.shard_size = local_shape[0] 

204 

205 def _gcd(m, n): 

206 if m < n: 

207 m, n = n, m 

208 if n == 0: 

209 raise ValueError("HSDP invalid gcd input 0.") 

210 r = m % n 

211 if r == 0: 

212 return n 

213 return _gcd(n, r) 

214 

215 rank_gcd = _gcd(local_shape[0], self.rank_size) 

216 self.shard_size = min(self.shard_size, rank_gcd) 

217 if rank_gcd % self.shard_size != 0: 

218 self.shard_size = 1 

219 self.param.hsdp_effective_shard_size = self.shard_size 

220 

221 def _create_sharded_dp_group(self): 

222 """create communication group for sharded parameter""" 

223 if self.shard_size <= 1: 

224 return "hsdp_sharded_dp_group_invalid", None 

225 

226 rank_list = self._get_op_rank_list() 

227 rank_list_str = "_".join([str(i) for i in rank_list]) 

228 group_name = "hsdp_sharded_dp_group_" + rank_list_str 

229 group = self.platform.create_group(rank_list, group_name) 

230 return group_name, group 

231 

232 def _create_unsharded_dp_group(self): 

233 """create communication group for unsharded parameter""" 

234 if self.dp_size <= 1: 

235 return "hsdp_unsharded_dp_group_invalid", None 

236 

237 rank_list = self._get_dp_rank_list() 

238 rank_list_str = "_".join([str(i) for i in rank_list]) 

239 group_name = "hsdp_unshared_dp_group_" + rank_list_str 

240 group = self.platform.create_group(rank_list, group_name) 

241 return group_name, group 

242 

243 def _init_param(self): 

244 """init hsdp parameter""" 

245 self.param.acc_grad = None 

246 self.param_shape = self.platform.get_param_local_shape(self.param) 

247 

248 if self.shard_size == 1: 

249 self.sharded = False 

250 self.fully_sharded = False 

251 if self.config.requires_acc_grad and self.param.requires_grad: 

252 acc_grad_type = self.param.dtype 

253 if self.config.reduce_dtype is not None: 

254 acc_grad_type = self.config.reduce_dtype 

255 self.acc_grad = self.platform.new_zero_parameter(self.param_shape, acc_grad_type, False, 

256 self.param.device) 

257 self.param.acc_grad = self.acc_grad 

258 return 

259 

260 origin_param_shape = list(self.param_shape) 

261 self._init_unsharded_param() 

262 self._init_sharded_param() 

263 if self.config.requires_acc_grad and self.param.requires_grad: 

264 acc_grad_shape = origin_param_shape 

265 if self.config.shard_level != OptimizerLevel.SHARD_OPT: 

266 acc_grad_shape = self.sharded_param.shape 

267 acc_grad_type = self.param.dtype 

268 if self.config.reduce_dtype is not None: 

269 acc_grad_type = self.config.reduce_dtype 

270 self.acc_grad = self.platform.new_zero_parameter(acc_grad_shape, acc_grad_type, False, self.param.device) 

271 self.param.acc_grad = self.acc_grad 

272 self.sharded = True 

273 if self.shard_size == self.rank_size: 

274 self.fully_sharded = True 

275 else: 

276 self.fully_sharded = False 

277 

278 def to_sharded(self): 

279 """change parameter to sharded state""" 

280 self.platform.update_param_data(self.param, self.sharded_param) 

281 

282 def prefetch_unsharded(self): 

283 """prefetch unsharded param with async all gather""" 

284 if self.prefetch_handle is not None: 

285 return 

286 unshared_param_data, handle = self._get_unsharded_param_data(async_op=True) 

287 self.prefetch_data = unshared_param_data 

288 self.prefetch_handle = handle 

289 

290 #pylint: disable=W0212 

291 def to_unsharded(self): 

292 """change parameter to unsharded state""" 

293 if self.prefetch_handle is not None: 

294 self.prefetch_handle.wait() 

295 self.platform.update_param_data(self.sharded_param, self.platform.get_param_local_data(self.param)) 

296 self.platform.update_param_data(self.param, self.prefetch_data) 

297 self.prefetch_handle = None 

298 self.prefetch_data = None 

299 return 

300 

301 unshared_param_data, _ = self._get_unsharded_param_data(async_op=False) 

302 self.platform.update_param_data(self.sharded_param, self.platform.get_param_local_data(self.param)) 

303 self.platform.update_param_data(self.param, unshared_param_data) 

304 

305 def zero_acc_grad(self): 

306 """zero accumunication grad""" 

307 if self.param.acc_grad is not None: 

308 self.param.acc_grad.zero_()