Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/context_parallel/__init__.py 100%  
hyper_parallel/core/context_parallel/async_context_parallel.py 63.3% 50-51,72,94-98,105,120-124,233-235,237,248-252,260,272,275,294,307,316,325,363-364,368,377-378,381,387,392,398-399,421,455,513,651-654
hyper_parallel/core/context_parallel/async_dsa_context_parallel.py 85.7% 61-63,83-84,92,98,103,106,110-111,116,122,131
hyper_parallel/core/context_parallel/dsa_context_parallel.py 98.1% 507
hyper_parallel/platform/mindspore/platform.py 21.2% 83,88-92,97-104,121-123,127-128,133-135,139-140,155-160,167-172,690-696,701-704,708-709,715-719,1233-1237,1241,1260-1264,1268,1277
hyper_parallel/platform/platform.py 66.7% 592,608,657
hyper_parallel/platform/torch/platform.py 85.2% 75-76,84-88,153
hyper_parallel/core/context_parallel/async_context_parallel.py
46
47
48
49
50
51
52
53
54
55
# ---------------------------------------------------------------------------

def _detach_if_available(tensor: Tensor) -> Tensor:
    """Detach the communication buffer when the backend tensor exposes ``detach``."""
    detach = getattr(tensor, "detach", None)
    return detach() if detach is not None else tensor


def _launch_async_a2a_seq_to_head(
    tensor: Tensor,
68
69
70
71
72
73
74
75
76
        shape[:head_dim] + [world_size, num_heads // world_size] + shape[head_dim + 1:]
    ).permute(
        [head_dim] + list(range(head_dim)) + list(range(head_dim + 1, ndim))
    ).contiguous()
    out_perm, work = platform.all_to_all_single(_detach_if_available(x_perm), list(x_perm.shape), group, async_op=True)
    return work, out_perm


def _a2a_reconstruct(out_perm: Tensor, concat_dim: int) -> Tensor:
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102


def _move_dim_to_front(tensor: Tensor, dim: int) -> Tensor:
    """Move ``dim`` to the leading dimension before all-gather/reduce-scatter."""
    dim = _normalize_dim(dim, tensor.dim())
    if dim == 0:
        return tensor.contiguous()
    perm = [dim] + [i for i in range(tensor.dim()) if i != dim]
    return tensor.permute(perm).contiguous()


def _move_dim_from_front(tensor: Tensor, dim: int) -> Tensor:
    """Inverse of :func:`_move_dim_to_front`."""
101
102
103
104
105
106
107
108
109
def _move_dim_from_front(tensor: Tensor, dim: int) -> Tensor:
    """Inverse of :func:`_move_dim_to_front`."""
    dim = _normalize_dim(dim, tensor.dim())
    if dim == 0:
        return tensor.contiguous()
    perm = [dim] + [i for i in range(tensor.dim()) if i != dim]
    inverse = [0] * len(perm)
    for idx, value in enumerate(perm):
        inverse[value] = idx
116
117
118
119
120
121
122
123
124
125
126
127
128
    world_size: int,
    gather_dim: int,
) -> tuple:
    """Launch async all-gather along ``gather_dim``."""
    x_perm = _move_dim_to_front(tensor.contiguous(), gather_dim)
    output_shape = list(x_perm.shape)
    output_shape[0] *= world_size
    out_perm, work = platform.all_gather_single(_detach_if_available(x_perm), output_shape, group, async_op=True)
    return work, out_perm


