Coverage for hyper_parallel / core / shard / ops / parallel_matmul.py: 77%
193 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for MatMul operator.
17"""
19from hyper_parallel.core.layout import Layout
20from .parallel_ops import DistributedOp
22class MatMulExtDistributedOp(DistributedOp):
23 """Distributed implementation for MatMul operator."""
24 def infer_layout(self, layouts, extra_args):
25 """
26 Infer output layout for MatMul operator.
28 MatMul: output = x @ w
30 Rules:
31 1. Batch dimensions should have same layout
32 2. Contracting dimensions should have same layout
33 3. Output dimensions inherit layouts from non-contracting dimensions
35 Args:
36 x_layout (Layout): Layout of input x
37 w_layout (Layout): Layout of input w
39 Returns:
40 tuple: Layout for output tensor
41 """
42 if len(layouts) != 2:
43 raise ValueError(f"MatMul layout length is not 2, but {len(layouts)}")
44 x_layout = layouts[0]
45 w_layout = layouts[1]
46 if not x_layout or not w_layout:
47 raise ValueError(f"x_layout : {x_layout}, w_layout : {w_layout}")
48 x_mesh_shape = x_layout.mesh_shape
49 w_mesh_shape = w_layout.mesh_shape
50 if x_mesh_shape != w_mesh_shape:
51 raise ValueError("MatMul inputs must have same mesh_shape")
53 x_map = x_layout.alias_tensor_map
54 w_map = w_layout.alias_tensor_map
55 contract_dim = len(x_map) - 1
56 w_contract_dim = len(w_map) - 2
57 if x_map[contract_dim] != w_map[w_contract_dim]:
58 raise ValueError(f"Contracting dimensions must have same layout. "
59 f"Got {x_map[contract_dim]} and {w_map[w_contract_dim]}")
61 output_dim = len(w_map) - 1
62 output_map = x_map[:-1] + (w_map[output_dim],)
64 output_layout = Layout(
65 mesh_shape=x_layout.mesh_shape,
66 alias_name=x_layout.alias_name,
67 rank_list=x_layout.rank_list
68 )
69 out_layout = output_layout(*output_map)
71 # Set partial status
72 if x_map[contract_dim] != "None":
73 if isinstance(x_map[contract_dim], tuple):
74 for axis in x_map[contract_dim]:
75 out_layout.set_partial_by_dev_axis(axis, 'sum')
76 else:
77 out_layout.set_partial_by_dev_axis(x_map[contract_dim], 'sum')
79 return out_layout
82class MatMulDistributedOp(DistributedOp):
83 """Distributed implementation for MatMul operator."""
84 def infer_layout(self, layouts, extra_args):
85 """
86 Infer output layout for MatMul operator.
88 MatMul: output = x @ w, with possible transpose
90 Args:
91 layouts (tuple): Layouts of input tensors (x_layout, w_layout)
92 extra_args (tuple): Additional arguments (transpose_a, transpose_b)
94 Returns:
95 Layout: Layout for output tensor
96 """
97 if len(layouts) < 2:
98 raise ValueError("MatMul requires at least two input layouts")
100 x_layout, w_layout = layouts[:2]
102 if len(extra_args) != 2:
103 raise ValueError("MatMul requires two transpose input")
104 transpose_a, transpose_b = extra_args[0], extra_args[1]
106 x_dict = x_layout.to_dict()
107 w_dict = w_layout.to_dict()
109 if x_dict["mesh_shape"] != w_dict["mesh_shape"]:
110 raise ValueError("MatMul inputs must have same mesh_shape")
112 x_map = x_layout.alias_tensor_map
113 w_map = w_layout.alias_tensor_map
115 # Determine contracting dimensions based on transpose flags
116 if transpose_a:
117 x_input_dim = len(x_map) - 1
118 x_contract_dim = len(x_map) - 2 # Second to last dimension
119 else:
120 x_input_dim = len(x_map) - 2
121 x_contract_dim = len(x_map) - 1 # Last dimension
123 if transpose_b:
124 w_output_dim = len(w_map) - 2
125 w_contract_dim = len(w_map) - 1 # Last dimension
126 else:
127 w_output_dim = len(w_map) - 1
128 w_contract_dim = len(w_map) - 2 # Second to last dimension
130 # Validate contracting dimensions
131 if x_map[x_contract_dim] != w_map[w_contract_dim]:
132 raise ValueError(f"Contracting dimensions must have same layout. "
133 f"Got {x_map[x_contract_dim]} and {w_map[w_contract_dim]}")
135 # Create output layout
136 output_layout = Layout(
137 mesh_shape=x_layout.mesh_shape,
138 alias_name=x_layout.alias_name,
139 rank_list=x_layout.rank_list
140 )
141 output_map = list(x_map[:-2]) + [x_map[x_input_dim]] + [w_map[w_output_dim]]
142 output_layout = output_layout(*output_map)
144 # Set partial status
145 if x_map[x_contract_dim] != "None":
146 if isinstance(x_map[x_contract_dim], tuple):
147 for axis in x_map[x_contract_dim]:
148 output_layout.set_partial_by_dev_axis(axis, 'sum')
149 else:
150 output_layout.set_partial_by_dev_axis(x_map[x_contract_dim], 'sum')
152 return output_layout
155class BaseBatchMatMulDistributedOp(DistributedOp):
156 """Base class for BatchMatMul distributed implementations."""
158 def _merge_batch_entry(self, x_dims, w_dims):
159 """
160 Merge two batch tensor_map entries with broadcasting:
161 - none vs X -> X
162 - X vs none -> X
163 - X vs X (exact same after normalization) -> X
164 - otherwise -> conflict
165 """
166 if self._is_none_entry(x_dims) and self._is_none_entry(w_dims):
167 return "None"
168 if self._is_none_entry(x_dims):
169 return w_dims
170 if self._is_none_entry(w_dims):
171 return x_dims
172 if x_dims == w_dims:
173 return x_dims
174 raise ValueError(f"Incompatible batch sharding between inputs: {x_dims} vs {w_dims}")
176 def _is_none_entry(self, entry):
177 """An entry is 'none' (no sharding) if it is 'None' or tuple of all 'None'."""
178 if isinstance(entry, tuple):
179 return all(i == "None" for i in entry)
180 return entry == "None"
182 def _merge_batches(self, x_map, w_map):
183 """Right-align and merge batch dims from x_map and w_map."""
184 x_batch = list(x_map[:-2])
185 w_batch = list(w_map[:-2])
186 max_b = max(len(x_batch), len(w_batch))
187 x_batch = ["None"] * (max_b - len(x_batch)) + x_batch
188 w_batch = ["None"] * (max_b - len(w_batch)) + w_batch
189 merged_batch = []
190 for xb, wb in zip(x_batch, w_batch):
191 merged_batch.append(self._merge_batch_entry(xb, wb))
192 return merged_batch
194 def _build_output_layout(self, x_layout, merged_batch, x_n, w_p, x_contract):
195 """Construct output layout from merged dims and set partial status if needed."""
196 output_map = tuple(merged_batch) + (x_n, w_p)
198 output_layout = Layout(
199 mesh_shape=x_layout.mesh_shape,
200 alias_name=x_layout.alias_name,
201 rank_list=x_layout.rank_list
202 )
203 output_layout = output_layout(*output_map)
205 # Set partial status
206 if x_contract != "None":
207 if isinstance(x_contract, tuple):
208 for axis in x_contract:
209 output_layout.set_partial_by_dev_axis(axis, 'sum')
210 else:
211 output_layout.set_partial_by_dev_axis(x_contract, 'sum')
213 return output_layout
216class BatchMatMulExtDistributedOp(BaseBatchMatMulDistributedOp):
217 """Distributed implementation for BatchMatMulExt operator."""
219 def infer_layout(self, layouts, extra_args=None):
220 """
221 Infer output layout for BatchMatMulExt operator. Inputs shape are x=[b, n, m] and w=[b, m, p].
223 BatchMatMulExt: output = x @ w.
225 Rules:
226 - Mesh shape must match.
227 - Contracting K dims must have identical layout: x[-1] == w[-2].
228 - Batch dims are right-aligned broadcast:
229 none vs shard -> shard
230 shard vs none -> shard
231 shard vs shard (different) -> error
232 - Output batch dims = merged batch dims
233 - Output N inherits x[-2], Output P inherits w[-1]
235 Args:
236 x_layout (Layout): Layout of input x
237 w_layout (Layout): Layout of input w
239 Returns:
240 tuple: Layout for output tensor
242 Examples:
243 layout = Layout((2, 2, 2), ("dp", "cp", "mp"))
244 x_layout = layout("dp", "cp", "mp")
245 w_layout = layout("dp", "mp", "None")
246 out_layout = layout("dp", "cp", "None")
247 """
249 if len(layouts) < 2:
250 raise ValueError("BatchMatMul requires at least two input layouts")
251 x_layout, w_layout = layouts[:2]
253 if x_layout.mesh_shape != w_layout.mesh_shape:
254 raise ValueError("BatchMatMul inputs must have same mesh_shape")
256 x_map = x_layout.alias_tensor_map
257 w_map = w_layout.alias_tensor_map
259 # contracting dims
260 x_contract = x_map[-1]
261 w_contract = w_map[-2]
262 if x_contract != w_contract:
263 raise ValueError(f"Contracting (M) dim layouts must match, got {x_contract} (x) vs {w_contract} (w)")
265 merged_batch = self._merge_batches(x_map, w_map)
266 x_n = x_map[-2]
267 w_p = w_map[-1]
269 return self._build_output_layout(x_layout, merged_batch, x_n, w_p, x_contract)
272class BatchMatMulDistributedOp(BaseBatchMatMulDistributedOp):
273 """Distributed implementation for BatchMatMul operator."""
275 def infer_layout(self, layouts, extra_args):
276 """
277 Infer output layout for BatchMatMul operator. Inputs shape are x=[b, n, m] and w=[b, m, p].
279 BatchMatMul: output = x @ w, with possible transpose.
281 Rules:
282 - Mesh shape must match.
283 - Contracting K dims must have identical layout: x[-1] == w[-2].
284 - Batch dims are right-aligned broadcast:
285 none vs shard -> shard
286 shard vs none -> shard
287 shard vs shard (different) -> error
288 - Output batch dims = merged batch dims
289 - Output N inherits x[-2], Output P inherits w[-1]
291 Args:
292 layouts (tuple): Layouts of input tensors (x_layout, w_layout)
293 extra_args (tuple): Additional arguments (transpose_a, transpose_b)
295 Returns:
296 tuple: Layout for output tensor
298 Examples:
299 ms.mint.bmm((x_layout, w_layout),(transpose_a=True, transpose_b=False))
300 layout = Layout((2, 2, 2), ("dp", "cp", "mp"))
301 x_layout = layout("dp", "mp", "cp")
302 w_layout = layout("dp", "mp", "None")
303 out_layout = layout("dp", "cp", "None")
304 """
306 if len(layouts) < 2:
307 raise ValueError("BatchMatMul requires at least two input layouts")
308 if len(extra_args) != 2:
309 raise ValueError("BatchMatMul requires two transpose input")
311 x_layout, w_layout = layouts[:2]
312 transpose_a, transpose_b = extra_args
314 if x_layout.mesh_shape != w_layout.mesh_shape:
315 raise ValueError("BatchMatMul inputs must have same mesh_shape")
317 x_map = x_layout.alias_tensor_map
318 w_map = w_layout.alias_tensor_map
320 # handle transpose
321 if transpose_a:
322 x_n = x_map[-1]
323 x_contract = x_map[-2]
324 else:
325 x_n = x_map[-2]
326 x_contract = x_map[-1]
328 if transpose_b:
329 w_contract = w_map[-1]
330 w_p = w_map[-2]
331 else:
332 w_contract = w_map[-2]
333 w_p = w_map[-1]
335 if x_contract != w_contract:
336 raise ValueError(f"Contracting (M) dim layouts must match, got {x_contract} (x) vs {w_contract} (w)")
338 merged_batch = self._merge_batches(x_map, w_map)
340 return self._build_output_layout(x_layout, merged_batch, x_n, w_p, x_contract)
342class LinearDistributedOp(DistributedOp):
343 """Distributed implementation for Linear operator."""
344 def infer_layout(self, layouts, extra_args):
345 """
346 Infer output layout for MatMul operator.
348 Linear: output = x @ w
350 Rules:
351 1. Batch dimensions should have same layout
352 2. Contracting dimensions should have same layout
353 3. Output dimensions inherit layouts from non-contracting dimensions
355 Args:
356 x_layout (Layout): Layout of input x
357 w_layout (Layout): Layout of input w
359 Returns:
360 tuple: Layout for output tensor
361 """
362 if len(layouts) != 3:
363 raise ValueError(f"Linear layout length is not 3, but {len(layouts)}")
364 x_layout = layouts[0]
365 w_layout = layouts[1]
366 bias_layout = layouts[2]
367 if not x_layout or not w_layout:
368 raise ValueError(f"x_layout : {x_layout}, w_layout : {w_layout}")
369 x_mesh_shape = x_layout.mesh_shape
370 w_mesh_shape = w_layout.mesh_shape
371 if x_mesh_shape != w_mesh_shape:
372 raise ValueError("Linear inputs must have same mesh_shape")
373 if bias_layout and bias_layout.mesh_shape != x_mesh_shape:
374 raise ValueError("Linear bias and x must have same mesh_shape")
375 x_map = x_layout.alias_tensor_map
376 w_map = w_layout.alias_tensor_map
377 x_contract_dim = len(x_map) - 1
378 w_contract_dim = len(w_map) - 1
379 if x_map[x_contract_dim] != w_map[w_contract_dim]:
380 raise ValueError(f"Contracting dimensions must have same layout. "
381 f"Got {x_map[x_contract_dim]} and {w_map[w_contract_dim]}")
383 output_dim = 0
384 output_map = x_map[:-1] + (w_map[output_dim],)
385 if bias_layout and bias_layout.alias_tensor_map[0] != w_map[output_dim]:
386 raise ValueError(f"Output dimensions must have same sharding. "
387 f"Got weight output dim sharding size: {w_map[output_dim]}"
388 f" and bias output dim sharding size : {bias_layout.alias_tensor_map[0]}")
389 output_layout = Layout(
390 mesh_shape=x_layout.mesh_shape,
391 alias_name=x_layout.alias_name,
392 rank_list=x_layout.rank_list
393 )
394 out_layout = output_layout(*output_map)
396 # Set partial status
397 if x_map[x_contract_dim] != "None":
398 if isinstance(x_map[x_contract_dim], tuple):
399 for axis in x_map[x_contract_dim]:
400 out_layout.set_partial_by_dev_axis(axis, 'sum')
401 else:
402 out_layout.set_partial_by_dev_axis(x_map[x_contract_dim], 'sum')
404 return out_layout
406 def get_expand_impl(self, func, output_layout, layouts, extra_args):
407 """
408 Get expand implementation for the operator
409 """
410 x_layout = layouts[0]
411 bias_layout = layouts[2]
412 x_map = x_layout.alias_tensor_map
413 x_contract_dim = len(x_map) - 1
414 scaling_factor = 1
416 if x_map[x_contract_dim] != "None":
417 if isinstance(x_map[x_contract_dim], tuple):
418 for axis in x_map[x_contract_dim]:
419 scaling_factor *= output_layout.mesh.get_device_num_along_axis(axis)
420 else:
421 scaling_factor *= output_layout.mesh.get_device_num_along_axis(x_map[x_contract_dim])
423 def expand_impl(x, w, bias):
424 linear_out = func(x, w, bias / scaling_factor)
425 return linear_out
427 if x_map[x_contract_dim] != "None" and bias_layout:
428 return expand_impl
429 return None