Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / activation_checkpoint / swap.py: 17%

329 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Swap tensor and swap manager implementation for activation checkpointing""" 

16# pylint: disable=W0212 

17 

18import functools 

19import threading 

20import warnings 

21import weakref 

22from collections import defaultdict 

23from typing import Any, Dict, List, Optional 

24 

25from hyper_parallel.platform import get_platform 

26 

27platform = get_platform() 

28 

29 

30class SwapTensor: 

31 """A tensor that can be swapped between device and host memory asynchronously.""" 

32 STATE_DEVICE = "device" 

33 STATE_HOST = "host" 

34 STATE_D2H = "d2h" 

35 STATE_H2D = "h2d" 

36 STATE_NON_TENSOR = "non_tensor" 

37 

38 def __init__(self, val: Any, funcname: Any) -> None: 

39 self.val = val 

40 self.ver = val._version 

41 self.funcname = funcname 

42 self._keep_on_device = False 

43 if isinstance(val, platform.Tensor) and str(val.device).lower() != 'cpu': 

44 self._state = self.STATE_DEVICE 

45 self.is_slice_tensor = val.untyped_storage().size() != val.numel() * platform.get_element_size(val) 

46 self.val_cpu = platform.empty_like( 

47 val, device="cpu", pin_memory=True 

48 ) 

49 self.storage_size = val.untyped_storage().size() 

50 else: 

51 self._state = self.STATE_NON_TENSOR 

52 self.val_cpu = None 

53 

54 def protect_if_aliases(self, output_tensors: List[Any]) -> None: 

55 """Keep tensors that alias the wrapped module output on device.""" 

56 if self._state == self.STATE_NON_TENSOR: 

57 return 

58 self_storage_ptr = self.val.untyped_storage().data_ptr() 

59 for out in output_tensors: 

60 if not isinstance(out, platform.Tensor): 

61 continue 

62 if str(out.device).lower() == "cpu": 

63 continue 

64 if out.untyped_storage().data_ptr() == self_storage_ptr: 

65 self._keep_on_device = True 

66 return 

67 

68 def get_val(self) -> Any: 

69 if self._state == self.STATE_NON_TENSOR: 

70 return self.val 

71 if self._state != self.STATE_DEVICE: 

72 raise RuntimeError( 

73 f"Cannot call get_val(): tensor is in '{self._state}' state. " 

74 f"Must be in 'device' state." 

75 ) 

76 return self.val 

77 

78 def resize_device_storage(self): 

79 """Reallocate device memory on compute stream.""" 

80 if self._state == self.STATE_NON_TENSOR: 

81 return 

82 

83 if self._state != self.STATE_HOST: 

84 return 

85 storage = self.val.untyped_storage() 

86 if storage.size() == self.storage_size: 

87 return 

88 storage.resize_(self.storage_size) 

89 

90 def async_load(self): 

91 """async load tensor from host to device""" 

92 if self._state == self.STATE_NON_TENSOR or self._keep_on_device: 

93 return 

94 

95 if self._state != self.STATE_HOST: 

96 warnings.warn( 

97 f"[SwapTensor.async_load] Invalid state: current={self._state}, " 

98 f"expected 'host'. Operation skipped." 

99 ) 

100 return 

101 

102 if self.val_cpu is None: 

103 raise ValueError("val_cpu must not be None during async_load") 

104 if self.is_slice_tensor: 

105 self.val.data.copy_(self.val_cpu, non_blocking=True) 

106 else: 

107 self.val.untyped_storage().copy_(self.val_cpu.untyped_storage(), non_blocking=True) 

108 self._state = self.STATE_H2D 

109 

110 def wait_load(self): 

111 """change state to device after async load is done""" 

112 if self._state == self.STATE_NON_TENSOR or self._keep_on_device: 

113 return 

114 

115 if self._state == self.STATE_DEVICE: 

116 return # already loaded 

117 if self._state != self.STATE_H2D: 

118 warnings.warn( 

119 f"[SwapTensor.wait_load] Called in invalid state: {self._state}. " 

120 f"Expected 'h2d'. Skipped." 

121 ) 

122 return 

123 self._state = self.STATE_DEVICE 

124 

125 def async_offload(self): 

126 """async offload tensor from device to host""" 

127 if self._state == self.STATE_NON_TENSOR or self._keep_on_device: 

128 return 

129 

130 if self._state != self.STATE_DEVICE: 

131 warnings.warn( 

132 f"[SwapTensor.async_offload] Invalid state: current={self._state}, " 

133 f"expected 'device'. Operation skipped." 

134 ) 

135 return 

136 

137 if self.storage_size != self.val.untyped_storage().size(): 

138 raise RuntimeError( 

139 f"There is a tensor from {self.funcname} cannot be SWAPPED! Its storage has been resized " 

140 f"presize:{self.storage_size}, current size:{self.val.untyped_storage().size()}" 

141 ) 

142 if self.ver != self.val._version: 

143 raise RuntimeError( 

144 f"There is a tensor from {self.funcname} cannot be SWAPPED! In-place modification happened " 

145 f"preversion:{self.ver}, current version:{self.val._version}" 

146 ) 

147 

148 if self.is_slice_tensor: 

149 self.val_cpu.copy_(self.val, non_blocking=True) 

150 else: 

151 self.val_cpu.untyped_storage().copy_(self.val.untyped_storage(), non_blocking=True) 

152 self._state = self.STATE_D2H 

153 

154 def wait_offload(self): 

155 """wait offload to host and free device memory""" 

156 if self._state == self.STATE_NON_TENSOR or self._keep_on_device: 

157 return 

158 

159 if self._state == self.STATE_HOST: 

160 return 

161 if self._state != self.STATE_D2H: 

162 warnings.warn( 

163 f"[SwapTensor.wait_offload] Called in invalid state: {self._state}. " 

164 f"Expected 'd2h'. Skipped." 

165 ) 

166 return 

167 storage = self.val.untyped_storage() 

168 if storage.size() != 0: 

169 storage.resize_(0) 

170 self._state = self.STATE_HOST 

171 

172 @property 

173 def state(self) -> str: 

174 return self._state 

175 

176 def __repr__(self): 

177 if self._state == self.STATE_NON_TENSOR: 

178 return f"<SwapTensor state=non_tensor, val_type={type(self.val).__name__}>" 

179 return f"<SwapTensor state={self._state}, device_val={'exists' if self.val is not None else 'None'}>" 

180 

181 

182class Storage: 

183 """Manage a collection of tensors for swapping operations.""" 

184 

185 def __init__(self): 

186 self.save_storage: Dict[Any, List[Any]] = defaultdict(list) 

187 self.swap_storage: Dict[Any, List[Any]] = defaultdict(list) 

188 

189 def protect_output_tensors(self, outputs: Any): 

190 """Avoid offloading tensors that alias the wrapped module outputs.""" 

191 output_tensors = [] 

192 

193 def _collect_outputs(x): 

194 if isinstance(x, platform.Tensor): 

195 output_tensors.append(x) 

196 return x 

197 

198 platform.tree_map(_collect_outputs, outputs) 

199 if not output_tensors: 

200 return 

201 

202 def _protect_tensor(x): 

203 if isinstance(x, SwapTensor): 

204 x.protect_if_aliases(output_tensors) 

205 return x 

206 

207 for storage_list in self.swap_storage.values(): 

208 for item in storage_list: 

209 platform.tree_map(_protect_tensor, item) 

210 

211 def launch_load(self): 

212 """launch async load for all tensors in swap storage""" 

213 def _async_load(x): 

214 if isinstance(x, SwapTensor): 

215 x.async_load() 

216 return x 

217 

218 for storage_list in self.swap_storage.values(): 

219 for item in storage_list: 

220 platform.tree_map(_async_load, item) 

221 

222 def resize_device_storage(self): 

223 """Resize device storage for all swap tensors (runs on compute stream).""" 

224 def _resize(x): 

225 if isinstance(x, SwapTensor): 

226 x.resize_device_storage() 

227 return x 

228 for storage_list in self.swap_storage.values(): 

229 for item in storage_list: 

230 platform.tree_map(_resize, item) 

231 

232 def wait_load(self): 

233 """wait load for all tensors in swap storage""" 

234 def _wait_load(x): 

235 if isinstance(x, SwapTensor): 

236 x.wait_load() 

237 return x 

238 

239 for storage_list in self.swap_storage.values(): 

240 for item in storage_list: 

241 platform.tree_map(_wait_load, item) 

242 

243 def wait_offload(self): 

244 """wait offload for all tensors in swap storage""" 

245 def _wait_offload(x): 

246 if isinstance(x, SwapTensor): 

247 x.wait_offload() 

248 return x 

249 

250 for storage_list in self.swap_storage.values(): 

251 for item in storage_list: 

252 platform.tree_map(_wait_offload, item) 

253 

254 def launch_offload(self): 

255 """launch async offload for all tensors in swap storage""" 

256 def _async_offload(x): 

257 if isinstance(x, SwapTensor): 

258 x.async_offload() 

259 return x 

260 

261 for storage_list in self.swap_storage.values(): 

262 for item in storage_list: 

263 platform.tree_map(_async_offload, item) 

264 

265 

266class SwapGroup: 

267 """Manager for a group of storages to coordinate swap operations.""" 

268 

269 def __init__(self, group_name: str): 

270 self.group_name = group_name 

271 self._storages = weakref.WeakSet() 

272 self._load_event = None 

273 self._offload_event = None 

274 

275 def add(self, storage): 

276 """Add a storage to the swap group.""" 

277 self._storages.add(storage) 

278 

279 def protect_output_tensors(self, outputs: Any): 

280 """Protect current module outputs from premature offload.""" 

281 for storage in self._storages: 

282 storage.protect_output_tensors(outputs) 

283 

284 def launch_offload(self, copy_stream): 

285 """Launch async offload for all storages in the group.""" 

286 compute_event = platform.new_event() 

287 compute_event.record(platform.get_current_stream()) 

288 self._offload_event = platform.new_event() 

289 stream_context = platform.get_stream_context() 

290 with platform.no_grad(), stream_context(copy_stream): 

291 compute_event.wait(copy_stream) 

292 for storage in self._storages: 

293 storage.launch_offload() 

294 self._offload_event.record(copy_stream) 

295 

296 def wait_offload(self): 

297 """Wait for offload to complete for all storages in the group.""" 

298 if self._offload_event is None: 

299 raise RuntimeError( 

300 f"SwapGroup '{self.group_name}' wait_offload() called before launch_offload()." 

301 ) 

302 compute_stream = platform.get_current_stream() 

303 stream_context = platform.get_stream_context() 

304 with platform.no_grad(), stream_context(compute_stream): 

305 self._offload_event.wait(compute_stream) 

306 self._offload_event = None 

307 for storage in self._storages: 

308 storage.wait_offload() 

309 

310 def launch_load(self, copy_stream): 

311 """Prepare storage and launch async load for all storages in the group.""" 

312 with platform.no_grad(): 

313 for storage in self._storages: 

314 storage.resize_device_storage() 

315 

316 compute_event = platform.new_event() 

317 compute_event.record(platform.get_current_stream()) 

318 self._load_event = platform.new_event() 

319 stream_context = platform.get_stream_context() 

320 with platform.no_grad(), stream_context(copy_stream): 

321 compute_event.wait(copy_stream) 

322 for storage in self._storages: 

323 storage.launch_load() # Only copy, no resize 

324 self._load_event.record(copy_stream) 

325 

326 def wait_load(self): 

327 """Wait for load to complete for all storages in the group.""" 

328 if self._load_event is None: 

329 raise RuntimeError( 

330 f"SwapGroup '{self.group_name}' wait_load() called before launch_load()." 

331 ) 

332 compute_stream = platform.get_current_stream() 

333 stream_context = platform.get_stream_context() 

334 with platform.no_grad(), stream_context(compute_stream): 

335 self._load_event.wait(compute_stream) 

336 self._load_event = None 

337 for storage in self._storages: 

338 storage.wait_load() 

339 

340 

341class SwapManager: 

342 """Singleton manager for swap groups and their operations.""" 

343 _instance: Optional["SwapManager"] = None 

344 _lock = threading.Lock() 

345 

346 def __init__(self): 

347 if hasattr(self, '_groups'): 

348 return 

349 self._groups = {} 

350 self._current_group_name = "" 

351 self._counter_lock = threading.Lock() 

352 self._layer_count = 0 

353 self._copy_stream = None 

354 

355 def __new__(cls): 

356 if cls._instance is None: 

357 with cls._lock: 

358 if cls._instance is None: 

359 cls._instance = super().__new__(cls) 

360 return cls._instance 

361 

362 def add_storage(self, group_name: str, storage: Storage) -> None: 

363 """Add a storage to a specified swap group.""" 

364 if group_name not in self._groups: 

365 self._groups[group_name] = SwapGroup(group_name) 

366 self._groups[group_name].add(storage) 

367 

368 def launch_offload(self, group_name: str, copy_stream=None): 

369 """Launch async offload for a specified swap group.""" 

370 group = self._groups.get(group_name) 

371 if group is None: 

372 raise RuntimeError(f"Group {group_name} does not exist.") 

373 if copy_stream is None: 

374 copy_stream = self._get_copy_stream() 

375 group.launch_offload(copy_stream) 

376 

377 def protect_output_tensors(self, group_name: str, outputs: Any): 

378 """Keep tensors that alias the module output on device.""" 

379 group = self._groups.get(group_name) 

380 if group is None: 

381 raise RuntimeError(f"Group {group_name} does not exist.") 

382 group.protect_output_tensors(outputs) 

383 

384 def wait_offload(self, group_name: str): 

385 """Wait for offload to complete for a specified swap group.""" 

386 group = self._groups.get(group_name) 

387 if group is None: 

388 raise RuntimeError(f"Group {group_name} does not exist.") 

389 group.wait_offload() 

390 

391 def launch_load(self, group_name: str, copy_stream=None): 

392 """Launch async load for a specified swap group.""" 

393 group = self._groups.get(group_name) 

394 if group is None: 

395 raise RuntimeError(f"Group {group_name} does not exist.") 

396 if copy_stream is None: 

397 copy_stream = self._get_copy_stream() 

398 group.launch_load(copy_stream) 

399 

400 def wait_load(self, group_name: str): 

401 """Wait for load to complete for a specified swap group.""" 

402 group = self._groups.get(group_name) 

403 if group is None: 

404 raise RuntimeError(f"Group {group_name} does not exist.") 

405 group.wait_load() 

406 

407 def get_current_group_name(self): 

408 return self._current_group_name 

409 

410 def set_current_group_name(self, group_name): 

411 self._current_group_name = group_name 

412 

413 def set_forward_prefetch_layer(self, first_layer, second_layer): 

414 """ 

