Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/shard/ops/parallel_argmax_with_value_ops.py 96.9% 101
hyper_parallel/core/shard/ops/parallel_argsort.py 100%  
hyper_parallel/core/shard/ops/parallel_atleast_1d.py 42.9% 26,46-47,49,53,55,60,142
hyper_parallel/core/shard/ops/parallel_expand.py 75.7% 26,30,47-52,94,101,107,197,225-230
hyper_parallel/core/shard/ops/parallel_isin.py 100%  
hyper_parallel/core/shard/ops/parallel_masked_scatter.py 61.1% 25,45-47,52-53,58
hyper_parallel/core/shard/ops/parallel_nonzero.py 100%  
hyper_parallel/core/shard/ops/parallel_norm.py 73.0% 30,39,56-61,101,143-144,149,152,159,161-162,193
hyper_parallel/core/shard/ops/parallel_one_hot_ext.py 59.4% 29,46-47,49-51,53-54,56-59,103
hyper_parallel/core/shard/ops/parallel_reduce.py 89.4% 135,141,162,167,379,414,428
hyper_parallel/core/shard/ops/parallel_slice_ext.py 74.1% 26,43-48
hyper_parallel/core/shard/ops/parallel_split.py 47.7% 27,31,35,56-62,93,95,123-129,149-151,153-154,156-157,159-162,167-168,173,190-191,193-194,196,198-201,233,263-269,300,302,330-336,356-358,360-361,363-364,366-369,374-375,380,412-414,416-419
hyper_parallel/core/shard/ops/parallel_topk.py 97.2% 86
hyper_parallel/core/shard/ops/parallel_unbind.py 97.1% 83
hyper_parallel/core/shard/ops/parallel_argmax_with_value_ops.py
 97
 98
 99
100
101
102
103
104
105
        # Use alias_tensor_map to support StridedShard multi-axis mappings.
        alias_map = input_layout.alias_tensor_map
        mapping = alias_map[axis]
        if isinstance(mapping, tuple):
            is_sharded = any(m != "None" for m in mapping)
        else:
            is_sharded = mapping != "None"

        if is_sharded:
hyper_parallel/core/shard/ops/parallel_atleast_1d.py
22
23
24
25
26
27
28
29
30
from .parallel_ops import DistributedOp


def _normalize_atleast_1d_args(*tensors):
    return tensors, {}


class Atleast1DDistributedOp(DistributedOp):
    """Distributed implementation for torch.atleast_1d."""
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_atleast_1d_args(*args, **kwargs)
        tensors = args

        local_args = tuple(
            t.to_local() if hasattr(t, 'to_local') else t
            for t in tensors
        )
        local_kwargs = {}

        cache_values = [
            t.layout if hasattr(t, 'layout') else None
            for t in tensors
        ]

        return local_args, local_kwargs, cache_values

    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
        Infer output layouts for atleast_1d operator.
138
139
140
141
142
        # If there are multiple inputs, return a tuple of Layouts.
        if len(output_layouts) == 1:
            return ((output_layouts[0],), None)

        return (tuple(output_layouts), None)
hyper_parallel/core/shard/ops/parallel_expand.py
22
23
24
25
26
27
28
29
30
31
32
33
34
from .parallel_ops import DistributedOp


def _normalize_expand_args(input_tensor, *sizes):
    return (input_tensor, *sizes), {}


def _normalize_expand_as_args(input_tensor, target_tensor):
    return (input_tensor, target_tensor), {}


class ExpandDistributedOp(DistributedOp):
    """Distributed implementation for torch.Tensor.expand."""
43
44
45
46
47
48
49
50
51
52
53
54
55
56

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_expand_args(*args, **kwargs)
        input_tensor = args[0]
        sizes = tuple(args[1:])
        local_args = (input_tensor.to_local(), *sizes)
        cache_values = [input_tensor.layout, input_tensor.shape, sizes]
        return local_args, {}, cache_values

    @staticmethod
    def _validate_input_layouts(
        cache_values: list,
90
91
92
93
94
95
96
97
            )

        input_shape = cache_values[1] if len(cache_values) > 1 else None
        if not isinstance(input_shape, tuple):
            raise ValueError(
                f"For {op_name}, input_shape should be a tuple, "
                f"but got {type(input_shape)}."
            )
 97
 98
 99
