Coverage for hyper_parallel / core / activation_checkpoint / activation_checkpoint.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Activation checkpointing related interfaces""" 

16import contextlib 

17import enum 

18from functools import partial 

19from hyper_parallel.platform import get_platform 

20plat = get_platform() 

21 

22 

23class CheckpointPolicy(enum.Enum): 

24 """ 

25 Enum for specifying the policy for checkpointing during backpropagation. 

26 

27 This enum extends PyTorch's selective activation checkpointing policies 

28 by introducing a SWAP-based strategy, which allows activation tensors 

29 to be offloaded during the forward pass and loaded back before backward 

30 computation. 

31 

32 For PyTorch native policies (SAVE / RECOMPUTE semantics and MUST vs PREFER), 

33 see: https://docs.pytorch.org/docs/2.6/checkpoint.html#torch.utils.checkpoint.CheckpointPolicy 

34 

35 Additional policy: 

36 

37 - ``MUST_SWAP``: The operation's output is offloaded to host memory during the 

38 forward pass and loaded back asynchronously before backward computation. The backward 

39 pass reuses the loaded activations without recomputation. 

40 

41 This policy must be used together with :class:`SwapManager` to coordinate 

42 asynchronous offload/load and stream synchronization. 

43 

44 .. note:: 

45 ``MUST_SWAP`` is typically applied to operations that are either 

46 computationally expensive or have large memory footprints. Note that 

47 swapping very small outputs may introduce additional overhead and 

48 reduce the effectiveness of asynchronous copy. 

49 """ 

50 MUST_SAVE = 0 

51 PREFER_SAVE = 1 

52 MUST_RECOMPUTE = 2 

53 PREFER_RECOMPUTE = 3 

54 

55 # Offload during forward, reload before backward. Requires SwapManager. 

56 MUST_SWAP = 4 

57 

58 

59def checkpoint(function, *args, swap_inputs=False, policy_fn=None, **kwargs): 

60 """ 

61 Apply activation checkpointing to a function with optional input swapping. 

62  

63 Args: 

64 function: The function to apply checkpointing to. 

65 *args: Arguments to pass to the function. 

66 swap_inputs (bool): Whether to enable input swapping using async_save_on_cpu context. 

67 policy_fn (callable, optional): Function that determines checkpoint policy for operations. 

68 **kwargs: Additional keyword arguments to pass to the function. 

69  

70 Returns: 

71 The result of applying the function with checkpointing. 

72 """ 

73 context_fn = partial(plat.create_selective_checkpoint_contexts, policy_fn) if policy_fn else plat.noop_context_fn 

74 context = plat.async_save_on_cpu if swap_inputs else contextlib.nullcontext 

75 with context(): 

76 return plat.checkpoint(function, *args, context_fn=context_fn, use_reentrant=False, **kwargs) 

77 

78 

79checkpoint_wrapper = partial(plat.ckpt_wrapper, checkpoint_fn=checkpoint)