Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / fully_shard / hsdp_scheduler.py: 44%

180 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP scheduler""" 

16import functools 

17from typing import Any, List, Optional, Tuple, Union 

18 

19from hyper_parallel.platform import get_platform 

20from hyper_parallel.core.dtensor.device_mesh import DeviceMesh 

21from hyper_parallel.core.fully_shard.hsdp_utils import ( 

22 FSDPSchedulerState, 

23 HSDPConfigV2, 

24 get_managed_modules_parameters, 

25 get_hsdp_state 

26) 

27 

28platform = get_platform() 

29 

30 

31class HSDPSchedulerContext: 

32 """HSDPSchedulerContext""" 

33 

34 def __init__(self) -> None: 

35 # Currently only record is_last_backward flag for scheduler context. 

36 self.is_last_backward: bool = True 

37 # flag to identify "root_module" 

38 self.root_module = None 

39 

40 

41class HSDPSchedulerV2: 

42 """HSDPScheduler is used to scheduler hsdp""" 

43 root_bp_state = False 

44 def __init__(self, cell: Union[platform.Module, Tuple[platform.Module, ...]], mesh, 

45 reshard_after_forward, shard_placement_fn, 

46 mp_policy, offload_policy, ignored_params, replicate_params, device, comm_fusion, 

47 comm_fusion_zero_copy=False): 

48 """init hsdp scheduler. 

49 

50 Args: 

51 cell: A single platform.Module or tuple of platform.Module to manage as one FSDP unit. 

52 """ 

53 self.modules = (cell,) if isinstance(cell, platform.Module) else tuple(cell) 

54 self.cell = self.modules[0] 

55 self.mesh: DeviceMesh = mesh 

56 self.reshard_after_forward = reshard_after_forward 

57 self.shard_placement_fn = shard_placement_fn 

58 self.mp_policy = mp_policy 

59 self.offload_policy = offload_policy 

60 self.ignored_params = ignored_params 

61 self.replicate_params = replicate_params 

62 self.device = device 

63 self.scheduler_state = None 

64 self.forward_prefetch_cells = [] 

65 self.backward_prefetch_cells = [] 

66 self._backup_forward_fetch = None 

67 # Flag to identify root module. 

68 self._is_root = False 

69 # module and its all sub-modules share one same 'HSDPSchedulerContext' 

70 self.scheduler_ctx = HSDPSchedulerContext() 

71 # When ``fully_shard`` is given multiple root modules, forward pre/post hooks coordinate 

72 # so unshard / PostBackward / reshard run once per forward (aligned with PyTorch FSDP2). 

73 self._fsdp_group_post_pending: Optional[set] = set() if len(self.modules) > 1 else None 

74 self.config = HSDPConfigV2( 

75 mesh, 

76 reshard_after_forward, 

77 shard_placement_fn, 

78 mp_policy, 

79 offload_policy, 

80 ignored_params, 

81 replicate_params, 

82 comm_fusion=comm_fusion, 

83 comm_fusion_zero_copy=comm_fusion_zero_copy, 

84 ) 

85 self._init_platform() 

86 self._new_cell_state() 

87 self._register_hooks() 

88 

89 def _init_platform(self): 

90 """Initialize the platform.""" 

91 raise NotImplementedError("HSDPScheduler subclasses must implement _init_platform") 

92 

93 def _new_cell_state(self): 

94 """Create a new cell state.""" 

95 raise NotImplementedError("HSDPScheduler subclasses must implement _new_cell_state") 

96 

97 def _register_hooks(self): 

98 """Register hooks.""" 

99 raise NotImplementedError("HSDPScheduler subclasses must implement _register_hooks.") 

100 

101 def _register_forward_backward_hooks(self): 

102 """Register module forward and backward hook.""" 

103 raise NotImplementedError("HSDPScheduler subclasses must implement _register_forward_backward_hooks.") 

104 

105 def _get_managed_params(self): 

106 """Return deduplicated parameters from all managed modules.""" 

107 return get_managed_modules_parameters(self.modules, self.ignored_params) 

108 

109 def set_reshard_after_forward(self, reshard_after_forward: bool) -> None: 

110 """Set reshard_after_forward flag. 