415 Configure prefetching and offloading order between two consecutive layers. 

416 

417 Usage: 

418 for i in range(len(model.layers) - 1): 

419 set_forward_prefetch_layer(model.layers[i], model.layers[i + 1]) 

420 

421 Ensures idempotency: safe to call multiple times on the same layer pair. 

422 """ 

423 

424 def _ensure_group_name(module): 

425 """Assign a unique swap group name to the module if not already assigned.""" 

426 if not hasattr(module, "_swap_group_name"): 

427 name = f"swap_group_{self._layer_count}" 

428 self._layer_count += 1 

429 module._swap_group_name = name 

430 module._swap_group_order = {"prev": None, "next": None} 

431 return module._swap_group_name 

432 first_name = _ensure_group_name(first_layer) 

433 second_name = _ensure_group_name(second_layer) 

434 

435 if first_name not in self._groups: 

436 self._groups[first_name] = SwapGroup(first_name) 

437 if second_name not in self._groups: 

438 self._groups[second_name] = SwapGroup(second_name) 

439 

440 if first_layer._swap_group_order["next"] is None: 

441 first_layer._swap_group_order["next"] = second_name 

442 if second_layer._swap_group_order["prev"] is None: 

443 second_layer._swap_group_order["prev"] = first_name 

444 

445 def _forward_pre_hook(group_name, module, _): # pylint: disable=W0613 

446 if getattr(module, "_swap_state", None) == "pre_backward": 

447 return 

448 SwapManager().set_current_group_name(group_name) 

449 

450 def _forward_hook(group_name, module, args, output): # pylint: disable=W0613 

451 """ 

