Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / activation_checkpoint / sac.py: 27%
79 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""enhanced with selective checkpoint support swap"""
16# pylint: disable=W0212, W0613, C0115, C0116, C0103, R1705
17from typing import Any, Optional, Union
19import mindspore as ms
20from mindspore import MsDispatchMode
21from hyper_parallel.core.activation_checkpoint.swap import SwapManager, Storage, SwapTensor
22from hyper_parallel.core.activation_checkpoint import CheckpointPolicy
23from hyper_parallel.platform import get_platform
25platform = get_platform()
27class _VersionWrapper:
28 # Check that cached tensors are not mutated.
29 def __init__(self, val):
30 self.val: Union[ms.Tensor, Any] = val
31 self.version: Optional[int] = val._version if isinstance(val, ms.Tensor) else None
33 def get_val(self, allow_cache_entry_mutation):
34 if self.version is not None and not allow_cache_entry_mutation:
35 if self.val._version != self.version:
36 # Can we give user a stack trace of where the mutation happened?
37 raise RuntimeError(
38 "Tensor cached during selective activation checkpoint has been mutated"
39 )
40 return self.val
43def _maybe_detach(x):
44 if isinstance(x, ms.Tensor) and (x.is_floating_point() or x.is_complex()):
45 x = ms.ops.stop_gradient(x)
46 return x
49class SelectiveCheckpointContext:
50 def __init__(self, *, is_recompute):
51 self.is_recompute = is_recompute
53SAC_IGNORED_OPS = {"StopGradient"}
56class _CachingMindSporeDispatchMode(MsDispatchMode):
57 def __init__(self, policy_fn, storage):
58 self.policy_fn = policy_fn
59 self.storage = storage
60 self.add_to_storage = False
62 def __ms_dispatch__(self, func, args=(), kwargs=None):
63 kwargs = {} if kwargs is None else kwargs
64 if func.name in SAC_IGNORED_OPS:
65 return func(*args, **kwargs)
66 policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False),
67 func, *args, **kwargs)
69 out = func(*args, **kwargs)
71 if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE):
72 storage = self.storage.save_storage[func.name]
73 storage.append(platform.tree_map(lambda x: _VersionWrapper(_maybe_detach(x)), out))
74 elif policy == CheckpointPolicy.MUST_SWAP:
75 group_name = SwapManager().get_current_group_name()
76 if not self.add_to_storage:
77 SwapManager().add_storage(group_name, self.storage)
78 self.add_to_storage = True
79 storage = self.storage.swap_storage[func.name]
80 funcname = f"{group_name}::{func.name}"
81 storage.append(platform.tree_map(lambda x: SwapTensor(_maybe_detach(x), funcname), out))
82 return out
85class _CachedMindSporeDispatchMode(MsDispatchMode):
86 def __init__(self, policy_fn, storage, allow_cache_entry_mutation):
87 self.policy_fn = policy_fn
88 self.storage = storage
89 self.allow_cache_entry_mutation = allow_cache_entry_mutation
91 def __ms_dispatch__(self, func, args=(), kwargs=None):
92 kwargs = {} if kwargs is None else kwargs
93 if func.name in SAC_IGNORED_OPS:
94 return func(*args, **kwargs)
96 policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True),
97 func, *args, **kwargs)
99 if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE):
100 storage = self.storage.save_storage.get(func.name) # patch code
101 if storage is None:
102 raise RuntimeError(f"{func} encountered during backward, but not found in storage")
103 if len(storage) == 0:
104 raise RuntimeError(
105 "Trying to backward an extra time. You are only allowed to backward once "
106 "on any region computed under selective activation checkpoint."
107 )
108 out = platform.tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0))
109 elif policy == CheckpointPolicy.MUST_SWAP: # patch code
110 storage = self.storage.swap_storage.get(func.name)
111 if storage is None:
112 raise RuntimeError(f"{func} encountered during backward, but not found in storage")
113 if len(storage) == 0:
114 raise RuntimeError(
115 "Trying to backward an extra time. You are only allowed to backward once "
116 "on any region computed under selective activation checkpoint."
117 )
118 out = platform.tree_map(lambda x: x.get_val(), storage.pop(0))
119 else:
120 out = func(*args, **kwargs)
121 return out
124def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
125 if callable(policy_fn_or_list):
126 policy_fn = policy_fn_or_list
127 else:
128 raise TypeError("policy_fn_or_list must be either a function or a list of ops.")
130 storage = Storage()
131 return (
132 _CachingMindSporeDispatchMode(policy_fn, storage),
133 _CachedMindSporeDispatchMode(policy_fn, storage, allow_cache_entry_mutation)
134 )