Coverage for hyper_parallel / core / shard / ops / parallel_matmul.py: 77%

193 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for MatMul operator. 

17""" 

18 

19from hyper_parallel.core.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22class MatMulExtDistributedOp(DistributedOp): 

23 """Distributed implementation for MatMul operator.""" 

24 def infer_layout(self, layouts, extra_args): 

25 """ 

26 Infer output layout for MatMul operator. 

27 

28 MatMul: output = x @ w 

29 

30 Rules: 

31 1. Batch dimensions should have same layout 

32 2. Contracting dimensions should have same layout 

33 3. Output dimensions inherit layouts from non-contracting dimensions 

34 

35 Args: 

36 x_layout (Layout): Layout of input x 

37 w_layout (Layout): Layout of input w 

38 

39 Returns: 

40 tuple: Layout for output tensor 

41 """ 

42 if len(layouts) != 2: 

43 raise ValueError(f"MatMul layout length is not 2, but {len(layouts)}") 

44 x_layout = layouts[0] 

45 w_layout = layouts[1] 

46 if not x_layout or not w_layout: 

47 raise ValueError(f"x_layout : {x_layout}, w_layout : {w_layout}") 

48 x_mesh_shape = x_layout.mesh_shape 

49 w_mesh_shape = w_layout.mesh_shape 

50 if x_mesh_shape != w_mesh_shape: 

51 raise ValueError("MatMul inputs must have same mesh_shape") 

52 

53 x_map = x_layout.alias_tensor_map 

54 w_map = w_layout.alias_tensor_map 

55 contract_dim = len(x_map) - 1 

56 w_contract_dim = len(w_map) - 2 

57 if x_map[contract_dim] != w_map[w_contract_dim]: 

58 raise ValueError(f"Contracting dimensions must have same layout. " 

59 f"Got {x_map[contract_dim]} and {w_map[w_contract_dim]}") 

60 

61 output_dim = len(w_map) - 1 

62 output_map = x_map[:-1] + (w_map[output_dim],) 

63 

64 output_layout = Layout( 

65 mesh_shape=x_layout.mesh_shape, 

66 alias_name=x_layout.alias_name, 

67 rank_list=x_layout.rank_list 

68 ) 

69 out_layout = output_layout(*output_map) 

70 

71 # Set partial status 

72 if x_map[contract_dim] != "None": 

73 if isinstance(x_map[contract_dim], tuple): 

74 for axis in x_map[contract_dim]: 

75 out_layout.set_partial_by_dev_axis(axis, 'sum') 

76 else: 

77 out_layout.set_partial_by_dev_axis(x_map[contract_dim], 'sum') 

78 

79 return out_layout 

80 

81 

82class MatMulDistributedOp(DistributedOp): 

83 """Distributed implementation for MatMul operator.""" 

84 def infer_layout(self, layouts, extra_args): 

85 """ 

86 Infer output layout for MatMul operator. 

87 

88 MatMul: output = x @ w, with possible transpose 

89 

90 Args: 

91 layouts (tuple): Layouts of input tensors (x_layout, w_layout) 

92 extra_args (tuple): Additional arguments (transpose_a, transpose_b) 

93 

94 Returns: 

95 Layout: Layout for output tensor 

96 """ 

97 if len(layouts) < 2: 

98 raise ValueError("MatMul requires at least two input layouts") 

99 

100 x_layout, w_layout = layouts[:2] 

101 

102 if len(extra_args) != 2: 

103 raise ValueError("MatMul requires two transpose input") 

104 transpose_a, transpose_b = extra_args[0], extra_args[1] 

105 

106 x_dict = x_layout.to_dict() 

107 w_dict = w_layout.to_dict() 

108 

109 if x_dict["mesh_shape"] != w_dict["mesh_shape"]: 

110 raise ValueError("MatMul inputs must have same mesh_shape") 

111 

112 x_map = x_layout.alias_tensor_map 

113 w_map = w_layout.alias_tensor_map 

114 

115 # Determine contracting dimensions based on transpose flags 

116 if transpose_a: 

117 x_input_dim = len(x_map) - 1 

118 x_contract_dim = len(x_map) - 2 # Second to last dimension 

119 else: 

120 x_input_dim = len(x_map) - 2 

121 x_contract_dim = len(x_map) - 1 # Last dimension 

122 

123 if transpose_b: 

124 w_output_dim = len(w_map) - 2 

125 w_contract_dim = len(w_map) - 1 # Last dimension 

126 else: 

127 w_output_dim = len(w_map) - 1 

128 w_contract_dim = len(w_map) - 2 # Second to last dimension 

129 

130 # Validate contracting dimensions 

131 if x_map[x_contract_dim] != w_map[w_contract_dim]: 

132 raise ValueError(f"Contracting dimensions must have same layout. " 

133 f"Got {x_map[x_contract_dim]} and {w_map[w_contract_dim]}") 

134 

135 # Create output layout 

136 output_layout = Layout( 

137 mesh_shape=x_layout.mesh_shape, 

138 alias_name=x_layout.alias_name, 

139 rank_list=x_layout.rank_list 

140 ) 

141 output_map = list(x_map[:-2]) + [x_map[x_input_dim]] + [w_map[w_output_dim]] 

142 output_layout = output_layout(*output_map) 

143 

144 # Set partial status 

145 if x_map[x_contract_dim] != "None": 

146 if isinstance(x_map[x_contract_dim], tuple): 

147 for axis in x_map[x_contract_dim]: 

148 output_layout.set_partial_by_dev_axis(axis, 'sum') 

149 else: 

150 output_layout.set_partial_by_dev_axis(x_map[x_contract_dim], 'sum') 

151 

152 return output_layout 

153 

154 

155class BaseBatchMatMulDistributedOp(DistributedOp): 

156 """Base class for BatchMatMul distributed implementations.""" 

157 

158 def _merge_batch_entry(self, x_dims, w_dims): 

159 """ 

160 Merge two batch tensor_map entries with broadcasting: 

161 - none vs X -> X 

162 - X vs none -> X 

163 - X vs X (exact same after normalization) -> X 

164 - otherwise -> conflict 

165 """ 

166 if self._is_none_entry(x_dims) and self._is_none_entry(w_dims): 

167 return "None" 

168 if self._is_none_entry(x_dims): 

169 return w_dims 

170 if self._is_none_entry(w_dims): 

171 return x_dims 

172 if x_dims == w_dims: 

173 return x_dims 

174 raise ValueError(f"Incompatible batch sharding between inputs: {x_dims} vs {w_dims}") 

175 

176 def _is_none_entry(self, entry): 

177 """An entry is 'none' (no sharding) if it is 'None' or tuple of all 'None'.""" 

178 if isinstance(entry, tuple): 

179 return all(i == "None" for i in entry) 

180 return entry == "None" 

181 

182 def _merge_batches(self, x_map, w_map): 

183 """Right-align and merge batch dims from x_map and w_map.""" 

184 x_batch = list(x_map[:-2]) 

185 w_batch = list(w_map[:-2]) 

186 max_b = max(len(x_batch), len(w_batch)) 

187 x_batch = ["None"] * (max_b - len(x_batch)) + x_batch 

188 w_batch = ["None"] * (max_b - len(w_batch)) + w_batch 

189 merged_batch = [] 

190 for xb, wb in zip(x_batch, w_batch): 

191 merged_batch.append(self._merge_batch_entry(xb, wb)) 

192 return merged_batch 

193 

194 def _build_output_layout(self, x_layout, merged_batch, x_n, w_p, x_contract): 

195 """Construct output layout from merged dims and set partial status if needed.""" 

196 output_map = tuple(merged_batch) + (x_n, w_p) 

197 

198 output_layout = Layout( 

199 mesh_shape=x_layout.mesh_shape, 

200 alias_name=x_layout.alias_name, 

201 rank_list=x_layout.rank_list 

202 ) 

203 output_layout = output_layout(*output_map) 

204 

205 # Set partial status 

206 if x_contract != "None": 

207 if isinstance(x_contract, tuple): 

208 for axis in x_contract: 

209 output_layout.set_partial_by_dev_axis(axis, 'sum') 

210 else: 

211 output_layout.set_partial_by_dev_axis(x_contract, 'sum') 

212 

213 return output_layout 

214 

215 

216class BatchMatMulExtDistributedOp(BaseBatchMatMulDistributedOp): 

217 """Distributed implementation for BatchMatMulExt operator.""" 

218 

219 def infer_layout(self, layouts, extra_args=None): 

220 """ 

221 Infer output layout for BatchMatMulExt operator. Inputs shape are x=[b, n, m] and w=[b, m, p]. 

222 

223 BatchMatMulExt: output = x @ w. 

224 

225 Rules: 

226 - Mesh shape must match. 

227 - Contracting K dims must have identical layout: x[-1] == w[-2]. 

228 - Batch dims are right-aligned broadcast: 

229 none vs shard -> shard 

230 shard vs none -> shard 

231 shard vs shard (different) -> error 

232 - Output batch dims = merged batch dims 

233 - Output N inherits x[-2], Output P inherits w[-1] 

234 

235 Args: 

236 x_layout (Layout): Layout of input x 

237 w_layout (Layout): Layout of input w 

238 

239 Returns: 

240 tuple: Layout for output tensor 

241 

242 Examples: 

243 layout = Layout((2, 2, 2), ("dp", "cp", "mp")) 

244 x_layout = layout("dp", "cp", "mp") 

245 w_layout = layout("dp", "mp", "None") 

246 out_layout = layout("dp", "cp", "None") 

247 """ 

248 

249 if len(layouts) < 2: 

250 raise ValueError("BatchMatMul requires at least two input layouts") 

251 x_layout, w_layout = layouts[:2] 

252 

253 if x_layout.mesh_shape != w_layout.mesh_shape: 

254 raise ValueError("BatchMatMul inputs must have same mesh_shape") 

255 

256 x_map = x_layout.alias_tensor_map 

257 w_map = w_layout.alias_tensor_map 

258 

259 # contracting dims 

260 x_contract = x_map[-1] 

261 w_contract = w_map[-2] 

262 if x_contract != w_contract: 

263 raise ValueError(f"Contracting (M) dim layouts must match, got {x_contract} (x) vs {w_contract} (w)") 

264 

265 merged_batch = self._merge_batches(x_map, w_map) 

266 x_n = x_map[-2] 

267 w_p = w_map[-1] 

268 

269 return self._build_output_layout(x_layout, merged_batch, x_n, w_p, x_contract) 

270 

271 

272class BatchMatMulDistributedOp(BaseBatchMatMulDistributedOp): 

273 """Distributed implementation for BatchMatMul operator.""" 

274 

275 def infer_layout(self, layouts, extra_args): 

276 """ 

277 Infer output layout for BatchMatMul operator. Inputs shape are x=[b, n, m] and w=[b, m, p]. 

278 

279 BatchMatMul: output = x @ w, with possible transpose. 

280 

281 Rules: 

282 - Mesh shape must match. 

283 - Contracting K dims must have identical layout: x[-1] == w[-2]. 

284 - Batch dims are right-aligned broadcast: 

285 none vs shard -> shard 

286 shard vs none -> shard 

287 shard vs shard (different) -> error 

288 - Output batch dims = merged batch dims 

289 - Output N inherits x[-2], Output P inherits w[-1] 

290 

291 Args: 

292 layouts (tuple): Layouts of input tensors (x_layout, w_layout) 

293 extra_args (tuple): Additional arguments (transpose_a, transpose_b) 

294 

295 Returns: 

296 tuple: Layout for output tensor 

297 

298 Examples: 

299 ms.mint.bmm((x_layout, w_layout),(transpose_a=True, transpose_b=False)) 

300 layout = Layout((2, 2, 2), ("dp", "cp", "mp")) 

301 x_layout = layout("dp", "mp", "cp") 

302 w_layout = layout("dp", "mp", "None") 

303 out_layout = layout("dp", "cp", "None") 

304 """ 

305 

306 if len(layouts) < 2: 

307 raise ValueError("BatchMatMul requires at least two input layouts") 

308 if len(extra_args) != 2: 

309 raise ValueError("BatchMatMul requires two transpose input") 

310 

311 x_layout, w_layout = layouts[:2] 

312 transpose_a, transpose_b = extra_args 

313 

314 if x_layout.mesh_shape != w_layout.mesh_shape: 

315 raise ValueError("BatchMatMul inputs must have same mesh_shape") 

316 

317 x_map = x_layout.alias_tensor_map 

318 w_map = w_layout.alias_tensor_map 

319 

320 # handle transpose 

321 if transpose_a: 

322 x_n = x_map[-1] 

323 x_contract = x_map[-2] 

324 else: 

325 x_n = x_map[-2] 

326 x_contract = x_map[-1] 

327 

328 if transpose_b: 

329 w_contract = w_map[-1] 

330 w_p = w_map[-2] 

331 else: 

332 w_contract = w_map[-2] 

333 w_p = w_map[-1] 

334 

335 if x_contract != w_contract: 

336 raise ValueError(f"Contracting (M) dim layouts must match, got {x_contract} (x) vs {w_contract} (w)") 

337 

338 merged_batch = self._merge_batches(x_map, w_map) 

339 

340 return self._build_output_layout(x_layout, merged_batch, x_n, w_p, x_contract) 

341 

342class LinearDistributedOp(DistributedOp): 

343 """Distributed implementation for Linear operator.""" 

344 def infer_layout(self, layouts, extra_args): 

345 """ 

346 Infer output layout for MatMul operator. 

347 

348 Linear: output = x @ w 

349 

350 Rules: 

351 1. Batch dimensions should have same layout 

352 2. Contracting dimensions should have same layout 

353 3. Output dimensions inherit layouts from non-contracting dimensions 

354 

355 Args: 

356 x_layout (Layout): Layout of input x 

357 w_layout (Layout): Layout of input w 

358 

359 Returns: 

360 tuple: Layout for output tensor 

361 """ 

362 if len(layouts) != 3: 

363 raise ValueError(f"Linear layout length is not 3, but {len(layouts)}") 

364 x_layout = layouts[0] 

365 w_layout = layouts[1] 

366 bias_layout = layouts[2] 

367 if not x_layout or not w_layout: 

368 raise ValueError(f"x_layout : {x_layout}, w_layout : {w_layout}") 

369 x_mesh_shape = x_layout.mesh_shape 

370 w_mesh_shape = w_layout.mesh_shape 

371 if x_mesh_shape != w_mesh_shape: 

372 raise ValueError("Linear inputs must have same mesh_shape") 

373 if bias_layout and bias_layout.mesh_shape != x_mesh_shape: 

374 raise ValueError("Linear bias and x must have same mesh_shape") 

375 x_map = x_layout.alias_tensor_map 

376 w_map = w_layout.alias_tensor_map 

377 x_contract_dim = len(x_map) - 1 

378 w_contract_dim = len(w_map) - 1 

379 if x_map[x_contract_dim] != w_map[w_contract_dim]: 

380 raise ValueError(f"Contracting dimensions must have same layout. " 

381 f"Got {x_map[x_contract_dim]} and {w_map[w_contract_dim]}") 

382 

383 output_dim = 0 

384 output_map = x_map[:-1] + (w_map[output_dim],) 

385 if bias_layout and bias_layout.alias_tensor_map[0] != w_map[output_dim]: 

386 raise ValueError(f"Output dimensions must have same sharding. " 

387 f"Got weight output dim sharding size: {w_map[output_dim]}" 

388 f" and bias output dim sharding size : {bias_layout.alias_tensor_map[0]}") 

389 output_layout = Layout( 

390 mesh_shape=x_layout.mesh_shape, 

391 alias_name=x_layout.alias_name, 

392 rank_list=x_layout.rank_list 

393 ) 

394 out_layout = output_layout(*output_map) 

395 

396 # Set partial status 

397 if x_map[x_contract_dim] != "None": 

398 if isinstance(x_map[x_contract_dim], tuple): 

399 for axis in x_map[x_contract_dim]: 

400 out_layout.set_partial_by_dev_axis(axis, 'sum') 

401 else: 

402 out_layout.set_partial_by_dev_axis(x_map[x_contract_dim], 'sum') 

403 

404 return out_layout 

405 

406 def get_expand_impl(self, func, output_layout, layouts, extra_args): 

407 """ 

408 Get expand implementation for the operator 

409 """ 

410 x_layout = layouts[0] 

411 bias_layout = layouts[2] 

412 x_map = x_layout.alias_tensor_map 

413 x_contract_dim = len(x_map) - 1 

414 scaling_factor = 1 

415 

416 if x_map[x_contract_dim] != "None": 

417 if isinstance(x_map[x_contract_dim], tuple): 

418 for axis in x_map[x_contract_dim]: 

419 scaling_factor *= output_layout.mesh.get_device_num_along_axis(axis) 

420 else: 

421 scaling_factor *= output_layout.mesh.get_device_num_along_axis(x_map[x_contract_dim]) 

422 

423 def expand_impl(x, w, bias): 

424 linear_out = func(x, w, bias / scaling_factor) 

425 return linear_out 

426 

427 if x_map[x_contract_dim] != "None" and bias_layout: 

428 return expand_impl 

429 return None