Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/activation_checkpoint/swap.py 13.5% 44,52,122-124,126,212,215,218,221,232,269,280,290,301,304,313,324,333-336,341,351,355,366,381,387,396,406-413,424-427,438-439,494,638
hyper_parallel/platform/mindspore/activation_checkpoint/activation_swap.py 0.0% 254,256,258-260,337
hyper_parallel/platform/mindspore/activation_checkpoint/sac.py 28.0% 52-53,72,86,92,95-96,99,108,111,121-123,126,136,144,156-157
hyper_parallel/platform/mindspore/platform.py 50.0% 1458-1459
hyper_parallel/platform/platform.py 66.7% 1098
hyper_parallel/platform/torch/activation_checkpoint/activation_swap.py 0.0% 98,100,102-104,236
hyper_parallel/platform/torch/activation_checkpoint/sac.py 0.0% 20-21,29,62,65-67,132,134,159,165,168,172-173,179,181,184,198-200,203,213,221,306-307
hyper_parallel/platform/torch/platform.py 66.7% 1206
hyper_parallel/core/activation_checkpoint/swap.py
40
41
42
43
44
45
46
47
48
        self.funcname = funcname
        self._keep_on_device = False
        self._duplicate_swap = False
        if isinstance(val, platform.Tensor) and str(val.device).lower() != 'cpu':
            self.ver = val._version
            self._state = self.STATE_DEVICE
            self.is_slice_tensor = val.untyped_storage().size() != val.numel() * platform.get_element_size(val)
            self.val_cpu = platform.empty_like(
                val, device="cpu", pin_memory=True
48
49
50
51
52
53
54
55
56
                val, device="cpu", pin_memory=True
            )
            self.storage_size = val.untyped_storage().size()
        else:
            self.ver = None
            self._state = self.STATE_NON_TENSOR
            self.val_cpu = None

    def dedup_key(self):
