Coverage for hyper_parallel / platform / torch / activation_checkpoint / activation_swap.py: 80%

86 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py 

16# enhanced with activation swap functionality. 

17# ============================================================================ 

18"""Activation Swap implementation for PyTorch.""" 

19# pylint: disable=W0212, W0613 

20 

21import enum 

22from abc import ABC, abstractmethod 

23from collections.abc import Iterator 

24from typing import Optional, Callable, Any 

25import torch 

26from torch import nn 

27from torch.distributed.utils import _replace_by_prefix 

28from hyper_parallel.core.activation_checkpoint.swap import SwapManager, SwapTensor, Storage 

29 

30 

31_SWAP_WRAPPED_MODULE = "_swap_wrapped_module" 

32_SWAP_PREFIX = _SWAP_WRAPPED_MODULE + "." 

33 

34 

35class ActivationPolicy(enum.Enum): 

36 """Enum for activation policies.""" 

37 SAVE = 0 

38 SWAP = 1 

39 

40 

41def base_check_fn(tensor) -> bool: 

42 """ 

43 Basic check to determine if a tensor is eligible for offloading. 

44 - Skip Parameters and their views. 

45 - Skip empty storage tensors. 

46 """ 

47 if isinstance(tensor._base, torch.nn.parameter.Parameter) or isinstance(tensor, torch.nn.parameter.Parameter): # pylint: disable=W0212 

48 return False 

49 if tensor.storage().size() <= 0: 

50 return False 

51 return True 

52 

53 

54class AsyncSaveOnCpu(torch.autograd.graph.saved_tensors_hooks): 

55 """ 

56 Context manager to offload tensors to CPU during forward pass. 

57 """ 

58 def __init__(self, policy_fn=None) -> None: 

59 self.add_to_storage = False 

60 self.storage = Storage() 

61 self.count_idx = 0 

62 self.pack_count = 0 

63 self.unpack_count = 0 

64 self.policy_fn = policy_fn 

65 

66 def pack_to_cpu(tensor: torch.Tensor): 

67 # skip ineligible tensors 

68 if not base_check_fn(tensor): 

69 return tensor 

70 

71 if (policy_fn is not None) and (policy_fn(tensor)==ActivationPolicy.SAVE): 

72 return tensor 

73 

74 if not self.add_to_storage: 

75 group_name = SwapManager().get_current_group_name() 

76 SwapManager().add_storage(group_name, self.storage) 

77 self.add_to_storage = True 

78 self.storage.swap_storage[self.count_idx].append(SwapTensor(tensor)) 

79 idx = self.count_idx 

80 self.count_idx += 1 

81 self.pack_count += 1 

82 return idx 

83 

84 def unpack_from_cpu(idx) -> torch.Tensor: 

85 if isinstance(idx, torch.Tensor): 

86 return idx 

87 

88 swap_tensor = self.storage.swap_storage[idx].pop(0) 

89 tensor = swap_tensor.get_val() 

90 self.unpack_count += 1 

91 if self.unpack_count == self.pack_count: 

92 self.storage = None 

93 return tensor 

94 

95 super().__init__(pack_to_cpu, unpack_from_cpu) 

96 

97 

98class ActivationWrapper(torch.nn.Module, ABC): 

99 """ 

100 Base class for Activation Swap. 

101 

102 Not meant to be instantiated directly. 

103 """ 

104 

105 def __init__(self, module): 

106 super().__init__() 

107 self._swap_wrapped_module = module 

108 # state_dict post hook to remove prefix to allow loading into a 

109 # non-swap wrapped module. 

110 self._register_state_dict_hook(self._post_state_dict_hook) 

111 # load_state_dict pre-hook to allow loading back into 

112 # swap-wrapped module. 

113 self.register_load_state_dict_pre_hook(self._pre_load_state_dict_hook) 

114 

115 @abstractmethod 

116 def forward(self, *args, **kwargs): 

117 raise ValueError("Subclasses should implement forward().") 

118 

119 def __getattr__(self, name: str) -> Any: 

120 """Forward missing attributes to wrapped module.""" 

121 try: 

122 return super().__getattr__(name) # defer to nn.Module's logic 

123 except AttributeError: 

124 return getattr(self._swap_wrapped_module, name) 

125 

126 def __getitem__(self, key: int) -> Any: 

127 """Forward indexing calls in case the module is a nn.Sequential.""" 

128 return self._swap_wrapped_module.__getitem__(key) # type: ignore[operator] 

129 

130 def named_parameters( 

131 self, 

132 *args, 

133 **kwargs, 

134 ) -> Iterator[tuple[str, torch.nn.Parameter]]: 

135 """ 

136 Override :meth:`named_parameters()` to intercept parameter names. 

137 

138 remove all occurrences of ``_SWAP_PREFIX``. 

139 """ 

140 for param_name, param in super().named_parameters(*args, **kwargs): 

141 yield param_name.replace(_SWAP_PREFIX, ""), param 

142 

143 @staticmethod 

144 def _post_state_dict_hook( 

145 module: nn.Module, # pylint: disable=W0613 

146 state_dict: dict[str, Any], 

147 prefix: str, 

148 *args: Any, # pylint: disable=W0613 

149 ) -> dict[str, Any]: 

150 """ 

151 _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed. 

152 

153 For ``swap_wrapper``, it will strip swap-wrapped module prefix, 

154 so that this module can be loaded into non-swapped modules. 

155 It would still be able to be loaded into swap-wrapped modules as this class, 

156 adds the prefix back before loading the state_dict. 

157 """ 

158 _replace_by_prefix(state_dict, f"{prefix}{_SWAP_PREFIX}", prefix) 

159 return state_dict 

160 

161 @staticmethod 

162 def _pre_load_state_dict_hook( 

163 module: nn.Module, 

164 state_dict: dict[str, Any], 

165 prefix: str, 

166 *args: Any, 

167 ) -> None: 

168 """ 

169 ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` is called. 

170 

171 For ``swap_wrapper``, it will add back the module 

172 prefix so that non-swapped modules can be loaded into 

173 swap_wrapper modules properly. 

174 """ 

175 _replace_by_prefix(state_dict, prefix, prefix + f"{_SWAP_PREFIX}") 

176 

177 

178class SwapWrapper(ActivationWrapper): 

179 """ 

180 Customize an nn.Module wrapper class to add an AsyncSaveOnCpu context manager for the target model. 

181 """ 

182 def __init__(self, mod: nn.Module, policy_fn: Optional[Callable] = None): 

183 super().__init__(mod) 

184 self.policy_fn = policy_fn 

185 

186 def forward(self, *args, **kwargs): 

187 with AsyncSaveOnCpu(policy_fn=self.policy_fn): 

188 return self._swap_wrapped_module(*args, **kwargs) 

189 

190 

191def swap_wrapper(module: nn.Module, policy_fn: Optional[Callable] = None): 

192 return SwapWrapper(module, policy_fn)