Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/shard/ops/parallel_activation_with_axis.py 100%  
hyper_parallel/core/shard/ops/parallel_concat.py 96.4% 100
hyper_parallel/core/shard/ops/parallel_conv3d.py 67.7% 80,156,161,208-210,213-214,216-218,220-227,231-233,235,238-239,243-244,247,250,252
hyper_parallel/core/shard/ops/parallel_gather.py 80.0% 257,259,298-300,331,339,352,490,500,544,550,559,564,574,577
hyper_parallel/core/shard/ops/parallel_pad.py 87.5% 73,79,87
hyper_parallel/core/shard/ops/parallel_repeat.py 100%  
hyper_parallel/core/shard/ops/parallel_concat.py
 96
 97
 98
 99
100
101
102
103
104
                    f"Expected layout: {base_layout}, Mismatched layout: {layout}"
                )

        if not isinstance(dim, int):
            raise ValueError(
                f"For {self.op_name}, dimension should be int, but got {type(dim)}"
            )

        ndim = len(base_layout.alias_tensor_map)
hyper_parallel/core/shard/ops/parallel_conv3d.py
76
77
78
79
80
81
82
83
                for axis_name in axes:
                    dev_num *= w_layout.mesh.get_device_num_along_axis(axis_name)

                if groups % dev_num != 0:
                    raise ValueError(
                        f"For {self.op_name}, groups ({groups}) "
                        f"must be divisible by tp_size ({dev_num})."
                    )
152
153
154
155
156
157
158
159
160
161
162
163
164
        if b_layout is not None:
            self._check_partial_inputs([b_layout])

        if in_layout.mesh_shape != w_layout.mesh_shape:
            raise ValueError(
                f"For {self.op_name}, input and weight must have the same mesh_shape, "
                f"but got input: {in_layout.mesh_shape} and weight: {w_layout.mesh_shape}"
            )
        if b_layout is not None and b_layout.mesh_shape != in_layout.mesh_shape:
            raise ValueError(
                f"For {self.op_name}, bias and input must have the same mesh_shape, "
                f"but got bias: {b_layout.mesh_shape} and input: {in_layout.mesh_shape}"
            )
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        """
        Get expand implementation for the operator.
        Intercepts the execution to handle Grouped Convolution with Column Parallelism.
        """
        w_layout = cache_values[1]
        w_map = w_layout.alias_tensor_map
        w_map_0 = w_map[0]

        # If Weight is NOT sharded on C_out (dim=0), native conv3d works fine.
        if w_map_0 == "None":
            return None

        parsed_groups = cache_values[6]
        if parsed_groups == 1:
            return None

        mesh = w_layout.mesh
        axes = w_map_0 if isinstance(w_map_0, tuple) else (w_map_0,)
        dev_num = 1
        local_rank = 0
        for axis_name in axes:
            axis_size = mesh.get_device_num_along_axis(axis_name)
            dev_num *= axis_size
            local_rank = local_rank * axis_size + mesh.get_local_rank(axis_name)

        # Pre-calculate local groups and group boundaries for the current device ahead of time.
        # This hoisting optimization avoids redundant calculations during every forward pass.
        local_groups = parsed_groups // dev_num
        start_group = local_rank * local_groups
        end_group = start_group + local_groups

        def distributed_conv3d_impl(input_tensor, weight_tensor, bias=None, stride=1, padding=0, dilation=1, groups=1):
            # --- Handling Groups > 1 with Column Parallelism ---
            # Calculate the input channel chunk size
            c_in = input_tensor.shape[1]
            c_in_per_group = c_in // groups

            # Map the pre-calculated groups to the actual input channels
            # Uses start_group and end_group captured from the outer scope
            start_channel = start_group * c_in_per_group
            end_channel = end_group * c_in_per_group

            # Slice the replicated input to match the local groups
            sliced_input = input_tensor[:, start_channel:end_channel, ...]

            # Execute native conv3d with the sliced input and adjusted local groups
            return func(sliced_input, weight_tensor, bias, stride, padding, dilation, local_groups)

        return distributed_conv3d_impl
hyper_parallel/core/shard/ops/parallel_gather.py
253
254
255
256
257
258
259
260
261
262
263

        input_layout, index_layout, dim = cache_values[0], cache_values[1], cache_values[2]
        # Validate layouts exist
        if input_layout is None or not hasattr(input_layout, "tensor_map"):
            raise ValueError(f"For {self.op_name}, input layout cannot be None")
        if index_layout is None or not hasattr(index_layout, "tensor_map"):
            raise ValueError(f"For {self.op_name}, index layout cannot be None")
        input_tensor_map = input_layout.alias_tensor_map
        index_tensor_map = index_layout.alias_tensor_map
        # Validate same rank
        if len(input_tensor_map) != len(index_tensor_map):
294
295
296
297
298
299
300
301
302
303
304
            # pylint: disable=protected-access
            # Inherit current partial state from index layout
            output_layout._partial = list(index_layout.partial)
            if isinstance(dim_axis_name, tuple):
                for axis_name in dim_axis_name:
                    if axis_name != "None":
                        output_layout.set_partial_by_dev_axis(axis_name, 'sum')
            else:
                output_layout.set_partial_by_dev_axis(dim_axis_name, 'sum')
        # pylint: disable=protected-access
        # Rebuild readable alias tensor map