118
119
120
121
122
123
124
125
126
127
128
129
130
            return

        if self.val_cpu is None:
            raise ValueError("val_cpu must not be None during async_load")
        with platform.preserve_version_counter(self.val):
            if self.is_slice_tensor:
                self.val.data.copy_(self.val_cpu, non_blocking=True)
            else:
                self.val.untyped_storage().copy_(self.val_cpu.untyped_storage(), non_blocking=True)
        self._state = self.STATE_H2D

    def wait_load(self):
        """change state to device after async load is done"""
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    ``for batch in storage.values(): ...``.
    """

    def __init__(self):
        self._data: Dict[Any, List[Any]] = defaultdict(list)

    def __getitem__(self, key: Any) -> List[Any]:
        return self._data[key]

    def values(self):
        return self._data.values()

    def clear(self):
        self._data.clear()

    def iter_swap_tensors(self):
        """Iterate all SwapTensor objects stored in this storage."""
        collected = []
228
229
230
231
232
233
234
235
            if isinstance(x, SwapTensor):
                collected.append(x)
            return x

        for storage_list in self.values():
            for item in storage_list:
                platform.tree_map(_collect, item)
        return collected
265
266
267
268
269
270
271
272
273
            if isinstance(x, SwapTensor):
                x.protect_if_aliases(output_tensors)
            return x

        for storage_list in self.values():
            for item in storage_list:
                platform.tree_map(_protect_tensor, item)

    def launch_load(self):
276
277
278
279
280
281
282
283
284
            if isinstance(x, SwapTensor):
                x.async_load()
            return x

        for storage_list in self.values():
            for item in storage_list:
                platform.tree_map(_async_load, item)

    def resize_device_storage(self):
286
287
288
289
290
291
292
293
294
        def _resize(x):
            if isinstance(x, SwapTensor):
                x.resize_device_storage()
            return x
        for storage_list in self.values():
            for item in storage_list:
                platform.tree_map(_resize, item)

    def wait_load(self):
297
298
299
300
301
302
303
304
305
306
307
308
            if isinstance(x, SwapTensor):
                x.wait_load()
            return x

        for storage_list in self.values():
            for item in storage_list:
                platform.tree_map(_wait_load, item)
        self.clear()

    def wait_offload(self):
        """wait offload for all tensors in swap storage"""
        def _wait_offload(x):
309
310
311
312
313
314
315
316
317
            if isinstance(x, SwapTensor):
                x.wait_offload()
            return x

        for storage_list in self.values():
            for item in storage_list:
                platform.tree_map(_wait_offload, item)

    def launch_offload(self):
320
321
322
323
324
325
326
327
328
            if isinstance(x, SwapTensor):
                x.async_offload()
            return x

        for storage_list in self.values():
            for item in storage_list:
                platform.tree_map(_async_offload, item)

class SwapGroup:
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    """Manager for a group of storages to coordinate swap operations."""

    def __init__(self, group_name: str):
        self.group_name = group_name
        self.is_last_group: bool = False
        self._storages: List[Storage] = []
        self._load_event: Optional[Any] = None
        self._offload_event: Optional[Any] = None

    def add(self, storage):
        """Add a storage to the swap group."""
        seen_keys = set()
        for existing_storage in self._storages:
            for swap_tensor in existing_storage.iter_swap_tensors():
                dedup_key = swap_tensor.dedup_key()
                if dedup_key is not None:
                    seen_keys.add(dedup_key)
347
348
349
350
351
352
353
354
355
356
357
358
359
        if duplicate_count > 0:
            warnings.warn(
                f"SwapGroup '{self.group_name}' skipped {duplicate_count} duplicate tensor swap registration(s)."
            )
        self._storages.append(storage)

    def protect_output_tensors(self, outputs: Any):
        """Protect current module outputs from premature offload."""
        for storage in self._storages:
            storage.protect_output_tensors(outputs)

    def launch_offload(self, copy_stream):
        """Launch async offload for all storages in the group."""
362
363
364
365
366
367
368
369
370
        self._offload_event = platform.new_event()
        stream_context = platform.get_stream_context()
        with platform.no_grad(), stream_context(copy_stream):
            compute_event.wait(copy_stream)
            for storage in self._storages:
                storage.launch_offload()
            self._offload_event.record(copy_stream)

    def wait_offload(self):
377
378
379
380
381
382
383
384
385
        stream_context = platform.get_stream_context()
        with platform.no_grad(), stream_context(compute_stream):
            self._offload_event.wait(compute_stream)
            self._offload_event = None
            for storage in self._storages:
                storage.wait_offload()

    def launch_load(self, copy_stream):
        """Prepare storage and launch async load for all storages in the group."""
383
384
385
386
387
388
389
390
391

    def launch_load(self, copy_stream):
        """Prepare storage and launch async load for all storages in the group."""
        with platform.no_grad():
            for storage in self._storages:
                storage.resize_device_storage()

        compute_event = platform.new_event()
        compute_event.record(platform.get_current_stream())
392
393
394
395
396
397
398
399
400
        self._load_event = platform.new_event()
        stream_context = platform.get_stream_context()
        with platform.no_grad(), stream_context(copy_stream):
            compute_event.wait(copy_stream)
            for storage in self._storages:
                storage.launch_load()    # Only copy, no resize
            self._load_event.record(copy_stream)

    def wait_load(self):
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
        if self._load_event is None:
            raise RuntimeError(
                f"SwapGroup '{self.group_name}' wait_load() called before launch_load()."
            )
        compute_stream = platform.get_current_stream()
        stream_context = platform.get_stream_context()
        with platform.no_grad(), stream_context(compute_stream):
            self._load_event.wait(compute_stream)
            self._load_event = None
            for storage in self._storages:
                storage.wait_load()
        self._storages.clear()


class SwapManager:
    """Singleton manager for swap groups and their operations."""
420
421
422
423
424
425
426
427
428
429
430
431

    def __init__(self):
        if hasattr(self, '_groups'):
            return
        self._groups: Dict[str, SwapGroup] = {}
        self._current_group_name: str = ""
        self._layer_count: int = 0
        self._copy_stream: Optional[Any] = None

    def __new__(cls):
        if cls._instance is None:
            with cls._lock:
434
435
436
437
438
439
440
441
442
443
        return cls._instance

    def add_storage(self, group_name: str, storage: Storage) -> None:
        """Add a storage to a specified swap group."""
        self.ensure_group(group_name)
        self._groups[group_name].add(storage)

    def ensure_group(self, group_name: str) -> None:
        """Create the swap group if it does not exist yet."""
        if group_name not in self._groups:
490
491
492
493
494
495
496
497
        and therefore never goes through the offload-load cycle).
        """
        group = self._groups.get(group_name)
        if group is not None:
            group._storages.clear()

    def get_current_group_name(self) -> str:
        return self._current_group_name
634
635
636
637
638
639
640
641

            next_name = module._swap_group_order.get('next', None)
            if next_name:
                SwapManager().wait_load(group_name)
            SwapManager().release_group_storage(group_name)

        def _backward_hook(group_name, module, grad_input, grad_output):  # pylint: disable=W0613
            module._swap_state = "backward"
hyper_parallel/platform/mindspore/activation_checkpoint/activation_swap.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
            if not self.add_to_storage:
                SwapManager().add_storage(group_name, self.storage)
                self.add_to_storage = True
            funcname = f"{group_name}::{tensor.shape}"
            self.storage[self.count_idx].append(SwapTensor(tensor, funcname))
            self.count_idx += 1
            return tensor

        def unpack_from_cpu(tensor) -> ms.Tensor:
            if self.storage is not None:
                self.storage.clear()
                self.storage = None
            return tensor

        super().__init__(pack_to_cpu, unpack_from_cpu)
333
334
335
336
337
338
339
340
341
        nonlocal count_idx
        if isinstance(x, Tensor) and base_check_fn(x):
            tensor_tag = tag or f"{group_name}_swap_tensor"
            funcname = f"{tensor_tag}::{tuple(x.shape)}"
            storage[count_idx].append(SwapTensor(x, funcname))
            count_idx += 1
        return x

    def _map(tree):
hyper_parallel/platform/mindspore/activation_checkpoint/sac.py
48
49
50
51
52
53
54
55
56
57
class _SwapCacheEntry:
    """Pair the recompute cache and swap record around the same tensor object."""

    def __init__(self, val, funcname):
        self.save = _VersionWrapper(val)
        self.swap = SwapTensor(val, funcname)


def _maybe_detach(x):
    if isinstance(x, ms.Tensor) and (x.is_floating_point() or x.is_complex()):
68
69
70
71
72
73
74
75
76

class _CachingMindSporeDispatchMode(MsDispatchMode):
    def __init__(self, policy_fn, swap_storage, storage):
        self.policy_fn = policy_fn
        self.swap_storage = swap_storage
        self.storage = storage
        self.add_to_storage = False

    def __ms_dispatch__(self, func, args=(), kwargs=None):
82
83
84
85
86
87
88
89
90

        out = func(*args, **kwargs)

        if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE):
            self.storage[func.name].append(
                platform.tree_map(lambda x: _VersionWrapper(_maybe_detach(x)), out)
            )
        elif policy == CheckpointPolicy.MUST_SWAP:
            group_name = SwapManager().get_current_group_name()
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
            )
        elif policy == CheckpointPolicy.MUST_SWAP:
            group_name = SwapManager().get_current_group_name()
            if not self.add_to_storage:
                SwapManager().add_storage(group_name, self.swap_storage)
                self.add_to_storage = True
            funcname = f"{group_name}::{func.name}"
            entries = platform.tree_map(lambda x: _SwapCacheEntry(_maybe_detach(x), funcname), out)
            self.storage[func.name].append(
                platform.tree_map(lambda x: x.save, entries)
            )
            self.swap_storage[func.name].append(
                platform.tree_map(lambda x: x.swap, entries)
            )
        return out
104
105
106
107
108
109
110
111
112
113
114
115

class _CachedMindSporeDispatchMode(MsDispatchMode):
    def __init__(self, policy_fn, swap_storage, storage, allow_cache_entry_mutation):
        self.policy_fn = policy_fn
        self.swap_storage = swap_storage
        self.storage = storage
        self.allow_cache_entry_mutation = allow_cache_entry_mutation
        self._swap_cleared = False

    def __ms_dispatch__(self, func, args=(), kwargs=None):
        kwargs = {} if kwargs is None else kwargs
        if func.name in SAC_IGNORED_OPS:
117
118
119
120
121
122
123
124
125
126
127
128
129
130

        policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True),
                                func, *args, **kwargs)

        if not self._swap_cleared:
            self.swap_storage.clear()
            self._swap_cleared = True

        if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE):
            storage = self.storage.get(func.name)  # patch code
            if storage is None:
                raise RuntimeError(f"{func} encountered during backward, but not found in storage")
            if len(storage) == 0:
                raise RuntimeError(
132
133
134
135
136
137
138
139
140
                    "on any region computed under selective activation checkpoint."
                )
            out = platform.tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0))
        elif policy == CheckpointPolicy.MUST_SWAP:  # patch code
            storage = self.storage.get(func.name)
            if storage is None:
                raise RuntimeError(f"{func} encountered during backward, but not found in storage")
            if len(storage) == 0:
                raise RuntimeError(
140
141
142
143
144
145
146
147
                raise RuntimeError(
                    "Trying to backward an extra time. You are only allowed to backward once "
                    "on any region computed under selective activation checkpoint."
                )
            out = platform.tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0))
        else:
            out = func(*args, **kwargs)
        return out
152
153
154
155
156
157
158
159
160
161
        policy_fn = policy_fn_or_list
    else:
        raise TypeError("policy_fn_or_list must be either a function or a list of ops.")

    swap_storage = Storage()
    storage: Dict[Any, List[Any]] = defaultdict(list)
    return (
        _CachingMindSporeDispatchMode(policy_fn, swap_storage, storage),
        _CachedMindSporeDispatchMode(policy_fn, swap_storage, storage, allow_cache_entry_mutation)
    )
hyper_parallel/platform/mindspore/platform.py
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
        return _no_grad()

    @staticmethod
    def preserve_version_counter(tensor):
        from mindspore.common.api import _unsafe_preserve_version_counter  # pylint: disable=C0415
        return _unsafe_preserve_version_counter(tensor)

    @staticmethod
    def relu(tensor):
        return mint.nn.functional.relu(tensor)
