Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_expand.py: 92%
88 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Expand operator.
17"""
19from hyper_parallel.core.dtensor.layout import Layout
20from .parallel_ops import DistributedOp
23class ExpandDistributedOp(DistributedOp):
24 """Distributed implementation for torch.Tensor.expand."""
26 def infer_layout(self, layouts, extra_args=None):
27 """
28 Infer output layout for torch.Tensor.expand.
30 PyTorch semantics:
31 - Expands singleton dimensions (size 1) to larger sizes
32 - Passing -1 preserves the original size of that dimension
33 - Only dimensions with global size 1 can be expanded
34 - Existing dimensions being expanded MUST be unsharded:
36 Args:
37 layouts (tuple): Layouts of inputs. Expected:
38 layouts[0] (Layout): Input tensor layout (required).
39 extra_args (tuple): Should contain 'sizes'. Expected:
40 extra_args[0] (int): One element in desired expanded sizes (required).
41 ...
42 extra_args[n] (int): One element in desired expanded sizes (required).
44 Returns:
45 Layout: Output tensor layout with:
46 - New dimensions: unsharded (-1)
47 - Expanded existing dimensions: unsharded (-1)
48 - Preserved dimensions (-1 in sizes): original sharding preserved
49 """
50 if not layouts or layouts[0] is None:
51 raise ValueError(
52 f"Operation {self.op_name}: expand requires a valid input tensor layout."
53 )
54 input_layout = layouts[0]
55 in_tensor_map = input_layout.tensor_map
56 input_ndim = len(in_tensor_map)
58 if not extra_args or len(extra_args) < 1:
59 raise ValueError(
60 f"Operation {self.op_name}: expand requires 'sizes' parameter in extra_args."
61 )
62 output_ndim = len(extra_args)
64 # Normalize sizes to tuple
65 sizes = []
66 for i in range(output_ndim):
67 if not isinstance(extra_args[i], int):
68 raise ValueError(
69 f"Operation {self.op_name}: elements in 'sizes' parameter must be int."
70 )
71 sizes.append(extra_args[i])
72 sizes = tuple(sizes)
74 # output_ndim = len(sizes)
75 num_new_dims = output_ndim - input_ndim
77 # PyTorch only allows prepending new dimensions (not inserting in middle)
78 if num_new_dims < 0:
79 raise ValueError(
80 f"Operation {self.op_name}: Cannot reduce dimensions with expand. "
81 f"Input has {input_ndim} dims, requested {output_ndim} dims."
82 )
84 # Build output tensor map
85 output_map = []
87 # Rule 1: For the new dimensions, the size cannot be set to -1.
88 for i in range(num_new_dims):
89 if sizes[i] == -1:
90 raise ValueError(
91 f"Operation {self.op_name}: Cannot use -1 for new dimension at position {i}. "
92 )
93 output_map.append(-1) # Always unsharded
95 # Rule 2: Process existing dimensions
96 for i in range(input_ndim):
97 output_dim_idx = num_new_dims + i
98 requested_size = sizes[output_dim_idx]
100 if requested_size == -1:
101 # keep original sharding
102 output_map.append(in_tensor_map[i])
103 else:
104 # Cannot expand dimension which is sharded
105 if in_tensor_map[i] != -1:
106 raise ValueError(
107 f"Operation {self.op_name}: Cannot expand dimension {i} which is sharded."
108 )
109 # Expanded dimension becomes unsharded in output
110 output_map.append(-1)
112 # Construct output layout
113 mesh_shape = input_layout.mesh_shape
114 alias_name = input_layout.alias_name
115 rank_list = input_layout.rank_list
117 def idx_to_alias(idx, aliases):
118 if idx == -1:
119 return "None"
120 return aliases[len(aliases) - idx - 1]
122 output_alias_map = tuple(idx_to_alias(idx, alias_name) for idx in output_map)
124 output_layout = Layout(
125 mesh_shape=mesh_shape,
126 alias_name=alias_name,
127 rank_list=rank_list
128 )
129 output_layout = output_layout(*output_alias_map)
130 return output_layout
133class ExpandAsDistributedOp(DistributedOp):
134 """Distributed implementation for torch.Tensor.expand_as."""
136 def infer_layout(self, layouts, extra_args=None):
137 """
138 Infer output layout for expand_as.
140 PyTorch semantics:
141 - Only dimensions with global size == 1 can be expanded to larger sizes
142 - Dimensions with size > 1 must exactly match between input and target
143 - Broadcast replicates a single value across the expanded dimension
145 Critical sharding constraints:
146 - Input dimensions with global size == 1 MUST be unsharded (-1)
147 - When expanding a dimension (size 1 → N), this dimension must be unsharded in input layout
148 - Expanded dimensions become unsharded in output
149 - Non-expanded dimensions preserve their input sharding pattern
151 Args:
152 layouts (tuple): Layouts of inputs. Expected:
153 layouts[0] (Layout): Input tensor layout (required).
154 layouts[1] (Layout): Target tensor layout (No need).
155 extra_args (tuple): Must contain shape information. Expected:
156 extra_args[0][0] (tuple of int): Input global shape.
157 extra_args[0][1] (tuple of int): Target global shape.
159 Returns:
160 Layout: Output tensor layout with sharding preserved for non-expanded
161 dimensions and unsharded for expanded dimensions.
162 """
163 # Validate input layout
164 if not layouts or layouts[0] is None:
165 raise ValueError(
166 f"Operation {self.op_name}: expand requires a valid input tensor layout."
167 )
168 input_layout = layouts[0]
169 in_tensor_map = input_layout.tensor_map
170 input_ndim = len(in_tensor_map)
172 # Extract shape information from extra_args
173 if not extra_args or extra_args[0] is None or len(extra_args[0]) < 2:
174 raise ValueError(
175 f"Operation {self.op_name}: expand requires (input_global_shape, target_shape) "
176 f"in extra_args."
177 )
178 input_global_shape = extra_args[0][0]
179 target_shape = extra_args[0][1]
181 if not isinstance(target_shape, (tuple, list)):
182 raise ValueError(
183 f"Operation {self.op_name}: target_shape must be tuple/list, got {type(target_shape)}."
184 )
185 if not isinstance(input_global_shape, (tuple, list)):
186 raise ValueError(
187 f"Operation {self.op_name}: input_global_shape must be tuple/list, got {type(input_global_shape)}."
188 )
190 target_shape = tuple(target_shape)
191 input_global_shape = tuple(input_global_shape)
192 target_ndim = len(target_shape)
194 # PyTorch rule: target rank cannot be smaller than input rank
195 if target_ndim < input_ndim:
196 raise ValueError(
197 f"Operation {self.op_name}: target shape {target_shape} (ndim={target_ndim}) cannot be "
198 f"smaller than input shape {input_global_shape} (ndim={input_ndim})."
199 )
201 # Align dimensions (input to target)
202 num_leading_implicit = target_ndim - input_ndim
203 aligned_input_shape = (1,) * num_leading_implicit + input_global_shape
204 aligned_tensor_map = (-1,) * num_leading_implicit + in_tensor_map
206 # Validate expansion rules and build output tensor_map
207 output_tensor_map = []
208 for i, (in_size, tgt_size, shard_spec) in enumerate(
209 zip(aligned_input_shape, target_shape, aligned_tensor_map)
210 ):
211 if in_size == tgt_size:
212 # Dimension unchanged - preserve sharding pattern
213 output_tensor_map.append(shard_spec)
214 elif in_size == 1 and tgt_size > 1:
215 # Dimension is expanded (broadcast) - must be unsharded
216 if shard_spec != -1:
217 raise ValueError(
218 f"Operation {self.op_name}: Cannot expand sharded dimension {i} which is going to broadcast "
219 f"(global size 1 → {tgt_size})."
220 )
221 output_tensor_map.append(-1)
222 else:
223 raise ValueError(
224 f"Operation {self.op_name}: Cannot expand dimension {i} from size {in_size} "
225 f"to {tgt_size}."
226 )
228 # Construct output layout with same mesh configuration
229 mesh_shape = input_layout.mesh_shape
230 alias_name = input_layout.alias_name
231 rank_list = input_layout.rank_list
233 # Convert tensor_map indices to alias strings for Layout constructor
234 def idx_to_alias(idx, aliases):
235 if idx == -1:
236 return "None"
237 return aliases[len(aliases) - idx - 1]
239 output_map = tuple(idx_to_alias(idx, alias_name) for idx in output_tensor_map)
241 output_layout = Layout(
242 mesh_shape=mesh_shape,
243 alias_name=alias_name,
244 rank_list=rank_list
245 )
246 output_layout = output_layout(*output_map)
247 return output_layout