111 

112 Args: 

113 reshard_after_forward: Whether to reshard parameters after forward. 

114 """ 

115 if not isinstance(reshard_after_forward, bool): 

116 raise ValueError(f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}") 

117 self.reshard_after_forward = reshard_after_forward 

118 self.config.reshard_after_forward = reshard_after_forward 

119 

120 def set_reshard_after_backward(self, reshard_after_backward: bool) -> None: 

121 """Set reshard_after_backward flag. 

122 

123 Args: 

124 reshard_after_backward: Whether to reshard after backward completes. 

125 """ 

126 if not isinstance(reshard_after_backward, bool): 

127 raise ValueError(f"reshard_after_backward should be a bool, got {type(reshard_after_backward)}") 

128 if self.hsdp_state is not None: 

129 self.hsdp_state.reshard_after_backward = reshard_after_backward 

130 

131 def set_requires_all_reduce(self, requires_all_reduce: bool) -> None: 

132 """Set requires_all_reduce flag. 

133 

134 Args: 

135 requires_all_reduce: Whether this unit participates in all-reduce. 

136 """ 

137 if not isinstance(requires_all_reduce, bool): 

138 raise ValueError(f"requires_all_reduce should be a bool, got {type(requires_all_reduce)}") 

139 if self.hsdp_state is not None: 

140 self.hsdp_state.requires_all_reduce = requires_all_reduce 

141 

142 def set_requires_grad_sync(self, requires_grad_sync: bool) -> None: 

143 """Set flag controlling whether gradients are synchronized. 

144 

145 Args: 

146 requires_grad_sync: When True, enable grad sync for this scheduler. 

