Coverage for hyper_parallel / core / activation_checkpoint / swap.py: 65%

252 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Swap tensor and swap manager implementation for activation checkpointing""" 

16# pylint: disable=W0212 

17 

18import functools 

19import threading 

20import warnings 

21import weakref 

22from collections import defaultdict 

23from typing import Any, Dict, List, Optional 

24 

25from hyper_parallel.platform import get_platform 

26 

27platform = get_platform() 

28 

29 

30class SwapTensor: 

31 """A tensor that can be swapped between device and host memory asynchronously.""" 

32 STATE_DEVICE = "device" 

33 STATE_HOST = "host" 

34 STATE_D2H = "d2h" 

35 STATE_H2D = "h2d" 

36 STATE_NON_TENSOR = "non_tensor" 

37 

38 def __init__(self, val: Any) -> None: 

39 self.val = val 

40 if isinstance(val, platform.Tensor) and str(val.device).lower() != 'cpu': 

41 self._state = self.STATE_DEVICE 

42 self.is_slice_tensor = val.storage().size() != val.numel() 

43 self.val_cpu = platform.empty_like( 

44 val, device="cpu", pin_memory=True 

45 ) 

46 self.storage_size = val.storage().size() 

47 else: 

48 self._state = self.STATE_NON_TENSOR 

49 self.val_cpu = None 

50 

51 def get_val(self) -> Any: 

52 if self._state == self.STATE_NON_TENSOR: 

53 return self.val 

54 if self._state != self.STATE_DEVICE: 

55 raise RuntimeError( 

56 f"Cannot call get_val(): tensor is in '{self._state}' state. " 

57 f"Must be in 'device' state." 

58 ) 

59 return self.val 

60 

61 def async_load(self): 

62 """async load tensor from host to device""" 

63 if self._state == self.STATE_NON_TENSOR: 

64 return 

65 

66 if self._state != self.STATE_HOST: 

67 warnings.warn( 

68 f"[SwapTensor.async_load] Invalid state: current={self._state}, " 

69 f"expected 'host'. Operation skipped." 

70 ) 

71 return 

72 

73 assert self.val_cpu is not None 

74 self.val.storage().resize_(self.storage_size) 

75 if self.is_slice_tensor: 

76 self.val.copy_(self.val_cpu, non_blocking=True) 

77 else: 

78 self.val.storage().copy_(self.val_cpu.storage(), non_blocking=True) 

79 self._state = self.STATE_H2D 

80 

81 def wait_load(self): 

82 """chanage state to device after async load is done""" 

83 if self._state == self.STATE_NON_TENSOR: 

84 return 

85 

86 if self._state == self.STATE_DEVICE: 

87 return # already loaded 

88 if self._state != self.STATE_H2D: 

89 warnings.warn( 

90 f"[SwapTensor.wait_load] Called in invalid state: {self._state}. " 

91 f"Expected 'h2d'. Skipped." 

92 ) 

93 return 

94 self._state = self.STATE_DEVICE 

95 

96 def async_offload(self): 

97 """async offload tensor from device to host""" 

98 if self._state == self.STATE_NON_TENSOR: 

99 return 

100 

101 if self._state != self.STATE_DEVICE: 

102 warnings.warn( 

103 f"[SwapTensor.async_offload] Invalid state: current={self._state}, " 

104 f"expected 'device'. Operation skipped." 

105 ) 

106 return 

107 

108 if self.is_slice_tensor: 

109 self.val_cpu.copy_(self.val, non_blocking=True) 

110 else: 

111 self.val_cpu.storage().copy_(self.val.storage(), non_blocking=True) 

112 self._state = self.STATE_D2H 

113 

114 def wait_offload(self): 

115 """wait offload to host and free device memory""" 

116 if self._state == self.STATE_NON_TENSOR: 

117 return 

118 

119 if self._state == self.STATE_HOST: 

120 return 

121 if self._state != self.STATE_D2H: 

122 warnings.warn( 

123 f"[SwapTensor.wait_offload] Called in invalid state: {self._state}. " 

124 f"Expected 'd2h'. Skipped." 

125 ) 

126 return 

127 self.val.storage().resize_(0) 

128 self._state = self.STATE_HOST 

129 

130 @property 

131 def state(self) -> str: 

132 return self._state 

133 

134 def __repr__(self): 

135 if self._state == self.STATE_NON_TENSOR: 

136 return f"<SwapTensor state=non_tensor, val_type={type(self.val).__name__}>" 

137 return f"<SwapTensor state={self._state}, device_val={'exists' if self.val is not None else 'None'}>" 

138 

139 

140class Storage: 

141 """Manage a collection of tensors for swapping operations.""" 

142 

143 def __init__(self): 

144 self.save_storage: Dict[Any, List[Any]] = defaultdict(list) 

145 self.swap_storage: Dict[Any, List[Any]] = defaultdict(list) 

146 

147 def launch_load(self): 

148 """launch async load for all tensors in swap storage""" 

149 def _async_load(x): 

150 if isinstance(x, SwapTensor): 

151 x.async_load() 

152 return x 

153 

154 for storage_list in self.swap_storage.values(): 

155 for item in storage_list: 

156 platform.tree_map(_async_load, item) 

157 

158 def wait_load(self): 

159 """wait load for all tensors in swap storage""" 

160 def _wait_load(x): 

161 if isinstance(x, SwapTensor): 

162 x.wait_load() 

163 return x 

164 

165 for storage_list in self.swap_storage.values(): 

166 for item in storage_list: 

167 platform.tree_map(_wait_load, item) 

168 

169 def wait_offload(self): 

170 """wait offload for all tensors in swap storage""" 

171 def _wait_offload(x): 

172 if isinstance(x, SwapTensor): 

173 x.wait_offload() 

174 return x 

175 

176 for storage_list in self.swap_storage.values(): 

177 for item in storage_list: 

178 platform.tree_map(_wait_offload, item) 

179 

180 def launch_offload(self): 

181 """launch async offload for all tensors in swap storage""" 

182 def _async_offload(x): 

183 if isinstance(x, SwapTensor): 

184 x.async_offload() 

185 return x 

186 

187 for storage_list in self.swap_storage.values(): 

188 for item in storage_list: 

189 platform.tree_map(_async_offload, item) 

190 

191 

192class SwapGroup: 

193 """Manager for a group of storages to coordinate swap operations.""" 

194 

195 def __init__(self, group_name: str): 

196 self.group_name = group_name 

197 self._storages = weakref.WeakSet() 

198 self._load_event = None 

199 self._offload_event = None 

200 

201 def add(self, storage): 

202 """Add a storage to the swap group.""" 

203 self._storages.add(storage) 

204 

205 def launch_offload(self, copy_stream): 

206 """Launch async offload for all storages in the group.""" 

207 compute_event = platform.new_event() 

208 compute_event.record(platform.get_current_stream()) 

209 self._offload_event = platform.new_event() 

210 stream_context = platform.get_stream_context() 

211 with platform.no_grad(), stream_context(copy_stream): 

212 compute_event.wait(copy_stream) 

213 for storage in self._storages: 

214 storage.launch_offload() 

215 self._offload_event.record(copy_stream) 

216 

217 def wait_offload(self): 

218 """Wait for offload to complete for all storages in the group.""" 

219 compute_stream = platform.get_current_stream() 

220 stream_context = platform.get_stream_context() 

221 with platform.no_grad(), stream_context(compute_stream): 

222 self._offload_event.wait(compute_stream) 

223 self._offload_event = None 

224 for storage in self._storages: 

225 storage.wait_offload() 

226 

227 def launch_load(self, copy_stream): 

228 """Launch async load for all storages in the group.""" 

229 compute_event = platform.new_event() 

230 compute_event.record(platform.get_current_stream()) 

231 self._load_event = platform.new_event() 

232 stream_context = platform.get_stream_context() 

233 with platform.no_grad(), stream_context(copy_stream): 

234 compute_event.wait(copy_stream) 

235 for storage in self._storages: 

236 storage.launch_load() 

237 self._load_event.record(copy_stream) 

238 

239 def wait_load(self): 

240 """Wait for load to complete for all storages in the group.""" 

241 compute_stream = platform.get_current_stream() 

242 stream_context = platform.get_stream_context() 

243 with platform.no_grad(), stream_context(compute_stream): 

244 self._load_event.wait(compute_stream) 

245 self._load_event = None 

246 for storage in self._storages: 

247 storage.wait_load() 

248 

249 

250class SwapManager: 

251 """Singleton manager for swap groups and their operations.""" 

252 _instance: Optional["SwapManager"] = None 

253 _lock = threading.Lock() 

254 

255 def __new__(cls): 

256 if cls._instance is None: 

257 with cls._lock: 

258 if cls._instance is None: 

259 cls._instance = super().__new__(cls) 

260 cls._instance._groups = {} 

261 cls._instance._current_group_name = "" 

262 cls._instance._counter_lock = threading.Lock() 

263 cls._instance._layer_count = 0 

264 cls._copy_stream = None 

265 return cls._instance 

266 

267 def add_storage(self, group_name: str, storage: Storage) -> None: 

268 """Add a storage to a specified swap group.""" 

269 if group_name not in self._groups: 

270 self._groups[group_name] = SwapGroup(group_name) 

271 self._groups[group_name].add(storage) 

272 

273 def launch_offload(self, group_name: str, copy_stream=None): 

274 """Launch async offload for a specified swap group.""" 

275 group = self._groups.get(group_name) 

276 if group is None: 

277 raise RuntimeError(f"Group {group_name} does not exist.") 

278 if copy_stream is None: 

279 copy_stream = self._get_copy_stream() 

280 group.launch_offload(copy_stream) 

281 

282 def wait_offload(self, group_name: str): 

283 """Wait for offload to complete for a specified swap group.""" 

284 group = self._groups.get(group_name) 

285 if group is None: 

286 raise RuntimeError(f"Group {group_name} does not exist.") 

287 group.wait_offload() 

288 

289 def launch_load(self, group_name: str, copy_stream=None): 

290 """Launch async load for a specified swap group.""" 

291 group = self._groups.get(group_name) 

292 if group is None: 

293 raise RuntimeError(f"Group {group_name} does not exist.") 

294 if copy_stream is None: 

295 copy_stream = self._get_copy_stream() 

296 group.launch_load(copy_stream) 

297 

298 def wait_load(self, group_name: str): 

299 """Wait for load to complete for a specified swap group.""" 

300 group = self._groups.get(group_name) 

301 if group is None: 

302 raise RuntimeError(f"Group {group_name} does not exist.") 

303 group.wait_load() 

304 

305 def get_current_group_name(self): 

306 return self._current_group_name 

307 

308 def set_current_group_name(self, group_name): 

309 self._current_group_name = group_name 

310 

311 def set_forward_prefetch_layer(self, first_layer, second_layer): 

312 """ 

