Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/activation_checkpoint/activation_checkpoint.py 61.9% 76,148-152,161,164
hyper_parallel/core/activation_checkpoint/swap.py 4.3% 57-64,66-68,73-75,91-92,96-99,104-105,111,152-153,166-167,204-205,225-226,409-412,416,419,423,458-462,464-467,470-472,476-477,481-483,485,494-498,500-503,505-511,517-524,526-529,538-544,555,557-564,586,609-616,620-621,623-624,650-662,672-675,680-684
hyper_parallel/platform/mindspore/activation_checkpoint/activation_swap.py 26.1% 85,238-239,246,250,252,254,257,300,303,325,336-338,344,355,370
hyper_parallel/platform/mindspore/activation_checkpoint/sac.py 21.1% 53,75,77-78,95-97,99-101,134,150-153
hyper_parallel/platform/mindspore/platform.py 37.5% 1650-1651,1656-1657,1694-1701,1705-1707
hyper_parallel/platform/mindspore/platform_graph.py 0.0% 40,48,55-57
hyper_parallel/platform/platform.py 77.8% 1327,1372
hyper_parallel/platform/torch/activation_checkpoint/activation_swap.py 0.0% 70,79,86,91,93,95,98,132-134,203,211,214,218,223,226,233-234,240,253,262
hyper_parallel/platform/torch/activation_checkpoint/sac.py 0.0% 65,67,132,137,139-140,168-170,172-173,209,225,282-285
hyper_parallel/platform/torch/platform.py 37.5% 1318-1319,1324-1325,1330-1331,1341,1357-1359
hyper_parallel/core/activation_checkpoint/activation_checkpoint.py
72
73
74
75
76
77
78
79
80

    Returns:
        The result of applying the function with checkpointing.
    """
    context_fn = (
        partial(plat.create_selective_checkpoint_contexts, policy_fn, group_swap=group_swap)
        if policy_fn or group_swap else plat.noop_context_fn
    )
    context = plat.async_save_on_cpu if swap_inputs else contextlib.nullcontext
144
145
146
147
148
149
150
151
152
153
154
155
156
        self.checkpoint_kwargs = checkpoint_kwargs

    def _do_checkpoint(self, wrapped_module: Any, *args: Any, **kwargs: Any) -> Any:
        # Checkpoint may save inputs before the wrapped cell's pre-hook runs.
        group_name = getattr(self, "_swap_group_name", None)
        if group_name is not None:
            from hyper_parallel.core.activation_checkpoint.swap import SwapManager  # pylint: disable=C0415
            SwapManager().set_current_group_name(group_name)
        return checkpoint(
            wrapped_module,
            *args,
            group_swap=self.group_swap,
            **self.checkpoint_kwargs,
157
158
159
160
161
162
163
164
165
166
167
168
            **kwargs,
        )

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        return self._do_checkpoint(self._wrapped_module, *args, **kwargs)

    def construct(self, *args: Any, **kwargs: Any) -> Any:
        return self._do_checkpoint(self._wrapped_module, *args, **kwargs)


def ckpt_wrapper(module, group_swap: bool = False, **checkpoint_kwargs):
    """Wrap *module* with activation checkpointing.
