Coverage for hyper_parallel / platform / torch / activation_checkpoint / sac.py: 61%

105 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15# Adapted from https://github.com/pytorch/pytorch/blob/release/2.6/torch/utils/checkpoint.py 

16# enhanced with selective checkpoint support swap 

17# ============================================================================ 

18"""enhanced with selective checkpoint support swap""" 

19# pylint: disable=W0212, W0613, C0115, C0116, C0103, R1705 

20from typing import Any, Optional, Union 

21 

22import torch 

23import torch.fx.traceback as fx_traceback 

24from torch._functorch._aot_autograd.functional_utils import is_fun 

25from torch.utils._pytree import tree_map 

26from torch.utils._python_dispatch import TorchDispatchMode 

27from hyper_parallel.core.activation_checkpoint import CheckpointPolicy # patch code 

28from hyper_parallel.core.activation_checkpoint.swap import SwapManager, SwapTensor, Storage # patch code 

29 

30def _is_compiling(func, args, kwargs): 

31 # Check if we are under AOTAutograd tracing 

32 # There should probably be a better way to do this... 

33 # TODO: unify _is_compiling across all compile stacks 

34 for arg in args: 

35 if isinstance(arg, torch.Tensor) and is_fun(arg): 

36 return True 

37 return False 

38 

39 

40class _VersionWrapper: 

41 # Check that cached tensors are not mutated. 

42 def __init__(self, val): 

43 self.val: Union[torch.Tensor, Any] = val 

44 self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None 

45 

46 def get_val(self, allow_cache_entry_mutation): 

47 if self.version is not None and not allow_cache_entry_mutation: 

48 if self.val._version != self.version: 

49 # Can we give user a stack trace of where the mutation happened? 

50 raise RuntimeError( 

51 "Tensor cached during selective activation checkpoint has been mutated" 

52 ) 

53 return self.val 

54 

55 

56def _maybe_detach(x, any_ret_has_alias_info): 

57 # We detach for two separate reasons: 

58 # - For view ops, we need to ensure that when the tensor is returned from 

59 # CachedDispatchMode, as_view sees that the AutogradMeta is nullptr 

60 # - Avoid reference cycles 

61 # For case 1, it is not enough to check whether x has differentiable dtype 

62 # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g. 

63 # when the tensor is a view. 

64 if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info): 

65 with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False): 

66 # Ensure that view performed beneath autograd properly propagates 

67 # version counter. TODO: Use reentrant_dispatch instead of 

68 # manually manipulating dispatch keys. Using reentrant_dispatch 

69 # would respect inference_mode, though that is not relevant for 

70 # this case. 

71 x = x.detach() 

72 return x 

73 

74 

75class SelectiveCheckpointContext: 

76 """ 

77 Context passed to policy function during selective checkpointing. 

78 

79 This class is used to pass relevant metadata to the policy function during 

80 selective checkpointing. The metadata includes whether the current invocation 

81 of the policy function is during recomputation or not. 

82 

83 Example: 

84 >>> # xdoctest: +SKIP(stub) 

85 >>> 

86 >>> def policy_fn(ctx, op, *args, **kwargs): 

87 >>> print(ctx.is_recompute) 

88 >>> 

89 >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 

90 >>> 

91 >>> out = torch.utils.checkpoint.checkpoint( 

92 >>> fn, x, y, 

93 >>> use_reentrant=False, 

94 >>> context_fn=context_fn, 

95 >>> ) 

96 """ 

97 def __init__(self, *, is_recompute): 

98 self.is_recompute = is_recompute 

99 

100 

101def _policy_from_bool(b): 

102 # For backward compatibility 

103 return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE 

104 

105 

106SAC_IGNORED_OPS = { 

107 # AC inserts different number of detach during forward and recompute. 

108 torch.ops.aten.detach.default, 

109 # AC's determinism check invokes additional metadata ops during forward. 

110 # With subclasses involved, these metadata ops become dispatchable, this 

111 # can result in incorrectness if these ops are selected cached. 

112 torch.ops.prim.device.default, 

113} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) 

114 

115 

116class _CachingTorchDispatchMode(TorchDispatchMode): 

117 # Used together with _CachedTorchDispatchMode to implement SAC. 

118 def __init__(self, policy_fn, storage): 

119 self.policy_fn = policy_fn 

120 self.storage = storage 

121 self.add_to_storage = False 

122 

123 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 

124 if func in SAC_IGNORED_OPS: 

125 return func(*args, **kwargs) 

126 

127 kwargs = {} if kwargs is None else kwargs 

128 policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False), 

129 func, *args, **kwargs) 

130 if isinstance(policy, bool): 

131 policy = _policy_from_bool(policy) 

132 

133 is_compiling = _is_compiling(func, args, kwargs) 

134 

135 if is_compiling: 

136 # Overwrite each node's "recompute" tag to add in the user annotation. 

137 fx_traceback.current_meta["recompute"] = policy 

138 

139 out = func(*args, **kwargs) 

140 

141 any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) 

142 

143 if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE): 

144 storage = self.storage.save_storage[func] # patch code 

145 storage.append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)) 

146 elif policy == CheckpointPolicy.MUST_SWAP: # patch code 

147 if not self.add_to_storage: 

148 group_name = SwapManager().get_current_group_name() 

149 SwapManager().add_storage(group_name, self.storage) 

150 self.add_to_storage = True 