147 """ 

148 if not isinstance(requires_grad_sync, bool): 

149 raise ValueError(f"requires_grad_sync should be a bool, got {type(requires_grad_sync)}") 

150 self.hsdp_state.set_requires_grad_sync(requires_grad_sync) 

151 

152 # pylint: disable=W0613 

153 def _hsdp_forward_pre_hook(self, cell, args, kwargs): 

154 """Forward pre hook to unsharded parameter for forward process.""" 

155 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD: 

156 return args, kwargs 

157 if HSDPSchedulerV2.root_bp_state: 

158 self._disable_forward_prefetch_for_recompute() 

159 if self.scheduler_ctx.root_module is None: 

160 self.scheduler_ctx.root_module = self.cell 

161 self._is_root = True 

162 for _, module in platform.get_cells_and_names(self.scheduler_ctx.root_module): 

163 from hyper_parallel.core.fully_shard.api import HSDPModule # pylint: disable=C0415 

164 if isinstance(module, HSDPModule): 

165 submod_scheduler = getattr(module, "hsdp_scheduler", None) 

166 if submod_scheduler and submod_scheduler.scheduler_ctx is not self.scheduler_ctx: 

167 submod_scheduler.scheduler_ctx = self.scheduler_ctx 

168 

169 if not self._is_root and not self.hsdp_state.module_name: 

170 for module_name, module in platform.get_cells_and_names(self.scheduler_ctx.root_module): 

171 if module == self.cell: 

172 self.hsdp_state.module_name = module_name 

173 break 

174 self.scheduler_state = FSDPSchedulerState.PRE_FORWARD 

175 self._init_params_fqn() 

176 self._lazy_init_all_states() 

177 if self.mp_policy.cast_forward_inputs and self.mp_policy.param_dtype: 

178 cast_fn = functools.partial(self.platform.cast_fp_tensor, self.mp_policy.param_dtype) 

179 args = self.platform.apply_to_tensors(cast_fn, args) 

180 kwargs = self.platform.apply_to_tensors(cast_fn, kwargs) 

181 for prefetch_cell in self.forward_prefetch_cells: 

182 with self.platform.profiler_record(f"pre_forward prefetch:" 

183 f"{prefetch_cell.hsdp_scheduler.hsdp_state.module_name}"): 

184 prefetch_cell.hsdp_scheduler.hsdp_state.prefetch() 

185 with self.platform.profiler_record(f"pre_forward unshard:{self.hsdp_state.module_name}"): 

186 self.hsdp_state.unshard() 

187 return args, kwargs 

188 

189 def _lazy_init_all_states(self): 

190 if self._is_root and self.scheduler_ctx.root_module is not None: 

191 for _, module in platform.get_cells_and_names(self.scheduler_ctx.root_module): 

192 hsdp_state = get_hsdp_state(module) 

193 if hsdp_state: 

194 hsdp_state.lazy_init() 

195 

196 def _init_params_fqn(self): # pylint: disable=W0212 

197 if not self._is_root or self.scheduler_ctx.root_module is None: 

198 return 

199 # Build a map from original (sharded) parameter tensor → hsdp_param wrapper, 

200 # covering both sharded hsdp_params and replicate_params. 

201 param_to_hsdp_param = {} 

202 for _, module in platform.get_cells_and_names(self.scheduler_ctx.root_module): 

203 hsdp_state = get_hsdp_state(module) 

204 if hsdp_state is None: 

205 continue 

206 for hsdp_param in hsdp_state._iter_managed_params(): # pylint: disable=W0212 

207 orig_param = hsdp_param.sharded_param 

208 # Shared parameters: keep only the first mapping to preserve the 

209 # first-seen FQN (consistent with the deduplication in _init_hsdp_params). 

210 if orig_param not in param_to_hsdp_param: 

211 param_to_hsdp_param[orig_param] = hsdp_param 

212 

213 # Walk the full parameter tree and assign FQNs; skip params already seen 

214 # (shared-parameter deduplication: first name wins). 

215 visited_params = set() 

216 for param_name, parameter in platform.parameters_dict(self.scheduler_ctx.root_module): 

217 if parameter in visited_params: 

218 continue 

219 visited_params.add(parameter) 

220 hsdp_param = param_to_hsdp_param.get(parameter) 

221 if hsdp_param is not None: 

222 hsdp_param._param_fqn = param_name # pylint: disable=W0212 

223 

224 # pylint: disable=W0613, R1710 

225 def _hsdp_forward_hook(self, cell, inputs, outputs): 

226 """Forward hook to shard parameter for saving memory.""" 

227 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD: 

228 return 

229 self.scheduler_state = FSDPSchedulerState.FORWARD 

230 if self.reshard_after_forward: 

231 with self.platform.profiler_record(f"forward reshard:{self.hsdp_state.module_name}"): 

232 self.hsdp_state.shard(shard_replicate=False) 

233 if self.mp_policy.output_dtype is not None: 

234 outputs = self.platform.apply_to_tensors( 

235 functools.partial(self.platform.cast_fp_tensor, self.mp_policy.output_dtype), 

236 outputs, 

237 ) 

238 return outputs 

239 

240 # pylint: disable=W0613 

241 def _hsdp_backward_pre_hook(self, cell, grad_outputs): 

242 """Backward pre hook to unsharded parameter for backward process.""" 

243 self.scheduler_state = FSDPSchedulerState.PRE_BACKWARD 

244 for prefetch_cell in self.backward_prefetch_cells: 

245 with self.platform.profiler_record(f"pre_backward prefetch:" 

246 f"{prefetch_cell.hsdp_scheduler.hsdp_state.module_name}"): 

247 prefetch_cell.hsdp_scheduler.hsdp_state.prefetch(unshard_replicate=False) 

248 if self.reshard_after_forward: 

249 with self.platform.profiler_record(f"pre_backward unshard:{self.hsdp_state.module_name}"): 

250 self.hsdp_state.unshard(unshard_replicate=False) 

251 

252 # pylint: disable=W0613 

253 def _hsdp_backward_hook(self, cell, grad_inputs, grad_outputs): 

254 """Backward hook to shard parameter for optimizer process or saving memory.""" 

255 self.scheduler_state = FSDPSchedulerState.BACKWARD 

256 with self.platform.profiler_record(f"post_backward:{self.hsdp_state.module_name}"): 

257 self.hsdp_state.post_backward() 

258 if self._fsdp_group_post_pending is not None: 

259 self._fsdp_group_post_pending.clear() 

260 

261 # pylint: disable=W0613 

262 def _grouped_forward_pre_hook_skip(self, cell, args, kwargs): 

263 """Return value when grouped pre-forward should not run (first module already did). 

