Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_matmul.py: 70%
205 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for MatMul operator.
17"""
19from typing import Callable, Optional, Tuple
21from hyper_parallel.core.dtensor.layout import Layout
22from .parallel_ops import DistributedOp
25class MatMulExtDistributedOp(DistributedOp):
26 """Distributed implementation for MatMul operator."""
27 def infer_layout(self, layouts: tuple, extra_args: Optional[tuple] = None) -> tuple:
28 """
29 Infer output layout for MatMul operator.
31 MatMul: output = x @ w
33 Rules:
34 1. Batch dimensions should have same layout
35 2. Contracting dimensions should have same layout
36 3. Output dimensions inherit layouts from non-contracting dimensions
38 Args:
39 x_layout (Layout): Layout of input x
40 w_layout (Layout): Layout of input w
42 Returns:
43 tuple: Layout for output tensor
44 """
45 if len(layouts) != 2:
46 raise ValueError(f"MatMul layout length is not 2, but {len(layouts)}")
47 x_layout = layouts[0]
48 w_layout = layouts[1]
49 if not x_layout or not w_layout:
50 raise ValueError(f"x_layout : {x_layout}, w_layout : {w_layout}")
51 x_mesh_shape = x_layout.mesh_shape
52 w_mesh_shape = w_layout.mesh_shape
53 if x_mesh_shape != w_mesh_shape:
54 raise ValueError("MatMul inputs must have same mesh_shape")
56 x_map = x_layout.alias_tensor_map
57 w_map = w_layout.alias_tensor_map
58 contract_dim = len(x_map) - 1
59 w_contract_dim = len(w_map) - 2
60 if x_map[contract_dim] != w_map[w_contract_dim]:
61 raise ValueError(f"Contracting dimensions must have same layout. "
62 f"Got {x_map[contract_dim]} and {w_map[w_contract_dim]}")
64 output_dim = len(w_map) - 1
65 output_map = x_map[:-1] + (w_map[output_dim],)
67 output_layout = Layout(
68 mesh_shape=x_layout.mesh_shape,
69 alias_name=x_layout.alias_name,
70 rank_list=x_layout.rank_list
71 )
72 out_layout = output_layout(*output_map)
74 # Set partial status
75 if x_map[contract_dim] != "None":
76 if isinstance(x_map[contract_dim], tuple):
77 for axis in x_map[contract_dim]:
78 out_layout.set_partial_by_dev_axis(axis, 'sum')
79 else:
80 out_layout.set_partial_by_dev_axis(x_map[contract_dim], 'sum')
82 return out_layout
85class MatMulDistributedOp(DistributedOp):
86 """Distributed implementation for MatMul operator."""
87 def infer_layout(self, layouts: tuple, extra_args: Optional[tuple] = None) -> tuple:
88 """
89 Infer output layout for MatMul operator.
91 MatMul: output = x @ w, with possible transpose
93 Args:
94 layouts (tuple): Layouts of input tensors (x_layout, w_layout)
95 extra_args (tuple): Additional arguments (transpose_a, transpose_b)
97 Returns:
98 Layout: Layout for output tensor
99 """
100 if len(layouts) < 2:
101 raise ValueError("MatMul requires at least two input layouts")
103 x_layout, w_layout = layouts[:2]
105 if len(extra_args) != 2:
106 raise ValueError("MatMul requires two transpose input")
107 transpose_a, transpose_b = extra_args[0], extra_args[1]
109 x_dict = x_layout.to_dict()
110 w_dict = w_layout.to_dict()
112 if x_dict["mesh_shape"] != w_dict["mesh_shape"]:
113 raise ValueError("MatMul inputs must have same mesh_shape")
115 x_map = x_layout.alias_tensor_map
116 w_map = w_layout.alias_tensor_map
118 # Determine contracting dimensions based on transpose flags
119 if transpose_a:
120 x_input_dim = len(x_map) - 1
121 x_contract_dim = len(x_map) - 2 # Second to last dimension
122 else:
123 x_input_dim = len(x_map) - 2
124 x_contract_dim = len(x_map) - 1 # Last dimension
126 if transpose_b:
127 w_output_dim = len(w_map) - 2
128 w_contract_dim = len(w_map) - 1 # Last dimension
129 else:
130 w_output_dim = len(w_map) - 1
131 w_contract_dim = len(w_map) - 2 # Second to last dimension
133 # Validate contracting dimensions
134 if x_map[x_contract_dim] != w_map[w_contract_dim]:
135 raise ValueError(f"Contracting dimensions must have same layout. "
136 f"Got {x_map[x_contract_dim]} and {w_map[w_contract_dim]}")
138 # Create output layout
139 output_layout = Layout(
140 mesh_shape=x_layout.mesh_shape,
141 alias_name=x_layout.alias_name,
142 rank_list=x_layout.rank_list
143 )
144 output_map = list(x_map[:-2]) + [x_map[x_input_dim]] + [w_map[w_output_dim]]
145 output_layout = output_layout(*output_map)
147 # Set partial status
148 if x_map[x_contract_dim] != "None":
149 if isinstance(x_map[x_contract_dim], tuple):
150 for axis in x_map[x_contract_dim]:
151 output_layout.set_partial_by_dev_axis(axis, 'sum')
152 else:
153 output_layout.set_partial_by_dev_axis(x_map[x_contract_dim], 'sum')
155 return output_layout
158class BaseBatchMatMulDistributedOp(DistributedOp):
159 """Base class for BatchMatMul distributed implementations."""
161 def _merge_batch_entry(self, x_dims, w_dims):
162 """
163 Merge two batch tensor_map entries with broadcasting:
164 - none vs X -> X
165 - X vs none -> X
166 - X vs X (exact same after normalization) -> X
167 - otherwise -> conflict
168 """
169 if self._is_none_entry(x_dims) and self._is_none_entry(w_dims):
170 return "None"
171 if self._is_none_entry(x_dims):
172 return w_dims
173 if self._is_none_entry(w_dims):
174 return x_dims
175 if x_dims == w_dims:
176 return x_dims
177 raise ValueError(f"Incompatible batch sharding between inputs: {x_dims} vs {w_dims}")
179 def _is_none_entry(self, entry):
180 """An entry is 'none' (no sharding) if it is 'None' or tuple of all 'None'."""
181 if isinstance(entry, tuple):
182 return all(i == "None" for i in entry)
183 return entry == "None"
185 def _merge_batches(self, x_map, w_map):
186 """Right-align and merge batch dims from x_map and w_map."""
187 x_batch = list(x_map[:-2])
188 w_batch = list(w_map[:-2])
189 max_b = max(len(x_batch), len(w_batch))
190 x_batch = ["None"] * (max_b - len(x_batch)) + x_batch
191 w_batch = ["None"] * (max_b - len(w_batch)) + w_batch
192 merged_batch = []
193 for xb, wb in zip(x_batch, w_batch):
194 merged_batch.append(self._merge_batch_entry(xb, wb))
195 return merged_batch
197 def _build_output_layout(self, x_layout, merged_batch, x_n, w_p, x_contract):
198 """Construct output layout from merged dims and set partial status if needed."""
199 output_map = tuple(merged_batch) + (x_n, w_p)
201 output_layout = Layout(
202 mesh_shape=x_layout.mesh_shape,
203 alias_name=x_layout.alias_name,
204 rank_list=x_layout.rank_list
205 )
206 output_layout = output_layout(*output_map)
208 # Set partial status
209 if x_contract != "None":
210 if isinstance(x_contract, tuple):
211 for axis in x_contract:
212 output_layout.set_partial_by_dev_axis(axis, 'sum')
213 else:
214 output_layout.set_partial_by_dev_axis(x_contract, 'sum')
216 return output_layout
219class BatchMatMulExtDistributedOp(BaseBatchMatMulDistributedOp):
220 """Distributed implementation for BatchMatMulExt operator."""
222 def infer_layout(self, layouts: tuple, extra_args: Optional[tuple] = None) -> tuple:
223 """
224 Infer output layout for BatchMatMulExt operator. Inputs shape are x=[b, n, m] and w=[b, m, p].
226 BatchMatMulExt: output = x @ w.
228 Rules:
229 - Mesh shape must match.
230 - Contracting K dims must have identical layout: x[-1] == w[-2].
231 - Batch dims are right-aligned broadcast:
232 none vs shard -> shard
233 shard vs none -> shard
234 shard vs shard (different) -> error
235 - Output batch dims = merged batch dims
236 - Output N inherits x[-2], Output P inherits w[-1]
238 Args:
239 x_layout (Layout): Layout of input x
240 w_layout (Layout): Layout of input w
242 Returns:
243 tuple: Layout for output tensor
245 Examples:
246 layout = Layout((2, 2, 2), ("dp", "cp", "mp"))
247 x_layout = layout("dp", "cp", "mp")
248 w_layout = layout("dp", "mp", "None")
249 out_layout = layout("dp", "cp", "None")
250 """
252 if len(layouts) < 2:
253 raise ValueError("BatchMatMul requires at least two input layouts")
254 x_layout, w_layout = layouts[:2]
256 if x_layout.mesh_shape != w_layout.mesh_shape:
257 raise ValueError("BatchMatMul inputs must have same mesh_shape")
259 x_map = x_layout.alias_tensor_map
260 w_map = w_layout.alias_tensor_map
262 # contracting dims
263 x_contract = x_map[-1]
264 w_contract = w_map[-2]
265 if x_contract != w_contract:
266 raise ValueError(f"Contracting (M) dim layouts must match, got {x_contract} (x) vs {w_contract} (w)")
268 merged_batch = self._merge_batches(x_map, w_map)
269 x_n = x_map[-2]
270 w_p = w_map[-1]
272 return self._build_output_layout(x_layout, merged_batch, x_n, w_p, x_contract)
275class BatchMatMulDistributedOp(BaseBatchMatMulDistributedOp):
276 """Distributed implementation for BatchMatMul operator."""
278 def infer_layout(self, layouts: tuple, extra_args: Optional[tuple] = None) -> tuple:
279 """
280 Infer output layout for BatchMatMul operator. Inputs shape are x=[b, n, m] and w=[b, m, p].
282 BatchMatMul: output = x @ w, with possible transpose.
284 Rules:
285 - Mesh shape must match.
286 - Contracting K dims must have identical layout: x[-1] == w[-2].
287 - Batch dims are right-aligned broadcast:
288 none vs shard -> shard
289 shard vs none -> shard
290 shard vs shard (different) -> error
291 - Output batch dims = merged batch dims
292 - Output N inherits x[-2], Output P inherits w[-1]
294 Args:
295 layouts (tuple): Layouts of input tensors (x_layout, w_layout)
296 extra_args (tuple): Additional arguments (transpose_a, transpose_b)
298 Returns:
299 tuple: Layout for output tensor
301 Examples:
302 ms.mint.bmm((x_layout, w_layout),(transpose_a=True, transpose_b=False))
303 layout = Layout((2, 2, 2), ("dp", "cp", "mp"))
304 x_layout = layout("dp", "mp", "cp")
305 w_layout = layout("dp", "mp", "None")
306 out_layout = layout("dp", "cp", "None")
307 """
309 if len(layouts) < 2:
310 raise ValueError("BatchMatMul requires at least two input layouts")
311 if len(extra_args) != 2:
312 raise ValueError("BatchMatMul requires two transpose input")
314 x_layout, w_layout = layouts[:2]
315 transpose_a, transpose_b = extra_args
317 if x_layout.mesh_shape != w_layout.mesh_shape:
318 raise ValueError("BatchMatMul inputs must have same mesh_shape")
320 x_map = x_layout.alias_tensor_map
321 w_map = w_layout.alias_tensor_map
323 # handle transpose
324 if transpose_a:
325 x_n = x_map[-1]
326 x_contract = x_map[-2]
327 else:
328 x_n = x_map[-2]
329 x_contract = x_map[-1]
331 if transpose_b:
332 w_contract = w_map[-1]
333 w_p = w_map[-2]
334 else:
335 w_contract = w_map[-2]
336 w_p = w_map[-1]
338 if x_contract != w_contract:
339 raise ValueError(f"Contracting (M) dim layouts must match, got {x_contract} (x) vs {w_contract} (w)")
341 merged_batch = self._merge_batches(x_map, w_map)
343 return self._build_output_layout(x_layout, merged_batch, x_n, w_p, x_contract)
346def _normalize_linear_args(x, weight, bias=None):
347 return (x, weight, bias), {}
350class LinearDistributedOp(DistributedOp):
351 """Distributed implementation for Linear operator."""
353 def preprocess(self, args: tuple, kwargs: dict) -> tuple:
354 """
355 Preprocess arguments for Linear operator.
357 Args:
358 args (tuple): Input arguments containing x and weight tensors.
359 kwargs (dict): Keyword arguments, may contain bias.
361 Returns:
362 tuple: (local_args, local_kwargs, cache_values) where local_args contains
363 local tensors for x, weight, and bias; local_kwargs is empty; and
364 cache_values contains layouts and None-sentinel for absent bias.
365 """
366 args, kwargs = _normalize_linear_args(*args, **kwargs)
367 x_tensor, w_tensor, bias = args[0], args[1], args[2]
368 local_args = (
369 x_tensor.to_local(),
370 w_tensor.to_local(),
371 bias.to_local() if hasattr(bias, '_layout') else bias,
372 )
373 local_kwargs = {}
374 cache_values = [
375 x_tensor.layout,
376 w_tensor.layout,
377 bias.layout if hasattr(bias, '_layout') else None,
378 ]
379 return local_args, local_kwargs, cache_values
381 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
382 """
383 Infer output layout for Linear operator (output = x @ weight.T + bias).
385 Rules:
386 1. x and weight must share the same mesh_shape.
387 2. weight must be 2D [out_features, in_features].
388 3. Contracting dimensions (in_features) must have the same layout.
389 4. Output batch dimensions inherit from x; output feature dim inherits from weight dim 0.
390 5. Partial state is set on the output when the contracting dimension is sharded.
392 Args:
393 cache_values (list): [x_layout, w_layout, bias_layout] where bias_layout may be None.
395 Returns:
396 tuple: ((out_layout,), None)
398 Raises:
399 ValueError: If cache_values length is not 3, layouts are invalid, mesh shapes differ,
400 weight is not 2D, contracting dims mismatch, or bias sharding is inconsistent.
401 """
402 if len(cache_values) != 3:
403 raise ValueError(
404 f"For {self.op_name}, cache_values length should be 3, but got {len(cache_values)}"
405 )
406 x_layout = cache_values[0]
407 w_layout = cache_values[1]
408 bias_layout = cache_values[2]
410 if not x_layout or not w_layout:
411 raise ValueError(f"x_layout : {x_layout}, w_layout : {w_layout}")
413 self._check_partial_inputs([x_layout, w_layout])
415 x_mesh_shape = x_layout.mesh_shape
416 w_mesh_shape = w_layout.mesh_shape
417 if x_mesh_shape != w_mesh_shape:
418 raise ValueError(
419 f"For {self.op_name}, x and weight must have the same mesh_shape, "
420 f"but got x: {x_mesh_shape} and weight: {w_mesh_shape}"
421 )
422 if bias_layout and bias_layout.mesh_shape != x_mesh_shape:
423 raise ValueError(
424 f"For {self.op_name}, bias and x must have the same mesh_shape, "
425 f"but got bias: {bias_layout.mesh_shape} and x: {x_mesh_shape}"
426 )
428 x_map = x_layout.alias_tensor_map
429 w_map = w_layout.alias_tensor_map
431 if len(w_map) != 2:
432 raise ValueError(
433 f"For {self.op_name}, weight should be 2D [out_features, in_features], "
434 f"but got {len(w_map)}D"
435 )
437 x_contract_dim = len(x_map) - 1
438 w_contract_dim = len(w_map) - 1
439 if x_map[x_contract_dim] != w_map[w_contract_dim]:
440 raise ValueError(
441 f"For {self.op_name}, contracting dimensions must have the same layout, "
442 f"but got x: {x_map[x_contract_dim]} and weight: {w_map[w_contract_dim]}"
443 )
445 output_dim = 0
446 output_map = x_map[:-1] + (w_map[output_dim],)
447 if bias_layout and bias_layout.alias_tensor_map[0] != w_map[output_dim]:
448 raise ValueError(
449 f"For {self.op_name}, bias output dim sharding must match weight output dim sharding, "
450 f"but got weight: {w_map[output_dim]} and bias: {bias_layout.alias_tensor_map[0]}"
451 )
453 output_layout = Layout(
454 mesh_shape=x_layout.mesh_shape,
455 alias_name=x_layout.alias_name,
456 rank_list=x_layout.rank_list,
457 )
458 out_layout = output_layout(*output_map)
460 # Set partial status when contracting dimension is sharded
461 if x_map[x_contract_dim] != "None":
462 if isinstance(x_map[x_contract_dim], tuple):
463 for axis in x_map[x_contract_dim]:
464 out_layout.set_partial_by_dev_axis(axis, 'sum')
465 else:
466 out_layout.set_partial_by_dev_axis(x_map[x_contract_dim], 'sum')
468 return ((out_layout,), None)
470 def get_expand_impl(self, func: Callable, infer_result: tuple,
471 cache_values: list) -> Optional[Callable]:
472 """
473 Return a custom expand implementation when bias scaling is needed.
475 When the contracting dimension is sharded each rank computes a partial sum
476 (x_shard @ w_shard.T + bias). After AllReduce the bias would accumulate
477 scaling_factor times. The returned closure pre-divides bias by scaling_factor
478 to keep the result numerically correct.
480 Args:
481 func: Original operator callable.
482 infer_result (tuple): ((out_layout,), None) from infer_layout.
483 cache_values (list): [x_layout, w_layout, bias_layout].
485 Returns:
486 callable | None: expand_impl closure when scaling is required, else None.
487 """
488 x_layout = cache_values[0]
489 bias_layout = cache_values[2]
490 x_map = x_layout.alias_tensor_map
491 x_contract_dim = len(x_map) - 1
493 # Guard: scaling only needed when contract dim is sharded AND bias is present
494 if x_map[x_contract_dim] == "None" or not bias_layout:
495 return None
497 output_layout = infer_result[0][0]
498 scaling_factor = 1
499 if isinstance(x_map[x_contract_dim], tuple):
500 for axis in x_map[x_contract_dim]:
501 scaling_factor *= output_layout.mesh.get_device_num_along_axis(axis)
502 else:
503 scaling_factor *= output_layout.mesh.get_device_num_along_axis(x_map[x_contract_dim])
505 def expand_impl(x: object, w: object, bias: object) -> object:
506 """Pre-scale bias to counteract the AllReduce accumulation over shards.
508 Args:
509 x (object): Local input activation tensor.
510 w (object): Local weight tensor.
511 bias (object): Local bias tensor to be pre-scaled.
513 Returns:
514 object: Result of the linear operation with pre-scaled bias.
515 """
516 return func(x, w, bias / scaling_factor)
518 return expand_impl