Coverage for hyper_parallel / platform / torch / hsdp / scheduler.py: 82%
60 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Torch HSDP scheduler"""
16import torch
17from torch.autograd import Variable
18from torch.utils._pytree import tree_flatten, tree_unflatten
19from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel
20from hyper_parallel.core.hsdp.hsdp_scheduler import HSDPScheduler, FSDPSchedulerState
21from hyper_parallel.platform.torch.hsdp.hook_function import PostBackwardFunction
22from hyper_parallel.platform.torch.hsdp.state import TorchHSDPState
23from hyper_parallel.platform.torch.hsdp.grad_hook import TorchHSDPGradHook
24from hyper_parallel.platform.torch.hsdp.async_grad_hook import TorchHSDPAsyncGradHook
25from hyper_parallel.platform import get_platform
28class TorchHSDPScheduler(HSDPScheduler):
29 """TorchHSDPScheduler is used to implement optimizer level."""
31 def _init_platform(self):
32 """Initialize the platform."""
33 self.platform = get_platform()
35 def _new_cell_state(self):
36 """Create a new cell state for torch."""
37 self.hsdp_state = TorchHSDPState(self.cell, self.config, self.platform)
39 def _new_grad_hook(self):
40 """
41 Create and initialize a new TorchHSDPGradHook instance.
43 This method instantiates the gradient hook component used for handling
44 gradient operations in the HSDP scheduler.
45 """
46 if self.config.comm_async:
47 self.grad_hook = TorchHSDPAsyncGradHook(self.config, self.platform)
48 else:
49 self.grad_hook = TorchHSDPGradHook(self.config, self.platform)
51 def _register_backward_hook(self, args, kwargs):
52 """Register backward hook using backward function."""
53 args_list, args_spec = tree_flatten(args)
54 kwargs_list, kwargs_spec = tree_flatten(kwargs)
55 args_kwargs_list = list(args_list) + list(kwargs_list)
56 args_kwargs_list = PostBackwardFunction.apply(self, *args_kwargs_list)
57 args_list = args_kwargs_list[: len(args_list)]
58 kwargs_list = args_kwargs_list[len(args_list) :]
59 args = tree_unflatten(args_list, args_spec)
60 kwargs = tree_unflatten(kwargs_list, kwargs_spec)
61 return args, kwargs
63 def _forward_pre_hook(self, cell, args, kwargs):
64 """Execute forward pre hook and set up backward hook."""
65 self._hsdp_forward_pre_hook(cell, args)
66 return self._register_backward_hook(args, kwargs)
68 def _register_backward_pre_hook(self, outputs):
69 """Register output hook to trigger backward pre hook."""
70 flat_outputs, _ = tree_flatten(outputs)
71 for output in flat_outputs:
72 if isinstance(output, torch.Tensor) and output.requires_grad:
73 output.register_hook(self._backward_pre_hook)
74 return outputs
76 def _forward_hook(self, cell, inputs, outputs):
77 """Execute forward hook."""
78 self._register_backward_pre_hook(outputs)
79 if self.shard_level != OptimizerLevel.SHARD_OPT_GRAD_PARAM:
80 return
81 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD:
82 return
83 self._hsdp_forward_hook(cell, inputs, outputs)
85 # pylint: disable=W0212
86 def _backward_pre_hook(self, grad):
87 """Execute backward pre hook."""
88 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD:
89 return grad
90 self._hsdp_backward_pre_hook(self.cell, None)
91 Variable._execution_engine.queue_callback(self._backward_hook)
92 return grad
94 def _backward_hook(self):
95 """Execute backward hook."""
96 if self.scheduler_state == FSDPSchedulerState.BACKWARD:
97 return
98 if self.requires_acc_grad and self.shard_level != OptimizerLevel.SHARD_OPT_GRAD_PARAM:
99 self._hsdp_acc_backward_hook(self.cell, None, None)
100 else:
101 self._hsdp_backward_hook(self.cell, None, None)
103 def _register_forward_backward_hooks(self):
104 """Register module forward and backward hook."""
105 self.cell.register_forward_pre_hook(self._forward_pre_hook, with_kwargs=True)
106 self.cell.register_forward_hook(self._forward_hook)