Coverage for hyper_parallel / core / hsdp / api.py: 77%

101 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""hybrid shard data parallel interface""" 

16from typing import Optional, Any 

17from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel 

18from hyper_parallel.platform.platform import PlatformType 

19from hyper_parallel.platform import get_platform 

20platform = get_platform() 

21 

22origin_class_to_extend_class = {} 

23optimizer_level_map = { 

24 "level1": OptimizerLevel.SHARD_OPT, 

25 "level2": OptimizerLevel.SHARD_OPT_GRAD, 

26 "level3": OptimizerLevel.SHARD_OPT_GRAD_PARAM, 

27} 

28 

29 

30class HSDPCell: 

31 """ 

32 The hsdp block of neural networks with hsdp interface. 

33 

34 Supported Platforms: 

35 ``MindSpore`` ``torch`` 

36 """ 

37 # pylint: disable=C0415 

38 def hsdp_init(self, platform_type, cell, shard_size, threshold, optimizer_level, enable_grad_accumulation, 

39 use_eager_hook, grad_scale, reduce_dtype, comm_async, comm_fusion, bucket_size): 

40 """init hsdp scheduler.""" 

41 scheduler_class = None 

42 if platform_type == PlatformType.MINDSPORE: 

43 from hyper_parallel.platform.mindspore.hsdp.scheduler import MindSporeHSDPScheduler 

44 scheduler_class = MindSporeHSDPScheduler 

45 else: 

46 from hyper_parallel.platform.torch.hsdp.scheduler import TorchHSDPScheduler 

47 scheduler_class = TorchHSDPScheduler 

48 

49 self.hsdp_scheduler = scheduler_class(cell, 

50 shard_size, 

51 threshold, 

52 optimizer_level, 

53 enable_grad_accumulation, 

54 grad_scale, 

55 use_eager_hook, 

56 reduce_dtype, 

57 comm_async, 

58 comm_fusion, 

59 bucket_size) 

60 

61 def set_requires_grad_sync(self, requires_grad_sync): 

62 r""" 

63 set requires grad sync flag. 

64 Args: 

65 requires_grad_sync(bool): requires_grad_sync is used to control gradient sync process. 

66 Raises: 

67 ValueError: If `requires_grad_sync` is not bool. 

