Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/platform/mindspore/fully_shard/param.py 51.2% 201,205,573,800-802,804-806,809-811,813-815,820-821,824-825,828
hyper_parallel/platform/mindspore/fully_shard/param_group.py 69.1% 683,686,690,710-713,715,750,761-763,765-768,770-773,775-781,783,788
hyper_parallel/platform/mindspore/fully_shard/scheduler.py 57.1% 177-179
hyper_parallel/platform/mindspore/fully_shard/state.py 78.2% 120,326,388-398,430,444,453,456-457,462,468,493-495,502,505,550,614
hyper_parallel/platform/torch/fully_shard/param_group.py 41.7% 970-976
hyper_parallel/platform/mindspore/fully_shard/param.py
197
198
199
200
201
202
203
204
205
206
207
208
209
        self.gradient_scaling_factor = None

    @property
    def accumulated_allreduced_grad(self) -> bool:
        return self._accumulated_allreduced_grad

    @accumulated_allreduced_grad.setter
    def accumulated_allreduced_grad(self, value: bool) -> None:
        self._accumulated_allreduced_grad = value

    @property
    def uses_param_shard(self) -> bool:
        """Whether FSDP sharding is enabled for this parameter."""
569
570
571
572
573
574
575
576
577
    def _sharded_param_storage_dtype(self) -> Optional[ms.Type]:
        """Return the dtype of the sharded parameter's on-device storage."""
        sharded_param = self.sharded_param
        if isinstance(sharded_param, DTensor):
            return sharded_param._local_tensor.dtype
        if hasattr(sharded_param, "dtype"):
            dtype = sharded_param.dtype
            if isinstance(dtype, ms.Type):
                return dtype
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
        # apply gradient_scaling_factor (reduce-scatter leg)
        apply_gradient_scaling_factor(grad_flat, self.gradient_scaling_factor)
        # If parameter is not sharded (below threshold), no reduce-scatter needed
        if not self.is_sharded:
            if output_buffer is not None:
                copy_without_bumping_version(output_buffer, grad_flat)
                self._reduce_scatter_output = output_buffer
            else:
                self._reduce_scatter_output = grad_flat
            self.reduce_scatter_handle = None
            return self._reduce_scatter_output, None

        if shard_group is None or shard_group_size <= 1:
            if output_buffer is not None:
                copy_without_bumping_version(output_buffer, grad_flat)
                self._reduce_scatter_output = output_buffer
            else:
                self._reduce_scatter_output = grad_flat
            self.reduce_scatter_handle = None
            return self._reduce_scatter_output, None

        # Calculate output size
        output_numel = grad_flat.numel() // shard_group_size
        if output_buffer is not None:
            if output_buffer.numel() != output_numel:
                raise ValueError(
                    f"output_buffer size mismatch: expected {output_numel}, got {output_buffer.numel()}"
                )
            if output_buffer.dtype != reduce_dtype:
                raise ValueError(
                    f"output_buffer dtype mismatch: expected {reduce_dtype}, got {output_buffer.dtype}"
                )
            self._reduce_scatter_output = output_buffer
        else:
            self._reduce_scatter_output = ms.mint.empty(
                output_numel, dtype=reduce_dtype, device=grad.device.split(":")[0]
            )