hyper_parallel/core/activation_checkpoint/swap.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    ``buf[:total_numel]`` for the actual copy so the returned reference can be
    passed back to :func:`_return_cpu_pinned_buf` without any platform-specific
    introspection.
    """
    pool = _CPU_PINNED_POOL[dtype_key]
    best_i = -1
    for i, buf in enumerate(pool):
        if buf.numel() >= total_numel:
            if best_i == -1 or buf.numel() < pool[best_i].numel():
                best_i = i
    if best_i != -1:
        return pool.pop(best_i)
    # No suitable buffer — discard one stale undersized entry.
    if pool:
        pool.pop()
    return platform.alloc_tensor_buffer(total_numel, dtype, device='cpu', pin_memory=True)


def _return_cpu_pinned_buf(buf):
    """Return a full pinned CPU buffer to the pool for reuse."""
    if buf is None:
        return
    _CPU_PINNED_POOL[str(buf.dtype)].append(buf)


class SwapTensor:
    """A tensor that can be swapped between device and host memory asynchronously."""
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
        self.val = val
        self.funcname = funcname
        self._keep_on_device = False
        self._duplicate_swap = False
        self._group_managed = False # True when this tensor is handled by SwapGroup bulk copy
        self.group_swap = group_swap # opt-in for group copy fusion (MUST_SWAP tensors only)
        if isinstance(val, platform.Tensor) and str(val.device).lower() != 'cpu':
            self.ver = val._version
            self._state = self.STATE_DEVICE
            val_storage = val.untyped_storage()
            self.storage_size = val_storage.size()
            self.is_slice_tensor = self.storage_size != val.numel() * platform.get_element_size(val)
            self.val_cpu = None
        else:
            self.ver = None
            self._state = self.STATE_NON_TENSOR
            self.val_cpu = None
            self.is_slice_tensor = False
            self.storage_size = 0

    def dedup_key(self):
        """Return a stable identity key for duplicate-swap detection."""
        if self._state == self.STATE_NON_TENSOR:
107
108
109
110
111
112
113
114
115
    def dedup_key(self):
        """Return a stable identity key for duplicate-swap detection."""
        if self._state == self.STATE_NON_TENSOR:
            return None
        val_storage = self.val.untyped_storage()
        return (
            str(self.val.device),
            val_storage.data_ptr(),
            self.val.storage_offset(),
148
149
150
151
152
153
154
155
156
157
    def resize_device_storage(self):
        """Reallocate device memory on compute stream."""
        if self._state == self.STATE_NON_TENSOR or self._duplicate_swap:
            return
        if self._group_managed:
            return

        if self._state != self.STATE_HOST:
            return
        storage = self.val.untyped_storage()
162
163
164
165
166
167
168
169
170
171
    def async_load(self):
        """async load tensor from host to device"""
        if self._state == self.STATE_NON_TENSOR or self._keep_on_device or self._duplicate_swap:
            return
        if self._group_managed:
            return

        if self._state != self.STATE_HOST:
            warnings.warn(
                f"[SwapTensor.async_load] Invalid state: current={self._state}, "
200
201
202
203
204
205
206
207
208
209
    def async_offload(self):
        """async offload tensor from device to host"""
        if self._state == self.STATE_NON_TENSOR or self._keep_on_device or self._duplicate_swap:
            return
        if self._group_managed:
            return

        if self._state != self.STATE_DEVICE:
            warnings.warn(
                f"[SwapTensor.async_offload] Invalid state: current={self._state}, "
221
222
223
224
225
226
227
228
229
230
                f"There is a tensor from {self.funcname} cannot be SWAPPED! In-place modification happened "
                f"preversion:{self.ver}, current version:{self.val._version}"
            )

        if self.val_cpu is None:
            self.val_cpu = platform.empty_like(
                self.val, device="cpu", pin_memory=True
            )
        if self.is_slice_tensor:
            self.val_cpu.copy_(self.val, non_blocking=True)
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
        self._storages: List[Storage] = []
        self._load_event: Optional[Any] = None
        self._offload_event: Optional[Any] = None
        # Group-level contiguous buffers for non-slice tensors.
        self._packed_tensor_info: List = []   # [(SwapTensor, bucket_key, element_offset), ...]
        self._packed_buckets: Dict[str, Dict[str, Any]] = {}
        self._group_cpu_buf = None            # pinned CPU bufs; live offload→load
        self._group_device_buf = None         # temp device bufs; cleared after each phase
        # Persistent dedup set accumulated across add() calls; avoids O(N²) rebuild.
        # mark_duplicate_swaps mutates it in-place, so new keys are added automatically.
        # Reset at wait_load() so stale data_ptrs don't leak into the next iteration.
        self._seen_dedup_keys: set = set()
        # Per-bucket SwapTensor lists built in _collect_packable_tensors and consumed
        # in launch_offload, eliminating a redundant pass over _packed_tensor_info.
        self._packed_by_bucket: Dict[str, List] = {}

    def add(self, storage):
        """Add a storage to the swap group."""
        duplicate_count = storage.mark_duplicate_swaps(self._seen_dedup_keys)
        if duplicate_count > 0:
            warnings.warn(
                f"SwapGroup '{self.group_name}' skipped {duplicate_count} duplicate tensor swap registration(s)."
            )
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489

        Returns:
            Total byte count of all packable tensors.
        """
        candidate_buckets: Dict[str, List[Dict[str, Any]]] = {}
        packed_info: List = []
        packed_buckets: Dict[str, Dict[str, Any]] = {}
        packed_by_bucket: Dict[str, List] = {}
        total_bytes = 0

        def _try_pack(x):
            if not isinstance(x, SwapTensor):
                return x
            if (not x.group_swap or x._state != SwapTensor.STATE_DEVICE or x._keep_on_device or x.is_slice_tensor
                    or x._duplicate_swap or x.storage_size >= _GROUP_SWAP_MAX_BULK_COPY_BYTES
                    or not x.val.is_contiguous()):
                return x
            if x.storage_size != x.val.untyped_storage().size():
                raise RuntimeError(
                    f"There is a tensor from {x.funcname} cannot be SWAPPED! Its storage has been resized "
                    f"presize:{x.storage_size}, current size:{x.val.untyped_storage().size()}"
                )
            if x.ver != x.val._version:
                raise RuntimeError(
                    f"There is a tensor from {x.funcname} cannot be SWAPPED! In-place modification happened "
                    f"preversion:{x.ver}, current version:{x.val._version}"
                )
            dtype_key = str(x.val.dtype)
            dtype_buckets = candidate_buckets.setdefault(dtype_key, [])
            if (not dtype_buckets or
                    dtype_buckets[-1]["total_bytes"] + x.storage_size > _GROUP_SWAP_MAX_BULK_COPY_BYTES):
                dtype_buckets.append({
                    "bucket_key": f"{dtype_key}#{len(dtype_buckets)}",
                    "dtype": x.val.dtype,
                    "dtype_key": dtype_key,
                    "device": x.val.device,
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
                    "tensors": [],
                    "total_bytes": 0,
                    "total_numel": 0,
                })
            bucket = dtype_buckets[-1]
            bucket["tensors"].append(x)
            bucket["total_bytes"] += x.storage_size
            bucket["total_numel"] += x.val.numel()
            return x

        for storage in self._storages:
            for storage_list in storage.values():
                for item in storage_list:
                    platform.tree_map(_try_pack, item)

        for dtype_bucket_list in candidate_buckets.values():
            for candidate_bucket in dtype_bucket_list:
                tensors = candidate_bucket["tensors"]
                if len(tensors) < 2:
                    continue
                bucket_key = candidate_bucket["bucket_key"]
                packed_buckets[bucket_key] = {
                    "dtype": candidate_bucket["dtype"],
                    "dtype_key": candidate_bucket["dtype_key"],
                    "device": candidate_bucket["device"],
                    "total_numel": candidate_bucket["total_numel"],
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
                    "dtype_key": candidate_bucket["dtype_key"],
                    "device": candidate_bucket["device"],
                    "total_numel": candidate_bucket["total_numel"],
                }
                element_offset = 0
                for tensor in tensors:
                    tensor._group_managed = True
                    tensor._state = SwapTensor.STATE_D2H
                    packed_info.append((tensor, bucket_key, element_offset))
                    element_offset += tensor.val.numel()
                packed_by_bucket[bucket_key] = tensors
                total_bytes += candidate_bucket["total_bytes"]

        self._packed_tensor_info = packed_info
        self._packed_buckets = packed_buckets
        self._packed_by_bucket = packed_by_bucket
        return total_bytes

    def launch_offload(self, copy_stream):
        """Launch async offload for all storages in the group.
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
        Non-slice tensors are first packed into bounded contiguous device
        buffers, then transferred to pinned CPU memory.  Slice tensors are
        offloaded individually via the existing per-tensor path.
        """
        total_bytes = self._collect_packable_tensors()
        with platform.no_grad():
            if total_bytes > 0:
                group_device_bufs = {}
                group_cpu_bufs = {}
                for bucket_key, swap_tensors in self._packed_by_bucket.items():
                    group_device_bufs[bucket_key] = platform.cat(
                        [st.val.reshape(-1) for st in swap_tensors], dim=0
                    )

        compute_event = platform.new_event()
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        stream_context = platform.get_stream_context()
        with platform.no_grad(), stream_context(copy_stream):
            compute_event.wait(copy_stream)

            if total_bytes > 0:
                # One-shot D2H per packed bucket. MindSpore requires tensor/storage dtype consistency.
                for bucket_key, bucket in self._packed_buckets.items():
                    dtype_key = bucket["dtype_key"]
                    numel = bucket["total_numel"]
                    cpu_buf = _get_cpu_pinned_buf(dtype_key, numel, bucket["dtype"])
                    group_cpu_bufs[bucket_key] = cpu_buf
                    cpu_buf[:numel].copy_(group_device_bufs[bucket_key], non_blocking=True)
                self._group_device_buf = group_device_bufs
                self._group_cpu_buf = group_cpu_bufs

            # Slice tensors use the existing per-tensor path.
            # Group-managed tensors are already STATE_D2H so async_offload is a no-op.
            for storage in self._storages:
