Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / _op_dispatch.py: 31%

536 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""_op_dispatch""" 

16import os 

17import sys 

18import atexit 

19import glob 

20import importlib 

21from typing import Any, List, Dict, Optional, Set 

22from itertools import chain 

23 

24import yaml 

25 

26from hyper_parallel.core.shard.ops.parallel_ops_register import get_distributed_op 

27from hyper_parallel.core.dtensor.dtensor import DTensor 

28from hyper_parallel.core.dtensor.random import OffsetBasedRNGTracker, is_rng_supported_mesh 

29from hyper_parallel.platform import get_platform 

30from hyper_parallel.platform.platform import PlatformType 

31 

32platform = get_platform() 

33Tensor = platform.Tensor 

34 

35 

36def _apply_shard_offset_to_rng_args(args, offset_incr): 

37 """Apply per-shard offset increment to seed/offset tensors in MindSpore random op args. 

38 

39 MindSpore random ops (e.g. ``randn_like_``) receive ``(seed, offset)`` as 

40 explicit int64 scalar tensors from ``default_generator._step()`` in the 

41 Python wrapper *before* the C++ dispatch triggers ``__fallback__``. By the 

42 time ``_dispatch_random_op`` is called, the kernel will use whatever 

43 ``(seed, offset)`` values are in the args—it does **not** read the 

44 generator again. This function finds the offset tensor and adds the 

45 per-rank offset increment so each shard gets a unique random stream. 

46 

47 The (seed, offset) pair is identified as the last two consecutive int64 

48 0-dim tensors in *args* (scanning from the end to skip trailing dtype / 

49 device arguments). 

50 

51 Args: 

52 args: The list of local args for the random op. 

53 offset_incr (int): Per-shard offset increment. 

54 

55 Returns: 

56 list: Modified args with the offset tensor adjusted. 

57 """ 

58 int64_dtype = platform.tensor_dtype.int64 

59 last_int64_idx = -1 

60 for i in range(len(args) - 1, -1, -1): 

61 arg = args[i] 

62 if isinstance(arg, Tensor) and arg.dtype == int64_dtype and arg.ndim == 0: 

63 if last_int64_idx == i + 1: 

64 offset_idx = i + 1 

65 new_args = list(args) 

66 new_offset = int(new_args[offset_idx].item()) + offset_incr 

67 new_args[offset_idx] = platform.tensor([new_offset], dtype=int64_dtype).reshape(()) 

68 return new_args 

69 last_int64_idx = i 

70 return args 

71 

72_dtensor_dispatch = True 

73_no_skip_ops: Set[str] = set() 

74 

75 

76def get_no_skip_ops() -> Set[str]: 

77 """Return the set of op names that are exempt from SkipDTensorDispatch.""" 

78 return _no_skip_ops 

79 

80 

81def add_no_skip_ops(op_names: Set[str]) -> None: 

82 """Add op names to the no-skip set so they are always dispatched through DTensor. 

83 

84 Args: 

85 op_names: Set of canonical op name strings to register as no-skip. 

86 """ 

87 global _no_skip_ops 

88 _no_skip_ops = _no_skip_ops | op_names 

89 

90 

91def remove_no_skip_ops(op_names: Set[str]) -> None: 

92 """Remove op names from the no-skip set. 

93 

94 Args: 

95 op_names: Set of canonical op name strings to remove. 

96 """ 

97 global _no_skip_ops 

98 _no_skip_ops = _no_skip_ops - op_names 

99 

100 

101def enable_dtensor_dispatch() -> None: 

102 """ 

103 Enable DTensor dispatch for distributed tensor operations. 

104 

105 When enabled, tensor operations will be dispatched through the 

106 distributed operator dispatcher for layout inference and redistribution. 

107 """ 

108 global _dtensor_dispatch 

109 _dtensor_dispatch = True 

110 

111 

112def disable_dtensor_dispatch() -> None: 

113 """ 

114 Disable DTensor dispatch for distributed tensor operations. 

115 

116 When disabled, tensor operations will bypass the distributed operator 

117 dispatcher and use native implementations directly. 

118 """ 

119 global _dtensor_dispatch 

120 _dtensor_dispatch = False 

121 

122 

123def get_dtensor_dispatch() -> bool: 

124 """ 

125 Get the current DTensor dispatch status. 

126 

127 Returns: 

128 bool: True if DTensor dispatch is enabled, False otherwise. 

129 """ 

130 return _dtensor_dispatch 

131 

132 

133class LayoutCacheKey: 

134 """Immutable layout cache key.""" 

135 __slots__ = ('_tuple', '_hash') 

136 

137 def __init__(self, layout_ids: List[str]): 

138 self._tuple = tuple(layout_ids) 

139 self._hash = hash(self._tuple) 

140 

141 @classmethod 

142 def from_cache_values(cls, cache_values: list) -> "LayoutCacheKey": 

143 """Build a LayoutCacheKey from a cache_values list. 

144 

145 Args: 

146 cache_values (list): Mixed list of Layout objects (with compact_str) and raw scalars. 

147 

148 Returns: 

149 LayoutCacheKey: Immutable key derived from the string representation of each value. 

