Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / scheduler.py: 42%
122 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Torch HSDP scheduler"""
16import inspect
17import torch
18from typing import List
19from torch.autograd import Variable
20from torch.utils._pytree import tree_flatten, tree_unflatten
21from hyper_parallel.core.dtensor.dtensor import DTensor
22from hyper_parallel.core.fully_shard.hsdp_scheduler import HSDPSchedulerV2, FSDPSchedulerState
23from hyper_parallel.core.fully_shard.utils import FSDPMeshInfo, DDPMeshInfo, HSDPMeshInfo
24from hyper_parallel.platform.torch.fully_shard.hook_function import PostBackwardFunction
25from hyper_parallel.platform.torch.fully_shard.state import TorchHSDPStateV2
26from hyper_parallel.platform.torch.fully_shard.param_group import get_comm_ctx
27from hyper_parallel.platform import get_platform
30class TorchHSDPSchedulerV2(HSDPSchedulerV2):
31 """TorchHSDPScheduler is used to implement optimizer level."""
33 def __init__(self, *args, **kwargs):
34 """Initialize TorchHSDPSchedulerV2 and register forward/backward hooks."""
35 super().__init__(*args, **kwargs)
37 def _register_hooks(self):
38 """Register hooks."""
39 self._register_forward_backward_hooks()
41 def _init_platform(self):
42 """Initialize the platform."""
43 # pylint: disable=C0415
44 from hyper_parallel.platform.torch.platform import TorchPlatform
45 self.platform = get_platform()
46 if not isinstance(self.platform, TorchPlatform):
47 raise ValueError(f"TorchHSDPSchedulerV2 expect TorchPlatform, but got type: {type(self.platform)}")
49 def _new_cell_state(self):
50 """Create a new cell state for torch."""
51 params = self._get_managed_params()
52 if self.mesh is None:
53 compat_meshes = [
54 param.device_mesh for param in params if isinstance(param, DTensor)
55 ]
56 compat_mesh = compat_meshes[0] if compat_meshes else None
57 if compat_mesh is None:
58 raise ValueError(
59 "Cannot build fully_shard compatibility mesh_info "
60 "without a DTensor parameter mesh."
61 )
62 compat_mesh_hash = compat_mesh.to_hash()
63 for param_mesh in compat_meshes[1:]:
64 if param_mesh.to_hash() != compat_mesh_hash:
65 raise ValueError(
66 "fully_shard compatibility mode requires all DTensor parameters to share the same mesh."
67 )
68 self.mesh_info = DDPMeshInfo(mesh=compat_mesh, replicate_mesh_dim=0)
69 elif self.mesh.ndim == 1:
70 self.mesh_info = FSDPMeshInfo(mesh=self.mesh, shard_mesh_dim=0)
71 elif self.mesh.ndim == 2:
72 self.mesh_info = HSDPMeshInfo(mesh=self.mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
73 else:
74 raise ValueError(
75 "fully_shard only supports explicit 1D DP/FSDP meshes or 2D HSDP meshes. "
76 f"Got mesh.ndim={self.mesh.ndim}."
77 )
78 self.hsdp_state = TorchHSDPStateV2(
79 self.modules, self.mesh_info, self.config, self.platform, self.device
80 )
82 def _register_post_backward_hook(self, args, kwargs):
83 """Wrap forward args/kwargs through PostBackwardFunction to register backward hook."""
84 if not torch.is_grad_enabled():
85 return args, kwargs
86 args_list, args_spec = tree_flatten(args)
87 kwargs_list, kwargs_spec = tree_flatten(kwargs)
88 args_kwargs_list = list(args_list) + list(kwargs_list)
89 inp_tensor_indices: List[int] = []
90 inp_tensors: List[torch.Tensor] = []
91 for i, obj in enumerate(args_kwargs_list):
92 if torch.is_tensor(obj) and obj.requires_grad:
93 inp_tensor_indices.append(i)
94 inp_tensors.append(obj)
95 if len(inp_tensors) == 0:
96 return args, kwargs # no tensors that require gradients
97 processed_tensors = PostBackwardFunction.apply(self, *inp_tensors)
98 for inp_tensor_idx, processed_tensor in zip(inp_tensor_indices, processed_tensors):
99 args_kwargs_list[inp_tensor_idx] = processed_tensor
100 args_list = args_kwargs_list[: len(args_list)]
101 kwargs_list = args_kwargs_list[len(args_list) :]
102 args = tree_unflatten(args_list, args_spec)
103 kwargs = tree_unflatten(kwargs_list, kwargs_spec)
104 return args, kwargs
106 def _forward_pre_hook(self, cell, args, kwargs):
107 """Execute forward pre hook and set up backward hook."""
108 args, kwargs = self._hsdp_forward_pre_hook(cell, args, kwargs)
109 return self._register_post_backward_hook(args, kwargs)
111 def _register_backward_pre_hook(self, outputs):
112 """Register gradient hooks on all requires-grad outputs to trigger backward pre hook."""
113 flat_outputs, _ = tree_flatten(outputs)
114 for output in flat_outputs:
115 if isinstance(output, torch.Tensor) and output.requires_grad:
116 output.register_hook(self._backward_pre_hook)
117 return outputs
119 def _forward_hook(self, cell, inputs, outputs): # pylint: disable=R1710
120 """Execute forward hook."""
121 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD:
122 return
123 self._register_backward_pre_hook(outputs)
124 if HSDPSchedulerV2.root_bp_state:
125 self._restore_forward_prefetch_after_recompute()
126 return
127 return self._hsdp_forward_hook(cell, inputs, outputs)
129 # pylint: disable=W0212
130 def _backward_pre_hook(self, grad):
131 """Execute backward pre hook."""
132 Variable._execution_engine.queue_callback(self._root_backward_hook)
133 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD:
134 return grad
135 HSDPSchedulerV2.root_bp_state = True
136 self._hsdp_backward_pre_hook(self.cell, None)
137 return grad
139 def _root_backward_hook(self):
140 """Root backward hook: finalize gradient reduction for the outermost HSDP module.
142 For the root module (the last to finish backward), this hook drains any
143 pending fused reduction from ``CommContext`` and then calls ``reduce_params()``
144 to apply the final per-parameter gradient reduction.
145 """
146 apply_final_reduce = self.scheduler_state != FSDPSchedulerState.BACKWARD
147 self._backward_hook()
148 if apply_final_reduce:
149 HSDPSchedulerV2.root_bp_state = False
150 with torch.profiler.record_function(f"root_backward reduce:{self.hsdp_state.module_name}"):
151 # Drain any pending async fused reduction from the last module's backward
152 comm_ctx = get_comm_ctx()
153 # Drain any pending pipelined HSDP reductions
154 if comm_ctx.all_reduce_param_group is not None:
155 comm_ctx.all_reduce_param_group.wait_all_reduce_and_apply_grad()
156 comm_ctx.all_reduce_param_group = None
157 if comm_ctx.pre_param_group is not None:
158 comm_ctx.pre_param_group.apply_fusion_reduced_grad()
159 comm_ctx.pre_param_group = None
160 self.hsdp_state.reduce_params()
162 def _backward_hook(self):
163 """Execute backward hook."""
164 if self.scheduler_state == FSDPSchedulerState.BACKWARD:
165 return
166 self._hsdp_backward_hook(self.cell, None, None)
168 # pylint: disable=W0613
169 def _grouped_forward_pre_hook_skip(self, cell, args, kwargs) -> None:
170 """Override base ``(args, kwargs)`` return; ``nn.Module`` pre-hook uses ``None`` for no-op."""
171 return None
173 def _grouped_forward_post_hook_skip(self, outputs) -> None:
174 """Override base output pass-through; forward hook uses ``None`` for no-op."""
175 return None
177 def _register_forward_module_hook(self, mod, hook) -> None:
178 """Register forward hook; use ``always_call=True`` when supported (matches PyTorch FSDP)."""
179 sig = inspect.signature(mod.register_forward_hook)
180 if "always_call" in sig.parameters:
181 mod.register_forward_hook(hook, prepend=False, always_call=True)
182 else:
183 mod.register_forward_hook(hook, prepend=False)
185 def _register_forward_backward_hooks(self):
186 """Register module forward and backward hook on all managed modules."""
187 if self._fsdp_group_post_pending is None:
188 for mod in self.modules:
189 mod.register_forward_pre_hook(self._forward_pre_hook, with_kwargs=True)
190 mod.register_forward_hook(self._forward_hook)
191 return
192 for mod in self.modules:
193 mod.register_forward_pre_hook(self._grouped_forward_pre_hook, with_kwargs=True)
194 self._register_forward_module_hook(mod, self._make_grouped_forward_post_hook(mod))