hyper_parallel/platform/platform.py
1094
1095
1096
1097
1098
1099
1100
1101
1102

    @staticmethod
    def preserve_version_counter(tensor):
        """Get a context manager that preserves version for an internal tensor update."""
        raise NotImplementedError("Platform subclasses must implement preserve_version_counter")

    @staticmethod
    def relu(tensor):
        """Apply ReLU activation element-wise.
hyper_parallel/platform/torch/activation_checkpoint/activation_swap.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
            if not self.add_to_storage:
                SwapManager().add_storage(group_name, self.storage)
                self.add_to_storage = True
            funcname = f"{group_name}::{tensor.shape}"
            self.storage[self.count_idx].append(SwapTensor(tensor, funcname))
            self.count_idx += 1
            return tensor

        def unpack_from_cpu(tensor) -> torch.Tensor:
            if self.storage is not None:
                self.storage.clear()
                self.storage = None
            return tensor

        super().__init__(pack_to_cpu, unpack_from_cpu)
232
233
234
235
236
237
238
239
240
            return tensor

        tensor_tag = tag or f"{group_name}_swap_tensor"
        funcname = f"{tensor_tag}::{tuple(tensor.shape)}"
        storage[count_idx].append(SwapTensor(tensor, funcname))
        count_idx += 1
        return tensor

    wrapped = torch.utils._pytree.tree_map(  # pylint: disable=protected-access
hyper_parallel/platform/torch/activation_checkpoint/sac.py
16
17
18
19
20
21
22
23
24
25
# enhanced with selective checkpoint support swap
# ============================================================================
"""enhanced with selective checkpoint support swap"""
# pylint: disable=W0212, W0613, C0115, C0116, C0103, R1705
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union

import torch
import torch.fx.traceback as fx_traceback
from torch._functorch._aot_autograd.functional_utils import is_fun
25
26
27
28
29
30
31
32
33
from torch._functorch._aot_autograd.functional_utils import is_fun
from torch.utils._pytree import tree_map
from torch.utils._python_dispatch import TorchDispatchMode
from hyper_parallel.core.activation_checkpoint import CheckpointPolicy  # patch code
from hyper_parallel.core.activation_checkpoint.swap import (  # patch code
    SwapManager,
    SwapTensor,
    Storage,
)
58
59
60
61
62
63
64
65
66
67
68
69
70
71
                )
        return self.val


class _SwapCacheEntry:
    """Pair the recompute cache and swap record around the same tensor object."""

    def __init__(self, val, funcname):
        self.save = _VersionWrapper(val)
        self.swap = SwapTensor(val, funcname)


def _maybe_detach(x, any_ret_has_alias_info):
    # We detach for two separate reasons:
128
129
130
131
132
133
134
135
136
137
138


class _CachingTorchDispatchMode(TorchDispatchMode):
    # Used together with _CachedTorchDispatchMode to implement SAC.
    def __init__(self, policy_fn, swap_storage, storage):
        self.policy_fn = policy_fn
        self.swap_storage = swap_storage
        self.storage = storage
        self.add_to_storage = False

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
155
156
157
158
159
160
161
162
163

        any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns)

        if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE):
            self.storage[func].append(
                tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)
            )
        elif policy == CheckpointPolicy.MUST_SWAP:  # patch code
            group_name = SwapManager().get_current_group_name()
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            )
        elif policy == CheckpointPolicy.MUST_SWAP:  # patch code
            group_name = SwapManager().get_current_group_name()
            if not self.add_to_storage:
                SwapManager().add_storage(group_name, self.swap_storage)
                self.add_to_storage = True
            funcname = f"{group_name}::{func}"
            entries = tree_map(
                lambda x: _SwapCacheEntry(_maybe_detach(x, any_ret_has_alias_info), funcname),
                out,
            )
            self.storage[func].append(tree_map(lambda x: x.save, entries))
            self.swap_storage[func].append(tree_map(lambda x: x.swap, entries))
        return out


class _CachedTorchDispatchMode(TorchDispatchMode):
175
176
177
178
179
180
181
182
183
184
185
186
187
188


class _CachedTorchDispatchMode(TorchDispatchMode):
    # Used together with _CachedTorchDispatchMode to implement SAC.
    def __init__(self, policy_fn, swap_storage, storage, allow_cache_entry_mutation):
        self.policy_fn = policy_fn
        self.swap_storage = swap_storage
        self.storage = storage
        self.allow_cache_entry_mutation = allow_cache_entry_mutation
        self._swap_cleared = False

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if func in SAC_IGNORED_OPS:
            return func(*args, **kwargs)
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            policy = _policy_from_bool(policy)

        is_compiling = _is_compiling(func, args, kwargs)

        if not self._swap_cleared:
            self.swap_storage.clear()
            self._swap_cleared = True

        if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling:
            storage = self.storage.get(func)  # patch code
            if storage is None:
                raise RuntimeError(f"{func} encountered during backward, but not found in storage")
            if len(storage) == 0:
                raise RuntimeError(
209
210
211
212
213
214
215
216
217
                    "on any region computed under selective activation checkpoint."
                )
            out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0))
        elif policy == CheckpointPolicy.MUST_SWAP:  # patch code
            storage = self.storage.get(func)
            if storage is None:
                raise RuntimeError(f"{func} encountered during backward, but not found in storage")
            if len(storage) == 0:
                raise RuntimeError(
217
218
219
220
221
222
223
224
                raise RuntimeError(
                    "Trying to backward an extra time. You are only allowed to backward once "
                    "on any region computed under selective activation checkpoint."
                )
            out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0))
        else:
            out = func(*args, **kwargs)
        return out
302
303
304
305
306
307
308
309
310
311
        policy_fn = policy_fn_or_list
    else:
        raise TypeError("policy_fn_or_list must be either a function or a list of ops.")

    swap_storage = Storage()  # patch code
    storage: Dict[Any, List[Any]] = defaultdict(list)
    return (
        _CachingTorchDispatchMode(policy_fn, swap_storage, storage),
        _CachedTorchDispatchMode(policy_fn, swap_storage, storage, allow_cache_entry_mutation),
    )
hyper_parallel/platform/torch/platform.py
1202
1203
1204
1205
1206
1207
1208
1209
1210
        return torch.no_grad()

    @staticmethod
    def preserve_version_counter(tensor):
        return torch.autograd._unsafe_preserve_version_counter(tensor)  # pylint: disable=W0212

    @staticmethod
    def relu(tensor):
        return torch.relu(tensor)