Coverage for hyper_parallel / platform / torch / hsdp / scheduler.py: 82%

60 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Torch HSDP scheduler""" 

16import torch 

17from torch.autograd import Variable 

18from torch.utils._pytree import tree_flatten, tree_unflatten 

19from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel 

20from hyper_parallel.core.hsdp.hsdp_scheduler import HSDPScheduler, FSDPSchedulerState 

21from hyper_parallel.platform.torch.hsdp.hook_function import PostBackwardFunction 

22from hyper_parallel.platform.torch.hsdp.state import TorchHSDPState 

23from hyper_parallel.platform.torch.hsdp.grad_hook import TorchHSDPGradHook 

24from hyper_parallel.platform.torch.hsdp.async_grad_hook import TorchHSDPAsyncGradHook 

25from hyper_parallel.platform import get_platform 

26 

27 

28class TorchHSDPScheduler(HSDPScheduler): 

29 """TorchHSDPScheduler is used to implement optimizer level.""" 

30 

31 def _init_platform(self): 

32 """Initialize the platform.""" 

33 self.platform = get_platform() 

34 

35 def _new_cell_state(self): 

36 """Create a new cell state for torch.""" 

37 self.hsdp_state = TorchHSDPState(self.cell, self.config, self.platform) 

38 

39 def _new_grad_hook(self): 

40 """ 

41 Create and initialize a new TorchHSDPGradHook instance. 

42  

43 This method instantiates the gradient hook component used for handling 

44 gradient operations in the HSDP scheduler. 

45 """ 

46 if self.config.comm_async: 

47 self.grad_hook = TorchHSDPAsyncGradHook(self.config, self.platform) 

48 else: 

49 self.grad_hook = TorchHSDPGradHook(self.config, self.platform) 

50 

51 def _register_backward_hook(self, args, kwargs): 

52 """Register backward hook using backward function.""" 

53 args_list, args_spec = tree_flatten(args) 

54 kwargs_list, kwargs_spec = tree_flatten(kwargs) 

55 args_kwargs_list = list(args_list) + list(kwargs_list) 

56 args_kwargs_list = PostBackwardFunction.apply(self, *args_kwargs_list) 

57 args_list = args_kwargs_list[: len(args_list)] 

58 kwargs_list = args_kwargs_list[len(args_list) :] 

59 args = tree_unflatten(args_list, args_spec) 

60 kwargs = tree_unflatten(kwargs_list, kwargs_spec) 

61 return args, kwargs 

62 

63 def _forward_pre_hook(self, cell, args, kwargs): 

64 """Execute forward pre hook and set up backward hook.""" 

65 self._hsdp_forward_pre_hook(cell, args) 

66 return self._register_backward_hook(args, kwargs) 

67 

68 def _register_backward_pre_hook(self, outputs): 

69 """Register output hook to trigger backward pre hook.""" 

70 flat_outputs, _ = tree_flatten(outputs) 

71 for output in flat_outputs: 

72 if isinstance(output, torch.Tensor) and output.requires_grad: 

73 output.register_hook(self._backward_pre_hook) 

74 return outputs 

75 

76 def _forward_hook(self, cell, inputs, outputs): 

77 """Execute forward hook.""" 

78 self._register_backward_pre_hook(outputs) 

79 if self.shard_level != OptimizerLevel.SHARD_OPT_GRAD_PARAM: 

80 return 

81 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD: 

82 return 

83 self._hsdp_forward_hook(cell, inputs, outputs) 

84 

85 # pylint: disable=W0212 

86 def _backward_pre_hook(self, grad): 

87 """Execute backward pre hook.""" 

88 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD: 

89 return grad 

90 self._hsdp_backward_pre_hook(self.cell, None) 

91 Variable._execution_engine.queue_callback(self._backward_hook) 

92 return grad 

93 

94 def _backward_hook(self): 

95 """Execute backward hook.""" 

96 if self.scheduler_state == FSDPSchedulerState.BACKWARD: 

97 return 

98 if self.requires_acc_grad and self.shard_level != OptimizerLevel.SHARD_OPT_GRAD_PARAM: 

99 self._hsdp_acc_backward_hook(self.cell, None, None) 

100 else: 

101 self._hsdp_backward_hook(self.cell, None, None) 

102 

103 def _register_forward_backward_hooks(self): 

104 """Register module forward and backward hook.""" 

105 self.cell.register_forward_pre_hook(self._forward_pre_hook, with_kwargs=True) 

106 self.cell.register_forward_hook(self._forward_hook)