Coverage for hyper_parallel / core / activation_checkpoint / swap.py: 65%
252 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Swap tensor and swap manager implementation for activation checkpointing"""
16# pylint: disable=W0212
18import functools
19import threading
20import warnings
21import weakref
22from collections import defaultdict
23from typing import Any, Dict, List, Optional
25from hyper_parallel.platform import get_platform
27platform = get_platform()
30class SwapTensor:
31 """A tensor that can be swapped between device and host memory asynchronously."""
32 STATE_DEVICE = "device"
33 STATE_HOST = "host"
34 STATE_D2H = "d2h"
35 STATE_H2D = "h2d"
36 STATE_NON_TENSOR = "non_tensor"
38 def __init__(self, val: Any) -> None:
39 self.val = val
40 if isinstance(val, platform.Tensor) and str(val.device).lower() != 'cpu':
41 self._state = self.STATE_DEVICE
42 self.is_slice_tensor = val.storage().size() != val.numel()
43 self.val_cpu = platform.empty_like(
44 val, device="cpu", pin_memory=True
45 )
46 self.storage_size = val.storage().size()
47 else:
48 self._state = self.STATE_NON_TENSOR
49 self.val_cpu = None
51 def get_val(self) -> Any:
52 if self._state == self.STATE_NON_TENSOR:
53 return self.val
54 if self._state != self.STATE_DEVICE:
55 raise RuntimeError(
56 f"Cannot call get_val(): tensor is in '{self._state}' state. "
57 f"Must be in 'device' state."
58 )
59 return self.val
61 def async_load(self):
62 """async load tensor from host to device"""
63 if self._state == self.STATE_NON_TENSOR:
64 return
66 if self._state != self.STATE_HOST:
67 warnings.warn(
68 f"[SwapTensor.async_load] Invalid state: current={self._state}, "
69 f"expected 'host'. Operation skipped."
70 )
71 return
73 assert self.val_cpu is not None
74 self.val.storage().resize_(self.storage_size)
75 if self.is_slice_tensor:
76 self.val.copy_(self.val_cpu, non_blocking=True)
77 else:
78 self.val.storage().copy_(self.val_cpu.storage(), non_blocking=True)
79 self._state = self.STATE_H2D
81 def wait_load(self):
82 """chanage state to device after async load is done"""
83 if self._state == self.STATE_NON_TENSOR:
84 return
86 if self._state == self.STATE_DEVICE:
87 return # already loaded
88 if self._state != self.STATE_H2D:
89 warnings.warn(
90 f"[SwapTensor.wait_load] Called in invalid state: {self._state}. "
91 f"Expected 'h2d'. Skipped."
92 )
93 return
94 self._state = self.STATE_DEVICE
96 def async_offload(self):
97 """async offload tensor from device to host"""
98 if self._state == self.STATE_NON_TENSOR:
99 return
101 if self._state != self.STATE_DEVICE:
102 warnings.warn(
103 f"[SwapTensor.async_offload] Invalid state: current={self._state}, "
104 f"expected 'device'. Operation skipped."
105 )
106 return
108 if self.is_slice_tensor:
109 self.val_cpu.copy_(self.val, non_blocking=True)
110 else:
111 self.val_cpu.storage().copy_(self.val.storage(), non_blocking=True)
112 self._state = self.STATE_D2H
114 def wait_offload(self):
115 """wait offload to host and free device memory"""
116 if self._state == self.STATE_NON_TENSOR:
117 return
119 if self._state == self.STATE_HOST:
120 return
121 if self._state != self.STATE_D2H:
122 warnings.warn(
123 f"[SwapTensor.wait_offload] Called in invalid state: {self._state}. "
124 f"Expected 'd2h'. Skipped."
125 )
126 return
127 self.val.storage().resize_(0)
128 self._state = self.STATE_HOST
130 @property
131 def state(self) -> str:
132 return self._state
134 def __repr__(self):
135 if self._state == self.STATE_NON_TENSOR:
136 return f"<SwapTensor state=non_tensor, val_type={type(self.val).__name__}>"
137 return f"<SwapTensor state={self._state}, device_val={'exists' if self.val is not None else 'None'}>"
140class Storage:
141 """Manage a collection of tensors for swapping operations."""
143 def __init__(self):
144 self.save_storage: Dict[Any, List[Any]] = defaultdict(list)
145 self.swap_storage: Dict[Any, List[Any]] = defaultdict(list)
147 def launch_load(self):
148 """launch async load for all tensors in swap storage"""
149 def _async_load(x):
150 if isinstance(x, SwapTensor):
151 x.async_load()
152 return x
154 for storage_list in self.swap_storage.values():
155 for item in storage_list:
156 platform.tree_map(_async_load, item)
158 def wait_load(self):
159 """wait load for all tensors in swap storage"""
160 def _wait_load(x):
161 if isinstance(x, SwapTensor):
162 x.wait_load()
163 return x
165 for storage_list in self.swap_storage.values():
166 for item in storage_list:
167 platform.tree_map(_wait_load, item)
169 def wait_offload(self):
170 """wait offload for all tensors in swap storage"""
171 def _wait_offload(x):
172 if isinstance(x, SwapTensor):
173 x.wait_offload()
174 return x
176 for storage_list in self.swap_storage.values():
177 for item in storage_list:
178 platform.tree_map(_wait_offload, item)
180 def launch_offload(self):
181 """launch async offload for all tensors in swap storage"""
182 def _async_offload(x):
183 if isinstance(x, SwapTensor):
184 x.async_offload()
185 return x
187 for storage_list in self.swap_storage.values():
188 for item in storage_list:
189 platform.tree_map(_async_offload, item)
192class SwapGroup:
193 """Manager for a group of storages to coordinate swap operations."""
195 def __init__(self, group_name: str):
196 self.group_name = group_name
197 self._storages = weakref.WeakSet()
198 self._load_event = None
199 self._offload_event = None
201 def add(self, storage):
202 """Add a storage to the swap group."""
203 self._storages.add(storage)
205 def launch_offload(self, copy_stream):
206 """Launch async offload for all storages in the group."""
207 compute_event = platform.new_event()
208 compute_event.record(platform.get_current_stream())
209 self._offload_event = platform.new_event()
210 stream_context = platform.get_stream_context()
211 with platform.no_grad(), stream_context(copy_stream):
212 compute_event.wait(copy_stream)
213 for storage in self._storages:
214 storage.launch_offload()
215 self._offload_event.record(copy_stream)
217 def wait_offload(self):
218 """Wait for offload to complete for all storages in the group."""
219 compute_stream = platform.get_current_stream()
220 stream_context = platform.get_stream_context()
221 with platform.no_grad(), stream_context(compute_stream):
222 self._offload_event.wait(compute_stream)
223 self._offload_event = None
224 for storage in self._storages:
225 storage.wait_offload()
227 def launch_load(self, copy_stream):
228 """Launch async load for all storages in the group."""
229 compute_event = platform.new_event()
230 compute_event.record(platform.get_current_stream())
231 self._load_event = platform.new_event()
232 stream_context = platform.get_stream_context()
233 with platform.no_grad(), stream_context(copy_stream):
234 compute_event.wait(copy_stream)
235 for storage in self._storages:
236 storage.launch_load()
237 self._load_event.record(copy_stream)
239 def wait_load(self):
240 """Wait for load to complete for all storages in the group."""
241 compute_stream = platform.get_current_stream()
242 stream_context = platform.get_stream_context()
243 with platform.no_grad(), stream_context(compute_stream):
244 self._load_event.wait(compute_stream)
245 self._load_event = None
246 for storage in self._storages:
247 storage.wait_load()
250class SwapManager:
251 """Singleton manager for swap groups and their operations."""
252 _instance: Optional["SwapManager"] = None
253 _lock = threading.Lock()
255 def __new__(cls):
256 if cls._instance is None:
257 with cls._lock:
258 if cls._instance is None:
259 cls._instance = super().__new__(cls)
260 cls._instance._groups = {}
261 cls._instance._current_group_name = ""
262 cls._instance._counter_lock = threading.Lock()
263 cls._instance._layer_count = 0
264 cls._copy_stream = None
265 return cls._instance
267 def add_storage(self, group_name: str, storage: Storage) -> None:
268 """Add a storage to a specified swap group."""
269 if group_name not in self._groups:
270 self._groups[group_name] = SwapGroup(group_name)
271 self._groups[group_name].add(storage)
273 def launch_offload(self, group_name: str, copy_stream=None):
274 """Launch async offload for a specified swap group."""
275 group = self._groups.get(group_name)
276 if group is None:
277 raise RuntimeError(f"Group {group_name} does not exist.")
278 if copy_stream is None:
279 copy_stream = self._get_copy_stream()
280 group.launch_offload(copy_stream)
282 def wait_offload(self, group_name: str):
283 """Wait for offload to complete for a specified swap group."""
284 group = self._groups.get(group_name)
285 if group is None:
286 raise RuntimeError(f"Group {group_name} does not exist.")
287 group.wait_offload()
289 def launch_load(self, group_name: str, copy_stream=None):
290 """Launch async load for a specified swap group."""
291 group = self._groups.get(group_name)
292 if group is None:
293 raise RuntimeError(f"Group {group_name} does not exist.")
294 if copy_stream is None:
295 copy_stream = self._get_copy_stream()
296 group.launch_load(copy_stream)
298 def wait_load(self, group_name: str):
299 """Wait for load to complete for a specified swap group."""
300 group = self._groups.get(group_name)
301 if group is None:
302 raise RuntimeError(f"Group {group_name} does not exist.")
303 group.wait_load()
305 def get_current_group_name(self):
306 return self._current_group_name
308 def set_current_group_name(self, group_name):
309 self._current_group_name = group_name
311 def set_forward_prefetch_layer(self, first_layer, second_layer):
312 """
313 Configure prefetching and offloading order between two consecutive layers.
315 Usage:
316 for i in range(len(model.layers) - 1):
317 set_forward_prefetch_layer(model.layers[i], model.layers[i + 1])
319 Ensures idempotency: safe to call multiple times on the same layer pair.
320 """
322 def _ensure_group_name(module):
323 """Assign a unique swap group name to the module if not already assigned."""
324 if not hasattr(module, "_swap_group_name"):
325 name = f"swap_group_{self._layer_count}"
326 self._layer_count += 1
327 module._swap_group_name = name
328 module._swap_group_order = {"prev": None, "next": None}
329 return module._swap_group_name
330 first_name = _ensure_group_name(first_layer)
331 second_name = _ensure_group_name(second_layer)
333 if first_layer._swap_group_order["next"] is None:
334 first_layer._swap_group_order["next"] = second_name
335 if second_layer._swap_group_order["prev"] is None:
336 second_layer._swap_group_order["prev"] = first_name
338 def _forward_pre_hook(group_name, module, _): # pylint: disable=W0613
339 if getattr(module, "_swap_state", None) == "pre_backward":
340 return
341 SwapManager().set_current_group_name(group_name)
343 def _forward_hook(group_name, module, args, output): # pylint: disable=W0613
344 """
345 Forward post-hook executed immediately after forward computation
346 of the current layer finishes.
348 Execution timeline (example with 3 layers, forward order: L0 → L1 → L2):
350 Time →
351 Forward Compute Stream:
352 | Fwd L0 | post(L0) | Fwd L1 | post(L1) | Fwd L2 |
354 Copy Stream (offload):
355 | Offload L0 | - | Offload L1 |
356 ↑ ↑
357 offload at post(L0) offload at post(L1)
359 Swap rules:
360 1. After forward computation of the current layer completes:
361 - If a next layer exists, asynchronously offload the activations
362 of the current layer (launch_offload).
364 Example:
365 - At post-forward of L0, offload activations of L0.
366 - At post-forward of L1, offload activations of L1.
368 2. To limit device memory peak:
369 - If a previous layer exists, wait until its offload operation
370 has completed (wait_offload).
372 Notes:
373 - Offload operations are issued on the copy stream to overlap data transfer
374 with forward computation of subsequent layers.
375 - If the module is already in 'pre_backward' state, this hook is skipped
376 to avoid triggering offload during backward phase.
377 """
378 if getattr(module, "_swap_state", None) == "pre_backward":
379 return
380 next_name = module._swap_group_order.get('next', None)
381 if next_name:
382 SwapManager().launch_offload(group_name)
383 prev_name = module._swap_group_order.get('prev', None)
384 if prev_name:
385 SwapManager().wait_offload(prev_name)
387 def _backward_pre_hook(group_name, module, grad_input): # pylint: disable=W0613
388 """
389 Pre-backward hook executed immediately before backward computation
390 of the current layer starts.
392 Execution timeline (example with 3 layers, backward order: L2 → L1 → L0):
394 Time →
395 Backward Compute Stream:
396 | pre(L2) | Grad L2 | pre(L1) | Grad L1 | pre(L0) | Grad L0 |
398 Copy Stream (load):
399 | Load L1 | - | Load L0 |
400 ↑ ↑
401 prefetch at pre(L2) prefetch at pre(L1)
403 Swap rules:
404 1. At the beginning of backward for the current layer:
405 - If a previous layer exists in backward order, asynchronously
406 prefetch its activations (launch_load).
408 Example:
409 - At pre-backward of L2, prefetch activations of L1.
410 - At pre-backward of L1, prefetch activations of L0.
412 2. Before starting backward computation of the current layer:
413 - Ensure that the activations of the current layer have already
414 been loaded back to device memory (wait_load).
416 Notes:
417 - Load operations are issued on the copy stream to overlap data transfer
418 with backward computation of the current layer.
419 - The swap state is marked as 'pre_backward' to prevent forward hooks
420 from issuing offload operations during backward phase.
421 """
422 module._swap_state = "pre_backward"
423 prev_name = module._swap_group_order.get('prev', None)
424 if prev_name:
425 SwapManager().launch_load(prev_name)
427 next_name = module._swap_group_order.get('next', None)
428 if next_name:
429 SwapManager().wait_load(group_name)
431 def _backward_hook(group_name, module, grad_input, grad_output): # pylint: disable=W0613
432 module._swap_state = "backward"
434 def _register_hooks_once(module, group_name):
435 hooks = [
436 ("_swap_forward_pre_hook_handle",
437 lambda h: platform.register_forward_pre_hook(module, h, prepend=True),
438 functools.partial(_forward_pre_hook, group_name)),
440 ("_swap_forward_hook_handle",
441 module.register_forward_hook,
442 functools.partial(_forward_hook, group_name)),
444 ("_swap_backward_pre_hook_handle",
445 lambda h: platform.register_full_backward_pre_hook(module, h, prepend=True),
446 functools.partial(_backward_pre_hook, group_name)),
448 ("_swap_backward_hook_handle",
449 lambda h: platform.register_full_backward_hook(module, h),
450 functools.partial(_backward_hook, group_name)),
451 ]
453 for attr_name, register_func, hook in hooks:
454 if not hasattr(module, attr_name):
455 handle = register_func(hook)
456 setattr(module, attr_name, handle)
457 # Register for both layers
458 _register_hooks_once(first_layer, first_name)
459 _register_hooks_once(second_layer, second_name)
461 def _get_copy_stream(self):
462 """Return a singleton copy stream, created on first access."""
463 if self._copy_stream is None:
464 self._copy_stream = platform.new_stream()
465 return self._copy_stream