Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/dtensor/_collective_utils.py 36.0% 29-31,47-49,53-54,56-57,68-73
hyper_parallel/core/dtensor/dtensor.py 80.6% 529,536,544,547,553,559
hyper_parallel/core/tensor_parallel/style.py 100%  
hyper_parallel/platform/mindspore/platform.py 26.3% 1317-1319,1327-1332,1334-1337,1670
hyper_parallel/platform/platform.py 71.4% 177,520
hyper_parallel/platform/torch/platform.py 33.3% 650,986-987,989-990,994-997,999
hyper_parallel/core/dtensor/_collective_utils.py
25
26
27
28
29
30
31
32
33
34
35


def _ensure_mesh_process_groups(mesh: DeviceMesh) -> None:
    """Lazily create per-axis process groups when mesh was built with ``init_backend=False``."""
    if hasattr(mesh, "_dim_group_names") and mesh._dim_group_names is not None:
        return
    mesh._dim_group_names = DeviceMesh._init_process_groups(  # pylint: disable=protected-access
        mesh._mesh_shape,
        mesh.mesh_dim_names,
        mesh._rank_list,
    )
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    *,
    group_src: int = 0,
) -> Tensor:
    """Scatter tensor chunks along one mesh dimension (PyTorch ``mesh_scatter`` parity)."""
    _ensure_mesh_process_groups(mesh)
    group = mesh.get_group(mesh_dim)
    contiguous_list = [
        chunk.contiguous() if hasattr(chunk, "is_contiguous") and not chunk.is_contiguous() else chunk
        for chunk in scatter_list
    ]
    if platform.get_group_rank(group) == group_src:
        platform.scatter(output, list(contiguous_list), group=group, group_src=group_src)
    else:
        platform.scatter(output, None, group=group, group_src=group_src)
    return output


def mesh_broadcast(
    tensor: Tensor,
64
65
66
67
68
69
70
71
72
73
    *,
    group_src: int = 0,
) -> Tensor:
    """Broadcast a tensor along one mesh dimension (PyTorch ``mesh_broadcast`` parity)."""
    _ensure_mesh_process_groups(mesh)
    group = mesh.get_group(mesh_dim)
    if hasattr(tensor, "is_contiguous") and not tensor.is_contiguous():
        tensor = tensor.contiguous()
    platform.broadcast(tensor, group=group, group_src=group_src)
    return tensor
hyper_parallel/core/dtensor/dtensor.py
525
526
527
528
529
530
531
532
533
    from hyper_parallel.core.dtensor.placement_types import Partial, StridedShard

    local = tensor
    if len(placements) < device_mesh.ndim:
        raise ValueError(
            f"placements length ({len(placements)}) must be at least device_mesh.ndim "
            f"({device_mesh.ndim}) when src_data_rank is set"
        )
    for mesh_dim in range(device_mesh.ndim):
532
533
534
535
536
537
538
539
540
        )
    for mesh_dim in range(device_mesh.ndim):
        placement = placements[mesh_dim]
        if isinstance(placement, StridedShard):
            raise NotImplementedError(
                "distribute_tensor with src_data_rank does not support StridedShard yet; "
                "pass src_data_rank=None for local-only sharding."
            )
        if placement.is_shard():
540
541
542
543
544
545
546
547
548
549
550
551
        if placement.is_shard():
            shard_dim = _normalize_shard_dim(placement.dim, local.ndim)
            num_chunks = device_mesh.size(mesh_dim)
            if num_chunks <= 0:
                raise ValueError(f"invalid mesh dim size {num_chunks} on mesh_dim={mesh_dim}")
            chunks = tuple(local.chunk(num_chunks, dim=shard_dim))
            if not chunks:
                raise ValueError(f"cannot shard dim {shard_dim} into {num_chunks} chunks")
            output = platform.empty_like(chunks[0])
            local = mesh_scatter(output, chunks, device_mesh, mesh_dim, group_src=src_data_rank)
        elif placement.is_replicate() or placement.is_partial():
            local = mesh_broadcast(local, device_mesh, mesh_dim, group_src=src_data_rank)
549
550
551
552
553
554
555
556
557
            local = mesh_scatter(output, chunks, device_mesh, mesh_dim, group_src=src_data_rank)
        elif placement.is_replicate() or placement.is_partial():
            local = mesh_broadcast(local, device_mesh, mesh_dim, group_src=src_data_rank)
            if isinstance(placement, Partial):
                warnings.warn(
                    f"Partial placement {placement} during distribute_tensor: "
                    "broadcast only; partial partition is not applied yet.",
                    stacklevel=3,
                )
555
556
557
558
559
560
561
562
                    "broadcast only; partial partition is not applied yet.",
                    stacklevel=3,
                )
        else:
            raise RuntimeError(
                f"unsupported placement {placement} on device mesh dimension {mesh_dim}"
            )
    return local
hyper_parallel/platform/mindspore/platform.py
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
        return data, handle

    @staticmethod
    def broadcast(data, src=None, group=None, async_op=False, group_src=None):
        if group_src is not None:
            ranks = MindSporePlatform.get_process_group_ranks(group)
            src = ranks[group_src]
        handle = dist.broadcast(data, src, group, async_op)
        if async_op:
            handle.wait()
        return data
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
        return data

    @staticmethod
    def scatter(output, scatter_list, src=None, group=None, async_op=False, group_src=None):
        group_name = group if isinstance(group, str) else getattr(group, "group_name", group)
        if group_src is not None:
            ranks = MindSporePlatform.get_process_group_ranks(group)
            src = ranks[group_src]
        if scatter_list is not None:
            scatter_list = [c.contiguous() if hasattr(c, "is_contiguous") and not c.is_contiguous() else c
                            for c in scatter_list]
        handle = dist.scatter(output, scatter_list, src, group_name, async_op=async_op)
        if async_op and handle is not None:
            handle.wait()
        return output

    @staticmethod
    def reduce_scatter_tensor(data, group_info, async_op=False):
        group_name = group_info if isinstance(group_info, str) else group_info.group_name
1666
1667
1668
1669
1670
1671
1672
1673
1674
        return dist.get_group_rank(group, MindSporePlatform.get_rank())

    @staticmethod
    def get_group_rank(group=None) -> int:
        return MindSporePlatform.get_group_local_rank(group)

    @staticmethod
    def no_grad():
        return _no_grad()
hyper_parallel/platform/platform.py
173
174
175
176
177
178
179
180
181

    @staticmethod
    def get_group_rank(group):
        """Return this process's rank within *group*."""
        raise NotImplementedError("Platform subclasses must implement get_group_rank")

    @staticmethod
    def get_world_size():
        """Get the total number of processes in the default process group.
516
517
518
519
520
521
522
523
524

    @staticmethod
    def scatter(output, scatter_list, src=None, group=None, async_op=False, group_src=None):
        """Scatter tensor list from source rank to all ranks in group."""
        raise NotImplementedError("Platform subclasses must implement scatter")

    @staticmethod
    def isend(tensor, dst=None, group=None, tag=0):
        """Send tensor asynchronously to destination rank.
hyper_parallel/platform/torch/platform.py
646
647
648
649
650
651
652
653
654

    @staticmethod
    def get_group_rank(group):
        """Return this process's rank within *group*."""
        return dist.get_group_rank(group, dist.get_rank())

    @staticmethod
    def get_world_size():
        """
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
        return data, handle

    @staticmethod
    def broadcast(data, src=None, group=None, async_op=False, group_src=None):
        if group_src is not None:
            src = dist.get_global_rank(group, group_src)
        handle = dist.broadcast(data, src, group, async_op)
        if async_op and handle is not None:
            handle.wait()

    @staticmethod
    def scatter(output, scatter_list, src=None, group=None, async_op=False, group_src=None):
        if group_src is not None:
            src = dist.get_global_rank(group, group_src)
        handle = dist.scatter(output, scatter_list, src=src, group=group, async_op=async_op)
        if async_op and handle is not None:
            handle.wait()
        return output

    @staticmethod
    def isend(tensor, dst=None, group=None, tag=0):
        return dist.isend(tensor, dst, group, tag)