Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/activation_checkpoint/swap.py 26.1% 80,82-85,87-88,141-142,321,326,430-432,434,732,861
hyper_parallel/core/pipeline_parallel/pipeline_swap.py 0.0% 17-18,20,22,25-28,31-33,36,38,41-42,44,47-48,50,58-59,61,68,71,73-76,79,81-87,90,92-103,106-107,110-111,114,116-117,120,122-129,132,134,136-141,143,145-147,152-153,156,163,165-174,176-184,188-190,192,196-197,199,204,209,213,218-223,226,234-236,238,243,248,251-255,258,260-263,266,268-271,273,276,278,281,283,286,288
hyper_parallel/core/pipeline_parallel/scheduler.py 25.0% 247,292-295,298,314,388,395,397,449,457-466
hyper_parallel/platform/mindspore/activation_checkpoint/activation_swap.py 0.0% 253-254
hyper_parallel/platform/torch/activation_checkpoint/activation_swap.py 0.0% 94-95
hyper_parallel/core/activation_checkpoint/swap.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92


def _collect_device_storage_ptrs(tensors: Any) -> Set[int]:
    """Collect device storage pointers from a nested tensor structure."""
    storage_ptrs = set()

    def _collect(x):
        if isinstance(x, platform.Tensor) and str(x.device).lower() != "cpu":
            storage_ptrs.add(x.untyped_storage().data_ptr())
        return x

    platform.tree_map(_collect, tensors)
    return storage_ptrs


class SwapTensor:
    """A tensor that can be swapped between device and host memory asynchronously."""
137
138
139
140
141
142
143
144
145
146
    def protect_if_aliases(self, alias_storage_ptrs: Set[int]) -> None:
        """Keep tensors that alias externally-owned tensors on device."""
        if self._state == self.STATE_NON_TENSOR:
            return
        if self.val.untyped_storage().data_ptr() in alias_storage_ptrs:
            self._keep_on_device = True

    def get_val(self) -> Any:
        if self._state == self.STATE_NON_TENSOR:
            return self.val
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        return duplicate_count

    def protect_alias_storage_ptrs(self, alias_storage_ptrs: Set[int]):
        """Avoid offloading swap entries that alias externally-owned storage."""
        if not alias_storage_ptrs:
            return

        def _protect_tensor(x):
            if isinstance(x, SwapTensor):
                x.protect_if_aliases(alias_storage_ptrs)
            return x

        for storage_list in self.values():
            for item in storage_list:
426
427
428
429
430
431
432
433
434
435
436
437
        self._storages.append(storage)

    def protect_alias_tensors(self, tensors: Any):
        """Protect externally-owned tensors from premature offload."""
        alias_storage_ptrs = _collect_device_storage_ptrs(tensors)
        if not alias_storage_ptrs:
            return
        for storage in self._storages:
            storage.protect_alias_storage_ptrs(alias_storage_ptrs)

    def _collect_packable_tensors(self) -> int:
        """Identify tensors eligible for group packing and mark them for bulk copy.
728
729
730
731
732
733
734
735
736
        """Keep tensors that alias externally-owned tensors on device."""
        group = self._groups.get(group_name)
        if group is None:
            raise RuntimeError(f"Group {group_name} does not exist.")
        group.protect_alias_tensors(tensors)

    def wait_offload(self, group_name: str):
        """Wait for offload to complete for a specified swap group."""
        group = self._groups.get(group_name)
857
858
859
860
861
862
863
864
865
            if getattr(module, "_swap_state", None) == "pre_backward":
                return
            next_name = module._swap_group_order.get('next', None)
            if next_name:
                SwapManager().protect_alias_tensors(group_name, output)
                SwapManager().launch_offload(group_name)
            prev_name = module._swap_group_order.get('prev', None)
            if prev_name:
                SwapManager().wait_offload(prev_name)
hyper_parallel/core/pipeline_parallel/pipeline_swap.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# limitations under the License.
# ============================================================================
"""Pipeline-parallel activation swap scheduling helpers."""

from collections import defaultdict
from enum import IntEnum

from hyper_parallel.core.activation_checkpoint.swap import SwapManager

MIN_SWAP_GAP = 4


class _BeforeActionPriority(IntEnum):
    SET_GROUP = 10
    WAIT_LOAD = 20
    LAUNCH_LOAD = 30


class _AfterActionPriority(IntEnum):
    WAIT_OFFLOAD = 10
    LAUNCH_OFFLOAD = 20


def pp_swap_group_name(stage_index: int, micro_index: int) -> str:
    """Return the swap group name for a pipeline chunk."""
    return f"pp_swap_s{stage_index}_m{micro_index}"


def _is_compute_step(step) -> bool:
    from hyper_parallel.core.pipeline_parallel.scheduler import MetaStepType  # pylint: disable=C0415

    return step is not None and step.type in (MetaStepType.FWD, MetaStepType.BWD)


def _is_comm_step(step) -> bool:
    from hyper_parallel.core.pipeline_parallel.scheduler import MetaStepType  # pylint: disable=C0415

    return step is not None and step.type in (
        MetaStepType.FWD_RECV,
        MetaStepType.FWD_SEND,
        MetaStepType.BWD_RECV,
        MetaStepType.BWD_SEND,
54
55
56
57
58
59
60
61
62
63
64
65
        MetaStepType.BWD_SEND,
    )


def _is_composite_compute_step(step) -> bool:
    from hyper_parallel.core.pipeline_parallel.scheduler import MetaStepType  # pylint: disable=C0415

    return (
        step is not None
        and step.type in (MetaStepType.OVERLAP_F_B, MetaStepType.OVERLAP_B_F)
        and step.sub_steps
    )
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        and step.sub_steps
    )


class _ComputeLeaf:
    """A real FWD/BWD leaf and the top-level container that owns it."""

    __slots__ = ("step", "container_index", "compute_index")

    def __init__(self, step, container_index, compute_index):
        self.step = step
        self.container_index = container_index
        self.compute_index = compute_index


def _iter_compute_leaf_steps(step):
    """Yield real FWD/BWD steps, expanding composite containers."""
    if _is_compute_step(step):
        yield step
        return
    if _is_composite_compute_step(step):
        for sub_step in step.sub_steps:
            if _is_compute_step(sub_step):
                yield sub_step


def _collect_compute_leaves(order):
    """Collect compute leaves while counting each composite as one slot."""
    leaves = []
    container_by_compute_index = {}
    compute_index = 0
    for container_index, step in enumerate(order):
        leaf_steps = list(_iter_compute_leaf_steps(step))
        if not leaf_steps:
            continue
        container_by_compute_index[compute_index] = container_index
        for leaf_step in leaf_steps:
            leaves.append(_ComputeLeaf(leaf_step, container_index, compute_index))
        compute_index += 1
    return leaves, container_by_compute_index


def _append_after(after_steps, index, priority, step):
    after_steps[index].append((priority, step))


def _append_before(before_steps, index, priority, step):
    before_steps[index].append((priority, step))


def _iter_steps_by_priority(priority_steps):
    """Yield steps from high priority to low priority."""
    for _, step in sorted(priority_steps, key=lambda item: item[0], reverse=True):
        yield step


def _comm_block_anchor(order, index):
    """Return the last immediately following communication step."""
    anchor = index
    for next_index in range(index + 1, len(order)):
        next_step = order[next_index]
        if _is_comm_step(next_step):
            anchor = next_index
            continue
        break
    return anchor


def _post_compute_anchor(order, index, leaf_step=None):
    """Return the safe index after which post-compute swap steps may run."""
    from hyper_parallel.core.pipeline_parallel.scheduler import MetaStepType  # pylint: disable=C0415

    step = leaf_step if leaf_step is not None else order[index]
    fallback_anchor = _comm_block_anchor(order, index)
    if step.type == MetaStepType.FWD:
        send_type = MetaStepType.FWD_SEND
    elif step.type == MetaStepType.BWD:
        send_type = MetaStepType.BWD_SEND
    else:
        return fallback_anchor

    for next_index in range(index + 1, fallback_anchor + 1):
        next_step = order[next_index]
        if (
                next_step is not None
                and next_step.type == send_type
                and next_step.stage_index == step.stage_index
                and next_step.micro_index == step.micro_index):
            return next_index
    return fallback_anchor


def inject_pipeline_swap_steps(order):
    """Inject SWAP_* steps into one rank's pipeline order.

    The injected order preserves the required lifecycle:
    SET_GROUP -> FWD -> LAUNCH_OFFLOAD -> WAIT_OFFLOAD ->
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    The injected order preserves the required lifecycle:
    SET_GROUP -> FWD -> LAUNCH_OFFLOAD -> WAIT_OFFLOAD ->
    LAUNCH_LOAD -> WAIT_LOAD -> BWD.
    """
    from hyper_parallel.core.pipeline_parallel.scheduler import MetaStep, MetaStepType  # pylint: disable=C0415

    fwd_index = {}
    bwd_index = {}
    compute_leaves, container_by_compute_index = _collect_compute_leaves(order)
    for leaf in compute_leaves:
        step = leaf.step
        key = (step.stage_index, step.micro_index)
        if step.type == MetaStepType.FWD:
            fwd_index[key] = leaf
        elif step.type == MetaStepType.BWD:
            bwd_index[key] = leaf

    before_steps = defaultdict(list)
    after_steps = defaultdict(list)
    for key, fwd_leaf in fwd_index.items():
        bwd_leaf = bwd_index.get(key)
        if bwd_leaf is None:
            continue
        if bwd_leaf.compute_index - fwd_leaf.compute_index < MIN_SWAP_GAP:
            continue
        compute_between = [
            container_by_compute_index[index]
            for index in range(fwd_leaf.compute_index + 1, bwd_leaf.compute_index)
        ]
        if not compute_between:
            continue
        stage_index, micro_index = key

        _append_before(
            before_steps, fwd_leaf.container_index, _BeforeActionPriority.SET_GROUP,
            MetaStep(micro_index, MetaStepType.SWAP_SET_GROUP, stage_index),
        )
        fwd_anchor = _post_compute_anchor(order, fwd_leaf.container_index, fwd_leaf.step)
        first_between_anchor = _post_compute_anchor(order, compute_between[0])

        _append_after(
            after_steps, fwd_anchor, _AfterActionPriority.LAUNCH_OFFLOAD,
            MetaStep(micro_index, MetaStepType.SWAP_LAUNCH_OFFLOAD, stage_index),
        )

        _append_after(
            after_steps, first_between_anchor, _AfterActionPriority.WAIT_OFFLOAD,
            MetaStep(micro_index, MetaStepType.SWAP_WAIT_OFFLOAD, stage_index),
        )

        _append_before(
            before_steps, compute_between[-1], _BeforeActionPriority.LAUNCH_LOAD,
            MetaStep(micro_index, MetaStepType.SWAP_LAUNCH_LOAD, stage_index),
        )
        _append_before(
            before_steps, bwd_leaf.container_index, _BeforeActionPriority.WAIT_LOAD,
            MetaStep(micro_index, MetaStepType.SWAP_WAIT_LOAD, stage_index),
        )

    injected = []
    for index, step in enumerate(order):
        injected.extend(_iter_steps_by_priority(before_steps[index]))
        injected.append(step)
        injected.extend(_iter_steps_by_priority(after_steps[index]))
    return injected


def _protect_pipeline_owned_tensors(step, schedule, arg_mbs, kwarg_mbs) -> None:
    """Keep pipeline-owned boundary tensors alive on device.

    Swap offload clears the device storage of saved tensors after D2H copy.
    If a saved tensor aliases a pipeline boundary tensor, clearing it would
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    If a saved tensor aliases a pipeline boundary tensor, clearing it would
    also invalidate the object still held by the pipeline runtime.  The alias
    protection below marks those saved tensors as keep-on-device.
    """
    stage = schedule._stage_dict[step.stage_index]  # pylint: disable=protected-access
    group_name = pp_swap_group_name(step.stage_index, step.micro_index)
    manager = SwapManager()

    if stage.is_first_stage:
        # First-stage inputs come from split_microbatches(), outside the
        # wrapped stage.  They are not stage outputs, but they can alias
        # tensors saved by the first layer and must not have their storage
        # resized by the swap group.
        manager.protect_alias_tensors(
            group_name,
            (arg_mbs[step.micro_index], kwarg_mbs[step.micro_index]),
        )

    if stage.is_last_stage:
        # Last-stage outputs are consumed by the schedule as losses / sens
        # roots, so they must stay device-valid until backward has used them.
        outputs = stage.fwd_outputs_cache.get(step.micro_index)
        if outputs is None:
            outputs = stage.last_stage_outputs
        if outputs is not None:
            manager.protect_alias_tensors(group_name, outputs)


def swap_set_group(step) -> None:
    """Set the active SwapManager group for the next pipeline forward chunk."""
    group_name = pp_swap_group_name(step.stage_index, step.micro_index)
    manager = SwapManager()
    manager.ensure_group(group_name)
    manager.set_current_group_name(group_name)


def swap_launch_offload(step, schedule, arg_mbs, kwarg_mbs) -> None:
    """Launch D2H for a pipeline swap group."""
    group_name = pp_swap_group_name(step.stage_index, step.micro_index)
    manager = SwapManager()
    _protect_pipeline_owned_tensors(step, schedule, arg_mbs, kwarg_mbs)
    manager.launch_offload(group_name)
    # Only the immediately preceding forward belongs to this swap group.
    manager.set_current_group_name("")


def swap_wait_offload(step) -> None:
    """Wait for a pipeline swap group's D2H and release device storage."""
    SwapManager().wait_offload(pp_swap_group_name(step.stage_index, step.micro_index))


