Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/platform/torch/activation_checkpoint/__init__.py 100%  
hyper_parallel/platform/torch/activation_checkpoint/recompute_session.py 90.3% 178,206,250,315,325,327-329,353,380,400,408,416,480,489-490,534
hyper_parallel/platform/torch/platform.py 47.1% 1344-1345,1396-1397,1409,1424-1425,1435-1436
hyper_parallel/platform/torch/activation_checkpoint/recompute_session.py
174
175
176
177
178
179
180
181
182
                activations.
        """
        frame = self._frame
        if frame.is_recomputed[session_id]:
            return

        args, kwargs = frame.get_inputs()

        try:
202
203
204
205
206
207
208
209
210
        # Null out handles in all live holders for this session.
        for weak_holder in frame.weak_holders:
            holder = weak_holder()
            if holder is not None and session_id in holder.handles:
                holder.handles[session_id] = None


# ---------------------------------------------------------------------------
# Context variables
246
247
248
249
250
251
252
253
254
    Yields:
        The ``session_id`` string that is active for the scope.
    """
    if session_id is None:
        session_id = str(uuid.uuid4())
    session = _RecomputeSession(session_id=session_id, retain_on_unpack=retain_on_unpack)
    token = _recompute_session.set(session)
    try:
        yield session_id
311
312
313
314
315
316
317
318
319
    ) -> None:
        def pack_hook(x: torch.Tensor) -> torch.Tensor:
            frame = target_frame_ref()
            if frame is None:
                raise RuntimeError(
                    "CheckpointFrame has been garbage collected during recomputation."
                )

            frame.recomp_counter[session_id] += 1
321
322
323
324
325
326
327
328
329
330
331
332

            # If recomputation produces more tensors than the original forward
            # saved, either silently ignore or error.
            if recomp_idx >= len(frame.weak_holders):
                if not frame.early_stop and not frame.forward_completed:
                    # Allow the extra tensor through without caching.
                    frame.ignore_saved_mismatch = True
                    return x
                raise RuntimeError(
                    "Recompute session: more tensors were saved during "
                    "recomputation than during the original forward pass."
                )
349
350
351
352
353
354
355
356

            return x

        def unpack_hook(x: torch.Tensor) -> torch.Tensor:
            return x

        super().__init__(pack_hook, unpack_hook)

376
377
378
379
380
381
382
383
384

        def unpack_hook(holder: _Holder) -> torch.Tensor:
            session = _recompute_session.get()
            if session is None:
                raise RuntimeError(
                    "checkpoint_with_session: unpack triggered outside a "
                    "recompute session context.  Wrap backward in "
                    "_recompute_session_ctx()."
                )
396
397
398
399
400
401
402
403
404
                    pass
                frame.is_recomputed[key] = True

            if key not in holder.handles:
                raise RuntimeError(
                    f"checkpoint_with_session: session '{key}' has no handle "
                    "for this holder.  The recomputation may have saved a "
                    "different number of tensors than the original forward."
                )
404
405
406
407
408
409
410
411
412
                )

            handle = holder.handles[key]
            if handle is None:
                raise RuntimeError(
                    "checkpoint_with_session: unpack triggered for a tensor "
                    "that has already been unpacked once in this session.  "
                    "If you need to access the tensor multiple times, use "
                    "retain_on_unpack=True."
412
413
414
415
416
417
418
419
                    "retain_on_unpack=True."
                )

            if handle not in frame.recomputed[key]:
                raise RuntimeError(
                    "checkpoint_with_session: handle not found in recomputed "
                    f"cache for session '{key}'."
                )
476
477
478
479
480
481
482
483
    Raises:
        ValueError: If ``use_reentrant=True``.
    """
    if use_reentrant:
        raise ValueError(
            "checkpoint_with_session does not support use_reentrant=True.  "
            "Session-based checkpointing requires the non-reentrant path."
        )
485
486
487
488
489
490
491
492
493
494
    session = _recompute_session.get()

    # If no session is active, fall back to vanilla torch checkpoint.
    if session is None:
        from torch.utils.checkpoint import checkpoint as _torch_checkpoint  # pylint: disable=C0415
        return _torch_checkpoint(
            function, *args, use_reentrant=False, **kwargs
        )

    # -- Session is active: use our custom machinery. ------------------------
530
531
532
533
534
    if isinstance(value, torch.Tensor):
        detached = value.detach()
        detached.requires_grad = value.requires_grad
        return detached
    return value
hyper_parallel/platform/torch/platform.py
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349

    @property
    def checkpoint(self):
        # pylint: disable=C0415
        from hyper_parallel.platform.torch.activation_checkpoint.recompute_session import checkpoint_with_session
        return checkpoint_with_session

    @staticmethod
    def checkpoint_wrapper(module, **checkpoint_kwargs):
        # pylint: disable=C0415
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
            A list populated with one opaque recompute handle per checkpointed
            block executed during the forward pass within the context.
        """
        # pylint: disable=C0415
        from hyper_parallel.platform.torch.activation_checkpoint.recompute_session import _recompute_handle_collector_ctx
        return _recompute_handle_collector_ctx()

    @staticmethod
    def recompute_handle(handle, session_id):
        """Eagerly fire one checkpointed block's forward re-run.
1405
1406
1407
1408
1409
1410
1411
1412
1413
                :meth:`recompute_handle_collector_ctx`.
            session_id: Stable key shared by the producing re-run and the
                consuming backward.
        """
        return handle.recompute(session_id)

    @staticmethod
    def recompute_session_ctx(session_id=None, retain_on_unpack=False):
        """Context manager binding recompute unpack to a caller-provided session.
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
        Returns:
            A context manager activating the session for its scope.
        """
        # pylint: disable=C0415
        from hyper_parallel.platform.torch.activation_checkpoint.recompute_session import _recompute_session_ctx
        return _recompute_session_ctx(session_id=session_id, retain_on_unpack=retain_on_unpack)

    @staticmethod
    def clear_recompute_session(session_id):
        """Release retained recompute data for a session.
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
        Args:
            session_id: The session key whose cached recompute data is cleared.
        """
        # pylint: disable=C0415
        from hyper_parallel.platform.torch.activation_checkpoint.recompute_session import _clear_recompute_session
        return _clear_recompute_session(session_id)

    @staticmethod
    def get_element_size(tensor):
        """Get Tensor Element Size"""