Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / activation_checkpoint / activation_swap.py: 50%

117 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15# Adapted from 

16# hyper_parallel/platform/torch/activation_checkpoint/activation_swap.py 

17# adapted for MindSpore Cell API. 

18# ============================================================================ 

19"""Activation Swap Wrapper implementation for MindSpore.""" 

20from abc import ABC, abstractmethod 

21from collections.abc import Iterator 

22from typing import Optional, Callable, Any, Union 

23 

24import mindspore as ms 

25from mindspore import Tensor 

26from mindspore.common.parameter import Parameter 

27from mindspore.nn import Cell 

28 

29from hyper_parallel.core.activation_checkpoint.activation_checkpoint import CheckpointPolicy 

30from hyper_parallel.core.activation_checkpoint.swap import Storage, SwapManager, SwapTensor 

31 

32 

33_CKPT_WRAPPED_MODULE = "_ckpt_wrapped_module" 

34 

35 

36def _strip_ckpt_wrapped_module_prefix(name: str) -> str: 

37 """Remove the wrapper cell segment from a dotted MindSpore cell name.""" 

38 return ".".join(part for part in name.split(".") if part != _CKPT_WRAPPED_MODULE) 

39 

40 

41class FuncCell(Cell): 

42 """ 

43 Thin :class:`~mindspore.nn.Cell` adapter that wraps a plain callable. 

44 

45 Allows ordinary Python functions (or any callable without Cell 

46 parameters) to be passed to :func:`checkpoint_wrapper` and 

47 :func:`swap_wrapper` in place of a :class:`~mindspore.nn.Cell`. 

48 The wrapped function is stored as ``_fn`` and invoked in 

49 :meth:`construct`; the cell has no trainable parameters. 

50 

51 Args: 

52 fn (callable): The function to wrap. 

53 

54 Example: 

55 >>> wrapped = checkpoint_wrapper(lambda x: x * 2) 

56 """ 

57 

58 def __init__(self, fn: Callable): 

59 super().__init__() 

60 self._fn = fn 

61 

62 def construct(self, *args, **kwargs): 

63 return self._fn(*args, **kwargs) 

64 

65 

66class ActivationWrapper(Cell, ABC): 

67 """ 

68 Base class for Activation Checkpoint Wrapper in MindSpore. 

69 

70 Wraps a :class:`mindspore.nn.Cell` and forwards attribute lookups, 

71 parameter iteration, and indexing to the inner cell. Concrete 

72 sub-classes must implement :meth:`construct`. 

73 

74 Not meant to be instantiated directly. 

75 """ 

76 

77 def __init__(self, module: Union[Cell, Callable]): 

78 if callable(module) and not isinstance(module, Cell): 

79 module = FuncCell(module) 

80 super().__init__(auto_prefix=False) 

81 self._ckpt_wrapped_module = module 

82 self._wrapped_param_names = { 

83 id(param): param.name for _, param in module.parameters_and_names() 

84 } 

85 

86 @abstractmethod 

87 def construct(self, *args, **kwargs): 

88 raise ValueError("Subclasses should implement construct().") 

89 

90 def __getattr__(self, name: str) -> Any: 

91 """Forward missing attributes to the wrapped cell. 

92 

93 .. warning:: 

94 Do **not** call ``super().__getattr__(name)`` here. 

95 MindSpore's ``Cell.__init__`` calls ``hasattr(self, "bprop")`` at 

96 line 252 of ``cell.py`` *after* ``_cells`` is initialised as an 

97 empty ``OrderedDict`` but *before* ``ActivationWrapper.__init__`` 

98 has registered ``_ckpt_wrapped_module`` into ``_cells``. The 

99 PyTorch ``nn.Module.__init__`` is pure Python and never calls 

100 ``hasattr`` on ``self``, so this issue does not arise there. 

101 

102 Using ``super().__getattr__`` here would raise ``AttributeError`` 

103 (``_ckpt_wrapped_module`` not yet in ``_cells``), the fallback 

104 ``getattr(self._ckpt_wrapped_module, name)`` would access 

105 ``self._ckpt_wrapped_module`` — triggering another 

106 ``__getattr__("_ckpt_wrapped_module")`` — and the cycle repeats 

107 as infinite recursion. 

108 

109 Instead we replicate ``Cell.__getattr__``'s own dict-probe logic 

110 and fall through to the wrapped module only when it is already 

111 registered. 

112 """ 