150 """ 

151 key_values = [] 

152 for v in cache_values: 

153 if hasattr(v, 'compact_str'): 

154 key_values.append(str(v.compact_str)) 

155 else: 

156 key_values.append(str(v)) 

157 return cls(key_values) 

158 

159 def __eq__(self, other): 

160 if not isinstance(other, LayoutCacheKey): 

161 return False 

162 return self._tuple == other._tuple 

163 

164 def __hash__(self): 

165 return self._hash 

166 

167 def __repr__(self): 

168 return f"LayoutCacheKey({self._tuple})" 

169 

170class LayoutCacheManager: 

171 """ 

172 Cache layout in infer layout. 

173 

174 A singleton class that manages layout caches for distributed operations. 

175 It caches the inferred layouts and operation implementations to avoid 

176 redundant computation during repeated calls with the same input layouts. 

177 """ 

178 _instance = None 

179 

180 def __init__(self): 

181 self.layout_cache: Dict[str, Dict[LayoutCacheKey, Any]] = {} 

182 atexit.register(self.clear_cache) 

183 

184 @classmethod 

185 def get_instance(cls) -> "LayoutCacheManager": 

186 """ 

187 Get the singleton instance of LayoutCacheManager. 

188 

189 Returns: 

190 LayoutCacheManager: The singleton instance. 

191 """ 

192 if cls._instance is None: 

193 cls._instance = LayoutCacheManager() 

194 return cls._instance 

195 

196 def get_layout_cache(self) -> Dict[str, Dict[LayoutCacheKey, Any]]: 

197 """ 

198 Get the layout cache dictionary. 

199 

200 Returns: 

201 Dict[str, Dict[LayoutCacheKey, Any]]: The nested dictionary mapping 

202 operation names to their layout caches. 

203 """ 

204 return self.layout_cache 

205 

206 def distributed_op(self, op_name: str) -> Any: 

207 """ 

208 Get the distributed operation implementation by name. 

209 

210 Args: 

211 op_name (str): The name of the distributed operation. 

212 

213 Returns: 

214 Any: The distributed operation class or implementation. 

215 """ 

216 op = get_distributed_op(op_name) 

217 return op 

218 

219 def clear_cache(self) -> None: 

220 """ 

221 Clear all cached layouts. 

222 

223 This method is automatically registered with atexit to ensure 

224 cache is cleared when the program exits. 

225 """ 

226 self.layout_cache.clear() 

227 

228 

229class OpDispatcher: 

230 """ 

231 OpDispatcher 

232 """ 

233 def __init__(self): 

234 self._env_yaml_dir: Optional[str] = os.environ.get("HYPER_PARALLEL_OPS_YAML_DIR") 

235 self._env_python_path: Optional[str] = os.environ.get("HYPER_PARALLEL_OPS_PYTHON_PATH") 

236 # The following attributes are initialized in _setup_yaml_dir() 

237 self.work_dir = "" # Initialized in _setup_yaml_dir() 

238 self.yaml_dir = "" # Initialized in _setup_yaml_dir() 

239 

240 self._setup_paths_from_env() 

241 

242 self.layout_infer_ops = self.safe_load_yaml_from_dir() 

243 self.whitelist = ["InplaceAddExt", "InplaceSubExt", "InplaceMul", "InplaceDiv", "typeof", "DistCommIsend", 

244 "DistCommIrecv", "DistCommBroadcast", "DistCommAllReduce", "DistCommAllGather", 

245 "DistCommReduceScatter", "requires_grad_", "item", "__get__", "__set__", "register_hook", 

246 "is_complex", "chunk", "__bool__", "__len__", "__format__", "dim", 

247 "_has_compatible_shallow_copy_type", "is_floating_point", "is_contiguous"] 

248 

249 # Ops requiring args unpacking for layout inference (packed as prim, name, real_args). 

250 self.unpack_ops = ["ScatterUpdate", "Mod", "GatherNd"] 

251 

252 self._random_ops = { 

253 "normal_", "uniform_", "bernoulli", "bernoulli_", 

254 "native_dropout", "rand", "rand_like", "randn", 

255 "randn_like", "randint_like", "kaiming_uniform_", 

256 } 

257 # Only mint random op support 

258 # MindSpore use the actual kernel name. 

259 self._random_ms_ops = { 

260 "BernoulliExt", "MultinomialExt", "InplaceNormal", "InplaceUniform", 

261 "RandpermExt", "Randn", "RandLikeExt", "RandnLike", "RandInt", "RandIntLike", "RandExt", 

262 "FuncDropoutExt" 

263 } 

264 self._rng_tracker: Optional[OffsetBasedRNGTracker] = None 

265 

266 self._suffix_dispatch: Dict[str, str] = { 

267 "WithShape": "_with_layout_infer_with_shape", 

268 "Reshape": "_with_layout_infer_reshape", 

269 "WithTupleExpand": "_with_layout_infer_with_tuple_expand", 

270 "Slice": "_with_layout_infer_slice", 

271 } 

272 

273 self._register_distributed_ops() 

274 

275 def _setup_paths_from_env(self): 

276 """ 

277 Setup YAML directory and Python path from environment variables. 

278 

279 This method initializes the YAML directory and extends sys.path based on 

