Coverage for hyper_parallel / core / activation_checkpoint / activation_checkpoint.py: 100%
17 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Activation checkpointing related interfaces"""
16import contextlib
17import enum
18from functools import partial
19from hyper_parallel.platform import get_platform
20plat = get_platform()
23class CheckpointPolicy(enum.Enum):
24 """
25 Enum for specifying the policy for checkpointing during backpropagation.
27 This enum extends PyTorch's selective activation checkpointing policies
28 by introducing a SWAP-based strategy, which allows activation tensors
29 to be offloaded during the forward pass and loaded back before backward
30 computation.
32 For PyTorch native policies (SAVE / RECOMPUTE semantics and MUST vs PREFER),
33 see: https://docs.pytorch.org/docs/2.6/checkpoint.html#torch.utils.checkpoint.CheckpointPolicy
35 Additional policy:
37 - ``MUST_SWAP``: The operation's output is offloaded to host memory during the
38 forward pass and loaded back asynchronously before backward computation. The backward
39 pass reuses the loaded activations without recomputation.
41 This policy must be used together with :class:`SwapManager` to coordinate
42 asynchronous offload/load and stream synchronization.
44 .. note::
45 ``MUST_SWAP`` is typically applied to operations that are either
46 computationally expensive or have large memory footprints. Note that
47 swapping very small outputs may introduce additional overhead and
48 reduce the effectiveness of asynchronous copy.
49 """
50 MUST_SAVE = 0
51 PREFER_SAVE = 1
52 MUST_RECOMPUTE = 2
53 PREFER_RECOMPUTE = 3
55 # Offload during forward, reload before backward. Requires SwapManager.
56 MUST_SWAP = 4
59def checkpoint(function, *args, swap_inputs=False, policy_fn=None, **kwargs):
60 """
61 Apply activation checkpointing to a function with optional input swapping.
63 Args:
64 function: The function to apply checkpointing to.
65 *args: Arguments to pass to the function.
66 swap_inputs (bool): Whether to enable input swapping using async_save_on_cpu context.
67 policy_fn (callable, optional): Function that determines checkpoint policy for operations.
68 **kwargs: Additional keyword arguments to pass to the function.
70 Returns:
71 The result of applying the function with checkpointing.
72 """
73 context_fn = partial(plat.create_selective_checkpoint_contexts, policy_fn) if policy_fn else plat.noop_context_fn
74 context = plat.async_save_on_cpu if swap_inputs else contextlib.nullcontext
75 with context():
76 return plat.checkpoint(function, *args, context_fn=context_fn, use_reentrant=False, **kwargs)
79checkpoint_wrapper = partial(plat.ckpt_wrapper, checkpoint_fn=checkpoint)