Coverage for hyper_parallel / core / fully_shard / hsdp_scheduler.py: 69%
115 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP scheduler"""
16import functools
17from hyper_parallel.core.device_mesh import DeviceMesh
18from hyper_parallel.core.fully_shard.hsdp_utils import HSDPConfigV2, FSDPSchedulerState
19from hyper_parallel.core.fully_shard.hsdp_grad_hook import HSDPGradHook
20from hyper_parallel.core.fully_shard.hsdp_async_grad_hook import HSDPAsyncGradHook
23class HSDPSchedulerContext:
24 """HSDPSchedulerContext"""
26 def __init__(self) -> None:
27 self.post_backward_final_callback_queued: bool = False
28 self.is_last_backward: bool = True
29 self.post_optim_event = None
32class HSDPSchedulerV2:
33 """HSDPScheduler is used to scheduler hsdp"""
34 def __init__(self, cell, mesh, reshard_after_forward, shard_placement_fn,
35 mp_policy, offload_policy, ignored_params, device):
36 """init hsdp scheduler."""
37 self.cell = cell
38 self.mesh: DeviceMesh = mesh
39 self.reshard_after_forward = reshard_after_forward
40 self.shard_placement_fn = shard_placement_fn
41 self.mp_policy = mp_policy
42 self.offload_policy = offload_policy
43 self.ignored_params = ignored_params
44 self.device = device
45 self.scheduler_state = None
46 self.forward_prefetch_cells = []
47 self.backward_prefetch_cells = []
48 self.scheduler_ctx = HSDPSchedulerContext()
49 self.config = HSDPConfigV2(
50 mesh,
51 reshard_after_forward,
52 shard_placement_fn,
53 mp_policy,
54 offload_policy,
55 ignored_params
56 )
57 self._init_platform()
58 self._new_cell_state()
59 self._new_grad_hook()
60 self._register_hooks()
62 def _init_platform(self):
63 """Initialize the platform."""
64 raise NotImplementedError("HSDPScheduler subclasses must implement _init_platform")
66 def _new_cell_state(self):
67 """Create a new cell state."""
68 raise NotImplementedError("HSDPScheduler subclasses must implement _new_cell_state")
70 def _new_grad_hook(self):
71 """Create a new grad hook."""
72 if self.config.comm_async:
73 self.grad_hook = HSDPAsyncGradHook(self.config, self.platform)
74 else:
75 self.grad_hook = HSDPGradHook(self.config, self.platform)
77 def _register_hooks(self):
78 """Register hooks."""
79 raise NotImplementedError("HSDPScheduler subclasses must implement _register_hooks.")
81 def _register_grad_hook(self):
82 """Register parameter grad hook."""
83 for hsdp_param in self.hsdp_state.hsdp_params:
84 if not hsdp_param.param.requires_grad:
85 continue
86 if self.config.grad_fusion:
87 hsdp_param.param.register_hook(self._get_grad_buffer_hook(hsdp_param))
88 else:
89 hsdp_param.param.register_hook(self.grad_hook.get_hook(hsdp_param))
91 def _register_forward_backward_hooks(self):
92 """Register module forward and backward hook."""
93 raise NotImplementedError("HSDPScheduler subclasses must implement _register_forward_backward_hooks.")
95 def set_reshard_after_forward(self, reshard_after_forward: bool):
96 """set reshard_after_forward flag"""
97 if not isinstance(reshard_after_forward, bool):
98 raise ValueError(f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}")
99 self.reshard_after_forward = reshard_after_forward
100 self.config.reshard_after_forward = reshard_after_forward
102 def set_reshard_after_backward(self, reshard_after_backward: bool):
103 """set reshard_after_backward flag"""
104 if not isinstance(reshard_after_backward, bool):
105 raise ValueError(f"reshard_after_backward should be a bool, got {type(reshard_after_backward)}")
106 if self.hsdp_state is not None:
107 self.hsdp_state.reshard_after_backward = reshard_after_backward
109 def set_requires_all_reduce(self, requires_all_reduce: bool):
110 """set requires_all_reduce flag"""
111 if not isinstance(requires_all_reduce, bool):
112 raise ValueError(f"requires_all_reduce should be a bool, got {type(requires_all_reduce)}")
113 if self.hsdp_state is not None:
114 self.hsdp_state.all_reduce_grads = requires_all_reduce
116 def set_requires_grad_sync(self, requires_grad_sync: bool):
117 """Set requires grad sync flag to control gradient sync."""
118 if not isinstance(requires_grad_sync, bool):
119 raise ValueError(f"requires_grad_sync should be a bool, got {type(requires_grad_sync)}")
120 self.requires_grad_sync = requires_grad_sync
121 self.hsdp_state.set_requires_grad_sync(requires_grad_sync)
123 def zero_grads(self):
124 """Set gradient to zero."""
125 if self.requires_acc_grad:
126 self.hsdp_state.zero_grads()
128 # pylint: disable=W0613
129 def _hsdp_forward_pre_hook(self, cell, args, kwargs):
130 """Forward pre hook to unsharded parameter for forward process."""
131 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD:
132 return args, kwargs
133 self.scheduler_state = FSDPSchedulerState.PRE_FORWARD
134 if self.mp_policy.cast_forward_inputs and self.mp_policy.param_dtype:
135 cast_fn = functools.partial(self.platform.cast_fp_tensor, self.mp_policy.param_dtype)
136 args = self.platform.apply_to_tensors(cast_fn, args)
137 kwargs = self.platform.apply_to_tensors(cast_fn, kwargs)
138 self.hsdp_state.unshard()
139 for prefetch_cell in self.forward_prefetch_cells:
140 prefetch_cell.hsdp_scheduler.hsdp_state.prefetch()
141 return args, kwargs
143 # pylint: disable=W0613
144 def _hsdp_forward_hook(self, cell, inputs, outputs):
145 """Forward hook to shard parameter for saving memory."""
146 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD:
147 return
148 self.scheduler_state = FSDPSchedulerState.FORWARD
149 if self.reshard_after_forward:
150 self.hsdp_state.shard()
151 if self.mp_policy.output_dtype is not None:
152 outputs = self.platform.apply_to_tensors(
153 functools.partial(self.platform.cast_fp_tensor, self.mp_policy.output_dtype),
154 outputs,
155 )
156 return outputs
158 # pylint: disable=W0613
159 def _hsdp_backward_pre_hook(self, cell, grad_outputs):
160 """Backward pre hook to unsharded parameter for backward process."""
161 self.scheduler_state = FSDPSchedulerState.PRE_BACKWARD
162 if self.reshard_after_forward:
163 self.hsdp_state.unshard()
164 for prefetch_cell in self.backward_prefetch_cells:
165 prefetch_cell.hsdp_scheduler.hsdp_state.prefetch()
167 # pylint: disable=W0613
168 def _hsdp_backward_hook(self, cell, grad_inputs, grad_outputs):
169 """Backward hook to shard parameter for optimizer process or saving memory."""
170 self.scheduler_state = FSDPSchedulerState.BACKWARD
171 self.hsdp_state.post_backward()
174 def _get_grad_buffer_hook(self, hsdp_param):
175 """Set grad ready."""
177 def hook(grad):
178 hsdp_param.grad = grad
179 self.hsdp_state.set_grad_ready(hsdp_param)
180 return grad
182 return hook
184 def set_forward_prefetch_cells(self, hsdp_cell_list):
185 """Set forward prefetch cells."""
186 self.forward_prefetch_cells = hsdp_cell_list
188 def set_backward_prefetch_cells(self, hsdp_cell_list):
189 """Set backward prefetch cells."""
190 self.backward_prefetch_cells = hsdp_cell_list
192 def set_requires_allreuce(self, requires_all_reduce):
193 """set_require_allreuce for HSDP"""
194 self.hsdp_state.requires_all_reduce = requires_all_reduce
196 def reshard(self,):
197 """Reshard parameters after forward or backward."""
198 self.hsdp_state.reshard()