Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / activation_checkpoint / checkpoint_wrapper.py: 79%

14 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py 

16# adapted for MindSpore Cell API. 

17# ============================================================================ 

18"""Activation Checkpoint Wrapper implementation for MindSpore.""" 

19# pylint: disable=W0613 

20from typing import Optional, Callable, Union 

21 

22from mindspore.nn import Cell 

23 

24from hyper_parallel.platform.mindspore.activation_checkpoint.activation_swap import ActivationWrapper 

25 

26 

27class CheckpointWrapper(ActivationWrapper): 

28 """ 

29 Wrap a MindSpore :class:`~mindspore.nn.Cell` with activation recomputation 

30 (gradient checkpointing). 

31 

32 On construction the wrapped cell is marked for recomputation via 

33 :meth:`Cell.recompute`, which is effective in semi-auto and 

34 auto-parallel graph-mode training on Ascend/GPU. 

35 

36 When *checkpoint_fn* is supplied it is called in :meth:`construct` 

37 instead, which allows callers to inject a custom recompute strategy 

38 (e.g. selective activation checkpoint). Any extra keyword arguments 

39 passed to the constructor are forwarded to *checkpoint_fn* at every 

40 forward call. 

41 

42 Args: 

43 mod (Cell): The cell to wrap. 

44 checkpoint_fn (callable, optional): Custom checkpoint/recompute 

45 function with signature 

46 ``checkpoint_fn(cell, *args, **checkpoint_fn_kwargs, **kwargs)``. 

47 When ``None``, MindSpore's native :meth:`Cell.recompute` is used. 

48 **checkpoint_fn_kwargs: Extra keyword arguments forwarded to 

49 *checkpoint_fn* at every forward call. 

50 

51 Example: 

52 >>> from hyper_parallel.platform.mindspore.activation_checkpoint import checkpoint_wrapper 

53 >>> wrapped = checkpoint_wrapper(my_cell) 

54 >>> output = wrapped(inputs) 

55 """ 

56 

57 def __init__( 

58 self, 

59 mod: Union[Cell, Callable], 

60 checkpoint_fn: Optional[Callable] = None, 

61 **checkpoint_fn_kwargs, 

62 ): 

63 super().__init__(mod) 

64 self.checkpoint_fn = checkpoint_fn 

65 self.checkpoint_fn_kwargs = checkpoint_fn_kwargs 

66 

67 def construct(self, *args, **kwargs): 

68 if self.checkpoint_fn is not None: 

69 return self.checkpoint_fn( 

70 self._ckpt_wrapped_module, 

71 *args, 

72 **self.checkpoint_fn_kwargs, 

73 **kwargs, 

74 ) 

75 return self._ckpt_wrapped_module(*args, **kwargs) 

76 

77 

78def checkpoint_wrapper( 

79 module: Union[Cell, Callable], 

80 checkpoint_fn: Optional[Callable] = None, 

81 **checkpoint_fn_kwargs, 

82) -> CheckpointWrapper: 

83 """ 

84 Wrap *module* with activation recomputation (gradient checkpointing). 

85 

86 This is the MindSpore counterpart of 

87 ``torch.distributed.algorithms._checkpoint.checkpoint_wrapper.checkpoint_wrapper``. 

88 

89 Args: 

90 module (Cell or callable): The cell or plain function to wrap. 

91 If a plain callable is passed it is automatically wrapped in a 

92 :class:`~hyper_parallel.platform.mindspore.activation_checkpoint\ 

93.activation_swap.FuncCell` before being stored, and the native 

94 :meth:`Cell.recompute` call is skipped (use *checkpoint_fn* for 

95 custom recompute logic in that case). 

96 checkpoint_fn (callable, optional): Custom recompute function. When 

97 ``None`` (default), MindSpore's native :meth:`Cell.recompute` is 

98 used (Cell inputs only). 

99 **checkpoint_fn_kwargs: Extra keyword arguments forwarded to 

100 *checkpoint_fn* on every forward call. 

101 

102 Returns: 

103 CheckpointWrapper: The wrapped cell with activation recomputation 

104 enabled. 

105 

106 Example: 

107 >>> from hyper_parallel.platform.mindspore.activation_checkpoint import checkpoint_wrapper 

108 >>> model.layers[i] = checkpoint_wrapper(model.layers[i]) 

109 >>> wrapped_fn = checkpoint_wrapper(lambda x: x * 2) 

110 """ 

111 return CheckpointWrapper(module, checkpoint_fn=checkpoint_fn, **checkpoint_fn_kwargs)