68 """ 

69 if not isinstance(requires_grad_sync, bool): 

70 raise ValueError(f"requires_grad_sync must be bool but got {requires_grad_sync}.") 

71 if not hasattr(self, "hsdp_scheduler"): 

72 raise ValueError("call hsdp interface first.") 

73 

74 for _, cell in platform.get_cells_and_names(self): 

75 if isinstance(cell, HSDPCell): 

76 cell.hsdp_scheduler.set_requires_grad_sync(requires_grad_sync) 

77 

78 def zero_grads(self): 

79 """zero accumunication grads""" 

80 if not hasattr(self, "hsdp_scheduler"): 

81 raise ValueError("call hsdp interface first.") 

82 

83 for _, cell in platform.get_cells_and_names(self): 

84 if isinstance(cell, HSDPCell): 

85 cell.hsdp_scheduler.zero_grads() 

86 

87 def set_forward_prefetch_cells(self, hsdp_cell_list): 

88 """set forward prefetch cell list to prefetch all gather for unsharded parameters""" 

89 if not isinstance(hsdp_cell_list, (tuple, list)): 

90 raise ValueError("hsdp_cell_list must be HSDPCell list") 

91 for cell in hsdp_cell_list: 

92 if not isinstance(cell, HSDPCell): 

93 raise ValueError(f"hsdp_cell_list must be HSDPCell list but got {type(cell)} in list.") 

94 if not hasattr(self, "hsdp_scheduler"): 

95 raise ValueError("call hsdp interface first.") 

96 self.hsdp_scheduler.set_forward_prefetch_cells(hsdp_cell_list) 

97 

98 def set_backward_prefetch_cells(self, hsdp_cell_list): 

99 """set backward prefetch cell list to prefetch all gather for unsharded parameters""" 

100 if not isinstance(hsdp_cell_list, (tuple, list)): 

101 raise ValueError("hsdp_cell_list must be HSDPCell list") 

102 for cell in hsdp_cell_list: 

103 if not isinstance(cell, HSDPCell): 

104 raise ValueError(f"hsdp_cell_list must be HSDPCell list but got {type(cell)} in list.") 

105 if not hasattr(self, "hsdp_scheduler"): 

106 raise ValueError("call hsdp interface first.") 

107 self.hsdp_scheduler.set_backward_prefetch_cells(hsdp_cell_list) 

108 

109def _extend_cell_with_hsdp_interface(cell): 

110 """extend Cell with HSDPCell interface""" 

111 origin_class = cell.__class__ 

112 extend_class = origin_class_to_extend_class.get(origin_class, None) 

113 if extend_class is None: 

114 extend_class = type(f"HSDP{origin_class.__name__}", (HSDPCell, origin_class), {}) 

115 origin_class_to_extend_class[origin_class] = extend_class 

116 cell.__class__ = extend_class 

117 

118# pylint: disable=C0415 

119def _check_cell_valid(platform_type, cell): 

120 """check cell valid""" 

121 if platform_type == PlatformType.MINDSPORE: 

122 from mindspore.nn.cell import Cell 

123 if not isinstance(cell, Cell): 

124 raise ValueError(f"cell's type must be nn.cell but got {type(cell)}.") 

125 else: 

126 from torch.nn import Module 

127 if not isinstance(cell, Module): 

128 raise ValueError(f"cell's type must be nn.Module but got {type(cell)}.") 

129 

130# pylint: disable=C0415 

131def _check_hsdp_input_valid(platform_type, cell, shard_size, threshold, optimizer_level, enable_grad_accumulation, 

132 use_eager_hook, grad_scale, reduce_dtype, comm_async, comm_fusion, bucket_size): 

133 """check hsdp input valid""" 

134 _check_cell_valid(platform_type, cell) 

135 if not isinstance(shard_size, int) or (shard_size <= 0 and shard_size != -1): 

136 raise ValueError(f"shard_size must be a positive integer, but got {shard_size}.") 

137 if not isinstance(threshold, int) or threshold < 0: 

138 raise ValueError(f"threshold must be a positive integer or 0, but got {threshold}.") 

139 if optimizer_level not in ["level1", "level2", "level3"]: 

140 raise ValueError(f"Optimizer level should in ['level1', 'level2', 'level3'], but got {optimizer_level}.") 

141 if not isinstance(enable_grad_accumulation, bool): 

142 raise ValueError(f"enable_grad_accumulation must be bool but got {enable_grad_accumulation}.") 

143 if not isinstance(grad_scale, float): 

144 raise ValueError(f"grad_scale must be float but got {grad_scale}.") 

145 if not isinstance(use_eager_hook, bool): 

146 raise ValueError(f"use_eager_hook must be bool but got {use_eager_hook}.") 

147 if platform_type == PlatformType.MINDSPORE: 

148 from mindspore._c_expression.typing import Type 

149 if reduce_dtype is not None and not isinstance(reduce_dtype, Type): 

150 raise ValueError(f"reduce_dtype must be mindspore.dtype but got {reduce_dtype}.") 

151 else: 

152 import torch 

153 if reduce_dtype is not None and not isinstance(reduce_dtype, torch.dtype): 

154 raise ValueError(f"reduce_dtype must be torch.dtype but got {reduce_dtype}.") 

155 if not isinstance(comm_async, bool): 

156 raise ValueError(f"comm_async must be bool but got {comm_async}.") 

157 if not isinstance(comm_fusion, bool): 

158 raise ValueError(f"comm_fusion must be bool but got {comm_fusion}.") 

159 if not isinstance(bucket_size, int) or (bucket_size < 0 and bucket_size != -1): 

160 raise ValueError(f"bucket_size must be a positive integer or 0, but got {bucket_size}.") 

161 

162def hsdp( 

163 cell, 

164 shard_size: Optional[int] = -1, 

165 threshold: Optional[int] = 64, 

166 optimizer_level: Optional[str] = "level1", 

167 enable_grad_accumulation: Optional[bool] = False, 

168 use_eager_hook: Optional[bool] = True, 

169 grad_scale: Optional[float] = 1.0, 

170 reduce_dtype: Optional[Any] = None, 

171 comm_async: Optional[bool] = False, 

172 comm_fusion: Optional[bool] = False, 

173 bucket_size: Optional[int] = -1 

174): 

175 r""" 

176 apply hybrid sharded data parallel. 

177 

178 Args: 

179 cell(Cell|Module): The cell to apply hsdp. 

180 shard_size (int, optional): Set the optimizer weight shard group size if you want to specific the 

181 maximum group size across devices. The numerical range can be (0, device_num] or -1. Default value 

182 is -1, which means the optimizer weight shard group size will be 

