Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/shard/_op_dispatch.py 78.6% 579,746,903,1173,1263,1266
hyper_parallel/core/shard/ops/parallel_cumsum.py 100%  
hyper_parallel/core/shard/ops/parallel_elementwise.py 80.9% 29-33,134,161,627,632,661-662,665-666
hyper_parallel/core/shard/ops/parallel_embedding.py 95.3% 102,109
hyper_parallel/core/shard/ops/parallel_expand_dims.py 70.0% 26,71,81,89,91,117-120
hyper_parallel/core/shard/ops/parallel_flatten.py 85.2% 75,79,84,89
hyper_parallel/core/shard/ops/parallel_gather.py 100%  
hyper_parallel/core/shard/ops/parallel_repeat_interleave.py 92.9% 111,118,158,164
hyper_parallel/core/shard/ops/parallel_reshape.py 85.7% 32-33,35,37,39,319
hyper_parallel/core/shard/ops/parallel_slice.py 80.0% 58-60,62,103,108,132,138
hyper_parallel/core/shard/ops/parallel_squeeze.py 87.5% 159,161,167,213
hyper_parallel/core/shard/ops/parallel_tuple_elementwise.py 90.9% 55-56
hyper_parallel/core/shard/_op_dispatch.py
575
576
577
578
579
580
581
582
583

        if op_impl is None:
            op_impl = func

        py_output = OpDispatcher._call_op_impl(op_impl, _packed_call, input_args, input_kwargs)
        return distribute_op.wrap_output(py_output, output_layout)

    @staticmethod
    def _with_layout_infer_reshape(func: callable, *args) -> Tensor:
742
743
744
745
746
747
748
749
750

        if op_impl is None:
            op_impl = func

        py_output = OpDispatcher._call_op_impl(op_impl, _packed_call, input_args, input_kwargs)

        # set output layout
        if isinstance(py_output, (tuple, list)):
            output = ()
899
900
901
902
903
904
905
906
                    local_args = _apply_shard_offset_to_rng_args(local_args, offset_incr)
                local_results = op_call(*local_args, **local_kwargs)
        else:
            if maybe_user_generator is not None:
                local_kwargs["generator"] = maybe_user_generator
            local_results = op_call(*local_args, **local_kwargs)

        return self._wrap_random_result(op_name, local_results, first_arg, args, kwargs)
1169
1170
1171
1172
1173
1174
1175
1176
1177
                  the packed format was detected, otherwise the original args).
        """
        if op_name in unpack_ops and len(args) == 3 and \
            isinstance(args[1], str) and isinstance(args[2], (tuple, list)):
            return (args[0], args[1]), tuple(args[2])
        return None, args

    @staticmethod
    def _call_op_impl(op_impl: callable, packed_call, args, kwargs: dict):
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
        if not suffix:
            return self._with_layout_infer(op_call, *args, _packed_call=packed_call, **kwargs)

        if suffix == 'WithShape':
            return self._with_layout_infer_with_shape(op_call, *args, _packed_call=packed_call, **kwargs)

        if suffix == 'WithTupleExpand':
            return self._with_layout_infer_with_tuple_expand(op_call, *args, _packed_call=packed_call, **kwargs)

        handler_name = self._suffix_dispatch.get(suffix)
        if handler_name is None:
            raise RuntimeError(f"Operator {op_name} specified wrong suffix in parallel yaml.")
hyper_parallel/core/shard/ops/parallel_elementwise.py
25
26
27
28
29
30
31
32
33
34
35
36
37
def _unwrap_local_value(value):
    """Convert DTensor-like values to local tensors while preserving containers."""
    if hasattr(value, "_layout"):
        return value.to_local()
    if isinstance(value, tuple):
        return tuple(_unwrap_local_value(item) for item in value)
    if isinstance(value, list):
        return [_unwrap_local_value(item) for item in value]
    return value


def _collect_layout_and_shape(value):
    """Collect layout and shape from one argument for layout inference cache."""
130
131
132
133
134
135
136
137
138

        aligned_layouts, aligned_shapes = self._align_layouts_and_shapes(layouts, input_shapes)

        if len(aligned_layouts) <= 1 or len(aligned_layouts) != len(aligned_shapes):
            return ((copy.deepcopy(valid_layouts[0]),), None)

        output_shape = self._compute_output_shape(aligned_shapes)
        merged_tensor_map, merged_partial = self._merge_all_layouts(
            aligned_layouts,
157
158
159
160
161
162
163
164
165
                raise ValueError(
                    f"For {self.op_name}, cannot infer layout without shapes: "
                    f"mismatched alias_tensor_map {first_alias_map} vs {layout.alias_tensor_map}."
                )
        return copy.deepcopy(first_layout)

    def _align_layouts_and_shapes(self, layouts, input_shapes):
        """
        Align layouts with shapes by position, skipping None layouts.
623
624
625
626
627
628
629
630
631
632
633
634
635
            ValueError: If propagated Partial status is not "sum" or None.
        """
        infer_result = super().infer_layout(cache_values)
        if infer_result is None:
            return infer_result

        output_layout = infer_result[0][0]
        for i, partial_type in enumerate(output_layout.partial):
            if partial_type is not None and partial_type != "sum":
                raise ValueError(
                    f"For {self.op_name}, inputs partial status should be 'sum' or None, "
                    f"but got {partial_type} at index {i}."
                )
657
658
659
660
661
662
663
664
665
666
667
668
                scaling_factor *= output_layout.mesh_shape[i]

        # use expand_impl only when one of x1 and x2 is with partial placement.
        def _expand_impl1(x1, x2, *args, **kwargs):
            add_out = func(x1 / scaling_factor, x2, *args, **kwargs)
            return add_out

        def _expand_impl2(x1, x2, *args, **kwargs):
            add_out = func(x1, x2 / scaling_factor, *args, **kwargs)
            return add_out

        return _expand_impl1 if not x1_partial else _expand_impl2
hyper_parallel/core/shard/ops/parallel_embedding.py
 98
 99
100
101
102
103
104
105

        self._check_partial_inputs([input_layout, weight_layout])

        if input_layout.mesh_shape != weight_layout.mesh_shape:
            raise ValueError(
                f"For {self.op_name}, input and weight must have the same mesh_shape, "
                f"but got input: {input_layout.mesh_shape} and weight: {weight_layout.mesh_shape}"
            )
105
106
107
108
109
110
111
112
            )

        weight_tensor_map = weight_layout.tensor_map
        if len(weight_tensor_map) != 2:
            raise ValueError(
                f"For {self.op_name}, weight should be 2D [vocab_size, embedding_dim], "
                f"but got {len(weight_tensor_map)}D"
            )
hyper_parallel/core/shard/ops/parallel_expand_dims.py
22
23
24
25
26
27
28
29
30


def _normalize_expand_dims_args(x, axis=None, dim=None):
    if axis is None:
        axis = dim
    return (x, axis), {}


class ExpandDimsDistributedOp(DistributedOp):
67
68
69
70
71
72
73
74
            ValueError: If input has Partial status, input layout is missing,
                axis is missing or invalid, or axis is out of range.
        """
        if not cache_values:
            raise ValueError(
                f"For {self.op_name}, cache_values should contain input layout, "
                f"but got empty cache_values."
            )
77
78
79
80
81
82
83
84
        if not self._allow_partial_inputs:
            self._check_partial_inputs([x_layout])

        if x_layout.mesh_shape is None:
            raise ValueError(
                f"For {self.op_name}, input layout mesh_shape should not be None, "
                f"but got None."
            )
85
86
87
88
89
90
91
92
93
94
95

        axis = cache_values[1] if len(cache_values) > 1 else None

        if axis is None:
            raise ValueError(f"For {self.op_name}, axis parameter is required.")
        if not isinstance(axis, int):
            raise ValueError(
                f"For {self.op_name}, axis should be int, but got {type(axis)}."
            )

        in_rank = len(x_layout.alias_tensor_map)
113
114
115
116
117
118
119
120
121
122
        )
        output_layout = output_layout(*x_map)

        if self._allow_partial_inputs:
            for i, partial_op in enumerate(x_layout.partial):
                if partial_op is not None:
                    dev_axis_name = x_layout.alias_name[i]
                    output_layout.set_partial_by_dev_axis(dev_axis_name, partial_op)

        return ((output_layout,), None)
hyper_parallel/core/shard/ops/parallel_flatten.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        input_layout, start_dim, end_dim, input_shape = (
            cache_values[0], cache_values[1], cache_values[2], cache_values[3]
        )
        if input_layout is None:
            raise ValueError(
                f"For {self.op_name}, flatten requires a valid input tensor layout."
            )
        if not isinstance(input_shape, (list, tuple)):
            raise ValueError(
                f"For {self.op_name}, input_shape should be list or tuple, "
                f"but got {type(input_shape)}."
            )
        if len(input_shape) != len(input_layout.tensor_map):
            raise ValueError(
                f"For {self.op_name}, input shape rank should match layout rank, "
                f"but got {len(input_shape)} and {len(input_layout.tensor_map)}."
            )
        if not isinstance(start_dim, int) or not isinstance(end_dim, int):
            raise ValueError(
                f"For {self.op_name}, start_dim and end_dim should be int, "
                f"but got {type(start_dim)} and {type(end_dim)}."
            )
hyper_parallel/core/shard/ops/parallel_repeat_interleave.py
107
108
109
110
111
112
113
114
115
        local_input = input_tensor.to_local()

        # repeats can be int or Tensor; handle DTensor defensively
        if hasattr(repeats, 'to_local'):
            local_repeats = repeats.to_local()
        else:
            local_repeats = repeats

        local_args = (local_input, local_repeats, dim)
114
115
116
117
118
119
120
121

        local_args = (local_input, local_repeats, dim)
        local_kwargs = {}
        if output_size is not None:
            local_kwargs['output_size'] = output_size

        cache_values = [input_tensor.layout, dim]
        return local_args, local_kwargs, cache_values
154
155
156
157
158
159
160
161
162
            # Flatten mode: output is 1-D.
            sharded_dims = []
            for i, shard in enumerate(in_tensor_map):
                if isinstance(shard, (list, tuple)):
                    is_sharded = any(axis != "None" for axis in shard)
                else:
                    is_sharded = shard != "None"
                if is_sharded:
                    sharded_dims.append(i)
160
161
162
163
164
165
166
167
168
                    is_sharded = shard != "None"
                if is_sharded:
                    sharded_dims.append(i)
            if not sharded_dims:
                output_tensor_map = ("None",)
            elif sharded_dims == [0]:
                output_tensor_map = (in_tensor_map[0],)
            else:
                raise ValueError(
hyper_parallel/core/shard/ops/parallel_reshape.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def _normalize_reshape_args(x, *shape, **kwargs):
    """Normalize reshape/view arguments into positional args and empty kwargs."""
    unexpected_kwargs = set(kwargs) - {'shape'}
    if unexpected_kwargs:
        unexpected = next(iter(unexpected_kwargs))
        raise TypeError(f"reshape got an unexpected keyword argument '{unexpected}'.")
    if shape and 'shape' in kwargs:
        raise TypeError("reshape got shape from both args and kwargs.")
    if not shape and 'shape' in kwargs:
        shape = (kwargs['shape'],)
    if not shape:
        raise TypeError("reshape missing required shape argument.")
    return (x,) + shape, {}


def _filter_none_split_tensor_map(tensor_map, mesh_shape):
315
316
317
318
319
320
321
322
            )

        x_layout, dst_shape, input_shape = cache_values[0], cache_values[1], cache_values[2]
        if x_layout is None:
            raise ValueError(f"For {self.op_name}, reshape requires a valid input tensor layout.")

        out_layout, local_dst_shape = self._infer_reshape_layout(x_layout, dst_shape, input_shape)
        return ((out_layout,), local_dst_shape)
