Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/platform/mindspore/activation_checkpoint/activation_swap.py 50.0% 102-103,145-147
hyper_parallel/platform/torch/activation_checkpoint/activation_swap.py 80.0% 101,141
hyper_parallel/platform/mindspore/activation_checkpoint/activation_swap.py
 98
 99
100
101
102
103
104
105
106
107
def _get_wrapped_callable(cell: Cell) -> Optional[Callable]:
    wrapped_module = getattr(cell, _CKPT_WRAPPED_MODULE, None)
    if isinstance(wrapped_module, FuncCell):
        return getattr(wrapped_module, "_fn", None)
    if isinstance(cell, FuncCell):
        return getattr(cell, "_fn", None)
    return None


def _raise_callable_already_wrapped(callable_obj: Callable) -> None:
141
142
143
144
145
146
147
148
149
150
151
        )
    for _, submodule in module.cells_and_names():
        if submodule is module:
            continue
        wrapped_callable = _get_wrapped_callable(submodule)
        if wrapped_callable is not None and _is_shared_function_callable(wrapped_callable):
            continue
        if getattr(submodule, '_is_wrapped', False):
            if wrapped_callable is not None:
                _raise_callable_already_wrapped(wrapped_callable)
            raise ValueError(
hyper_parallel/platform/torch/activation_checkpoint/activation_swap.py
 97
 98
 99
100
101
102
103
104
105
    wrapped_module = getattr(module, _SWAP_WRAPPED_MODULE, None)
    if isinstance(wrapped_module, FuncModule):
        return getattr(wrapped_module, "_fn", None)
    if isinstance(module, FuncModule):
        return getattr(module, "_fn", None)
    return None


def _raise_callable_already_wrapped(callable_obj: Callable) -> None:
137
138
139
140
141
142
143
144
145
        if submodule is module:
            continue
        wrapped_callable = _get_wrapped_callable(submodule)
        if wrapped_callable is not None and _is_callable_exempt_from_overlap_check(wrapped_callable):
            continue
        if getattr(submodule, '_is_wrapped', False):
            if wrapped_callable is not None:
                _raise_callable_already_wrapped(wrapped_callable)
            raise ValueError(