Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/shard/ops/parallel_tril.py 90.2% 186,199-203
hyper_parallel/core/shard/ops/parallel_tril.py
182
183
184
185
186
187
188
189
                mapping = (mapping,)
            split_id, coef = 0, 1
            for dim in reversed(mapping):
                if dim == -1:
                    continue
                split_id += dev_id_list[-dim - 1] * coef
                coef *= mesh_shape[-dim - 1]
            return split_id
195
196
197
198
199
200
201
202
203
204
205
        if row_split_id == 0 and col_split_id == 0:
            return None

        def _expand_impl(*args, **_kwargs):
            local_input = args[0]
            cur_diagonal = args[1]
            row_offset = row_split_id * local_input.shape[ndim - 2]
            col_offset = col_split_id * local_input.shape[ndim - 1]
            return func(local_input, cur_diagonal + row_offset - col_offset)

        return _expand_impl