327
328
329
330
331
332
333
334
335
        """
        input_layout = cache_values[0]
        dim = cache_values[2]
        if dim < 0:
            dim += len(input_layout.tensor_map)
        input_alias_map = input_layout.alias_tensor_map
        # Check if dim axis is sharded (enhanced MP)
        if input_alias_map[dim] == "None": # native sharding, no need for custom implementation
            return None
335
336
337
338
339
340
341
342
343
            return None

        dim_axis_name = input_alias_map[dim]
        if isinstance(dim_axis_name, tuple):
            dim_axis_name = next(axis for axis in dim_axis_name if axis != "None")

        def distributed_gatherd_impl(*args, **kwargs):
            """
            Distributed GatherD implementation for sharded dim axis.
348
349
350
351
352
353
354
355
356
            input_tensor = args[0]
            index_tensor = args[2]
            # Calculate local partition offset for the dim axis
            mesh = input_layout.mesh
            mesh_dim_idx = input_layout.alias_name.index(dim_axis_name)
            # Get the coordinate of current rank along the mesh dimension
            dim_coord = mesh.get_local_rank(mesh_dim_idx)
            # Calculate the size of input tensor's dim dimension per partition
            input_dim_size = input_tensor.shape[dim]
486
487
488
489
490
491
492
493
494
                )

        # For GatherNd: input_layout can be None (treated as fully replicated), but indices_layout must exist.
        if indices_layout is None or not hasattr(indices_layout, "alias_tensor_map"):
            raise ValueError(f"For {self.op_name}, indices layout cannot be None")

        return input_layout, indices_layout

    def _validate_tensor_maps(self, input_layout, indices_layout, k):
496
497
498
499
500
501
502
503
504
        indices_tensor_map = indices_layout.alias_tensor_map

        # Validate: indices tensor_map must exist and last dimension cannot be split.
        if not indices_tensor_map:
            raise ValueError(f"For {self.op_name}, indices tensor_map cannot be empty")

        last_axis = indices_tensor_map[-1]
        if not self._is_none_axis(last_axis):
            raise ValueError(
540
541
542
543
544
545
546
547

        input_shape = input_shapes[0]
        indices_shape = input_shapes[1]
        if input_shape is None or indices_shape is None:
            raise ValueError(f"For {self.op_name}, input_shapes contains None: {input_shapes}")

        input_shape = self._normalize_shape(input_shape, "input")
        indices_shape = self._normalize_shape(indices_shape, "indices")
546
547
548
549
550
551
552
553
554
        input_shape = self._normalize_shape(input_shape, "input")
        indices_shape = self._normalize_shape(indices_shape, "indices")

        if len(indices_shape) < 1:
            raise ValueError(f"For {self.op_name}, indices shape invalid: {indices_shape}")

        return input_shape, indices_shape

    def _normalize_shape(self, shape, name):
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        """Normalize shape-like object to tuple of int."""
        try:
            norm = tuple(shape)
        except TypeError as err:
            raise ValueError(f"For {self.op_name}, {name} shape is not iterable: {shape}") from err

        try:
            norm = tuple(int(dim) for dim in norm)
        except (TypeError, ValueError) as err:
            raise ValueError(f"For {self.op_name}, {name} shape contains non-integer dims: {norm}") from err

        return norm

    def _get_k_and_trailing_rank(self, input_shape, indices_shape):
570
571
572
573
574
575
576
577
578
579
580
581
        k = indices_shape[-1]
        try:
            k = int(k)
        except (TypeError, ValueError) as err:
            raise ValueError(f"For {self.op_name}, indices last dim (K) is invalid: {k}") from err

        if k <= 0:
            raise ValueError(f"For {self.op_name}, indices last dim (K) must be positive, but got {k}")

        trail_rank = len(input_shape) - k
        if trail_rank < 0:
            raise ValueError(
hyper_parallel/core/shard/ops/parallel_pad.py
69
70
71
72
73
74
75
76
77
            ValueError: If input has Partial status, pad is invalid, or padding is
                attempted on a sharded dimension.
        """
        if len(cache_values) != 2:
            raise ValueError(
                f"For {self.op_name}, cache_values length should be 2, but got {len(cache_values)}"
            )

        input_layout, pad = cache_values[0], cache_values[1]
75
76
77
78
79
80
81
82
83
            )

        input_layout, pad = cache_values[0], cache_values[1]
        if input_layout is None:
            raise ValueError(f"For {self.op_name}, pad requires a valid input tensor layout.")

        self._check_partial_inputs([input_layout])

        tensor_map = input_layout.alias_tensor_map
83
84
85
86
87
88
89
90
91
        tensor_map = input_layout.alias_tensor_map
        ndim = len(tensor_map)

        if not isinstance(pad, (tuple, list)):
            raise ValueError(
                f"For {self.op_name}, expected pad tuple or list, but got {type(pad)}"
            )

        pad_len = len(pad)