Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/shard/ops/parallel_mhc_pre_sinkhorn.py 29.2% 78-80,83-94,296-297
hyper_parallel/platform/mindspore/custom_ops/custom_op_impl.py 40.5% 67,71,80-93,435,442,444-446,452,456-457,464
hyper_parallel/core/shard/ops/parallel_mhc_pre_sinkhorn.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98


def _normalize_mhc_pre_clamp_sinkhorn_args(*args, **kwargs):
    """Normalize npu_mhc_pre_clamp_sinkhorn arguments."""
    values = dict(_MHC_PRE_CLAMP_DEFAULTS)
    if len(args) > len(_MHC_PRE_CLAMP_ARG_NAMES):
        raise TypeError(
            f"npu_mhc_pre_clamp_sinkhorn expected at most {len(_MHC_PRE_CLAMP_ARG_NAMES)} arguments"
        )
    for name, value in zip(_MHC_PRE_CLAMP_ARG_NAMES, args):
        values[name] = value
    for name, value in kwargs.items():
        if name not in _MHC_PRE_CLAMP_ARG_NAMES:
            raise TypeError(f"npu_mhc_pre_clamp_sinkhorn got an unexpected keyword argument '{name}'")
        if name in _MHC_PRE_CLAMP_ARG_NAMES[:len(args)]:
            raise TypeError(f"npu_mhc_pre_clamp_sinkhorn got multiple values for argument '{name}'")
        values[name] = value
    missing = [name for name in _MHC_PRE_CLAMP_ARG_NAMES[:4] if name not in values]
    if missing:
        raise TypeError(f"npu_mhc_pre_clamp_sinkhorn missing required arguments: {missing}")
    return tuple(values[name] for name in _MHC_PRE_CLAMP_ARG_NAMES), {}


# Validation rules table for npu_mhc_pre_sinkhorn
# Key: tensor_map length (format identifier)
292
293
294
295
296
297
298
299
300
301
        ]
        return local_args, local_kwargs, cache_values

    def infer_layout(self, layouts: list, extra_args=None) -> Tuple[tuple, None]:
        del extra_args
        x_layout, phi_layout, alpha_layout, bias_layout = layouts

        self._check_partial_inputs([x_layout, phi_layout, alpha_layout, bias_layout])
        _validate_input_layouts_mhc_pre_sinkhorn(
            x_layout, phi_layout, alpha_layout, bias_layout
hyper_parallel/platform/mindspore/custom_ops/custom_op_impl.py
63
64
65
66
67
68
69
70
71
72
73
74
75


def _bind_mhc_pre_clamp_args(args, kwargs):
    """Bind npu_mhc_pre_clamp_sinkhorn arguments with Python defaults."""
    names = (
        "x", "phi", "alpha", "bias", "hc_mult", "num_iters",
        "hc_eps", "norm_eps", "out_flag", "clamp_min", "clamp_max",
    )
    values = {
        "hc_mult": 4,
        "num_iters": 20,
        "hc_eps": 1e-6,
        "norm_eps": 1e-6,
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        "out_flag": True,
        "clamp_min": 0.0,
        "clamp_max": 0.0,
    }
    if len(args) > len(names):
        raise TypeError(f"npu_mhc_pre_clamp_sinkhorn expected at most {len(names)} arguments")
    for name, value in zip(names, args):
        values[name] = value
    for name, value in kwargs.items():
        if name in values and name in names[:len(args)]:
            raise TypeError(f"npu_mhc_pre_clamp_sinkhorn got multiple values for argument '{name}'")
        if name not in names:
            raise TypeError(f"npu_mhc_pre_clamp_sinkhorn got an unexpected keyword argument '{name}'")
        values[name] = value
    missing = [name for name in names[:4] if name not in values]
    if missing:
        raise TypeError(f"npu_mhc_pre_clamp_sinkhorn missing required arguments: {missing}")
    return _MhcPreClampArgs(*(values[name] for name in names))


def _build_custom_ops():
    return ms.ops.CustomOpBuilder(
431
432
433
434
435
436
437
438
439

    @staticmethod
    def forward(ctx, *args, **kwargs):
        """Forward pass: delegates to the clamp-enabled Ascend custom kernel."""
        bound = _bind_mhc_pre_clamp_args(args, kwargs)
        result = _custom_ops.npu_mhc_pre_clamp_sinkhorn(
            bound.x, bound.phi, bound.alpha, bound.bias,
            bound.hc_mult, bound.num_iters, bound.hc_eps, bound.norm_eps,
            bound.out_flag, bound.clamp_min, bound.clamp_max
438
439
440
441
442
443
444
445
446
447
448
449
450
            bound.hc_mult, bound.num_iters, bound.hc_eps, bound.norm_eps,
            bound.out_flag, bound.clamp_min, bound.clamp_max
        )
        _, _, _, h_pre, hc_before_norm, inv_rms, sum_out, norm_out, h_res_logits = result
        ctx.save_for_backward(bound.x, bound.phi, bound.alpha, bound.bias,
                              h_pre, hc_before_norm, inv_rms, sum_out, norm_out, h_res_logits)
        ctx.hc_eps = bound.hc_eps
        ctx.clamp_min = bound.clamp_min
        ctx.clamp_max = bound.clamp_max
        return result

    @staticmethod
    def backward(ctx, *grad_outputs):
448
449
450
451
452
453
454
455
456
457
458
459
460
461

    @staticmethod
    def backward(ctx, *grad_outputs):
        """Backward pass: calls npu_mhc_pre_clamp_sinkhorn_backward kernel."""
        tensors = _ensure_contiguous(
            grad_outputs[0], grad_outputs[1], grad_outputs[2],
            *ctx.saved_tensors
        )
        n = tensors[1].shape[-1]
        grad_h_res = ms.ops.reshape(tensors[2], tuple(tensors[2].shape[:-1]) + (n, n))

        grads = _custom_ops.npu_mhc_pre_clamp_sinkhorn_backward(
            tensors[0], tensors[1], grad_h_res,
            tensors[3], tensors[4], tensors[5], tensors[6],
460
461
462
463
464
            tensors[0], tensors[1], grad_h_res,
            tensors[3], tensors[4], tensors[5], tensors[6],
            tensors[7], tensors[8], tensors[9], tensors[10], tensors[11], tensors[12],
            ctx.hc_eps, ctx.clamp_min, ctx.clamp_max)
        return tuple(grads[:4]) + _MHC_PRE_CLAMP_NONE_GRADS