Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/shard/_op_dispatch.py 73.3% 564,566,569,960,962,1023-1025,1051,1053-1054,1056,1058-1059,1075,1114,1131-1133,1206
hyper_parallel/core/tensor_parallel/__init__.py 100%  
hyper_parallel/core/tensor_parallel/_ce_op_registry.py 85.7% 92,101,130
hyper_parallel/core/tensor_parallel/loss_parallel.py 95.7% 172,181
hyper_parallel/core/tensor_parallel/loss_parallel_ops_common.py 38.6% 55,58,67-69,74-80,85-87,92-94,123,125-127,131,136-138,142,176,178-180,185,190-192,197-198,204-205,210-211,223-224
hyper_parallel/core/tensor_parallel/style.py 100%  
hyper_parallel/platform/mindspore/dtensor.py 31.6% 23-24,33-35,37,264-266,282-285
hyper_parallel/platform/mindspore/loss_parallel_ops.py 0.0% 20,22,24-27,29-30,39-40,42,44,52,54,57,72-73,76,96,98-99,101,103,105,107,109,112,135,137,139,141-142,144-145,147,149,151-152,154,156,158,160-164,166,170,172-177,179,181-183,185-186,189,192-193,205-209,211,215,225-228,232,241-248,250-253,255,264-271,273-274,283-290,292,294-295,297,307-313,315-316,318,320,322-323,325,327-330,332,334,336-340,342-344,346,348-349,351-353,355,357,359-360,362-365,367-368,370,373,398-400,402-404,408-410,412-413,415,427-429,431-433,435,440-441,443,445,456,459,476-483,485
hyper_parallel/platform/torch/dtensor.py 27.3% 277,369-371,387-390
hyper_parallel/platform/torch/loss_parallel_ops.py 0.0% 20,22,24-25,27-28,37-38,40,42,50,52,55,70-71,74,94,96-97,99,101,103,105,107,110,133,135,137,139-140,142-143,145,147,149-150,152,154,156,158-162,164,168,170-175,177,179-180,183,185,190,193-194,206-210,212,216,226-229,233,242-249,251-256,258,267-274,276-277,286-293,295,297-298,300,310-316,318-319,321,323,325-326,328,330-333,335,337,339-343,345-347,349,351-352,354-356,358,360,362-363,365-367,369-370,372,375,400-402,404-406,410-412,414-415,417,429-431,433-435,437,442-443,445,447,458,461,478-485,487
hyper_parallel/core/shard/_op_dispatch.py
560
561
562
563
564
565
566
567
568
569
570
571
572
        # Can be called as:
        #   reshape(-1, 1024) -> args = (tensor, -1, 1024)
        #   reshape((-1, 1024)) -> args = (tensor, (-1, 1024))
        #   reshape([4, 16, 1024]) -> args = (tensor, [4, 16, 1024])
        if len(args) > 2:
            # Multiple int arguments: reshape(-1, 1024)
            shape = args[1:]
        else:
            # Single argument: reshape((4, 16, 1024)) or reshape(4)
            shape = args[1]

        layout = input_tensor.layout
        input_layouts = [layout]