def swap_launch_load(step) -> None:
    """Launch H2D for a pipeline swap group."""
    SwapManager().launch_load(pp_swap_group_name(step.stage_index, step.micro_index))


def swap_wait_load(step) -> None:
    """Wait for a pipeline swap group's H2D before backward uses activations."""
    SwapManager().wait_load(pp_swap_group_name(step.stage_index, step.micro_index))
hyper_parallel/core/pipeline_parallel/scheduler.py
243
244
245
246
247
248
249
250
        self._build_stage_to_rank_index()
        self.fwd_handle_cache = {}
        self.bwd_handle_cache = {}
        self._custom_fn_map = {}
        self._pp_swap_enabled = swap

    def register_custom_function(self, step_type: MetaStepType, fn) -> None:
        """Register a custom execution function for the given step type.
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        )

    def _inject_local_pp_swap_actions(self):
        """Annotate the local rank schedule with pipeline activation-swap actions."""
        if not self._pp_swap_enabled:
            return
        current_rank = self._stage_to_rank_index[self.stages[0].stage_index]
        from hyper_parallel.core.pipeline_parallel.pipeline_swap import (  # pylint: disable=C0415
            inject_pipeline_swap_steps,
        )
        self.exec_order[current_rank] = inject_pipeline_swap_steps(self.exec_order[current_rank])

    @abstractmethod
    def _build_stage_to_rank_index(self) -> None:
        """
310
311
312
313
314
315
316
317
318

    def build_exec_order(self) -> None:
        """Build the execution order and inject optional PP-swap/FSDP actions."""
        self.construct_exec_order()
        self._inject_local_pp_swap_actions()
        self._inject_local_fsdp_actions()

    def convert_stages_dict(self):
        """convert stages to dict."""
384
385
386
387
388
389
390
391
392
        stage = self._stage_dict[cur_step.stage_index]
        stage_index = cur_step.stage_index
        micro_index = cur_step.micro_index

        if cur_step.type in (
            MetaStepType.SWAP_SET_GROUP,
            MetaStepType.SWAP_LAUNCH_OFFLOAD,
            MetaStepType.SWAP_WAIT_OFFLOAD,
            MetaStepType.SWAP_LAUNCH_LOAD,
391
392
393
394
395
396
397
398
399
400
401
            MetaStepType.SWAP_WAIT_OFFLOAD,
            MetaStepType.SWAP_LAUNCH_LOAD,
            MetaStepType.SWAP_WAIT_LOAD,
        ):
            self._exec_pipeline_swap_step(cur_step, arg_mbs, kwarg_mbs)

        elif cur_step.type == MetaStepType.FWD_RECV:
            comm_handle = stage.exec_fwd_recv_ops(micro_index)
            if not self._overlap_p2p:
                self._wait_p2p(comm_handle)
            else:
445
446
447
448
449
450
451
452
453
            self._exec_fsdp_step(cur_step, stage)

    def _exec_pipeline_swap_step(self, cur_step, arg_mbs, kwarg_mbs):
        """Execute a pipeline activation-swap control step."""
        from hyper_parallel.core.pipeline_parallel.pipeline_swap import (  # pylint: disable=C0415
            swap_launch_load,
            swap_launch_offload,
            swap_set_group,
            swap_wait_load,
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
            swap_wait_load,
            swap_wait_offload,
        )

        if cur_step.type == MetaStepType.SWAP_SET_GROUP:
            swap_set_group(cur_step)
        elif cur_step.type == MetaStepType.SWAP_LAUNCH_OFFLOAD:
            swap_launch_offload(cur_step, self, arg_mbs, kwarg_mbs)
        elif cur_step.type == MetaStepType.SWAP_WAIT_OFFLOAD:
            swap_wait_offload(cur_step)
        elif cur_step.type == MetaStepType.SWAP_LAUNCH_LOAD:
            swap_launch_load(cur_step)
        elif cur_step.type == MetaStepType.SWAP_WAIT_LOAD:
            swap_wait_load(cur_step)

    def _exec_fsdp_step(self, cur_step, stage):
        """Execute an FSDP control step (unshard, reshard, or reduce-grad)."""
        if cur_step.type == MetaStepType.FSDP_UNSHARD:
hyper_parallel/platform/mindspore/activation_checkpoint/activation_swap.py
249
250
251
252
253
254
255
256
257
258
                return tensor
            if (policy_fn is not None) and (policy_fn(tensor) == CheckpointPolicy.MUST_SAVE):
                return tensor
            group_name = swap_manager.get_current_group_name()
            if not group_name:
                return tensor
            if not self.add_to_storage:
                swap_manager.add_storage(group_name, self.storage)
                self.add_to_storage = True
            funcname = f"{group_name}::{tensor.shape}"
hyper_parallel/platform/torch/activation_checkpoint/activation_swap.py
90
91
92
93
94
95
96
97
98
99
                return tensor
            if (policy_fn is not None) and (policy_fn(tensor) == CheckpointPolicy.MUST_SAVE):
                return tensor
            group_name = swap_manager.get_current_group_name()
            if not group_name:
                return tensor
            if not self.add_to_storage:
                swap_manager.add_storage(group_name, self.storage)
                self.add_to_storage = True
            funcname = f"{group_name}::{tensor.shape}"