Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_repeat_interleave.py: 85%
33 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
16"""
17Distributed implementation for RepeatInterleave operator.
18"""
19from hyper_parallel.core.dtensor.layout import Layout
20from .parallel_ops import DistributedOp
23class RepeatInterleaveDistributedOp(DistributedOp):
24 """Distributed implementation for torch.repeat_interleave."""
26 def infer_layout(self, layouts, extra_args=None):
27 """
28 Infer output layout for RepeatInterleave operator.
30 RepeatInterleave: output = repeat_interleave(input, repeats, dim)
32 Rules:
33 1. dim = None if not specified.
34 2. The dimension `dim` MUST be unsharded to ensure global repeat_interleave correctness.
35 3. Output layout usually same as input, but shape changes.
37 Args:
38 layouts (tuple): Layouts of inputs. Expected:
39 layouts[0] (Layout): Input tensor layout (required).
40 extra_args (tuple, optional): Contains repeats and dim. Expected:
41 extra_args[0] (int or Tensor): Number of repeats or Tensor
42 extra_args[1] (int, optional): Dimension to repeat. Defaults to -1.
44 Returns:
45 tuple: Layouts for values.
46 """
47 if not layouts or layouts[0] is None:
48 raise ValueError("repeat_interleave requires a valid input tensor layout.")
50 input_layout = layouts[0]
51 in_tensor_map = input_layout.tensor_map
52 dim = None # The dimension along which to repeat values.
53 # By default, use the flattened input array, and return a flat output array.
54 if len(extra_args) >= 2 and extra_args[1] is not None:
55 dim = extra_args[1]
56 if dim is None:
57 sharded_dims = [i for i, shard in enumerate(in_tensor_map) if shard != -1]
58 if not sharded_dims: # not shard
59 output_tensor_map = [-1]
60 # Only can shard on the first dimension.
61 elif sharded_dims == [0] and in_tensor_map[0] != -1:
62 output_tensor_map = [in_tensor_map[0]]
63 else:
64 # Other dims must NOT be sharded.
65 raise ValueError(
66 f"Operation {self.op_name}: Cannot flatten tensor when dim=None."
67 )
69 def idx_to_alias(idx, aliases):
70 if idx == -1:
71 return "None"
72 return aliases[len(aliases) - idx - 1]
73 output_map = tuple(idx_to_alias(idx, input_layout.alias_name) for idx in output_tensor_map)
75 output_layout = Layout(
76 mesh_shape=input_layout.mesh_shape,
77 alias_name=input_layout.alias_name,
78 rank_list=input_layout.rank_list
79 )
81 return output_layout(*output_map)
82 input_dim = len(in_tensor_map)
83 if dim < 0:
84 dim = input_dim + dim
85 # Check if dimension is within valid range
86 if not 0 <= dim < input_dim:
87 raise ValueError(f"Dimension out of range (expected to be in [0, {input_dim}), but got {dim}).")
88 # The chosen dim must NOT be sharded
89 if in_tensor_map[dim] != -1:
90 raise ValueError(
91 f"Operation {self.op_name}: Cannot perform sharding on params along the chosen dim"
92 )
93 # Output layout same as input layout (shape change does not affect sharding pattern)
94 return input_layout