151 storage = self.storage.swap_storage[func] 

152 storage.append(tree_map(lambda x: SwapTensor(_maybe_detach(x, any_ret_has_alias_info)), out)) 

153 return out 

154 

155 

156class _CachedTorchDispatchMode(TorchDispatchMode): 

157 # Used together with _CachedTorchDispatchMode to implement SAC. 

158 def __init__(self, policy_fn, storage, allow_cache_entry_mutation): 

159 self.policy_fn = policy_fn 

160 self.storage = storage 

161 self.allow_cache_entry_mutation = allow_cache_entry_mutation 

162 

163 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 

164 if func in SAC_IGNORED_OPS: 

165 return func(*args, **kwargs) 

166 

167 kwargs = {} if kwargs is None else kwargs 

168 policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True), 

169 func, *args, **kwargs) 

170 if isinstance(policy, bool): 

171 policy = _policy_from_bool(policy) 

172 

173 is_compiling = _is_compiling(func, args, kwargs) 

174 

175 if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: 

176 storage = self.storage.save_storage.get(func) # patch code 

177 if storage is None: 

178 raise RuntimeError(f"{func} encountered during backward, but not found in storage") 

179 if len(storage) == 0: 

180 raise RuntimeError( 

181 "Trying to backward an extra time. You are only allowed to backward once " 

182 "on any region computed under selective activation checkpoint." 

183 ) 

184 out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0)) 

185 elif policy == CheckpointPolicy.MUST_SWAP: # patch code 

186 storage = self.storage.swap_storage.get(func) 

187 if storage is None: 

188 raise RuntimeError(f"{func} encountered during backward, but not found in storage") 

189 if len(storage) == 0: 

190 raise RuntimeError( 

191 "Trying to backward an extra time. You are only allowed to backward once " 

192 "on any region computed under selective activation checkpoint." 

193 ) 

194 out = tree_map(lambda x: x.get_val(), storage.pop(0)) 

195 else: 

196 out = func(*args, **kwargs) 

197 return out 

198 

199 

200def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): 

201 """ 

202 Helper to avoid recomputing certain ops during activation checkpointing. 

203 

204 Use this with `torch.utils.checkpoint.checkpoint` to control which 

205 operations are recomputed during the backward pass. 

206 

207 Args: 

208 policy_fn_or_list (Callable or List): 

209 - If a policy function is provided, it should accept a 

210 :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and 

211 kwargs to the op, and return a :class:`CheckpointPolicy` enum value 

212 indicating whether the execution of the op should be recomputed or not. 

213 - If a list of operations is provided, it is equivalent to a policy 

214 returning `CheckpointPolicy.MUST_SAVE` for the specified 

215 operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other 

216 operations. 

217 allow_cache_entry_mutation (bool, optional): By default, an error is 

218 raised if any tensors cached by selective activation checkpoint are 

219 mutated in order to ensure correctness. If set to `True`, this check 

220 is disabled. 

221 Returns: 

222 A tuple of two context managers. 

223 

224 Example: 

225 >>> # xdoctest: +REQUIRES(LINUX) 

226 >>> import functools 

227 >>> 

228 >>> x = torch.rand(10, 10, requires_grad=True) 

229 >>> y = torch.rand(10, 10, requires_grad=True) 

230 >>> 

231 >>> ops_to_save = [ 

232 >>> torch.ops.aten.mm.default, 

233 >>> ] 

234 >>> 

235 >>> def policy_fn(ctx, op, *args, **kwargs): 

236 >>> if op in ops_to_save: 

237 >>> return CheckpointPolicy.MUST_SAVE 

238 >>> else: 

239 >>> return CheckpointPolicy.PREFER_RECOMPUTE 

240 >>> 

241 >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 

242 >>> 

243 >>> # or equivalently 

244 >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) 

245 >>> 

246 >>> def fn(x, y): 

247 >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y 

248 >>> 

249 >>> out = torch.utils.checkpoint.checkpoint( 

250 >>> fn, x, y, 

251 >>> use_reentrant=False, 

252 >>> context_fn=context_fn, 

253 >>> ) 

254 """ 

255 # NB: If grad_mode is disabled, checkpoint would not run forward under 

256 # context_fn anyway, so proceed as usual. 

257 if isinstance(policy_fn_or_list, list): 

258 for op in policy_fn_or_list: 

259 if not isinstance(op, torch._ops.OpOverload): 

260 _extra_msg = ( 

261 "Please update the OpOverloadPacket to a specific OpOverload." 

262 "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." 

263 ) if isinstance(op, torch._ops.OpOverloadPacket) else "" 

264 raise ValueError( 

265 f"Expected op in `op_list` to be an OpOverload but got: {op} " 

266 f"of type {type(op)}. {_extra_msg}" 

267 ) 

268 

269 def policy_fn(ctx, op, *args, **kwargs): 

270 if op in policy_fn_or_list: 

271 return CheckpointPolicy.MUST_SAVE 

272 else: 

273 return CheckpointPolicy.PREFER_RECOMPUTE 

274 elif callable(policy_fn_or_list): 

275 policy_fn = policy_fn_or_list 

276 else: 

277 raise TypeError("policy_fn_or_list must be either a function or a list of ops.") 

278 

279 storage = Storage() # patch code 

280 return ( 

281 _CachingTorchDispatchMode(policy_fn, storage), 

282 _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation), 

283 )