def _allgather_reconstruct(out_perm: Tensor, gather_dim: int) -> Tensor:
    """Move the leading communication buffer dimension back to ``gather_dim``.
229
230
231
232
233
234
235
236
237
238
239
240
241
            )
        co = cp_size // ds

        if ds == 1:
            if self.load_balance:
                return super().apply(module, device_mesh)
            return self._apply_colossal_async(module, device_mesh, cp_size, k_proj, v_proj)

        return self._apply_a2a_async(module, device_mesh, ds, co, q_proj, k_proj, v_proj)

    def _apply_colossal_async(
        self,
        module: Module,
244
245
246
247
248
249
250
251
252
253
254
255
256
        k_proj: Module,
        v_proj: Module,
    ) -> Module:
        """Register Pure Colossal async K/V AllGather hooks."""
        co_submesh = _ensure_1d(device_mesh)
        group = co_submesh.get_group()
        fwd_ag_slots = {"k": None, "v": None}
        bwd_ag_slots = {"k": [], "v": []}
        self._register_ag_proj_hooks(
            k_proj,
            v_proj,
            group=group,
            world_size=cp_size,
256
257
258
259
260
261
262
263
264
            world_size=cp_size,
            fwd_slots=fwd_ag_slots,
            bwd_slots=bwd_ag_slots,
        )
        platform.register_forward_pre_hook(
            module,
            partial(
                self._attn_pre_hook_colossal,
                co_submesh=co_submesh,
268
269
270
271
272
273
274
275
276
277
278
279
                bwd_slots=bwd_ag_slots,
            ),
            with_kwargs=True,
        )
        module.register_forward_hook(
            partial(self._post_hook_colossal, co_submesh=co_submesh)
        )
        return module

    def _apply_a2a_async(  # pylint: disable=too-many-arguments
        self,
        module: Module,
290
291
292
293
294
295
296
297
298

        if co == 1:
            ds_submesh = _ensure_1d(device_mesh)
            group = ds_submesh.get_group()
            pre_hook = partial(
                self._attn_pre_hook_ulysses,
                group=group,
                world_size=ds,
                fwd_slots=fwd_slots,
303
304
305
306
307
308
309
310
311
            dim_names = two_d_mesh.mesh_dim_names
            assert dim_names is not None, "2-D mesh must have mesh_dim_names (guaranteed by _build_2d_mesh)"
            ds_submesh = two_d_mesh[dim_names[1]]
            group = ds_submesh.get_group()
            pre_hook = partial(
                self._attn_pre_hook_hybrid,
                group=group,
                world_size=ds,
                two_d_mesh=two_d_mesh,
312
313
314
315
316
317
318
319
320
                fwd_slots=fwd_slots,
                bwd_slots=bwd_slots,
            )

        self._register_proj_hooks(
            q_proj,
            k_proj,
            v_proj,
            group=group,
321
322
323
324
325
326
327
328
329
            world_size=ds,
            fwd_slots=fwd_slots,
            bwd_slots=bwd_slots,
        )
        platform.register_forward_pre_hook(
            module,
            pre_hook,
            with_kwargs=True,
        )
359
360
361
362
363
364
365
366
367
368
369
370
371
        return output

    def _register_ag_proj_hooks(self, k_proj, v_proj, group, world_size, fwd_slots, bwd_slots):
        """Register async AllGather hooks for K/V projection modules."""
        for key, proj in [("k", k_proj), ("v", v_proj)]:
            proj.register_forward_hook(
                partial(self._proj_ag_post_hook, key=key, group=group, world_size=world_size,
                        fwd_slots=fwd_slots)
            )
            platform.register_full_backward_pre_hook(
                proj,
                partial(self._proj_ag_bwd_pre_hook, bwd_slot=bwd_slots[key])
            )
373
374
375
376
377
378
379
380
381
382
383
384
385
    def _proj_ag_post_hook(  # pylint: disable=unused-argument,too-many-arguments
        self, module, inputs, output, key, group, world_size, fwd_slots
    ):
        """Launch async K/V AllGather after projection; return original output."""
        tensor = output.to_local() if isinstance(output, DTensor) else output
        fwd_slots[key] = _launch_async_allgather_seq(
            tensor, group, world_size, self.seq_dim
        )
        return output

    def _get_qkv_value(self, args, kwargs, qkv_pos: int):
        """Return Q/K/V value from positional args or configured kwargs."""
        idx = self.qkv_indices[qkv_pos]
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    def _get_qkv_value(self, args, kwargs, qkv_pos: int):
        """Return Q/K/V value from positional args or configured kwargs."""
        idx = self.qkv_indices[qkv_pos]
        if idx < len(args):
            return args[idx]
        if qkv_pos < len(self.qkv_kwarg_names):
            name = self.qkv_kwarg_names[qkv_pos]
            if name in kwargs:
                return kwargs[name]
        return None

    def _set_qkv_value(self, args, kwargs, qkv_pos: int, value):
        """Set Q/K/V value in positional args or configured kwargs."""
        idx = self.qkv_indices[qkv_pos]
394
395
396
397
398
399
400
401
402
403
    def _set_qkv_value(self, args, kwargs, qkv_pos: int, value):
        """Set Q/K/V value in positional args or configured kwargs."""
        idx = self.qkv_indices[qkv_pos]
        if idx < len(args):
            args[idx] = value
            return
        if qkv_pos < len(self.qkv_kwarg_names):
            name = self.qkv_kwarg_names[qkv_pos]
            if name in kwargs:
                kwargs[name] = value
417
418
419
420
421
422
423
424
425
        )

    def _wait_allgather(self, tensor, group, world_size, work, out_perm, bwd_slot=None):
        """Wait for pre-launched AllGather and return gathered tensor."""
        return platform.differentiable_async_allgather_wait(
            tensor,
            work,
            out_perm,
            group,
451
452
453
454
455
456
457
458
459
        transforms = (transform_q, transform_k, transform_v)
        for pos, transform in enumerate(transforms):
            value = self._get_qkv_value(new_args, new_kwargs, pos)
            if value is None:
                continue
            self._set_qkv_value(
                new_args,
                new_kwargs,
                pos,
509
510
511
512
513
514
515
516
517

        for pos, key in enumerate(("k", "v"), start=1):
            value = self._get_qkv_value(new_args, new_kwargs, pos)
            if value is None:
                continue
            local = _to_local(value)
            work, out_perm = fwd_slots[key]
            fwd_slots[key] = None
            gathered = self._wait_allgather(
647
648
649
650
651
652
653
654
        return (d_seq,) + grad_output[1:] if isinstance(grad_output, tuple) else (d_seq,)

    def _proj_ag_bwd_pre_hook(self, module, grad_output, bwd_slot):  # pylint: disable=unused-argument
        """Wait backward reduce-scatter just before K/V projection GEMM."""
        work, out_perm, gather_dim = bwd_slot.pop()
        work.wait()
        d_local = _allgather_reconstruct(out_perm, gather_dim)
        return (d_local,) + grad_output[1:] if isinstance(grad_output, tuple) else (d_local,)
hyper_parallel/core/context_parallel/async_dsa_context_parallel.py
57
58
59
60
61
62
63
64
65
66
67
    @staticmethod
    def _extract_tensor_output(output: Any) -> Any:
        if _is_tensor_or_dtensor(output):
            return output
        if isinstance(output, (tuple, list)) and len(output) == 1 and _is_tensor_or_dtensor(output[0]):
            return output[0]
        return None

    @staticmethod
    def _local_tensor(value: Any) -> Any:
        return value.to_local() if isinstance(value, DTensor) else value
79
80
81
82
83
84
85
86
87
                self.launch(slot_name, tensor)
            return output

        def _backward_pre_hook(hook_module, grad_output):
            del hook_module
            return self._producer_bwd_pre_hook(grad_output, bwd_slot)

        module.register_forward_hook(_post_hook)
        platform.register_full_backward_pre_hook(module, _backward_pre_hook)
88
89
90
91
92
93
94
95
96

    def _producer_bwd_pre_hook(self, grad_output: Any, bwd_slot: list) -> Any:
        """Wait deferred reduce-scatter before gradients cross the producer boundary."""
        if not bwd_slot:
            return grad_output
        work, out_perm, gather_dim = bwd_slot.pop()
        work.wait()
        d_local = _allgather_reconstruct(out_perm, gather_dim)
        if isinstance(grad_output, tuple):
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        work.wait()
        d_local = _allgather_reconstruct(out_perm, gather_dim)
        if isinstance(grad_output, tuple):
            return (d_local,) + grad_output[1:]
        return (d_local,)

    def launch(self, slot_name: str, value: Any) -> None:
        """Launch all-gather for ``value`` and enqueue its handle."""
        if not _is_tensor_or_dtensor(value):
            return
        local = self._local_tensor(value)
        if not platform.is_tensor(local):
            return
        if self.world_size <= 1:
            self._slots.setdefault(slot_name, []).append((local, None, None))
            return
        work, out_perm = _launch_async_allgather_seq(local, self.group, self.world_size, self.seq_dim)
        self._slots.setdefault(slot_name, []).append((local, work, out_perm))

    def wait(self, slot_name: str, value: Any) -> Any:
        """Wait on a pre-launched gather, or fall back to consumer-local launch."""
        if not _is_tensor_or_dtensor(value):
            return value
        slot = self._slots.get(slot_name)
        if slot:
            local, work, out_perm = slot.pop(0)
            if work is None:
118
119
120
121
122
123
124
125
126
        if slot:
            local, work, out_perm = slot.pop(0)
            if work is None:
                return DTensor.from_local(local, self.device_mesh, (Replicate(),))
            gathered = platform.differentiable_async_allgather_wait(
                local,
                work,
                out_perm,
                self.group,
127
128
129
130
131
132
133
134
135
                self.world_size,
                self.seq_dim,
                self._bwd_slots.setdefault(slot_name, []),
            )
            return DTensor.from_local(gathered, self.device_mesh, (Replicate(),))
        return _to_sequence_replicate(value, self.device_mesh, self.seq_dim)


class AsyncDSAIndexerContextParallel(DSAIndexerContextParallel):
hyper_parallel/core/context_parallel/dsa_context_parallel.py
503
504
505
506
507
508
509
510
511
        if self.key_indexer_index is not None and self.key_indexer_index < len(args):
            return self._local_shape(args[self.key_indexer_index])
        if self.key_indexer_kwarg_name and self.key_indexer_kwarg_name in kwargs:
            return self._local_shape(kwargs[self.key_indexer_kwarg_name])
        return None

    @staticmethod
    def _get_local_idx(cp_mesh: DeviceMesh) -> int:
        """Return current rank's index in the CP mesh rank list."""