100
101
102
103
104
105
            )

        sizes = cache_values[2] if len(cache_values) > 2 else None
        if sizes is None or len(sizes) < 1:
            raise ValueError(
                f"For {op_name}, sizes should be a non-empty tuple of ints, "
                f"but got {sizes}."
            )
        for i, sz in enumerate(sizes):
103
104
105
106
107
108
109
110
                f"but got {sizes}."
            )
        for i, sz in enumerate(sizes):
            if not isinstance(sz, int):
                raise ValueError(
                    f"For {op_name}, elements in sizes should be int, "
                    f"but got {type(sz)} at position {i}."
                )
193
194
195
196
197
198
199
200
                        f"got mapping {in_alias_map[i]}."
                    )
                output_map.append("None")
            else:
                raise ValueError(
                    f"For {self.op_name}, cannot expand dimension {i} "
                    f"from size {input_size} to {requested_size}."
                )
221
222
223
224
225
226
227
228
229
230
231
232
233
234

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_expand_as_args(*args, **kwargs)
        input_tensor = args[0]
        target_tensor = args[1]
        local_args = (input_tensor.to_local(), target_tensor.to_local())
        cache_values = [input_tensor.layout, input_tensor.shape, target_tensor.shape]
        return local_args, {}, cache_values

    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
        Infer output layout for expand_as.
hyper_parallel/core/shard/ops/parallel_masked_scatter.py
21
22
23
24
25
26
27
28
29
from .parallel_ops import DistributedOp


def _normalize_masked_scatter_args(input_tensor, mask, source):
    return (input_tensor, mask, source), {}


class MaskedScatterDistributedOp(DistributedOp):
    """Distributed implementation for torch.Tensor.masked_scatter."""
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_masked_scatter_args(*args, **kwargs)
        input_tensor, mask, source = args[0], args[1], args[2]
        local_args = (
            input_tensor.to_local(),
            mask.to_local(),
            source.to_local(),
        )
        local_kwargs = {}
        cache_values = [
            input_tensor.layout,
            mask.layout,
            source.layout,
        ]
        return local_args, local_kwargs, cache_values

    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
        Infer output layout for torch.Tensor.masked_scatter.
hyper_parallel/core/shard/ops/parallel_norm.py
26
27
28
29
30
31
32
33
34
    """Normalize RmsNorm args to positional form.

    MindSpore Primitive RmsNorm receives (x, gamma, epsilon) as positional arguments.
    """
    return (x, gamma, epsilon), {}


def _normalize_layernorm_args(input_tensor, normalized_shape, weight=None, bias=None, eps=1e-5):
    """Normalize layer_norm args to positional form.
35
36
37
38
39
40
41
42
43

    torch.nn.functional.layer_norm(input_tensor, normalized_shape, weight=None, bias=None, eps=1e-5)
    has no keyword-only parameters, so everything stays positional.
    """
    return (input_tensor, normalized_shape, weight, bias, eps), {}


class NormDistributedOp(DistributedOp):
    """Distributed implementation for RmsNorm operator."""
52
53
54
55
56
57
58
59
60
61
62
63
64
65

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_rmsnorm_args(*args, **kwargs)
        x, gamma, epsilon = args
        local_args = (x.to_local(), gamma.to_local(), epsilon)
        local_kwargs = {}
        cache_values = [x.layout, gamma.layout]
        return local_args, local_kwargs, cache_values

    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
        Infer output layouts for RmsNorm operator.
 97
 98
 99
