Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / fully_shard / hsdp_scheduler.py: 44%
180 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP scheduler"""
16import functools
17from typing import Any, List, Optional, Tuple, Union
19from hyper_parallel.platform import get_platform
20from hyper_parallel.core.dtensor.device_mesh import DeviceMesh
21from hyper_parallel.core.fully_shard.hsdp_utils import (
22 FSDPSchedulerState,
23 HSDPConfigV2,
24 get_managed_modules_parameters,
25 get_hsdp_state
26)
28platform = get_platform()
31class HSDPSchedulerContext:
32 """HSDPSchedulerContext"""
34 def __init__(self) -> None:
35 # Currently only record is_last_backward flag for scheduler context.
36 self.is_last_backward: bool = True
37 # flag to identify "root_module"
38 self.root_module = None
41class HSDPSchedulerV2:
42 """HSDPScheduler is used to scheduler hsdp"""
43 root_bp_state = False
44 def __init__(self, cell: Union[platform.Module, Tuple[platform.Module, ...]], mesh,
45 reshard_after_forward, shard_placement_fn,
46 mp_policy, offload_policy, ignored_params, replicate_params, device, comm_fusion,
47 comm_fusion_zero_copy=False):
48 """init hsdp scheduler.
50 Args:
51 cell: A single platform.Module or tuple of platform.Module to manage as one FSDP unit.
52 """
53 self.modules = (cell,) if isinstance(cell, platform.Module) else tuple(cell)
54 self.cell = self.modules[0]
55 self.mesh: DeviceMesh = mesh
56 self.reshard_after_forward = reshard_after_forward
57 self.shard_placement_fn = shard_placement_fn
58 self.mp_policy = mp_policy
59 self.offload_policy = offload_policy
60 self.ignored_params = ignored_params
61 self.replicate_params = replicate_params
62 self.device = device
63 self.scheduler_state = None
64 self.forward_prefetch_cells = []
65 self.backward_prefetch_cells = []
66 self._backup_forward_fetch = None
67 # Flag to identify root module.
68 self._is_root = False
69 # module and its all sub-modules share one same 'HSDPSchedulerContext'
70 self.scheduler_ctx = HSDPSchedulerContext()
71 # When ``fully_shard`` is given multiple root modules, forward pre/post hooks coordinate
72 # so unshard / PostBackward / reshard run once per forward (aligned with PyTorch FSDP2).
73 self._fsdp_group_post_pending: Optional[set] = set() if len(self.modules) > 1 else None
74 self.config = HSDPConfigV2(
75 mesh,
76 reshard_after_forward,
77 shard_placement_fn,
78 mp_policy,
79 offload_policy,
80 ignored_params,
81 replicate_params,
82 comm_fusion=comm_fusion,
83 comm_fusion_zero_copy=comm_fusion_zero_copy,
84 )
85 self._init_platform()
86 self._new_cell_state()
87 self._register_hooks()
89 def _init_platform(self):
90 """Initialize the platform."""
91 raise NotImplementedError("HSDPScheduler subclasses must implement _init_platform")
93 def _new_cell_state(self):
94 """Create a new cell state."""
95 raise NotImplementedError("HSDPScheduler subclasses must implement _new_cell_state")
97 def _register_hooks(self):
98 """Register hooks."""
99 raise NotImplementedError("HSDPScheduler subclasses must implement _register_hooks.")
101 def _register_forward_backward_hooks(self):
102 """Register module forward and backward hook."""
103 raise NotImplementedError("HSDPScheduler subclasses must implement _register_forward_backward_hooks.")
105 def _get_managed_params(self):
106 """Return deduplicated parameters from all managed modules."""
107 return get_managed_modules_parameters(self.modules, self.ignored_params)
109 def set_reshard_after_forward(self, reshard_after_forward: bool) -> None:
110 """Set reshard_after_forward flag.
112 Args:
113 reshard_after_forward: Whether to reshard parameters after forward.
114 """
115 if not isinstance(reshard_after_forward, bool):
116 raise ValueError(f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}")
117 self.reshard_after_forward = reshard_after_forward
118 self.config.reshard_after_forward = reshard_after_forward
120 def set_reshard_after_backward(self, reshard_after_backward: bool) -> None:
121 """Set reshard_after_backward flag.
123 Args:
124 reshard_after_backward: Whether to reshard after backward completes.
125 """
126 if not isinstance(reshard_after_backward, bool):
127 raise ValueError(f"reshard_after_backward should be a bool, got {type(reshard_after_backward)}")
128 if self.hsdp_state is not None:
129 self.hsdp_state.reshard_after_backward = reshard_after_backward
131 def set_requires_all_reduce(self, requires_all_reduce: bool) -> None:
132 """Set requires_all_reduce flag.
134 Args:
135 requires_all_reduce: Whether this unit participates in all-reduce.
136 """
137 if not isinstance(requires_all_reduce, bool):
138 raise ValueError(f"requires_all_reduce should be a bool, got {type(requires_all_reduce)}")
139 if self.hsdp_state is not None:
140 self.hsdp_state.requires_all_reduce = requires_all_reduce
142 def set_requires_grad_sync(self, requires_grad_sync: bool) -> None:
143 """Set flag controlling whether gradients are synchronized.
145 Args:
146 requires_grad_sync: When True, enable grad sync for this scheduler.
147 """
148 if not isinstance(requires_grad_sync, bool):
149 raise ValueError(f"requires_grad_sync should be a bool, got {type(requires_grad_sync)}")
150 self.hsdp_state.set_requires_grad_sync(requires_grad_sync)
152 # pylint: disable=W0613
153 def _hsdp_forward_pre_hook(self, cell, args, kwargs):
154 """Forward pre hook to unsharded parameter for forward process."""
155 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD:
156 return args, kwargs
157 if HSDPSchedulerV2.root_bp_state:
158 self._disable_forward_prefetch_for_recompute()
159 if self.scheduler_ctx.root_module is None:
160 self.scheduler_ctx.root_module = self.cell
161 self._is_root = True
162 for _, module in platform.get_cells_and_names(self.scheduler_ctx.root_module):
163 from hyper_parallel.core.fully_shard.api import HSDPModule # pylint: disable=C0415
164 if isinstance(module, HSDPModule):
165 submod_scheduler = getattr(module, "hsdp_scheduler", None)
166 if submod_scheduler and submod_scheduler.scheduler_ctx is not self.scheduler_ctx:
167 submod_scheduler.scheduler_ctx = self.scheduler_ctx
169 if not self._is_root and not self.hsdp_state.module_name:
170 for module_name, module in platform.get_cells_and_names(self.scheduler_ctx.root_module):
171 if module == self.cell:
172 self.hsdp_state.module_name = module_name
173 break
174 self.scheduler_state = FSDPSchedulerState.PRE_FORWARD
175 self._init_params_fqn()
176 self._lazy_init_all_states()
177 if self.mp_policy.cast_forward_inputs and self.mp_policy.param_dtype:
178 cast_fn = functools.partial(self.platform.cast_fp_tensor, self.mp_policy.param_dtype)
179 args = self.platform.apply_to_tensors(cast_fn, args)
180 kwargs = self.platform.apply_to_tensors(cast_fn, kwargs)
181 for prefetch_cell in self.forward_prefetch_cells:
182 with self.platform.profiler_record(f"pre_forward prefetch:"
183 f"{prefetch_cell.hsdp_scheduler.hsdp_state.module_name}"):
184 prefetch_cell.hsdp_scheduler.hsdp_state.prefetch()
185 with self.platform.profiler_record(f"pre_forward unshard:{self.hsdp_state.module_name}"):
186 self.hsdp_state.unshard()
187 return args, kwargs
189 def _lazy_init_all_states(self):
190 if self._is_root and self.scheduler_ctx.root_module is not None:
191 for _, module in platform.get_cells_and_names(self.scheduler_ctx.root_module):
192 hsdp_state = get_hsdp_state(module)
193 if hsdp_state:
194 hsdp_state.lazy_init()
196 def _init_params_fqn(self): # pylint: disable=W0212
197 if not self._is_root or self.scheduler_ctx.root_module is None:
198 return
199 # Build a map from original (sharded) parameter tensor → hsdp_param wrapper,
200 # covering both sharded hsdp_params and replicate_params.
201 param_to_hsdp_param = {}
202 for _, module in platform.get_cells_and_names(self.scheduler_ctx.root_module):
203 hsdp_state = get_hsdp_state(module)
204 if hsdp_state is None:
205 continue
206 for hsdp_param in hsdp_state._iter_managed_params(): # pylint: disable=W0212
207 orig_param = hsdp_param.sharded_param
208 # Shared parameters: keep only the first mapping to preserve the
209 # first-seen FQN (consistent with the deduplication in _init_hsdp_params).
210 if orig_param not in param_to_hsdp_param:
211 param_to_hsdp_param[orig_param] = hsdp_param
213 # Walk the full parameter tree and assign FQNs; skip params already seen
214 # (shared-parameter deduplication: first name wins).
215 visited_params = set()
216 for param_name, parameter in platform.parameters_dict(self.scheduler_ctx.root_module):
217 if parameter in visited_params:
218 continue
219 visited_params.add(parameter)
220 hsdp_param = param_to_hsdp_param.get(parameter)
221 if hsdp_param is not None:
222 hsdp_param._param_fqn = param_name # pylint: disable=W0212
224 # pylint: disable=W0613, R1710
225 def _hsdp_forward_hook(self, cell, inputs, outputs):
226 """Forward hook to shard parameter for saving memory."""
227 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD:
228 return
229 self.scheduler_state = FSDPSchedulerState.FORWARD
230 if self.reshard_after_forward:
231 with self.platform.profiler_record(f"forward reshard:{self.hsdp_state.module_name}"):
232 self.hsdp_state.shard(shard_replicate=False)
233 if self.mp_policy.output_dtype is not None:
234 outputs = self.platform.apply_to_tensors(
235 functools.partial(self.platform.cast_fp_tensor, self.mp_policy.output_dtype),
236 outputs,
237 )
238 return outputs
240 # pylint: disable=W0613
241 def _hsdp_backward_pre_hook(self, cell, grad_outputs):
242 """Backward pre hook to unsharded parameter for backward process."""
243 self.scheduler_state = FSDPSchedulerState.PRE_BACKWARD
244 for prefetch_cell in self.backward_prefetch_cells:
245 with self.platform.profiler_record(f"pre_backward prefetch:"
246 f"{prefetch_cell.hsdp_scheduler.hsdp_state.module_name}"):
247 prefetch_cell.hsdp_scheduler.hsdp_state.prefetch(unshard_replicate=False)
248 if self.reshard_after_forward:
249 with self.platform.profiler_record(f"pre_backward unshard:{self.hsdp_state.module_name}"):
250 self.hsdp_state.unshard(unshard_replicate=False)
252 # pylint: disable=W0613
253 def _hsdp_backward_hook(self, cell, grad_inputs, grad_outputs):
254 """Backward hook to shard parameter for optimizer process or saving memory."""
255 self.scheduler_state = FSDPSchedulerState.BACKWARD
256 with self.platform.profiler_record(f"post_backward:{self.hsdp_state.module_name}"):
257 self.hsdp_state.post_backward()
258 if self._fsdp_group_post_pending is not None:
259 self._fsdp_group_post_pending.clear()
261 # pylint: disable=W0613
262 def _grouped_forward_pre_hook_skip(self, cell, args, kwargs):
263 """Return value when grouped pre-forward should not run (first module already did).
265 Default matches MindSpore Cell forward pre-hooks (explicit ``(args, kwargs)``).
266 ``TorchHSDPSchedulerV2`` overrides this to return ``None`` (``nn.Module`` idiom).
267 """
268 return args, kwargs
270 def _grouped_forward_post_hook_skip(self, outputs):
271 """Return value when grouped post-forward is deferred to a later module in the group.
273 Default returns ``outputs`` (MindSpore). ``TorchHSDPSchedulerV2`` overrides to ``None``.
274 """
275 return outputs
277 def _grouped_forward_pre_hook(self, cell, args, kwargs):
278 """Run FSDP pre-forward only for the first module in the group (PyTorch FSDP2-aligned)."""
279 pending = self._fsdp_group_post_pending
280 if pending is None:
281 return self._forward_pre_hook(cell, args, kwargs)
282 if len(pending) == 0:
283 pending.update(self.modules)
284 return self._forward_pre_hook(cell, args, kwargs)
285 return self._grouped_forward_pre_hook_skip(cell, args, kwargs)
287 def _make_grouped_forward_post_hook(self, mod):
288 """Build post-forward hook: last module in the group runs reshard + output backward hooks."""
290 def grouped_post_hook(cell, inputs, outputs):
291 pending = self._fsdp_group_post_pending
292 if pending is None:
293 return self._forward_hook(cell, inputs, outputs)
294 pending.discard(mod)
295 if len(pending) == 0:
296 return self._forward_hook(cell, inputs, outputs)
297 return self._grouped_forward_post_hook_skip(outputs)
299 return grouped_post_hook
301 def set_forward_prefetch_cells(self, hsdp_cell_list: List[Any]) -> None:
302 """Set cells prefetched during forward.
304 Args:
305 hsdp_cell_list: HSDP cells to prefetch ahead of forward.
306 """
307 self.forward_prefetch_cells = hsdp_cell_list
309 def set_backward_prefetch_cells(self, hsdp_cell_list: List[Any]) -> None:
310 """Set cells prefetched during backward.
312 Args:
313 hsdp_cell_list: HSDP cells to prefetch ahead of backward.
314 """
315 self.backward_prefetch_cells = hsdp_cell_list
317 def _disable_forward_prefetch_for_recompute(self) -> None:
318 """Temporarily disable forward prefetch during activation recompute."""
319 self._backup_forward_fetch = self.forward_prefetch_cells
320 self.forward_prefetch_cells = []
322 def _restore_forward_prefetch_after_recompute(self) -> bool:
323 """Restore forward prefetch list after a recompute forward hook finishes."""
324 if self._backup_forward_fetch is None:
325 return False
326 self.forward_prefetch_cells = self._backup_forward_fetch
327 self._backup_forward_fetch = None
328 return True