Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / activation_checkpoint / swap.py: 17%
329 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Swap tensor and swap manager implementation for activation checkpointing"""
16# pylint: disable=W0212
18import functools
19import threading
20import warnings
21import weakref
22from collections import defaultdict
23from typing import Any, Dict, List, Optional
25from hyper_parallel.platform import get_platform
27platform = get_platform()
30class SwapTensor:
31 """A tensor that can be swapped between device and host memory asynchronously."""
32 STATE_DEVICE = "device"
33 STATE_HOST = "host"
34 STATE_D2H = "d2h"
35 STATE_H2D = "h2d"
36 STATE_NON_TENSOR = "non_tensor"
38 def __init__(self, val: Any, funcname: Any) -> None:
39 self.val = val
40 self.ver = val._version
41 self.funcname = funcname
42 self._keep_on_device = False
43 if isinstance(val, platform.Tensor) and str(val.device).lower() != 'cpu':
44 self._state = self.STATE_DEVICE
45 self.is_slice_tensor = val.untyped_storage().size() != val.numel() * platform.get_element_size(val)
46 self.val_cpu = platform.empty_like(
47 val, device="cpu", pin_memory=True
48 )
49 self.storage_size = val.untyped_storage().size()
50 else:
51 self._state = self.STATE_NON_TENSOR
52 self.val_cpu = None
54 def protect_if_aliases(self, output_tensors: List[Any]) -> None:
55 """Keep tensors that alias the wrapped module output on device."""
56 if self._state == self.STATE_NON_TENSOR:
57 return
58 self_storage_ptr = self.val.untyped_storage().data_ptr()
59 for out in output_tensors:
60 if not isinstance(out, platform.Tensor):
61 continue
62 if str(out.device).lower() == "cpu":
63 continue
64 if out.untyped_storage().data_ptr() == self_storage_ptr:
65 self._keep_on_device = True
66 return
68 def get_val(self) -> Any:
69 if self._state == self.STATE_NON_TENSOR:
70 return self.val
71 if self._state != self.STATE_DEVICE:
72 raise RuntimeError(
73 f"Cannot call get_val(): tensor is in '{self._state}' state. "
74 f"Must be in 'device' state."
75 )
76 return self.val
78 def resize_device_storage(self):
79 """Reallocate device memory on compute stream."""
80 if self._state == self.STATE_NON_TENSOR:
81 return
83 if self._state != self.STATE_HOST:
84 return
85 storage = self.val.untyped_storage()
86 if storage.size() == self.storage_size:
87 return
88 storage.resize_(self.storage_size)
90 def async_load(self):
91 """async load tensor from host to device"""
92 if self._state == self.STATE_NON_TENSOR or self._keep_on_device:
93 return
95 if self._state != self.STATE_HOST:
96 warnings.warn(
97 f"[SwapTensor.async_load] Invalid state: current={self._state}, "
98 f"expected 'host'. Operation skipped."
99 )
100 return
102 if self.val_cpu is None:
103 raise ValueError("val_cpu must not be None during async_load")
104 if self.is_slice_tensor:
105 self.val.data.copy_(self.val_cpu, non_blocking=True)
106 else:
107 self.val.untyped_storage().copy_(self.val_cpu.untyped_storage(), non_blocking=True)
108 self._state = self.STATE_H2D
110 def wait_load(self):
111 """change state to device after async load is done"""
112 if self._state == self.STATE_NON_TENSOR or self._keep_on_device:
113 return
115 if self._state == self.STATE_DEVICE:
116 return # already loaded
117 if self._state != self.STATE_H2D:
118 warnings.warn(
119 f"[SwapTensor.wait_load] Called in invalid state: {self._state}. "
120 f"Expected 'h2d'. Skipped."
121 )
122 return
123 self._state = self.STATE_DEVICE
125 def async_offload(self):
126 """async offload tensor from device to host"""
127 if self._state == self.STATE_NON_TENSOR or self._keep_on_device:
128 return
130 if self._state != self.STATE_DEVICE:
131 warnings.warn(
132 f"[SwapTensor.async_offload] Invalid state: current={self._state}, "
133 f"expected 'device'. Operation skipped."
134 )
135 return
137 if self.storage_size != self.val.untyped_storage().size():
138 raise RuntimeError(
139 f"There is a tensor from {self.funcname} cannot be SWAPPED! Its storage has been resized "
140 f"presize:{self.storage_size}, current size:{self.val.untyped_storage().size()}"
141 )
142 if self.ver != self.val._version:
143 raise RuntimeError(
144 f"There is a tensor from {self.funcname} cannot be SWAPPED! In-place modification happened "
145 f"preversion:{self.ver}, current version:{self.val._version}"
146 )
148 if self.is_slice_tensor:
149 self.val_cpu.copy_(self.val, non_blocking=True)
150 else:
151 self.val_cpu.untyped_storage().copy_(self.val.untyped_storage(), non_blocking=True)
152 self._state = self.STATE_D2H
154 def wait_offload(self):
155 """wait offload to host and free device memory"""
156 if self._state == self.STATE_NON_TENSOR or self._keep_on_device:
157 return
159 if self._state == self.STATE_HOST:
160 return
161 if self._state != self.STATE_D2H:
162 warnings.warn(
163 f"[SwapTensor.wait_offload] Called in invalid state: {self._state}. "
164 f"Expected 'd2h'. Skipped."
165 )
166 return
167 storage = self.val.untyped_storage()
168 if storage.size() != 0:
169 storage.resize_(0)
170 self._state = self.STATE_HOST
172 @property
173 def state(self) -> str:
174 return self._state
176 def __repr__(self):
177 if self._state == self.STATE_NON_TENSOR:
178 return f"<SwapTensor state=non_tensor, val_type={type(self.val).__name__}>"
179 return f"<SwapTensor state={self._state}, device_val={'exists' if self.val is not None else 'None'}>"
182class Storage:
183 """Manage a collection of tensors for swapping operations."""
185 def __init__(self):
186 self.save_storage: Dict[Any, List[Any]] = defaultdict(list)
187 self.swap_storage: Dict[Any, List[Any]] = defaultdict(list)
189 def protect_output_tensors(self, outputs: Any):
190 """Avoid offloading tensors that alias the wrapped module outputs."""
191 output_tensors = []
193 def _collect_outputs(x):
194 if isinstance(x, platform.Tensor):
195 output_tensors.append(x)
196 return x
198 platform.tree_map(_collect_outputs, outputs)
199 if not output_tensors:
200 return
202 def _protect_tensor(x):
203 if isinstance(x, SwapTensor):
204 x.protect_if_aliases(output_tensors)
205 return x
207 for storage_list in self.swap_storage.values():
208 for item in storage_list:
209 platform.tree_map(_protect_tensor, item)
211 def launch_load(self):
212 """launch async load for all tensors in swap storage"""
213 def _async_load(x):
214 if isinstance(x, SwapTensor):
215 x.async_load()
216 return x
218 for storage_list in self.swap_storage.values():
219 for item in storage_list:
220 platform.tree_map(_async_load, item)
222 def resize_device_storage(self):
223 """Resize device storage for all swap tensors (runs on compute stream)."""
224 def _resize(x):
225 if isinstance(x, SwapTensor):
226 x.resize_device_storage()
227 return x
228 for storage_list in self.swap_storage.values():
229 for item in storage_list:
230 platform.tree_map(_resize, item)
232 def wait_load(self):
233 """wait load for all tensors in swap storage"""
234 def _wait_load(x):
235 if isinstance(x, SwapTensor):
236 x.wait_load()
237 return x
239 for storage_list in self.swap_storage.values():
240 for item in storage_list:
241 platform.tree_map(_wait_load, item)
243 def wait_offload(self):
244 """wait offload for all tensors in swap storage"""
245 def _wait_offload(x):
246 if isinstance(x, SwapTensor):
247 x.wait_offload()
248 return x
250 for storage_list in self.swap_storage.values():
251 for item in storage_list:
252 platform.tree_map(_wait_offload, item)
254 def launch_offload(self):
255 """launch async offload for all tensors in swap storage"""
256 def _async_offload(x):
257 if isinstance(x, SwapTensor):
258 x.async_offload()
259 return x
261 for storage_list in self.swap_storage.values():
262 for item in storage_list:
263 platform.tree_map(_async_offload, item)
266class SwapGroup:
267 """Manager for a group of storages to coordinate swap operations."""
269 def __init__(self, group_name: str):
270 self.group_name = group_name
271 self._storages = weakref.WeakSet()
272 self._load_event = None
273 self._offload_event = None
275 def add(self, storage):
276 """Add a storage to the swap group."""
277 self._storages.add(storage)
279 def protect_output_tensors(self, outputs: Any):
280 """Protect current module outputs from premature offload."""
281 for storage in self._storages:
282 storage.protect_output_tensors(outputs)
284 def launch_offload(self, copy_stream):
285 """Launch async offload for all storages in the group."""
286 compute_event = platform.new_event()
287 compute_event.record(platform.get_current_stream())
288 self._offload_event = platform.new_event()
289 stream_context = platform.get_stream_context()
290 with platform.no_grad(), stream_context(copy_stream):
291 compute_event.wait(copy_stream)
292 for storage in self._storages:
293 storage.launch_offload()
294 self._offload_event.record(copy_stream)
296 def wait_offload(self):
297 """Wait for offload to complete for all storages in the group."""
298 if self._offload_event is None:
299 raise RuntimeError(
300 f"SwapGroup '{self.group_name}' wait_offload() called before launch_offload()."
301 )
302 compute_stream = platform.get_current_stream()
303 stream_context = platform.get_stream_context()
304 with platform.no_grad(), stream_context(compute_stream):
305 self._offload_event.wait(compute_stream)
306 self._offload_event = None
307 for storage in self._storages:
308 storage.wait_offload()
310 def launch_load(self, copy_stream):
311 """Prepare storage and launch async load for all storages in the group."""
312 with platform.no_grad():
313 for storage in self._storages:
314 storage.resize_device_storage()
316 compute_event = platform.new_event()
317 compute_event.record(platform.get_current_stream())
318 self._load_event = platform.new_event()
319 stream_context = platform.get_stream_context()
320 with platform.no_grad(), stream_context(copy_stream):
321 compute_event.wait(copy_stream)
322 for storage in self._storages:
323 storage.launch_load() # Only copy, no resize
324 self._load_event.record(copy_stream)
326 def wait_load(self):
327 """Wait for load to complete for all storages in the group."""
328 if self._load_event is None:
329 raise RuntimeError(
330 f"SwapGroup '{self.group_name}' wait_load() called before launch_load()."
331 )
332 compute_stream = platform.get_current_stream()
333 stream_context = platform.get_stream_context()
334 with platform.no_grad(), stream_context(compute_stream):
335 self._load_event.wait(compute_stream)
336 self._load_event = None
337 for storage in self._storages:
338 storage.wait_load()
341class SwapManager:
342 """Singleton manager for swap groups and their operations."""
343 _instance: Optional["SwapManager"] = None
344 _lock = threading.Lock()
346 def __init__(self):
347 if hasattr(self, '_groups'):
348 return
349 self._groups = {}
350 self._current_group_name = ""
351 self._counter_lock = threading.Lock()
352 self._layer_count = 0
353 self._copy_stream = None
355 def __new__(cls):
356 if cls._instance is None:
357 with cls._lock:
358 if cls._instance is None:
359 cls._instance = super().__new__(cls)
360 return cls._instance
362 def add_storage(self, group_name: str, storage: Storage) -> None:
363 """Add a storage to a specified swap group."""
364 if group_name not in self._groups:
365 self._groups[group_name] = SwapGroup(group_name)
366 self._groups[group_name].add(storage)
368 def launch_offload(self, group_name: str, copy_stream=None):
369 """Launch async offload for a specified swap group."""
370 group = self._groups.get(group_name)
371 if group is None:
372 raise RuntimeError(f"Group {group_name} does not exist.")
373 if copy_stream is None:
374 copy_stream = self._get_copy_stream()
375 group.launch_offload(copy_stream)
377 def protect_output_tensors(self, group_name: str, outputs: Any):
378 """Keep tensors that alias the module output on device."""
379 group = self._groups.get(group_name)
380 if group is None:
381 raise RuntimeError(f"Group {group_name} does not exist.")
382 group.protect_output_tensors(outputs)
384 def wait_offload(self, group_name: str):
385 """Wait for offload to complete for a specified swap group."""
386 group = self._groups.get(group_name)
387 if group is None:
388 raise RuntimeError(f"Group {group_name} does not exist.")
389 group.wait_offload()
391 def launch_load(self, group_name: str, copy_stream=None):
392 """Launch async load for a specified swap group."""
393 group = self._groups.get(group_name)
394 if group is None:
395 raise RuntimeError(f"Group {group_name} does not exist.")
396 if copy_stream is None:
397 copy_stream = self._get_copy_stream()
398 group.launch_load(copy_stream)
400 def wait_load(self, group_name: str):
401 """Wait for load to complete for a specified swap group."""
402 group = self._groups.get(group_name)
403 if group is None:
404 raise RuntimeError(f"Group {group_name} does not exist.")
405 group.wait_load()
407 def get_current_group_name(self):
408 return self._current_group_name
410 def set_current_group_name(self, group_name):
411 self._current_group_name = group_name
413 def set_forward_prefetch_layer(self, first_layer, second_layer):
414 """
415 Configure prefetching and offloading order between two consecutive layers.
417 Usage:
418 for i in range(len(model.layers) - 1):
419 set_forward_prefetch_layer(model.layers[i], model.layers[i + 1])
421 Ensures idempotency: safe to call multiple times on the same layer pair.
422 """
424 def _ensure_group_name(module):
425 """Assign a unique swap group name to the module if not already assigned."""
426 if not hasattr(module, "_swap_group_name"):
427 name = f"swap_group_{self._layer_count}"
428 self._layer_count += 1
429 module._swap_group_name = name
430 module._swap_group_order = {"prev": None, "next": None}
431 return module._swap_group_name
432 first_name = _ensure_group_name(first_layer)
433 second_name = _ensure_group_name(second_layer)
435 if first_name not in self._groups:
436 self._groups[first_name] = SwapGroup(first_name)
437 if second_name not in self._groups:
438 self._groups[second_name] = SwapGroup(second_name)
440 if first_layer._swap_group_order["next"] is None:
441 first_layer._swap_group_order["next"] = second_name
442 if second_layer._swap_group_order["prev"] is None:
443 second_layer._swap_group_order["prev"] = first_name
445 def _forward_pre_hook(group_name, module, _): # pylint: disable=W0613
446 if getattr(module, "_swap_state", None) == "pre_backward":
447 return
448 SwapManager().set_current_group_name(group_name)
450 def _forward_hook(group_name, module, args, output): # pylint: disable=W0613
451 """
452 Forward post-hook executed immediately after forward computation
453 of the current layer finishes.
455 Execution timeline (example with 3 layers, forward order: L0 → L1 → L2):
457 Time →
458 Forward Compute Stream:
459 | Fwd L0 | post(L0) | Fwd L1 | post(L1) | Fwd L2 |
461 Copy Stream (offload):
462 | Offload L0 | - | Offload L1 |
463 ↑ ↑
464 offload at post(L0) offload at post(L1)
466 Swap rules:
467 1. After forward computation of the current layer completes:
468 - If a next layer exists, asynchronously offload the activations
469 of the current layer (launch_offload).
471 Example:
472 - At post-forward of L0, offload activations of L0.
473 - At post-forward of L1, offload activations of L1.
475 2. To limit device memory peak:
476 - If a previous layer exists, wait until its offload operation
477 has completed (wait_offload).
479 Notes:
480 - Offload operations are issued on the copy stream to overlap data transfer
481 with forward computation of subsequent layers.
482 - If the module is already in 'pre_backward' state, this hook is skipped
483 to avoid triggering offload during backward phase.
484 """
485 if getattr(module, "_swap_state", None) == "pre_backward":
486 return
487 next_name = module._swap_group_order.get('next', None)
488 if next_name:
489 SwapManager().protect_output_tensors(group_name, output)
490 SwapManager().launch_offload(group_name)
491 prev_name = module._swap_group_order.get('prev', None)
492 if prev_name:
493 SwapManager().wait_offload(prev_name)
495 def _backward_pre_hook(group_name, module, grad_input): # pylint: disable=W0613
496 """
497 Pre-backward hook executed immediately before backward computation
498 of the current layer starts.
500 Execution timeline (example with 3 layers, backward order: L2 → L1 → L0):
502 Time →
503 Backward Compute Stream:
504 | pre(L2) | Grad L2 | pre(L1) | Grad L1 | pre(L0) | Grad L0 |
506 Copy Stream (load):
507 | Load L1 | - | Load L0 |
508 ↑ ↑
509 prefetch at pre(L2) prefetch at pre(L1)
511 Swap rules:
512 1. At the beginning of backward for the current layer:
513 - If a previous layer exists in backward order, asynchronously
514 prefetch its activations (launch_load).
516 Example:
517 - At pre-backward of L2, prefetch activations of L1.
518 - At pre-backward of L1, prefetch activations of L0.
520 2. Before starting backward computation of the current layer:
521 - Ensure that the activations of the current layer have already
522 been loaded back to device memory (wait_load).
524 Notes:
525 - Load operations are issued on the copy stream to overlap data transfer
526 with backward computation of the current layer.
527 - The swap state is marked as 'pre_backward' to prevent forward hooks
528 from issuing offload operations during backward phase.
529 """
530 module._swap_state = "pre_backward"
531 prev_name = module._swap_group_order.get('prev', None)
532 if prev_name:
533 SwapManager().launch_load(prev_name)
535 next_name = module._swap_group_order.get('next', None)
536 if next_name:
537 SwapManager().wait_load(group_name)
539 def _backward_hook(group_name, module, grad_input, grad_output): # pylint: disable=W0613
540 module._swap_state = "backward"
542 def _register_hooks_once(module, group_name):
543 hooks = [
544 ("_swap_forward_pre_hook_handle",
545 lambda h: platform.register_forward_pre_hook(module, h, prepend=True),
546 functools.partial(_forward_pre_hook, group_name)),
548 ("_swap_forward_hook_handle",
549 module.register_forward_hook,
550 functools.partial(_forward_hook, group_name)),
552 ("_swap_backward_pre_hook_handle",
553 lambda h: platform.register_full_backward_pre_hook(module, h, prepend=True),
554 functools.partial(_backward_pre_hook, group_name)),
556 ("_swap_backward_hook_handle",
557 lambda h: platform.register_full_backward_hook(module, h),
558 functools.partial(_backward_hook, group_name)),
559 ]
561 for attr_name, register_func, hook in hooks:
562 if not hasattr(module, attr_name):
563 handle = register_func(hook)
564 setattr(module, attr_name, handle)
565 # Register for both layers
566 _register_hooks_once(first_layer, first_name)
567 _register_hooks_once(second_layer, second_name)
569 def _get_copy_stream(self):
570 """Return a singleton copy stream, created on first access."""
571 if self._copy_stream is None:
572 self._copy_stream = platform.new_stream()
573 return self._copy_stream