582
583
584
585
586
587
588
589
            self._offload_event = None
            for storage in self._storages:
                storage.wait_offload()
        # Release the temporary device packing buffer; _group_cpu_buf persists until launch_load.
        self._group_device_buf = None

    def launch_load(self, copy_stream):
        """Prepare storage and launch async load for all storages in the group.
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
        stream_context = platform.get_stream_context()
        with platform.no_grad(), stream_context(copy_stream):
            compute_event.wait(copy_stream)

            if self._packed_tensor_info and self._group_cpu_buf is not None:
                group_device_bufs = {}
                for bucket_key, bucket in self._packed_buckets.items():
                    cpu_buf = self._group_cpu_buf.get(bucket_key)
                    if cpu_buf is None:
                        continue
                    numel = bucket["total_numel"]
                    group_device_bufs[bucket_key] = platform.alloc_tensor_buffer(
                        numel, bucket["dtype"], bucket["device"]
                    )
                    # One-shot H2D per packed bucket.
                    group_device_bufs[bucket_key].copy_(cpu_buf[:numel], non_blocking=True)
                self._group_device_buf = group_device_bufs
                # Mirror async_load's STATE_H2D transition: H2D is in flight.
                for st, _, _ in self._packed_tensor_info:
                    st._state = SwapTensor.STATE_H2D

            # Slice tensors use the existing per-tensor path.
            # Group-managed tensors skip async_load via _group_managed flag.
            for storage in self._storages:
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
        with platform.no_grad(), stream_context(compute_stream):
            self._load_event.wait(compute_stream)
            self._load_event = None
            # Restore group-managed tensors: alias into the contiguous device buffer.
            if self._group_device_buf is not None:
                prev_key = None
                group_storage = None
                for st, bucket_key, element_offset in self._packed_tensor_info:
                    if bucket_key != prev_key:
                        group_device_buf = self._group_device_buf.get(bucket_key)
                        group_storage = group_device_buf.untyped_storage() if group_device_buf is not None else None
                        prev_key = bucket_key
                    if group_storage is None:
                        continue
                    with platform.preserve_version_counter(st.val):
                        st.val.set_(group_storage, element_offset, st.val.shape, st.val.stride())
                    st._state = SwapTensor.STATE_DEVICE
            for storage in self._storages:
                storage.wait_load()
        self._storages.clear()
        # Return CPU pinned buffers to the pool.  By the time wait_load
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
        # means the copy stream's H2D transfer has completed and the CPU
        # buffer is no longer being read by the DMA engine.  The next
        # launch_offload (start of the following iteration) will pop these
        # buffers from the pool, well after the current H2D is done.
        if self._group_cpu_buf is not None:
            for buf in self._group_cpu_buf.values():
                _return_cpu_pinned_buf(buf)
        self._group_cpu_buf = None
        # Device buffer: the pool holds the staging reference; just drop
        # the local reference.  Tensors aliasing _group_device_buf's
        # storage keep it alive via their own storage references until
        # they are consumed in backward.
        self._group_device_buf = None
        self._packed_tensor_info = []
        self._packed_buckets = {}
        self._packed_by_bucket = {}
        self._seen_dedup_keys = set()


class SwapManager:
    """Singleton manager for swap groups and their operations."""
hyper_parallel/platform/mindspore/activation_checkpoint/activation_swap.py
81
82
83
84
85
86
87
88
89
        }

    @property
    def _wrapped_module(self):
        return self._ckpt_wrapped_module

    @abstractmethod
    def construct(self, *args, **kwargs):
        raise ValueError("Subclasses should implement construct().")
234
235
236
237
238
239
240
241
242
243
    Context manager to offload tensors to CPU during forward pass.
    """
    def __init__(self, policy_fn=None, group_swap: bool = False) -> None:
        # pylint: disable=C0415
        from hyper_parallel.core.activation_checkpoint.activation_checkpoint import CheckpointPolicy
        from hyper_parallel.core.activation_checkpoint.swap import Storage, SwapManager, SwapTensor
        self.add_to_storage = False
        self.storage = Storage()
        self.count_idx = 0
        self.policy_fn = policy_fn
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        self.count_idx = 0
        self.policy_fn = policy_fn

        # Cache per-context-manager state once to avoid per-tensor singleton lookups.
        swap_manager = SwapManager()
        def pack_to_cpu(tensor: ms.Tensor):
            if not base_check_fn(tensor):
                return tensor
            if (policy_fn is not None) and (policy_fn(tensor) == CheckpointPolicy.MUST_SAVE):
                return tensor
            group_name = swap_manager.get_current_group_name()
            if not self.add_to_storage:
                swap_manager.add_storage(group_name, self.storage)
                self.add_to_storage = True
            funcname = f"{group_name}::{tensor.shape}"
            self.storage[self.count_idx].append(
                SwapTensor(tensor, funcname, group_swap=group_swap)
            )
            self.count_idx += 1
            return tensor
