1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py
16# adapted for MindSpore Cell API.
17# ============================================================================
18"""Activation Checkpoint Wrapper implementation for MindSpore."""
19# pylint: disable=W0613
20from typing import Optional, Callable, Union
21
22from mindspore.nn import Cell
23
24from hyper_parallel.platform.mindspore.activation_checkpoint.activation_swap import ActivationWrapper
25
26
27class CheckpointWrapper(ActivationWrapper):
28 """
29 Wrap a MindSpore :class:`~mindspore.nn.Cell` with activation recomputation
30 (gradient checkpointing).
31
32 On construction the wrapped cell is marked for recomputation via
33 :meth:`Cell.recompute`, which is effective in semi-auto and
34 auto-parallel graph-mode training on Ascend/GPU.
35
36 When *checkpoint_fn* is supplied it is called in :meth:`construct`
37 instead, which allows callers to inject a custom recompute strategy
38 (e.g. selective activation checkpoint). Any extra keyword arguments
39 passed to the constructor are forwarded to *checkpoint_fn* at every
40 forward call.
41
42 Args:
43 mod (Cell): The cell to wrap.
44 checkpoint_fn (callable, optional): Custom checkpoint/recompute
45 function with signature
46 ``checkpoint_fn(cell, *args, **checkpoint_fn_kwargs, **kwargs)``.
47 When ``None``, MindSpore's native :meth:`Cell.recompute` is used.
48 **checkpoint_fn_kwargs: Extra keyword arguments forwarded to
49 *checkpoint_fn* at every forward call.
50
51 Example:
52 >>> from hyper_parallel.platform.mindspore.activation_checkpoint import checkpoint_wrapper
53 >>> wrapped = checkpoint_wrapper(my_cell)
54 >>> output = wrapped(inputs)
55 """
56
57 def __init__(
58 self,
59 mod: Union[Cell, Callable],
60 checkpoint_fn: Optional[Callable] = None,
61 **checkpoint_fn_kwargs,
62 ):
63 super().__init__(mod)
64 self.checkpoint_fn = checkpoint_fn
65 self.checkpoint_fn_kwargs = checkpoint_fn_kwargs
66
67 def construct(self, *args, **kwargs):
68 if self.checkpoint_fn is not None:
69 return self.checkpoint_fn(
70 self._ckpt_wrapped_module,
71 *args,
72 **self.checkpoint_fn_kwargs,
73 **kwargs,
74 )
75 return self._ckpt_wrapped_module(*args, **kwargs)
76
77
78def checkpoint_wrapper(
79 module: Union[Cell, Callable],
80 checkpoint_fn: Optional[Callable] = None,
81 **checkpoint_fn_kwargs,
82) -> CheckpointWrapper:
83 """
84 Wrap *module* with activation recomputation (gradient checkpointing).
85
86 This is the MindSpore counterpart of
87 ``torch.distributed.algorithms._checkpoint.checkpoint_wrapper.checkpoint_wrapper``.
88
89 Args:
90 module (Cell or callable): The cell or plain function to wrap.
91 If a plain callable is passed it is automatically wrapped in a
92 :class:`~hyper_parallel.platform.mindspore.activation_checkpoint\
93.activation_swap.FuncCell` before being stored, and the native
94 :meth:`Cell.recompute` call is skipped (use *checkpoint_fn* for
95 custom recompute logic in that case).
96 checkpoint_fn (callable, optional): Custom recompute function. When
97 ``None`` (default), MindSpore's native :meth:`Cell.recompute` is
98 used (Cell inputs only).
99 **checkpoint_fn_kwargs: Extra keyword arguments forwarded to
100 *checkpoint_fn* on every forward call.
101
102 Returns:
103 CheckpointWrapper: The wrapped cell with activation recomputation
104 enabled.
105
106 Example:
107 >>> from hyper_parallel.platform.mindspore.activation_checkpoint import checkpoint_wrapper
108 >>> model.layers[i] = checkpoint_wrapper(model.layers[i])
109 >>> wrapped_fn = checkpoint_wrapper(lambda x: x * 2)
110 """
111 return CheckpointWrapper(module, checkpoint_fn=checkpoint_fn, **checkpoint_fn_kwargs)