Coverage for hyper_parallel / platform / torch / fully_shard / scheduler.py: 83%
63 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Torch HSDP scheduler"""
16import torch
17from torch.autograd import Variable
18from torch.utils._pytree import tree_flatten, tree_unflatten
19from hyper_parallel.core.fully_shard.hsdp_scheduler import HSDPSchedulerV2, FSDPSchedulerState
20from hyper_parallel.platform.torch.fully_shard.hook_function import PostBackwardFunction
21from hyper_parallel.platform.torch.fully_shard.state import TorchHSDPStateV2
22from hyper_parallel.platform.torch.fully_shard.utils import FSDPMeshInfo, HSDPMeshInfo
23from hyper_parallel.platform import get_platform
27class TorchHSDPSchedulerV2(HSDPSchedulerV2):
28 """TorchHSDPScheduler is used to implement optimizer level."""
30 def _register_hooks(self):
31 """Register hooks."""
32 self._register_forward_backward_hooks()
34 def _init_platform(self):
35 """Initialize the platform."""
36 from hyper_parallel.platform.torch.platform import TorchPlatform
37 self.platform = get_platform()
38 if not isinstance(self.platform, TorchPlatform):
39 raise ValueError(f"TorchHSDPSchedulerV2 expect TorchPlatform, but got type: {type(self.platform)}")
41 def _new_cell_state(self):
42 """Create a new cell state for torch."""
43 if self.mesh.ndim not in (1, 2):
44 raise ValueError("fully_shard only support 1D and 2D mesh.")
45 elif self.mesh.ndim == 1:
46 # FSDP2
47 self.mesh_info = FSDPMeshInfo(mesh=self.mesh, shard_mesh_dim=0)
48 else:
49 # HSDP
50 self.mesh_info = HSDPMeshInfo(mesh=self.mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
51 self.hsdp_state = TorchHSDPStateV2(self.cell, self.mesh_info, self.config, self.platform, self.device)
53 def _new_grad_hook(self):
54 """
55 Create and initialize a new TorchHSDPGradHook instance.
57 This method instantiates the gradient hook component used for handling
58 gradient operations in the HSDP scheduler.
59 """
60 # TorchHSDPScheduler don't need param hook, using param.grad
61 pass
62 # if self.config.comm_async:
63 # self.grad_hook = TorchHSDPAsyncGradHook(self.config, self.platform)
64 # else:
65 # self.grad_hook = TorchHSDPGradHook(self.config, self.platform)
67 def _register_post_backward_hook(self, args, kwargs):
68 """Register backward hook using backward function."""
69 args_list, args_spec = tree_flatten(args)
70 kwargs_list, kwargs_spec = tree_flatten(kwargs)
71 args_kwargs_list = list(args_list) + list(kwargs_list)
72 args_kwargs_list = PostBackwardFunction.apply(self, *args_kwargs_list)
73 args_list = args_kwargs_list[: len(args_list)]
74 kwargs_list = args_kwargs_list[len(args_list) :]
75 args = tree_unflatten(args_list, args_spec)
76 kwargs = tree_unflatten(kwargs_list, kwargs_spec)
77 return args, kwargs
79 def _forward_pre_hook(self, cell, args, kwargs):
80 """Execute forward pre hook and set up backward hook."""
81 args, kwargs = self._hsdp_forward_pre_hook(cell, args, kwargs)
82 return self._register_post_backward_hook(args, kwargs)
84 def _register_backward_pre_hook(self, outputs):
85 """Register output hook to trigger backward pre hook."""
86 flat_outputs, _ = tree_flatten(outputs)
87 for output in flat_outputs:
88 if isinstance(output, torch.Tensor) and output.requires_grad:
89 output.register_hook(self._backward_pre_hook)
90 return outputs
92 def _forward_hook(self, cell, inputs, outputs):
93 """Execute forward hook."""
94 self._register_backward_pre_hook(outputs)
95 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD:
96 return
97 outputs = self._hsdp_forward_hook(cell, inputs, outputs)
98 return outputs
100 # pylint: disable=W0212
101 def _backward_pre_hook(self, grad):
102 """Execute backward pre hook."""
103 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD:
104 return grad
105 self._hsdp_backward_pre_hook(self.cell, None)
106 Variable._execution_engine.queue_callback(self._backward_hook)
107 return grad
109 def _backward_hook(self):
110 """Execute backward hook."""
111 if self.scheduler_state == FSDPSchedulerState.BACKWARD:
112 return
113 self._hsdp_backward_hook(self.cell, None, None)
115 def _register_forward_backward_hooks(self):
116 """Register module forward and backward hook."""
117 self.cell.register_forward_pre_hook(self._forward_pre_hook, with_kwargs=True)
118 self.cell.register_forward_hook(self._forward_hook)