296
297
298
299
300
301
302
303
304
305
306
307
        group_swap: bool = False,
    ):
        super().__init__(mod)
        self.policy_fn = policy_fn
        self.group_swap = group_swap

    def construct(self, *args, **kwargs):
        with AsyncSaveOnCpu(policy_fn=self.policy_fn, group_swap=self.group_swap):
            return self._ckpt_wrapped_module(*args, **kwargs)


def swap_wrapper(
321
322
323
324
325
326
327
328
329

    Returns:
        SwapWrapper: The wrapped cell with activation swap enabled.
    """
    return SwapWrapper(module, policy_fn, group_swap)


def swap_tensor_wrapper(target, tag: Optional[str] = None, group_swap: bool = False):
    """Register selected tensors into the current swap group.
332
333
334
335
336
337
338
339
340
341
342
    participates in the existing swap scheduling managed by ``SwapManager``.
    It preserves the input structure and returns the original tensors.
    """
    # pylint: disable=C0415
    from hyper_parallel.core.activation_checkpoint.swap import Storage, SwapManager, SwapTensor
    swap_manager = SwapManager()
    group_name = swap_manager.get_current_group_name()
    if not group_name:
        warnings.warn(
            f"Tensor {tag} cannot be swapped, for its group is unregistered."
        )
340
341
342
343
344
345
346
347
348
        warnings.warn(
            f"Tensor {tag} cannot be swapped, for its group is unregistered."
        )
        return target
    if swap_manager.is_last_group(group_name):
        return target

    storage = Storage()
    count_idx = 0
351
352
353
354
355
356
357
358
359
        nonlocal count_idx
        if isinstance(x, Tensor) and base_check_fn(x):
            tensor_tag = tag or f"{group_name}_swap_tensor"
            funcname = f"{tensor_tag}::{tuple(x.shape)}"
            storage[count_idx].append(SwapTensor(x, funcname, group_swap=group_swap))
            count_idx += 1
        return x

    def _map(tree):
366
367
368
369
370
371
        return _apply(tree)

    wrapped = _map(target)
    if count_idx > 0:
        _manager.add_storage(group_name, storage)
    return wrapped
hyper_parallel/platform/mindspore/activation_checkpoint/sac.py
49
50
51
52
53
54
55
56
57
    """Pair the recompute cache and swap record around the same tensor object."""

    def __init__(self, val, funcname, group_swap=False):
        self.save = _VersionWrapper(val)
        self.swap = SwapTensor(val, funcname, group_swap=group_swap)


def _maybe_detach(x):
    if isinstance(x, ms.Tensor) and (x.is_floating_point() or x.is_complex()):
71
72
73
74
75
76
77
78
79
80
81
82
        self.policy_fn = policy_fn
        self.swap_storage = swap_storage
        self.storage = storage
        self.add_to_storage = False
        self.group_swap = group_swap
        # Cache context and singleton to avoid per-dispatch allocation / lookup.
        self._swap_manager = SwapManager()
        self._group_prefix = ""

    def __ms_dispatch__(self, func, args=(), kwargs=None):
        kwargs = {} if kwargs is None else kwargs
        if func.name in SAC_IGNORED_OPS:
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
                platform.tree_map(lambda x: _VersionWrapper(_maybe_detach(x)), out)
            )
        elif policy == CheckpointPolicy.MUST_SWAP:
            if not self.add_to_storage:
                group_name = self._swap_manager.get_current_group_name()
                self._group_prefix = f"{group_name}::"
                self._swap_manager.add_storage(group_name, self.swap_storage)
                self.add_to_storage = True
            funcname = f"{self._group_prefix}{func.name}"
            group_swap = self.group_swap
            entries = platform.tree_map(
                lambda x: _SwapCacheEntry(_maybe_detach(x), funcname, group_swap=group_swap), out
            )
            self.storage[func.name].append(
                platform.tree_map(lambda x: x.save, entries)
130
131
132
133
134
135
136
137
138
            self.swap_storage.clear()
            self._swap_cleared = True

        # MUST_SAVE and MUST_SWAP both restore from storage identically.
        if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE, CheckpointPolicy.MUST_SWAP):
            storage = self.storage.get(func.name)
            if storage is None:
                raise RuntimeError(f"{func} encountered during backward, but not found in storage")
            if len(storage) == 0:
146
147
148
149
150
151
152
153
154
155
156
        return out


def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False, group_swap=False):
    if policy_fn_or_list is None:
        def policy_fn(_ctx, _op, *_args, **_kwargs):
            return CheckpointPolicy.PREFER_RECOMPUTE
    elif callable(policy_fn_or_list):
        policy_fn = policy_fn_or_list
    else:
        raise TypeError("policy_fn_or_list must be either a function or a list of ops.")
hyper_parallel/platform/mindspore/platform.py
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661

    @staticmethod
    def swap_wrapper(module, policy_fn=None, group_swap=False):
        # pylint: disable=C0415
        from hyper_parallel.platform.mindspore.activation_checkpoint.activation_swap import swap_wrapper
        return swap_wrapper(module, policy_fn=policy_fn, group_swap=group_swap)

    @staticmethod
    def swap_tensor_wrapper(target, tag=None, group_swap=False):
        # pylint: disable=C0415
        from hyper_parallel.platform.mindspore.activation_checkpoint.activation_swap import swap_tensor_wrapper
        return swap_tensor_wrapper(target, tag=tag, group_swap=group_swap)

    @staticmethod
    def get_class_activation_wrapper():
        # pylint: disable=C0415
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711

    @staticmethod
    def alloc_tensor_buffer(numel: int, dtype, device, pin_memory: bool = False):
        """Allocate an uninitialized 1-D tensor buffer."""
        if pin_memory:
            return mint.empty((numel,), dtype=dtype, device="cpu", pin_memory=True)
        if device is None:
            return mint.empty((numel,), dtype=dtype)
        device_type = str(device).split(":", maxsplit=1)[0].lower()
        ms_device = MindSporePlatform._MS_DEVICE_MAP.get(device_type)
        if ms_device is None:
            raise ValueError(
                f"Unsupported device type '{device_type}' for MindSpore; "
                f"supported: {sorted(MindSporePlatform._MS_DEVICE_MAP)}"
            )
        if ms_device == "cpu":
            return mint.empty((numel,), dtype=dtype, device="cpu")
        return mint.empty((numel,), dtype=dtype, device=ms_device)

    @staticmethod
    def get_element_size(tensor):
        """Get Tensor Element Size"""
hyper_parallel/platform/mindspore/platform_graph.py
36
37
38
39
40
41
42
43
44
        output = ops.ReduceScatter(group=group_info.group_name)(data)
        return output, None

    @staticmethod
    def swap_wrapper(module, policy_fn=None, group_swap=False):
        raise NotImplementedError("swap_wrapper is not supported on MindSpore Graph platform")

    @property
    def noop_context_fn(self):
44
45
46
47
48
49
50
51
52
    def noop_context_fn(self):
        raise NotImplementedError("noop_context_fn is not supported on MindSpore Graph platform")

    @staticmethod
    def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False, group_swap=False):
        raise NotImplementedError("create_selective_checkpoint_contexts is not supported on MindSpore Graph platform")

    @staticmethod
    def async_save_on_cpu(policy_fn=None):
51
52
53
54
55
56
57
    @staticmethod
    def async_save_on_cpu(policy_fn=None):
        raise NotImplementedError("async_save_on_cpu is not supported on MindSpore Graph platform")

    @staticmethod
    def get_class_activation_wrapper():
        raise NotImplementedError("get_class_activation_wrapper is not supported on MindSpore Graph platform")
hyper_parallel/platform/platform.py
1323
1324
1325
1326
1327
1328
1329
1330
1331
        raise NotImplementedError("Platform subclasses must implement swap_tensor_wrapper")

    @staticmethod
    def get_class_activation_wrapper():
        raise NotImplementedError("Platform subclasses must implement get_class_activation_wrapper")

    @property
    def noop_context_fn(self):
        """Get a no-op context function for checkpointing.
1368
1369
1370
1371
1372
1373
1374
1375
1376

    @staticmethod
    def alloc_tensor_buffer(numel: int, dtype, device, pin_memory: bool = False):
        """Allocate an uninitialized 1-D tensor buffer."""
        raise NotImplementedError("Platform subclasses must implement alloc_tensor_buffer")

    @staticmethod
    def tensor_to_numpy(tensor) -> np.ndarray:
        """Convert a framework tensor to a NumPy array.
hyper_parallel/platform/torch/activation_checkpoint/activation_swap.py
66
67
68
69
70
71
72
73
    - Skip empty storage tensors.
    """
    if isinstance(tensor._base, torch.nn.parameter.Parameter) or isinstance(tensor, torch.nn.parameter.Parameter):  # pylint: disable=W0212
        return False
    if tensor.untyped_storage().size() == 0:
        return False
    return True

75
76
77
78
79
80
81
82
83
class AsyncSaveOnCpu(torch.autograd.graph.saved_tensors_hooks):
    """
    Context manager to offload tensors to CPU during forward pass.
    """
    def __init__(self, policy_fn=None, group_swap: bool = False) -> None:
        self.add_to_storage = False
        self.storage = Storage()
        self.count_idx = 0
        self.policy_fn = policy_fn
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
        self.count_idx = 0
        self.policy_fn = policy_fn

        # Cache per-context-manager state once to avoid per-tensor singleton lookups.
        swap_manager = SwapManager()

        def pack_to_cpu(tensor: torch.Tensor):
            if not base_check_fn(tensor):
                return tensor
            if (policy_fn is not None) and (policy_fn(tensor) == CheckpointPolicy.MUST_SAVE):
                return tensor
            group_name = swap_manager.get_current_group_name()
            if not self.add_to_storage:
                swap_manager.add_storage(group_name, self.storage)
                self.add_to_storage = True
            funcname = f"{group_name}::{tensor.shape}"
            self.storage[self.count_idx].append(
                SwapTensor(tensor, funcname, group_swap=group_swap)
            )
            self.count_idx += 1
            return tensor
128
129
130
131
132
133
134
135
136
137
138
        # load_state_dict pre-hook to allow loading back into
        # swap-wrapped module.
        self.register_load_state_dict_pre_hook(self._pre_load_state_dict_hook)

    @property
    def _wrapped_module(self):
        return self._swap_wrapped_module

    @abstractmethod
    def forward(self, *args, **kwargs):
        raise ValueError("Subclasses should implement forward().")
199
200
201
202
203
204
205
206
207
class SwapWrapper(ActivationWrapper):
    """
    Customize an nn.Module wrapper class to add an AsyncSaveOnCpu context manager for the target model.
    """
    def __init__(
        self,
        mod: Union[nn.Module, Callable],
        policy_fn: Optional[Callable] = None,
        group_swap: bool = False,
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        group_swap: bool = False,
    ):
        super().__init__(mod)
        self.policy_fn = policy_fn
        self.group_swap = group_swap

    def forward(self, *args, **kwargs):
        with AsyncSaveOnCpu(policy_fn=self.policy_fn, group_swap=self.group_swap):
            return self._swap_wrapped_module(*args, **kwargs)


def swap_wrapper(
    module: Union[nn.Module, Callable],
    policy_fn: Optional[Callable] = None,
    group_swap: bool = False,
) -> SwapWrapper:
    return SwapWrapper(module, policy_fn, group_swap)


def swap_tensor_wrapper(target, tag: Optional[str] = None, group_swap: bool = False):
    """Register selected tensors into the current swap group.

    This helper is intended to be used inside a forward path that already
    participates in the existing swap scheduling managed by ``SwapManager``.
229
230
231
232
233
234
235
236
237
238
    This helper is intended to be used inside a forward path that already
    participates in the existing swap scheduling managed by ``SwapManager``.
    It preserves the input structure and returns the original tensors.
    """
    swap_manager = SwapManager()
    group_name = swap_manager.get_current_group_name()
    if not group_name:
        warnings.warn(
            f"Tensor {tag} cannot be swapped, for its group is unregistered."
        )
236
237
238
239
240
241
242
243
244
        warnings.warn(
            f"Tensor {tag} cannot be swapped, for its group is unregistered."
        )
        return target
    if swap_manager.is_last_group(group_name):
        return target

    storage = Storage()
    count_idx = 0
249
250
251
252
253
254
255
256
257
            return tensor

        tensor_tag = tag or f"{group_name}_swap_tensor"
        funcname = f"{tensor_tag}::{tuple(tensor.shape)}"
        storage[count_idx].append(SwapTensor(tensor, funcname, group_swap=group_swap))
        count_idx += 1
        return tensor

    wrapped = torch.utils._pytree.tree_map(  # pylint: disable=protected-access
258
259
260
261
262
263
        lambda x: _register_tensor(x) if isinstance(x, torch.Tensor) else x,
        target,
    )
    if count_idx > 0:
        _manager.add_storage(group_name, storage)
    return wrapped
hyper_parallel/platform/torch/activation_checkpoint/sac.py
61
62
63
64
65
66
67
68
69
70
71

class _SwapCacheEntry:
    """Pair the recompute cache and swap record around the same tensor object."""

    def __init__(self, val, funcname, group_swap=False):
        self.save = _VersionWrapper(val)
        self.swap = SwapTensor(val, funcname, group_swap=group_swap)


def _maybe_detach(x, any_ret_has_alias_info):
    # We detach for two separate reasons:
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144


class _CachingTorchDispatchMode(TorchDispatchMode):
    # Used together with _CachedTorchDispatchMode to implement SAC.
    def __init__(self, policy_fn, swap_storage, storage, group_swap=False):
        self.policy_fn = policy_fn
        self.swap_storage = swap_storage
        self.storage = storage
        self.add_to_storage = False
        self.group_swap = group_swap
        # Cache context and singleton to avoid per-dispatch allocation / lookup.
        self._swap_manager = SwapManager()
        self._group_prefix = ""

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if func in SAC_IGNORED_OPS:
            return func(*args, **kwargs)
164
165
166
167
168
169
170
171
172
173
174
175
176
177
                tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)
            )
        elif policy == CheckpointPolicy.MUST_SWAP:  # patch code
            if not self.add_to_storage:
                group_name = self._swap_manager.get_current_group_name()
                self._group_prefix = f"{group_name}::"
                self._swap_manager.add_storage(group_name, self.swap_storage)
                self.add_to_storage = True
            funcname = f"{self._group_prefix}{func}"
            group_swap = self.group_swap
            entries = tree_map(
                lambda x: _SwapCacheEntry(_maybe_detach(x, any_ret_has_alias_info), funcname, group_swap=group_swap),
                out,
            )