183 the data parallel group of each parameter. 

184 threshold (int, optional): Set the threshold of parallel optimizer. When parallel optimizer is 

185 enabled, parameters with size smaller than this threshold will not be 

186 sharded across the devices. Parameter size = shape[0] \* ... \* 

187 shape[n] \* size(dtype). Non-negative. Unit: KB. Default: 64. 

188 optimizer_level (str, optional): optimizer_level configuration is used to specify 

189 the splitting level for optimizer sharding. It is important to note that the implementation 

190 of optimizer sharding in static graph is inconsistent with dynamic graph like megatron, 

191 but the memory optimization effect is the same. 

192 It must be one of [ ``level1``, ``level2``, ``level3`` ]. Default: ``level1``. 

193 

194 - level1: 

195 Splitting is performed on weights and optimizer state. 

196 - level2: 

197 Splitting is performed on weights, optimizer state, and gradients. 

198 - level3: 

199 Splitting is performed on weights, optimizer state, 

200 gradients, additionally, before the backward pass, the weights are further applied with 

201 allgather communication to release the memory used by the forward pass allgather. 

202 enable_grad_accumulation (bool, optional): enable gradient accumulation. When gradient accumulation is 

203 enable, gradient synchronization should be explicitly called by `set_requires_grad_sync` interface. 

204 use_eager_hook (bool, optional): Controls whether to enable eager hook behavior. Default: ``True``. 

205 

206 - Set to `True` for **eager mode** (both MindSpore and PyTorch). 

207 - Set to `False` for **MindSpore Graph mode** (static graph). 

208 grad_scale (float, optional): gradient will scale with grad_scale. 

209 reduce_dtype (float, optional): gradient reduce dtype. Default value is None, which means gradient 

210 will be reduced with its origin dtype. 

211 comm_async (bool, optional): reduce gradient with async communication op for communication overlap. 

212 When comm_async is enable, ``hsdp_sync_stream`` should be called before using generated 

213 gradient. Default value is False, which means gradient will be reduced with sync communication op. 

214 comm_fusion (bool, optional): fuse forward parameter allgathers and backward gradient reducescatters or 

215 allreduces into buffers communication to reduce the number of communication op. ``bucket_size` will 

216 further control the size of backward gradient reduce buffer size. 

217 Default value is False, which means communication op is not fused and will run one by one. 

218 bucket_size (int, optional): bucket_size is used to control the size of comm fusion buffer. Unit: KB. 

219 Default value is -1, which means gradient will be fused into a buffer. When value is 0, which means 

220 gradients will not be fused, each gradient acts as a buffer. 

221 

222 Raises: 

223 ValueError: If the `cell` is not a cell. 

224 ValueError: If the `shard_size` is not a positive integer or -1. 

225 ValueError: If `threshold` is not a positive integer or 0. 

226 ValueError: If `optimizer_level` is not one of the [ ``level1``, ``level2``, ``level3`` ]. 

227 ValueError: If `enable_grad_accumulation` is not bool. 

228 ValueError: If `grad_scale` is not float. 

229 ValueError: If `reduce_dtype` is not mindspore.dtype. 

230 ValueError: If `comm_async` is not bool. 

231 ValueError: If `comm_fusion` is not bool. 

232 ValueError: If the `bucket_size` is not a positive integer or -1. 

233 """ 

234 platform_type = platform.platform_type 

235 

236 _check_hsdp_input_valid( 

237 platform_type, 

238 cell, 

239 shard_size, 

240 threshold, 

241 optimizer_level, 

242 enable_grad_accumulation, 

243 use_eager_hook, 

244 grad_scale, 

245 reduce_dtype, 

246 comm_async, 

247 comm_fusion, 

248 bucket_size 

249 ) 

250 optimizer_level = optimizer_level_map.get(optimizer_level) 

251 _extend_cell_with_hsdp_interface(cell) 

252 cell.hsdp_init( 

253 platform_type, 

254 cell, 

255 shard_size, 

256 threshold * 1024, 

257 optimizer_level, 

258 enable_grad_accumulation, 

259 use_eager_hook, 

260 grad_scale, 

261 reduce_dtype, 

262 comm_async, 

263 comm_fusion, 

264 bucket_size * 1024 

265 ) 

266 return cell 

267 

268def hsdp_sync_stream(): 

269 """wait for hsdp gradient handle to be completed""" 

270 if platform is None: 

271 return 

272 platform.wait_grad_handle()