Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_repeat.py: 100%
44 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Repeat operator.
17"""
19from hyper_parallel.core.dtensor.layout import Layout
20from .parallel_ops import DistributedOp
23class RepeatDistributedOp(DistributedOp):
24 """Distributed implementation for torch.Tensor.repeat."""
26 def infer_layout(self, layouts, extra_args=None):
27 """
28 Infer output layout for torch.Tensor.repeat.
30 PyTorch semantics:
31 - Repeats this tensor along the specified dimensions.
32 - If the number of repeat dimensions is larger than the tensor dimensions,
33 the tensor is implicitly unsqueezed at the front.
34 - The number of repeat dimensions cannot be smaller than the tensor dimensions.
35 - Dimensions being repeated (>1 or 0) MUST be unsharded.
37 Args:
38 layouts (tuple): Layouts of inputs. Expected:
39 layouts[0] (Layout): Input tensor layout (required).
40 extra_args (tuple/list): Should contain the repeat sizes.
42 Returns:
43 Layout: Output tensor layout with:
44 - New prepended dimensions: unsharded (-1)
45 - Repeated existing dimensions (size != 1): unsharded (-1)
46 - Preserved existing dimensions (size == 1): original sharding preserved
47 """
48 if not layouts or layouts[0] is None:
49 raise ValueError(
50 f"Operation {self.op_name}: repeat requires a valid input tensor layout."
51 )
53 input_layout = layouts[0]
54 in_tensor_map = input_layout.tensor_map
55 input_ndim = len(in_tensor_map)
57 if not extra_args or len(extra_args) < 1:
58 raise ValueError(
59 f"Operation {self.op_name}: repeat requires repeat sizes in extra_args."
60 )
62 # Robustly handle sizes unpacking (e.g., if args are packed as a single tuple)
63 if len(extra_args) == 1 and isinstance(extra_args[0], (tuple, list)):
64 flat_args = extra_args[0]
65 else:
66 flat_args = extra_args
68 # Normalize repeat sizes to tuple of ints
69 repeats = []
70 for arg in flat_args:
71 if not isinstance(arg, int):
72 arg = int(arg)
73 repeats.append(arg)
74 repeats = tuple(repeats)
75 output_ndim = len(repeats)
77 num_new_dims = output_ndim - input_ndim
78 output_map = []
80 # Rule 1: New prepended dimensions are always unsharded
81 for _ in range(num_new_dims):
82 output_map.append(-1)
84 # Rule 2: Process existing dimensions
85 for i in range(input_ndim):
86 repeat_idx = num_new_dims + i
87 repeat_times = repeats[repeat_idx]
89 if repeat_times == 1:
90 # If the dimension is not repeated, keep the original sharding
91 output_map.append(in_tensor_map[i])
92 else:
93 # If the dimension is repeated (or zeroed), it cannot be currently sharded
94 if in_tensor_map[i] != -1:
95 raise ValueError(
96 f"Operation {self.op_name}: Cannot repeat dimension {i} which is sharded. "
97 f"Please redistribute (unshard) the tensor along this dimension first."
98 )
99 # Repeated dimension remains unsharded in output
100 output_map.append(-1)
102 # Construct output layout mapping
103 mesh_shape = input_layout.mesh_shape
104 alias_name = input_layout.alias_name
105 rank_list = input_layout.rank_list
107 def idx_to_alias(idx, aliases):
108 """Convert layout index back to alias string mapping"""
109 if idx == -1:
110 return "None"
111 return aliases[len(aliases) - idx - 1]
113 output_alias_map = tuple(idx_to_alias(idx, alias_name) for idx in output_map)
115 # Instantiate new layout
116 output_layout = Layout(
117 mesh_shape=mesh_shape,
118 alias_name=alias_name,
119 rank_list=rank_list
120 )
121 output_layout = output_layout(*output_alias_map)
123 return output_layout