Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_matmul.py: 70%

205 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for MatMul operator. 

17""" 

18 

19from typing import Callable, Optional, Tuple 

20 

21from hyper_parallel.core.dtensor.layout import Layout 

22from .parallel_ops import DistributedOp 

23 

24 

25class MatMulExtDistributedOp(DistributedOp): 

26 """Distributed implementation for MatMul operator.""" 

27 def infer_layout(self, layouts: tuple, extra_args: Optional[tuple] = None) -> tuple: 

28 """ 

29 Infer output layout for MatMul operator. 

30 

31 MatMul: output = x @ w 

32 

33 Rules: 

34 1. Batch dimensions should have same layout 

35 2. Contracting dimensions should have same layout 

36 3. Output dimensions inherit layouts from non-contracting dimensions 

37 

38 Args: 

39 x_layout (Layout): Layout of input x 

40 w_layout (Layout): Layout of input w 

41 

42 Returns: 

43 tuple: Layout for output tensor 

44 """ 

45 if len(layouts) != 2: 

46 raise ValueError(f"MatMul layout length is not 2, but {len(layouts)}") 

47 x_layout = layouts[0] 

48 w_layout = layouts[1] 

49 if not x_layout or not w_layout: 

50 raise ValueError(f"x_layout : {x_layout}, w_layout : {w_layout}") 

51 x_mesh_shape = x_layout.mesh_shape 

52 w_mesh_shape = w_layout.mesh_shape 

53 if x_mesh_shape != w_mesh_shape: 

54 raise ValueError("MatMul inputs must have same mesh_shape") 

55 

56 x_map = x_layout.alias_tensor_map 

57 w_map = w_layout.alias_tensor_map 

58 contract_dim = len(x_map) - 1 

59 w_contract_dim = len(w_map) - 2 

60 if x_map[contract_dim] != w_map[w_contract_dim]: 

61 raise ValueError(f"Contracting dimensions must have same layout. " 

62 f"Got {x_map[contract_dim]} and {w_map[w_contract_dim]}") 

63 

64 output_dim = len(w_map) - 1 

65 output_map = x_map[:-1] + (w_map[output_dim],) 

66 

67 output_layout = Layout( 

68 mesh_shape=x_layout.mesh_shape, 

69 alias_name=x_layout.alias_name, 

70 rank_list=x_layout.rank_list 

71 ) 

72 out_layout = output_layout(*output_map) 

73 

74 # Set partial status 

75 if x_map[contract_dim] != "None": 

76 if isinstance(x_map[contract_dim], tuple): 

77 for axis in x_map[contract_dim]: 

78 out_layout.set_partial_by_dev_axis(axis, 'sum') 

79 else: 

80 out_layout.set_partial_by_dev_axis(x_map[contract_dim], 'sum') 

81 

82 return out_layout 

83 

84 

85class MatMulDistributedOp(DistributedOp): 

86 """Distributed implementation for MatMul operator.""" 

87 def infer_layout(self, layouts: tuple, extra_args: Optional[tuple] = None) -> tuple: 

88 """ 

89 Infer output layout for MatMul operator. 

90 

91 MatMul: output = x @ w, with possible transpose 

92 

93 Args: 

94 layouts (tuple): Layouts of input tensors (x_layout, w_layout) 

95 extra_args (tuple): Additional arguments (transpose_a, transpose_b) 

96 

97 Returns: 

98 Layout: Layout for output tensor 

99 """ 

100 if len(layouts) < 2: 

101 raise ValueError("MatMul requires at least two input layouts") 

102 

103 x_layout, w_layout = layouts[:2] 

104 

105 if len(extra_args) != 2: 

106 raise ValueError("MatMul requires two transpose input") 

107 transpose_a, transpose_b = extra_args[0], extra_args[1] 

108 

109 x_dict = x_layout.to_dict() 

110 w_dict = w_layout.to_dict() 

111 

112 if x_dict["mesh_shape"] != w_dict["mesh_shape"]: 

113 raise ValueError("MatMul inputs must have same mesh_shape") 

114 

115 x_map = x_layout.alias_tensor_map 

116 w_map = w_layout.alias_tensor_map 

117 

118 # Determine contracting dimensions based on transpose flags 

119 if transpose_a: 

120 x_input_dim = len(x_map) - 1 

121 x_contract_dim = len(x_map) - 2 # Second to last dimension 

122 else: 

123 x_input_dim = len(x_map) - 2 

124 x_contract_dim = len(x_map) - 1 # Last dimension 

125 

126 if transpose_b: 

127 w_output_dim = len(w_map) - 2 

128 w_contract_dim = len(w_map) - 1 # Last dimension 

129 else: 

130 w_output_dim = len(w_map) - 1 

131 w_contract_dim = len(w_map) - 2 # Second to last dimension 

132 

133 # Validate contracting dimensions 

134 if x_map[x_contract_dim] != w_map[w_contract_dim]: 

135 raise ValueError(f"Contracting dimensions must have same layout. " 

136 f"Got {x_map[x_contract_dim]} and {w_map[w_contract_dim]}") 

137 

138 # Create output layout 

139 output_layout = Layout( 

140 mesh_shape=x_layout.mesh_shape, 

141 alias_name=x_layout.alias_name, 

142 rank_list=x_layout.rank_list 

143 ) 

144 output_map = list(x_map[:-2]) + [x_map[x_input_dim]] + [w_map[w_output_dim]] 

145 output_layout = output_layout(*output_map) 

146 

147 # Set partial status 

148 if x_map[x_contract_dim] != "None": 

149 if isinstance(x_map[x_contract_dim], tuple): 

150 for axis in x_map[x_contract_dim]: 

151 output_layout.set_partial_by_dev_axis(axis, 'sum') 

152 else: 

153 output_layout.set_partial_by_dev_axis(x_map[x_contract_dim], 'sum') 

154 

155 return output_layout 

156 

157 

158class BaseBatchMatMulDistributedOp(DistributedOp): 

159 """Base class for BatchMatMul distributed implementations.""" 

160 

161 def _merge_batch_entry(self, x_dims, w_dims): 

162 """ 

163 Merge two batch tensor_map entries with broadcasting: 

164 - none vs X -> X 

165 - X vs none -> X 

166 - X vs X (exact same after normalization) -> X 

167 - otherwise -> conflict 

168 """ 

169 if self._is_none_entry(x_dims) and self._is_none_entry(w_dims): 

170 return "None" 

171 if self._is_none_entry(x_dims): 

172 return w_dims 

173 if self._is_none_entry(w_dims): 

174 return x_dims 

175 if x_dims == w_dims: 

176 return x_dims 

177 raise ValueError(f"Incompatible batch sharding between inputs: {x_dims} vs {w_dims}") 

178 

179 def _is_none_entry(self, entry): 

180 """An entry is 'none' (no sharding) if it is 'None' or tuple of all 'None'.""" 

181 if isinstance(entry, tuple): 

182 return all(i == "None" for i in entry) 

183 return entry == "None" 

184 

185 def _merge_batches(self, x_map, w_map): 

186 """Right-align and merge batch dims from x_map and w_map.""" 

187 x_batch = list(x_map[:-2]) 

188 w_batch = list(w_map[:-2]) 

189 max_b = max(len(x_batch), len(w_batch)) 

190 x_batch = ["None"] * (max_b - len(x_batch)) + x_batch 

191 w_batch = ["None"] * (max_b - len(w_batch)) + w_batch 

192 merged_batch = [] 

193 for xb, wb in zip(x_batch, w_batch): 

194 merged_batch.append(self._merge_batch_entry(xb, wb)) 

195 return merged_batch 

196 

197 def _build_output_layout(self, x_layout, merged_batch, x_n, w_p, x_contract): 

198 """Construct output layout from merged dims and set partial status if needed.""" 

199 output_map = tuple(merged_batch) + (x_n, w_p) 

200 

201 output_layout = Layout( 

202 mesh_shape=x_layout.mesh_shape, 

203 alias_name=x_layout.alias_name, 

204 rank_list=x_layout.rank_list 

205 ) 

206 output_layout = output_layout(*output_map) 

207 

208 # Set partial status 

209 if x_contract != "None": 

210 if isinstance(x_contract, tuple): 

211 for axis in x_contract: 

212 output_layout.set_partial_by_dev_axis(axis, 'sum') 

213 else: 

214 output_layout.set_partial_by_dev_axis(x_contract, 'sum') 

215 

216 return output_layout 

217 

218 

219class BatchMatMulExtDistributedOp(BaseBatchMatMulDistributedOp): 

220 """Distributed implementation for BatchMatMulExt operator.""" 

221 

222 def infer_layout(self, layouts: tuple, extra_args: Optional[tuple] = None) -> tuple: 

223 """ 

224 Infer output layout for BatchMatMulExt operator. Inputs shape are x=[b, n, m] and w=[b, m, p]. 

225 

226 BatchMatMulExt: output = x @ w. 

227 

228 Rules: 

229 - Mesh shape must match. 

230 - Contracting K dims must have identical layout: x[-1] == w[-2]. 

231 - Batch dims are right-aligned broadcast: 

232 none vs shard -> shard 

233 shard vs none -> shard 

234 shard vs shard (different) -> error 

235 - Output batch dims = merged batch dims 

236 - Output N inherits x[-2], Output P inherits w[-1] 

237 

238 Args: 

239 x_layout (Layout): Layout of input x 

240 w_layout (Layout): Layout of input w 

241 

242 Returns: 

243 tuple: Layout for output tensor 

244 

245 Examples: 

246 layout = Layout((2, 2, 2), ("dp", "cp", "mp")) 

247 x_layout = layout("dp", "cp", "mp") 

248 w_layout = layout("dp", "mp", "None") 

249 out_layout = layout("dp", "cp", "None") 

250 """ 

251 

252 if len(layouts) < 2: 

253 raise ValueError("BatchMatMul requires at least two input layouts") 

254 x_layout, w_layout = layouts[:2] 

255 

256 if x_layout.mesh_shape != w_layout.mesh_shape: 

257 raise ValueError("BatchMatMul inputs must have same mesh_shape") 

258 

259 x_map = x_layout.alias_tensor_map 

260 w_map = w_layout.alias_tensor_map 

261 

262 # contracting dims 

263 x_contract = x_map[-1] 

264 w_contract = w_map[-2] 

265 if x_contract != w_contract: 

266 raise ValueError(f"Contracting (M) dim layouts must match, got {x_contract} (x) vs {w_contract} (w)") 

267 

268 merged_batch = self._merge_batches(x_map, w_map) 

269 x_n = x_map[-2] 

270 w_p = w_map[-1] 

271 

272 return self._build_output_layout(x_layout, merged_batch, x_n, w_p, x_contract) 

273 

274 

275class BatchMatMulDistributedOp(BaseBatchMatMulDistributedOp): 

276 """Distributed implementation for BatchMatMul operator.""" 

277 

278 def infer_layout(self, layouts: tuple, extra_args: Optional[tuple] = None) -> tuple: 

279 """ 

280 Infer output layout for BatchMatMul operator. Inputs shape are x=[b, n, m] and w=[b, m, p]. 

281 

282 BatchMatMul: output = x @ w, with possible transpose. 

283 

284 Rules: 

285 - Mesh shape must match. 

286 - Contracting K dims must have identical layout: x[-1] == w[-2]. 

287 - Batch dims are right-aligned broadcast: 

288 none vs shard -> shard 

289 shard vs none -> shard 

290 shard vs shard (different) -> error 

291 - Output batch dims = merged batch dims 

292 - Output N inherits x[-2], Output P inherits w[-1] 

293 

294 Args: 

295 layouts (tuple): Layouts of input tensors (x_layout, w_layout) 

296 extra_args (tuple): Additional arguments (transpose_a, transpose_b) 

297 

298 Returns: 

299 tuple: Layout for output tensor 

300 

301 Examples: 

302 ms.mint.bmm((x_layout, w_layout),(transpose_a=True, transpose_b=False)) 

303 layout = Layout((2, 2, 2), ("dp", "cp", "mp")) 

304 x_layout = layout("dp", "mp", "cp") 

305 w_layout = layout("dp", "mp", "None") 

306 out_layout = layout("dp", "cp", "None") 

307 """ 

308 

309 if len(layouts) < 2: 

310 raise ValueError("BatchMatMul requires at least two input layouts") 

311 if len(extra_args) != 2: 

312 raise ValueError("BatchMatMul requires two transpose input") 

313 

314 x_layout, w_layout = layouts[:2] 

315 transpose_a, transpose_b = extra_args 

316 

317 if x_layout.mesh_shape != w_layout.mesh_shape: 

318 raise ValueError("BatchMatMul inputs must have same mesh_shape") 

319 

320 x_map = x_layout.alias_tensor_map 

321 w_map = w_layout.alias_tensor_map 

322 

323 # handle transpose 

324 if transpose_a: 

325 x_n = x_map[-1] 

326 x_contract = x_map[-2] 

327 else: 

328 x_n = x_map[-2] 

329 x_contract = x_map[-1] 

330 

331 if transpose_b: 

332 w_contract = w_map[-1] 

333 w_p = w_map[-2] 

334 else: 

335 w_contract = w_map[-2] 

336 w_p = w_map[-1] 

337 

338 if x_contract != w_contract: 

339 raise ValueError(f"Contracting (M) dim layouts must match, got {x_contract} (x) vs {w_contract} (w)") 

340 

341 merged_batch = self._merge_batches(x_map, w_map) 

342 

343 return self._build_output_layout(x_layout, merged_batch, x_n, w_p, x_contract) 

344 

345 

346def _normalize_linear_args(x, weight, bias=None): 

347 return (x, weight, bias), {} 

348 

349 

350class LinearDistributedOp(DistributedOp): 

351 """Distributed implementation for Linear operator.""" 

352 

353 def preprocess(self, args: tuple, kwargs: dict) -> tuple: 

354 """ 

355 Preprocess arguments for Linear operator. 

356 

357 Args: 

358 args (tuple): Input arguments containing x and weight tensors. 

359 kwargs (dict): Keyword arguments, may contain bias. 

360 

361 Returns: 

362 tuple: (local_args, local_kwargs, cache_values) where local_args contains 

363 local tensors for x, weight, and bias; local_kwargs is empty; and 

364 cache_values contains layouts and None-sentinel for absent bias. 

365 """ 

366 args, kwargs = _normalize_linear_args(*args, **kwargs) 

367 x_tensor, w_tensor, bias = args[0], args[1], args[2] 

368 local_args = ( 

369 x_tensor.to_local(), 

370 w_tensor.to_local(), 

371 bias.to_local() if hasattr(bias, '_layout') else bias, 

372 ) 

373 local_kwargs = {} 

374 cache_values = [ 

375 x_tensor.layout, 

376 w_tensor.layout, 

377 bias.layout if hasattr(bias, '_layout') else None, 

378 ] 

379 return local_args, local_kwargs, cache_values 

380 

381 def infer_layout(self, cache_values: list) -> Tuple[tuple, None]: 

382 """ 

383 Infer output layout for Linear operator (output = x @ weight.T + bias). 

384 

385 Rules: 

386 1. x and weight must share the same mesh_shape. 

387 2. weight must be 2D [out_features, in_features]. 

388 3. Contracting dimensions (in_features) must have the same layout. 

389 4. Output batch dimensions inherit from x; output feature dim inherits from weight dim 0. 

390 5. Partial state is set on the output when the contracting dimension is sharded. 

391 

392 Args: 

393 cache_values (list): [x_layout, w_layout, bias_layout] where bias_layout may be None. 

394 

395 Returns: 

396 tuple: ((out_layout,), None) 

397 

398 Raises: 

399 ValueError: If cache_values length is not 3, layouts are invalid, mesh shapes differ, 

400 weight is not 2D, contracting dims mismatch, or bias sharding is inconsistent. 

401 """ 

402 if len(cache_values) != 3: 

403 raise ValueError( 

404 f"For {self.op_name}, cache_values length should be 3, but got {len(cache_values)}" 

405 ) 

406 x_layout = cache_values[0] 

407 w_layout = cache_values[1] 

408 bias_layout = cache_values[2] 

409 

410 if not x_layout or not w_layout: 

411 raise ValueError(f"x_layout : {x_layout}, w_layout : {w_layout}") 

412 

413 self._check_partial_inputs([x_layout, w_layout]) 

414 

415 x_mesh_shape = x_layout.mesh_shape 

416 w_mesh_shape = w_layout.mesh_shape 

417 if x_mesh_shape != w_mesh_shape: 

418 raise ValueError( 

419 f"For {self.op_name}, x and weight must have the same mesh_shape, " 

420 f"but got x: {x_mesh_shape} and weight: {w_mesh_shape}" 

421 ) 

422 if bias_layout and bias_layout.mesh_shape != x_mesh_shape: 

423 raise ValueError( 

424 f"For {self.op_name}, bias and x must have the same mesh_shape, " 

425 f"but got bias: {bias_layout.mesh_shape} and x: {x_mesh_shape}" 

426 ) 

427 

428 x_map = x_layout.alias_tensor_map 

429 w_map = w_layout.alias_tensor_map 

430 

431 if len(w_map) != 2: 

432 raise ValueError( 

433 f"For {self.op_name}, weight should be 2D [out_features, in_features], " 

434 f"but got {len(w_map)}D" 

435 ) 

436 

437 x_contract_dim = len(x_map) - 1 

438 w_contract_dim = len(w_map) - 1 

439 if x_map[x_contract_dim] != w_map[w_contract_dim]: 

440 raise ValueError( 

441 f"For {self.op_name}, contracting dimensions must have the same layout, " 

442 f"but got x: {x_map[x_contract_dim]} and weight: {w_map[w_contract_dim]}" 

443 ) 

444 

445 output_dim = 0 

446 output_map = x_map[:-1] + (w_map[output_dim],) 

447 if bias_layout and bias_layout.alias_tensor_map[0] != w_map[output_dim]: 

448 raise ValueError( 

449 f"For {self.op_name}, bias output dim sharding must match weight output dim sharding, " 

450 f"but got weight: {w_map[output_dim]} and bias: {bias_layout.alias_tensor_map[0]}" 

451 ) 

452 

453 output_layout = Layout( 

454 mesh_shape=x_layout.mesh_shape, 

455 alias_name=x_layout.alias_name, 

456 rank_list=x_layout.rank_list, 

457 ) 

458 out_layout = output_layout(*output_map) 

459 

460 # Set partial status when contracting dimension is sharded 

461 if x_map[x_contract_dim] != "None": 

462 if isinstance(x_map[x_contract_dim], tuple): 

463 for axis in x_map[x_contract_dim]: 

464 out_layout.set_partial_by_dev_axis(axis, 'sum') 

465 else: 

466 out_layout.set_partial_by_dev_axis(x_map[x_contract_dim], 'sum') 

467 

468 return ((out_layout,), None) 

469 

470 def get_expand_impl(self, func: Callable, infer_result: tuple, 

471 cache_values: list) -> Optional[Callable]: 

472 """ 

473 Return a custom expand implementation when bias scaling is needed. 

474 

475 When the contracting dimension is sharded each rank computes a partial sum 

476 (x_shard @ w_shard.T + bias). After AllReduce the bias would accumulate 

477 scaling_factor times. The returned closure pre-divides bias by scaling_factor 

478 to keep the result numerically correct. 

479 

480 Args: 

481 func: Original operator callable. 

482 infer_result (tuple): ((out_layout,), None) from infer_layout. 

483 cache_values (list): [x_layout, w_layout, bias_layout]. 

484 

485 Returns: 

486 callable | None: expand_impl closure when scaling is required, else None. 

487 """ 

488 x_layout = cache_values[0] 

489 bias_layout = cache_values[2] 

490 x_map = x_layout.alias_tensor_map 

491 x_contract_dim = len(x_map) - 1 

492 

493 # Guard: scaling only needed when contract dim is sharded AND bias is present 

494 if x_map[x_contract_dim] == "None" or not bias_layout: 

495 return None 

496 

497 output_layout = infer_result[0][0] 

498 scaling_factor = 1 

499 if isinstance(x_map[x_contract_dim], tuple): 

500 for axis in x_map[x_contract_dim]: 

501 scaling_factor *= output_layout.mesh.get_device_num_along_axis(axis) 

502 else: 

503 scaling_factor *= output_layout.mesh.get_device_num_along_axis(x_map[x_contract_dim]) 

504 

505 def expand_impl(x: object, w: object, bias: object) -> object: 

506 """Pre-scale bias to counteract the AllReduce accumulation over shards. 

507 

508 Args: 

509 x (object): Local input activation tensor. 

510 w (object): Local weight tensor. 

511 bias (object): Local bias tensor to be pre-scaled. 

512 

513 Returns: 

514 object: Result of the linear operation with pre-scaled bias. 

515 """ 

516 return func(x, w, bias / scaling_factor) 

517 

518 return expand_impl