Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/shard/ops/parallel_tril.py 90.7% 188-192
hyper_parallel/core/shard/ops/parallel_tril.py
184
185
186
187
188
189
190
191
192
193
194
        if col_map != -1:
            col_coord = mesh.get_local_rank(len(mesh.mesh_shape) - 1 - col_map)

        def _tril_expand_impl(*args, **_kwargs):
            local_input = args[0]
            runtime_diagonal = args[1]
            row_offset = row_coord * local_input.shape[ndim - 2] if row_map != -1 else 0
            col_offset = col_coord * local_input.shape[ndim - 1] if col_map != -1 else 0
            return func(local_input, runtime_diagonal + row_offset - col_offset)

        return _tril_expand_impl