452 Forward post-hook executed immediately after forward computation 

453 of the current layer finishes. 

454 

455 Execution timeline (example with 3 layers, forward order: L0 → L1 → L2): 

456 

457 Time → 

458 Forward Compute Stream: 

459 | Fwd L0 | post(L0) | Fwd L1 | post(L1) | Fwd L2 | 

460 

461 Copy Stream (offload): 

462 | Offload L0 | - | Offload L1 | 

463 ↑ ↑ 

464 offload at post(L0) offload at post(L1) 

465 

466 Swap rules: 

467 1. After forward computation of the current layer completes: 

468 - If a next layer exists, asynchronously offload the activations 

469 of the current layer (launch_offload). 

470 

471 Example: 

472 - At post-forward of L0, offload activations of L0. 

473 - At post-forward of L1, offload activations of L1. 

474 

475 2. To limit device memory peak: 

476 - If a previous layer exists, wait until its offload operation 

477 has completed (wait_offload). 

478 

479 Notes: 

480 - Offload operations are issued on the copy stream to overlap data transfer 

481 with forward computation of subsequent layers. 

482 - If the module is already in 'pre_backward' state, this hook is skipped 

483 to avoid triggering offload during backward phase. 

484 """ 

485 if getattr(module, "_swap_state", None) == "pre_backward": 

486 return 

487 next_name = module._swap_group_order.get('next', None) 

488 if next_name: 

489 SwapManager().protect_output_tensors(group_name, output) 

490 SwapManager().launch_offload(group_name) 

491 prev_name = module._swap_group_order.get('prev', None) 

492 if prev_name: 

493 SwapManager().wait_offload(prev_name) 

494 

495 def _backward_pre_hook(group_name, module, grad_input): # pylint: disable=W0613 

496 """ 

