Diff: origin/master...HEAD, staged and unstaged changes
182 183 184 185 186 187 188 189
mapping = (mapping,) split_id, coef = 0, 1 for dim in reversed(mapping): if dim == -1: continue split_id += dev_id_list[-dim - 1] * coef coef *= mesh_shape[-dim - 1] return split_id
195 196 197 198 199 200 201 202 203 204 205
if row_split_id == 0 and col_split_id == 0: return None def _expand_impl(*args, **_kwargs): local_input = args[0] cur_diagonal = args[1] row_offset = row_split_id * local_input.shape[ndim - 2] col_offset = col_split_id * local_input.shape[ndim - 1] return func(local_input, cur_diagonal + row_offset - col_offset) return _expand_impl