Coverage for hyper_parallel / core / shard / ops / parallel_expand.py: 86%
88 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Expand operator.
17"""
19from hyper_parallel.core.layout import Layout
20from .parallel_ops import DistributedOp
22class ExpandDistributedOp(DistributedOp):
23 """Distributed implementation for torch.Tensor.expand."""
25 def infer_layout(self, layouts, extra_args=None):
26 """
27 Infer output layout for torch.Tensor.expand.
29 PyTorch semantics:
30 - Expands singleton dimensions (size 1) to larger sizes
31 - Passing -1 preserves the original size of that dimension
32 - Only dimensions with global size 1 can be expanded
33 - Existing dimensions being expanded MUST be unsharded:
35 Args:
36 layouts (tuple): Layouts of inputs. Expected:
37 layouts[0] (Layout): Input tensor layout (required).
38 extra_args (tuple): Should contain 'sizes'. Expected:
39 extra_args[0] (int): One element in desired expanded sizes (required).
40 ...
41 extra_args[n] (int): One element in desired expanded sizes (required).
43 Returns:
44 Layout: Output tensor layout with:
45 - New dimensions: unsharded (-1)
46 - Expanded existing dimensions: unsharded (-1)
47 - Preserved dimensions (-1 in sizes): original sharding preserved
48 """
49 if not layouts or layouts[0] is None:
50 raise ValueError(
51 f"Operation {self.op_name}: expand requires a valid input tensor layout."
52 )
53 input_layout = layouts[0]
54 in_tensor_map = input_layout.tensor_map
55 input_ndim = len(in_tensor_map)
57 if not extra_args or len(extra_args) < 1:
58 raise ValueError(
59 f"Operation {self.op_name}: expand requires 'sizes' parameter in extra_args."
60 )
61 output_ndim = len(extra_args)
63 # Normalize sizes to tuple
64 sizes = []
65 for i in range(output_ndim):
66 if not isinstance(extra_args[i], int):
67 raise ValueError(
68 f"Operation {self.op_name}: elements in 'sizes' parameter must be int."
69 )
70 sizes.append(extra_args[i])
71 sizes = tuple(sizes)
73 # output_ndim = len(sizes)
74 num_new_dims = output_ndim - input_ndim
76 # PyTorch only allows prepending new dimensions (not inserting in middle)
77 if num_new_dims < 0:
78 raise ValueError(
79 f"Operation {self.op_name}: Cannot reduce dimensions with expand. "
80 f"Input has {input_ndim} dims, requested {output_ndim} dims."
81 )
83 # Build output tensor map
84 output_map = []
86 # Rule 1: For the new dimensions, the size cannot be set to -1.
87 for i in range(num_new_dims):
88 if sizes[i] == -1:
89 raise ValueError(
90 f"Operation {self.op_name}: Cannot use -1 for new dimension at position {i}. "
91 )
92 output_map.append(-1) # Always unsharded
94 # Rule 2: Process existing dimensions
95 for i in range(input_ndim):
96 output_dim_idx = num_new_dims + i
97 requested_size = sizes[output_dim_idx]
99 if requested_size == -1:
100 # keep original sharding
101 output_map.append(in_tensor_map[i])
102 else:
103 # Cannot expand dimension which is sharded
104 if in_tensor_map[i] != -1:
105 raise ValueError(
106 f"Operation {self.op_name}: Cannot expand dimension {i} which is sharded."
107 )
108 # Expanded dimension becomes unsharded in output
109 output_map.append(-1)
111 # Construct output layout
112 mesh_shape = input_layout.mesh_shape
113 alias_name = input_layout.alias_name
114 rank_list = input_layout.rank_list
116 def idx_to_alias(idx, aliases):
117 if idx == -1:
118 return "None"
119 return aliases[len(aliases) - idx - 1]
121 output_alias_map = tuple(idx_to_alias(idx, alias_name) for idx in output_map)
123 output_layout = Layout(
124 mesh_shape=mesh_shape,
125 alias_name=alias_name,
126 rank_list=rank_list
127 )
128 output_layout = output_layout(*output_alias_map)
129 return output_layout
132class ExpandAsDistributedOp(DistributedOp):
133 """Distributed implementation for torch.Tensor.expand_as."""
135 def infer_layout(self, layouts, extra_args=None):
136 """
137 Infer output layout for expand_as.
139 PyTorch semantics:
140 - Only dimensions with global size == 1 can be expanded to larger sizes
141 - Dimensions with size > 1 must exactly match between input and target
142 - Broadcast replicates a single value across the expanded dimension
144 Critical sharding constraints:
145 - Input dimensions with global size == 1 MUST be unsharded (-1)
146 - When expanding a dimension (size 1 → N), this dimension must be unsharded in input layout
147 - Expanded dimensions become unsharded in output
148 - Non-expanded dimensions preserve their input sharding pattern
150 Args:
151 layouts (tuple): Layouts of inputs. Expected:
152 layouts[0] (Layout): Input tensor layout (required).
153 layouts[1] (Layout): Target tensor layout (No need).
154 extra_args (tuple): Must contain shape information. Expected:
155 extra_args[0][0] (tuple of int): Input global shape.
156 extra_args[0][1] (tuple of int): Target global shape.
158 Returns:
159 Layout: Output tensor layout with sharding preserved for non-expanded
160 dimensions and unsharded for expanded dimensions.
161 """
162 # Validate input layout
163 if not layouts or layouts[0] is None:
164 raise ValueError(
165 f"Operation {self.op_name}: expand requires a valid input tensor layout."
166 )
167 input_layout = layouts[0]
168 in_tensor_map = input_layout.tensor_map
169 input_ndim = len(in_tensor_map)
171 # Extract shape information from extra_args
172 if not extra_args or extra_args[0] is None or len(extra_args[0]) < 2:
173 raise ValueError(
174 f"Operation {self.op_name}: expand requires (input_global_shape, target_shape) "
175 f"in extra_args."
176 )
177 input_global_shape = extra_args[0][0]
178 target_shape = extra_args[0][1]
180 if not isinstance(target_shape, (tuple, list)):
181 raise ValueError(
182 f"Operation {self.op_name}: target_shape must be tuple/list, got {type(target_shape)}."
183 )
184 if not isinstance(input_global_shape, (tuple, list)):
185 raise ValueError(
186 f"Operation {self.op_name}: input_global_shape must be tuple/list, got {type(input_global_shape)}."
187 )
189 target_shape = tuple(target_shape)
190 input_global_shape = tuple(input_global_shape)
191 target_ndim = len(target_shape)
193 # PyTorch rule: target rank cannot be smaller than input rank
194 if target_ndim < input_ndim:
195 raise ValueError(
196 f"Operation {self.op_name}: target shape {target_shape} (ndim={target_ndim}) cannot be "
197 f"smaller than input shape {input_global_shape} (ndim={input_ndim})."
198 )
200 # Align dimensions (input to target)
201 num_leading_implicit = target_ndim - input_ndim
202 aligned_input_shape = (1,) * num_leading_implicit + input_global_shape
203 aligned_tensor_map = (-1,) * num_leading_implicit + in_tensor_map
205 # Validate expansion rules and build output tensor_map
206 output_tensor_map = []
207 for i, (in_size, tgt_size, shard_spec) in enumerate(
208 zip(aligned_input_shape, target_shape, aligned_tensor_map)
209 ):
210 if in_size == tgt_size:
211 # Dimension unchanged - preserve sharding pattern
212 output_tensor_map.append(shard_spec)
213 elif in_size == 1 and tgt_size > 1:
214 # Dimension is expanded (broadcast) - must be unsharded
215 if shard_spec != -1:
216 raise ValueError(
217 f"Operation {self.op_name}: Cannot expand sharded dimension {i} which is going to broadcast "
218 f"(global size 1 → {tgt_size})."
219 )
220 output_tensor_map.append(-1)
221 else:
222 raise ValueError(
223 f"Operation {self.op_name}: Cannot expand dimension {i} from size {in_size} "
224 f"to {tgt_size}."
225 )
227 # Construct output layout with same mesh configuration
228 mesh_shape = input_layout.mesh_shape
229 alias_name = input_layout.alias_name
230 rank_list = input_layout.rank_list
232 # Convert tensor_map indices to alias strings for Layout constructor
233 def idx_to_alias(idx, aliases):
234 if idx == -1:
235 return "None"
236 return aliases[len(aliases) - idx - 1]
238 output_map = tuple(idx_to_alias(idx, alias_name) for idx in output_tensor_map)
240 output_layout = Layout(
241 mesh_shape=mesh_shape,
242 alias_name=alias_name,
243 rank_list=rank_list
244 )
245 output_layout = output_layout(*output_map)
246 return output_layout