Coverage for hyper_parallel / platform / torch / fully_shard / scheduler.py: 83%

63 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Torch HSDP scheduler""" 

16import torch 

17from torch.autograd import Variable 

18from torch.utils._pytree import tree_flatten, tree_unflatten 

19from hyper_parallel.core.fully_shard.hsdp_scheduler import HSDPSchedulerV2, FSDPSchedulerState 

20from hyper_parallel.platform.torch.fully_shard.hook_function import PostBackwardFunction 

21from hyper_parallel.platform.torch.fully_shard.state import TorchHSDPStateV2 

22from hyper_parallel.platform.torch.fully_shard.utils import FSDPMeshInfo, HSDPMeshInfo 

23from hyper_parallel.platform import get_platform 

24 

25 

26 

27class TorchHSDPSchedulerV2(HSDPSchedulerV2): 

28 """TorchHSDPScheduler is used to implement optimizer level.""" 

29 

30 def _register_hooks(self): 

31 """Register hooks.""" 

32 self._register_forward_backward_hooks() 

33 

34 def _init_platform(self): 

35 """Initialize the platform.""" 

36 from hyper_parallel.platform.torch.platform import TorchPlatform 

37 self.platform = get_platform() 

38 if not isinstance(self.platform, TorchPlatform): 

39 raise ValueError(f"TorchHSDPSchedulerV2 expect TorchPlatform, but got type: {type(self.platform)}") 

40 

41 def _new_cell_state(self): 

42 """Create a new cell state for torch.""" 

43 if self.mesh.ndim not in (1, 2): 

44 raise ValueError("fully_shard only support 1D and 2D mesh.") 

45 elif self.mesh.ndim == 1: 

46 # FSDP2 

47 self.mesh_info = FSDPMeshInfo(mesh=self.mesh, shard_mesh_dim=0) 

48 else: 

49 # HSDP 

50 self.mesh_info = HSDPMeshInfo(mesh=self.mesh, shard_mesh_dim=1, replicate_mesh_dim=0) 

51 self.hsdp_state = TorchHSDPStateV2(self.cell, self.mesh_info, self.config, self.platform, self.device) 

52 

53 def _new_grad_hook(self): 

54 """ 

55 Create and initialize a new TorchHSDPGradHook instance. 

56  

57 This method instantiates the gradient hook component used for handling 

58 gradient operations in the HSDP scheduler. 

59 """ 

60 # TorchHSDPScheduler don't need param hook, using param.grad 

61 pass 

62 # if self.config.comm_async: 

63 # self.grad_hook = TorchHSDPAsyncGradHook(self.config, self.platform) 

64 # else: 

65 # self.grad_hook = TorchHSDPGradHook(self.config, self.platform) 

66 

67 def _register_post_backward_hook(self, args, kwargs): 

68 """Register backward hook using backward function.""" 

69 args_list, args_spec = tree_flatten(args) 

70 kwargs_list, kwargs_spec = tree_flatten(kwargs) 

71 args_kwargs_list = list(args_list) + list(kwargs_list) 

72 args_kwargs_list = PostBackwardFunction.apply(self, *args_kwargs_list) 

73 args_list = args_kwargs_list[: len(args_list)] 

74 kwargs_list = args_kwargs_list[len(args_list) :] 

75 args = tree_unflatten(args_list, args_spec) 

76 kwargs = tree_unflatten(kwargs_list, kwargs_spec) 

77 return args, kwargs 

78 

79 def _forward_pre_hook(self, cell, args, kwargs): 

80 """Execute forward pre hook and set up backward hook.""" 

81 args, kwargs = self._hsdp_forward_pre_hook(cell, args, kwargs) 

82 return self._register_post_backward_hook(args, kwargs) 

83 

84 def _register_backward_pre_hook(self, outputs): 

85 """Register output hook to trigger backward pre hook.""" 

86 flat_outputs, _ = tree_flatten(outputs) 

87 for output in flat_outputs: 

88 if isinstance(output, torch.Tensor) and output.requires_grad: 

89 output.register_hook(self._backward_pre_hook) 

90 return outputs 

91 

92 def _forward_hook(self, cell, inputs, outputs): 

93 """Execute forward hook.""" 

94 self._register_backward_pre_hook(outputs) 

95 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD: 

96 return 

97 outputs = self._hsdp_forward_hook(cell, inputs, outputs) 

98 return outputs 

99 

100 # pylint: disable=W0212 

101 def _backward_pre_hook(self, grad): 

102 """Execute backward pre hook.""" 

103 if self.scheduler_state == FSDPSchedulerState.PRE_BACKWARD: 

104 return grad 

105 self._hsdp_backward_pre_hook(self.cell, None) 

106 Variable._execution_engine.queue_callback(self._backward_hook) 

107 return grad 

108 

109 def _backward_hook(self): 

110 """Execute backward hook.""" 

111 if self.scheduler_state == FSDPSchedulerState.BACKWARD: 

112 return 

113 self._hsdp_backward_hook(self.cell, None, None) 

114 

115 def _register_forward_backward_hooks(self): 

116 """Register module forward and backward hook.""" 

117 self.cell.register_forward_pre_hook(self._forward_pre_hook, with_kwargs=True) 

118 self.cell.register_forward_hook(self._forward_hook)