Coverage for hyper_parallel / platform / torch / activation_checkpoint / sac.py: 61%
105 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/utils/checkpoint.py
16# enhanced with selective checkpoint support swap
17# ============================================================================
18"""enhanced with selective checkpoint support swap"""
19# pylint: disable=W0212, W0613, C0115, C0116, C0103, R1705
20from typing import Any, Optional, Union
22import torch
23import torch.fx.traceback as fx_traceback
24from torch._functorch._aot_autograd.functional_utils import is_fun
25from torch.utils._pytree import tree_map
26from torch.utils._python_dispatch import TorchDispatchMode
27from hyper_parallel.core.activation_checkpoint import CheckpointPolicy # patch code
28from hyper_parallel.core.activation_checkpoint.swap import SwapManager, SwapTensor, Storage # patch code
30def _is_compiling(func, args, kwargs):
31 # Check if we are under AOTAutograd tracing
32 # There should probably be a better way to do this...
33 # TODO: unify _is_compiling across all compile stacks
34 for arg in args:
35 if isinstance(arg, torch.Tensor) and is_fun(arg):
36 return True
37 return False
40class _VersionWrapper:
41 # Check that cached tensors are not mutated.
42 def __init__(self, val):
43 self.val: Union[torch.Tensor, Any] = val
44 self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None
46 def get_val(self, allow_cache_entry_mutation):
47 if self.version is not None and not allow_cache_entry_mutation:
48 if self.val._version != self.version:
49 # Can we give user a stack trace of where the mutation happened?
50 raise RuntimeError(
51 "Tensor cached during selective activation checkpoint has been mutated"
52 )
53 return self.val
56def _maybe_detach(x, any_ret_has_alias_info):
57 # We detach for two separate reasons:
58 # - For view ops, we need to ensure that when the tensor is returned from
59 # CachedDispatchMode, as_view sees that the AutogradMeta is nullptr
60 # - Avoid reference cycles
61 # For case 1, it is not enough to check whether x has differentiable dtype
62 # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g.
63 # when the tensor is a view.
64 if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info):
65 with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False):
66 # Ensure that view performed beneath autograd properly propagates
67 # version counter. TODO: Use reentrant_dispatch instead of
68 # manually manipulating dispatch keys. Using reentrant_dispatch
69 # would respect inference_mode, though that is not relevant for
70 # this case.
71 x = x.detach()
72 return x
75class SelectiveCheckpointContext:
76 """
77 Context passed to policy function during selective checkpointing.
79 This class is used to pass relevant metadata to the policy function during
80 selective checkpointing. The metadata includes whether the current invocation
81 of the policy function is during recomputation or not.
83 Example:
84 >>> # xdoctest: +SKIP(stub)
85 >>>
86 >>> def policy_fn(ctx, op, *args, **kwargs):
87 >>> print(ctx.is_recompute)
88 >>>
89 >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
90 >>>
91 >>> out = torch.utils.checkpoint.checkpoint(
92 >>> fn, x, y,
93 >>> use_reentrant=False,
94 >>> context_fn=context_fn,
95 >>> )
96 """
97 def __init__(self, *, is_recompute):
98 self.is_recompute = is_recompute
101def _policy_from_bool(b):
102 # For backward compatibility
103 return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE
106SAC_IGNORED_OPS = {
107 # AC inserts different number of detach during forward and recompute.
108 torch.ops.aten.detach.default,
109 # AC's determinism check invokes additional metadata ops during forward.
110 # With subclasses involved, these metadata ops become dispatchable, this
111 # can result in incorrectness if these ops are selected cached.
112 torch.ops.prim.device.default,
113} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns)
116class _CachingTorchDispatchMode(TorchDispatchMode):
117 # Used together with _CachedTorchDispatchMode to implement SAC.
118 def __init__(self, policy_fn, storage):
119 self.policy_fn = policy_fn
120 self.storage = storage
121 self.add_to_storage = False
123 def __torch_dispatch__(self, func, types, args=(), kwargs=None):
124 if func in SAC_IGNORED_OPS:
125 return func(*args, **kwargs)
127 kwargs = {} if kwargs is None else kwargs
128 policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False),
129 func, *args, **kwargs)
130 if isinstance(policy, bool):
131 policy = _policy_from_bool(policy)
133 is_compiling = _is_compiling(func, args, kwargs)
135 if is_compiling:
136 # Overwrite each node's "recompute" tag to add in the user annotation.
137 fx_traceback.current_meta["recompute"] = policy
139 out = func(*args, **kwargs)
141 any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns)
143 if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE):
144 storage = self.storage.save_storage[func] # patch code
145 storage.append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out))
146 elif policy == CheckpointPolicy.MUST_SWAP: # patch code
147 if not self.add_to_storage:
148 group_name = SwapManager().get_current_group_name()
149 SwapManager().add_storage(group_name, self.storage)
150 self.add_to_storage = True
151 storage = self.storage.swap_storage[func]
152 storage.append(tree_map(lambda x: SwapTensor(_maybe_detach(x, any_ret_has_alias_info)), out))
153 return out
156class _CachedTorchDispatchMode(TorchDispatchMode):
157 # Used together with _CachedTorchDispatchMode to implement SAC.
158 def __init__(self, policy_fn, storage, allow_cache_entry_mutation):
159 self.policy_fn = policy_fn
160 self.storage = storage
161 self.allow_cache_entry_mutation = allow_cache_entry_mutation
163 def __torch_dispatch__(self, func, types, args=(), kwargs=None):
164 if func in SAC_IGNORED_OPS:
165 return func(*args, **kwargs)
167 kwargs = {} if kwargs is None else kwargs
168 policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True),
169 func, *args, **kwargs)
170 if isinstance(policy, bool):
171 policy = _policy_from_bool(policy)
173 is_compiling = _is_compiling(func, args, kwargs)
175 if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling:
176 storage = self.storage.save_storage.get(func) # patch code
177 if storage is None:
178 raise RuntimeError(f"{func} encountered during backward, but not found in storage")
179 if len(storage) == 0:
180 raise RuntimeError(
181 "Trying to backward an extra time. You are only allowed to backward once "
182 "on any region computed under selective activation checkpoint."
183 )
184 out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0))
185 elif policy == CheckpointPolicy.MUST_SWAP: # patch code
186 storage = self.storage.swap_storage.get(func)
187 if storage is None:
188 raise RuntimeError(f"{func} encountered during backward, but not found in storage")
189 if len(storage) == 0:
190 raise RuntimeError(
191 "Trying to backward an extra time. You are only allowed to backward once "
192 "on any region computed under selective activation checkpoint."
193 )
194 out = tree_map(lambda x: x.get_val(), storage.pop(0))
195 else:
196 out = func(*args, **kwargs)
197 return out
200def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
201 """
202 Helper to avoid recomputing certain ops during activation checkpointing.
204 Use this with `torch.utils.checkpoint.checkpoint` to control which
205 operations are recomputed during the backward pass.
207 Args:
208 policy_fn_or_list (Callable or List):
209 - If a policy function is provided, it should accept a
210 :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and
211 kwargs to the op, and return a :class:`CheckpointPolicy` enum value
212 indicating whether the execution of the op should be recomputed or not.
213 - If a list of operations is provided, it is equivalent to a policy
214 returning `CheckpointPolicy.MUST_SAVE` for the specified
215 operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other
216 operations.
217 allow_cache_entry_mutation (bool, optional): By default, an error is
218 raised if any tensors cached by selective activation checkpoint are
219 mutated in order to ensure correctness. If set to `True`, this check
220 is disabled.
221 Returns:
222 A tuple of two context managers.
224 Example:
225 >>> # xdoctest: +REQUIRES(LINUX)
226 >>> import functools
227 >>>
228 >>> x = torch.rand(10, 10, requires_grad=True)
229 >>> y = torch.rand(10, 10, requires_grad=True)
230 >>>
231 >>> ops_to_save = [
232 >>> torch.ops.aten.mm.default,
233 >>> ]
234 >>>
235 >>> def policy_fn(ctx, op, *args, **kwargs):
236 >>> if op in ops_to_save:
237 >>> return CheckpointPolicy.MUST_SAVE
238 >>> else:
239 >>> return CheckpointPolicy.PREFER_RECOMPUTE
240 >>>
241 >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
242 >>>
243 >>> # or equivalently
244 >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save)
245 >>>
246 >>> def fn(x, y):
247 >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
248 >>>
249 >>> out = torch.utils.checkpoint.checkpoint(
250 >>> fn, x, y,
251 >>> use_reentrant=False,
252 >>> context_fn=context_fn,
253 >>> )
254 """
255 # NB: If grad_mode is disabled, checkpoint would not run forward under
256 # context_fn anyway, so proceed as usual.
257 if isinstance(policy_fn_or_list, list):
258 for op in policy_fn_or_list:
259 if not isinstance(op, torch._ops.OpOverload):
260 _extra_msg = (
261 "Please update the OpOverloadPacket to a specific OpOverload."
262 "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`."
263 ) if isinstance(op, torch._ops.OpOverloadPacket) else ""
264 raise ValueError(
265 f"Expected op in `op_list` to be an OpOverload but got: {op} "
266 f"of type {type(op)}. {_extra_msg}"
267 )
269 def policy_fn(ctx, op, *args, **kwargs):
270 if op in policy_fn_or_list:
271 return CheckpointPolicy.MUST_SAVE
272 else:
273 return CheckpointPolicy.PREFER_RECOMPUTE
274 elif callable(policy_fn_or_list):
275 policy_fn = policy_fn_or_list
276 else:
277 raise TypeError("policy_fn_or_list must be either a function or a list of ops.")
279 storage = Storage() # patch code
280 return (
281 _CachingTorchDispatchMode(policy_fn, storage),
282 _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation),
283 )