100
101
102
103
104
105
            raise ValueError(f"{self.op_name} inputs must have same mesh_shape")
        x_alias_map = x_layout.alias_tensor_map
        gamma_alias_map = gamma_layout.alias_tensor_map
        if len(gamma_alias_map) > len(x_alias_map):
            raise ValueError(
                f"For {self.op_name}, gamma ndim {len(gamma_alias_map)} cannot exceed "
                f"input ndim {len(x_alias_map)}."
            )
        begin_norm_axis = len(x_alias_map) - len(gamma_alias_map)
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_layernorm_args(*args, **kwargs)
        input_tensor, normalized_shape, weight, bias, eps = args

        # Normalize normalized_shape: int → (int,), list → tuple
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        elif isinstance(normalized_shape, list):
            normalized_shape = tuple(normalized_shape)

        local_args = [
            input_tensor.to_local(),
            normalized_shape,
            weight.to_local() if weight is not None and hasattr(weight, 'to_local') else weight,
            bias.to_local() if bias is not None and hasattr(bias, 'to_local') else bias,
155
156
157
158
159
160
161
162
163
164
165
166
            weight.to_local() if weight is not None and hasattr(weight, 'to_local') else weight,
            bias.to_local() if bias is not None and hasattr(bias, 'to_local') else bias,
            eps,
        ]
        local_kwargs = {}

        cache_values = [input_tensor.layout, normalized_shape]
        return tuple(local_args), local_kwargs, cache_values

    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
        Infer output layout for layer_norm operator.
189
190
191
192
193
194
195
196
        if not self._allow_partial_inputs:
            self._check_partial_inputs([input_layout])

        if normalized_shape is None:
            raise ValueError(f"{self.op_name} requires normalized_shape.")

        if not isinstance(normalized_shape, tuple):
            raise ValueError(f"normalized_shape must be int, list, or tuple, got {type(normalized_shape)}")
hyper_parallel/core/shard/ops/parallel_one_hot_ext.py
25
26
27
28
29
30
31
32
33
platform = get_platform()


def _normalize_one_hot_ext_args(indices, num_classes, on_value, off_value, axis):
    return (indices, num_classes, on_value, off_value, axis), {}


class OneHotExtDistributedOp(DistributedOp):
    """Distributed implementation for OneHotExt operator."""
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_one_hot_ext_args(*args, **kwargs)
        indices, num_classes, on_value, off_value, axis = args

        indices_local = indices.to_local()
        on_value_local = on_value.to_local() if hasattr(on_value, '_layout') else on_value
        off_value_local = off_value.to_local() if hasattr(off_value, '_layout') else off_value

        on_value_layout = on_value.layout if hasattr(on_value, '_layout') else None
        off_value_layout = off_value.layout if hasattr(off_value, '_layout') else None

        local_args = (indices_local, num_classes, on_value_local, off_value_local, axis)
        local_kwargs = {}
        cache_values = [indices.layout, on_value_layout, off_value_layout, num_classes, axis]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
 99
100
101
102
103
104
105
106
107
        axis = self._validate_axis(axis)

        in_tensor_map = indices_layout.tensor_map
        if not in_tensor_map:
            raise ValueError(
                f"For {self.op_name}, indices tensor_map is empty."
            )

        self._validate_multi_dim_restriction(in_tensor_map, axis, indices_layout)
hyper_parallel/core/shard/ops/parallel_reduce.py
131
132
133
134
135
136
137
138
139
        dim = cache_values[1]
        keepdim = cache_values[2]

        if x_layout is None or x_layout.mesh_shape is None:
            raise ValueError(
                f"For {self.op_name}, input layout cannot be None."
            )

        # Check partial inputs
