Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/optimizer/__init__.py 0.0% 86,97,104,108
hyper_parallel/core/optimizer/adamw.py 0.0% 41-42
hyper_parallel/core/optimizer/dtensor_compat.py 0.0% 22,24-25,27,29-30,33,36,38,42,45,47,50,57,59-60,62-66,68-70,72,74-76,79,81-84,86-87,89,93,95-98,101,103-107,110,112-116,119,121-125,128,130-131,133-134,138,140-144,147,149,152,155-156,160,169,171-178,181,183
hyper_parallel/core/optimizer/muon.py 0.0% 20,25-26,125,128-133,341-342,344,346,348,429,549
hyper_parallel/core/optimizer/optimizer.py 0.0% 20,29,181,190,196,205,212-216,218-219,222-229,231,233-236,239,242-243,245,247-248,250-252,254,256,258-259,262-270,272,274-280,282-283,285-287,289-290,294-295,297,303,305,307-308,310-311,313,315,321-322,325-334,336-337,341-345,349-350,354-355,358-359,361-363,366-369,371,373,375-376,381,388-389,392-393,429-432,447,481-483,485-486,489,560,745,747,904
hyper_parallel/core/optimizer/sharding_category.py 0.0% 25,276-278
hyper_parallel/core/optimizer/__init__.py
82
83
84
85
86
87
88
89
90
    }
    allowed_keys_adamw = inspect.signature(AdamW.__init__).parameters.keys() - {'self', 'params'}
    filtered_adamw_config = {k: v for k, v in adamw_config.items() if k in allowed_keys_adamw}
    if excluded_adamw_keys := adamw_config.keys() - allowed_keys_adamw:
        logger.info_rank0("Excluded adamw config: %s", list(excluded_adamw_keys))

    # 1.2 muon
    muon_raw = muon_kwargs or {}
    muon_config = {
 93
 94
 95
 96
 97
 98
 99
100
    }
    allowed_keys_muon = inspect.signature(Muon.__init__).parameters.keys() - {'self', 'params'}
    filtered_muon_config = {k: v for k, v in muon_config.items() if k in allowed_keys_muon}
    if excluded_muon_keys := muon_config.keys() - allowed_keys_muon:
        logger.info_rank0("Excluded muon config: %s", list(excluded_muon_keys))

    # 2. Optimizer Creation
    optimizers = {}
100
101
102
103
104
105
106
107
108
109
110
111
112
    optimizers = {}

    if adamw_params:
        optimizers["adamw"] = AdamW(adamw_params, **filtered_adamw_config)
        logger.info_rank0("Using adamw config: %s", filtered_adamw_config)

    if muon_params:
        optimizers["muon"] = Muon(muon_params, **filtered_muon_config)
        logger.info_rank0("Using muon config: %s", filtered_muon_config)

    flatten = bool(adamw_params and muon_params)

    return ChainedOptimizer(model, optimizers=optimizers, flatten=flatten)
hyper_parallel/core/optimizer/adamw.py
37
38
39
40
41
42
43
44
45
46
) -> None:
    r"""Functional API that performs AdamW algorithm computation.
    See :class:`~torch.optim.AdamW` for details.
    """
    device = torch.npu.current_device() if torch.npu.is_available() else torch.cuda.current_device()
    step_tensor = torch.tensor(step, dtype=torch.int64, device=device)
    state_steps = [step_tensor] * len(params)

    torch._fused_adamw_(  # pylint: disable=protected-access
        params,
hyper_parallel/core/optimizer/dtensor_compat.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
Provides lazy exports (PEP 562) for DTensor, DeviceMesh, Shard, Replicate, 
and StridedShard based on the detected backend ('torch' or 'hyper').
"""

from __future__ import annotations

import logging
from typing import Any, List

import torch.distributed._tensor as torch_dt

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global backend flag
_DTENSOR_BACKEND: str = "hyper"  # "hyper" or "torch"


class _NeverMatch:
    """Safe fallback class that always returns False for ``isinstance()``."""
    __slots__ = ()


# Lazy-export cache
_LAZY_CACHE: dict = {}


def _invalidate_lazy_cache() -> None:
    """Clear the lazy-export cache to rebuild on next access."""
    _LAZY_CACHE.clear()


def detect_dtensor_backend(
        adamw_params: List[Any],
        muon_params: List[Any],
) -> str:
    """Detect and set the DTensor backend ('torch' or 'hyper') from parameter lists."""
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
) -> str:
    """Detect and set the DTensor backend ('torch' or 'hyper') from parameter lists."""
    global _DTENSOR_BACKEND  # pylint: disable=global-statement

    sample_param = _extract_first_param(muon_params)

    if sample_param is None:
        sample_param = _extract_first_param(adamw_params)

    if sample_param is None:
        logger.info("No parameters found for backend detection; defaulting to 'hyper'.")
        _DTENSOR_BACKEND = "hyper"
        _invalidate_lazy_cache()
        return _DTENSOR_BACKEND

    param_cls_module = type(sample_param).__module__
    if param_cls_module.startswith("torch.distributed"):
        _DTENSOR_BACKEND = "torch"
    else:
        _DTENSOR_BACKEND = "hyper"

    logger.info("Detected DTensor backend: '%s'.", _DTENSOR_BACKEND)
    _invalidate_lazy_cache()
    return _DTENSOR_BACKEND


def _extract_first_param(param_groups: List[Any]) -> Any:
    """Return the first parameter from a list of param groups, or None."""
    for group in param_groups:
        params = group.get("params", []) if isinstance(group, dict) else []
        for p in params:
            return p

    for p in param_groups:
        return p

    return None


# Accessor functions
def get_dtensor_cls():
    """Return the DTensor class for the active backend."""
    if _DTENSOR_BACKEND == "torch":
        return torch_dt.DTensor
    from hyper_parallel.core.dtensor.dtensor import DTensor  # pylint: disable=import-outside-toplevel
    return DTensor


def get_device_mesh_cls():
    """Return the DeviceMesh class for the active backend."""
    if _DTENSOR_BACKEND == "torch":
        from torch.distributed.device_mesh import DeviceMesh  # pylint: disable=import-outside-toplevel
        return DeviceMesh
    from hyper_parallel.core.dtensor.device_mesh import DeviceMesh  # pylint: disable=import-outside-toplevel
    return DeviceMesh


def get_shard_cls():
    """Return the Shard placement class for the active backend."""
    if _DTENSOR_BACKEND == "torch":
        from torch.distributed._tensor.placement_types import Shard  # pylint: disable=import-outside-toplevel
        return Shard
    from hyper_parallel.core.dtensor.placement_types import Shard  # pylint: disable=import-outside-toplevel
    return Shard


def get_replicate_cls():
    """Return the Replicate placement class for the active backend."""
    if _DTENSOR_BACKEND == "torch":
        from torch.distributed._tensor.placement_types import Replicate  # pylint: disable=import-outside-toplevel
        return Replicate
    from hyper_parallel.core.dtensor.placement_types import Replicate  # pylint: disable=import-outside-toplevel
    return Replicate


def get_strided_shard_cls():
    """Return the StridedShard placement class. Returns _NEVER_MATCH for 'torch'."""
    if _DTENSOR_BACKEND == "torch":
        return _NeverMatch

    from hyper_parallel.core.dtensor.placement_types import StridedShard  # pylint: disable=import-outside-toplevel
    return StridedShard


# DTensor union type resolver
def _import_hyper_dtensor():
    """Import hyper DTensor class; return torch DTensor as fallback."""
    try:
        from hyper_parallel.core.dtensor.dtensor import DTensor  # pylint: disable=import-outside-toplevel
        return DTensor
    except ImportError:
        return torch_dt.DTensor


def _resolve_dtensor_union():
    """Build ``torch_dt.DTensor | hyper_dt.DTensor`` on demand."""
    return torch_dt.DTensor | _import_hyper_dtensor()


def to_local_if_dtensor(tensor: Any) -> Any:
    """Return the local shard if `tensor` is a DTensor, otherwise return as-is."""
    # Use resolver directly for internal module lookups instead of lazy-loaded DTensor
    dtensor_type = _LAZY_CACHE.get("DTensor") or _resolve_dtensor_union()
    return tensor.to_local() if isinstance(tensor, dtensor_type) else tensor


# lazy exports
_LAZY_RESOLVERS = {
    "DTensor": _resolve_dtensor_union,
    "DeviceMesh": get_device_mesh_cls,
    "Shard": get_shard_cls,
    "Replicate": get_replicate_cls,
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    "StridedShard": get_strided_shard_cls,
}


def __getattr__(name):  # type: ignore[no-untyped-def]  # pylint: disable=invalid-name
    """Resolve module attributes on first access."""
    resolver = _LAZY_RESOLVERS.get(name)
    if resolver is not None:
        value = _LAZY_CACHE.get(name)
        if value is None:
            value = resolver()
            _LAZY_CACHE[name] = value
        return value
    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def __dir__():  # type: ignore[no-untyped-def]  # pylint: disable=invalid-name
    """Include lazy-exported names in dir() for IDE autocomplete."""
    return list(globals().keys()) + list(_LAZY_RESOLVERS.keys())
hyper_parallel/core/optimizer/muon.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""Muon optimizer with HSDP shard-group-aware communication."""

import math
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist

from hyper_parallel.core.optimizer.optimizer import AsyncReplicateBroadcaster, BaseDistributedOptimizer
from hyper_parallel.core.optimizer.dtensor_compat import to_local_if_dtensor
from hyper_parallel.core.optimizer.sharding_category import (
    HSDPGroupAssignment,
    fused_allgather_dtensor_params,
    build_owner_by_size,
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            "momentum": momentum,
            "nesterov": nesterov,
            "ns_steps": ns_steps,
        }
        super().__init__(params, defaults, is_muon=True, hsdp_replica_count=hsdp_replica_count)

        self._group_dtensor_by_mesh()
        deduced_count = self._auto_deduce_replica_count()
        if deduced_count is None:
            self.hsdp_replica_count = None
        elif self.hsdp_replica_count is None:
            self.hsdp_replica_count = deduced_count
        self._split_replicate_groups()
        self._build_hsdp_batch()
        self._build_param_broadcast_info()
        self._classify_parameters_for_step()
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
                )

                # Fused batched apply — all params in the same sub_batch share
                # the same adjusted_lr, so we can use foreach ops.
                local_params = [to_local_if_dtensor(p.data) for p in sub_batch]
                local_updates = [updates_dict[p].view(lp.shape) for p, lp in zip(sub_batch, local_params)]

                if weight_decay != 0.0:
                    # pylint: disable=protected-access
                    torch._foreach_mul_(local_params, 1 - lr * weight_decay)
                # pylint: disable=protected-access
                torch._foreach_add_(local_params, local_updates, alpha=-adjusted_lr)

    def _gather_and_compute_shard_updates(
            self,
            valid_params: List[torch.nn.Parameter],
425
426
427
428
429
430
431
432
433
            buffer_cache: Optional[Dict] = None,
    ) -> None:
        """Process sharded params with greedy shard-group compute assignment."""
        platform = get_platform()
        device = torch.npu.current_device() if torch.npu.is_available() else torch.cuda.current_device()

        lr = group["lr"]
        weight_decay = group["weight_decay"]
        rms = group["matched_adamw_rms"]
545
546
547
548
549
550
551
552
553
        slice_sizes = []
        shapes_info = []

        for p in p_list:
            origin_shape = tuple(getattr(p, 'local_shape', None) or p.to_local().shape) if no_shard else tuple(p.shape)
            ns_input = ns_inputs[p].view(origin_shape)

            is_conv = False
            if len(origin_shape) == 2:
hyper_parallel/core/optimizer/optimizer.py
16
17
18
19
20
21
22
23
24
"""Base distributed optimizer and chain optimizer composition."""

from collections import defaultdict
import logging
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from torch.distributed.checkpoint.state_dict import (
25
26
27
28
29
30
31
32
33
    StateDictOptions,
    get_optimizer_state_dict,
    set_optimizer_state_dict,
)
from hyper_parallel.core.optimizer.dtensor_compat import to_local_if_dtensor
from hyper_parallel.core.optimizer.sharding_category import (
    HSDPGroupAssignment,
    build_owner_by_size,
    get_multi_dim_logical_info,
177
178
179
180
181
182
183
184
185

    Provides fused hierarchical broadcast for parameters and optimizer states.
    """

    def __init__(
            self,
            params: Any,
            defaults: Dict[str, Any],
            is_muon: bool,
186
187
188
189
190
191
192
193
            hsdp_replica_count: Optional[Union[int, Tuple[int, ...]]] = None,
    ) -> None:
        super().__init__(params, defaults)
        self.is_muon = is_muon
        self.hsdp_replica_count = hsdp_replica_count
        self._param_to_broadcast_info: Dict[
            torch.nn.Parameter, Tuple[Tuple[int, ...], Tuple[dist.ProcessGroup, ...]]
        ] = {}
192
193
194
195
196
197
198
199
200
            torch.nn.Parameter, Tuple[Tuple[int, ...], Tuple[dist.ProcessGroup, ...]]
        ] = {}

        # Cache: (parent_ranks_tuple, sub_size) -> {sub_idx: sub_pg}
        self._split_sub_pg_cache: Dict[Tuple[Tuple[int, ...], int], Dict[int, dist.ProcessGroup]] = {}

    def _group_dtensor_by_mesh(self):
        """Group DTensor parameters by mesh topology and shard layout."""
        self._hsdp_grouping: Dict[int, Tuple[List, List]] = {}
201
202
203
204
205
206
207
208
209
        for group_key, group in enumerate(self.param_groups):
            no_comm_params, hsdp_groups = group_parameters_for_hsdp(group["params"])
            self._hsdp_grouping[group_key] = (no_comm_params, hsdp_groups)

    def _auto_deduce_replica_count(self) -> Optional[Union[int, Tuple[int, ...]]]:
        """Deduce hsdp_replica_count based on cluster topology.

        - Intra-node PGs: Full dedup (no split), high bandwidth makes broadcast cheap.
        - Inter-node PGs: Split at node boundaries to restrict communication domains 
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        - Intra-node PGs: Full dedup (no split), high bandwidth makes broadcast cheap.
        - Inter-node PGs: Split at node boundaries to restrict communication domains 
          within a single node, bypassing cross-node bottlenecks.
        """
        devices_per_node = 1
        if torch.npu.is_available():
            devices_per_node = torch.npu.device_count()
        elif torch.cuda.is_available():
            devices_per_node = torch.cuda.device_count()

        dedup_per_dim: Dict[int, int] = {}
        needs_split = False

        # group_key, (no_comm_params, hsdp_groups)
        for _, (_, hsdp_groups) in self._hsdp_grouping.items():
            for hsdp_group in hsdp_groups:
                for dim_idx, pg in enumerate(hsdp_group.replicate_pgs):
                    if pg is None:
                        continue
                    pg_size = dist.get_world_size(pg)
                    if pg_size <= 1:
                        continue

                    if pg_size > devices_per_node:
                        # Inter-node: Find largest divisor safe for node boundary
                        dedup = min(pg_size, devices_per_node)
                        while pg_size % dedup != 0:
                            dedup -= 1
                        needs_split = True
                    else:
                        # origin Inter-node
                        dedup = pg_size

                    # Enforce conservative (smallest) dedup across shared mesh axes
                    if dim_idx not in dedup_per_dim:
                        dedup_per_dim[dim_idx] = dedup
                    else:
                        dedup_per_dim[dim_idx] = min(dedup_per_dim[dim_idx], dedup)

        if not needs_split:
            return None

        sorted_dedups = [dedup_per_dim[k] for k in sorted(dedup_per_dim.keys())]
        if len(sorted_dedups) == 1:
            return sorted_dedups[0]

        return tuple(sorted_dedups)

    def _split_replicate_groups(self) -> None:
        """Split replicate ProcessGroups into smaller sub-groups based on hsdp_replica_count."""
        if self.hsdp_replica_count is None:
            return

        # Argument validation
        if isinstance(self.hsdp_replica_count, int):
            if self.hsdp_replica_count <= 0:
                raise ValueError(f"hsdp_replica_count must be positive, got {self.hsdp_replica_count}")
        elif isinstance(self.hsdp_replica_count, tuple):
            if not self.hsdp_replica_count:
                raise ValueError("hsdp_replica_count tuple must not be empty")
            for i, v in enumerate(self.hsdp_replica_count):
                if not isinstance(v, int) or v <= 0:
                    raise ValueError(f"hsdp_replica_count[{i}] must be positive, got {v}")
        else:
            raise TypeError(f"Unsupported hsdp_replica_count type: {type(self.hsdp_replica_count).__name__}")

        for group_key, (_, hsdp_groups) in self._hsdp_grouping.items():
            for hsdp_group in hsdp_groups:
                new_replicate_pgs: List[dist.ProcessGroup] = []
                for dim_idx, pg in enumerate(hsdp_group.replicate_pgs):
                    if pg is None:
                        new_replicate_pgs.append(pg)
                        continue

                    pg_size = dist.get_world_size(pg)
                    dedup_size = self._get_dedup_size_for_dim(dim_idx, pg_size)

                    if pg_size <= dedup_size:
                        new_replicate_pgs.append(pg)
                        continue

                    if pg_size % dedup_size != 0:
                        raise ValueError(
                            f"hsdp_replica_count {dedup_size} must evenly divide replicate group size {pg_size}"
                        )

                    sub_pg = self._get_or_create_sub_pg(pg, dedup_size)
                    new_replicate_pgs.append(sub_pg)

                    logger.info_rank0(
                        "[HSDP Split] group_key=%s, dim_idx=%s, original_size=%s, sub_size=%s",
                        group_key, dim_idx, pg_size, dedup_size
                    )
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
                        group_key, dim_idx, pg_size, dedup_size
                    )

                # replace new sub groups of hsdp groups
                hsdp_group.replicate_pgs = tuple(new_replicate_pgs)

    def _get_dedup_size_for_dim(self, dim_idx: int, pg_size: int) -> int:
        """Get the effective dedup size for a specific replicate dimension."""
        if isinstance(self.hsdp_replica_count, int):
            return self.hsdp_replica_count

        if dim_idx < len(self.hsdp_replica_count):
            return self.hsdp_replica_count[dim_idx]

        return pg_size

    def _get_or_create_sub_pg(
            self,
            parent_pg: dist.ProcessGroup,
            sub_size: int,
    ) -> dist.ProcessGroup:
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
            parent_pg: dist.ProcessGroup,
            sub_size: int,
    ) -> dist.ProcessGroup:
        """Retrieve or collectively create a synchronized sub-ProcessGroup."""
        local_parent_ranks = tuple(sorted(list(dist.get_process_group_ranks(parent_pg))))
        cache_key = (local_parent_ranks, sub_size)

        # Cache Hit Check
        if cache_key in self._split_sub_pg_cache:
            sub_pg_map = self._split_sub_pg_cache[cache_key]
            for sub_idx, sub_pg in sub_pg_map.items():
                if sub_pg is not None:
                    try:
                        dist.get_rank(group=sub_pg)
                        return sub_pg
                    except RuntimeError:
                        continue
            raise RuntimeError(f"Current rank not found in cached sub-groups for parent_pg={local_parent_ranks}")

        global_rank = dist.get_rank()
        world_size = dist.get_world_size()

        # Global Rendezvous via GPU all_gather_into_tensor (NCCL/HCCL fast-path).
        # Far more scalable than all_gather_object for large world_size.
        device = torch.npu.current_device() if torch.npu.is_available() else torch.cuda.current_device()
        num_parent_ranks = len(local_parent_ranks)
        local_tensor = torch.tensor(local_parent_ranks, dtype=torch.long, device=device)
        gathered_tensor = torch.empty((world_size, num_parent_ranks), dtype=torch.long, device=device)
        dist.all_gather_into_tensor(gathered_tensor.view(-1), local_tensor)

        # Deduplicate: each row is one rank's parent_ranks; convert to
        # sorted set of tuples for deterministic iteration order.
        gathered_cpu_list = gathered_tensor.cpu().tolist()
        unique_all_parent_ranks = sorted(set(
            tuple(row) for row in gathered_cpu_list
        ))

        sub_pg_map: Dict[int, dist.ProcessGroup] = {}
        my_sub_pg: Optional[dist.ProcessGroup] = None

        # Synchronized collective creation loop
        for parent_ranks in unique_all_parent_ranks:
            num_sub_groups = len(parent_ranks) // sub_size

            for sub_idx in range(num_sub_groups):
                sub_ranks = parent_ranks[sub_idx * sub_size: (sub_idx + 1) * sub_size]
                sub_pg = dist.new_group(sub_ranks)

                # Map results only if this parent ranks list matches the current rank's context
                if parent_ranks == local_parent_ranks:
                    if global_rank in sub_ranks:
                        my_sub_pg = sub_pg
                        sub_pg_map[sub_idx] = sub_pg
                    else:
                        sub_pg_map[sub_idx] = None

        self._split_sub_pg_cache[cache_key] = sub_pg_map

        if my_sub_pg is None:
            raise RuntimeError(
                f"Rank {global_rank} not found in any sub-group of parent_pg "
                f"with ranks {local_parent_ranks} and sub_size={sub_size}"
            )

        return my_sub_pg

    def _build_hsdp_batch(
            self,
            max_batch_numel: Optional[int] = None,
384
385
386
387
388
389
390
391
392
393
394
395
396
397
            self,
            max_batch_numel: Optional[int] = None,
    ) -> None:
        """Split HSDP groups into memory-capped batches for compute-broadcast overlap."""
        if max_batch_numel is None:
            broadcast_max_bytes = getattr(
                self, "replicate_broadcast_max_bytes", 512 * 1024 * 1024,
            )
            hsdp_size = self.hsdp_replica_count if self.hsdp_replica_count is not None else 1
            max_batch_numel = broadcast_max_bytes * hsdp_size if hsdp_size > 1 else float('inf')

        self._hsdp_batches: Dict[int, List[Dict]] = {}

        for group_key, (_, hsdp_groups) in self._hsdp_grouping.items():
425
426
427
428
429
430
431
432
433
434
435

                    # Soft limit: allow the bucket to slightly exceed the cap
                    # so that symmetric structures stay together and fragmentation
                    # is reduced (same approach as PyTorch FSDP/DDP bucketing).
                    if current_numel >= max_batch_numel:
                        sub_batches.append(current_batch)
                        current_batch = []
                        current_numel = 0

                if current_batch:
                    sub_batches.append(current_batch)
443
444
445
446
447
448
449
450
451
                })

            # log batch split info.
            total_sub_batches = sum(len(bg["sub_batches"]) for bg in batch_groups)
            logger.info_rank0(
                "[HSDP Batch] group_key=%s, num_hsdp_groups=%s, num_batch_groups=%s, "
                "total_sub_batches=%s, group_numels=%s, max_batch_numel=%s",
                group_key,
                len(hsdp_groups),
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
        # When hsdp_replica_count is set, remap coordinates into
        # sub-groups: each original coord maps to (coord % sub_size)
        # within its sub-group, and the effective group size shrinks.
        # Supports per-dimension control via Tuple[int, ...].
        if self.hsdp_replica_count is not None:
            if isinstance(self.hsdp_replica_count, int):
                dedup_per_dim = (self.hsdp_replica_count,) * len(replicate_group_ranks)
            else:
                dedup_per_dim = self.hsdp_replica_count
            replicate_group_ranks = tuple(
                r % dedup_per_dim[i] for i, r in enumerate(replicate_group_ranks)
            )
            replicate_sizes = tuple(
                min(s, dedup_per_dim[i]) for i, s in enumerate(replicate_sizes)
            )

        # Greedy owner assignment on this sub-batch's records.
556
557
558
559
560
561
562
563
564
                hsdp_group = bg["hsdp_group"]
                sub_batch_assigns: List[HSDPGroupAssignment] = []

                for sub_batch_entries in bg["sub_batches"]:
                    hsdp_assign = self._build_sub_batch_assignment(sub_batch_entries, hsdp_group)
                    if hsdp_assign is not None:
                        sub_batch_assigns.append(hsdp_assign)

                assignment_batch_groups.append({
741
742
743
744
745
746
747
748
749
750
751
        Args:
            target: "param" or "state".
            state_keys: State dict keys to broadcast when target="state".
        """
        device = torch.npu.current_device() if torch.npu.is_available() else torch.cuda.current_device()

        alignment = 512  # bytes
        rank_dtype_tensors = self._collect_broadcast_tensors(target, state_keys)

        for (src_coord, dtype, replicate_pgs), tensors in rank_dtype_tensors.items():
            if not tensors:
900
901
902
903
904
905
906
907
908
            tensors: List[torch.Tensor],
            async_op: bool = True,
    ) -> None:
        """Pack and broadcast tensors for one broadcast key."""
        device = torch.npu.current_device() if torch.npu.is_available() else torch.cuda.current_device()
        src_coord, dtype, replicate_pgs = key
        alignment = 512  # bytes

        local_coord = tuple(
hyper_parallel/core/optimizer/sharding_category.py
21
22
23
24
25
26
27
28
29

import torch
import torch.distributed as dist

from hyper_parallel.core.optimizer.dtensor_compat import (
    DTensor,
    DeviceMesh,
    Shard,
    StridedShard,
272
273
274
275
276
277
278
279
280
281
    no_comm_params: List[DTensor] = []
    groups: Dict[HSDPGroupKey, HSDPCommGroup] = {}

    for param_index, param in enumerate(params):
        if not isinstance(param, DTensor):
            no_comm_params.append(param)
            continue

        shard_spec = extract_param_shard_spec(param)
        comm_key = build_comm_domain_key(shard_spec)