Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/shard/ops/parallel_argsort.py 100%  
hyper_parallel/core/shard/ops/parallel_expand.py 75.7% 26,30,47-52,94,101,107,197,225-230
hyper_parallel/core/shard/ops/parallel_isin.py 100%  
hyper_parallel/core/shard/ops/parallel_nonzero.py 100%  
hyper_parallel/core/shard/ops/parallel_one_hot_ext.py 59.4% 29,46-47,49-51,53-54,56-59,103
hyper_parallel/core/shard/ops/parallel_slice_ext.py 74.1% 26,43-48
hyper_parallel/core/shard/ops/parallel_split.py 47.7% 27,31,35,56-62,93,95,123-129,149-151,153-154,156-157,159-162,167-168,173,190-191,193-194,196,198-201,233,263-269,300,302,330-336,356-358,360-361,363-364,366-369,374-375,380,412-414,416-419
hyper_parallel/core/shard/ops/parallel_unbind.py 97.1% 83
hyper_parallel/core/shard/ops/parallel_expand.py
22
23
24
25
26
27
28
29
30
31
32
33
34
from .parallel_ops import DistributedOp


def _normalize_expand_args(input_tensor, *sizes):
    return (input_tensor, *sizes), {}


def _normalize_expand_as_args(input_tensor, target_tensor):
    return (input_tensor, target_tensor), {}


class ExpandDistributedOp(DistributedOp):
    """Distributed implementation for torch.Tensor.expand."""
43
44
45
46
47
48
49
50
51
52
53
54
55
56

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_expand_args(*args, **kwargs)
        input_tensor = args[0]
        sizes = tuple(args[1:])
        local_args = (input_tensor.to_local(), *sizes)
        cache_values = [input_tensor.layout, input_tensor.shape, sizes]
        return local_args, {}, cache_values

    @staticmethod
    def _validate_input_layouts(
        cache_values: list,
90
91
92
93
94
95
96
97
            )

        input_shape = cache_values[1] if len(cache_values) > 1 else None
        if not isinstance(input_shape, tuple):
            raise ValueError(
                f"For {op_name}, input_shape should be a tuple, "
                f"but got {type(input_shape)}."
            )
 97
 98
 99
100
101
102
103
104
105
            )

        sizes = cache_values[2] if len(cache_values) > 2 else None
        if sizes is None or len(sizes) < 1:
            raise ValueError(
                f"For {op_name}, sizes should be a non-empty tuple of ints, "
                f"but got {sizes}."
            )
        for i, sz in enumerate(sizes):
103
104
105
106
107
108
109
110
                f"but got {sizes}."
            )
        for i, sz in enumerate(sizes):
            if not isinstance(sz, int):
                raise ValueError(
                    f"For {op_name}, elements in sizes should be int, "
                    f"but got {type(sz)} at position {i}."
                )
193
194
195
196
197
198
199
200
                        f"got mapping {in_alias_map[i]}."
                    )
                output_map.append("None")
            else:
                raise ValueError(
                    f"For {self.op_name}, cannot expand dimension {i} "
                    f"from size {input_size} to {requested_size}."
                )
221
222
223
224
225
226
227
228
229
230
231
232
233
234

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_expand_as_args(*args, **kwargs)
        input_tensor = args[0]
        target_tensor = args[1]
        local_args = (input_tensor.to_local(), target_tensor.to_local())
        cache_values = [input_tensor.layout, input_tensor.shape, target_tensor.shape]
        return local_args, {}, cache_values

    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
        Infer output layout for expand_as.
hyper_parallel/core/shard/ops/parallel_one_hot_ext.py
25
26
27
28
29
30
31
32
33
platform = get_platform()


def _normalize_one_hot_ext_args(indices, num_classes, on_value, off_value, axis):
    return (indices, num_classes, on_value, off_value, axis), {}