137
138
139
140
141
142
143
144
145
            )

        # Check partial inputs
        if not self._allow_partial_inputs:
            self._check_partial_inputs([x_layout])

        if dim is not None and not isinstance(dim, (int, tuple, list)):
            raise TypeError(
                f"For {self.op_name}, the `dim` argument should be `None`, `int`, "
158
159
160
161
162
163
164
165
166
167
168
169
170
171
            return self._handle_all_axis_reduce(x_layout, keepdim)

        # Case 2: dim is an empty tuple/list — reduce no dimensions, output layout equals input.
        if isinstance(dim, (tuple, list)) and len(dim) == 0:
            output_layout = Layout(
                mesh_shape=x_layout.mesh_shape,
                alias_name=x_layout.alias_name,
                rank_list=x_layout.rank_list
            )
            return output_layout(*x_layout.alias_tensor_map)

        # Case 3: dim is int, tuple, or list with at least one element.
        output_layout = Layout(
            mesh_shape=x_layout.mesh_shape,
375
376
377
378
379
380
381
382
            )
        # torch.max(x) global reduction: only pass the tensor so the call
        # remains torch.max(local_x), not torch.max(local_x, None, False).
        if dim is None:
            local_args = (input_tensor.to_local(),)
        else:
            local_args = (input_tensor.to_local(), dim, keepdim)
        local_kwargs = {}
410
411
412
413
414
415
416
417
418
        # Element-wise mode: two Layout objects in cache_values.
        if len(cache_values) == 2 and hasattr(cache_values[1], "mesh_shape"):
            # Check partial inputs
            if not self._allow_partial_inputs:
                self._check_partial_inputs(cache_values)
            return ((deepcopy(cache_values[0]),), None)

        x_layout = cache_values[0]
        dim = cache_values[1]
424
425
426
427
428
429
430
431
432
            )

        # Check partial inputs
        if not self._allow_partial_inputs:
            self._check_partial_inputs([x_layout])

        if dim is not None and not isinstance(dim, (int, tuple, list)):
            raise TypeError(
                f"For {self.op_name}, the `dim` argument should be `None`, `int`, "
hyper_parallel/core/shard/ops/parallel_slice_ext.py
22
23
24
25
26
27
28
29
30
from .parallel_ops import DistributedOp


def _normalize_slice_ext_args(x, axis, begin, end, step):
    return (x, axis, begin, end, step), {}


class SliceExtDistributedOp(DistributedOp):
    """Distributed implementation for SliceExt operator."""
39
40
41
42
43
44
45
46
47
48
49
50
51
52

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_slice_ext_args(*args, **kwargs)
        input_tensor, axis, begin, end, step = args
        local_args = (input_tensor.to_local(), axis, begin, end, step)
        local_kwargs = {}
        cache_values = [input_tensor.layout, axis]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
hyper_parallel/core/shard/ops/parallel_split.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from .parallel_ops import DistributedOp


def _normalize_split_with_size_args(x, split_sections, dim):
    return (x, split_sections, dim), {}


def _normalize_split_args(x, split_size_or_sections, dim=0):
    return (x, split_size_or_sections, dim), {}


def _normalize_split_tensor_args(x, split_size, dim):
    return (x, split_size, dim), {}


def _normalize_tensor_split_args(x, indices_or_sections, dim=0):
    return (x, indices_or_sections, dim), {}
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_split_with_size_args(*args, **kwargs)
        input_tensor, split_sections, dim = args
        output_num = len(split_sections)
        local_args = (input_tensor.to_local(), split_sections, dim)
        local_kwargs = {}
        cache_values = [input_tensor.layout, dim, output_num]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
89
90
91
92
93
94
95
96
97
98
        in_tensor_map = layout.alias_tensor_map
        ndim = len(in_tensor_map)

        if dim < 0:
            dim = ndim + dim
        if not 0 <= dim < ndim:
            raise ValueError(
                f"For {self.op_name}, dimension should be in range [0, {ndim}), "
                f"but got {dim}."
            )
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_split_with_size_args(*args, **kwargs)
        input_tensor, split_sections, dim = args
        output_num = len(split_sections)
        local_args = (input_tensor.to_local(), split_sections, dim)
        local_kwargs = {}
        cache_values = [input_tensor.layout, dim, output_num]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

        Raises:
            ValueError: If any rule above is violated.
        """
        layout = cache_values[0]
        dim = cache_values[1]
        output_num = cache_values[2]

        if not self._allow_partial_inputs:
            self._check_partial_inputs([layout])

        in_tensor_map = layout.alias_tensor_map
        ndim = len(in_tensor_map)

        if dim < 0:
            dim = ndim + dim
        if not 0 <= dim < ndim:
            raise ValueError(
                f"For {self.op_name}, dimension should be in range [0, {ndim}), "
                f"but got {dim}."
            )

        if in_tensor_map[dim] != "None":
            raise ValueError(
                f"For {self.op_name}, can not split tensor at sharded axis[{dim}], "
                f"but got layout: {layout}."
            )

        return (tuple(copy.deepcopy(layout) for _ in range(output_num)), None)


class SplitDistributedOp(DistributedOp):
    """Distributed implementation for Split operator (MindSpore Split and torch.split)."""
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_split_args(*args, **kwargs)
        input_tensor, split_size_or_sections, dim = args

        if isinstance(split_size_or_sections, int):
            output_num = math.ceil(input_tensor.shape[dim] / split_size_or_sections)
        else:
            output_num = len(split_size_or_sections)

        local_args = (input_tensor.to_local(), split_size_or_sections, dim)
        local_kwargs = {}
        cache_values = [input_tensor.layout, dim, output_num]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
229
230
231
232
233
234
235
236
237
        in_tensor_map = layout.alias_tensor_map
        ndim = len(in_tensor_map)

        if dim < 0:
            dim = ndim + dim
        if not 0 <= dim < ndim:
            raise ValueError(
                f"For {self.op_name}, dimension should be in range [0, {ndim}), "
                f"but got {dim}."
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_split_tensor_args(*args, **kwargs)
        input_tensor, split_size, dim = args
        output_num = math.ceil(input_tensor.shape[dim] / split_size)
        local_args = (input_tensor.to_local(), split_size, dim)
        local_kwargs = {}
        cache_values = [input_tensor.layout, dim, output_num]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
296
297
298
299
300
301
302
303
304
305
        in_tensor_map = layout.alias_tensor_map
        ndim = len(in_tensor_map)

        if dim < 0:
            dim = ndim + dim
        if not 0 <= dim < ndim:
            raise ValueError(
                f"For {self.op_name}, dimension should be in range [0, {ndim}), "
                f"but got {dim}."
            )
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_split_tensor_args(*args, **kwargs)
        input_tensor, split_size, dim = args
        output_num = math.ceil(input_tensor.shape[dim] / split_size)
        local_args = (input_tensor.to_local(), split_size, dim)
        local_kwargs = {}
        cache_values = [input_tensor.layout, dim, output_num]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

        Raises:
            ValueError: If any rule above is violated.
        """
        layout = cache_values[0]
        dim = cache_values[1]
        output_num = cache_values[2]

        if not self._allow_partial_inputs:
            self._check_partial_inputs([layout])

        in_tensor_map = layout.alias_tensor_map
        ndim = len(in_tensor_map)

        if dim < 0:
            dim = ndim + dim
        if not 0 <= dim < ndim:
            raise ValueError(
                f"For {self.op_name}, dimension should be in range [0, {ndim}), "
                f"but got {dim}."
            )

        if in_tensor_map[dim] != "None":
            raise ValueError(
                f"For {self.op_name}, can not split tensor at sharded axis[{dim}], "
                f"but got layout: {layout}."
            )

        return (tuple(copy.deepcopy(layout) for _ in range(output_num)), None)


class TensorSplitDistributedOp(DistributedOp):
    """Distributed implementation for tensor_split operator."""
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
                f"For {self.op_name}, indices_or_sections must be an integer, "
                f"list, tuple, or 1D tensor."
            )

        local_indices = indices_or_sections
        if hasattr(indices_or_sections, "_layout"):
            local_indices = indices_or_sections.to_local()

        local_args = (input_tensor.to_local(), local_indices, dim)
        local_kwargs = {}
        cache_values = [input_tensor.layout, dim, output_num]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
hyper_parallel/core/shard/ops/parallel_topk.py
82
83
84
85
86
87
88
89
90

        if dim is None:
            dim = -1
        if not isinstance(dim, int):
            raise ValueError(
                f"For {self.op_name}, dimension should be int, but got {type(dim)}"
            )

        alias_map = layout.alias_tensor_map
hyper_parallel/core/shard/ops/parallel_unbind.py
79
80
81
82
83
84
85
86
87
        alias_tensor_map = layout.alias_tensor_map
        ndim = len(shape)

        if not isinstance(dim, int):
            raise ValueError(
                f"For {self.op_name}, dimension should be int, but got {type(dim)}"
            )

        if dim < -ndim or dim >= ndim: