Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/platform/mindspore/fully_shard/param.py 72.7% 641-642,671-672,680-681
hyper_parallel/platform/mindspore/fully_shard/state.py 100%  
hyper_parallel/platform/mindspore/fully_shard/param.py
637
638
639
640
641
642
643
644
645
646
            # NPU, so reuse it directly; the device path keeps it for Torch parity.
            if self.offload_to_cpu:
                local_view = local_tensor
            else:
                with _no_grad():
                    local_view = local_tensor.narrow(dim=shard_dim, start=0, length=length)
            set_requires_grad_if_needed(self.sharded_param, local_view)
            self.sharded_param._local_tensor = local_view
            if not self.sharded_param._local_tensor.is_contiguous():
                raise AssertionError(
667
668
669
670
671
672
673
674
675
676
            # wait here or the collective reads a half-written shard (verified racy).
            ms.runtime.synchronize()
            updated_local_tensor = True
        if self.pin_memory and not local_tensor.is_pinned():
            local_tensor = local_tensor.pin_memory()
            updated_local_tensor = True
        if self.offload_to_cpu:
            # Offloaded shards stay on host (``view(-1)`` would relocate them to
            # NPU); all_gather_inputs stages the device copy per unshard. Rebind
            # unconditionally: the to("cpu")/pin above can produce a new host
676
677
678
679
680
681
682
683
684
685
            # unconditionally: the to("cpu")/pin above can produce a new host
            # tensor even when same_local_tensor was True, leaving the old NPU
            # storage referenced here.
            self._sharded_param_data = local_tensor
        elif not same_local_tensor:
            self._sharded_param_data = local_tensor.view(-1)
        return local_tensor, updated_local_tensor

    def _get_unsharded_param_data(self, async_op: bool = False) -> Tuple[ms.Tensor, Optional[CommHandle]]:
        """