Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / activation_checkpoint / sac.py: 27%

79 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""enhanced with selective checkpoint support swap""" 

16# pylint: disable=W0212, W0613, C0115, C0116, C0103, R1705 

17from typing import Any, Optional, Union 

18 

19import mindspore as ms 

20from mindspore import MsDispatchMode 

21from hyper_parallel.core.activation_checkpoint.swap import SwapManager, Storage, SwapTensor 

22from hyper_parallel.core.activation_checkpoint import CheckpointPolicy 

23from hyper_parallel.platform import get_platform 

24 

25platform = get_platform() 

26 

27class _VersionWrapper: 

28 # Check that cached tensors are not mutated. 

29 def __init__(self, val): 

30 self.val: Union[ms.Tensor, Any] = val 

31 self.version: Optional[int] = val._version if isinstance(val, ms.Tensor) else None 

32 

33 def get_val(self, allow_cache_entry_mutation): 

34 if self.version is not None and not allow_cache_entry_mutation: 

35 if self.val._version != self.version: 

36 # Can we give user a stack trace of where the mutation happened? 

37 raise RuntimeError( 

38 "Tensor cached during selective activation checkpoint has been mutated" 

39 ) 

40 return self.val 

41 

42 

43def _maybe_detach(x): 

44 if isinstance(x, ms.Tensor) and (x.is_floating_point() or x.is_complex()): 

45 x = ms.ops.stop_gradient(x) 

46 return x 

47 

48 

49class SelectiveCheckpointContext: 

50 def __init__(self, *, is_recompute): 

51 self.is_recompute = is_recompute 

52 

53SAC_IGNORED_OPS = {"StopGradient"} 

54 

55 

56class _CachingMindSporeDispatchMode(MsDispatchMode): 

57 def __init__(self, policy_fn, storage): 

58 self.policy_fn = policy_fn 

59 self.storage = storage 

60 self.add_to_storage = False 

61 

62 def __ms_dispatch__(self, func, args=(), kwargs=None): 

63 kwargs = {} if kwargs is None else kwargs 

64 if func.name in SAC_IGNORED_OPS: 

65 return func(*args, **kwargs) 

66 policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False), 

67 func, *args, **kwargs) 

68 

69 out = func(*args, **kwargs) 

70 

71 if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE): 

72 storage = self.storage.save_storage[func.name] 

73 storage.append(platform.tree_map(lambda x: _VersionWrapper(_maybe_detach(x)), out)) 

74 elif policy == CheckpointPolicy.MUST_SWAP: 

75 group_name = SwapManager().get_current_group_name() 

76 if not self.add_to_storage: 

77 SwapManager().add_storage(group_name, self.storage) 

78 self.add_to_storage = True 

79 storage = self.storage.swap_storage[func.name] 

80 funcname = f"{group_name}::{func.name}" 

81 storage.append(platform.tree_map(lambda x: SwapTensor(_maybe_detach(x), funcname), out)) 

82 return out 

83 

84 

85class _CachedMindSporeDispatchMode(MsDispatchMode): 

86 def __init__(self, policy_fn, storage, allow_cache_entry_mutation): 

87 self.policy_fn = policy_fn 

88 self.storage = storage 

89 self.allow_cache_entry_mutation = allow_cache_entry_mutation 

90 

91 def __ms_dispatch__(self, func, args=(), kwargs=None): 

92 kwargs = {} if kwargs is None else kwargs 

93 if func.name in SAC_IGNORED_OPS: 

94 return func(*args, **kwargs) 

95 

96 policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True), 

97 func, *args, **kwargs) 

98 

99 if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE): 

100 storage = self.storage.save_storage.get(func.name) # patch code 

101 if storage is None: 

102 raise RuntimeError(f"{func} encountered during backward, but not found in storage") 

103 if len(storage) == 0: 

104 raise RuntimeError( 

105 "Trying to backward an extra time. You are only allowed to backward once " 

106 "on any region computed under selective activation checkpoint." 

107 ) 

108 out = platform.tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0)) 

109 elif policy == CheckpointPolicy.MUST_SWAP: # patch code 

110 storage = self.storage.swap_storage.get(func.name) 

111 if storage is None: 

112 raise RuntimeError(f"{func} encountered during backward, but not found in storage") 

113 if len(storage) == 0: 

114 raise RuntimeError( 

115 "Trying to backward an extra time. You are only allowed to backward once " 

116 "on any region computed under selective activation checkpoint." 

117 ) 

118 out = platform.tree_map(lambda x: x.get_val(), storage.pop(0)) 

119 else: 

120 out = func(*args, **kwargs) 

121 return out 

122 

123 

124def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): 

125 if callable(policy_fn_or_list): 

126 policy_fn = policy_fn_or_list 

127 else: 

128 raise TypeError("policy_fn_or_list must be either a function or a list of ops.") 

129 

130 storage = Storage() 

131 return ( 

132 _CachingMindSporeDispatchMode(policy_fn, storage), 

133 _CachedMindSporeDispatchMode(policy_fn, storage, allow_cache_entry_mutation) 

134 )