hyper_parallel/platform/mindspore/platform.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108


def _normalize_dim(dim: int, ndim: int) -> int:
    """Normalize a possibly-negative dimension index."""
    return dim + ndim if dim < 0 else dim


def _move_dim_to_front(tensor: Tensor, dim: int) -> Tensor:
    """Move ``dim`` to the front while preserving the other dimensions' order."""
    dim = _normalize_dim(dim, tensor.dim())
    if dim == 0:
        return tensor.contiguous()
    perm = [dim] + [i for i in range(tensor.dim()) if i != dim]
    return tensor.permute(perm).contiguous()


def _move_dim_from_front(tensor: Tensor, dim: int) -> Tensor:
    """Inverse of :func:`_move_dim_to_front`."""
    dim = _normalize_dim(dim, tensor.dim())
    if dim == 0:
        return tensor.contiguous()
    perm = [dim] + [i for i in range(tensor.dim()) if i != dim]
    inverse = [0] * len(perm)
    for idx, value in enumerate(perm):
        inverse[value] = idx
    return tensor.permute(inverse).contiguous()


def _normalize_all_to_all_single_result(result, output: Tensor) -> tuple[Tensor, object]:
    """Normalize MindSpore all_to_all_single return values to ``(output, handle)``."""
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144


def _normalize_all_gather_single_result(result, output: Tensor) -> tuple[Tensor, object]:
    """Normalize MindSpore all_gather_into_tensor return values to ``(output, handle)``."""
    if isinstance(result, tuple):
        if len(result) != 2:
            raise ValueError(
                "mindspore all_gather_into_tensor returned an unexpected tuple "
                f"with length {len(result)}"
            )
        return result
    return output, result