280 environment variables HYPER_PARALLEL_OPS_YAML_DIR and HYPER_PARALLEL_OPS_PYTHON_PATH. 

281 """ 

282 self._setup_yaml_dir(self._env_yaml_dir) 

283 self._extend_sys_path(self._env_python_path) 

284 

285 def _setup_yaml_dir(self, env_yaml_dir: Optional[str]): 

286 """ 

287 Feature: Configure yaml_dir/work_dir for OpDispatcher 

288 Description: Resolve the YAML directory used to load distributed op definitions. 

289 If env_yaml_dir is an absolute path, use it directly; otherwise treat it 

290 as a path relative to the project work_dir. If env_yaml_dir is not set, 

291 fall back to the default 'shard/ops/yaml' under work_dir. 

292 Expectation: self.yaml_dir and self.work_dir are set to valid values used later by 

293 safe_load_yaml_from_dir(); no functional behavior is changed. 

294 """ 

295 if env_yaml_dir: 

296 if os.path.isabs(env_yaml_dir): 

297 self.yaml_dir = env_yaml_dir 

298 self.work_dir = "" 

299 else: 

300 self.work_dir = os.path.normpath( 

301 os.path.join(os.path.dirname(os.path.realpath(__file__)), "../") 

302 ) 

303 self.yaml_dir = env_yaml_dir 

304 else: 

305 self.yaml_dir = "shard/ops/yaml" 

306 self.work_dir = os.path.normpath( 

307 os.path.join(os.path.dirname(os.path.realpath(__file__)), "../") 

308 ) 

309 

310 def _extend_sys_path(self, env_python_path: Optional[str]): 

311 if not env_python_path: 

312 return 

313 python_paths = env_python_path.split(":") 

314 for path in python_paths: 

315 if path and os.path.isdir(path) and path not in sys.path: 

316 sys.path.insert(0, path) 

317 

318 def _register_distributed_ops(self): 

319 for op_name, config in self.layout_infer_ops.items(): 

320 self._register_single_distributed_op(op_name, config) 

321 

322 def _register_single_distributed_op(self, op_name: str, config: dict): 

323 """ 

324 Feature: Register a single distributed op implementation 

325 Description: Import the distributed op class specified by config and instantiate it 

326 with op_name to trigger registration in the distributed op registry. 

327 Prefer 'distributed_op_module' when provided; otherwise import from 

328 built-in module prefix 'hyper_parallel.core.shard.ops.' plus 

329 'distributed_op_file'. If import fails and an external python path is 

330 provided via env, fall back to importing 'distributed_op_file' directly. 

331 Expectation: The distributed op class is imported and instantiated successfully, 

332 or the original import error is raised; no functional behavior is changed. 

333 """ 

334 class_name = config["distributed_op_class"] 

335 

336 if "distributed_op_module" in config: 

337 module_name = config["distributed_op_module"] 

338 module = importlib.import_module(module_name) 

339 op_class = getattr(module, class_name) 

340 _ = op_class(op_name) 

341 return 

342 

343 module_file = config["distributed_op_file"] 

344 try: 

345 module_name = "hyper_parallel.core.shard.ops." + module_file 

346 module = importlib.import_module(module_name) 

347 op_class = getattr(module, class_name) 

348 _ = op_class(op_name) 

349 except (ModuleNotFoundError, ImportError): 

350 if self._env_python_path: 

351 module = importlib.import_module(module_file) 

352 op_class = getattr(module, class_name) 

353 _ = op_class(op_name) 

354 else: 

355 raise 

356 

357 @staticmethod 

358 def _process_args_and_kwargs( 

359 args, kwargs 

360 ) -> tuple[list, list, list, dict, list]: 

361 """_process_args_and_kwargs""" 

362 input_layouts = [] 

363 extra_args = [] 

364 input_args = [] 

365 input_kwargs = kwargs.copy() 

366 cache_key_values = [] 

367 

368 for arg in args: 

369 if arg is None: 

370 input_layouts.append(None) 

371 input_args.append(arg) 

372 continue 

373 

374 if not hasattr(arg, "_layout"): 

375 id_str = "scalar" 

376 if not isinstance(arg, Tensor): 

377 id_str = str(arg) 

378 cache_key_values.append(id_str) 

379 extra_args.append(arg) 

380 input_layouts.append(None) 

381 input_args.append(arg) 

382 else: 

383 layout = arg.layout 

384 layout_id = layout.compact_str 

385 cache_key_values.append(str(layout_id)) 

386 input_layouts.append(layout) 

387 if isinstance(arg, DTensor): 

388 input_args.append(arg.to_local()) 

389 else: 

390 input_args.append(arg) 

391 

392 for k, val in kwargs.items(): 

393 if val is None: 

394 input_layouts.append(None) 

395 continue 

396 if not hasattr(val, "_layout"): 

397 id_str = "scalar" 

398 if not isinstance(val, Tensor): 

399 id_str = str(val) 

400 cache_key_values.append(id_str) 

401 extra_args.append(val) 

402 input_layouts.append(None) 

403 else: 

404 layout = val.layout 

405 layout_id = layout.compact_str 

406 cache_key_values.append(str(layout_id)) 

407 input_layouts.append(layout) 

408 if isinstance(val, DTensor): 

409 input_kwargs[k] = val.to_local() 

410 

411 return input_layouts, extra_args, input_args, input_kwargs, cache_key_values 

412 

413 def _with_layout_infer(self, func: callable, *args, **kwargs) -> Tensor: 

414 """_with_layout_infer""" 

415 func_name = platform.get_op_name(func) 

416 packed_call = None 

417 if(func_name in self.unpack_ops and len(args) == 3 and 

418 isinstance(args[1], str) and isinstance(args[2],(tuple,list))): 

419 packed_call = (args[0], args[1]) 

420 args = tuple(args[2]) 

421 

422 input_layouts, extra_args, input_args, input_kwargs, cache_key_values = \ 

423 OpDispatcher._process_args_and_kwargs(args, kwargs) 

424 cache_key = LayoutCacheKey(cache_key_values) 

425 cache_manager = LayoutCacheManager.get_instance() 

426 layout_cache = cache_manager.get_layout_cache() 

427 if func_name not in layout_cache: 

428 layout_cache[func_name] = {} 

429 

430 op_layout_cache = layout_cache[func_name] 

431 

432 distribute_op = cache_manager.distributed_op(func_name) 

433 if cache_key in op_layout_cache: 

434 output_layout, op_impl = op_layout_cache[cache_key] 

435 else: 

436 all_args = (input_layouts, extra_args) 

437 output_layout = distribute_op.infer_layout(*all_args) 

438 op_impl = distribute_op.get_expand_impl(func, output_layout, input_layouts, extra_args) 

439 op_layout_cache[cache_key] = (output_layout, op_impl) 

440 

441 if op_impl is None: 

442 op_impl = func 

443 

444 if packed_call is not None: 

445 py_output = op_impl(packed_call[0], packed_call[1], tuple(input_args), **input_kwargs) 

446 else: 

447 py_output = op_impl(*input_args, **input_kwargs) 

448 

449 if isinstance(py_output, (tuple, list)): 

450 output = () 

451 if isinstance(output_layout, (tuple, list)): 

452 if len(py_output) == len(output_layout): 

453 for i, output_item in enumerate(py_output): 

454 output += (DTensor.from_local( 

455 output_item, output_layout[i].mesh, 

456 output_layout[i].alias_placements),) 

457 else: 

458 raise RuntimeError(f"Output tuple size ({len(py_output)}) " 

459 f"does not match layout tuple size ({len(output_layout)})") 

460 else: 

461 raise RuntimeError("Output is a tuple but layout is not") 

462 return output 

463 

464 return DTensor.from_local( 

465 py_output, output_layout.mesh, output_layout.alias_placements) 

466 

467 def _extract_single_arg_layout(self, expanded_args, kwargs_value): 

468 """Helper to extract layout and cache info for a single argument.""" 

469 cache_key_values = [] 

470 input_layouts = [] 

471 extra_args = [] 

472 

473 for arg in chain(expanded_args, kwargs_value): 

474 if arg is None: 

475 input_layouts.append(None) 

476 continue 

477 

478 if not hasattr(arg, "_layout"): 

479 id_str = "scalar" if isinstance(arg, Tensor) else str(arg) 

480 cache_key_values.append(id_str) 

481 extra_args.append(arg) 

482 input_layouts.append(None) 

483 else: 

484 layout = arg.layout 

485 cache_key_values.append(str(layout.compact_str)) 

486 input_layouts.append(layout) 

487 return cache_key_values, input_layouts, extra_args 

488 

489 def _pack_infer_output(self, py_output, output_layout): 

490 """Helper to pack py_output into DTensors using output_layout.""" 

491 if isinstance(py_output, (tuple, list)): 

492 if not isinstance(output_layout, (tuple, list)): 

493 raise RuntimeError("Output is a tuple but layout is not") 

494 if len(py_output) != len(output_layout): 

495 raise RuntimeError(f"Output tuple size ({len(py_output)}) " 

496 f"does not match layout tuple size ({len(output_layout)})") 

497 

498 return tuple( 

499 DTensor.from_local(item, layout.mesh, layout.alias_placements) 

500 for item, layout in zip(py_output, output_layout) 

501 ) 

502 

503 return DTensor.from_local(py_output, output_layout.mesh, output_layout.alias_placements) 

504 

505 def _with_layout_infer_with_tuple_expand(self, func: callable, *args, **kwargs) -> Tensor: 

506 """_with_layout_infer_with_tuple_expand""" 

507 expanded_args = [] 

508 input_args = [] 

509 for arg in args: 

510 if isinstance(arg, (tuple, list)): 

511 expanded_args.extend(arg) 

512 # pylint: disable=R1728 

513 input_args.append(tuple(item.to_local() if hasattr(item, "_layout") else item for item in arg)) 

514 else: 

515 expanded_args.append(arg) 

516 input_args.append(arg.to_local() if isinstance(arg, DTensor) else arg) 

517 

518 # Process kwargs into local tensors 

519 input_kwargs = {k: (v.to_local() if isinstance(v, DTensor) else v) for k, v in kwargs.items()} 

520 

521 # Extract layouts for positional args 

522 cache_key_values, input_layouts, extra_args = self._extract_single_arg_layout(expanded_args, kwargs.values()) 

523 

524 cache_key = LayoutCacheKey(cache_key_values) 

525 

526 cache_manager = LayoutCacheManager.get_instance() 

527 layout_cache = cache_manager.get_layout_cache() 

528 func_name = platform.get_op_name(func) 

529 if func_name not in layout_cache: 

530 layout_cache[func_name] = {} 

531 

532 op_layout_cache = layout_cache[func_name] 

533 distribute_op = cache_manager.distributed_op(func_name) 

534 

535 if cache_key in op_layout_cache: 

536 output_layout, op_impl = op_layout_cache[cache_key] 

537 else: 

538 all_args = (input_layouts, extra_args) 

539 output_layout = distribute_op.infer_layout(*all_args) 

540 op_impl = distribute_op.get_expand_impl(func, output_layout, input_layouts, extra_args) 

541 op_layout_cache[cache_key] = (output_layout, op_impl) 

542 

543 if op_impl is None: 

544 op_impl = func 

545 

546 py_output = op_impl(*input_args, **input_kwargs) 

547 return distribute_op.wrap_output(py_output, output_layout) 

548 

549 @staticmethod 

550 def _with_layout_infer_reshape(func: callable, *args) -> Tensor: 

551 """_with_layout_infer_reshape""" 

552 input_tensor = args[0] 

553 shape = args[1] 

554 

555 layout = input_tensor.layout 

556 input_layouts = [layout] 

557 

558 extra_args = [shape, input_tensor.shape] 

559 

560 cache_key_values = [str(layout.compact_str), str(shape), str(input_tensor.shape)] 

561 cache_key = LayoutCacheKey(cache_key_values) 

562 

563 cache_manager = LayoutCacheManager.get_instance() 

564 layout_cache = cache_manager.get_layout_cache() 

565 func_name = platform.get_op_name(func) 

566 if func_name not in layout_cache: 

567 layout_cache[func_name] = {} 

568 

569 op_layout_cache = layout_cache[func_name] 

570 

571 distribute_op = cache_manager.distributed_op(func_name) 

572 if cache_key in op_layout_cache: 

573 infer_output, op_impl = op_layout_cache[cache_key] 

574 else: 

575 all_args = (input_layouts, extra_args) 

576 infer_output = distribute_op.infer_layout(*all_args) 

577 op_impl = distribute_op.get_expand_impl(func, infer_output, input_layouts, extra_args) 

578 op_layout_cache[cache_key] = (infer_output, op_impl) 

579 

580 infer_output_tuple = infer_output 

581 local_shape = infer_output_tuple[1] 

582 

583 if op_impl is None: 

584 op_impl = func 

585 

586 py_output = op_impl(input_tensor.to_local(), local_shape) 

587 

588 return DTensor.from_local(py_output, infer_output_tuple[0].mesh, infer_output_tuple[0].alias_placements) 

589 

590 @staticmethod 

591 def _process_args_and_kwargs_with_shape(args, kwargs): 

592 """Process args and kwargs with input shapes for WithShape suffix operators. 

593 

594 Args: 

595 args: Positional arguments from dispatch. 

596 kwargs: Keyword arguments from dispatch. 

597 

598 Returns: 

599 tuple: (input_layouts, input_shapes, extra_args, input_args, input_kwargs, cache_key_values) 

600 """ 

601 input_layouts = [] 

602 extra_args = [] 

603 input_shapes = [] 

604 input_args = [] 

605 input_kwargs = kwargs.copy() 

606 cache_key_values = [] 

607 

608 for arg in args: 

609 if arg is None: 

610 input_layouts.append(None) 

611 input_shapes.append(None) 

612 input_args.append(arg) 

613 continue 

614 

615 if not hasattr(arg, "_layout"): 

616 id_str = "scalar" 

617 if not isinstance(arg, Tensor): 

618 id_str = str(arg) 

619 cache_key_values.append(id_str) 

620 extra_args.append(arg) 

621 input_layouts.append(None) 

622 input_args.append(arg) 

623 else: 

624 layout = arg.layout 

625 layout_id = layout.compact_str 

626 cache_key_values.append(str(layout_id)) 

627 input_layouts.append(layout) 

628 if isinstance(arg, DTensor): 

629 input_args.append(arg.to_local()) 

630 else: 

631 input_args.append(arg) 

632 

633 if not hasattr(arg, "shape"): 

634 input_shapes.append(None) 

635 else: 

636 input_shape = arg.shape 

637 input_shapes.append(input_shape) 

638 cache_key_values.append(str(input_shape)) 

639 

640 for k, val in kwargs.items(): 

641 if val is None: 

642 input_layouts.append(None) 

643 continue 

644 if not hasattr(val, "_layout"): 

645 id_str = "scalar" 

646 if not isinstance(val, Tensor): 

647 id_str = str(val) 

648 cache_key_values.append(id_str) 

649 extra_args.append(val) 

650 input_layouts.append(None) 

651 else: 

652 layout = val.layout 

653 layout_id = layout.compact_str 

654 cache_key_values.append(str(layout_id)) 

655 input_layouts.append(layout) 

656 if isinstance(val, DTensor): 

657 input_kwargs[k] = val.to_local() 

658 

659 if not hasattr(val, "shape"): 

660 input_shapes.append(None) 

661 else: 

662 input_shape = val.shape 

663 input_shapes.append(input_shape) 

664 cache_key_values.append(str(input_shape)) 

665 

666 return input_layouts, input_shapes, extra_args, input_args, input_kwargs, cache_key_values 

667 

668 def _with_layout_infer_with_shape(self, func: callable, *args, **kwargs) -> Tensor: 

669 """_with_layout_infer_with_shape""" 

670 func_name = platform.get_op_name(func) 

671 packed_call = None 

672 # Packed fallback args for some ops (e.g. Mod: (prim_obj, "Mod", (x, y))). 

673 if (func_name in self.unpack_ops and len(args) == 3 and 

674 isinstance(args[1], str) and isinstance(args[2], (tuple, list))): 

675 packed_call = (args[0], args[1]) 

676 args = tuple(args[2]) 

677 

678 (input_layouts, input_shapes, extra_args, input_args, 

679 input_kwargs, cache_key_values) = OpDispatcher._process_args_and_kwargs_with_shape(args, kwargs) 

680 cache_key = LayoutCacheKey(cache_key_values) 

681 

682 cache_manager = LayoutCacheManager.get_instance() 

683 layout_cache = cache_manager.get_layout_cache() 

684 if func_name not in layout_cache: 

685 layout_cache[func_name] = {} 

686 

687 op_layout_cache = layout_cache[func_name] 

688 

689 distribute_op = cache_manager.distributed_op(func_name) 

690 if cache_key in op_layout_cache: 

691 output_layout, op_impl = op_layout_cache[cache_key] 

692 else: 

693 extra_args.append(input_shapes) 

694 all_args = (input_layouts, extra_args) 

695 output_layout = distribute_op.infer_layout(*all_args) 

696 op_impl = distribute_op.get_expand_impl(func, output_layout, input_layouts, extra_args) 

697 op_layout_cache[cache_key] = (output_layout, op_impl) 

698 

699 if op_impl is None: 

700 op_impl = func 

701 

702 if packed_call is not None: 

703 py_output = op_impl(packed_call[0], packed_call[1], tuple(input_args), **input_kwargs) 

704 else: 

705 py_output = op_impl(*input_args, **input_kwargs) 

706 

707 # set output layout 

708 if isinstance(py_output, (tuple, list)): 

709 output = () 

710 if isinstance(output_layout, (tuple, list)): 

711 if len(py_output) == len(output_layout): 

712 for i, output_item in enumerate(py_output): 

713 output += (DTensor.from_local( 

714 output_item, output_layout[i].mesh, 

715 output_layout[i].alias_placements),) 

716 else: 

717 raise RuntimeError(f"Output tuple size ({len(py_output)}) " 

718 f"does not match layout tuple size ({len(output_layout)})") 

719 else: 

720 raise RuntimeError("Output is a tuple but layout is not") 

721 return output 

722 

723 return DTensor.from_local( 

724 py_output, output_layout.mesh, output_layout.alias_placements) 

725 

726 def _with_layout_infer_slice(self, func: callable, *args) -> Tensor: 

727 """_with_layout_infer_slice""" 

728 input_tensor = args[0] 

729 begin = args[1] 

730 end = args[2] 

731 

732 # input layout 

733 input_layouts = [] 

734 

735 layout = input_tensor.layout 

736 global_shape = input_tensor.shape 

737 input_layouts.append(layout) 

738 layout_id = layout.compact_str 

739 

740 extra_args = [] 

741 extra_args.append(begin) 

742 extra_args.append(end) 

743 extra_args.append(global_shape) 

744 cache_key_values = [str(layout_id), str(begin), str(end), str(global_shape)] 

745 cache_key = LayoutCacheKey(cache_key_values) 

746 

747 cache_manager = LayoutCacheManager.get_instance() 

748 layout_cache = cache_manager.get_layout_cache() 

749 func_name = platform.get_op_name(func) 

750 if func_name not in layout_cache: 

751 layout_cache[func_name] = {} 

752 

753 op_layout_cache = layout_cache[func_name] 

754 

755 distribute_op = cache_manager.distributed_op(func_name) 

756 if cache_key in op_layout_cache: 

757 infer_output, op_impl = op_layout_cache[cache_key] 

758 else: 

759 all_args = (input_layouts, extra_args) 

760 infer_output = distribute_op.infer_layout(*all_args) 

761 op_impl = distribute_op.get_expand_impl(func, infer_output, input_layouts, extra_args) 

762 op_layout_cache[cache_key] = (infer_output, op_impl) 

763 

764 infer_output_tuple = infer_output 

765 new_begin = infer_output_tuple[1] 

766 new_end = infer_output_tuple[2] 

767 

768 if op_impl is None: 

769 op_impl = func 

770 

771 py_output = op_impl(input_tensor.to_local(), new_begin, new_end) 

772 

773 return DTensor.from_local(py_output, infer_output_tuple[0].mesh, infer_output_tuple[0].alias_placements) 

774 

775 @staticmethod 

776 def _merge_default(config: dict): 

777 """Apply __default__ values to all ops in this YAML file.""" 

778 if "__default__" not in config: 

779 return config 

780 

781 default_cfg = config["__default__"] 

782 merged = {} 

783 

784 for op_name, op_cfg in config.items(): 

785 if op_name == "__default__": 

786 continue 

787 

788 new_cfg = default_cfg.copy() 

789 new_cfg.update(op_cfg) 

790 merged[op_name] = new_cfg 

791 

792 return merged 

793 

794 def safe_load_yaml_from_dir(self) -> dict: 

795 """ 

796 Load yaml dictionary from directory. 

797 

798 Returns: 

799 dict: Merged dictionary of all operator configurations loaded from YAML files. 

800 """ 

801 yaml_dict = {} 

802 yaml_path = os.path.join(self.work_dir, self.yaml_dir) if self.work_dir else self.yaml_dir 

803 if not os.path.isdir(yaml_path): 

804 raise ValueError(f"Invalid yaml directory path: {yaml_path}") 

805 

806 for yaml_file_path in glob.glob(os.path.join(yaml_path, '*.yaml')): 

807 with open(yaml_file_path, 'r', encoding="utf-8") as f: 

808 yaml_data = yaml.safe_load(f) 

809 

810 yaml_data = OpDispatcher._merge_default(yaml_data) 

811 for name, data in yaml_data.items(): 

812 if name in yaml_dict: 

813 raise ValueError(f"Duplicate yaml object with name '{name}'.") 

814 yaml_dict[name] = data 

815 

816 return yaml_dict 

817 

818 def _dispatch_random_op(self, op_name: str, op_call: callable, args, kwargs): 

819 """Handle dispatch for random ops that operate on DTensors.""" 

820 first_arg = next( 

821 (x for x in chain(args, kwargs.values()) if isinstance(x, DTensor)), 

822 None, 

823 ) 

824 # Fall back to the default op if no DTensor is found. 

825 if first_arg is None: 

826 return op_call(*args, **kwargs) 

827 

828 local_args = [arg.to_local() if isinstance(arg, DTensor) else arg for arg in args] 

829 local_kwargs = {k: v.to_local() if isinstance(v, DTensor) else v for k, v in kwargs.items()} 

830 first_local_arg = first_arg.to_local() 

831 

832 if self._rng_tracker is None and is_rng_supported_mesh(): 

833 self._rng_tracker = OffsetBasedRNGTracker() 

834 

835 maybe_user_generator = local_kwargs.pop("generator", None) 

836 if ( 

837 self._rng_tracker is not None 

838 and not first_local_arg.is_meta 

839 and self._rng_tracker.distribute_region_enabled 

840 ): 

841 # pylint: disable=W0212 

842 with self._rng_tracker._distribute_region( 

843 device_mesh=first_arg.device_mesh, 

844 placements=first_arg.placements, 

845 global_shape=first_arg.shape, 

846 generator=maybe_user_generator, 

847 ): 

848 # MindSpore random ops (e.g. mint.randn_like) extract (seed, offset) 

849 # from default_generator._step() in the Python wrapper *before* the 

850 # C++ dispatch triggers __fallback__. The callback reuses these 

851 # pre-fetched tensor args, so set_rng_state inside _distribute_region 

852 # has no effect on the kernel. Fix: apply the per-shard offset 

853 # increment directly to the offset tensor in the args. 

854 if platform.platform_type == PlatformType.MINDSPORE: 

855 offset_incr = self._rng_tracker.compute_offset_incr( 

856 first_arg.device_mesh, first_arg.placements, first_arg.shape, 

857 ) 

858 local_args = _apply_shard_offset_to_rng_args(local_args, offset_incr) 

859 local_results = op_call(*local_args, **local_kwargs) 

860 else: 

861 local_results = op_call(*local_args, **local_kwargs) 

862 

863 # in-place ops 

864 if op_name.endswith('_'): 

865 return first_arg 

866 # non-in-place ops 

867 # Some ops return tuple/list, e.g. native_dropout returns (output, mask). 

868 if isinstance(local_results, (tuple, list)): 

869 return tuple( 

870 DTensor.from_local(r, first_arg.device_mesh, first_arg.layout.alias_placements) 

871 if isinstance(r, Tensor) else r 

872 for r in local_results 

873 ) 

874 if isinstance(local_results, Tensor): 

875 return DTensor.from_local(local_results, first_arg.device_mesh, first_arg.layout.alias_placements) 

876 # Fallback: return as-is for non-Tensor results (currently unreachable with existing _random_ops). 

877 return local_results 

878 

879 @staticmethod 

880 def _unwrap_args(args: tuple) -> list: 

881 """Strip DTensor wrappers from args, preserving tuple/list container structure. 

882 

883 Args: 

884 args: Op call positional arguments, may contain DTensor instances. 

885 

886 Returns: 

887 List of args with DTensor replaced by their local tensors. 

888 """ 

889 def unwrap(arg: object) -> object: 

890 """Replace DTensor with its local tensor; pass scalars and plain tensors through. 

891 

892 Args: 

893 arg (object): An element of the operator's argument list. 

894 

895 Returns: 

896 object: The local tensor if arg is a DTensor, otherwise arg unchanged. 

897 """ 

898 if isinstance(arg, DTensor): 

899 return arg.to_local() 

900 if isinstance(arg, tuple): 

901 return tuple(e.to_local() if isinstance(e, DTensor) else e for e in arg) 

902 if isinstance(arg, list): 

903 return [e.to_local() if isinstance(e, DTensor) else e for e in arg] 

904 return arg 

905 return [unwrap(arg) for arg in args] 

906 

907 def _should_bypass_dispatch(self, op_name: str) -> bool: 

908 """Return True if the op should bypass DTensor dispatch and run locally. 

909 

910 Args: 

911 op_name: Canonical operator name from platform.get_op_name(). 

912 

913 Returns: 

914 True when the op is whitelisted or DTensor dispatch is globally disabled. 

915 """ 

916 skip_dispatch = get_dtensor_dispatch() is False and op_name not in get_no_skip_ops() 

917 return op_name in self.whitelist or skip_dispatch 

918 

919 def _dispatch_layout_infer( 

920 self, op_name: str, op_call: callable, args: tuple, kwargs: dict 

921 ): 

922 """Dispatch an op through the layout-inference path. 

