Coverage for hyper_parallel / platform / torch / activation_checkpoint / activation_swap.py: 80%
86 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py
16# enhanced with activation swap functionality.
17# ============================================================================
18"""Activation Swap implementation for PyTorch."""
19# pylint: disable=W0212, W0613
21import enum
22from abc import ABC, abstractmethod
23from collections.abc import Iterator
24from typing import Optional, Callable, Any
25import torch
26from torch import nn
27from torch.distributed.utils import _replace_by_prefix
28from hyper_parallel.core.activation_checkpoint.swap import SwapManager, SwapTensor, Storage
31_SWAP_WRAPPED_MODULE = "_swap_wrapped_module"
32_SWAP_PREFIX = _SWAP_WRAPPED_MODULE + "."
35class ActivationPolicy(enum.Enum):
36 """Enum for activation policies."""
37 SAVE = 0
38 SWAP = 1
41def base_check_fn(tensor) -> bool:
42 """
43 Basic check to determine if a tensor is eligible for offloading.
44 - Skip Parameters and their views.
45 - Skip empty storage tensors.
46 """
47 if isinstance(tensor._base, torch.nn.parameter.Parameter) or isinstance(tensor, torch.nn.parameter.Parameter): # pylint: disable=W0212
48 return False
49 if tensor.storage().size() <= 0:
50 return False
51 return True
54class AsyncSaveOnCpu(torch.autograd.graph.saved_tensors_hooks):
55 """
56 Context manager to offload tensors to CPU during forward pass.
57 """
58 def __init__(self, policy_fn=None) -> None:
59 self.add_to_storage = False
60 self.storage = Storage()
61 self.count_idx = 0
62 self.pack_count = 0
63 self.unpack_count = 0
64 self.policy_fn = policy_fn
66 def pack_to_cpu(tensor: torch.Tensor):
67 # skip ineligible tensors
68 if not base_check_fn(tensor):
69 return tensor
71 if (policy_fn is not None) and (policy_fn(tensor)==ActivationPolicy.SAVE):
72 return tensor
74 if not self.add_to_storage:
75 group_name = SwapManager().get_current_group_name()
76 SwapManager().add_storage(group_name, self.storage)
77 self.add_to_storage = True
78 self.storage.swap_storage[self.count_idx].append(SwapTensor(tensor))
79 idx = self.count_idx
80 self.count_idx += 1
81 self.pack_count += 1
82 return idx
84 def unpack_from_cpu(idx) -> torch.Tensor:
85 if isinstance(idx, torch.Tensor):
86 return idx
88 swap_tensor = self.storage.swap_storage[idx].pop(0)
89 tensor = swap_tensor.get_val()
90 self.unpack_count += 1
91 if self.unpack_count == self.pack_count:
92 self.storage = None
93 return tensor
95 super().__init__(pack_to_cpu, unpack_from_cpu)
98class ActivationWrapper(torch.nn.Module, ABC):
99 """
100 Base class for Activation Swap.
102 Not meant to be instantiated directly.
103 """
105 def __init__(self, module):
106 super().__init__()
107 self._swap_wrapped_module = module
108 # state_dict post hook to remove prefix to allow loading into a
109 # non-swap wrapped module.
110 self._register_state_dict_hook(self._post_state_dict_hook)
111 # load_state_dict pre-hook to allow loading back into
112 # swap-wrapped module.
113 self.register_load_state_dict_pre_hook(self._pre_load_state_dict_hook)
115 @abstractmethod
116 def forward(self, *args, **kwargs):
117 raise ValueError("Subclasses should implement forward().")
119 def __getattr__(self, name: str) -> Any:
120 """Forward missing attributes to wrapped module."""
121 try:
122 return super().__getattr__(name) # defer to nn.Module's logic
123 except AttributeError:
124 return getattr(self._swap_wrapped_module, name)
126 def __getitem__(self, key: int) -> Any:
127 """Forward indexing calls in case the module is a nn.Sequential."""
128 return self._swap_wrapped_module.__getitem__(key) # type: ignore[operator]
130 def named_parameters(
131 self,
132 *args,
133 **kwargs,
134 ) -> Iterator[tuple[str, torch.nn.Parameter]]:
135 """
136 Override :meth:`named_parameters()` to intercept parameter names.
138 remove all occurrences of ``_SWAP_PREFIX``.
139 """
140 for param_name, param in super().named_parameters(*args, **kwargs):
141 yield param_name.replace(_SWAP_PREFIX, ""), param
143 @staticmethod
144 def _post_state_dict_hook(
145 module: nn.Module, # pylint: disable=W0613
146 state_dict: dict[str, Any],
147 prefix: str,
148 *args: Any, # pylint: disable=W0613
149 ) -> dict[str, Any]:
150 """
151 _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed.
153 For ``swap_wrapper``, it will strip swap-wrapped module prefix,
154 so that this module can be loaded into non-swapped modules.
155 It would still be able to be loaded into swap-wrapped modules as this class,
156 adds the prefix back before loading the state_dict.
157 """
158 _replace_by_prefix(state_dict, f"{prefix}{_SWAP_PREFIX}", prefix)
159 return state_dict
161 @staticmethod
162 def _pre_load_state_dict_hook(
163 module: nn.Module,
164 state_dict: dict[str, Any],
165 prefix: str,
166 *args: Any,
167 ) -> None:
168 """
169 ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` is called.
171 For ``swap_wrapper``, it will add back the module
172 prefix so that non-swapped modules can be loaded into
173 swap_wrapper modules properly.
174 """
175 _replace_by_prefix(state_dict, prefix, prefix + f"{_SWAP_PREFIX}")
178class SwapWrapper(ActivationWrapper):
179 """
180 Customize an nn.Module wrapper class to add an AsyncSaveOnCpu context manager for the target model.
181 """
182 def __init__(self, mod: nn.Module, policy_fn: Optional[Callable] = None):
183 super().__init__(mod)
184 self.policy_fn = policy_fn
186 def forward(self, *args, **kwargs):
187 with AsyncSaveOnCpu(policy_fn=self.policy_fn):
188 return self._swap_wrapped_module(*args, **kwargs)
191def swap_wrapper(module: nn.Module, policy_fn: Optional[Callable] = None):
192 return SwapWrapper(module, policy_fn)