313 Configure prefetching and offloading order between two consecutive layers. 

314 

315 Usage: 

316 for i in range(len(model.layers) - 1): 

317 set_forward_prefetch_layer(model.layers[i], model.layers[i + 1]) 

318 

319 Ensures idempotency: safe to call multiple times on the same layer pair. 

320 """ 

321 

322 def _ensure_group_name(module): 

323 """Assign a unique swap group name to the module if not already assigned.""" 

324 if not hasattr(module, "_swap_group_name"): 

325 name = f"swap_group_{self._layer_count}" 

326 self._layer_count += 1 

327 module._swap_group_name = name 

328 module._swap_group_order = {"prev": None, "next": None} 

329 return module._swap_group_name 

330 first_name = _ensure_group_name(first_layer) 

331 second_name = _ensure_group_name(second_layer) 

332 

333 if first_layer._swap_group_order["next"] is None: 

334 first_layer._swap_group_order["next"] = second_name 

335 if second_layer._swap_group_order["prev"] is None: 

336 second_layer._swap_group_order["prev"] = first_name 

337 

338 def _forward_pre_hook(group_name, module, _): # pylint: disable=W0613 

339 if getattr(module, "_swap_state", None) == "pre_backward": 

340 return 

341 SwapManager().set_current_group_name(group_name) 

342 

343 def _forward_hook(group_name, module, args, output): # pylint: disable=W0613 

344 """ 

