Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / scheduler.py: 42%

122 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Torch HSDP scheduler""" 

16import inspect 

17import torch 

18from typing import List 

19from torch.autograd import Variable 

20from torch.utils._pytree import tree_flatten, tree_unflatten 

21from hyper_parallel.core.dtensor.dtensor import DTensor 

22from hyper_parallel.core.fully_shard.hsdp_scheduler import HSDPSchedulerV2, FSDPSchedulerState 

23from hyper_parallel.core.fully_shard.utils import FSDPMeshInfo, DDPMeshInfo, HSDPMeshInfo 

24from hyper_parallel.platform.torch.fully_shard.hook_function import PostBackwardFunction 

25from hyper_parallel.platform.torch.fully_shard.state import TorchHSDPStateV2 

26from hyper_parallel.platform.torch.fully_shard.param_group import get_comm_ctx 

27from hyper_parallel.platform import get_platform 

28 

29 

30class TorchHSDPSchedulerV2(HSDPSchedulerV2): 

31 """TorchHSDPScheduler is used to implement optimizer level.""" 

32 

33 def __init__(self, *args, **kwargs): 

34 """Initialize TorchHSDPSchedulerV2 and register forward/backward hooks.""" 

35 super().__init__(*args, **kwargs) 

36 

37 def _register_hooks(self): 

38 """Register hooks.""" 

39 self._register_forward_backward_hooks() 

40 

41 def _init_platform(self): 

42 """Initialize the platform.""" 

43 # pylint: disable=C0415 

44 from hyper_parallel.platform.torch.platform import TorchPlatform 

45 self.platform = get_platform() 

46 if not isinstance(self.platform, TorchPlatform): 

47 raise ValueError(f"TorchHSDPSchedulerV2 expect TorchPlatform, but got type: {type(self.platform)}") 

48 

49 def _new_cell_state(self): 

50 """Create a new cell state for torch.""" 

51 params = self._get_managed_params() 

52 if self.mesh is None: 

53 compat_meshes = [ 

54 param.device_mesh for param in params if isinstance(param, DTensor) 

55 ] 

56 compat_mesh = compat_meshes[0] if compat_meshes else None 

57 if compat_mesh is None: 

58 raise ValueError( 

59 "Cannot build fully_shard compatibility mesh_info " 

60 "without a DTensor parameter mesh." 

61 ) 

62 compat_mesh_hash = compat_mesh.to_hash() 

63 for param_mesh in compat_meshes[1:]: 

64 if param_mesh.to_hash() != compat_mesh_hash: 

65 raise ValueError( 

66 "fully_shard compatibility mode requires all DTensor parameters to share the same mesh." 

67 ) 

68 self.mesh_info = DDPMeshInfo(mesh=compat_mesh, replicate_mesh_dim=0) 

69 elif self.mesh.ndim == 1: 

70 self.mesh_info = FSDPMeshInfo(mesh=self.mesh, shard_mesh_dim=0) 

71 elif self.mesh.ndim == 2: 

72 self.mesh_info = HSDPMeshInfo(mesh=self.mesh, shard_mesh_dim=1, replicate_mesh_dim=0) 

73 else: 

74 raise ValueError( 

75 "fully_shard only supports explicit 1D DP/FSDP meshes or 2D HSDP meshes. " 

76 f"Got mesh.ndim={self.mesh.ndim}." 

77 ) 

78 self.hsdp_state = TorchHSDPStateV2( 

79 self.modules, self.mesh_info, self.config, self.platform, self.device 

80 ) 

81 

82 def _register_post_backward_hook(self, args, kwargs): 

83 """Wrap forward args/kwargs through PostBackwardFunction to register backward hook.""" 

84 if not torch.is_grad_enabled(): 

85 return args, kwargs 

86 args_list, args_spec = tree_flatten(args) 

87 kwargs_list, kwargs_spec = tree_flatten(kwargs) 

88 args_kwargs_list = list(args_list) + list(kwargs_list) 

89 inp_tensor_indices: List[int] = [] 

90 inp_tensors: List[torch.Tensor] = [] 

91 for i, obj in enumerate(args_kwargs_list): 

92 if torch.is_tensor(obj) and obj.requires_grad: 

93 inp_tensor_indices.append(i) 

94 inp_tensors.append(obj) 

95 if len(inp_tensors) == 0: 

96 return args, kwargs # no tensors that require gradients 

97 processed_tensors = PostBackwardFunction.apply(self, *inp_tensors) 

98 for inp_tensor_idx, processed_tensor in zip(inp_tensor_indices, processed_tensors): 

99 args_kwargs_list[inp_tensor_idx] = processed_tensor 

100 args_list = args_kwargs_list[: len(args_list)] 

101 kwargs_list = args_kwargs_list[len(args_list) :] 

102 args = tree_unflatten(args_list, args_spec) 

103 kwargs = tree_unflatten(kwargs_list, kwargs_spec) 

104 return args, kwargs 

105 

106 def _forward_pre_hook(self, cell, args, kwargs): 

107 """Execute forward pre hook and set up backward hook.""" 

108 args, kwargs = self._hsdp_forward_pre_hook(cell, args, kwargs) 

109 return self._register_post_backward_hook(args, kwargs) 

110 

111 def _register_backward_pre_hook(self, outputs): 

112 """Register gradient hooks on all requires-grad outputs to trigger backward pre hook.""" 

113 flat_outputs, _ = tree_flatten(outputs) 

114 for output in flat_outputs: 

115 if isinstance(output, torch.Tensor) and output.requires_grad: 

116 output.register_hook(self._backward_pre_hook) 

117 return outputs 

118 

119 def _forward_hook(self, cell, inputs, outputs): # pylint: disable=R1710 

120 """Execute forward hook.""" 

121 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD: 

122 return 

123 self._register_backward_pre_hook(outputs) 

124 if HSDPSchedulerV2.root_bp_state: 

125 self._restore_forward_prefetch_after_recompute() 

126 return 

127 return self._hsdp_forward_hook(cell, inputs, outputs) 

128 

129 # pylint: disable=W0212 

130 def _backward_pre_hook(self, grad): 

131 """Execute backward pre hook.""" 

132 Variable._execution_engine.queue_callback(self._root_backward_hook) 

133 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD: 

134 return grad 

135 HSDPSchedulerV2.root_bp_state = True 

136 self._hsdp_backward_pre_hook(self.cell, None) 

137 return grad 

138 

139 def _root_backward_hook(self): 

140 """Root backward hook: finalize gradient reduction for the outermost HSDP module. 