956
957
958
959
960
961
962
963
964
965
966
        def gather(value: object) -> object:
            if isinstance(value, DTensor):
                return value.full_tensor()
            if isinstance(value, tuple):
                return tuple(gather(e) for e in value)
            if isinstance(value, list):
                return [gather(e) for e in value]
            return value

        gathered_args = [gather(arg) for arg in args]
        gathered_kwargs = {k: gather(v) for k, v in kwargs.items()}
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
                has_vocab_sharded_dtensor = True
                break
        if not has_vocab_sharded_dtensor:
            for val in kwargs.values():
                if isinstance(val, DTensor) and _is_shard_on_last_dim(val):
                    has_vocab_sharded_dtensor = True
                    break

        if has_vocab_sharded_dtensor:
            raise ValueError(
                f"Operator '{op_name}' is a decomposed component of cross_entropy and should not be called "
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062

        Returns:
            Result of the distributed cross_entropy computation.
        """
        if platform.platform_type == PlatformType.PYTORCH:
            # pylint: disable=C0415
            from hyper_parallel.platform.torch.loss_parallel_ops import distributed_cross_entropy_from_op_call
        elif platform.platform_type == PlatformType.MINDSPORE:
            # pylint: disable=C0415
            from hyper_parallel.platform.mindspore.loss_parallel_ops import distributed_cross_entropy_from_op_call
        else:
            raise RuntimeError(f"Unsupported platform for loss_parallel: {platform.platform_type}")
        return distributed_cross_entropy_from_op_call(op_call, args, kwargs)

    def _check_ce_op_without_loss_parallel_context(self, op_name: str, args: tuple):
        """Check if CE op is called with Shard(-1) DTensor outside loss_parallel context.
1071
1072
1073
1074
1075
1076
1077
1078
1079
        if is_loss_parallel_active() or not is_loss_parallel_op(op_name):
            return

        if len(args) == 0 or not isinstance(args[0], DTensor):
            return

        logits = args[0]
        if _is_shard_on_last_dim(logits):
            raise ValueError(
1110
1111
1112
1113
1114
1115
1116
1117
1118
            if has_dtensor:
                self._check_ce_op_without_loss_parallel_context(op_name, args)

                if not is_loss_parallel_op(op_name):
                    raise RuntimeError(
                        f"Operator {op_name} does not contain parallel layout infer func. "
                        f"DTensor dispatch requires explicit layout inference registration. "
                        f"Please register a distributed operator for '{op_name}' or use local tensors."
                    )
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
                    logits = gathered_args[0]
                    targets = gathered_args[1]
                    if isinstance(logits, Tensor) and isinstance(targets, Tensor):
                        if logits.ndim > 2 and targets.ndim > 1 and targets.ndim == logits.ndim - 1:
                            vocab_size = logits.shape[-1]
                            gathered_args[0] = logits.reshape(-1, vocab_size)
                            gathered_args[1] = targets.reshape(-1)

                return op_call(*gathered_args, **gathered_kwargs)
            raise RuntimeError(f"Operator {op_name} does not contain parallel layout infer func.")
1202
1203
1204
1205
1206
1207
1208
1209

        self._check_decomposed_ce_op_in_loss_parallel(op_name, args, kwargs)

        if self._should_dispatch_loss_parallel(op_name):
            return self._dispatch_loss_parallel(op_call, args, kwargs)

        if op_name not in self.layout_infer_ops and get_distributed_op(op_name) is not None:
            self.layout_infer_ops[op_name] = {}
hyper_parallel/core/tensor_parallel/_ce_op_registry.py
88
89
90
91
92
93
94
95
96

    Args:
        *names: Operator names to register (platform.get_op_name results).
    """
    _EXTENDED_CE_OP_NAMES.update(names)


def unregister_loss_parallel_op_names(*names: str) -> None:
    """Remove registered CE operator names (for test cleanup).
 97
 98
 99
100
101
102
103
104
105

    Args:
        *names: Operator names to remove.
    """
    _EXTENDED_CE_OP_NAMES.difference_update(names)


def is_loss_parallel_op(op_name: str) -> bool:
    """Check if operator is a CE entry point (e.g., cross_entropy).
126
127
128
129
130


def clear_extended_op_names() -> None:
    """Clear extended operator name set (for test cleanup)."""
    _EXTENDED_CE_OP_NAMES.clear()
hyper_parallel/core/tensor_parallel/loss_parallel.py
168
169
170
171
172
173
174
175
176

    Returns:
        Optional[DeviceMesh]: Explicitly specified mesh, or None (infer from DTensor).
    """
    return _loss_parallel_mesh.get()


def _get_loss_parallel_strict() -> bool:
    """Get the strict setting in current context.
177
178
179
180
181

    Returns:
        bool: Whether in strict mode.
    """
    return _loss_parallel_strict.get()
hyper_parallel/core/tensor_parallel/loss_parallel_ops_common.py
51
52
53
54
55
56
57
58
59
60
61
62

def _is_shard_on_last_dim(dtensor: DTensor) -> bool:
    """Check if DTensor is Shard on the last dimension."""
    if not _is_dtensor(dtensor):
        return False
    placements = dtensor.placements
    if not placements:
        return False
    last_placement = placements[-1]
    if isinstance(last_placement, Shard):
        return last_placement.dim in (-1, len(dtensor.shape) - 1)
    return False
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98


def _is_replicate(dtensor: DTensor) -> bool:
    """Check if DTensor is Replicate."""
    if not _is_dtensor(dtensor):
        return False
    return all(isinstance(p, Replicate) for p in dtensor.placements)


def _get_mesh_and_dim(dtensor: DTensor) -> tuple[Optional["DeviceMesh"], Optional[int]]:
    """Get mesh and shard dimension from DTensor."""
    mesh = dtensor.device_mesh
    shard_dim = None
    for p in dtensor.placements:
        if isinstance(p, Shard):
            shard_dim = p.dim if p.dim >= 0 else len(dtensor.shape) + p.dim
            break
    return mesh, shard_dim


def _get_local_tensor(tensor):
    """Get DTensor local tensor, or return regular tensor."""
    if _is_dtensor(tensor):
        return tensor._local_tensor  # type: ignore  # pylint: disable=W0212
    return tensor


def _get_full_tensor(tensor):
    """Get full tensor (all_gather)."""
    if _is_dtensor(tensor):
        return tensor.full_tensor()  # type: ignore
    return tensor


def _validate_target_type_base(is_floating: bool) -> None:
    """Validate target type (base check).
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

    Raises:
        ValueError: If mesh is not 1D or input is not Shard(-1) in strict mode.
    """
    mesh = dtensor.device_mesh

    if mesh.ndim != 1:
        if strict:
            raise ValueError(
                f"Expected 1D TP mesh, got {mesh.ndim}D mesh. "
                "Slice a 1D sub-mesh first: mesh['tp']"
            )
        warnings.warn(
            f"Expected 1D TP mesh, got {mesh.ndim}D mesh. "
            "This may cause incorrect results."
        )

    if not _is_shard_on_last_dim(dtensor):
        if strict:
            raise ValueError(
                "Expected Shard(-1) on class dimension. "
                f"Got placements: {dtensor.placements}"
            )
        warnings.warn(
            f"Expected Shard(-1) on class dimension. "
            f"Got placements: {dtensor.placements}. "
            "This may cause incorrect results."
        )
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

    Raises:
        ValueError: If parameters are invalid.
    """
    _validate_target_type_base(is_floating_fn(target))

    if _is_dtensor(input_tensor):
        if not _is_shard_on_last_dim(input_tensor):
            raise ValueError(
                "input must be Shard(-1) on class dimension. "
                f"Got placements: {input_tensor.placements}"
            )
    else:
        raise ValueError(
            "input must be a DTensor when using loss_parallel. "
            f"Got type: {type(input)}"
        )

    if weight is not None and _is_dtensor(weight):
        if not _is_replicate(weight):
            raise ValueError(
                "weight must be Replicate when it's a DTensor. "
                f"Got placements: {weight.placements}"
            )

    if size_average is not None or reduce is not None:
        warnings.warn(
            "size_average and reduce arguments are deprecated. "
            "Please use reduction='mean' or reduction='sum' instead.",
            DeprecationWarning,
        )
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
            "Please use reduction='mean' or reduction='sum' instead.",
            DeprecationWarning,
        )

    if label_smoothing != 0.0:
        raise ValueError(
            "label_smoothing is not supported in loss_parallel. "
            "Please set label_smoothing=0.0 or disable loss_parallel."
        )

    if reduction not in ("none", "mean", "sum"):
        raise ValueError(f"Invalid reduction: {reduction}. Must be 'none', 'mean', or 'sum'.")


def _check_context_and_layout(dtensor: DTensor) -> None:  # pylint: disable=W0613
    """Check context and layout.
219
220
221
222
223
224
225
226
227

    Raises:
        ValueError: If not in loss_parallel context.
    """
    if not is_loss_parallel_active():
        raise ValueError(
            "Shard logits detected but not in loss_parallel context. "
            "Please wrap with loss_parallel() context manager."
        )
hyper_parallel/platform/mindspore/dtensor.py
19
20
21
22
23
24
25
26
27
28
from mindspore.common.initializer import initializer

try:
    from mindspore._c_expression import NoFallbackGuard  # pylint: disable=C0415
except ImportError:
    warnings.warn(
        "mindspore._c_expression.NoFallbackGuard not available; "
        "using no-op fallback guard. This may allow recursive fallback dispatch "
        "in some edge cases. Please upgrade MindSpore to a version that provides "
        "NoFallbackGuard for proper protection against recursive dispatch.",
29
30
31
32
33
34
35
36
37
38
39
40
41
        RuntimeWarning,
        stacklevel=2,
    )

    @contextmanager
    def _no_fallback_guard():
        yield

    NoFallbackGuard = _no_fallback_guard


class DTensorBase(Tensor):
    """
260
261
262
263
264
265
266
267
268
269
        self._local_tensor.param_info = local_param_info_value

    def _alias_placements(self):
        """Return alias_placements from layout, falling back to _placements."""
        if hasattr(self, '_layout') and self._layout is not None:
            return self._layout.alias_placements
        return self._placements

    def to(self, *args, **kwargs):
        """Move the DTensor to a different device or dtype.
278
279
280
281
282
283
284
285

        Returns:
            DTensorBase: A new DTensor with the converted local tensor.
        """
        new_local = self._local_tensor.to(*args, **kwargs)
        new_dt = Tensor._make_subclass(type(self), new_local)
        new_dt.__init_data__(new_local, self._device_mesh, self._alias_placements())
        return new_dt
hyper_parallel/platform/mindspore/loss_parallel_ops.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

Distributed cross-entropy kernel implementation using mindspore.common._grad_function._Function.
"""

from __future__ import annotations

from typing import Any, Optional, Tuple

import mindspore as ms  # pylint: disable=C0415
from mindspore import mint  # pylint: disable=C0415
from mindspore.common._grad_function import _Function  # pylint: disable=C0415
from mindspore.common.tensor import Tensor  # pylint: disable=C0415

from hyper_parallel.core.dtensor.device_mesh import DeviceMesh
from hyper_parallel.core.tensor_parallel.loss_parallel_ops_common import (
    _is_dtensor,
    _is_shard_on_last_dim,
    _get_mesh_and_dim,
    _get_local_tensor,
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    _validate_cross_entropy_params,
    _check_context_and_layout,
    _validate_mesh_and_shard,
)
from hyper_parallel.core.tensor_parallel.loss_parallel import _get_loss_parallel_strict
from hyper_parallel.platform import get_platform

platform = get_platform()

__all__ = [
    "distributed_cross_entropy",
    "distributed_log_softmax",
    "distributed_nll_loss_forward",
    "DistributedCrossEntropyFunction",
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    "DistributedCrossEntropyFunction",
]


def _is_floating_ms(tensor: Tensor) -> bool:
    """Check if MindSpore tensor is floating point."""
    return tensor.dtype in (ms.float16, ms.float32, ms.float64)


def _compute_vocab_start(vocab_size: int, tp_size: int, rank: int) -> int:
    """Compute the starting index for this rank's vocab shard.

    Args:
        vocab_size: Total vocabulary size.
68
69
70
71
72
73
74
75
76
77
78
79
80
    Note:
        This follows torch.chunk behavior: chunk_size = ceil(vocab_size/tp_size),
        and each rank's start = rank * chunk_size. The last rank may have fewer elements.
    """
    chunk_size = (vocab_size + tp_size - 1) // tp_size  # ceil division
    return rank * chunk_size


def distributed_log_softmax(
    logits_local: Tensor,
    dim: int,
    mesh: DeviceMesh,
    mesh_dim: int = 0,
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

    Communication:
        MAX + SUM all_reduce
    """
    max_local = logits_local.max(axis=dim, keepdims=True)

    group = mesh.get_group(mesh_dim)
    max_global = platform.differentiable_all_reduce(max_local, op="max", group=group)

    exp_local = (logits_local - max_global).exp()

    sum_local = exp_local.sum(axis=dim, keepdims=True)

    sum_global = platform.differentiable_all_reduce(sum_local, op="sum", group=group)

    log_softmax = logits_local - max_global - sum_global.log()

    return log_softmax


def distributed_nll_loss_forward(
    log_probs: Tensor,
    target: Tensor,
    weight: Optional[Tensor],
    ignore_index: int,
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

    Returns:
        Tuple of (loss, total_weight, target_mask, vocab_start_tensor).
    """
    batch_size = target.numel()

    target_flat = target.flatten()

    target_mask = (target_flat >= vocab_start) & (target_flat < vocab_end)

    ignore_mask = target_flat != ignore_index
    target_mask = target_mask & ignore_mask

    if reduction == "none":
        loss = mint.zeros((batch_size,), dtype=log_probs.dtype)
    else:
        loss = mint.zeros((1,), dtype=log_probs.dtype)

    total_weight = mint.zeros((1,), dtype=log_probs.dtype)

    if target_mask.any():
        local_target = target_flat[target_mask] - vocab_start

        log_probs_2d = log_probs.reshape(-1, log_probs.shape[-1])

        row_indices = mint.nonzero(target_mask).flatten()

        selected_log_probs = log_probs_2d[row_indices, local_target]

        if weight is not None:
            global_target = target_flat[target_mask]
            sample_weights = weight[global_target]
            selected_log_probs = selected_log_probs * sample_weights
            total_weight = sample_weights.sum().reshape(1)
        else:
            total_weight = ms.Tensor(
                target_mask.sum().asnumpy().item(), dtype=log_probs.dtype
            ).reshape(1)

        nll = -selected_log_probs

        if reduction == "none":
            loss_flat = mint.zeros((batch_size,), dtype=log_probs.dtype)
            loss_flat[target_mask] = nll
            loss = loss_flat.reshape(target.shape)
        elif reduction == "sum":
            loss = nll.sum().unsqueeze(0)
        else:
            loss = nll.sum().unsqueeze(0)
    else:
        if reduction == "none":
            loss = mint.zeros((batch_size,), dtype=log_probs.dtype).reshape(target.shape)
        total_weight = mint.zeros((1,), dtype=log_probs.dtype)

    vocab_start_tensor = ms.Tensor(vocab_start, dtype=ms.int64)
    return loss, total_weight, target_mask, vocab_start_tensor


class DistributedCrossEntropyFunction(_Function):
    """K3: Fused backward for distributed cross_entropy (MindSpore version)."""

    @staticmethod
    def forward(
        ctx: Any,
        input_local: Tensor,
        target: Tensor,
        weight: Optional[Tensor],
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        mesh: DeviceMesh,
        mesh_dim: int,
    ) -> Tensor:
        """Forward pass."""
        local_vocab_size = input_local.shape[-1]
        rank = mesh.get_local_rank(mesh_dim)
        tp_size = mesh.size(mesh_dim)
        vocab_start = _compute_vocab_start(vocab_size, tp_size, rank)
        vocab_end = vocab_start + local_vocab_size

        log_probs_local = distributed_log_softmax(
            input_local, dim=-1, mesh=mesh, mesh_dim=mesh_dim
        )

        loss, total_weight, target_mask, vocab_start_tensor = distributed_nll_loss_forward(
            log_probs_local,
            target,
            weight,
            ignore_index,
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
            vocab_start,
            vocab_end,
        )

        if reduction == "mean":
            group = mesh.get_group(mesh_dim)
            total_loss = platform.differentiable_all_reduce(loss, op="sum", group=group)
            total_weight_sum = platform.differentiable_all_reduce(
                total_weight, op="sum", group=group
            )

            ctx.save_for_backward(
                input_local,
                log_probs_local,
                target,
                weight,
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
                total_weight_sum,
                target_mask,
                vocab_start_tensor,
            )
            ctx.reduction = reduction
            ctx.ignore_index = ignore_index
            ctx.vocab_size = vocab_size
            ctx.local_vocab_size = local_vocab_size
            ctx.mesh = mesh
            ctx.mesh_dim = mesh_dim
            ctx.vocab_start = vocab_start
            ctx.vocab_end = vocab_end

            return total_loss / total_weight_sum.clamp(min=1e-12)
        if reduction == "sum":
            group = mesh.get_group(mesh_dim)
            total_loss = platform.differentiable_all_reduce(loss, op="sum", group=group)

            ctx.save_for_backward(
                input_local,
                log_probs_local,
                target,
                weight,
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
                mint.zeros((1,), dtype=loss.dtype),
                target_mask,
                vocab_start_tensor,
            )
            ctx.reduction = reduction
            ctx.ignore_index = ignore_index
            ctx.vocab_size = vocab_size
            ctx.local_vocab_size = local_vocab_size
            ctx.mesh = mesh
            ctx.mesh_dim = mesh_dim
            ctx.vocab_start = vocab_start
            ctx.vocab_end = vocab_end

            return total_loss
        ctx.save_for_backward(
            input_local,
            log_probs_local,
            target,
            weight,
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
            mint.zeros((1,), dtype=loss.dtype),
            target_mask,
            vocab_start_tensor,
        )
        ctx.reduction = reduction
        ctx.ignore_index = ignore_index
        ctx.vocab_size = vocab_size
        ctx.local_vocab_size = local_vocab_size
        ctx.mesh = mesh
        ctx.mesh_dim = mesh_dim
        ctx.vocab_start = vocab_start
        ctx.vocab_end = vocab_end

        return loss

    @staticmethod
    def backward(ctx: Any, grad_output: Tensor) -> Tuple[Optional[Tensor], ...]:
        """Backward pass (vectorized implementation)."""
        (
            _,
            log_probs_local,
            target,
            weight,
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
            _,
            _,
        ) = ctx.saved_tensors

        reduction = ctx.reduction
        ignore_index = ctx.ignore_index
        _ = ctx.local_vocab_size
        vocab_start = ctx.vocab_start
        vocab_end = ctx.vocab_end
        _ = ctx.mesh
        _ = ctx.mesh_dim

        batch_size = target.numel()
        target_flat = target.flatten()

        softmax_local = log_probs_local.exp()

        ignore_mask = target_flat != ignore_index

        if weight is not None:
            sample_weights = weight[target_flat]
        else:
            sample_weights = None

        if reduction == "mean":
            grad_scale = grad_output / total_weight.clamp(min=1e-12)
        elif reduction == "sum":
            grad_scale = grad_output
        else:
            grad_scale = grad_output.flatten()

        in_vocab_mask = (target_flat >= vocab_start) & (target_flat < vocab_end) & ignore_mask

        if reduction == "none":
            grad_scale_expanded = grad_scale.unsqueeze(-1)
            if sample_weights is not None:
                grad_scale_expanded = grad_scale_expanded * sample_weights.unsqueeze(-1)
            grad_input = softmax_local * grad_scale_expanded
        else:
            if sample_weights is not None:
                grad_scale = grad_scale * sample_weights.unsqueeze(-1)
            grad_input = softmax_local * grad_scale.unsqueeze(-1)

        local_targets = mint.where(in_vocab_mask, target_flat - vocab_start, mint.zeros_like(target_flat))

        if in_vocab_mask.any():
            row_indices = mint.arange(batch_size, dtype=ms.int64)

            if reduction == "none":
                if sample_weights is not None:
                    grad_values = -grad_scale * sample_weights
                else:
                    grad_values = -grad_scale
            else:
                grad_values = -grad_scale.expand_as(target_flat)

            grad_input = grad_input.contiguous()
            grad_input[row_indices[in_vocab_mask], local_targets[in_vocab_mask]] += grad_values[in_vocab_mask]

        if not ignore_mask.all():
            ignore_indices = ~ignore_mask
            if reduction == "none":
                grad_input[ignore_indices] = 0.0
            else:
                ignore_indices_expanded = ignore_indices.unsqueeze(-1).expand_as(grad_input)
                grad_input = mint.where(ignore_indices_expanded, grad_input, mint.zeros_like(grad_input))

        return grad_input, None, None, None, None, None, None, None


def distributed_cross_entropy(
    input_tensor: Tensor,
    target: Tensor,
    weight: Optional[Tensor] = None,
    size_average: Optional[bool] = None,
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419

    Returns:
        Loss tensor.
    """
    input_dtensor = None
    mesh = None
    vocab_size = None

    if _is_dtensor(input_tensor):
        if not _is_shard_on_last_dim(input_tensor):
            raise ValueError(
                "input must be Shard(-1) on class dimension. "
                f"Got placements: {input_tensor.placements}"
            )
        input_dtensor = input_tensor
        mesh, _ = _get_mesh_and_dim(input_tensor)
        vocab_size = input_tensor.shape[-1]

    input_for_check = input_dtensor if input_dtensor is not None else input_tensor
    _check_context_and_layout(input_for_check)  # type: ignore

    _validate_cross_entropy_params(
        input_tensor,
        target,
        weight,
        size_average,
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        label_smoothing,
        _is_floating_ms,
    )

    if input_dtensor is not None:
        input_local = _get_local_tensor(input_dtensor)
        local_vocab_size = input_local.shape[-1]

        if input_dtensor.ndim > 2:
            input_local = input_local.reshape(-1, local_vocab_size)
            target = target.reshape(-1)
    else:
        raise ValueError(
            "input must be a DTensor when using loss_parallel. "
            f"Got type: {type(input_tensor)}"
        )

    strict = _get_loss_parallel_strict()
    _validate_mesh_and_shard(input_dtensor, strict)  # type: ignore

    mesh_dim = 0

    loss = DistributedCrossEntropyFunction.apply(
        input_local,
        target,
        weight,
        ignore_index,
452
453
454
455
456
457
458
459
460
461
462
463
        mesh,
        mesh_dim,
    )

    return loss


def distributed_cross_entropy_from_op_call(
    op_call: Any,  # pylint: disable=W0613
    args: tuple,
    kwargs: dict,
) -> Tensor:
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489

    Returns:
        Loss tensor.
    """
    input_tensor = args[0] if len(args) > 0 else kwargs.get("input")
    target = args[1] if len(args) > 1 else kwargs.get("target")
    weight = args[2] if len(args) > 2 else kwargs.get("weight")
    size_average = args[3] if len(args) > 3 else kwargs.get("size_average")
    ignore_index = args[4] if len(args) > 4 else kwargs.get("ignore_index", -100)
    reduce = args[5] if len(args) > 5 else kwargs.get("reduce")
    reduction = args[6] if len(args) > 6 else kwargs.get("reduction", "mean")
    label_smoothing = args[7] if len(args) > 7 else kwargs.get("label_smoothing", 0.0)

    return distributed_cross_entropy(
        input_tensor=input_tensor,
        target=target,
        weight=weight,
        size_average=size_average,
hyper_parallel/platform/torch/dtensor.py
273
274
275
276
277
278
279
280
281
        """
        if dtype is None:
            return self._local_tensor.type()
        new_local = self._local_tensor.to(dtype=dtype, non_blocking=non_blocking)
        return self.__class__(new_local, device_mesh=self._device_mesh, placements=self._alias_placements())

    def size(self, dim: Optional[int] = None):
        """
        Get the size of this tensor.
365
366
367
368
369
370
371
372
373
374

    # ====================== Auxiliary print ======================
    def _alias_placements(self):
        """Return alias_placements from layout, falling back to _placements."""
        if hasattr(self, '_layout') and self._layout is not None:
            return self._layout.alias_placements
        return self._placements

    def to(self, *args, **kwargs):
        """Move the DTensor to a different device or dtype.
383
384
385
386
387
388
389
390
391
392
393
394

        Returns:
            DTensorBase: A new DTensor with the converted local tensor.
        """
        new_local = self._local_tensor.to(*args, **kwargs)
        new_dt = Tensor._make_subclass(type(self), new_local, new_local.requires_grad)
        new_dt.__init_data__(new_local, self._device_mesh, self._alias_placements())
        return new_dt

    def __repr__(self) -> str:
        return (
            f"DTensor(\n"
hyper_parallel/platform/torch/loss_parallel_ops.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

Distributed cross-entropy kernel implementation using torch.autograd.Function.
"""

from __future__ import annotations

from typing import Any, Optional, Tuple

import torch  # pylint: disable=C0415
from torch import Tensor  # pylint: disable=C0415

from hyper_parallel.core.dtensor.device_mesh import DeviceMesh
from hyper_parallel.core.tensor_parallel.loss_parallel_ops_common import (
    _is_dtensor,
    _is_shard_on_last_dim,
    _get_mesh_and_dim,
    _get_local_tensor,
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    _validate_cross_entropy_params,
    _check_context_and_layout,
    _validate_mesh_and_shard,
)
from hyper_parallel.core.tensor_parallel.loss_parallel import _get_loss_parallel_strict
from hyper_parallel.platform import get_platform

platform = get_platform()

__all__ = [
    "distributed_cross_entropy",
    "distributed_log_softmax",
    "distributed_nll_loss_forward",
    "DistributedCrossEntropyFunction",
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    "DistributedCrossEntropyFunction",
]


def _is_floating_torch(tensor: Tensor) -> bool:
    """Check if PyTorch tensor is floating point."""
    return tensor.is_floating_point()


def _compute_vocab_start(vocab_size: int, tp_size: int, rank: int) -> int:
    """Compute the starting index for this rank's vocab shard.

    Args:
        vocab_size: Total vocabulary size.
66
67
68
69
70
71
72
73
74
75
76
77
78
    Note:
        This follows torch.chunk behavior: chunk_size = ceil(vocab_size/tp_size),
        and each rank's start = rank * chunk_size. The last rank may have fewer elements.
    """
    chunk_size = (vocab_size + tp_size - 1) // tp_size  # ceil division
    return rank * chunk_size


def distributed_log_softmax(
    logits_local: Tensor,
    dim: int,
    mesh: DeviceMesh,
    mesh_dim: int = 0,
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

    Communication:
        MAX + SUM all_reduce
    """
    max_local = logits_local.max(dim=dim, keepdim=True).values

    group = mesh.get_group(mesh_dim)
    max_global = platform.differentiable_all_reduce(max_local, op="max", group=group)

    exp_local = (logits_local - max_global).exp()

    sum_local = exp_local.sum(dim=dim, keepdim=True)

    sum_global = platform.differentiable_all_reduce(sum_local, op="sum", group=group)

    log_softmax = logits_local - max_global - sum_global.log()

    return log_softmax


def distributed_nll_loss_forward(
    log_probs: Tensor,
    target: Tensor,
    weight: Optional[Tensor],
    ignore_index: int,
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

    Returns:
        Tuple of (loss, total_weight, target_mask, vocab_start_tensor).
    """
    batch_size = target.numel()

    target_flat = target.flatten()

    target_mask = (target_flat >= vocab_start) & (target_flat < vocab_end)

    ignore_mask = target_flat != ignore_index
    target_mask = target_mask & ignore_mask

    if reduction == "none":
        loss = torch.zeros(batch_size, dtype=log_probs.dtype, device=log_probs.device)
    else:
        loss = torch.zeros(1, dtype=log_probs.dtype, device=log_probs.device)

    total_weight = torch.zeros(1, dtype=log_probs.dtype, device=log_probs.device)

    if target_mask.any():
        local_target = target_flat[target_mask] - vocab_start

        log_probs_2d = log_probs.reshape(-1, log_probs.shape[-1])

        row_indices = torch.where(target_mask)[0]

        selected_log_probs = log_probs_2d[row_indices, local_target]

        if weight is not None:
            global_target = target_flat[target_mask]
            sample_weights = weight[global_target]
            selected_log_probs = selected_log_probs * sample_weights
            total_weight = sample_weights.sum().reshape(1)
        else:
            total_weight = torch.tensor(
                target_mask.sum().item(), dtype=log_probs.dtype, device=log_probs.device
            ).reshape(1)

        nll = -selected_log_probs

        if reduction == "none":
            loss_flat = torch.zeros(batch_size, dtype=log_probs.dtype, device=log_probs.device)
            loss_flat[target_mask] = nll
            loss = loss_flat.reshape(target.shape)
        elif reduction == "sum":
            loss = nll.sum().unsqueeze(0)
        else:
            loss = nll.sum().unsqueeze(0)
    else:
        if reduction == "none":
            loss = torch.zeros(
                batch_size, dtype=log_probs.dtype, device=log_probs.device
            ).reshape(target.shape)
        total_weight = torch.zeros(1, dtype=log_probs.dtype, device=log_probs.device)

    return loss, total_weight, target_mask, torch.tensor(
        vocab_start, dtype=torch.long, device=log_probs.device
    )


class DistributedCrossEntropyFunction(torch.autograd.Function):
    """K3: Fused backward for distributed cross_entropy."""

    @staticmethod
    def forward(
        ctx: Any,
        input_local: Tensor,
        target: Tensor,
        weight: Optional[Tensor],
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        mesh: DeviceMesh,
        mesh_dim: int,
    ) -> Tensor:
        """Forward pass."""
        local_vocab_size = input_local.shape[-1]
        rank = mesh.get_local_rank(mesh_dim)
        tp_size = mesh.size(mesh_dim)
        vocab_start = _compute_vocab_start(vocab_size, tp_size, rank)
        vocab_end = vocab_start + local_vocab_size

        log_probs_local = distributed_log_softmax(
            input_local, dim=-1, mesh=mesh, mesh_dim=mesh_dim
        )

        loss, total_weight, target_mask, vocab_start_tensor = distributed_nll_loss_forward(
            log_probs_local,
            target,
            weight,
            ignore_index,
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
            vocab_start,
            vocab_end,
        )

        if reduction == "mean":
            group = mesh.get_group(mesh_dim)
            total_loss = platform.differentiable_all_reduce(loss, op="sum", group=group)
            total_weight_sum = platform.differentiable_all_reduce(
                total_weight, op="sum", group=group
            )

            ctx.save_for_backward(
                input_local,
                log_probs_local,
                target,
                weight,
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
                total_weight_sum,
                target_mask,
                vocab_start_tensor,
            )
            ctx.reduction = reduction
            ctx.ignore_index = ignore_index
            ctx.vocab_size = vocab_size
            ctx.local_vocab_size = local_vocab_size
            ctx.mesh = mesh
            ctx.mesh_dim = mesh_dim
            ctx.vocab_start = vocab_start
            ctx.vocab_end = vocab_end

            if total_weight_sum.item() == 0:
                return torch.tensor(float('nan'), dtype=total_loss.dtype, device=total_loss.device)
            return total_loss / total_weight_sum
        if reduction == "sum":
            group = mesh.get_group(mesh_dim)
            total_loss = platform.differentiable_all_reduce(loss, op="sum", group=group)

            ctx.save_for_backward(
                input_local,
                log_probs_local,
                target,
                weight,
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
                torch.zeros(1, dtype=loss.dtype, device=loss.device),
                target_mask,
                vocab_start_tensor,
            )
            ctx.reduction = reduction
            ctx.ignore_index = ignore_index
            ctx.vocab_size = vocab_size
            ctx.local_vocab_size = local_vocab_size
            ctx.mesh = mesh
            ctx.mesh_dim = mesh_dim
            ctx.vocab_start = vocab_start
            ctx.vocab_end = vocab_end

            return total_loss
        ctx.save_for_backward(
            input_local,
            log_probs_local,
            target,
            weight,
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
            torch.zeros(1, dtype=loss.dtype, device=loss.device),
            target_mask,
            vocab_start_tensor,
        )
        ctx.reduction = reduction
        ctx.ignore_index = ignore_index
        ctx.vocab_size = vocab_size
        ctx.local_vocab_size = local_vocab_size
        ctx.mesh = mesh
        ctx.mesh_dim = mesh_dim
        ctx.vocab_start = vocab_start
        ctx.vocab_end = vocab_end

        return loss

    @staticmethod
    def backward(ctx: Any, grad_output: Tensor) -> Tuple[Optional[Tensor], ...]:
        """Backward pass (vectorized implementation)."""
        (
            _,
            log_probs_local,
            target,
            weight,
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
            _,
            _,
        ) = ctx.saved_tensors

        reduction = ctx.reduction
        ignore_index = ctx.ignore_index
        _ = ctx.local_vocab_size
        vocab_start = ctx.vocab_start
        vocab_end = ctx.vocab_end
        _ = ctx.mesh
        _ = ctx.mesh_dim

        batch_size = target.numel()
        target_flat = target.flatten()

        softmax_local = log_probs_local.exp()

        ignore_mask = target_flat != ignore_index

        if weight is not None:
            sample_weights = weight[target_flat]
        else:
            sample_weights = None

        if reduction == "mean":
            grad_scale = grad_output / total_weight.clamp(min=1e-12)
        elif reduction == "sum":
            grad_scale = grad_output
        else:
            grad_scale = grad_output.flatten()

        in_vocab_mask = (target_flat >= vocab_start) & (target_flat < vocab_end) & ignore_mask

        if reduction == "none":
            grad_scale_expanded = grad_scale.unsqueeze(-1)
            if sample_weights is not None:
                grad_scale_expanded = grad_scale_expanded * sample_weights.unsqueeze(-1)
            grad_input = softmax_local * grad_scale_expanded
        else:
            if sample_weights is not None:
                grad_scale = grad_scale * sample_weights.unsqueeze(-1)
            grad_input = softmax_local * grad_scale.unsqueeze(-1)

        local_targets = torch.where(in_vocab_mask, target_flat - vocab_start, torch.zeros_like(target_flat))

        if in_vocab_mask.any():
            row_indices = torch.arange(batch_size, device=target.device, dtype=torch.long)

            if reduction == "none":
                if sample_weights is not None:
                    grad_values = -grad_scale * sample_weights
                else:
                    grad_values = -grad_scale
            else:
                grad_values = -grad_scale.expand_as(target_flat)

            grad_input = grad_input.contiguous()
            grad_input[row_indices[in_vocab_mask], local_targets[in_vocab_mask]] += grad_values[in_vocab_mask]

        if not ignore_mask.all():
            if reduction == "none":
                grad_input[~ignore_mask] = 0.0
            else:
                ignore_indices_expanded = (~ignore_mask).unsqueeze(-1).expand_as(grad_input)
                grad_input[ignore_indices_expanded] = 0.0

        return grad_input, None, None, None, None, None, None, None


def distributed_cross_entropy(
    input_tensor: Tensor,
    target: Tensor,
    weight: Optional[Tensor] = None,
    size_average: Optional[bool] = None,
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421

    Returns:
        Loss tensor.
    """
    input_dtensor = None
    mesh = None
    vocab_size = None

    if _is_dtensor(input_tensor):
        if not _is_shard_on_last_dim(input_tensor):
            raise ValueError(
                "input must be Shard(-1) on class dimension. "
                f"Got placements: {input_tensor.placements}"
            )
        input_dtensor = input_tensor
        mesh, _ = _get_mesh_and_dim(input_tensor)
        vocab_size = input_tensor.shape[-1]

    input_for_check = input_dtensor if input_dtensor is not None else input_tensor
    _check_context_and_layout(input_for_check)  # type: ignore

    _validate_cross_entropy_params(
        input_tensor,
        target,
        weight,
        size_average,
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
        label_smoothing,
        _is_floating_torch,
    )

    if input_dtensor is not None:
        input_local = _get_local_tensor(input_dtensor)
        local_vocab_size = input_local.shape[-1]

        if input_dtensor.ndim > 2:
            input_local = input_local.reshape(-1, local_vocab_size)
            target = target.reshape(-1)
    else:
        raise ValueError(
            "input must be a DTensor when using loss_parallel. "
            f"Got type: {type(input_tensor)}"
        )

    strict = _get_loss_parallel_strict()
    _validate_mesh_and_shard(input_dtensor, strict)  # type: ignore

    mesh_dim = 0

    loss = DistributedCrossEntropyFunction.apply(
        input_local,
        target,
        weight,
        ignore_index,
454
455
456
457
458
459
460
461
462
463
464
465
        mesh,
        mesh_dim,
    )

    return loss


def distributed_cross_entropy_from_op_call(
    op_call: Any,  # pylint: disable=W0613
    args: tuple,
    kwargs: dict,
) -> Tensor:
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491

    Returns:
        Loss tensor.
    """
    input_tensor = args[0] if len(args) > 0 else kwargs.get("input")
    target = args[1] if len(args) > 1 else kwargs.get("target")
    weight = args[2] if len(args) > 2 else kwargs.get("weight")
    size_average = args[3] if len(args) > 3 else kwargs.get("size_average")
    ignore_index = args[4] if len(args) > 4 else kwargs.get("ignore_index", -100)
    reduce = args[5] if len(args) > 5 else kwargs.get("reduce")
    reduction = args[6] if len(args) > 6 else kwargs.get("reduction", "mean")
    label_smoothing = args[7] if len(args) > 7 else kwargs.get("label_smoothing", 0.0)

    return distributed_cross_entropy(
        input_tensor=input_tensor,
        target=target,
        weight=weight,
        size_average=size_average,