Coverage for hyper_parallel / core / shard / ops / parallel_repeat_interleave.py: 85%
33 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
16"""
17Distributed implementation for RepeatInterleave operator.
18"""
19from hyper_parallel.core.layout import Layout
20from .parallel_ops import DistributedOp
22class RepeatInterleaveDistributedOp(DistributedOp):
23 """Distributed implementation for torch.repeat_interleave."""
25 def infer_layout(self, layouts, extra_args):
26 """
27 Infer output layout for RepeatInterleave operator.
29 RepeatInterleave: output = repeat_interleave(input, repeats, dim)
31 Rules:
32 1. dim = None if not specified.
33 2. The dimension `dim` MUST be unsharded to ensure global repeat_interleave correctness.
34 3. Output layout usually same as input, but shape changes.
36 Args:
37 layouts (tuple): Layouts of inputs. Expected:
38 layouts[0] (Layout): Input tensor layout (required).
39 extra_args (tuple, optional): Contains repeats and dim. Expected:
40 extra_args[0] (int or Tensor): Number of repeats or Tensor
41 extra_args[1] (int, optional): Dimension to repeat. Defaults to -1.
43 Returns:
44 tuple: Layouts for values.
45 """
46 if not layouts or layouts[0] is None:
47 raise ValueError("repeat_interleave requires a valid input tensor layout.")
49 input_layout = layouts[0]
50 in_tensor_map = input_layout.tensor_map
51 dim = None # The dimension along which to repeat values. By default, use the flattened input array, and return a flat output array.
52 if len(extra_args) >= 2 and extra_args[1] is not None:
53 dim = extra_args[1]
54 if dim is None:
55 sharded_dims = [i for i, shard in enumerate(in_tensor_map) if shard != -1]
56 if not sharded_dims: # not shard
57 output_tensor_map = [-1]
58 # Only can shard on the first dimension.
59 elif sharded_dims == [0] and in_tensor_map[0] != -1:
60 output_tensor_map = [in_tensor_map[0]]
61 else:
62 # Other dims must NOT be sharded.
63 raise ValueError(
64 f"Operation {self.op_name}: Cannot flatten tensor when dim=None."
65 )
66 def idx_to_alias(idx, aliases):
67 if idx == -1:
68 return "None"
69 return aliases[len(aliases) - idx - 1]
70 output_map = tuple(idx_to_alias(idx, input_layout.alias_name) for idx in output_tensor_map)
72 output_layout = Layout(
73 mesh_shape=input_layout.mesh_shape,
74 alias_name=input_layout.alias_name,
75 rank_list=input_layout.rank_list
76 )
78 return output_layout(*output_map)
79 input_dim = len(in_tensor_map)
80 if dim < 0:
81 dim = input_dim + dim
82 # Check if dimension is within valid range
83 if not 0 <= dim < input_dim:
84 raise ValueError(f"Dimension out of range (expected to be in [0, {input_dim}), but got {dim}).")
85 # The chosen dim must NOT be sharded
86 if in_tensor_map[dim] != -1:
87 raise ValueError(
88 f"Operation {self.op_name}: Cannot perform sharding on params along the chosen dim"
89 )
90 # Output layout same as input layout (shape change does not affect sharding pattern)
91 return input_layout