Coverage for hyper_parallel / core / shard / ops / parallel_elementwise.py: 74%

245 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Element-wise operator. 

17""" 

18 

19import copy 

20from .parallel_ops import DistributedOp 

21 

22 

23class ElementWiseDistributedOp(DistributedOp): 

24 """ 

25 Base class for distributed element-wise operators. 

26 

27 Supports broadcasting following broadcasting rules and handles 

28 distributed tensor layouts with proper sharding strategy inference. 

29 

30 Args: 

31 op_name (str): Name of the operator to register. 

32 """ 

33 

34 def infer_layout(self, layouts, extra_args): 

35 """ 

36 Infer output layouts for element-wise operations with broadcasting support. 

37 

38 For element-wise operations: 

39 - Supports broadcasting following NumPy broadcasting rules 

40 - All inputs must have compatible shapes for broadcasting 

41 - Output will have the broadcasted shape and appropriate sharding strategy 

42 - Handles both simple and complex sharding patterns (including tuple-type tensor_maps) 

43 

44 Args: 

45 layouts (tuple): Tuple of layouts for input tensors 

46 extra_args: Extra arguments for the operation. It can be: 

47 - dict containing 'input_shapes' 

48 - list/tuple where the last element is input_shapes (WithShape path) 

49 

50 Returns: 

51 Layout: Layout for output tensor with merged sharding strategy. 

52 

53 Raises: 

54 ValueError: If input layouts are not compatible for broadcasting. 

55 """ 

56 if not layouts: 

57 return None 

58 

59 valid_layouts = [layout for layout in layouts if layout is not None] 

60 

61 if not valid_layouts: 

62 return None 

63 

64 # Check partial inputs - ElementWiseDistributedOp does not support partial by default 

65 # This check is performed after basic layout validation 

66 if not self._allow_partial_inputs: 

67 self._check_partial_inputs(layouts) 

68 

69 if len(valid_layouts) == 1: 

70 return valid_layouts[0] 

71 

72 input_shapes = self._extract_input_shapes(extra_args) 

73 

74 if not input_shapes: 

75 return self._handle_no_input_shapes(valid_layouts) 

76 

77 aligned_layouts, aligned_shapes = self._align_layouts_and_shapes(layouts, input_shapes) 

78 

79 if len(aligned_layouts) <= 1 or len(aligned_layouts) != len(aligned_shapes): 

80 return valid_layouts[0] 

81 

82 output_shape = self._compute_output_shape(aligned_shapes) 

83 merged_tensor_map, merged_partial = self._merge_all_layouts( 

84 aligned_layouts, 

85 aligned_shapes, 

86 output_shape, 

87 layouts 

88 ) 

89 

90 self._check_all_inputs_broadcasts_and_partial(aligned_layouts, aligned_shapes, output_shape) 

91 

92 return self._create_output_layout(aligned_layouts[0], merged_tensor_map, merged_partial) 

93 

94 def _handle_no_input_shapes(self, valid_layouts): 

95 """ 

96 Handle the case when input shapes are not available. 

97 """ 

98 first_layout = valid_layouts[0] 

99 for layout in valid_layouts[1:]: 

100 if layout.tensor_map != first_layout.tensor_map: 

101 raise ValueError( 

102 f"For {self.op_name}, cannot infer layout without shapes: " 

103 f"mismatched tensor_map {first_layout.tensor_map} vs {layout.tensor_map}." 

104 ) 

105 return first_layout 

106 

107 def _align_layouts_and_shapes(self, layouts, input_shapes): 

108 """ 

109 Align layouts with shapes by position, skipping None layouts. 

110 """ 

111 aligned_layouts = [] 

112 aligned_shapes = [] 

113 for layout, shape in zip(layouts, input_shapes): 

114 if layout is None: 

115 continue 

116 aligned_layouts.append(layout) 

117 aligned_shapes.append(shape) 

118 return aligned_layouts, aligned_shapes 

119 

120 def _compute_output_shape(self, aligned_shapes): 

121 """ 

122 Compute broadcasted output shape from all input shapes. 

123 """ 

124 output_shape = aligned_shapes[0] 

125 for shape in aligned_shapes[1:]: 

126 output_shape = self._broadcast_shapes(output_shape, shape) 

127 return output_shape 

128 

129 def _merge_all_layouts(self, aligned_layouts, aligned_shapes, output_shape, layouts): 

130 """ 

131 Merge all input layouts sequentially to get final tensor_map and partial status. 

132 """ 

133 base_layout = aligned_layouts[0] 

134 

135 merged_tensor_map = self._merge_tensor_maps_for_broadcast( 

136 aligned_layouts[0], 

137 aligned_layouts[1], 

138 aligned_shapes[0], 

139 aligned_shapes[1], 

140 output_shape 

141 ) 

142 

143 merged_partial = self._merge_partial_status( 

144 base_layout.partial, 

145 aligned_layouts[1].partial, 

146 merged_tensor_map, 

147 aligned_layouts[0].tensor_map if aligned_layouts[0].tensor_map else tuple(), 

148 aligned_layouts[1].tensor_map if aligned_layouts[1].tensor_map else tuple(), 

149 layouts 

150 ) 

151 

152 for i in range(2, len(aligned_layouts)): 

153 temp_layout = self._create_output_layout(base_layout, merged_tensor_map, merged_partial) 

154 merged_tensor_map = self._merge_tensor_maps_for_broadcast( 

155 temp_layout, 

156 aligned_layouts[i], 

157 output_shape, 

158 aligned_shapes[i], 

159 output_shape 

160 ) 

161 merged_partial = self._merge_partial_status( 

162 merged_partial, 

163 aligned_layouts[i].partial, 

164 merged_tensor_map, 

165 temp_layout.tensor_map if temp_layout.tensor_map else tuple(), 

166 aligned_layouts[i].tensor_map if aligned_layouts[i].tensor_map else tuple(), 

167 layouts 

168 ) 

169 

170 return merged_tensor_map, merged_partial 

171 

172 def _extract_input_shapes(self, extra_args): 

173 """ 

174 Extract input_shapes from extra_args. 

175 

176 Compatible with: 

177 - dict: {"input_shapes": [...]} 

178 - list/tuple (WithShape dispatcher): extra_args = [..., input_shapes] 

179 """ 

180 if isinstance(extra_args, dict): 

181 return extra_args.get("input_shapes", None) 

182 

183 if isinstance(extra_args, (list, tuple)) and extra_args: 

184 maybe_shapes = extra_args[-1] 

185 if isinstance(maybe_shapes, (list, tuple)): 

186 return maybe_shapes 

187 

188 return None 

189 

190 def _merge_partial_status(self, partial1, partial2, merged_tensor_map, tensor_map1, tensor_map2, layouts): 

191 """ 

192 Merge partial status from two inputs. 

193 

194 Rules: 

195 1. Both None → None 

196 2. One None → Use the other 

197 3. Both not None and same → Use it 

198 4. Both not None and different → Error 

199 5. Check Shard + Partial conflicts for each input 

200 

201 Args: 

202 partial1: Partial status list from first input 

203 partial2: Partial status list from second input 

204 merged_tensor_map: Merged tensor map for output 

205 tensor_map1: Tensor map of first input 

206 tensor_map2: Tensor map of second input 

207 

208 Returns: 

209 List: Merged partial status 

210 

211 Raises: 

212 ValueError: If partial operations conflict or Shard+Partial conflict found 

213 """ 

214 # Check Shard + Partial conflicts for input1 

215 self._check_shard_partial_conflict(tensor_map1, partial1, layouts) 

216 

217 # Check Shard + Partial conflicts for input2 

218 self._check_shard_partial_conflict(tensor_map2, partial2, layouts) 

219 

220 # Determine mesh dimension from partial lists 

221 mesh_dim = max(len(partial1) if partial1 else 0, len(partial2) if partial2 else 0) 

222 

223 merged_partial = [None] * mesh_dim 

224 

225 for i in range(mesh_dim): 

226 op1 = partial1[i] if partial1 and i < len(partial1) else None 

227 op2 = partial2[i] if partial2 and i < len(partial2) else None 

228 

229 # Both have partial status with different operations 

230 if op1 is not None and op2 is not None and op1 != op2: 

231 raise ValueError( 

232 f"For {self.op_name}, partial operations should be same for device axis {i}, " 

233 f"but got {op1} and {op2}" 

234 ) 

235 

236 # Merge: prefer non-None, or either if both same 

237 if op1 is not None: 

238 merged_partial[i] = op1 

239 elif op2 is not None: 

240 merged_partial[i] = op2 

241 

242 # Check final output for Shard + Partial conflicts 

243 self._check_shard_partial_conflict(merged_tensor_map, merged_partial, layouts) 

244 

245 return merged_partial 

246 

247 def _check_shard_partial_conflict(self, tensor_map, partial_list, layouts): 

248 """ 

249 Check for conflicts between Shard and Partial on same device axis. 

250 

251 Args: 

252 tensor_map: Tensor map to check 

253 partial_list: Partial status list 

254 

255 Raises: 

256 ValueError: If Shard and Partial conflict found 

257 """ 

258 if not partial_list: 

259 return 

260 

261 mesh_dim = len(partial_list) 

262 

263 # Collect all device axis used for sharding 

264 sharded_axis = set() 

265 if tensor_map: 

266 for map_val in tensor_map: 

267 if isinstance(map_val, tuple): 

268 for sub_val in map_val: 

269 if sub_val != -1: 

270 # Convert to device axis index 

271 axis_idx = mesh_dim - 1 - sub_val 

272 sharded_axis.add(axis_idx) 

273 elif map_val != -1: 

274 axis_idx = mesh_dim - 1 - map_val 

275 sharded_axis.add(axis_idx) 

276 

277 # Check if any sharded axis has partial status 

278 for axis_idx in sharded_axis: 

279 if 0 <= axis_idx < len(partial_list) and partial_list[axis_idx] is not None: 

280 raise ValueError( 

281 f"For {self.op_name}, Shard and Partial should not coexist on same device axis " 

282 f"{axis_idx}, but got Partial({partial_list[axis_idx]}). " 

283 f"Please check layouts: {layouts}." 

284 ) 

285 

286 def _check_all_inputs_broadcasts_and_partial(self, layouts, input_shapes, output_shape): 

287 """ 

288 Check if any input broadcasts and has Partial status. 

289 """ 

290 for i, (layout, input_shape) in enumerate(zip(layouts, input_shapes)): 

291 if layout is None: 

292 continue 

293 

294 input_name = f"input{i+1}" 

295 

296 input_len = len(input_shape) 

297 output_len = len(output_shape) 

298 

299 if input_len < output_len: 

300 aligned_input_shape = (1,) * (output_len - input_len) + tuple(input_shape) 

301 else: 

302 aligned_input_shape = input_shape 

303 

304 broadcasts = False 

305 for in_dim, out_dim in zip(aligned_input_shape, output_shape): 

306 if in_dim == 1 and out_dim > 1: 

307 broadcasts = True 

308 break 

309 

310 if broadcasts and layout.is_partial(): 

311 raise ValueError( 

312 f"For {self.op_name}, {input_name} has Partial status and broadcasts. " 

313 f"Should be without Partial status for broadcasting without communication" 

314 ) 

315 

316 def _merge_tensor_maps_without_shape(self, layout1, layout2): 

317 """ 

318 Merge tensor_maps without shape information (for broadcasting scenarios). 

319 

320 Merging rules without shape: 

321 - If both dimensions are not sharded: use -1 

322 - If one is sharded and one is not: use the sharded one (assume broadcasting) 

323 - If both are sharded: they must be identical, otherwise raise error 

324 

325 Args: 

326 layout1: Layout of the first input 

327 layout2: Layout of the second input 

328 

329 Returns: 

330 tuple: Merged tensor_map 

331 

332 Raises: 

333 ValueError: If sharding strategies conflict 

334 """ 

335 map1 = layout1.tensor_map if layout1.tensor_map else tuple() 

336 map2 = layout2.tensor_map if layout2.tensor_map else tuple() 

337 

338 # Align ranks by padding with -1 

339 max_len = max(len(map1), len(map2)) 

340 padded_map1 = (-1,) * (max_len - len(map1)) + map1 

341 padded_map2 = (-1,) * (max_len - len(map2)) + map2 

342 

343 merged_map = [] 

344 for i, (m1, m2) in enumerate(zip(padded_map1, padded_map2)): 

345 m1_axis = self._normalize_tensor_map_element(m1) 

346 m2_axis = self._normalize_tensor_map_element(m2) 

347 

348 m1_axis_for_compare = frozenset(m1_axis) 

349 m2_axis_for_compare = frozenset(m2_axis) 

350 

351 m1_is_sharded = bool(m1_axis) 

352 m2_is_sharded = bool(m2_axis) 

353 

354 if not m1_is_sharded and not m2_is_sharded: 

355 merged_map.append(-1) 

356 elif not m1_is_sharded: 

357 merged_map.append(self._denormalize_tensor_map_element(m2_axis)) 

358 elif not m2_is_sharded: 

359 merged_map.append(self._denormalize_tensor_map_element(m1_axis)) 

360 else: 

361 if m1_axis_for_compare != m2_axis_for_compare: 

362 raise ValueError( 

363 f"For {self.op_name}, inputs should have same sharding pattern, " 

364 f"but got confilcting sharding at dimension {i}, " 

365 f"input1 shaded on {m1_axis} and input2 shaded on {m2_axis}." 

366 ) 

367 merged_map.append(self._denormalize_tensor_map_element(m1_axis)) 

368 

369 return tuple(merged_map) 

370 

371 def _broadcast_shapes(self, shape1, shape2): 

372 """ 

373 Calculate the broadcasted shape of two shapes according to broadcasting rules. 

374 

375 Broadcasting rules: 

376 1. If two arrays have different numbers of dimensions, pad the shape of the 

377 lower-dimensional array with 1s on the left until both shapes have the same length. 

378 2. If two arrays have the same number of dimensions but different lengths in some 

379 dimensions, dimensions with length 1 will be expanded to match the other array's 

380 dimension length. 

381 3. If two arrays have the same number of dimensions but any dimension has different 

382 lengths and neither is 1, raise an error. 

383 

384 Args: 

385 shape1 (tuple): Shape of the first tensor, e.g., (3, 1, 5) 

386 shape2 (tuple): Shape of the second tensor, e.g., (4, 5) 

387 

388 Returns: 

389 tuple: Broadcasted shape, e.g., (3, 4, 5) 

390 

391 Raises: 

392 ValueError: If shapes cannot be broadcast together. 

393 """ 

394 # Rule 1: Right-align, pad with 1s on the left to make dimensions equal 

395 len1, len2 = len(shape1), len(shape2) 

396 max_len = max(len1, len2) 

397 

398 padded_shape1 = (1,) * (max_len - len1) + tuple(shape1) 

399 padded_shape2 = (1,) * (max_len - len2) + tuple(shape2) 

400 

401 # Rules 2 and 3: Check if each dimension can be broadcast 

402 result_shape = [] 

403 for dim1, dim2 in zip(padded_shape1, padded_shape2): 

404 if dim1 == dim2: 

405 # Dimensions are the same, use directly 

406 result_shape.append(dim1) 

407 elif dim1 == 1: 

408 # First shape has 1 in this dimension, expand to dim2 

409 result_shape.append(dim2) 

410 elif dim2 == 1: 

411 # Second shape has 1 in this dimension, expand to dim1 

412 result_shape.append(dim1) 

413 else: 

414 # Rule 3: Dimensions are different and neither is 1, cannot broadcast 

415 raise ValueError( 

416 f"For {self.op_name}, shapes {shape1} and {shape2} cannot be broadcast together. " 

417 f"Dimension mismatch: {dim1} vs {dim2}" 

418 ) 

419 

420 return tuple(result_shape) 

421 

422 def _align_tensor_maps_for_broadcast(self, layout1, layout2, shape1, shape2): 

423 """ 

424 Align tensor_maps of two layouts to support broadcasting. 

425 

426 When two tensors have different dimensions, the tensor_map of the 

427 lower-dimensional tensor is padded with -1 (indicating no sharding) at the front. 

428 

429 Args: 

430 layout1: Layout of the first tensor 

431 layout2: Layout of the second tensor 

432 shape1 (tuple): Global shape of the first tensor 

433 shape2 (tuple): Global shape of the second tensor 

434 

435 Returns: 

436 tuple: (aligned_map1, aligned_map2) - Aligned tensor_maps 

437 """ 

438 len1, len2 = len(shape1), len(shape2) 

439 max_len = max(len1, len2) 

440 

441 map1 = layout1.tensor_map if layout1.tensor_map else tuple([-1] * len1) 

442 map2 = layout2.tensor_map if layout2.tensor_map else tuple([-1] * len2) 

443 

444 aligned_map1 = (-1,) * (max_len - len1) + map1 

445 aligned_map2 = (-1,) * (max_len - len2) + map2 

446 

447 return aligned_map1, aligned_map2 

448 

449 def _normalize_tensor_map_element(self, map_element): 

450 """ 

451 Normalize a tensor_map element to a tuple of device axis for unified processing. 

452 

453 Args: 

454 map_element: Element from tensor_map, can be: 

455 - int: -1 (no sharding) or device axis index 

456 - tuple: multiple device axis 

457 

458 Returns: 

459 tuple: Tuple of device axis (empty tuple if not sharded) 

460 """ 

461 if map_element == -1: 

462 return () 

463 if isinstance(map_element, int): 

464 return (map_element,) 

465 if isinstance(map_element, tuple): 

466 return tuple(dim for dim in map_element if dim != -1) 

467 return () 

468 

469 def _denormalize_tensor_map_element(self, device_axis_tuple): 

470 """ 

471 Convert a tuple of device axis back to tensor_map element format. 

472 

473 Args: 

474 device_axis_tuple (tuple): Tuple of device axis 

475 

476 Returns: 

477 int or tuple: -1 if empty, single int if one element, tuple if multiple elements 

478 """ 

479 if not device_axis_tuple: 

480 return -1 

481 if len(device_axis_tuple) == 1: 

482 return device_axis_tuple[0] 

483 return device_axis_tuple 

484 

485 def _merge_tensor_maps_for_broadcast(self, layout1, layout2, shape1, shape2, output_shape): 

486 """ 

487 Merge tensor_maps of two inputs to generate output tensor_map. 

488 

489 This method handles both simple int-type and complex tuple-type tensor_map elements, 

490 ensuring correct sharding strategy for the broadcasted output. 

491 

492 Args: 

493 layout1: Layout of the first input 

494 layout2: Layout of the second input 

495 shape1 (tuple): Global shape of the first input 

496 shape2 (tuple): Global shape of the second input 

497 output_shape (tuple): Global shape of the output 

498 

499 Returns: 

500 tuple: Merged tensor_map for the output 

501 

502 Raises: 

503 ValueError: If sharding strategies conflict or broadcasting dimension is sharded 

504 """ 

505 map1, map2 = self._align_tensor_maps_for_broadcast(layout1, layout2, shape1, shape2) 

506 

507 len1, len2 = len(shape1), len(shape2) 

508 max_len = len(output_shape) 

509 padded_shape1 = (1,) * (max_len - len1) + tuple(shape1) 

510 padded_shape2 = (1,) * (max_len - len2) + tuple(shape2) 

511 

512 merged_map = [] 

513 for i, (dim1, dim2, out_dim) in enumerate(zip(padded_shape1, padded_shape2, output_shape)): 

514 m1, m2 = map1[i], map2[i] 

515 

516 m1_axis = self._normalize_tensor_map_element(m1) 

517 m2_axis = self._normalize_tensor_map_element(m2) 

518 

519 m1_axis_for_compare = frozenset(m1_axis) 

520 m2_axis_for_compare = frozenset(m2_axis) 

521 

522 m1_is_sharded = bool(m1_axis) 

523 m2_is_sharded = bool(m2_axis) 

524 

525 if not m1_is_sharded and not m2_is_sharded: 

526 merged_map.append(-1) 

527 

528 elif not m1_is_sharded: 

529 if dim2 == 1 and out_dim > 1: 

530 raise ValueError( 

531 f"For {self.op_name}, dimension {i} of second input has size 1 " 

532 f"but is sharded on device axis {m2_axis}. " 

533 f"Broadcasting dimension cannot be sharded." 

534 ) 

535 merged_map.append(self._denormalize_tensor_map_element(m2_axis)) 

536 

537 elif not m2_is_sharded: 

538 if dim1 == 1 and out_dim > 1: 

539 raise ValueError( 

540 f"For {self.op_name}, dimension {i} of first input has size 1 " 

541 f"but is sharded on device axis {m1_axis}. " 

542 f"Broadcasting dimension cannot be sharded." 

543 ) 

544 merged_map.append(self._denormalize_tensor_map_element(m1_axis)) 

545 

546 else: 

547 if m1_axis_for_compare != m2_axis_for_compare: 

548 raise ValueError( 

549 f"For {self.op_name}, inputs should have same sharding pattern, " 

550 f"but got confilcting sharding at dimension {i}, " 

551 f"input1 shaded on {m1_axis} and input2 shaded on {m2_axis}." 

552 ) 

553 

554 if (dim1 == 1 or dim2 == 1) and dim1 != dim2: 

555 raise ValueError( 

556 f"For {self.op_name}, dimension {i} is broadcast from size 1 " 

557 f"to {out_dim} but is sharded on device axis {m1_axis}. " 

558 f"Broadcasting dimension cannot be sharded." 

559 ) 

560 

561 merged_map.append(self._denormalize_tensor_map_element(m1_axis)) 

562 

563 return tuple(merged_map) 

564 

565 def _create_output_layout(self, base_layout, output_tensor_map, partial_list=None): 

566 """ 

567 Create output layout based on input layout. 

568 

569 Args: 

570 base_layout: Base layout (usually from the first input) 

571 output_tensor_map (tuple): Tensor_map for the output 

572 partial_list (list): Partial status list for the output 

573 

574 Returns: 

575 Layout: New Layout object with updated tensor_map and alias_tensor_map 

576 """ 

577 new_layout = copy.deepcopy(base_layout) 

578 new_layout.set_tensor_map(output_tensor_map) 

579 

580 alias_tensor_map = [] 

581 for tensor_dim in output_tensor_map: 

582 if tensor_dim == -1: 

583 alias_tensor_map.append("None") 

584 elif isinstance(tensor_dim, tuple): 

585 alias_tuple = tuple( 

586 base_layout.alias_name[len(base_layout.alias_name) - 1 - dim] 

587 for dim in tensor_dim 

588 if dim != -1 

589 ) 

590 alias_tensor_map.append(alias_tuple if alias_tuple else "None") 

591 else: 

592 alias_tensor_map.append( 

593 base_layout.alias_name[len(base_layout.alias_name) - 1 - tensor_dim] 

594 ) 

595 

596 new_layout.set_alias_tensor_map(tuple(alias_tensor_map)) 

597 

598 # Set partial status if provided 

599 if partial_list: 

600 for i, partial_op in enumerate(partial_list): 

601 if partial_op is not None and i < len(new_layout.alias_name): 

602 new_layout.set_partial_by_dev_axis(new_layout.alias_name[i], partial_op) 

603 

604 return new_layout 

605 

606 

607class ElementWiseWithPartialDistributedOp(ElementWiseDistributedOp): 

608 """ 

609 Base class for elementwise operations that support partial status propagation. 

610 """ 

611 def __init__(self, op_name): 

612 super().__init__(op_name) 

613 self._allow_partial_inputs = True 

614 

615 

616class AddDistributedOp(ElementWiseWithPartialDistributedOp): 

617 """ 

618 Distributed implementation for Add operator. 

619 

620 This operator supports partial status propagation from inputs to output, 

621 which is useful for operations like gradient accumulation where partial 

622 results need to be preserved through the computation graph. 

623 """ 

624 

625 def get_expand_impl(self, func, output_layout, layouts, extra_args): 

626 """ 

627 Get expand implementation for the operator 

628 """ 

629 x1_layout = layouts[0] 

630 x2_layout = layouts[1] 

631 x1_partial = x1_layout.is_partial() if x1_layout is not None else None 

632 x2_partial = x2_layout.is_partial() if x2_layout is not None else None 

633 

634 if x1_partial != x2_partial: 

635 scaling_factor = 1 

636 for i, partial_type in enumerate(output_layout.partial): 

637 if partial_type == "sum": 

638 scaling_factor *= output_layout.mesh_shape[i] 

639 elif partial_type is not None: 

640 raise ValueError( 

641 f"For {self.op_name}, inputs partial status should be 'sum' or None, " 

642 f"but got {partial_type} at index {i}." 

643 ) 

644 

645 # use expand_impl only when one of x1 and x2 is with partial placement. 

646 def expand_impl1(x1, x2): 

647 add_out = func(x1 / scaling_factor, x2) 

648 return add_out 

649 

650 def expand_impl2(x1, x2): 

651 add_out = func(x1, x2 / scaling_factor) 

652 return add_out 

653 return expand_impl1 if not x1_partial else expand_impl2 

654 return None