class OneHotExtDistributedOp(DistributedOp):
    """Distributed implementation for OneHotExt operator."""
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_one_hot_ext_args(*args, **kwargs)
        indices, num_classes, on_value, off_value, axis = args

        indices_local = indices.to_local()
        on_value_local = on_value.to_local() if hasattr(on_value, '_layout') else on_value
        off_value_local = off_value.to_local() if hasattr(off_value, '_layout') else off_value

        on_value_layout = on_value.layout if hasattr(on_value, '_layout') else None
        off_value_layout = off_value.layout if hasattr(off_value, '_layout') else None

        local_args = (indices_local, num_classes, on_value_local, off_value_local, axis)
        local_kwargs = {}
        cache_values = [indices.layout, on_value_layout, off_value_layout, num_classes, axis]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
 99
100
101
102
103
104
105
106
107
        axis = self._validate_axis(axis)

        in_tensor_map = indices_layout.tensor_map
        if not in_tensor_map:
            raise ValueError(
                f"For {self.op_name}, indices tensor_map is empty."
            )

        self._validate_multi_dim_restriction(in_tensor_map, axis, indices_layout)
hyper_parallel/core/shard/ops/parallel_slice_ext.py
22
23
24
25
26
27
28
29
30
from .parallel_ops import DistributedOp


def _normalize_slice_ext_args(x, axis, begin, end, step):
    return (x, axis, begin, end, step), {}


class SliceExtDistributedOp(DistributedOp):
    """Distributed implementation for SliceExt operator."""
39
40
41
42
43
44
45
46
47
48
49
50
51
52

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_slice_ext_args(*args, **kwargs)
        input_tensor, axis, begin, end, step = args
        local_args = (input_tensor.to_local(), axis, begin, end, step)
        local_kwargs = {}
        cache_values = [input_tensor.layout, axis]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
hyper_parallel/core/shard/ops/parallel_split.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from .parallel_ops import DistributedOp


def _normalize_split_with_size_args(x, split_sections, dim):
    return (x, split_sections, dim), {}


def _normalize_split_args(x, split_size_or_sections, dim=0):
    return (x, split_size_or_sections, dim), {}


def _normalize_split_tensor_args(x, split_size, dim):
    return (x, split_size, dim), {}


def _normalize_tensor_split_args(x, indices_or_sections, dim=0):
    return (x, indices_or_sections, dim), {}
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_split_with_size_args(*args, **kwargs)
        input_tensor, split_sections, dim = args
        output_num = len(split_sections)
        local_args = (input_tensor.to_local(), split_sections, dim)
        local_kwargs = {}
        cache_values = [input_tensor.layout, dim, output_num]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