205
206
207
208
209
210
211
212
213
            self.swap_storage.clear()
            self._swap_cleared = True

        # MUST_SAVE, PREFER_SAVE, and MUST_SWAP all restore from storage identically.
        if (policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE, CheckpointPolicy.MUST_SWAP)
           or is_compiling):
            storage = self.storage.get(func)  # patch code
            if storage is None:
                raise RuntimeError(f"{func} encountered during backward, but not found in storage")
221
222
223
224
225
226
227
228
229
            out = func(*args, **kwargs)
        return out


def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False, group_swap=False):
    """
    Helper to avoid recomputing certain ops during activation checkpointing.

    Use this with `torch.utils.checkpoint.checkpoint` to control which
278
279
280
281
282
283
284
285
286
287
288
289
        >>> )
    """
    # NB: If grad_mode is disabled, checkpoint would not run forward under
    #     context_fn anyway, so proceed as usual.
    if policy_fn_or_list is None:
        def policy_fn(_ctx, _op, *_args, **_kwargs):
            return CheckpointPolicy.PREFER_RECOMPUTE
    elif isinstance(policy_fn_or_list, list):
        for op in policy_fn_or_list:
            if not isinstance(op, torch._ops.OpOverload):
                _extra_msg = (
                    "Please update the OpOverloadPacket to a specific OpOverload."
hyper_parallel/platform/torch/platform.py
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335

    @staticmethod
    def swap_wrapper(module, policy_fn=None, group_swap=False):
        # pylint: disable=C0415
        from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import swap_wrapper
        return swap_wrapper(module, policy_fn=policy_fn, group_swap=group_swap)

    @staticmethod
    def swap_tensor_wrapper(target, tag=None, group_swap=False):
        # pylint: disable=C0415
        from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import swap_tensor_wrapper
        return swap_tensor_wrapper(target, tag=tag, group_swap=group_swap)

    @staticmethod
    def get_class_activation_wrapper():
        # pylint: disable=C0415
        from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import ActivationWrapper
        return ActivationWrapper

    @property
    def noop_context_fn(self):
        return noop_context_fn
1337
1338
1339
1340
1341
1342
1343
1344
1345
    @staticmethod
    def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False, group_swap=False):
        # pylint: disable=C0415
        from hyper_parallel.platform.torch.activation_checkpoint.sac import create_selective_checkpoint_contexts
        return create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation, group_swap)

    @staticmethod
    def async_save_on_cpu(policy_fn=None):
        # pylint: disable=C0415
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363

    @staticmethod
    def alloc_tensor_buffer(numel: int, dtype, device, pin_memory: bool = False):
        """Allocate an uninitialized 1-D tensor buffer."""
        if pin_memory:
            return torch.empty(numel, dtype=dtype, device='cpu', pin_memory=True)
        return torch.empty(numel, dtype=dtype, device=device)

    @staticmethod
    def tensor_to_numpy(tensor) -> np.ndarray:
        """Convert PyTorch tensor to numpy array."""