Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/shard/ops/parallel_lightning_indexer.py 28.6% 50-52,173,175
hyper_parallel/core/shard/ops/parallel_lightning_indexer.py
46
47
48
49
50
51
52
53
54
55
56


def _to_local(t):
    """Extract local tensor from DTensor, or pass through non-DTensor values."""
    if isinstance(t, DTensor):
        return t.to_local()
    return t


def _normalize_lightning_indexer_args(
        query,
169
170
171
172
173
174
175
176
177
178
179

        qlen_kw = local_kwargs.get('actual_seq_lengths_query')
        klen_kw = local_kwargs.get('actual_seq_lengths_key')
        if qlen_kw is not None:
            local_kwargs['actual_seq_lengths_query'] = _to_local(qlen_kw)
        if klen_kw is not None:
            local_kwargs['actual_seq_lengths_key'] = _to_local(klen_kw)

        local_args = (query_index.to_local(), key_index.to_local(), weights.to_local())

        cache_values = [query_index.layout, key_index.layout, weights.layout, layout_str]