Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/pipeline_parallel/__init__.py 100%  
hyper_parallel/core/pipeline_parallel/overlap_callbacks.py 18.4% 82-87,89-91,94,96,102-107,111-112,114-115,119,121-122,124-125,128,135-136,138,140
hyper_parallel/core/pipeline_parallel/scheduler.py 46.0% 195-197,202-204,209,275,414-416,418,422-424,426,430-432,436-438,442-443,446,448-449,458-460,462-463,485,487,496-497,499,501,505-506,508-509,511,513-514,516-519,521-523,526-527,532-534,575,578-580,585,589,598,609,612,614,619-620,1208-1210,1215-1216,1275
hyper_parallel/platform/mindspore/pipeline_parallel/stage.py 5.0% 182-193,195-197,199-201,203,214-223,225-228,230-231,234-236
hyper_parallel/platform/mindspore/platform.py 0.0% 1782
hyper_parallel/core/pipeline_parallel/overlap_callbacks.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
        Callable ``(step, ctx) -> None`` suitable for
        :meth:`PipelineScheduleRuntime.register_custom_function`.
    """

    def _callback(step, ctx):
        bwd_step, fwd_step = step.sub_steps
        schedule = ctx.schedule
        fwd_stage = schedule._stage_dict[fwd_step.stage_index]  # pylint: disable=protected-access
        bwd_stage = schedule._stage_dict[bwd_step.stage_index]  # pylint: disable=protected-access
        fwd_mi, bwd_mi = fwd_step.micro_index, bwd_step.micro_index

        def fwd_fn():
            schedule.wait_fwd_recv(fwd_stage.stage_index, fwd_mi)
            out = fwd_stage.forward_one_chunk(
                fwd_mi, ctx.arg_mbs[fwd_mi], ctx.kwarg_mbs[fwd_mi],
            )
            schedule.update_losses(fwd_stage, out, ctx.losses)

        def bwd_fn():
            # MS PyNative's grad-enable flag is thread-local; the daemon BWD
            # thread does not inherit the main thread's enabled state, so
            # ``value_and_grad`` would otherwise raise "In no_grad context"
            # on first call.  Torch's autograd is process-wide so this
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
            # thread does not inherit the main thread's enabled state, so
            # ``value_and_grad`` would otherwise raise "In no_grad context"
            # on first call.  Torch's autograd is process-wide so this
            # import is a MS-only no-op when unavailable.
            try:
                from mindspore.common.api import _pynative_executor  # pylint: disable=C0415
                _pynative_executor.set_enable_grad(True)
            except ImportError:
                pass
            schedule.wait_bwd_recv(bwd_stage.stage_index, bwd_mi)
            # First-stage degeneration (no input grad -> dx no-op, dw does the
            # full backward) is handled inside the stage methods, so the same
            # dx -> send -> dw sequence applies to every stage.
            if schedule.enable_dxdw_split:
                logger.debug("dxdw: stage=%d mi=%d dx start",
                             bwd_stage.stage_index, bwd_mi)
                bwd_stage.backward_input_one_chunk(bwd_mi)
                logger.debug("dxdw: stage=%d mi=%d dx done; send_bwd",
                             bwd_stage.stage_index, bwd_mi)
                # dx / dw are profiler-tagged inside the stage methods; tag the
                # grad send so the dx -> send -> dw split is visible in traces.
                with platform.profiler_record(
                        f"dxdw/send/stage_{bwd_stage.stage_index}/mi_{bwd_mi}"):
                    schedule.send_bwd(bwd_stage, bwd_mi)
                logger.debug("dxdw: stage=%d mi=%d dw start",
                             bwd_stage.stage_index, bwd_mi)
                bwd_stage.backward_weight_one_chunk(bwd_mi)
                logger.debug("dxdw: stage=%d mi=%d dw done",
                             bwd_stage.stage_index, bwd_mi)
            else:
                bwd_stage.backward_one_chunk(bwd_mi)

            # Pair-8 BWD partner: explicit rendezvous so the FWD thread's
            # CHUNK_END hook always has a partner.  See
            # ``pp_overlap_moe_poc.py`` for the protocol rationale — MS
131
132
133
134
135
136
137
138
139
140
            # CHUNK_END hook always has a partner.  See
            # ``pp_overlap_moe_poc.py`` for the protocol rationale — MS
            # autograd may skip ``CHUNK_START.bwd`` when the chunk input
            # has no ``requires_grad`` (its ``x.grad`` is unused).
            if overlap.coordinator.is_enabled():
                overlap.coordinator.rendezvous(HookRole.COMPUTE)

        overlap.run(fwd_fn=fwd_fn, bwd_fn=bwd_fn)

    return _callback
hyper_parallel/core/pipeline_parallel/scheduler.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213


def _exec_fsdp_unshard(stage):
    """Unshard every HSDPModule in the stage's submodule tree."""
    for _, module in platform.get_cells_and_names(stage.submodule):
        if isinstance(module, HSDPModule):
            module.unshard()


def _exec_fsdp_reshard(stage):
    """Reshard every HSDPModule in the stage's submodule tree."""
    for _, module in platform.get_cells_and_names(stage.submodule):
        if isinstance(module, HSDPModule):
            module.reshard()


def _exec_fsdp_reduce_grad(stage):
    """Run the stage's FSDP post-backward gradient reduction."""
    stage.execute_reduce_grad()


# FSDP control MetaStep -> handler(stage).  Membership also marks which
# MetaStepTypes are FSDP control steps, so the runtime loop dispatches with a
271
272
273
274
275
276
277
278
        self._custom_fn_map = {}
        self._pp_swap_enabled = swap
        # Outstanding async send handle groups for the in-flight
        # ``run_microbatches`` call; reset per run and drained at its end.
        self._send_handles = []

    def register_custom_function(self, step_type: MetaStepType, fn) -> None:
        """Register a custom execution function for the given step type.
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
    # op waits inline.

    def recv_fwd(self, stage: "hyper_parallel.PipelineStage", micro_index: int) -> None:
        """Post the FWD recv for ``micro_index``; cache it (overlap_p2p) or wait now."""
        handles = stage.exec_fwd_recv_ops(micro_index)
        if self._overlap_p2p:
            self.fwd_handle_cache[(stage.stage_index, micro_index)] = handles
        else:
            self._wait_p2p(handles)

    def recv_bwd(self, stage: "hyper_parallel.PipelineStage", micro_index: int) -> None:
        """Post the BWD recv for ``micro_index``; cache it (overlap_p2p) or wait now."""
        handles = stage.exec_bwd_recv_ops(micro_index)
        if self._overlap_p2p:
            self.bwd_handle_cache[(stage.stage_index, micro_index)] = handles
        else:
            self._wait_p2p(handles)

    def wait_fwd_recv(self, stage_index: int, micro_index: int) -> None:
        """Wait the FWD recv cached by :meth:`recv_fwd`; no-op if nothing is cached."""
        handles = self.fwd_handle_cache.pop((stage_index, micro_index), None)
        if handles:
            self._wait_p2p(handles)

    def wait_bwd_recv(self, stage_index: int, micro_index: int) -> None:
        """Wait the BWD recv cached by :meth:`recv_bwd`; no-op if nothing is cached."""
        handles = self.bwd_handle_cache.pop((stage_index, micro_index), None)
        if handles:
            self._wait_p2p(handles)

    def send_fwd(self, stage: "hyper_parallel.PipelineStage", micro_index: int) -> list:
        """Send this stage's forward output for ``micro_index`` to the next stage."""
        handles = stage.exec_fwd_send_ops(micro_index) or []
        if self._overlap_p2p:
            # Append the whole handle group: run_microbatches drains _send_handles
            # group by group, so a bare handle would be wrongly iterated as a list.
            self._send_handles.append(handles)
        else:
            self._wait_p2p(handles)
        return handles

    def send_bwd(self, stage: "hyper_parallel.PipelineStage", micro_index: int) -> list:
        """Send this stage's input-gradient for ``micro_index`` to the previous stage.
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        Call from an OVERLAP callback only when ``enable_dxdw_split=True``: the
        scheduler then carries no ``BWD_SEND`` step for it, so the callback owns
        the send.  With the flag off this double-sends the gradient.
        """
        handles = stage.exec_bwd_send_ops(micro_index) or []
        if self._overlap_p2p:
            self._send_handles.append(handles)
        else:
            self._wait_p2p(handles)
        return handles

    def _assert_in_unshard_if_needed(self, stage, check_step):
        if not isinstance(stage.submodule, HSDPModule):
            return
481
482
483
484
485
486
487
488
489
490
491
        then runs.
        """
        stage = self._stage_dict[cur_step.stage_index]
        micro_index = cur_step.micro_index
        step_type = cur_step.type

        if step_type in (
            MetaStepType.SWAP_SET_GROUP,
            MetaStepType.SWAP_LAUNCH_OFFLOAD,
            MetaStepType.SWAP_WAIT_OFFLOAD,
            MetaStepType.SWAP_LAUNCH_LOAD,
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
            MetaStepType.SWAP_WAIT_LOAD,
        ):
            self._exec_pipeline_swap_step(cur_step, arg_mbs, kwarg_mbs)

        elif step_type == MetaStepType.FWD_RECV:
            self.recv_fwd(stage, micro_index)

        elif step_type == MetaStepType.FWD:
            self._assert_in_unshard_if_needed(stage, cur_step)
            self.wait_fwd_recv(stage.stage_index, micro_index)
            out = stage.forward_one_chunk(micro_index, arg_mbs[micro_index], kwarg_mbs[micro_index])
            self.update_losses(stage, out, losses)

        elif step_type == MetaStepType.FWD_SEND:
            self.send_fwd(stage, micro_index)

        elif step_type == MetaStepType.BWD_RECV:
            self.recv_bwd(stage, micro_index)

        elif step_type == MetaStepType.BWD_INPUT:
            self._assert_in_unshard_if_needed(stage, cur_step)
            self.wait_bwd_recv(stage.stage_index, micro_index)
            stage.backward_input_one_chunk(micro_index)

        elif step_type == MetaStepType.BWD_WEIGHT:
            self._assert_in_unshard_if_needed(stage, cur_step)
            self.wait_bwd_recv(stage.stage_index, micro_index)
            stage.backward_weight_one_chunk(micro_index)

        elif step_type == MetaStepType.BWD:
            self._assert_in_unshard_if_needed(stage, cur_step)
            self.wait_bwd_recv(stage.stage_index, micro_index)
            stage.backward_one_chunk(micro_index)

        elif step_type == MetaStepType.BWD_SEND:
            self.send_bwd(stage, micro_index)

        else:
            # FSDP control steps dispatch via the handler table; any other type
            # is a no-op here (composite/custom types are handled upstream).
            fsdp_handler = _FSDP_STEP_HANDLERS.get(step_type)
            if fsdp_handler is not None:
                fsdp_handler(stage)

    def _exec_pipeline_swap_step(self, cur_step, arg_mbs, kwarg_mbs):
        """Execute a pipeline activation-swap control step."""
        from hyper_parallel.core.pipeline_parallel.pipeline_swap import (  # pylint: disable=C0415
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        .setLevel(logging.DEBUG)`` to trace per-rank schedule advancement
        (handy when diagnosing deadlocks or callback ordering issues).
        """
        real_stage_index = self.stages[0].stage_index % self.real_stage_num
        self._send_handles = []
        ctx = None  # lazily created

        ordered = self.exec_order[real_stage_index]
        total_steps = len(ordered)
        logger.debug(
            "run_microbatches start: rank=%d total_steps=%d micro_batch_num=%d",
            real_stage_index, total_steps, self.micro_batch_num,
        )

        for step_idx, cur_step in enumerate(ordered):
            if cur_step is None:
                continue

            logger.debug(
                "rank=%d step=%d/%d %s",
                real_stage_index, step_idx, total_steps, cur_step,
            )
594
595
596
597
598
599
600
601
602
            # Check for registered custom function
            custom_fn = self._custom_fn_map.get(cur_step.type)
            if custom_fn is not None:
                if ctx is None:
                    ctx = PipelineContext(self, arg_mbs, kwarg_mbs, losses)
                custom_fn(cur_step, ctx)
                continue

            # Default for composite OVERLAP steps: run sub_steps sequentially.
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
            # semantically equivalent to non-overlapped 1F1B.
            if (cur_step.type in (MetaStepType.OVERLAP_F_B, MetaStepType.OVERLAP_B_F)
                    and cur_step.sub_steps):
                for sub in cur_step.sub_steps:
                    self._exec_step(sub, arg_mbs, kwarg_mbs, losses)
                continue

            self._exec_step(cur_step, arg_mbs, kwarg_mbs, losses)

        logger.debug(
            "run_microbatches end: rank=%d pending_send_handles=%d",
            real_stage_index, len(self._send_handles),
        )
        self.sync_shared_parameters_grad()
        while self._send_handles:
            self._wait_p2p(self._send_handles.pop())


class _OverlapPhantom:
    """Internal marker used by :func:`add_send_recv` to expand an
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
        # pairing in the 1F1B steady-state phase.  Must be set before
        # ``construct_stage_exec_order`` is called below.
        self._overlap_b_f = overlap_b_f
        # enable dx_dw split in overlap phase.
        self._enable_dxdw_split = enable_dxdw_split
        if enable_dxdw_split and not overlap_b_f:
            raise ValueError(
                "enable_dxdw_split=True requires overlap_b_f=True; the split "
                "is only applied to BWD sub-steps inside OVERLAP_B_F composite steps."
            )

        self._init_round_layout()
        self.build_exec_order()

    def _init_round_layout(self):
        """Compute per-round micro-batch counts used by stage-order emission.
1271
1272
1273
1274
1275
1276
1277
1278
1279
            for op in ops:
                if op is None or op.type != MetaStepType.OVERLAP_B_F:
                    continue
                if not op.sub_steps:
                    continue
                for sub in op.sub_steps:
                    if sub.type == MetaStepType.BWD:
                        obf_bwd_keys.add((sub.stage_index, sub.micro_index))
            kept = []
hyper_parallel/platform/mindspore/pipeline_parallel/stage.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        is no dx to compute and no gradient to send upstream.  The paired
        :meth:`backward_weight_one_chunk` runs the full backward instead, so
        ``grad_fn`` is left untouched in the cache for it to pop.
        """
        from hyper_parallel.core.fully_shard.api import HSDPModule  # pylint: disable=C0415
        from hyper_parallel.platform import get_platform  # pylint: disable=C0415
        if not self._has_backward:
            return
        with get_platform().profiler_record(f"backward_input_one_chunk: stage_{self.stage_index}/mi_{micro_index}"):
            for _, mod in self.submodule.cells_and_names():
                if not isinstance(mod, HSDPModule):
                    continue
                mod.set_reshard_after_backward(False)
                mod.set_requires_gradient_sync(False)
            if self.is_first_stage:
                return
            # Index, NOT pop: backward_weight_one_chunk performs the terminal pop.
            grad_fn = self.fwd_grad_fn_cache[micro_index]
            if self.is_last_stage:
                sens = self.get_last_stage_sens(self.last_stage_outputs)
            else:
                sens = self._build_padded_sens(micro_index)
            _ = grad_fn.compute_input_grad(sens=sens)
            input_grads = [recv_info.buffer.grad for recv_info in self.args_recv_info[micro_index]
                           if recv_info.requires_grad]
            self.bwd_cache[micro_index] = input_grads

    def backward_weight_one_chunk(self, micro_index):
        """dw-only backward; pops grad_fn (terminal) and clears recv buffers.
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        intermediate gradients were saved.  The full backward ``grad_fn(sens)``
        runs here instead, which yields only weight gradients (the stage has no
        input grad to compute).
        """
        from hyper_parallel.core.fully_shard.api import HSDPModule  # pylint: disable=C0415
        from hyper_parallel.platform import get_platform  # pylint: disable=C0415
        if not self._has_backward:
            return
        with get_platform().profiler_record(f"backward_weight_one_chunk: stage_{self.stage_index}/mi_{micro_index}"):
            for _, mod in self.submodule.cells_and_names():
                if not isinstance(mod, HSDPModule):
                    continue
                mod.set_reshard_after_backward(False)
                mod.set_requires_gradient_sync(False)

            grad_fn = self.fwd_grad_fn_cache.pop(micro_index)
            if self.is_first_stage:
                sens = self._build_padded_sens(micro_index)
                _ = grad_fn(sens=sens)
            else:
                if not grad_fn._saved_intermediates:  # pylint: disable=protected-access
                    raise RuntimeError(
                        f"stage: {self.stage_index} micro_{micro_index} dw called before dx."
                    )
                grad_fn.compute_weight_grad()
            self._clear_recv_buffer(self.grad_recv_info, micro_index)
            self._clear_recv_buffer(self.args_recv_info, micro_index)

    def _construct_backward_func(self):
        """construct backward func."""
        enable_mindspore_backward_compat()
hyper_parallel/platform/mindspore/platform.py
1778
1779
1780
1781
1782
1783
1784
1785
1786

    @staticmethod
    def profiler_record(name):
        """Profiler context manager for recording operations using mindspore.profiler."""
        return ms.profiler.common.record_function.RecordFunction(name)

    def str_to_dtype(self, dtype_str: str) -> Any:
        """Resolve checkpoint dtype strings (``mindspore.*`` or short ``str(Tensor.dtype)`` e.g. ``Float32``)."""
        if "." in dtype_str: