Coverage for hyper_parallel / core / shard / _op_dispatch.py: 78%

441 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""_op_dispatch""" 

16import os 

17import sys 

18import atexit 

19import glob 

20import importlib 

21from typing import Any, List, Dict, Optional 

22import yaml 

23 

24from hyper_parallel.core.shard.ops.parallel_ops_register import get_distributed_op 

25from hyper_parallel.core.dtensor import DTensor 

26from hyper_parallel.platform import get_platform 

27 

28platform = get_platform() 

29Tensor = platform.Tensor 

30 

31_dtensor_dispatch = True 

32 

33def enable_dtensor_dispatch(): 

34 global _dtensor_dispatch 

35 _dtensor_dispatch = True 

36 

37def disable_dtensor_dispatch(): 

38 global _dtensor_dispatch 

39 _dtensor_dispatch = False 

40 

41def get_dtensor_dispatch(): 

42 return _dtensor_dispatch 

43 

44 

45class LayoutCacheKey: 

46 """ 

47 Layout cache key 

48 """ 

49 def __init__(self, layout_ids: List[str]): 

50 self.layout_ids = layout_ids 

51 

52 def __eq__(self, other): 

53 if not isinstance(other, LayoutCacheKey): 

54 return False 

55 return self.layout_ids == other.layout_ids 

56 

57 def __hash__(self): 

58 seed = 0 

59 for id_str in self.layout_ids: 

60 h = hash(id_str) 

61 seed ^= h + 0x9e3779b9 + (seed << 6) + (seed >> 2) 

62 return seed 

63 

64class LayoutCacheManager: 

65 """ 

66 Cache layout in infer layout. 

67 """ 

68 _instance = None 

69 

70 def __init__(self): 

71 self.layout_cache: Dict[str, Dict[LayoutCacheKey, Any]] = {} 

72 atexit.register(self.clear_cache) 

73 

74 @classmethod 

75 def get_instance(cls): 

76 if cls._instance is None: 

77 cls._instance = LayoutCacheManager() 

78 return cls._instance 

79 

80 def get_layout_cache(self) -> Dict[str, Dict[LayoutCacheKey, Any]]: 

81 return self.layout_cache 

82 

83 def distributed_op(self, op_name: str) -> Any: 

84 op = get_distributed_op(op_name) 

85 return op 

86 

87 def clear_cache(self): 

88 self.layout_cache.clear() 

89 

90 

91class OpDispatcher: 

92 """ 

93 OpDispatcher 

94 """ 

95 def __init__(self): 

96 self._env_yaml_dir: Optional[str] = os.environ.get("HYPER_PARALLEL_OPS_YAML_DIR") 

97 self._env_python_path: Optional[str] = os.environ.get("HYPER_PARALLEL_OPS_PYTHON_PATH") 

98 

99 self._setup_paths_from_env() 

100 

101 self.layout_infer_ops = self.safe_load_yaml_from_dir() 

102 self.whitelist = ["InplaceAddExt", "InplaceSubExt", "InplaceMul", "InplaceDiv", "typeof", "DistCommIsend", 

103 "DistCommIrecv", "DistCommBroadcast", "DistCommAllReduce", "DistCommAllGather", 

104 "DistCommReduceScatter", "requires_grad_", "item", "__get__", "__set__", "register_hook", 

105 "is_complex", "chunk", "__bool__", "__len__", "__format__"] 

106 

107 # Ops requiring args unpacking for layout inference (packed as prim, name, real_args). 

108 self.unpack_ops = ["ScatterUpdate", "Mod", "GatherNd"] 

109 

110 self._register_distributed_ops() 

111 

112 def _setup_paths_from_env(self): 

113 self._setup_yaml_dir(self._env_yaml_dir) 

114 self._extend_sys_path(self._env_python_path) 

115 

116 def _setup_yaml_dir(self, env_yaml_dir: Optional[str]): 

117 """ 

118 Feature: Configure yaml_dir/work_dir for OpDispatcher 

119 Description: Resolve the YAML directory used to load distributed op definitions. 

120 If env_yaml_dir is an absolute path, use it directly; otherwise treat it 

121 as a path relative to the project work_dir. If env_yaml_dir is not set, 

122 fall back to the default 'shard/ops/yaml' under work_dir. 

123 Expectation: self.yaml_dir and self.work_dir are set to valid values used later by 

124 safe_load_yaml_from_dir(); no functional behavior is changed. 

125 """ 

126 if env_yaml_dir: 

127 if os.path.isabs(env_yaml_dir): 

128 self.yaml_dir = env_yaml_dir 

129 self.work_dir = "" 

130 else: 

131 self.work_dir = os.path.normpath( 

132 os.path.join(os.path.dirname(os.path.realpath(__file__)), "../") 

133 ) 

134 self.yaml_dir = env_yaml_dir 

135 else: 

136 self.yaml_dir = "shard/ops/yaml" 

137 self.work_dir = os.path.normpath( 

138 os.path.join(os.path.dirname(os.path.realpath(__file__)), "../") 

139 ) 

140 

141 def _extend_sys_path(self, env_python_path: Optional[str]): 

142 if not env_python_path: 

143 return 

144 python_paths = env_python_path.split(":") 

145 for path in python_paths: 

146 if path and os.path.isdir(path) and path not in sys.path: 

147 sys.path.insert(0, path) 

148 

149 def _register_distributed_ops(self): 

150 for op_name, config in self.layout_infer_ops.items(): 

151 self._register_single_distributed_op(op_name, config) 

152 

153 def _register_single_distributed_op(self, op_name: str, config: dict): 

154 """ 

155 Feature: Register a single distributed op implementation 

156 Description: Import the distributed op class specified by config and instantiate it 

157 with op_name to trigger registration in the distributed op registry. 

158 Prefer 'distributed_op_module' when provided; otherwise import from 

159 built-in module prefix 'hyper_parallel.core.shard.ops.' plus 

160 'distributed_op_file'. If import fails and an external python path is 

161 provided via env, fall back to importing 'distributed_op_file' directly. 

162 Expectation: The distributed op class is imported and instantiated successfully, 

163 or the original import error is raised; no functional behavior is changed. 

164 """ 

165 class_name = config["distributed_op_class"] 

166 

167 if "distributed_op_module" in config: 

168 module_name = config["distributed_op_module"] 

169 module = importlib.import_module(module_name) 

170 op_class = getattr(module, class_name) 

171 _ = op_class(op_name) 

172 return 

173 

174 module_file = config["distributed_op_file"] 

175 try: 

176 module_name = "hyper_parallel.core.shard.ops." + module_file 

177 module = importlib.import_module(module_name) 

178 op_class = getattr(module, class_name) 

179 _ = op_class(op_name) 

180 except (ModuleNotFoundError, ImportError): 

181 if self._env_python_path: 

182 module = importlib.import_module(module_file) 

183 op_class = getattr(module, class_name) 

184 _ = op_class(op_name) 

185 else: 

186 raise 

187 

188 def _process_args_and_kwargs( 

189 self, args, kwargs, cache_key: "LayoutCacheKey" 

190 ) -> tuple[list, list, list, dict]: 

191 """_process_args_and_kwargs""" 

192 # input_layouts contain prarmeters which have layout, extra_args contain other parameters 

193 input_layouts = [] 

194 extra_args = [] 

195 # input_args are position prarmeters, input_kwargs are keyword parameters 

196 input_args = [] 

197 input_kwargs = kwargs.copy() 

198 

199 # Normal ops pass real inputs directly (e.g. SumExt: args = (dtensor, axis: list, keep_dims: bool, dtype: None)). 

200 for arg in args: 

201 if arg is None: 

202 input_layouts.append(None) 

203 input_args.append(arg) 

204 continue 

205 

206 if not hasattr(arg, "_layout"): 

207 id_str = "scalar" 

208 if not isinstance(arg, Tensor): 

209 id_str = str(arg) 

210 cache_key.layout_ids.append(id_str) 

211 extra_args.append(arg) 

212 input_layouts.append(None) 

213 input_args.append(arg) 

214 else: 

215 layout = arg.layout 

216 layout_id = layout.compact_str 

217 cache_key.layout_ids.append(str(layout_id)) 

218 input_layouts.append(layout) 

219 if isinstance(arg, DTensor): 

220 input_args.append(arg.to_local()) 

221 else: 

222 input_args.append(arg) 

223 

224 for k, val in kwargs.items(): 

225 if val is None: 

226 input_layouts.append(None) 

227 continue 

228 if not hasattr(val, "_layout"): 

229 id_str = "scalar" 

230 if not isinstance(val, Tensor): 

231 id_str = str(val) 

232 cache_key.layout_ids.append(id_str) 

233 extra_args.append(val) 

234 input_layouts.append(None) 

235 else: 

236 layout = val.layout 

237 layout_id = layout.compact_str 

238 cache_key.layout_ids.append(str(layout_id)) 

239 input_layouts.append(layout) 

240 if isinstance(val, DTensor): 

241 input_kwargs[k] = val.to_local() 

242 

243 return input_layouts, extra_args, input_args, input_kwargs 

244 

245 def _with_layout_infer(self, func: callable, *args, **kwargs) -> Tensor: 

246 """_with_layout_infer""" 

247 func_name = platform.get_op_name(func) 

248 packed_call = None 

249 # Ops in unpack_ops use packed fallback args (e.g. ScatterUpdate: (prim_obj, op_name: str, (input_x, indices, updates))). 

250 if(func_name in self.unpack_ops and len(args) == 3 and 

251 isinstance(args[1], str) and isinstance(args[2],(tuple,list))): 

252 packed_call = (args[0], args[1]) 

253 args = tuple(args[2]) 

254 

255 cache_key = LayoutCacheKey([]) 

256 input_layouts, extra_args, input_args, input_kwargs = self._process_args_and_kwargs( 

257 args, kwargs, cache_key 

258 ) 

259 cache_manager = LayoutCacheManager.get_instance() 

260 layout_cache = cache_manager.get_layout_cache() 

261 if func_name not in layout_cache: 

262 layout_cache[func_name] = {} 

263 

264 op_layout_cache = layout_cache[func_name] 

265 

266 distribute_op = cache_manager.distributed_op(func_name) 

267 if cache_key in op_layout_cache: 

268 output_layout, op_impl = op_layout_cache[cache_key] 

269 else: 

270 all_args = (input_layouts, extra_args) 

271 output_layout = distribute_op.infer_layout(*all_args) 

272 op_impl = getattr( 

273 distribute_op, "get_expand_impl", lambda *args, **kwargs: None 

274 )(func, output_layout, input_layouts, extra_args) 

275 op_layout_cache[cache_key] = (output_layout, op_impl) 

276 

277 if op_impl is None: 

278 op_impl = func 

279 

280 if packed_call is not None: 

281 py_output = op_impl(packed_call[0], packed_call[1], tuple(input_args), **input_kwargs) 

282 else: 

283 py_output = op_impl(*input_args, **input_kwargs) 

284 

285 if isinstance(py_output, (tuple, list)): 

286 output = () 

287 if isinstance(output_layout, (tuple, list)): 

288 if len(py_output) == len(output_layout): 

289 for i, output_item in enumerate(py_output): 

290 output += (DTensor.from_local(output_item, output_layout[i].mesh, output_layout[i].placements),) 

291 else: 

292 raise RuntimeError(f"Output tuple size ({len(py_output)}) " 

293 f"does not match layout tuple size ({len(output_layout)})") 

294 else: 

295 raise RuntimeError("Output is a tuple but layout is not") 

296 return output 

297 

298 return DTensor.from_local(py_output, output_layout.mesh, output_layout.placements) 

299 

300 def _with_layout_infer_with_tuple_expand(self, func: callable, *args, **kwargs) -> Tensor: 

301 """_with_layout_infer_with_tuple_expand""" 

302 expanded_args = [] 

303 input_args = [] 

304 for arg in args: 

305 if isinstance(arg, (tuple, list)): 

306 expanded_args.extend(arg) 

307 # pylint: disable=R1728 

308 input_args.append(tuple(item.to_local() if hasattr(item, "_layout") else item for item in arg)) 

309 else: 

310 expanded_args.append(arg) 

311 input_args.append(arg.to_local() if isinstance(arg, DTensor) else arg) 

312 

313 cache_key = LayoutCacheKey([]) 

314 input_layouts = [] 

315 extra_args = [] 

316 

317 for arg in expanded_args: 

318 if arg is None: 

319 input_layouts.append(None) 

320 continue 

321 

322 if not hasattr(arg, "_layout"): 

323 id_str = "scalar" 

324 if not isinstance(arg, Tensor): 

325 id_str = str(arg) 

326 cache_key.layout_ids.append(id_str) 

327 extra_args.append(arg) 

328 input_layouts.append(None) 

329 else: 

330 layout = arg.layout 

331 layout_id = layout.compact_str 

332 cache_key.layout_ids.append(str(layout_id)) 

333 input_layouts.append(layout) 

334 

335 cache_manager = LayoutCacheManager.get_instance() 

336 layout_cache = cache_manager.get_layout_cache() 

337 func_name = platform.get_op_name(func) 

338 if func_name not in layout_cache: 

339 layout_cache[func_name] = {} 

340 

341 op_layout_cache = layout_cache[func_name] 

342 

343 distribute_op = cache_manager.distributed_op(func_name) 

344 if cache_key in op_layout_cache: 

345 output_layout, op_impl = op_layout_cache[cache_key] 

346 else: 

347 all_args = (input_layouts, extra_args) 

348 output_layout = distribute_op.infer_layout(*all_args) 

349 op_impl = distribute_op.get_expand_impl(func, output_layout, input_layouts, extra_args) 

350 op_layout_cache[cache_key] = (output_layout, op_impl) 

351 

352 if op_impl is None: 

353 op_impl = func 

354 

355 py_output = op_impl(*input_args, **kwargs) 

356 

357 if isinstance(py_output, (tuple, list)): 

358 output = () 

359 if isinstance(output_layout, (tuple, list)): 

360 if len(py_output) == len(output_layout): 

361 for i, output_item in enumerate(py_output): 

362 output += (DTensor.from_local(output_item, output_layout[i].mesh, output_layout[i].placements),) 

363 else: 

364 raise RuntimeError(f"Output tuple size ({len(py_output)}) " 

365 f"does not match layout tuple size ({len(output_layout)})") 

366 else: 

367 raise RuntimeError("Output is a tuple but layout is not") 

368 return output 

369 

370 return DTensor.from_local(py_output, output_layout.mesh, output_layout.placements) 

371 

372 def _with_layout_infer_reshape(self, func: callable, *args) -> Tensor: 

373 """_with_layout_infer_reshape""" 

374 input_tensor = args[0] 

375 shape = args[1] 

376 cache_key = LayoutCacheKey([]) 

377 input_layouts = [] 

378 

379 layout = input_tensor.layout 

380 input_layouts.append(layout) 

381 layout_id = layout.compact_str 

382 cache_key.layout_ids.append(str(layout_id)) 

383 

384 extra_args = [] 

385 extra_args.append(shape) 

386 cache_key.layout_ids.append(str(shape)) 

387 

388 input_shape = input_tensor.shape 

389 extra_args.append(input_shape) 

390 cache_key.layout_ids.append(str(input_shape)) 

391 

392 cache_manager = LayoutCacheManager.get_instance() 

393 layout_cache = cache_manager.get_layout_cache() 

394 func_name = platform.get_op_name(func) 

395 if func_name not in layout_cache: 

396 layout_cache[func_name] = {} 

397 

398 op_layout_cache = layout_cache[func_name] 

399 

400 distribute_op = cache_manager.distributed_op(func_name) 

401 if cache_key in op_layout_cache: 

402 infer_output, op_impl = op_layout_cache[cache_key] 

403 else: 

404 all_args = (input_layouts, extra_args) 

405 infer_output = distribute_op.infer_layout(*all_args) 

406 op_impl = distribute_op.get_expand_impl(func, infer_output, input_layouts, extra_args) 

407 op_layout_cache[cache_key] = (infer_output, op_impl) 

408 

409 infer_output_tuple = infer_output 

410 local_shape = infer_output_tuple[1] 

411 

412 if op_impl is None: 

413 op_impl = func 

414 

415 py_output = op_impl(input_tensor.to_local(), local_shape) 

416 

417 return DTensor.from_local(py_output, infer_output_tuple[0].mesh, infer_output_tuple[0].placements) 

418 

419 def _process_args_and_kwargs_with_shape( 

420 self, args, kwargs, cache_key: "LayoutCacheKey" 

421 ) -> tuple[list, list, list, list, dict]: 

422 """_process_args_and_kwargs_with_shape""" 

423 input_layouts = [] 

424 extra_args = [] 

425 input_shapes = [] 

426 input_args = [] 

427 input_kwargs = kwargs.copy() 

428 for arg in args: 

429 if arg is None: 

430 input_layouts.append(None) 

431 input_shapes.append(None) 

432 input_args.append(arg) 

433 continue 

434 

435 if not hasattr(arg, "_layout"): 

436 id_str = "scalar" 

437 if not isinstance(arg, Tensor): 

438 id_str = str(arg) 

439 cache_key.layout_ids.append(id_str) 

440 extra_args.append(arg) 

441 input_layouts.append(None) 

442 input_args.append(arg) 

443 else: 

444 layout = arg.layout 

445 layout_id = layout.compact_str 

446 cache_key.layout_ids.append(str(layout_id)) 

447 input_layouts.append(layout) 

448 if isinstance(arg, DTensor): 

449 input_args.append(arg.to_local()) 

450 else: 

451 input_args.append(arg) 

452 

453 if not hasattr(arg, "shape"): 

454 input_shapes.append(None) 

455 else: 

456 input_shape = arg.shape 

457 input_shapes.append(input_shape) 

458 cache_key.layout_ids.append(str(input_shape)) 

459 

460 for k, val in kwargs.items(): 

461 if val is None: 

462 input_layouts.append(None) 

463 continue 

464 if not hasattr(val, "_layout"): 

465 id_str = "scalar" 

466 if not isinstance(val, Tensor): 

467 id_str = str(val) 

468 cache_key.layout_ids.append(id_str) 

469 extra_args.append(val) 

470 input_layouts.append(None) 

471 else: 

472 layout = val.layout 

473 layout_id = layout.compact_str 

474 cache_key.layout_ids.append(str(layout_id)) 

475 input_layouts.append(layout) 

476 if isinstance(val, DTensor): 

477 input_kwargs[k] = val.to_local() 

478 

479 if not hasattr(val, "shape"): 

480 input_shapes.append(None) 

481 else: 

482 input_shape = val.shape 

483 input_shapes.append(input_shape) 

484 cache_key.layout_ids.append(str(input_shape)) 

485 

486 return input_layouts, input_shapes, extra_args, input_args, input_kwargs 

487 

488 def _with_layout_infer_with_shape(self, func: callable, *args, **kwargs) -> Tensor: 

489 """_with_layout_infer_with_shape""" 

490 func_name = platform.get_op_name(func) 

491 packed_call = None 

492 # Packed fallback args for some ops (e.g. Mod: (prim_obj, "Mod", (x, y))). 

493 if (func_name in self.unpack_ops and len(args) == 3 and 

494 isinstance(args[1], str) and isinstance(args[2], (tuple, list))): 

495 packed_call = (args[0], args[1]) 

496 args = tuple(args[2]) 

497 

498 cache_key = LayoutCacheKey([]) 

499 input_layouts, input_shapes, extra_args, input_args, input_kwargs = \ 

500 self._process_args_and_kwargs_with_shape(args, kwargs, cache_key) 

501 

502 cache_manager = LayoutCacheManager.get_instance() 

503 layout_cache = cache_manager.get_layout_cache() 

504 if func_name not in layout_cache: 

505 layout_cache[func_name] = {} 

506 

507 op_layout_cache = layout_cache[func_name] 

508 

509 distribute_op = cache_manager.distributed_op(func_name) 

510 if cache_key in op_layout_cache: 

511 output_layout, op_impl = op_layout_cache[cache_key] 

512 else: 

513 extra_args.append(input_shapes) 

514 all_args = (input_layouts, extra_args) 

515 output_layout = distribute_op.infer_layout(*all_args) 

516 op_impl = distribute_op.get_expand_impl(func, output_layout, input_layouts, extra_args) 

517 op_layout_cache[cache_key] = (output_layout, op_impl) 

518 

519 if op_impl is None: 

520 op_impl = func 

521 

522 if packed_call is not None: 

523 py_output = op_impl(packed_call[0], packed_call[1], tuple(input_args), **input_kwargs) 

524 else: 

525 py_output = op_impl(*input_args, **input_kwargs) 

526 

527 # 设置输出布局 

528 if isinstance(py_output, (tuple, list)): 

529 output = () 

530 if isinstance(output_layout, (tuple, list)): 

531 if len(py_output) == len(output_layout): 

532 for i, output_item in enumerate(py_output): 

533 output += (DTensor.from_local(output_item, output_layout[i].mesh, output_layout[i].placements),) 

534 else: 

535 raise RuntimeError(f"Output tuple size ({len(py_output)}) " 

536 f"does not match layout tuple size ({len(output_layout)})") 

537 else: 

538 raise RuntimeError("Output is a tuple but layout is not") 

539 return output 

540 

541 return DTensor.from_local(py_output, output_layout.mesh, output_layout.placements) 

542 

543 def _with_layout_infer_slice(self, func: callable, *args) -> Tensor: 

544 """_with_layout_infer_slice""" 

545 input_tensor = args[0] 

546 begin = args[1] 

547 end = args[2] 

548 

549 # 输入布局 

550 cache_key = LayoutCacheKey([]) 

551 input_layouts = [] 

552 

553 layout = input_tensor.layout 

554 global_shape = input_tensor.shape 

555 input_layouts.append(layout) 

556 layout_id = layout.compact_str 

557 cache_key.layout_ids.append(str(layout_id)) 

558 

559 extra_args = [] 

560 extra_args.append(begin) 

561 extra_args.append(end) 

562 extra_args.append(global_shape) 

563 cache_key.layout_ids.append(str(begin)) 

564 cache_key.layout_ids.append(str(end)) 

565 cache_key.layout_ids.append(str(global_shape)) 

566 

567 cache_manager = LayoutCacheManager.get_instance() 

568 layout_cache = cache_manager.get_layout_cache() 

569 func_name = platform.get_op_name(func) 

570 if func_name not in layout_cache: 

571 layout_cache[func_name] = {} 

572 

573 op_layout_cache = layout_cache[func_name] 

574 

575 distribute_op = cache_manager.distributed_op(func_name) 

576 if cache_key in op_layout_cache: 

577 infer_output, op_impl = op_layout_cache[cache_key] 

578 else: 

579 all_args = (input_layouts, extra_args) 

580 infer_output = distribute_op.infer_layout(*all_args) 

581 op_impl = distribute_op.get_expand_impl(func, infer_output, input_layouts, extra_args) 

582 op_layout_cache[cache_key] = (infer_output, op_impl) 

583 

584 infer_output_tuple = infer_output 

585 new_begin = infer_output_tuple[1] 

586 new_end = infer_output_tuple[2] 

587 

588 if op_impl is None: 

589 op_impl = func 

590 

591 py_output = op_impl(input_tensor.to_local(), new_begin, new_end) 

592 

593 return DTensor.from_local(py_output, infer_output_tuple[0].mesh, infer_output_tuple[0].placements) 

594 

595 def _merge_default(self, config: dict): 

596 """Apply __default__ values to all ops in this YAML file.""" 

597 if "__default__" not in config: 

598 return config 

599 

600 default_cfg = config["__default__"] 

601 merged = {} 

602 

603 for op_name, op_cfg in config.items(): 

604 if op_name == "__default__": 

605 continue 

606 

607 new_cfg = default_cfg.copy() 

608 new_cfg.update(op_cfg) 

609 merged[op_name] = new_cfg 

610 

611 return merged 

612 

613 def safe_load_yaml_from_dir(self): 

614 """ 

615 Load yaml dictionary from directory. 

616 """ 

617 yaml_dict = {} 

618 yaml_path = os.path.join(self.work_dir, self.yaml_dir) if self.work_dir else self.yaml_dir 

619 if not os.path.isdir(yaml_path): 

620 raise ValueError(f"Invalid yaml directory path: {yaml_path}") 

621 

622 for yaml_file_path in glob.glob(os.path.join(yaml_path, '*.yaml')): 

623 with open(yaml_file_path, 'r', encoding="utf-8") as f: 

624 yaml_data = yaml.safe_load(f) 

625 

626 yaml_data = self._merge_default(yaml_data) 

627 for name, data in yaml_data.items(): 

628 if name in yaml_dict: 

629 raise ValueError(f"Duplicate yaml object with name '{name}'.") 

630 yaml_dict[name] = data 

631 

632 return yaml_dict 

633 

634 def dispatch(self, op_call: callable, args: tuple[object, ...], kwargs: dict[str, object]): 

635 """ 

636 dispatch 

637 :param op_call: 

638 :param args: 

639 :param kwargs: 

640 :return: 

641 """ 

642 op_name = platform.get_op_name(op_call) 

643 if op_name in self.whitelist or get_dtensor_dispatch() is False: 

644 input_args = [arg.to_local() if isinstance(arg, DTensor) else arg for arg in args] 

645 return op_call(*input_args, **kwargs) 

646 if op_name not in self.layout_infer_ops: 

647 raise RuntimeError(f"Operator {op_name} dose not contain parallel layout infer func.") 

648 

649 layout_infer_info = self.layout_infer_ops[op_name] 

650 suffix = layout_infer_info.get('infer_layout_suffix', '') 

651 if not suffix: 

652 return self._with_layout_infer(op_call, *args, **kwargs) 

653 

654 if suffix == "WithShape": 

655 return self._with_layout_infer_with_shape(op_call, *args, **kwargs) 

656 if suffix == "Reshape": 

657 return self._with_layout_infer_reshape(op_call, *args) 

658 if suffix == "WithTupleExpand": 

659 return self._with_layout_infer_with_tuple_expand(op_call, *args, **kwargs) 

660 if suffix == "Slice": 

661 return self._with_layout_infer_slice(op_call, *args) 

662 raise RuntimeError(f"Operator {op_name} specified wrong suffix in parallel yaml.") 

663 

664_OP_DISPATCHER = OpDispatcher()