Coverage for hyper_parallel / core / hsdp / hsdp_scheduler.py: 90%
89 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP scheduler"""
16from enum import auto, Enum
17from hyper_parallel.core.hsdp.hsdp_utils import HSDPConfig
18from hyper_parallel.core.hsdp.hsdp_grad_hook import HSDPGradHook
19from hyper_parallel.core.hsdp.hsdp_async_grad_hook import HSDPAsyncGradHook
22class FSDPSchedulerState(Enum):
23 """
24 Scheduler state:
25 - PRE_FORWARD:
26 already run hook before forward.
27 - FORWARD:
28 already run hook after forward.
29 - PRE_BACKWARD:
30 already run hook before backward.
31 - PRE_BACKWARD:
32 already run hook after backward.
33 """
34 PRE_FORWARD = auto()
35 FORWARD = auto()
36 PRE_BACKWARD = auto()
37 BACKWARD = auto()
40class HSDPScheduler:
41 """HSDPScheduler is used to implement optimizer level."""
43 def __init__(self, cell, shard_size, threshold, shard_level, requires_acc_grad, grad_scale, use_eager_hook,
44 reduce_dtype, comm_async, comm_fusion, bucket_size):
45 """init hsdp scheduler."""
46 self.cell = cell
47 self.no_param_sharded = shard_size == 1
48 self.shard_level = shard_level
49 self.requires_acc_grad = requires_acc_grad
50 self.requires_grad_sync = False
51 self.scheduler_state = None
53 self.forward_prefetch_cells = []
54 self.backward_prefetch_cells = []
55 self.config = HSDPConfig(
56 shard_size,
57 threshold,
58 requires_acc_grad,
59 grad_scale,
60 shard_level,
61 use_eager_hook,
62 reduce_dtype,
63 comm_async,
64 comm_fusion,
65 bucket_size
66 )
67 self._init_platform()
68 self._new_cell_state()
69 self._new_grad_hook()
70 self._register_hooks()
72 def _init_platform(self):
73 """Initialize the platform."""
74 raise NotImplementedError("HSDPScheduler subclasses must implement _init_platform")
76 def _new_cell_state(self):
77 """Create a new cell state."""
78 raise NotImplementedError("HSDPScheduler subclasses must implement _new_cell_state")
80 def _new_grad_hook(self):
81 """Create a new grad hook."""
82 if self.config.comm_async:
83 self.grad_hook = HSDPAsyncGradHook(self.config, self.platform)
84 else:
85 self.grad_hook = HSDPGradHook(self.config, self.platform)
87 def _register_hooks(self):
88 """Register hooks."""
89 self._register_grad_hook()
90 if self.no_param_sharded:
91 return
92 self._register_forward_backward_hooks()
94 def _register_grad_hook(self):
95 """Register parameter grad hook."""
96 for hsdp_param in self.hsdp_state.hsdp_params:
97 if not hsdp_param.param.requires_grad:
98 continue
99 if self.config.grad_fusion:
100 hsdp_param.param.register_hook(self._get_grad_buffer_hook(hsdp_param))
101 else:
102 hsdp_param.param.register_hook(self.grad_hook.get_hook(hsdp_param))
104 def _register_forward_backward_hooks(self):
105 """Register module forward and backward hook."""
106 raise NotImplementedError("HSDPScheduler subclasses must implement _register_forward_backward_hooks.")
108 def set_requires_grad_sync(self, requires_grad_sync):
109 """Set requires grad sync flag to control gradient sync."""
110 self.requires_grad_sync = requires_grad_sync
111 self.grad_hook.set_requires_grad_sync(requires_grad_sync)
112 self.hsdp_state.set_requires_grad_sync(requires_grad_sync)
114 def zero_grads(self):
115 """Set gradient to zero."""
116 if self.requires_acc_grad:
117 self.hsdp_state.zero_grads()
119 # pylint: disable=W0613
120 def _hsdp_forward_pre_hook(self, cell, inputs):
121 """Forward pre hook to unsharded parameter for forward process."""
122 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD:
123 return
124 self.scheduler_state = FSDPSchedulerState.PRE_FORWARD
125 if len(inputs) > 0:
126 self.platform.set_tensor_requires_grad(inputs[0])
127 self.hsdp_state.unshard()
128 for prefetch_cell in self.forward_prefetch_cells:
129 prefetch_cell.hsdp_scheduler.hsdp_state.prefetch()
131 # pylint: disable=W0613
132 def _hsdp_forward_hook(self, cell, inputs, outputs):
133 """Forward hook to shard parameter for saving memory."""
134 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD:
135 return
136 self.scheduler_state = FSDPSchedulerState.FORWARD
137 self.hsdp_state.shard()
139 # pylint: disable=W0613
140 def _hsdp_backward_pre_hook(self, cell, grad_outputs):
141 """Backward pre hook to unsharded parameter for backward process."""
142 self.scheduler_state = FSDPSchedulerState.PRE_BACKWARD
143 self.hsdp_state.unshard()
144 for prefetch_cell in self.backward_prefetch_cells:
145 prefetch_cell.hsdp_scheduler.hsdp_state.prefetch()
147 # pylint: disable=W0613
148 def _hsdp_backward_hook(self, cell, grad_inputs, grad_outputs):
149 """Backward hook to shard parameter for optimizer process or saving memory."""
150 self.scheduler_state = FSDPSchedulerState.BACKWARD
151 self.hsdp_state.shard()
153 # pylint: disable=W0613
154 def _hsdp_acc_backward_hook(self, cell, grad_inputs, grad_outputs):
155 """Backward hook to shard parameter for grad accumulation when requires_grad_sync is True."""
156 self.scheduler_state = FSDPSchedulerState.BACKWARD
157 if self.requires_grad_sync:
158 self.hsdp_state.shard()
160 def _get_grad_buffer_hook(self, hsdp_param):
161 """Set grad ready."""
162 def hook(grad):
163 hsdp_param.grad = grad
164 self.hsdp_state.set_grad_ready(hsdp_param)
165 return grad
166 return hook
168 def set_forward_prefetch_cells(self, hsdp_cell_list):
169 """Set forward prefetch cells."""
170 self.forward_prefetch_cells = hsdp_cell_list
172 def set_backward_prefetch_cells(self, hsdp_cell_list):
173 """Set backward prefetch cells."""
174 self.backward_prefetch_cells = hsdp_cell_list