hyper_parallel/core/shard/ops/parallel_slice.py
54
55
56
57
58
59
60
61
62
63
64
65
66
                shard_dim.append(1)
                continue
            if isinstance(axis_name, (tuple, list)):
                shard_num = 1
                for axis in axis_name:
                    if axis != "None":
                        shard_num *= layout.mesh.get_device_num_along_axis(axis)
                shard_dim.append(shard_num)
                continue
            shard_dim.append(layout.mesh.get_device_num_along_axis(axis_name))
        return shard_dim

    def _check_layout(self, layout, begin, end, shape):
 99
100
101
102
103
104
105
106
107
108
109
110
111
        layout, begin, end, global_shape = cache_values
        self._check_partial_inputs([layout])

        if len(begin) != len(end) or len(begin) != len(global_shape):
            raise ValueError(
                f"For {self.op_name}, begin, end and global_shape must have the same length, "
                f"but got begin: {len(begin)}, end: {len(end)}, global_shape: {len(global_shape)}"
            )
        if len(begin) != len(layout.alias_tensor_map):
            raise ValueError(
                f"For {self.op_name}, slice arguments rank must match input layout rank, "
                f"but got args rank: {len(begin)} and layout rank: {len(layout.alias_tensor_map)}"
            )
128
129
130
131
132
133
134
135
136
        Returns:
            callable | None: expand_impl closure when local slice bounds differ, else None.
        """
        if func is None:
            return None

        begin = cache_values[1]
        end = cache_values[2]
        new_begin, new_end = infer_result[1]
134
135
136
137
138
139
140
141
142
        begin = cache_values[1]
        end = cache_values[2]
        new_begin, new_end = infer_result[1]
        if begin == new_begin and end == new_end:
            return None

        def expand_impl(input_tensor: object, *_unused_args: object) -> object:
            """Call Slice with local slice bounds."""
            return func(input_tensor, new_begin, new_end)
hyper_parallel/core/shard/ops/parallel_squeeze.py
155
156
157
158
159
160
161
162
163
164
        # Convert axis to list if it's a single integer
        if isinstance(axis, int):
            axis = [axis]
        elif isinstance(axis, tuple):
            axis = list(axis)
        elif not isinstance(axis, list):
            raise ValueError(
                f"For {self.op_name}, axis should be int, list or tuple, "
                f"but got {type(axis)}."
            )
163
164
165
166
167
168
169
170
                f"but got {type(axis)}."
            )

        if not all(isinstance(ax, int) for ax in axis):
            raise ValueError(
                f"For {self.op_name}, every axis value should be int, "
                f"but got {axis}."
            )
209
210
211
212
213
214
215
216

    def _create_output_layout(self, x_layout, dims_to_squeeze):
        """Create output layout after squeezing dimensions."""
        if not dims_to_squeeze:
            return deepcopy(x_layout)

        # Get current alias tensor map
        x_map = list(x_layout.alias_tensor_map)
hyper_parallel/core/shard/ops/parallel_tuple_elementwise.py
51
52
53
54
55
56
57
58
59
60
            if isinstance(arg, (tuple, list)):
                expanded_args.extend(arg)
                local_args.append(tuple(_unwrap_local_value(item) for item in arg))
            else:
                expanded_args.append(arg)
                local_args.append(_unwrap_local_value(arg))

        local_kwargs = {key: _unwrap_local_value(value) for key, value in kwargs.items()}
        cache_values = [getattr(arg, "layout", None) for arg in expanded_args]
        cache_values.extend(getattr(value, "layout", None) for value in kwargs.values())