Coverage for hyper_parallel / core / shard / ops / parallel_elementwise.py: 74%
245 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Element-wise operator.
17"""
19import copy
20from .parallel_ops import DistributedOp
23class ElementWiseDistributedOp(DistributedOp):
24 """
25 Base class for distributed element-wise operators.
27 Supports broadcasting following broadcasting rules and handles
28 distributed tensor layouts with proper sharding strategy inference.
30 Args:
31 op_name (str): Name of the operator to register.
32 """
34 def infer_layout(self, layouts, extra_args):
35 """
36 Infer output layouts for element-wise operations with broadcasting support.
38 For element-wise operations:
39 - Supports broadcasting following NumPy broadcasting rules
40 - All inputs must have compatible shapes for broadcasting
41 - Output will have the broadcasted shape and appropriate sharding strategy
42 - Handles both simple and complex sharding patterns (including tuple-type tensor_maps)
44 Args:
45 layouts (tuple): Tuple of layouts for input tensors
46 extra_args: Extra arguments for the operation. It can be:
47 - dict containing 'input_shapes'
48 - list/tuple where the last element is input_shapes (WithShape path)
50 Returns:
51 Layout: Layout for output tensor with merged sharding strategy.
53 Raises:
54 ValueError: If input layouts are not compatible for broadcasting.
55 """
56 if not layouts:
57 return None
59 valid_layouts = [layout for layout in layouts if layout is not None]
61 if not valid_layouts:
62 return None
64 # Check partial inputs - ElementWiseDistributedOp does not support partial by default
65 # This check is performed after basic layout validation
66 if not self._allow_partial_inputs:
67 self._check_partial_inputs(layouts)
69 if len(valid_layouts) == 1:
70 return valid_layouts[0]
72 input_shapes = self._extract_input_shapes(extra_args)
74 if not input_shapes:
75 return self._handle_no_input_shapes(valid_layouts)
77 aligned_layouts, aligned_shapes = self._align_layouts_and_shapes(layouts, input_shapes)
79 if len(aligned_layouts) <= 1 or len(aligned_layouts) != len(aligned_shapes):
80 return valid_layouts[0]
82 output_shape = self._compute_output_shape(aligned_shapes)
83 merged_tensor_map, merged_partial = self._merge_all_layouts(
84 aligned_layouts,
85 aligned_shapes,
86 output_shape,
87 layouts
88 )
90 self._check_all_inputs_broadcasts_and_partial(aligned_layouts, aligned_shapes, output_shape)
92 return self._create_output_layout(aligned_layouts[0], merged_tensor_map, merged_partial)
94 def _handle_no_input_shapes(self, valid_layouts):
95 """
96 Handle the case when input shapes are not available.
97 """
98 first_layout = valid_layouts[0]
99 for layout in valid_layouts[1:]:
100 if layout.tensor_map != first_layout.tensor_map:
101 raise ValueError(
102 f"For {self.op_name}, cannot infer layout without shapes: "
103 f"mismatched tensor_map {first_layout.tensor_map} vs {layout.tensor_map}."
104 )
105 return first_layout
107 def _align_layouts_and_shapes(self, layouts, input_shapes):
108 """
109 Align layouts with shapes by position, skipping None layouts.
110 """
111 aligned_layouts = []
112 aligned_shapes = []
113 for layout, shape in zip(layouts, input_shapes):
114 if layout is None:
115 continue
116 aligned_layouts.append(layout)
117 aligned_shapes.append(shape)
118 return aligned_layouts, aligned_shapes
120 def _compute_output_shape(self, aligned_shapes):
121 """
122 Compute broadcasted output shape from all input shapes.
123 """
124 output_shape = aligned_shapes[0]
125 for shape in aligned_shapes[1:]:
126 output_shape = self._broadcast_shapes(output_shape, shape)
127 return output_shape
129 def _merge_all_layouts(self, aligned_layouts, aligned_shapes, output_shape, layouts):
130 """
131 Merge all input layouts sequentially to get final tensor_map and partial status.
132 """
133 base_layout = aligned_layouts[0]
135 merged_tensor_map = self._merge_tensor_maps_for_broadcast(
136 aligned_layouts[0],
137 aligned_layouts[1],
138 aligned_shapes[0],
139 aligned_shapes[1],
140 output_shape
141 )
143 merged_partial = self._merge_partial_status(
144 base_layout.partial,
145 aligned_layouts[1].partial,
146 merged_tensor_map,
147 aligned_layouts[0].tensor_map if aligned_layouts[0].tensor_map else tuple(),
148 aligned_layouts[1].tensor_map if aligned_layouts[1].tensor_map else tuple(),
149 layouts
150 )
152 for i in range(2, len(aligned_layouts)):
153 temp_layout = self._create_output_layout(base_layout, merged_tensor_map, merged_partial)
154 merged_tensor_map = self._merge_tensor_maps_for_broadcast(
155 temp_layout,
156 aligned_layouts[i],
157 output_shape,
158 aligned_shapes[i],
159 output_shape
160 )
161 merged_partial = self._merge_partial_status(
162 merged_partial,
163 aligned_layouts[i].partial,
164 merged_tensor_map,
165 temp_layout.tensor_map if temp_layout.tensor_map else tuple(),
166 aligned_layouts[i].tensor_map if aligned_layouts[i].tensor_map else tuple(),
167 layouts
168 )
170 return merged_tensor_map, merged_partial
172 def _extract_input_shapes(self, extra_args):
173 """
174 Extract input_shapes from extra_args.
176 Compatible with:
177 - dict: {"input_shapes": [...]}
178 - list/tuple (WithShape dispatcher): extra_args = [..., input_shapes]
179 """
180 if isinstance(extra_args, dict):
181 return extra_args.get("input_shapes", None)
183 if isinstance(extra_args, (list, tuple)) and extra_args:
184 maybe_shapes = extra_args[-1]
185 if isinstance(maybe_shapes, (list, tuple)):
186 return maybe_shapes
188 return None
190 def _merge_partial_status(self, partial1, partial2, merged_tensor_map, tensor_map1, tensor_map2, layouts):
191 """
192 Merge partial status from two inputs.
194 Rules:
195 1. Both None → None
196 2. One None → Use the other
197 3. Both not None and same → Use it
198 4. Both not None and different → Error
199 5. Check Shard + Partial conflicts for each input
201 Args:
202 partial1: Partial status list from first input
203 partial2: Partial status list from second input
204 merged_tensor_map: Merged tensor map for output
205 tensor_map1: Tensor map of first input
206 tensor_map2: Tensor map of second input
208 Returns:
209 List: Merged partial status
211 Raises:
212 ValueError: If partial operations conflict or Shard+Partial conflict found
213 """
214 # Check Shard + Partial conflicts for input1
215 self._check_shard_partial_conflict(tensor_map1, partial1, layouts)
217 # Check Shard + Partial conflicts for input2
218 self._check_shard_partial_conflict(tensor_map2, partial2, layouts)
220 # Determine mesh dimension from partial lists
221 mesh_dim = max(len(partial1) if partial1 else 0, len(partial2) if partial2 else 0)
223 merged_partial = [None] * mesh_dim
225 for i in range(mesh_dim):
226 op1 = partial1[i] if partial1 and i < len(partial1) else None
227 op2 = partial2[i] if partial2 and i < len(partial2) else None
229 # Both have partial status with different operations
230 if op1 is not None and op2 is not None and op1 != op2:
231 raise ValueError(
232 f"For {self.op_name}, partial operations should be same for device axis {i}, "
233 f"but got {op1} and {op2}"
234 )
236 # Merge: prefer non-None, or either if both same
237 if op1 is not None:
238 merged_partial[i] = op1
239 elif op2 is not None:
240 merged_partial[i] = op2
242 # Check final output for Shard + Partial conflicts
243 self._check_shard_partial_conflict(merged_tensor_map, merged_partial, layouts)
245 return merged_partial
247 def _check_shard_partial_conflict(self, tensor_map, partial_list, layouts):
248 """
249 Check for conflicts between Shard and Partial on same device axis.
251 Args:
252 tensor_map: Tensor map to check
253 partial_list: Partial status list
255 Raises:
256 ValueError: If Shard and Partial conflict found
257 """
258 if not partial_list:
259 return
261 mesh_dim = len(partial_list)
263 # Collect all device axis used for sharding
264 sharded_axis = set()
265 if tensor_map:
266 for map_val in tensor_map:
267 if isinstance(map_val, tuple):
268 for sub_val in map_val:
269 if sub_val != -1:
270 # Convert to device axis index
271 axis_idx = mesh_dim - 1 - sub_val
272 sharded_axis.add(axis_idx)
273 elif map_val != -1:
274 axis_idx = mesh_dim - 1 - map_val
275 sharded_axis.add(axis_idx)
277 # Check if any sharded axis has partial status
278 for axis_idx in sharded_axis:
279 if 0 <= axis_idx < len(partial_list) and partial_list[axis_idx] is not None:
280 raise ValueError(
281 f"For {self.op_name}, Shard and Partial should not coexist on same device axis "
282 f"{axis_idx}, but got Partial({partial_list[axis_idx]}). "
283 f"Please check layouts: {layouts}."
284 )
286 def _check_all_inputs_broadcasts_and_partial(self, layouts, input_shapes, output_shape):
287 """
288 Check if any input broadcasts and has Partial status.
289 """
290 for i, (layout, input_shape) in enumerate(zip(layouts, input_shapes)):
291 if layout is None:
292 continue
294 input_name = f"input{i+1}"
296 input_len = len(input_shape)
297 output_len = len(output_shape)
299 if input_len < output_len:
300 aligned_input_shape = (1,) * (output_len - input_len) + tuple(input_shape)
301 else:
302 aligned_input_shape = input_shape
304 broadcasts = False
305 for in_dim, out_dim in zip(aligned_input_shape, output_shape):
306 if in_dim == 1 and out_dim > 1:
307 broadcasts = True
308 break
310 if broadcasts and layout.is_partial():
311 raise ValueError(
312 f"For {self.op_name}, {input_name} has Partial status and broadcasts. "
313 f"Should be without Partial status for broadcasting without communication"
314 )
316 def _merge_tensor_maps_without_shape(self, layout1, layout2):
317 """
318 Merge tensor_maps without shape information (for broadcasting scenarios).
320 Merging rules without shape:
321 - If both dimensions are not sharded: use -1
322 - If one is sharded and one is not: use the sharded one (assume broadcasting)
323 - If both are sharded: they must be identical, otherwise raise error
325 Args:
326 layout1: Layout of the first input
327 layout2: Layout of the second input
329 Returns:
330 tuple: Merged tensor_map
332 Raises:
333 ValueError: If sharding strategies conflict
334 """
335 map1 = layout1.tensor_map if layout1.tensor_map else tuple()
336 map2 = layout2.tensor_map if layout2.tensor_map else tuple()
338 # Align ranks by padding with -1
339 max_len = max(len(map1), len(map2))
340 padded_map1 = (-1,) * (max_len - len(map1)) + map1
341 padded_map2 = (-1,) * (max_len - len(map2)) + map2
343 merged_map = []
344 for i, (m1, m2) in enumerate(zip(padded_map1, padded_map2)):
345 m1_axis = self._normalize_tensor_map_element(m1)
346 m2_axis = self._normalize_tensor_map_element(m2)
348 m1_axis_for_compare = frozenset(m1_axis)
349 m2_axis_for_compare = frozenset(m2_axis)
351 m1_is_sharded = bool(m1_axis)
352 m2_is_sharded = bool(m2_axis)
354 if not m1_is_sharded and not m2_is_sharded:
355 merged_map.append(-1)
356 elif not m1_is_sharded:
357 merged_map.append(self._denormalize_tensor_map_element(m2_axis))
358 elif not m2_is_sharded:
359 merged_map.append(self._denormalize_tensor_map_element(m1_axis))
360 else:
361 if m1_axis_for_compare != m2_axis_for_compare:
362 raise ValueError(
363 f"For {self.op_name}, inputs should have same sharding pattern, "
364 f"but got confilcting sharding at dimension {i}, "
365 f"input1 shaded on {m1_axis} and input2 shaded on {m2_axis}."
366 )
367 merged_map.append(self._denormalize_tensor_map_element(m1_axis))
369 return tuple(merged_map)
371 def _broadcast_shapes(self, shape1, shape2):
372 """
373 Calculate the broadcasted shape of two shapes according to broadcasting rules.
375 Broadcasting rules:
376 1. If two arrays have different numbers of dimensions, pad the shape of the
377 lower-dimensional array with 1s on the left until both shapes have the same length.
378 2. If two arrays have the same number of dimensions but different lengths in some
379 dimensions, dimensions with length 1 will be expanded to match the other array's
380 dimension length.
381 3. If two arrays have the same number of dimensions but any dimension has different
382 lengths and neither is 1, raise an error.
384 Args:
385 shape1 (tuple): Shape of the first tensor, e.g., (3, 1, 5)
386 shape2 (tuple): Shape of the second tensor, e.g., (4, 5)
388 Returns:
389 tuple: Broadcasted shape, e.g., (3, 4, 5)
391 Raises:
392 ValueError: If shapes cannot be broadcast together.
393 """
394 # Rule 1: Right-align, pad with 1s on the left to make dimensions equal
395 len1, len2 = len(shape1), len(shape2)
396 max_len = max(len1, len2)
398 padded_shape1 = (1,) * (max_len - len1) + tuple(shape1)
399 padded_shape2 = (1,) * (max_len - len2) + tuple(shape2)
401 # Rules 2 and 3: Check if each dimension can be broadcast
402 result_shape = []
403 for dim1, dim2 in zip(padded_shape1, padded_shape2):
404 if dim1 == dim2:
405 # Dimensions are the same, use directly
406 result_shape.append(dim1)
407 elif dim1 == 1:
408 # First shape has 1 in this dimension, expand to dim2
409 result_shape.append(dim2)
410 elif dim2 == 1:
411 # Second shape has 1 in this dimension, expand to dim1
412 result_shape.append(dim1)
413 else:
414 # Rule 3: Dimensions are different and neither is 1, cannot broadcast
415 raise ValueError(
416 f"For {self.op_name}, shapes {shape1} and {shape2} cannot be broadcast together. "
417 f"Dimension mismatch: {dim1} vs {dim2}"
418 )
420 return tuple(result_shape)
422 def _align_tensor_maps_for_broadcast(self, layout1, layout2, shape1, shape2):
423 """
424 Align tensor_maps of two layouts to support broadcasting.
426 When two tensors have different dimensions, the tensor_map of the
427 lower-dimensional tensor is padded with -1 (indicating no sharding) at the front.
429 Args:
430 layout1: Layout of the first tensor
431 layout2: Layout of the second tensor
432 shape1 (tuple): Global shape of the first tensor
433 shape2 (tuple): Global shape of the second tensor
435 Returns:
436 tuple: (aligned_map1, aligned_map2) - Aligned tensor_maps
437 """
438 len1, len2 = len(shape1), len(shape2)
439 max_len = max(len1, len2)
441 map1 = layout1.tensor_map if layout1.tensor_map else tuple([-1] * len1)
442 map2 = layout2.tensor_map if layout2.tensor_map else tuple([-1] * len2)
444 aligned_map1 = (-1,) * (max_len - len1) + map1
445 aligned_map2 = (-1,) * (max_len - len2) + map2
447 return aligned_map1, aligned_map2
449 def _normalize_tensor_map_element(self, map_element):
450 """
451 Normalize a tensor_map element to a tuple of device axis for unified processing.
453 Args:
454 map_element: Element from tensor_map, can be:
455 - int: -1 (no sharding) or device axis index
456 - tuple: multiple device axis
458 Returns:
459 tuple: Tuple of device axis (empty tuple if not sharded)
460 """
461 if map_element == -1:
462 return ()
463 if isinstance(map_element, int):
464 return (map_element,)
465 if isinstance(map_element, tuple):
466 return tuple(dim for dim in map_element if dim != -1)
467 return ()
469 def _denormalize_tensor_map_element(self, device_axis_tuple):
470 """
471 Convert a tuple of device axis back to tensor_map element format.
473 Args:
474 device_axis_tuple (tuple): Tuple of device axis
476 Returns:
477 int or tuple: -1 if empty, single int if one element, tuple if multiple elements
478 """
479 if not device_axis_tuple:
480 return -1
481 if len(device_axis_tuple) == 1:
482 return device_axis_tuple[0]
483 return device_axis_tuple
485 def _merge_tensor_maps_for_broadcast(self, layout1, layout2, shape1, shape2, output_shape):
486 """
487 Merge tensor_maps of two inputs to generate output tensor_map.
489 This method handles both simple int-type and complex tuple-type tensor_map elements,
490 ensuring correct sharding strategy for the broadcasted output.
492 Args:
493 layout1: Layout of the first input
494 layout2: Layout of the second input
495 shape1 (tuple): Global shape of the first input
496 shape2 (tuple): Global shape of the second input
497 output_shape (tuple): Global shape of the output
499 Returns:
500 tuple: Merged tensor_map for the output
502 Raises:
503 ValueError: If sharding strategies conflict or broadcasting dimension is sharded
504 """
505 map1, map2 = self._align_tensor_maps_for_broadcast(layout1, layout2, shape1, shape2)
507 len1, len2 = len(shape1), len(shape2)
508 max_len = len(output_shape)
509 padded_shape1 = (1,) * (max_len - len1) + tuple(shape1)
510 padded_shape2 = (1,) * (max_len - len2) + tuple(shape2)
512 merged_map = []
513 for i, (dim1, dim2, out_dim) in enumerate(zip(padded_shape1, padded_shape2, output_shape)):
514 m1, m2 = map1[i], map2[i]
516 m1_axis = self._normalize_tensor_map_element(m1)
517 m2_axis = self._normalize_tensor_map_element(m2)
519 m1_axis_for_compare = frozenset(m1_axis)
520 m2_axis_for_compare = frozenset(m2_axis)
522 m1_is_sharded = bool(m1_axis)
523 m2_is_sharded = bool(m2_axis)
525 if not m1_is_sharded and not m2_is_sharded:
526 merged_map.append(-1)
528 elif not m1_is_sharded:
529 if dim2 == 1 and out_dim > 1:
530 raise ValueError(
531 f"For {self.op_name}, dimension {i} of second input has size 1 "
532 f"but is sharded on device axis {m2_axis}. "
533 f"Broadcasting dimension cannot be sharded."
534 )
535 merged_map.append(self._denormalize_tensor_map_element(m2_axis))
537 elif not m2_is_sharded:
538 if dim1 == 1 and out_dim > 1:
539 raise ValueError(
540 f"For {self.op_name}, dimension {i} of first input has size 1 "
541 f"but is sharded on device axis {m1_axis}. "
542 f"Broadcasting dimension cannot be sharded."
543 )
544 merged_map.append(self._denormalize_tensor_map_element(m1_axis))
546 else:
547 if m1_axis_for_compare != m2_axis_for_compare:
548 raise ValueError(
549 f"For {self.op_name}, inputs should have same sharding pattern, "
550 f"but got confilcting sharding at dimension {i}, "
551 f"input1 shaded on {m1_axis} and input2 shaded on {m2_axis}."
552 )
554 if (dim1 == 1 or dim2 == 1) and dim1 != dim2:
555 raise ValueError(
556 f"For {self.op_name}, dimension {i} is broadcast from size 1 "
557 f"to {out_dim} but is sharded on device axis {m1_axis}. "
558 f"Broadcasting dimension cannot be sharded."
559 )
561 merged_map.append(self._denormalize_tensor_map_element(m1_axis))
563 return tuple(merged_map)
565 def _create_output_layout(self, base_layout, output_tensor_map, partial_list=None):
566 """
567 Create output layout based on input layout.
569 Args:
570 base_layout: Base layout (usually from the first input)
571 output_tensor_map (tuple): Tensor_map for the output
572 partial_list (list): Partial status list for the output
574 Returns:
575 Layout: New Layout object with updated tensor_map and alias_tensor_map
576 """
577 new_layout = copy.deepcopy(base_layout)
578 new_layout.set_tensor_map(output_tensor_map)
580 alias_tensor_map = []
581 for tensor_dim in output_tensor_map:
582 if tensor_dim == -1:
583 alias_tensor_map.append("None")
584 elif isinstance(tensor_dim, tuple):
585 alias_tuple = tuple(
586 base_layout.alias_name[len(base_layout.alias_name) - 1 - dim]
587 for dim in tensor_dim
588 if dim != -1
589 )
590 alias_tensor_map.append(alias_tuple if alias_tuple else "None")
591 else:
592 alias_tensor_map.append(
593 base_layout.alias_name[len(base_layout.alias_name) - 1 - tensor_dim]
594 )
596 new_layout.set_alias_tensor_map(tuple(alias_tensor_map))
598 # Set partial status if provided
599 if partial_list:
600 for i, partial_op in enumerate(partial_list):
601 if partial_op is not None and i < len(new_layout.alias_name):
602 new_layout.set_partial_by_dev_axis(new_layout.alias_name[i], partial_op)
604 return new_layout
607class ElementWiseWithPartialDistributedOp(ElementWiseDistributedOp):
608 """
609 Base class for elementwise operations that support partial status propagation.
610 """
611 def __init__(self, op_name):
612 super().__init__(op_name)
613 self._allow_partial_inputs = True
616class AddDistributedOp(ElementWiseWithPartialDistributedOp):
617 """
618 Distributed implementation for Add operator.
620 This operator supports partial status propagation from inputs to output,
621 which is useful for operations like gradient accumulation where partial
622 results need to be preserved through the computation graph.
623 """
625 def get_expand_impl(self, func, output_layout, layouts, extra_args):
626 """
627 Get expand implementation for the operator
628 """
629 x1_layout = layouts[0]
630 x2_layout = layouts[1]
631 x1_partial = x1_layout.is_partial() if x1_layout is not None else None
632 x2_partial = x2_layout.is_partial() if x2_layout is not None else None
634 if x1_partial != x2_partial:
635 scaling_factor = 1
636 for i, partial_type in enumerate(output_layout.partial):
637 if partial_type == "sum":
638 scaling_factor *= output_layout.mesh_shape[i]
639 elif partial_type is not None:
640 raise ValueError(
641 f"For {self.op_name}, inputs partial status should be 'sum' or None, "
642 f"but got {partial_type} at index {i}."
643 )
645 # use expand_impl only when one of x1 and x2 is with partial placement.
646 def expand_impl1(x1, x2):
647 add_out = func(x1 / scaling_factor, x2)
648 return add_out
650 def expand_impl2(x1, x2):
651 add_out = func(x1, x2 / scaling_factor)
652 return add_out
653 return expand_impl1 if not x1_partial else expand_impl2
654 return None