89
90
91
92
93
94
95
96
97
98
        in_tensor_map = layout.alias_tensor_map
        ndim = len(in_tensor_map)

        if dim < 0:
            dim = ndim + dim
        if not 0 <= dim < ndim:
            raise ValueError(
                f"For {self.op_name}, dimension should be in range [0, {ndim}), "
                f"but got {dim}."
            )
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_split_with_size_args(*args, **kwargs)
        input_tensor, split_sections, dim = args
        output_num = len(split_sections)
        local_args = (input_tensor.to_local(), split_sections, dim)
        local_kwargs = {}
        cache_values = [input_tensor.layout, dim, output_num]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

        Raises:
            ValueError: If any rule above is violated.
        """
        layout = cache_values[0]
        dim = cache_values[1]
        output_num = cache_values[2]

        if not self._allow_partial_inputs:
            self._check_partial_inputs([layout])

        in_tensor_map = layout.alias_tensor_map
        ndim = len(in_tensor_map)

        if dim < 0:
            dim = ndim + dim
        if not 0 <= dim < ndim:
            raise ValueError(
                f"For {self.op_name}, dimension should be in range [0, {ndim}), "
                f"but got {dim}."
            )

        if in_tensor_map[dim] != "None":
            raise ValueError(
                f"For {self.op_name}, can not split tensor at sharded axis[{dim}], "
                f"but got layout: {layout}."
            )

        return (tuple(copy.deepcopy(layout) for _ in range(output_num)), None)


class SplitDistributedOp(DistributedOp):
    """Distributed implementation for Split operator (MindSpore Split and torch.split)."""
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_split_args(*args, **kwargs)
        input_tensor, split_size_or_sections, dim = args

        if isinstance(split_size_or_sections, int):
            output_num = math.ceil(input_tensor.shape[dim] / split_size_or_sections)
        else:
            output_num = len(split_size_or_sections)

        local_args = (input_tensor.to_local(), split_size_or_sections, dim)
        local_kwargs = {}
        cache_values = [input_tensor.layout, dim, output_num]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
229
230
231
232
233
234
235
236
237
        in_tensor_map = layout.alias_tensor_map
        ndim = len(in_tensor_map)

        if dim < 0:
            dim = ndim + dim
        if not 0 <= dim < ndim:
            raise ValueError(
                f"For {self.op_name}, dimension should be in range [0, {ndim}), "
                f"but got {dim}."
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_split_tensor_args(*args, **kwargs)
        input_tensor, split_size, dim = args
        output_num = math.ceil(input_tensor.shape[dim] / split_size)
        local_args = (input_tensor.to_local(), split_size, dim)
        local_kwargs = {}
        cache_values = [input_tensor.layout, dim, output_num]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
296
297
298
299
300
301
302
303
304
305
        in_tensor_map = layout.alias_tensor_map
        ndim = len(in_tensor_map)

        if dim < 0:
            dim = ndim + dim
        if not 0 <= dim < ndim:
            raise ValueError(
                f"For {self.op_name}, dimension should be in range [0, {ndim}), "
                f"but got {dim}."
            )
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340

        Returns:
            tuple: (local_args, local_kwargs, cache_values)
        """
        args, kwargs = _normalize_split_tensor_args(*args, **kwargs)
        input_tensor, split_size, dim = args
        output_num = math.ceil(input_tensor.shape[dim] / split_size)
        local_args = (input_tensor.to_local(), split_size, dim)
        local_kwargs = {}
        cache_values = [input_tensor.layout, dim, output_num]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

        Raises:
            ValueError: If any rule above is violated.
        """
        layout = cache_values[0]
        dim = cache_values[1]
        output_num = cache_values[2]

        if not self._allow_partial_inputs:
            self._check_partial_inputs([layout])

        in_tensor_map = layout.alias_tensor_map
        ndim = len(in_tensor_map)

        if dim < 0:
            dim = ndim + dim
        if not 0 <= dim < ndim:
            raise ValueError(
                f"For {self.op_name}, dimension should be in range [0, {ndim}), "
                f"but got {dim}."
            )

        if in_tensor_map[dim] != "None":
            raise ValueError(
                f"For {self.op_name}, can not split tensor at sharded axis[{dim}], "
                f"but got layout: {layout}."
            )

        return (tuple(copy.deepcopy(layout) for _ in range(output_num)), None)


class TensorSplitDistributedOp(DistributedOp):
    """Distributed implementation for tensor_split operator."""
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
                f"For {self.op_name}, indices_or_sections must be an integer, "
                f"list, tuple, or 1D tensor."
            )

        local_indices = indices_or_sections
        if hasattr(indices_or_sections, "_layout"):
            local_indices = indices_or_sections.to_local()

        local_args = (input_tensor.to_local(), local_indices, dim)
        local_kwargs = {}
        cache_values = [input_tensor.layout, dim, output_num]
        return local_args, local_kwargs, cache_values

    # pylint: disable=W0237
    def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
        """
hyper_parallel/core/shard/ops/parallel_unbind.py
79
80
81
82
83
84
85
86
87
        alias_tensor_map = layout.alias_tensor_map
        ndim = len(shape)

        if not isinstance(dim, int):
            raise ValueError(
                f"For {self.op_name}, dimension should be int, but got {type(dim)}"
            )

        if dim < -ndim or dim >= ndim: