Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/platform/mindspore/fully_shard/param.py 51.2% 201,205,573,800-802,804-806,809-811,813-815,820-821,824-825,828
hyper_parallel/platform/mindspore/fully_shard/param_group.py 75.0% 683,686,690,710-713,715,750,762,769,771,773-774,776-783,785,790
hyper_parallel/platform/mindspore/fully_shard/scheduler.py 57.1% 184-186
hyper_parallel/platform/mindspore/fully_shard/state.py 72.8% 121,212-213,215-217,220-223,228,356,418-428,460,474,483,486-487,492,498,523-525,532,535,580,644
hyper_parallel/platform/torch/fully_shard/param_group.py 41.7% 970-976
hyper_parallel/platform/mindspore/fully_shard/param.py
197
198
199
200
201
202
203
204
205
206
207
208
209
        self.gradient_scaling_factor = None

    @property
    def accumulated_allreduced_grad(self) -> bool:
        return self._accumulated_allreduced_grad

    @accumulated_allreduced_grad.setter
    def accumulated_allreduced_grad(self, value: bool) -> None:
        self._accumulated_allreduced_grad = value

    @property
    def uses_param_shard(self) -> bool:
        """Whether FSDP sharding is enabled for this parameter."""
569
570
571
572
573
574
575
576
577
    def _sharded_param_storage_dtype(self) -> Optional[ms.Type]:
        """Return the dtype of the sharded parameter's on-device storage."""
        sharded_param = self.sharded_param
        if isinstance(sharded_param, DTensor):
            return sharded_param._local_tensor.dtype
        if hasattr(sharded_param, "dtype"):
            dtype = sharded_param.dtype
            if isinstance(dtype, ms.Type):
                return dtype
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
        # apply gradient_scaling_factor (reduce-scatter leg)
        apply_gradient_scaling_factor(grad_flat, self.gradient_scaling_factor)
        # If parameter is not sharded (below threshold), no reduce-scatter needed
        if not self.is_sharded:
            if output_buffer is not None:
                copy_without_bumping_version(output_buffer, grad_flat)
                self._reduce_scatter_output = output_buffer
            else:
                self._reduce_scatter_output = grad_flat
            self.reduce_scatter_handle = None
            return self._reduce_scatter_output, None

        if shard_group is None or shard_group_size <= 1:
            if output_buffer is not None:
                copy_without_bumping_version(output_buffer, grad_flat)
                self._reduce_scatter_output = output_buffer
            else:
                self._reduce_scatter_output = grad_flat
            self.reduce_scatter_handle = None
            return self._reduce_scatter_output, None

        # Calculate output size
        output_numel = grad_flat.numel() // shard_group_size
        if output_buffer is not None:
            if output_buffer.numel() != output_numel:
                raise ValueError(
                    f"output_buffer size mismatch: expected {output_numel}, got {output_buffer.numel()}"
                )
            if output_buffer.dtype != reduce_dtype:
                raise ValueError(
                    f"output_buffer dtype mismatch: expected {reduce_dtype}, got {output_buffer.dtype}"
                )
            self._reduce_scatter_output = output_buffer
        else:
            self._reduce_scatter_output = ms.mint.empty(
                output_numel, dtype=reduce_dtype, device=grad.device.split(":")[0]
            )
hyper_parallel/platform/mindspore/fully_shard/param_group.py
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
        orig_dtypes: List[Any],
    ) -> Any:
        """Resolve None reduce_dtype to match ``reduce_scatter_grad``'s ``dtype or grad.dtype``."""
        if reduce_dtype is not None:
            return reduce_dtype
        for hsdp_param in hsdp_params:
            if getattr(hsdp_param, "unsharded_accumulated_grad", None) is not None:
                return hsdp_param.unsharded_accumulated_grad_data.dtype
            unsharded_param = getattr(hsdp_param, "unsharded_param", None)
            if unsharded_param is not None and getattr(unsharded_param, "grad", None) is not None:
                return hsdp_param.unsharded_grad_data.dtype
        return orig_dtypes[0] if orig_dtypes else None

    def __init__(
        self,
        replicate_group,
706
707
708
709
710
711
712
713
714
715
716
717
718
719
        self.reduce_op = reduce_op
        self.mp_policy = mp_policy
        if replicate_world_size is not None:
            self.replicate_world_size = replicate_world_size
        elif replicate_group is not None and hasattr(replicate_group, "rank_size"):
            self.replicate_world_size = replicate_group.rank_size
        elif hsdp_params:
            self.replicate_world_size = hsdp_params[0].unsharded_group_info.rank_size
        else:
            self.replicate_world_size = 1
        self.fused_buffer: Optional[ms.Tensor] = None
        self.param_offsets: List[int] = []
        self.param_numels: List[int] = []
        self.all_reduce_handle: Optional[CommHandle] = None
746
747
748
749
750
751
752
753

    def get_param_buffer_view(self, idx: int) -> ms.Tensor:
        """Return a flat view for reduce_scatter output of parameter idx."""
        if self.fused_buffer is None:
            raise RuntimeError("Fused buffer not allocated. Call allocate_fused_buffer first.")
        offset = self.param_offsets[idx]
        numel = self.param_numels[idx]
        return self.fused_buffer.narrow(0, offset, numel)
758
759
760
761
762
763
764
765
766

    def accumulate_existing_grads_to_buffer(self) -> None:
        """Accumulate existing sharded grads into fused_buffer before all-reduce."""
        if self.fused_buffer is None:
            return
        from hyper_parallel.core.dtensor.dtensor import DTensor

        for idx, hsdp_param in enumerate(self.hsdp_params):
            existing_grad = None
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
        for idx, hsdp_param in enumerate(self.hsdp_params):
            existing_grad = None
            if self.mp_policy is not None and self.mp_policy.apply_grad_on_fp32_main_grad:
                if hasattr(hsdp_param.sharded_param, "main_grad"):
                    existing_grad = hsdp_param.sharded_param.main_grad
            else:
                existing_grad = hsdp_param.sharded_param.grad
            if existing_grad is not None and not hsdp_param.accumulated_allreduced_grad:
                if isinstance(existing_grad, DTensor):
                    existing_grad_local = existing_grad._local_tensor
                else:
                    existing_grad_local = existing_grad
                buffer_view = self.get_param_buffer_view(idx)
                if existing_grad_local.dtype != self.reduce_dtype:
                    existing_grad_local = existing_grad_local.to(self.reduce_dtype)
                buffer_view.add_(existing_grad_local.view_as(buffer_view))
                if self.mp_policy is not None and self.mp_policy.apply_grad_on_fp32_main_grad:
                    if hasattr(hsdp_param.sharded_param, "main_grad"):
                        hsdp_param.sharded_param.main_grad = None
                else:
                    hsdp_param.sharded_param.grad = None

    def issue_async_allreduce(self) -> None:
        """Issue async all_reduce on the fused buffer (SUM for padding correctness)."""
        if self.fused_buffer is None:
            raise RuntimeError("Fused buffer not allocated.")
        self.all_reduce_handle = dist.all_reduce(
            self.fused_buffer,
            op=ops.ReduceOp.SUM,
            group=self.replicate_group,
hyper_parallel/platform/mindspore/fully_shard/scheduler.py
180
181
182
183
184
185
186
187
188
189
190
        # Step 1: Wait for previous reduce-scatter groups and get them for all-reduce
        prev_groups = self.hsdp_state._wait_prev_reduce_scatter()
        # Step 2: Accumulate and issue async all-reduce for previous groups
        for group in prev_groups:
            group.accumulate_existing_grads_to_buffer()
            group.issue_async_allreduce()
            MindSporeHSDPStateV2.pending_all_reduce_groups.append(group)
        # Step 3: Wait/apply any remaining reduce-scatter for pure FSDP params
        self.hsdp_state.reduce_scattered_params()
        # Step 4: Wait for pending all-reduce groups and apply grads
        MindSporeHSDPStateV2.delay_apply_reduce_grads()
hyper_parallel/platform/mindspore/fully_shard/state.py
117
118
119
120
121
122
123
124
125
        # Requires AllReduce for grad When HSDP
        self.requires_all_reduce = True
        # Default reduce op is decided at the fully_shard-state level:
        # if any managed parameter is DTensor-backed, use SUM; otherwise AVG.
        self.reduce_op_type = self._resolve_default_reduce_op()
        self._reset_sharded_params = False
        self._init_param_group()

    def _iter_managed_params(self):
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        x.div_(divisor)

    def _finish_ignored_allreduce(self) -> None:
        """Wait for async all-reduce of replicate_params and materialize param.grad."""
        if not MindSporeHSDPStateV2._ignored_allreduce_works:
            return

        need_synchronize = False
        while MindSporeHSDPStateV2._ignored_allreduce_works:
            param, reduced_grad, reduce_group_size, orig_dtype, need_div = (
                MindSporeHSDPStateV2._ignored_allreduce_works.pop(0)
            )
            if param.all_reduce_handle:
                param.all_reduce_handle.wait()
            self._div_if_needed(reduced_grad, reduce_group_size, need_div)
            need_synchronize = (
                param.apply_reduced_grad(reduced_grad, orig_dtype)
                or need_synchronize
            )

        self._synchronize_current_stream_if_needed(need_synchronize)

    def _move_states_to_device(self):
        """move states to device"""
        for mod in self.modules:
352
353
354
355
356
357
358
359
360
                continue
            if not hsdp_param.sharded_param.requires_grad:
                continue
            if not self._has_pending_unsharded_grad(hsdp_param):
                continue
            if self._should_run_all_reduce(hsdp_param):
                self._queue_compat_all_reduce(hsdp_param)
            else:
                need_synchronize = self._apply_pending_unsharded_grad_locally(hsdp_param)
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432

    def _wait_prev_reduce_scatter(self) -> List:
        """Step 1: wait previous module RS for HSDP fused all-reduce groups."""
        if MindSporeHSDPStateV2.pre_all_reduce_groups:
            prev_groups = list(MindSporeHSDPStateV2.pre_all_reduce_groups)
            MindSporeHSDPStateV2.pre_all_reduce_groups.clear()
            for prev_group in prev_groups:
                for hsdp_param in prev_group.hsdp_params:
                    hsdp_param.reduce_scatter_output()
                    hsdp_param.clear_reduce_scatter_output()
                    if hsdp_param.unsharded_accumulated_grad_data is not None:
                        hsdp_param.unsharded_accumulated_grad = None
                    elif hsdp_param.unsharded_param.grad is not None:
                        hsdp_param.unsharded_param.grad = None
            return prev_groups
        return []

    def _wait_and_apply_prev_no_allreduce_params(self):
        """Step 2: wait/apply previous reduce-scatter for pure FSDP params."""
456
457
458
459
460
461
462
463
464
        """Whether the 4-step RS/AR overlap pipeline has pending work this hook."""
        if MindSporeHSDPStateV2.pre_all_reduce_groups:
            return True
        if HSDPState.pre_reduce_scatter_params:
            return True
        return bool(self._collect_params_for_reduce_scatter())

    def _run_overlap_post_backward_steps(self) -> None:
        """Run the 4-step HSDP RS/AR overlap pipeline for the current module."""
470
471
472
473
474
475
476
477
478
    def _issue_reduce_scatter_for_current_module(self):
        """Issue reduce_scatter for current module with fused all-reduce when needed."""
        params_to_reduce = self._collect_params_for_reduce_scatter()
        if not params_to_reduce:
            return

        groups_by_comm = defaultdict(list)
        for hsdp_param in params_to_reduce:
            if self._should_run_all_reduce(hsdp_param):
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
                replicate_group = hsdp_param.unsharded_group_info.group
                key = id(replicate_group) if replicate_group is not None else None
                groups_by_comm[key].append(hsdp_param)
            else:
                groups_by_comm[None].append(hsdp_param)

        if None in groups_by_comm:
            for hsdp_param in groups_by_comm[None]:
                hsdp_param.reduce_scatter_grad(
                    async_op=True,
                    dtype=self._reduce_dtype,
                    reduce_op=self._resolve_reduce_op(),
                )
                HSDPState.pre_reduce_scatter_params.append(
                    (hsdp_param, self._orig_dtype)
                )

        for key, hsdp_params in groups_by_comm.items():
494
495
496
497
498
499
500
501
502
                )

        for key, hsdp_params in groups_by_comm.items():
            if key is None:
                continue
            group_info = hsdp_params[0].unsharded_group_info
            group = AllReduceParamGroup(
                replicate_group=group_info.group,
                hsdp_params=hsdp_params,
519
520
521
522
523
524
525
526
527
528
529

    def _issue_prev_fused_allreduce(self, prev_groups: List) -> None:
        """Step 4: issue async all-reduce for previous HSDP groups (no-op without fusion groups)."""
        for prev_group in prev_groups:
            prev_group.accumulate_existing_grads_to_buffer()
            prev_group.issue_async_allreduce()
            MindSporeHSDPStateV2.pending_all_reduce_groups.append(prev_group)

    @classmethod
    def delay_apply_reduce_grads(cls) -> None:
        """Wait pending fused all-reduce groups at root backward."""
528
529
530
531
532
533
534
535
536
537
538
539
    def delay_apply_reduce_grads(cls) -> None:
        """Wait pending fused all-reduce groups at root backward."""
        need_synchronize = False
        for group in cls.pending_all_reduce_groups:
            need_synchronize = group.wait_and_apply_grads() or need_synchronize
        cls.pending_all_reduce_groups.clear()
        if need_synchronize:
            ms.runtime.current_stream().synchronize()

    def post_backward_for_comm_fusion(self):
        """Drive the fused gradient-reduction pipeline for sharded params."""
        self.reduce_params()
576
577
578
579
580
581
582
583
584

    def _can_direct_all_reduce_compat_grad(self, hsdp_param) -> bool:
        """Whether ``hsdp_param`` should reduce its existing ``sharded_param.grad`` directly."""
        if not hasattr(hsdp_param, "param_mode"):
            return False
        return (
            hsdp_param.param_mode == FullyShardParamMode.DTENSOR_COMPAT
            and hsdp_param.enable_fsdp_shard
            and not hsdp_param.is_sharded
640
641
642
643
644
645
646
647
648
                if not self._has_pending_unsharded_grad(hsdp_param):
                    continue
                if hsdp_param.shard_size <= 1:
                    if self._should_run_all_reduce(hsdp_param):
                        self._queue_compat_all_reduce(hsdp_param)
                    else:
                        # No-communication path (shard_size == 1, no all-reduce):
                        # this leg owns the scaling since the grad never goes through
                        # reduce_scatter_grad / all_reduce_grad.
hyper_parallel/platform/torch/fully_shard/param_group.py
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
    ) -> Optional[torch.dtype]:
        """Resolve None reduce_dtype to match ``reduce_scatter_grad``'s ``dtype or grad.dtype``."""
        if reduce_dtype is not None:
            return reduce_dtype
        for hsdp_param in hsdp_params:
            if getattr(hsdp_param, "unsharded_accumulated_grad", None) is not None:
                return hsdp_param.unsharded_accumulated_grad_data.dtype
            unsharded_param = getattr(hsdp_param, "unsharded_param", None)
            if unsharded_param is not None and getattr(unsharded_param, "grad", None) is not None:
                return hsdp_param.unsharded_grad_data.dtype
        return orig_dtypes[0] if orig_dtypes else None

    def __init__(
        self,
        replicate_group: dist.ProcessGroup,