hyper_parallel/platform/mindspore/fully_shard/param_group.py
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
        orig_dtypes: List[Any],
    ) -> Any:
        """Resolve None reduce_dtype to match ``reduce_scatter_grad``'s ``dtype or grad.dtype``."""
        if reduce_dtype is not None:
            return reduce_dtype
        for hsdp_param in hsdp_params:
            if getattr(hsdp_param, "unsharded_accumulated_grad", None) is not None:
                return hsdp_param.unsharded_accumulated_grad_data.dtype
            unsharded_param = getattr(hsdp_param, "unsharded_param", None)
            if unsharded_param is not None and getattr(unsharded_param, "grad", None) is not None:
                return hsdp_param.unsharded_grad_data.dtype
        return orig_dtypes[0] if orig_dtypes else None

    def __init__(
        self,
        replicate_group,
706
707
708
709
710
711
712
713
714
715
716
717
718
719
        self.reduce_op = reduce_op
        self.mp_policy = mp_policy
        if replicate_world_size is not None:
            self.replicate_world_size = replicate_world_size
        elif replicate_group is not None and hasattr(replicate_group, "rank_size"):
            self.replicate_world_size = replicate_group.rank_size
        elif hsdp_params:
            self.replicate_world_size = hsdp_params[0].unsharded_group_info.rank_size
        else:
            self.replicate_world_size = 1
        self.fused_buffer: Optional[ms.Tensor] = None
        self.param_offsets: List[int] = []
        self.param_numels: List[int] = []
        self.all_reduce_handle: Optional[CommHandle] = None
746
747
748
749
750
751
752
753

    def get_param_buffer_view(self, idx: int) -> ms.Tensor:
        """Return a flat view for reduce_scatter output of parameter idx."""
        if self.fused_buffer is None:
            raise RuntimeError("Fused buffer not allocated. Call allocate_fused_buffer first.")
        offset = self.param_offsets[idx]
        numel = self.param_numels[idx]
        return self.fused_buffer.narrow(0, offset, numel)
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
        return self.get_param_buffer_view(idx).view(target_shape)

    def accumulate_existing_grads_to_buffer(self) -> None:
        """Accumulate existing sharded grads into fused_buffer before all-reduce."""
        if self.fused_buffer is None:
            return
        from hyper_parallel.core.dtensor.dtensor import DTensor

        for idx, hsdp_param in enumerate(self.hsdp_params):
            existing_grad = None
            if self.mp_policy is not None and self.mp_policy.apply_grad_on_fp32_main_grad:
                existing_grad = hsdp_param.sharded_param.main_grad
            else:
                existing_grad = hsdp_param.sharded_param.grad
            if existing_grad is not None and not hsdp_param.accumulated_allreduced_grad:
                if isinstance(existing_grad, DTensor):
                    existing_grad_local = existing_grad._local_tensor
                else:
                    existing_grad_local = existing_grad
                buffer_view = self.get_param_buffer_view(idx)
                if existing_grad_local.dtype != self.reduce_dtype:
                    existing_grad_local = existing_grad_local.to(self.reduce_dtype)
                buffer_view.add_(existing_grad_local.view_as(buffer_view))
                if self.mp_policy is not None and self.mp_policy.apply_grad_on_fp32_main_grad:
                    hsdp_param.sharded_param.main_grad = None
                else:
                    hsdp_param.sharded_param.grad = None

    def issue_async_allreduce(self) -> None:
        """Issue async all_reduce on the fused buffer (SUM for padding correctness)."""
        if self.fused_buffer is None:
            raise RuntimeError("Fused buffer not allocated.")
        self.all_reduce_handle = dist.all_reduce(
            self.fused_buffer,
            op=ops.ReduceOp.SUM,
            group=self.replicate_group,
hyper_parallel/platform/mindspore/fully_shard/scheduler.py
173
174
175
176
177
178
179
180
181
182
183
            # Step 1: Wait for previous reduce-scatter groups and get them for all-reduce
            prev_groups = self.hsdp_state._wait_prev_reduce_scatter()
            # Step 2: Accumulate and issue async all-reduce for previous groups
            for group in prev_groups:
                group.accumulate_existing_grads_to_buffer()
                group.issue_async_allreduce()
                MindSporeHSDPStateV2.pending_all_reduce_groups.append(group)
            # Step 3: Wait/apply any remaining reduce-scatter for pure FSDP params
            self.hsdp_state.reduce_scattered_params()
            # Step 4: Wait for pending all-reduce groups and apply grads
            MindSporeHSDPStateV2.delay_apply_reduce_grads()
hyper_parallel/platform/mindspore/fully_shard/state.py
116
117
118
119
120
121
122
123
124
        # Requires AllReduce for grad When HSDP
        self.requires_all_reduce = True
        # Default reduce op is decided at the fully_shard-state level:
        # if any managed parameter is DTensor-backed, use SUM; otherwise AVG.
        self.reduce_op_type = self._resolve_default_reduce_op()
        self._reset_sharded_params = False
        self._init_param_group()

    def _iter_managed_params(self):
322
323
324
325
326
327
328
329
330
                continue
            if not hsdp_param.sharded_param.requires_grad:
                continue
            if not self._has_pending_unsharded_grad(hsdp_param):
                continue
            if self._should_run_all_reduce(hsdp_param):
                self._queue_compat_all_reduce(hsdp_param)
            else:
                need_synchronize = self._apply_pending_unsharded_grad_locally(hsdp_param)
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402

    def _wait_prev_reduce_scatter(self) -> List:
        """Step 1: wait previous module RS for HSDP fused all-reduce groups."""
        if MindSporeHSDPStateV2.pre_all_reduce_groups:
            prev_groups = list(MindSporeHSDPStateV2.pre_all_reduce_groups)
            MindSporeHSDPStateV2.pre_all_reduce_groups.clear()
            for prev_group in prev_groups:
                for hsdp_param in prev_group.hsdp_params:
                    hsdp_param.reduce_scatter_output()
                    hsdp_param.clear_reduce_scatter_output()
                    if hsdp_param.unsharded_accumulated_grad_data is not None:
                        hsdp_param.unsharded_accumulated_grad = None
                    elif hsdp_param.unsharded_param.grad is not None:
                        hsdp_param.unsharded_param.grad = None
            return prev_groups
        return []

    def _wait_and_apply_prev_no_allreduce_params(self):
        """Step 2: wait/apply previous reduce-scatter for pure FSDP params."""
426
427
428
429
430
431
432
433
434
        """Whether the 4-step RS/AR overlap pipeline has pending work this hook."""
        if MindSporeHSDPStateV2.pre_all_reduce_groups:
            return True
        if HSDPState.pre_reduce_scatter_params:
            return True
        return bool(self._collect_params_for_reduce_scatter())

    def _run_overlap_post_backward_steps(self) -> None:
        """Run the 4-step HSDP RS/AR overlap pipeline for the current module."""
440
441
442
443
444
445
446
447
448
    def _issue_reduce_scatter_for_current_module(self):
        """Issue reduce_scatter for current module with fused all-reduce when needed."""
        params_to_reduce = self._collect_params_for_reduce_scatter()
        if not params_to_reduce:
            return

        groups_by_comm = defaultdict(list)
        for hsdp_param in params_to_reduce:
            if self._should_run_all_reduce(hsdp_param):
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
                replicate_group = hsdp_param.unsharded_group_info.group
                key = id(replicate_group) if replicate_group is not None else None
                groups_by_comm[key].append(hsdp_param)
            else:
                groups_by_comm[None].append(hsdp_param)

        if None in groups_by_comm:
            for hsdp_param in groups_by_comm[None]:
                hsdp_param.reduce_scatter_grad(
                    async_op=True,
                    dtype=self._reduce_dtype,
                    reduce_op=self._resolve_reduce_op(),
                )
                HSDPState.pre_reduce_scatter_params.append(
                    (hsdp_param, self._orig_dtype)
                )

        for key, hsdp_params in groups_by_comm.items():
464
465
466
467
468
469
470
471
472
                )

        for key, hsdp_params in groups_by_comm.items():
            if key is None:
                continue
            group_info = hsdp_params[0].unsharded_group_info
            group = AllReduceParamGroup(
                replicate_group=group_info.group,
                hsdp_params=hsdp_params,
489
490
491
492
493
494
495
496
497
498
499

    def _issue_prev_fused_allreduce(self, prev_groups: List) -> None:
        """Step 4: issue async all-reduce for previous HSDP groups (no-op without fusion groups)."""
        for prev_group in prev_groups:
            prev_group.accumulate_existing_grads_to_buffer()
            prev_group.issue_async_allreduce()
            MindSporeHSDPStateV2.pending_all_reduce_groups.append(prev_group)

    @classmethod
    def delay_apply_reduce_grads(cls) -> None:
        """Wait pending fused all-reduce groups at root backward."""
498
499
500
501
502
503
504
505
506
507
508
509
    def delay_apply_reduce_grads(cls) -> None:
        """Wait pending fused all-reduce groups at root backward."""
        need_synchronize = False
        for group in cls.pending_all_reduce_groups:
            need_synchronize = group.wait_and_apply_grads() or need_synchronize
        cls.pending_all_reduce_groups.clear()
        if need_synchronize:
            ms.runtime.current_stream().synchronize()

    def post_backward_for_comm_fusion(self):
        """Drive the fused gradient-reduction pipeline for sharded params."""
        self.reduce_params()
546
547
548
549
550
551
552
553
554

    def _can_direct_all_reduce_compat_grad(self, hsdp_param) -> bool:
        """Whether ``hsdp_param`` should reduce its existing ``sharded_param.grad`` directly."""
        if not hasattr(hsdp_param, "param_mode"):
            return False
        return (
            hsdp_param.param_mode == FullyShardParamMode.DTENSOR_COMPAT
            and hsdp_param.enable_fsdp_shard
            and not hsdp_param.is_sharded
610
611
612
613
614
615
616
617
618
                if not self._has_pending_unsharded_grad(hsdp_param):
                    continue
                if hsdp_param.shard_size <= 1:
                    if self._should_run_all_reduce(hsdp_param):
                        self._queue_compat_all_reduce(hsdp_param)
                    else:
                        # No-communication path (shard_size == 1, no all-reduce):
                        # this leg owns the scaling since the grad never goes through
                        # reduce_scatter_grad / all_reduce_grad.
hyper_parallel/platform/torch/fully_shard/param_group.py
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
    ) -> Optional[torch.dtype]:
        """Resolve None reduce_dtype to match ``reduce_scatter_grad``'s ``dtype or grad.dtype``."""
        if reduce_dtype is not None:
            return reduce_dtype
        for hsdp_param in hsdp_params:
            if getattr(hsdp_param, "unsharded_accumulated_grad", None) is not None:
                return hsdp_param.unsharded_accumulated_grad_data.dtype
            unsharded_param = getattr(hsdp_param, "unsharded_param", None)
            if unsharded_param is not None and getattr(unsharded_param, "grad", None) is not None:
                return hsdp_param.unsharded_grad_data.dtype
        return orig_dtypes[0] if orig_dtypes else None

    def __init__(
        self,
        replicate_group: dist.ProcessGroup,