264 

265 Default matches MindSpore Cell forward pre-hooks (explicit ``(args, kwargs)``). 

266 ``TorchHSDPSchedulerV2`` overrides this to return ``None`` (``nn.Module`` idiom). 

267 """ 

268 return args, kwargs 

269 

270 def _grouped_forward_post_hook_skip(self, outputs): 

271 """Return value when grouped post-forward is deferred to a later module in the group. 

272 

273 Default returns ``outputs`` (MindSpore). ``TorchHSDPSchedulerV2`` overrides to ``None``. 

274 """ 

275 return outputs 

276 

277 def _grouped_forward_pre_hook(self, cell, args, kwargs): 

278 """Run FSDP pre-forward only for the first module in the group (PyTorch FSDP2-aligned).""" 

279 pending = self._fsdp_group_post_pending 

280 if pending is None: 

281 return self._forward_pre_hook(cell, args, kwargs) 

282 if len(pending) == 0: 

283 pending.update(self.modules) 

284 return self._forward_pre_hook(cell, args, kwargs) 

285 return self._grouped_forward_pre_hook_skip(cell, args, kwargs) 

286 

287 def _make_grouped_forward_post_hook(self, mod): 

288 """Build post-forward hook: last module in the group runs reshard + output backward hooks.""" 

289 

290 def grouped_post_hook(cell, inputs, outputs): 

291 pending = self._fsdp_group_post_pending 

292 if pending is None: 

293 return self._forward_hook(cell, inputs, outputs) 

294 pending.discard(mod) 

295 if len(pending) == 0: 

296 return self._forward_hook(cell, inputs, outputs) 

297 return self._grouped_forward_post_hook_skip(outputs) 

298 

299 return grouped_post_hook 

300 

301 def set_forward_prefetch_cells(self, hsdp_cell_list: List[Any]) -> None: 

302 """Set cells prefetched during forward. 

303 

304 Args: 

305 hsdp_cell_list: HSDP cells to prefetch ahead of forward. 

306 """ 

307 self.forward_prefetch_cells = hsdp_cell_list 

308 

309 def set_backward_prefetch_cells(self, hsdp_cell_list: List[Any]) -> None: 

310 """Set cells prefetched during backward. 

311 

312 Args: 

313 hsdp_cell_list: HSDP cells to prefetch ahead of backward. 

314 """ 

315 self.backward_prefetch_cells = hsdp_cell_list 

316 

317 def _disable_forward_prefetch_for_recompute(self) -> None: 

318 """Temporarily disable forward prefetch during activation recompute.""" 

319 self._backup_forward_fetch = self.forward_prefetch_cells 

320 self.forward_prefetch_cells = [] 

321 

322 def _restore_forward_prefetch_after_recompute(self) -> bool: 

323 """Restore forward prefetch list after a recompute forward hook finishes.""" 

324 if self._backup_forward_fetch is None: 

325 return False 

326 self.forward_prefetch_cells = self._backup_forward_fetch 

327 self._backup_forward_fetch = None 

328 return True