113 for attr_dict in ('_params', '_buffers', '_cells', '_params_list'): 

114 d = self.__dict__.get(attr_dict) 

115 if d is not None and name in d: 

116 return d[name] 

117 cells = self.__dict__.get('_cells', {}) 

118 wrapped = cells.get(_CKPT_WRAPPED_MODULE) 

119 if wrapped is not None: 

120 return getattr(wrapped, name) 

121 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") 

122 

123 @property 

124 def unwrap_cell(self) -> Cell: 

125 """Recursively return the innermost wrapped cell.""" 

126 return self._ckpt_wrapped_module 

127 

128 def __getitem__(self, key: int) -> Any: 

129 """Forward indexing calls in case the wrapped cell is a SequentialCell.""" 

130 return self._ckpt_wrapped_module.__getitem__(key) # type: ignore[operator] 

131 

132 def cells_and_names(self, cells=None, name_prefix=''): 

133 """ 

134 Return wrapped cells without exposing the wrapper storage prefix. 

135 

136 MindSpore registers ``_ckpt_wrapped_module`` as a real child cell, so 

137 the default :meth:`Cell.cells_and_names` would expose names such as 

138 ``layer._ckpt_wrapped_module.attn``. Strip that implementation detail 

139 so downstream code sees the same names as it would for the unwrapped 

140 model. 

141 """ 

142 for cell_name, cell in super().cells_and_names(cells, name_prefix): 

143 yield _strip_ckpt_wrapped_module_prefix(cell_name), cell 

144 

145 def parameters_and_names( 

146 self, 

147 name_prefix: str = '', 

148 expand: bool = True, 

149 ) -> Iterator[tuple[str, Parameter]]: 

150 """ 

151 Override :meth:`parameters_and_names` to strip the wrapper prefix. 

152 

153 Removes all occurrences of ``_ckpt_wrapped_module.`` from parameter 

154 names so that a checkpoint saved from this wrapper is compatible with 

155 the unwrapped cell. 

156 

157 Args: 

158 name_prefix (str): Prefix prepended to every parameter name. 

159 expand (bool): Whether to recursively expand sub-cells. 

160 

161 Yields: 

162 tuple[str, Parameter]: ``(name, parameter)`` pairs with the 

163 wrapper prefix removed. 

164 """ 

165 for param_name, param in super().parameters_and_names(name_prefix, expand): 

166 yield _strip_ckpt_wrapped_module_prefix(param_name), param 

167 

168 def update_parameters_name(self, prefix='', recurse=True): 

169 """ 

170 Update wrapped parameter names without collapsing existing full paths. 

171 

172 When a wrapper replaces an already-registered child cell, the wrapped 

173 parameters usually already have globally unique names such as 

174 ``0.attn.qkv.weight``. MindSpore will still call 

175 ``wrapper.update_parameters_name("attn.")`` during reassignment; if we 

176 blindly apply that prefix again through the wrapper view, those names 

177 are rewritten to ``attn.qkv.weight`` and collide across layers. 

178 

179 For parameters that already contain the requested child prefix in their 

180 existing full name, keep the current name unchanged. For fresh 

181 standalone modules that only have local names like ``qkv.weight``, 

182 synthesize the prefixed name as usual. 

183 """ 

184 if prefix is None: 

185 prefix = '' 

186 for local_name, param in self._ckpt_wrapped_module.parameters_and_names(expand=recurse): 

187 original_name = self._wrapped_param_names.get(id(param), param.name) 

188 if prefix and (original_name.startswith(prefix) or f".{prefix}" in original_name): 

189 new_name = original_name 

190 elif prefix: 

191 new_name = prefix + local_name 

192 else: 

193 new_name = local_name 

194 if new_name != param.name: 

195 param.is_init = False 

196 param.name = new_name 

197 self._wrapped_param_names[id(param)] = new_name 

198 

199 

200def base_check_fn(tensor: Any) -> bool: 

201 """ 

202 Basic eligibility check: returns ``True`` when *tensor* may be offloaded. 

203 

204 Skips: 

205 

206 * Non-tensor objects. 

207 * :class:`~mindspore.common.parameter.Parameter` objects. 

208 * Empty tensors (zero elements). 

209 

210 Args: 

211 tensor: The value to test. 

212 

213 Returns: 

214 bool: ``True`` if the tensor is eligible for CPU offloading. 

215 """ 