345 Forward post-hook executed immediately after forward computation 

346 of the current layer finishes. 

347 

348 Execution timeline (example with 3 layers, forward order: L0 → L1 → L2): 

349 

350 Time → 

351 Forward Compute Stream: 

352 | Fwd L0 | post(L0) | Fwd L1 | post(L1) | Fwd L2 | 

353 

354 Copy Stream (offload): 

355 | Offload L0 | - | Offload L1 | 

356 ↑ ↑ 

357 offload at post(L0) offload at post(L1) 

358 

359 Swap rules: 

360 1. After forward computation of the current layer completes: 

361 - If a next layer exists, asynchronously offload the activations 

362 of the current layer (launch_offload). 

363 

364 Example: 

365 - At post-forward of L0, offload activations of L0. 

366 - At post-forward of L1, offload activations of L1. 

367 

368 2. To limit device memory peak: 

369 - If a previous layer exists, wait until its offload operation 

370 has completed (wait_offload). 

371 

372 Notes: 

373 - Offload operations are issued on the copy stream to overlap data transfer 

374 with forward computation of subsequent layers. 

375 - If the module is already in 'pre_backward' state, this hook is skipped 

376 to avoid triggering offload during backward phase. 

377 """ 

378 if getattr(module, "_swap_state", None) == "pre_backward": 

379 return 

380 next_name = module._swap_group_order.get('next', None) 

381 if next_name: 

382 SwapManager().launch_offload(group_name) 

383 prev_name = module._swap_group_order.get('prev', None) 

384 if prev_name: 

385 SwapManager().wait_offload(prev_name) 

386 

387 def _backward_pre_hook(group_name, module, grad_input): # pylint: disable=W0613 

388 """ 

389 Pre-backward hook executed immediately before backward computation 

390 of the current layer starts. 

391 

392 Execution timeline (example with 3 layers, backward order: L2 → L1 → L0): 

393 

394 Time → 

395 Backward Compute Stream: 

396 | pre(L2) | Grad L2 | pre(L1) | Grad L1 | pre(L0) | Grad L0 | 

397 

398 Copy Stream (load): 

399 | Load L1 | - | Load L0 | 

400 ↑ ↑ 

401 prefetch at pre(L2) prefetch at pre(L1) 

402 

403 Swap rules: 

404 1. At the beginning of backward for the current layer: 

405 - If a previous layer exists in backward order, asynchronously 

406 prefetch its activations (launch_load). 

407 

408 Example: 

409 - At pre-backward of L2, prefetch activations of L1. 

410 - At pre-backward of L1, prefetch activations of L0. 

411 

412 2. Before starting backward computation of the current layer: 

413 - Ensure that the activations of the current layer have already 

414 been loaded back to device memory (wait_load). 

415 

416 Notes: 

417 - Load operations are issued on the copy stream to overlap data transfer 

418 with backward computation of the current layer. 

419 - The swap state is marked as 'pre_backward' to prevent forward hooks 

420 from issuing offload operations during backward phase. 

421 """ 

422 module._swap_state = "pre_backward" 

423 prev_name = module._swap_group_order.get('prev', None) 

424 if prev_name: 

425 SwapManager().launch_load(prev_name) 

426 

427 next_name = module._swap_group_order.get('next', None) 

428 if next_name: 

429 SwapManager().wait_load(group_name) 

430 

431 def _backward_hook(group_name, module, grad_input, grad_output): # pylint: disable=W0613 

432 module._swap_state = "backward" 

433 

434 def _register_hooks_once(module, group_name): 

435 hooks = [ 

436 ("_swap_forward_pre_hook_handle", 

437 lambda h: platform.register_forward_pre_hook(module, h, prepend=True), 

438 functools.partial(_forward_pre_hook, group_name)), 

439 

440 ("_swap_forward_hook_handle", 

441 module.register_forward_hook, 

442 functools.partial(_forward_hook, group_name)), 

443 

444 ("_swap_backward_pre_hook_handle", 

445 lambda h: platform.register_full_backward_pre_hook(module, h, prepend=True), 

446 functools.partial(_backward_pre_hook, group_name)), 

447 

448 ("_swap_backward_hook_handle", 

449 lambda h: platform.register_full_backward_hook(module, h), 

450 functools.partial(_backward_hook, group_name)), 

451 ] 

452 

453 for attr_name, register_func, hook in hooks: 

454 if not hasattr(module, attr_name): 

455 handle = register_func(hook) 

456 setattr(module, attr_name, handle) 

457 # Register for both layers 

458 _register_hooks_once(first_layer, first_name) 

459 _register_hooks_once(second_layer, second_name) 

460 

461 def _get_copy_stream(self): 

462 """Return a singleton copy stream, created on first access.""" 

463 if self._copy_stream is None: 

464 self._copy_stream = platform.new_stream() 

465 return self._copy_stream