923 

924 Args: 

925 op_name: Canonical operator name. 

926 op_call: The raw operator callable. 

927 args: Positional arguments for op_call. 

928 kwargs: Keyword arguments for op_call. 

929 

930 Returns: 

931 Result of the layout-infer dispatch. 

932 

933 Raises: 

934 RuntimeError: If op_name is not registered or has an unknown suffix. 

935 """ 

936 if op_name not in self.layout_infer_ops: 

937 raise RuntimeError(f"Operator {op_name} does not contain parallel layout infer func.") 

938 

939 cache_manager = LayoutCacheManager.get_instance() 

940 distribute_op = cache_manager.distributed_op(op_name) 

941 

942 result = distribute_op.preprocess(args, kwargs) 

943 if result is not None: 

944 return self._dispatch_new(op_call, distribute_op, result) 

945 

946 suffix = self.layout_infer_ops[op_name].get('infer_layout_suffix', '') 

947 if not suffix: 

948 return self._with_layout_infer(op_call, *args, **kwargs) 

949 

950 handler_name = self._suffix_dispatch.get(suffix) 

951 if handler_name is None: 

952 raise RuntimeError(f"Operator {op_name} specified wrong suffix in parallel yaml.") 

953 return getattr(self, handler_name)(op_call, *args, **kwargs) 

954 

955 def _dispatch_new(self, func, distribute_op, result) -> Tensor: 

956 """New dispatch flow using preprocess result. 

957 

958 Args: 

959 func: Original function. 

960 distribute_op: Distributed operation instance. 

961 result: Preprocessed result (local_args, local_kwargs, cache_values). 

962 

963 Returns: 

964 Tensor: Dispatched result as DTensor. 

965 """ 

966 local_args, local_kwargs, cache_values = result 

967 cache_key = LayoutCacheKey.from_cache_values(cache_values) 

968 func_name = platform.get_op_name(func) 

969 cache_manager = LayoutCacheManager.get_instance() 

970 layout_cache = cache_manager.get_layout_cache() 

971 if func_name not in layout_cache: 

972 layout_cache[func_name] = {} 

973 op_layout_cache = layout_cache[func_name] 

974 if cache_key in op_layout_cache: 

975 infer_result, op_impl = op_layout_cache[cache_key] 

976 else: 

977 infer_result = distribute_op.infer_layout(cache_values) 

978 op_impl = distribute_op.get_expand_impl(func, infer_result, cache_values) 

979 op_layout_cache[cache_key] = (infer_result, op_impl) 

980 output_layouts, _ = infer_result 

981 op_impl = func if op_impl is None else op_impl 

982 py_output = op_impl(*local_args, **local_kwargs) 

983 return distribute_op.wrap_output(py_output, output_layouts) 

984 

985 def dispatch(self, op_call: callable, args: tuple, kwargs: dict) -> object: 

986 """Route an op call through the appropriate DTensor dispatch path. 

987 

988 Args: 

989 op_call: The raw operator callable. 

990 args: Positional arguments for op_call. 

991 kwargs: Keyword arguments for op_call. 

992 

993 Returns: 

994 Result of the dispatched op call. 

995 """ 

996 op_name = platform.get_op_name(op_call) 

997 

998 if self._should_bypass_dispatch(op_name): 

999 return op_call(*self._unwrap_args(args), **kwargs) 

1000 

1001 if op_name in self._random_ops or op_name in self._random_ms_ops: 

1002 return self._dispatch_random_op(op_name, op_call, args, kwargs) 

1003 

1004 # Auto-register ops that were registered programmatically via DistributedOp 

1005 # (e.g. through DFunction) without a corresponding YAML entry. 

1006 if op_name not in self.layout_infer_ops and get_distributed_op(op_name) is not None: 

1007 self.layout_infer_ops[op_name] = {} 

1008 

1009 return self._dispatch_layout_infer(op_name, op_call, args, kwargs) 

1010 

1011_OP_DISPATCHER = OpDispatcher()