if col_map != -1:
col_coord = mesh.get_local_rank(len(mesh.mesh_shape) - 1 - col_map)
def _tril_expand_impl(*args, **_kwargs):
local_input = args[0]
runtime_diagonal = args[1]
row_offset = row_coord * local_input.shape[ndim - 2] if row_map != -1 else 0
col_offset = col_coord * local_input.shape[ndim - 1] if col_map != -1 else 0
return func(local_input, runtime_diagonal + row_offset - col_offset)
return _tril_expand_impl