def _normalize_reduce_scatter_single_result(result, output: Tensor) -> tuple[Tensor, object]:
    """Normalize MindSpore reduce_scatter_tensor return values to ``(output, handle)``."""
    if isinstance(result, tuple):
        if len(result) != 2:
            raise ValueError(
                "mindspore reduce_scatter_tensor returned an unexpected tuple "
                f"with length {len(result)}"
            )
        return result
    return output, result


def _mindspore_all_to_all_single(input_tensor: Tensor, output_shape, group, async_op=False) -> tuple[Tensor, object]:
    """Launch MindSpore all_to_all_single and normalize return values."""
151
152
153
154
155
156
157
158
159
160
161
162
163
164


def _mindspore_all_gather_single(input_tensor: Tensor, output_shape, group, async_op=False) -> tuple[Tensor, object]:
    """Launch MindSpore all_gather_into_tensor and normalize return values."""
    output = mint.empty(tuple(output_shape), dtype=input_tensor.dtype)
    result = ops_comm.all_gather_into_tensor(output, input_tensor, group=group, async_op=async_op)
    normalized_output, handle = _normalize_all_gather_single_result(result, output)
    if not async_op:
        return normalized_output, None
    return normalized_output, handle


def _mindspore_reduce_scatter_single(
        input_tensor: Tensor, output_shape, group, async_op=False
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def _mindspore_reduce_scatter_single(
        input_tensor: Tensor, output_shape, group, async_op=False
) -> tuple[Tensor, object]:
    """Launch MindSpore reduce_scatter_tensor and normalize return values."""
    output = mint.empty(tuple(output_shape), dtype=input_tensor.dtype)
    result = ops_comm.reduce_scatter_tensor(output, input_tensor, group=group, async_op=async_op)
    normalized_output, handle = _normalize_reduce_scatter_single_result(result, output)
    if not async_op:
        return normalized_output, None
    return normalized_output, handle


class AsyncCollectiveTensor(Tensor):
    """MindSpore Tensor subclass that defers ``CommHandle.wait()`` to
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713

    @staticmethod
    def forward(ctx, x, work, out_perm, group, world_size, gather_dim, handle_box):  # pylint: disable=arguments-differ
        """Wait for pre-launched all-gather and reconstruct the gathered tensor."""
        ctx.group = group
        ctx.world_size = world_size
        ctx.gather_dim = gather_dim
        ctx.handle_box = handle_box
        ctx.x_shape = tuple(x.shape)
        work.wait()
        return _move_dim_from_front(out_perm, gather_dim)

    @staticmethod
    def backward(ctx, grad_output):
        """Launch reverse reduce-scatter for the all-gather."""
        grad_perm = _move_dim_to_front(grad_output.contiguous(), ctx.gather_dim)
        output_shape = list(grad_perm.shape)
        if output_shape[0] % ctx.world_size != 0:
            raise ValueError(
                "all_gather backward expected gathered dimension to be divisible by world_size, "
                f"got {output_shape[0]} and {ctx.world_size}."
            )
        output_shape[0] //= ctx.world_size
        output, work = _mindspore_reduce_scatter_single(
            grad_perm,
            output_shape,
            ctx.group,
            async_op=True,
711
712
713
714
715
716
717
718
719
720
721
722
723
            output_shape,
            ctx.group,
            async_op=True,
        )
        if ctx.handle_box is not None:
            ctx.handle_box.append((work, output, ctx.gather_dim))
            return mint.zeros(ctx.x_shape, dtype=grad_output.dtype), None, None, None, None, None, None
        work.wait()
        return _move_dim_from_front(output, ctx.gather_dim), None, None, None, None, None, None


class MindSporePlatform(Platform):
    """MindSpore platform api"""
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
        return group_name

    @staticmethod
    def all_gather_into_tensor(data, group_info, async_op=False):
        group_name = group_info if isinstance(group_info, str) else group_info.group_name
        rank_size = get_group_size(group_name) if isinstance(group_info, str) else group_info.rank_size
        output_shape = list(data.shape)
        output_shape[0] *= rank_size
        return _mindspore_all_gather_single(data, output_shape, group_name, async_op=async_op)

    @staticmethod
    def all_gather_single(input_tensor, output_shape, group, async_op=False):
        return _mindspore_all_gather_single(input_tensor, output_shape, group, async_op=async_op)

    @staticmethod
    def all_reduce(data, group_info, async_op=False):
        if isinstance(group_info, str):
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
        return data

    @staticmethod
    def reduce_scatter_tensor(data, group_info, async_op=False):
        group_name = group_info if isinstance(group_info, str) else group_info.group_name
        rank_size = get_group_size(group_name) if isinstance(group_info, str) else group_info.rank_size
        output_shape = list(data.shape)
        output_shape[0] //= rank_size
        return _mindspore_reduce_scatter_single(data, output_shape, group_name, async_op=async_op)

    @staticmethod
    def reduce_scatter_single(input_tensor, output_shape, group, async_op=False):
        return _mindspore_reduce_scatter_single(input_tensor, output_shape, group, async_op=async_op)

    @staticmethod
    def all_to_all_single(input_tensor, output_shape, group, async_op=False):
        return _mindspore_all_to_all_single(input_tensor, output_shape, group, async_op=async_op)
1273
1274
1275
1276
1277
1278
1279
1280
1281

    @staticmethod
    def differentiable_async_allgather_wait(x, work, out_perm, group, world_size, gather_dim,
                                            handle_box=None):
        return _MSAsyncAllGatherFunction.apply(
            x, work, out_perm, group, world_size, gather_dim, handle_box
        )

    @staticmethod
hyper_parallel/platform/platform.py
588
589
590
591
592
593
594
595
596
        Returns:
            Tuple ``(output, work)`` where *output* is the gathered tensor and
            *work* is the async handle (``None`` when ``async_op=False``).
        """
        raise NotImplementedError("Platform subclasses must implement all_gather_single")

    @staticmethod
    def reduce_scatter_single(input_tensor, output_shape, group, async_op=False):
        """Reduce-scatter a tensor with optional async execution.
604
605
606
607
608
609
610
611
612
        Returns:
            Tuple ``(output, work)`` where *output* is the local shard and
            *work* is the async handle (``None`` when ``async_op=False``).
        """
        raise NotImplementedError("Platform subclasses must implement reduce_scatter_single")

    @staticmethod
    def all_to_all_single(input_tensor, output_shape, group, async_op=False):
        """All-to-all single collective with optional async execution.
653
654
655
656
657
658
659
660
661

        Returns:
            Gathered tensor connected to the autograd graph through *x*.
        """
        raise NotImplementedError("Platform subclasses must implement differentiable_async_allgather_wait")

    @staticmethod
    def differentiable_async_a2a_wait(x, work, out_perm, group, world_size, concat_dim, split_dim,
                                      handle_box=None):
hyper_parallel/platform/torch/platform.py
71
72
73
74
75
76
77
78
79
80
    """Move ``dim`` to the front while keeping the other dimensions ordered."""
    dim = _normalize_dim(dim, tensor.dim())
    if dim == 0:
        return tensor.contiguous()
    perm = [dim] + [i for i in range(tensor.dim()) if i != dim]
    return tensor.permute(perm).contiguous()


def _move_dim_from_front(tensor: torch.Tensor, dim: int) -> torch.Tensor:
    """Inverse of :func:`_move_dim_to_front`."""
80
81
82
83
84
85
86
87
88
89
90
91
92
    """Inverse of :func:`_move_dim_to_front`."""
    dim = _normalize_dim(dim, tensor.dim())
    if dim == 0:
        return tensor.contiguous()
    perm = [dim] + [i for i in range(tensor.dim()) if i != dim]
    inverse = [0] * len(perm)
    for idx, value in enumerate(perm):
        inverse[value] = idx
    return tensor.permute(inverse).contiguous()


class _TorchAsyncA2AFunction(torch.autograd.Function):
    """Differentiable wrapper for pre-launched async all-to-all.
149
150
151
152
153
154
155
156
157
        """Launch reverse reduce-scatter for the all-gather."""
        grad_perm = _move_dim_to_front(grad_output.contiguous(), ctx.gather_dim)
        output_shape = list(grad_perm.shape)
        if output_shape[0] % ctx.world_size != 0:
            raise ValueError(
                "all_gather backward expected gathered dimension to be divisible by world_size, "
                f"got {output_shape[0]} and {ctx.world_size}."
            )
        output_shape[0] //= ctx.world_size