216 if not isinstance(tensor, Tensor): 

217 return False 

218 if tensor.param_info is not None: 

219 return False 

220 if tensor.untyped_storage().size() == 0: 

221 return False 

222 return True 

223 

224 

225def _normalize_device(device: str) -> str: 

226 if ":" in device: 

227 return device.split(":", maxsplit=1)[0] 

228 return device 

229 

230 

231class AsyncSaveOnCpu(ms.saved_tensors_hooks): 

232 """ 

233 Context manager to offload tensors to CPU during forward pass. 

234 """ 

235 def __init__(self, policy_fn=None) -> None: 

236 self.add_to_storage = False 

237 self.storage = Storage() 

238 self.count_idx = 0 

239 self.pack_count = 0 

240 self.unpack_count = 0 

241 self.policy_fn = policy_fn 

242 

243 def pack_to_cpu(tensor: ms.Tensor): 

244 # skip ineligible tensors 

245 if not base_check_fn(tensor): 

246 return tensor 

247 

248 if (policy_fn is not None) and (policy_fn(tensor)==CheckpointPolicy.MUST_SAVE): 

249 return tensor 

250 

251 group_name = SwapManager().get_current_group_name() 

252 if not self.add_to_storage: 

253 SwapManager().add_storage(group_name, self.storage) 

254 self.add_to_storage = True 

255 funcname = f"{group_name}::{tensor.shape}" 

256 self.storage.swap_storage[self.count_idx].append(SwapTensor(tensor, funcname)) 

257 idx = self.count_idx 

258 self.count_idx += 1 

259 self.pack_count += 1 

260 return idx 

261 

262 def unpack_from_cpu(idx) -> ms.Tensor: 

263 if isinstance(idx, ms.Tensor): 

264 return idx 

265 

266 swap_tensor = self.storage.swap_storage[idx].pop(0) 

267 tensor = swap_tensor.get_val() 

268 self.unpack_count += 1 

269 if self.unpack_count == self.pack_count: 

270 self.storage = None 

271 return tensor 

272 

273 super().__init__(pack_to_cpu, unpack_from_cpu) 

274 

275 

276class SwapWrapper(ActivationWrapper): 

277 """ 

278 MindSpore counterpart of :class:`~hyper_parallel.platform.torch 

279 .activation_checkpoint.activation_swap.SwapWrapper`. 

280 

281 Wraps a :class:`~mindspore.nn.Cell` and applies async activation swap 

282 during the forward pass via the platform's ``async_save_on_cpu`` context 

283 manager. Falls back to a no-op context when that context is not yet 

284 available on the current platform. 

285 

286 Args: 

287 mod (Cell): The cell whose intermediate activations should be swapped. 

288 policy_fn (callable, optional): Per-tensor swap policy; see 

289 :class:`AsyncSaveOnCpu`. 

290 

291 Example: 

292 >>> from hyper_parallel.platform.mindspore.activation_checkpoint import swap_wrapper 

293 >>> model.layers[i].attn = swap_wrapper(model.layers[i].attn, policy_fn) 

294 """ 

295 

296 def __init__(self, mod: Union[Cell, Callable], policy_fn: Optional[Callable] = None): 

297 super().__init__(mod) 

298 self.policy_fn = policy_fn 

299 

300 def construct(self, *args, **kwargs): 

301 with AsyncSaveOnCpu(policy_fn=self.policy_fn): 

302 return self._ckpt_wrapped_module(*args, **kwargs) 

303 

304 

305def swap_wrapper(module: Union[Cell, Callable], policy_fn: Optional[Callable] = None) -> SwapWrapper: 

306 """ 

307 Wrap *module* with async activation swap. 

308 

309 Args: 

310 module (Cell or callable): The cell or plain function to wrap. 

311 If a plain callable is passed it is automatically wrapped in a 

312 :class:`FuncCell` before being stored. 

313 policy_fn (callable, optional): Per-tensor swap policy; see 

314 :class:`AsyncSaveOnCpu`. 

315 

316 Returns: 

317 SwapWrapper: The wrapped cell with activation swap enabled. 

318 """ 

319 return SwapWrapper(module, policy_fn)