1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# Adapted from
16# hyper_parallel/platform/torch/activation_checkpoint/activation_swap.py
17# adapted for MindSpore Cell API.
18# ============================================================================
19"""Activation Swap Wrapper implementation for MindSpore."""
20from abc import ABC, abstractmethod
21from collections.abc import Iterator
22from typing import Optional, Callable, Any, Union
23
24import mindspore as ms
25from mindspore import Tensor
26from mindspore.common.parameter import Parameter
27from mindspore.nn import Cell
28
29from hyper_parallel.core.activation_checkpoint.activation_checkpoint import CheckpointPolicy
30from hyper_parallel.core.activation_checkpoint.swap import Storage, SwapManager, SwapTensor
31
32
33_CKPT_WRAPPED_MODULE = "_ckpt_wrapped_module"
34
35
36def _strip_ckpt_wrapped_module_prefix(name: str) -> str:
37 """Remove the wrapper cell segment from a dotted MindSpore cell name."""
38 return ".".join(part for part in name.split(".") if part != _CKPT_WRAPPED_MODULE)
39
40
41class FuncCell(Cell):
42 """
43 Thin :class:`~mindspore.nn.Cell` adapter that wraps a plain callable.
44
45 Allows ordinary Python functions (or any callable without Cell
46 parameters) to be passed to :func:`checkpoint_wrapper` and
47 :func:`swap_wrapper` in place of a :class:`~mindspore.nn.Cell`.
48 The wrapped function is stored as ``_fn`` and invoked in
49 :meth:`construct`; the cell has no trainable parameters.
50
51 Args:
52 fn (callable): The function to wrap.
53
54 Example:
55 >>> wrapped = checkpoint_wrapper(lambda x: x * 2)
56 """
57
58 def __init__(self, fn: Callable):
59 super().__init__()
60 self._fn = fn
61
62 def construct(self, *args, **kwargs):
63 return self._fn(*args, **kwargs)
64
65
66class ActivationWrapper(Cell, ABC):
67 """
68 Base class for Activation Checkpoint Wrapper in MindSpore.
69
70 Wraps a :class:`mindspore.nn.Cell` and forwards attribute lookups,
71 parameter iteration, and indexing to the inner cell. Concrete
72 sub-classes must implement :meth:`construct`.
73
74 Not meant to be instantiated directly.
75 """
76
77 def __init__(self, module: Union[Cell, Callable]):
78 if callable(module) and not isinstance(module, Cell):
79 module = FuncCell(module)
80 super().__init__(auto_prefix=False)
81 self._ckpt_wrapped_module = module
82 self._wrapped_param_names = {
83 id(param): param.name for _, param in module.parameters_and_names()
84 }
85
86 @abstractmethod
87 def construct(self, *args, **kwargs):
88 raise ValueError("Subclasses should implement construct().")
89
90 def __getattr__(self, name: str) -> Any:
91 """Forward missing attributes to the wrapped cell.
92
93 .. warning::
94 Do **not** call ``super().__getattr__(name)`` here.
95 MindSpore's ``Cell.__init__`` calls ``hasattr(self, "bprop")`` at
96 line 252 of ``cell.py`` *after* ``_cells`` is initialised as an
97 empty ``OrderedDict`` but *before* ``ActivationWrapper.__init__``
98 has registered ``_ckpt_wrapped_module`` into ``_cells``. The
99 PyTorch ``nn.Module.__init__`` is pure Python and never calls
100 ``hasattr`` on ``self``, so this issue does not arise there.
101
102 Using ``super().__getattr__`` here would raise ``AttributeError``
103 (``_ckpt_wrapped_module`` not yet in ``_cells``), the fallback
104 ``getattr(self._ckpt_wrapped_module, name)`` would access
105 ``self._ckpt_wrapped_module`` — triggering another
106 ``__getattr__("_ckpt_wrapped_module")`` — and the cycle repeats
107 as infinite recursion.
108
109 Instead we replicate ``Cell.__getattr__``'s own dict-probe logic
110 and fall through to the wrapped module only when it is already
111 registered.
112 """
113 for attr_dict in ('_params', '_buffers', '_cells', '_params_list'):
114 d = self.__dict__.get(attr_dict)
115 if d is not None and name in d:
116 return d[name]
117 cells = self.__dict__.get('_cells', {})
118 wrapped = cells.get(_CKPT_WRAPPED_MODULE)
119 if wrapped is not None:
120 return getattr(wrapped, name)
121 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
122
123 @property
124 def unwrap_cell(self) -> Cell:
125 """Recursively return the innermost wrapped cell."""
126 return self._ckpt_wrapped_module
127
128 def __getitem__(self, key: int) -> Any:
129 """Forward indexing calls in case the wrapped cell is a SequentialCell."""
130 return self._ckpt_wrapped_module.__getitem__(key) # type: ignore[operator]
131
132 def cells_and_names(self, cells=None, name_prefix=''):
133 """
134 Return wrapped cells without exposing the wrapper storage prefix.
135
136 MindSpore registers ``_ckpt_wrapped_module`` as a real child cell, so
137 the default :meth:`Cell.cells_and_names` would expose names such as
138 ``layer._ckpt_wrapped_module.attn``. Strip that implementation detail
139 so downstream code sees the same names as it would for the unwrapped
140 model.
141 """
142 for cell_name, cell in super().cells_and_names(cells, name_prefix):
143 yield _strip_ckpt_wrapped_module_prefix(cell_name), cell
144
145 def parameters_and_names(
146 self,
147 name_prefix: str = '',
148 expand: bool = True,
149 ) -> Iterator[tuple[str, Parameter]]:
150 """
151 Override :meth:`parameters_and_names` to strip the wrapper prefix.
152
153 Removes all occurrences of ``_ckpt_wrapped_module.`` from parameter
154 names so that a checkpoint saved from this wrapper is compatible with
155 the unwrapped cell.
156
157 Args:
158 name_prefix (str): Prefix prepended to every parameter name.
159 expand (bool): Whether to recursively expand sub-cells.
160
161 Yields:
162 tuple[str, Parameter]: ``(name, parameter)`` pairs with the
163 wrapper prefix removed.
164 """
165 for param_name, param in super().parameters_and_names(name_prefix, expand):
166 yield _strip_ckpt_wrapped_module_prefix(param_name), param
167
168 def update_parameters_name(self, prefix='', recurse=True):
169 """
170 Update wrapped parameter names without collapsing existing full paths.
171
172 When a wrapper replaces an already-registered child cell, the wrapped
173 parameters usually already have globally unique names such as
174 ``0.attn.qkv.weight``. MindSpore will still call
175 ``wrapper.update_parameters_name("attn.")`` during reassignment; if we
176 blindly apply that prefix again through the wrapper view, those names
177 are rewritten to ``attn.qkv.weight`` and collide across layers.
178
179 For parameters that already contain the requested child prefix in their
180 existing full name, keep the current name unchanged. For fresh
181 standalone modules that only have local names like ``qkv.weight``,
182 synthesize the prefixed name as usual.
183 """
184 if prefix is None:
185 prefix = ''
186 for local_name, param in self._ckpt_wrapped_module.parameters_and_names(expand=recurse):
187 original_name = self._wrapped_param_names.get(id(param), param.name)
188 if prefix and (original_name.startswith(prefix) or f".{prefix}" in original_name):
189 new_name = original_name
190 elif prefix:
191 new_name = prefix + local_name
192 else:
193 new_name = local_name
194 if new_name != param.name:
195 param.is_init = False
196 param.name = new_name
197 self._wrapped_param_names[id(param)] = new_name
198
199
200def base_check_fn(tensor: Any) -> bool:
201 """
202 Basic eligibility check: returns ``True`` when *tensor* may be offloaded.
203
204 Skips:
205
206 * Non-tensor objects.
207 * :class:`~mindspore.common.parameter.Parameter` objects.
208 * Empty tensors (zero elements).
209
210 Args:
211 tensor: The value to test.
212
213 Returns:
214 bool: ``True`` if the tensor is eligible for CPU offloading.
215 """
216 if not isinstance(tensor, Tensor):
217 return False
218 if tensor.param_info is not None:
219 return False
220 if tensor.untyped_storage().size() == 0:
221 return False
222 return True
223
224
225def _normalize_device(device: str) -> str:
226 if ":" in device:
227 return device.split(":", maxsplit=1)[0]
228 return device
229
230
231class AsyncSaveOnCpu(ms.saved_tensors_hooks):
232 """
233 Context manager to offload tensors to CPU during forward pass.
234 """
235 def __init__(self, policy_fn=None) -> None:
236 self.add_to_storage = False
237 self.storage = Storage()
238 self.count_idx = 0
239 self.pack_count = 0
240 self.unpack_count = 0
241 self.policy_fn = policy_fn
242
243 def pack_to_cpu(tensor: ms.Tensor):
244 # skip ineligible tensors
245 if not base_check_fn(tensor):
246 return tensor
247
248 if (policy_fn is not None) and (policy_fn(tensor)==CheckpointPolicy.MUST_SAVE):
249 return tensor
250
251 group_name = SwapManager().get_current_group_name()
252 if not self.add_to_storage:
253 SwapManager().add_storage(group_name, self.storage)
254 self.add_to_storage = True
255 funcname = f"{group_name}::{tensor.shape}"
256 self.storage.swap_storage[self.count_idx].append(SwapTensor(tensor, funcname))
257 idx = self.count_idx
258 self.count_idx += 1
259 self.pack_count += 1
260 return idx
261
262 def unpack_from_cpu(idx) -> ms.Tensor:
263 if isinstance(idx, ms.Tensor):
264 return idx
265
266 swap_tensor = self.storage.swap_storage[idx].pop(0)
267 tensor = swap_tensor.get_val()
268 self.unpack_count += 1
269 if self.unpack_count == self.pack_count:
270 self.storage = None
271 return tensor
272
273 super().__init__(pack_to_cpu, unpack_from_cpu)
274
275
276class SwapWrapper(ActivationWrapper):
277 """
278 MindSpore counterpart of :class:`~hyper_parallel.platform.torch
279 .activation_checkpoint.activation_swap.SwapWrapper`.
280
281 Wraps a :class:`~mindspore.nn.Cell` and applies async activation swap
282 during the forward pass via the platform's ``async_save_on_cpu`` context
283 manager. Falls back to a no-op context when that context is not yet
284 available on the current platform.
285
286 Args:
287 mod (Cell): The cell whose intermediate activations should be swapped.
288 policy_fn (callable, optional): Per-tensor swap policy; see
289 :class:`AsyncSaveOnCpu`.
290
291 Example:
292 >>> from hyper_parallel.platform.mindspore.activation_checkpoint import swap_wrapper
293 >>> model.layers[i].attn = swap_wrapper(model.layers[i].attn, policy_fn)
294 """
295
296 def __init__(self, mod: Union[Cell, Callable], policy_fn: Optional[Callable] = None):
297 super().__init__(mod)
298 self.policy_fn = policy_fn
299
300 def construct(self, *args, **kwargs):
301 with AsyncSaveOnCpu(policy_fn=self.policy_fn):
302 return self._ckpt_wrapped_module(*args, **kwargs)
303
304
305def swap_wrapper(module: Union[Cell, Callable], policy_fn: Optional[Callable] = None) -> SwapWrapper:
306 """
307 Wrap *module* with async activation swap.
308
309 Args:
310 module (Cell or callable): The cell or plain function to wrap.
311 If a plain callable is passed it is automatically wrapped in a
312 :class:`FuncCell` before being stored.
313 policy_fn (callable, optional): Per-tensor swap policy; see
314 :class:`AsyncSaveOnCpu`.
315
316 Returns:
317 SwapWrapper: The wrapped cell with activation swap enabled.
318 """
319 return SwapWrapper(module, policy_fn)