141 

142 For the root module (the last to finish backward), this hook drains any 

143 pending fused reduction from ``CommContext`` and then calls ``reduce_params()`` 

144 to apply the final per-parameter gradient reduction. 

145 """ 

146 apply_final_reduce = self.scheduler_state != FSDPSchedulerState.BACKWARD 

147 self._backward_hook() 

148 if apply_final_reduce: 

149 HSDPSchedulerV2.root_bp_state = False 

150 with torch.profiler.record_function(f"root_backward reduce:{self.hsdp_state.module_name}"): 

151 # Drain any pending async fused reduction from the last module's backward 

152 comm_ctx = get_comm_ctx() 

153 # Drain any pending pipelined HSDP reductions 

154 if comm_ctx.all_reduce_param_group is not None: 

155 comm_ctx.all_reduce_param_group.wait_all_reduce_and_apply_grad() 

156 comm_ctx.all_reduce_param_group = None 

157 if comm_ctx.pre_param_group is not None: 

158 comm_ctx.pre_param_group.apply_fusion_reduced_grad() 

159 comm_ctx.pre_param_group = None 

160 self.hsdp_state.reduce_params() 

161 

162 def _backward_hook(self): 

163 """Execute backward hook.""" 

164 if self.scheduler_state == FSDPSchedulerState.BACKWARD: 

165 return 

166 self._hsdp_backward_hook(self.cell, None, None) 

167 

168 # pylint: disable=W0613 

169 def _grouped_forward_pre_hook_skip(self, cell, args, kwargs) -> None: 

170 """Override base ``(args, kwargs)`` return; ``nn.Module`` pre-hook uses ``None`` for no-op.""" 

171 return None 

172 

173 def _grouped_forward_post_hook_skip(self, outputs) -> None: 

174 """Override base output pass-through; forward hook uses ``None`` for no-op.""" 

175 return None 

176 

177 def _register_forward_module_hook(self, mod, hook) -> None: 

178 """Register forward hook; use ``always_call=True`` when supported (matches PyTorch FSDP).""" 

179 sig = inspect.signature(mod.register_forward_hook) 

180 if "always_call" in sig.parameters: 

181 mod.register_forward_hook(hook, prepend=False, always_call=True) 

182 else: 

183 mod.register_forward_hook(hook, prepend=False) 

184 

185 def _register_forward_backward_hooks(self): 

186 """Register module forward and backward hook on all managed modules.""" 

187 if self._fsdp_group_post_pending is None: 

188 for mod in self.modules: 

189 mod.register_forward_pre_hook(self._forward_pre_hook, with_kwargs=True) 

190 mod.register_forward_hook(self._forward_hook) 

191 return 

192 for mod in self.modules: 

193 mod.register_forward_pre_hook(self._grouped_forward_pre_hook, with_kwargs=True) 

194 self._register_forward_module_hook(mod, self._make_grouped_forward_post_hook(mod))