497 Pre-backward hook executed immediately before backward computation 

498 of the current layer starts. 

499 

500 Execution timeline (example with 3 layers, backward order: L2 → L1 → L0): 

501 

502 Time → 

503 Backward Compute Stream: 

504 | pre(L2) | Grad L2 | pre(L1) | Grad L1 | pre(L0) | Grad L0 | 

505 

506 Copy Stream (load): 

507 | Load L1 | - | Load L0 | 

508 ↑ ↑ 

509 prefetch at pre(L2) prefetch at pre(L1) 

510 

511 Swap rules: 

512 1. At the beginning of backward for the current layer: 

513 - If a previous layer exists in backward order, asynchronously 

514 prefetch its activations (launch_load). 

515 

516 Example: 

517 - At pre-backward of L2, prefetch activations of L1. 

518 - At pre-backward of L1, prefetch activations of L0. 

519 

520 2. Before starting backward computation of the current layer: 

521 - Ensure that the activations of the current layer have already 

522 been loaded back to device memory (wait_load). 

523 

524 Notes: 

525 - Load operations are issued on the copy stream to overlap data transfer 

526 with backward computation of the current layer. 

527 - The swap state is marked as 'pre_backward' to prevent forward hooks 

528 from issuing offload operations during backward phase. 

529 """ 

530 module._swap_state = "pre_backward" 

531 prev_name = module._swap_group_order.get('prev', None) 

532 if prev_name: 

533 SwapManager().launch_load(prev_name) 

534 

535 next_name = module._swap_group_order.get('next', None) 

536 if next_name: 

537 SwapManager().wait_load(group_name) 

538 

539 def _backward_hook(group_name, module, grad_input, grad_output): # pylint: disable=W0613 

540 module._swap_state = "backward" 

541 

542 def _register_hooks_once(module, group_name): 

543 hooks = [ 

544 ("_swap_forward_pre_hook_handle", 

545 lambda h: platform.register_forward_pre_hook(module, h, prepend=True), 

546 functools.partial(_forward_pre_hook, group_name)), 

547 

548 ("_swap_forward_hook_handle", 

549 module.register_forward_hook, 

550 functools.partial(_forward_hook, group_name)), 

551 

552 ("_swap_backward_pre_hook_handle", 

553 lambda h: platform.register_full_backward_pre_hook(module, h, prepend=True), 

554 functools.partial(_backward_pre_hook, group_name)), 

555 

556 ("_swap_backward_hook_handle", 

557 lambda h: platform.register_full_backward_hook(module, h), 

558 functools.partial(_backward_hook, group_name)), 

559 ] 

560 

561 for attr_name, register_func, hook in hooks: 

562 if not hasattr(module, attr_name): 

563 handle = register_func(hook) 

564 setattr(module, attr_name, handle) 

565 # Register for both layers 

566 _register_hooks_once(first_layer, first_name) 

567 _register_hooks_once(second_layer, second_name) 

568 

569 def _get_copy_stream(self): 

570 """Return a singleton copy stream, created on first access.""" 

571 if self._copy_stream is None: 

572 self._copy_stream = platform.new_stream() 

573 return self._copy_stream