Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / _op_dispatch.py: 31%
536 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""_op_dispatch"""
16import os
17import sys
18import atexit
19import glob
20import importlib
21from typing import Any, List, Dict, Optional, Set
22from itertools import chain
24import yaml
26from hyper_parallel.core.shard.ops.parallel_ops_register import get_distributed_op
27from hyper_parallel.core.dtensor.dtensor import DTensor
28from hyper_parallel.core.dtensor.random import OffsetBasedRNGTracker, is_rng_supported_mesh
29from hyper_parallel.platform import get_platform
30from hyper_parallel.platform.platform import PlatformType
32platform = get_platform()
33Tensor = platform.Tensor
36def _apply_shard_offset_to_rng_args(args, offset_incr):
37 """Apply per-shard offset increment to seed/offset tensors in MindSpore random op args.
39 MindSpore random ops (e.g. ``randn_like_``) receive ``(seed, offset)`` as
40 explicit int64 scalar tensors from ``default_generator._step()`` in the
41 Python wrapper *before* the C++ dispatch triggers ``__fallback__``. By the
42 time ``_dispatch_random_op`` is called, the kernel will use whatever
43 ``(seed, offset)`` values are in the args—it does **not** read the
44 generator again. This function finds the offset tensor and adds the
45 per-rank offset increment so each shard gets a unique random stream.
47 The (seed, offset) pair is identified as the last two consecutive int64
48 0-dim tensors in *args* (scanning from the end to skip trailing dtype /
49 device arguments).
51 Args:
52 args: The list of local args for the random op.
53 offset_incr (int): Per-shard offset increment.
55 Returns:
56 list: Modified args with the offset tensor adjusted.
57 """
58 int64_dtype = platform.tensor_dtype.int64
59 last_int64_idx = -1
60 for i in range(len(args) - 1, -1, -1):
61 arg = args[i]
62 if isinstance(arg, Tensor) and arg.dtype == int64_dtype and arg.ndim == 0:
63 if last_int64_idx == i + 1:
64 offset_idx = i + 1
65 new_args = list(args)
66 new_offset = int(new_args[offset_idx].item()) + offset_incr
67 new_args[offset_idx] = platform.tensor([new_offset], dtype=int64_dtype).reshape(())
68 return new_args
69 last_int64_idx = i
70 return args
72_dtensor_dispatch = True
73_no_skip_ops: Set[str] = set()
76def get_no_skip_ops() -> Set[str]:
77 """Return the set of op names that are exempt from SkipDTensorDispatch."""
78 return _no_skip_ops
81def add_no_skip_ops(op_names: Set[str]) -> None:
82 """Add op names to the no-skip set so they are always dispatched through DTensor.
84 Args:
85 op_names: Set of canonical op name strings to register as no-skip.
86 """
87 global _no_skip_ops
88 _no_skip_ops = _no_skip_ops | op_names
91def remove_no_skip_ops(op_names: Set[str]) -> None:
92 """Remove op names from the no-skip set.
94 Args:
95 op_names: Set of canonical op name strings to remove.
96 """
97 global _no_skip_ops
98 _no_skip_ops = _no_skip_ops - op_names
101def enable_dtensor_dispatch() -> None:
102 """
103 Enable DTensor dispatch for distributed tensor operations.
105 When enabled, tensor operations will be dispatched through the
106 distributed operator dispatcher for layout inference and redistribution.
107 """
108 global _dtensor_dispatch
109 _dtensor_dispatch = True
112def disable_dtensor_dispatch() -> None:
113 """
114 Disable DTensor dispatch for distributed tensor operations.
116 When disabled, tensor operations will bypass the distributed operator
117 dispatcher and use native implementations directly.
118 """
119 global _dtensor_dispatch
120 _dtensor_dispatch = False
123def get_dtensor_dispatch() -> bool:
124 """
125 Get the current DTensor dispatch status.
127 Returns:
128 bool: True if DTensor dispatch is enabled, False otherwise.
129 """
130 return _dtensor_dispatch
133class LayoutCacheKey:
134 """Immutable layout cache key."""
135 __slots__ = ('_tuple', '_hash')
137 def __init__(self, layout_ids: List[str]):
138 self._tuple = tuple(layout_ids)
139 self._hash = hash(self._tuple)
141 @classmethod
142 def from_cache_values(cls, cache_values: list) -> "LayoutCacheKey":
143 """Build a LayoutCacheKey from a cache_values list.
145 Args:
146 cache_values (list): Mixed list of Layout objects (with compact_str) and raw scalars.
148 Returns:
149 LayoutCacheKey: Immutable key derived from the string representation of each value.
150 """
151 key_values = []
152 for v in cache_values:
153 if hasattr(v, 'compact_str'):
154 key_values.append(str(v.compact_str))
155 else:
156 key_values.append(str(v))
157 return cls(key_values)
159 def __eq__(self, other):
160 if not isinstance(other, LayoutCacheKey):
161 return False
162 return self._tuple == other._tuple
164 def __hash__(self):
165 return self._hash
167 def __repr__(self):
168 return f"LayoutCacheKey({self._tuple})"
170class LayoutCacheManager:
171 """
172 Cache layout in infer layout.
174 A singleton class that manages layout caches for distributed operations.
175 It caches the inferred layouts and operation implementations to avoid
176 redundant computation during repeated calls with the same input layouts.
177 """
178 _instance = None
180 def __init__(self):
181 self.layout_cache: Dict[str, Dict[LayoutCacheKey, Any]] = {}
182 atexit.register(self.clear_cache)
184 @classmethod
185 def get_instance(cls) -> "LayoutCacheManager":
186 """
187 Get the singleton instance of LayoutCacheManager.
189 Returns:
190 LayoutCacheManager: The singleton instance.
191 """
192 if cls._instance is None:
193 cls._instance = LayoutCacheManager()
194 return cls._instance
196 def get_layout_cache(self) -> Dict[str, Dict[LayoutCacheKey, Any]]:
197 """
198 Get the layout cache dictionary.
200 Returns:
201 Dict[str, Dict[LayoutCacheKey, Any]]: The nested dictionary mapping
202 operation names to their layout caches.
203 """
204 return self.layout_cache
206 def distributed_op(self, op_name: str) -> Any:
207 """
208 Get the distributed operation implementation by name.
210 Args:
211 op_name (str): The name of the distributed operation.
213 Returns:
214 Any: The distributed operation class or implementation.
215 """
216 op = get_distributed_op(op_name)
217 return op
219 def clear_cache(self) -> None:
220 """
221 Clear all cached layouts.
223 This method is automatically registered with atexit to ensure
224 cache is cleared when the program exits.
225 """
226 self.layout_cache.clear()
229class OpDispatcher:
230 """
231 OpDispatcher
232 """
233 def __init__(self):
234 self._env_yaml_dir: Optional[str] = os.environ.get("HYPER_PARALLEL_OPS_YAML_DIR")
235 self._env_python_path: Optional[str] = os.environ.get("HYPER_PARALLEL_OPS_PYTHON_PATH")
236 # The following attributes are initialized in _setup_yaml_dir()
237 self.work_dir = "" # Initialized in _setup_yaml_dir()
238 self.yaml_dir = "" # Initialized in _setup_yaml_dir()
240 self._setup_paths_from_env()
242 self.layout_infer_ops = self.safe_load_yaml_from_dir()
243 self.whitelist = ["InplaceAddExt", "InplaceSubExt", "InplaceMul", "InplaceDiv", "typeof", "DistCommIsend",
244 "DistCommIrecv", "DistCommBroadcast", "DistCommAllReduce", "DistCommAllGather",
245 "DistCommReduceScatter", "requires_grad_", "item", "__get__", "__set__", "register_hook",
246 "is_complex", "chunk", "__bool__", "__len__", "__format__", "dim",
247 "_has_compatible_shallow_copy_type", "is_floating_point", "is_contiguous"]
249 # Ops requiring args unpacking for layout inference (packed as prim, name, real_args).
250 self.unpack_ops = ["ScatterUpdate", "Mod", "GatherNd"]
252 self._random_ops = {
253 "normal_", "uniform_", "bernoulli", "bernoulli_",
254 "native_dropout", "rand", "rand_like", "randn",
255 "randn_like", "randint_like", "kaiming_uniform_",
256 }
257 # Only mint random op support
258 # MindSpore use the actual kernel name.
259 self._random_ms_ops = {
260 "BernoulliExt", "MultinomialExt", "InplaceNormal", "InplaceUniform",
261 "RandpermExt", "Randn", "RandLikeExt", "RandnLike", "RandInt", "RandIntLike", "RandExt",
262 "FuncDropoutExt"
263 }
264 self._rng_tracker: Optional[OffsetBasedRNGTracker] = None
266 self._suffix_dispatch: Dict[str, str] = {
267 "WithShape": "_with_layout_infer_with_shape",
268 "Reshape": "_with_layout_infer_reshape",
269 "WithTupleExpand": "_with_layout_infer_with_tuple_expand",
270 "Slice": "_with_layout_infer_slice",
271 }
273 self._register_distributed_ops()
275 def _setup_paths_from_env(self):
276 """
277 Setup YAML directory and Python path from environment variables.
279 This method initializes the YAML directory and extends sys.path based on
280 environment variables HYPER_PARALLEL_OPS_YAML_DIR and HYPER_PARALLEL_OPS_PYTHON_PATH.
281 """
282 self._setup_yaml_dir(self._env_yaml_dir)
283 self._extend_sys_path(self._env_python_path)
285 def _setup_yaml_dir(self, env_yaml_dir: Optional[str]):
286 """
287 Feature: Configure yaml_dir/work_dir for OpDispatcher
288 Description: Resolve the YAML directory used to load distributed op definitions.
289 If env_yaml_dir is an absolute path, use it directly; otherwise treat it
290 as a path relative to the project work_dir. If env_yaml_dir is not set,
291 fall back to the default 'shard/ops/yaml' under work_dir.
292 Expectation: self.yaml_dir and self.work_dir are set to valid values used later by
293 safe_load_yaml_from_dir(); no functional behavior is changed.
294 """
295 if env_yaml_dir:
296 if os.path.isabs(env_yaml_dir):
297 self.yaml_dir = env_yaml_dir
298 self.work_dir = ""
299 else:
300 self.work_dir = os.path.normpath(
301 os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
302 )
303 self.yaml_dir = env_yaml_dir
304 else:
305 self.yaml_dir = "shard/ops/yaml"
306 self.work_dir = os.path.normpath(
307 os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
308 )
310 def _extend_sys_path(self, env_python_path: Optional[str]):
311 if not env_python_path:
312 return
313 python_paths = env_python_path.split(":")
314 for path in python_paths:
315 if path and os.path.isdir(path) and path not in sys.path:
316 sys.path.insert(0, path)
318 def _register_distributed_ops(self):
319 for op_name, config in self.layout_infer_ops.items():
320 self._register_single_distributed_op(op_name, config)
322 def _register_single_distributed_op(self, op_name: str, config: dict):
323 """
324 Feature: Register a single distributed op implementation
325 Description: Import the distributed op class specified by config and instantiate it
326 with op_name to trigger registration in the distributed op registry.
327 Prefer 'distributed_op_module' when provided; otherwise import from
328 built-in module prefix 'hyper_parallel.core.shard.ops.' plus
329 'distributed_op_file'. If import fails and an external python path is
330 provided via env, fall back to importing 'distributed_op_file' directly.
331 Expectation: The distributed op class is imported and instantiated successfully,
332 or the original import error is raised; no functional behavior is changed.
333 """
334 class_name = config["distributed_op_class"]
336 if "distributed_op_module" in config:
337 module_name = config["distributed_op_module"]
338 module = importlib.import_module(module_name)
339 op_class = getattr(module, class_name)
340 _ = op_class(op_name)
341 return
343 module_file = config["distributed_op_file"]
344 try:
345 module_name = "hyper_parallel.core.shard.ops." + module_file
346 module = importlib.import_module(module_name)
347 op_class = getattr(module, class_name)
348 _ = op_class(op_name)
349 except (ModuleNotFoundError, ImportError):
350 if self._env_python_path:
351 module = importlib.import_module(module_file)
352 op_class = getattr(module, class_name)
353 _ = op_class(op_name)
354 else:
355 raise
357 @staticmethod
358 def _process_args_and_kwargs(
359 args, kwargs
360 ) -> tuple[list, list, list, dict, list]:
361 """_process_args_and_kwargs"""
362 input_layouts = []
363 extra_args = []
364 input_args = []
365 input_kwargs = kwargs.copy()
366 cache_key_values = []
368 for arg in args:
369 if arg is None:
370 input_layouts.append(None)
371 input_args.append(arg)
372 continue
374 if not hasattr(arg, "_layout"):
375 id_str = "scalar"
376 if not isinstance(arg, Tensor):
377 id_str = str(arg)
378 cache_key_values.append(id_str)
379 extra_args.append(arg)
380 input_layouts.append(None)
381 input_args.append(arg)
382 else:
383 layout = arg.layout
384 layout_id = layout.compact_str
385 cache_key_values.append(str(layout_id))
386 input_layouts.append(layout)
387 if isinstance(arg, DTensor):
388 input_args.append(arg.to_local())
389 else:
390 input_args.append(arg)
392 for k, val in kwargs.items():
393 if val is None:
394 input_layouts.append(None)
395 continue
396 if not hasattr(val, "_layout"):
397 id_str = "scalar"
398 if not isinstance(val, Tensor):
399 id_str = str(val)
400 cache_key_values.append(id_str)
401 extra_args.append(val)
402 input_layouts.append(None)
403 else:
404 layout = val.layout
405 layout_id = layout.compact_str
406 cache_key_values.append(str(layout_id))
407 input_layouts.append(layout)
408 if isinstance(val, DTensor):
409 input_kwargs[k] = val.to_local()
411 return input_layouts, extra_args, input_args, input_kwargs, cache_key_values
413 def _with_layout_infer(self, func: callable, *args, **kwargs) -> Tensor:
414 """_with_layout_infer"""
415 func_name = platform.get_op_name(func)
416 packed_call = None
417 if(func_name in self.unpack_ops and len(args) == 3 and
418 isinstance(args[1], str) and isinstance(args[2],(tuple,list))):
419 packed_call = (args[0], args[1])
420 args = tuple(args[2])
422 input_layouts, extra_args, input_args, input_kwargs, cache_key_values = \
423 OpDispatcher._process_args_and_kwargs(args, kwargs)
424 cache_key = LayoutCacheKey(cache_key_values)
425 cache_manager = LayoutCacheManager.get_instance()
426 layout_cache = cache_manager.get_layout_cache()
427 if func_name not in layout_cache:
428 layout_cache[func_name] = {}
430 op_layout_cache = layout_cache[func_name]
432 distribute_op = cache_manager.distributed_op(func_name)
433 if cache_key in op_layout_cache:
434 output_layout, op_impl = op_layout_cache[cache_key]
435 else:
436 all_args = (input_layouts, extra_args)
437 output_layout = distribute_op.infer_layout(*all_args)
438 op_impl = distribute_op.get_expand_impl(func, output_layout, input_layouts, extra_args)
439 op_layout_cache[cache_key] = (output_layout, op_impl)
441 if op_impl is None:
442 op_impl = func
444 if packed_call is not None:
445 py_output = op_impl(packed_call[0], packed_call[1], tuple(input_args), **input_kwargs)
446 else:
447 py_output = op_impl(*input_args, **input_kwargs)
449 if isinstance(py_output, (tuple, list)):
450 output = ()
451 if isinstance(output_layout, (tuple, list)):
452 if len(py_output) == len(output_layout):
453 for i, output_item in enumerate(py_output):
454 output += (DTensor.from_local(
455 output_item, output_layout[i].mesh,
456 output_layout[i].alias_placements),)
457 else:
458 raise RuntimeError(f"Output tuple size ({len(py_output)}) "
459 f"does not match layout tuple size ({len(output_layout)})")
460 else:
461 raise RuntimeError("Output is a tuple but layout is not")
462 return output
464 return DTensor.from_local(
465 py_output, output_layout.mesh, output_layout.alias_placements)
467 def _extract_single_arg_layout(self, expanded_args, kwargs_value):
468 """Helper to extract layout and cache info for a single argument."""
469 cache_key_values = []
470 input_layouts = []
471 extra_args = []
473 for arg in chain(expanded_args, kwargs_value):
474 if arg is None:
475 input_layouts.append(None)
476 continue
478 if not hasattr(arg, "_layout"):
479 id_str = "scalar" if isinstance(arg, Tensor) else str(arg)
480 cache_key_values.append(id_str)
481 extra_args.append(arg)
482 input_layouts.append(None)
483 else:
484 layout = arg.layout
485 cache_key_values.append(str(layout.compact_str))
486 input_layouts.append(layout)
487 return cache_key_values, input_layouts, extra_args
489 def _pack_infer_output(self, py_output, output_layout):
490 """Helper to pack py_output into DTensors using output_layout."""
491 if isinstance(py_output, (tuple, list)):
492 if not isinstance(output_layout, (tuple, list)):
493 raise RuntimeError("Output is a tuple but layout is not")
494 if len(py_output) != len(output_layout):
495 raise RuntimeError(f"Output tuple size ({len(py_output)}) "
496 f"does not match layout tuple size ({len(output_layout)})")
498 return tuple(
499 DTensor.from_local(item, layout.mesh, layout.alias_placements)
500 for item, layout in zip(py_output, output_layout)
501 )
503 return DTensor.from_local(py_output, output_layout.mesh, output_layout.alias_placements)
505 def _with_layout_infer_with_tuple_expand(self, func: callable, *args, **kwargs) -> Tensor:
506 """_with_layout_infer_with_tuple_expand"""
507 expanded_args = []
508 input_args = []
509 for arg in args:
510 if isinstance(arg, (tuple, list)):
511 expanded_args.extend(arg)
512 # pylint: disable=R1728
513 input_args.append(tuple(item.to_local() if hasattr(item, "_layout") else item for item in arg))
514 else:
515 expanded_args.append(arg)
516 input_args.append(arg.to_local() if isinstance(arg, DTensor) else arg)
518 # Process kwargs into local tensors
519 input_kwargs = {k: (v.to_local() if isinstance(v, DTensor) else v) for k, v in kwargs.items()}
521 # Extract layouts for positional args
522 cache_key_values, input_layouts, extra_args = self._extract_single_arg_layout(expanded_args, kwargs.values())
524 cache_key = LayoutCacheKey(cache_key_values)
526 cache_manager = LayoutCacheManager.get_instance()
527 layout_cache = cache_manager.get_layout_cache()
528 func_name = platform.get_op_name(func)
529 if func_name not in layout_cache:
530 layout_cache[func_name] = {}
532 op_layout_cache = layout_cache[func_name]
533 distribute_op = cache_manager.distributed_op(func_name)
535 if cache_key in op_layout_cache:
536 output_layout, op_impl = op_layout_cache[cache_key]
537 else:
538 all_args = (input_layouts, extra_args)
539 output_layout = distribute_op.infer_layout(*all_args)
540 op_impl = distribute_op.get_expand_impl(func, output_layout, input_layouts, extra_args)
541 op_layout_cache[cache_key] = (output_layout, op_impl)
543 if op_impl is None:
544 op_impl = func
546 py_output = op_impl(*input_args, **input_kwargs)
547 return distribute_op.wrap_output(py_output, output_layout)
549 @staticmethod
550 def _with_layout_infer_reshape(func: callable, *args) -> Tensor:
551 """_with_layout_infer_reshape"""
552 input_tensor = args[0]
553 shape = args[1]
555 layout = input_tensor.layout
556 input_layouts = [layout]
558 extra_args = [shape, input_tensor.shape]
560 cache_key_values = [str(layout.compact_str), str(shape), str(input_tensor.shape)]
561 cache_key = LayoutCacheKey(cache_key_values)
563 cache_manager = LayoutCacheManager.get_instance()
564 layout_cache = cache_manager.get_layout_cache()
565 func_name = platform.get_op_name(func)
566 if func_name not in layout_cache:
567 layout_cache[func_name] = {}
569 op_layout_cache = layout_cache[func_name]
571 distribute_op = cache_manager.distributed_op(func_name)
572 if cache_key in op_layout_cache:
573 infer_output, op_impl = op_layout_cache[cache_key]
574 else:
575 all_args = (input_layouts, extra_args)
576 infer_output = distribute_op.infer_layout(*all_args)
577 op_impl = distribute_op.get_expand_impl(func, infer_output, input_layouts, extra_args)
578 op_layout_cache[cache_key] = (infer_output, op_impl)
580 infer_output_tuple = infer_output
581 local_shape = infer_output_tuple[1]
583 if op_impl is None:
584 op_impl = func
586 py_output = op_impl(input_tensor.to_local(), local_shape)
588 return DTensor.from_local(py_output, infer_output_tuple[0].mesh, infer_output_tuple[0].alias_placements)
590 @staticmethod
591 def _process_args_and_kwargs_with_shape(args, kwargs):
592 """Process args and kwargs with input shapes for WithShape suffix operators.
594 Args:
595 args: Positional arguments from dispatch.
596 kwargs: Keyword arguments from dispatch.
598 Returns:
599 tuple: (input_layouts, input_shapes, extra_args, input_args, input_kwargs, cache_key_values)
600 """
601 input_layouts = []
602 extra_args = []
603 input_shapes = []
604 input_args = []
605 input_kwargs = kwargs.copy()
606 cache_key_values = []
608 for arg in args:
609 if arg is None:
610 input_layouts.append(None)
611 input_shapes.append(None)
612 input_args.append(arg)
613 continue
615 if not hasattr(arg, "_layout"):
616 id_str = "scalar"
617 if not isinstance(arg, Tensor):
618 id_str = str(arg)
619 cache_key_values.append(id_str)
620 extra_args.append(arg)
621 input_layouts.append(None)
622 input_args.append(arg)
623 else:
624 layout = arg.layout
625 layout_id = layout.compact_str
626 cache_key_values.append(str(layout_id))
627 input_layouts.append(layout)
628 if isinstance(arg, DTensor):
629 input_args.append(arg.to_local())
630 else:
631 input_args.append(arg)
633 if not hasattr(arg, "shape"):
634 input_shapes.append(None)
635 else:
636 input_shape = arg.shape
637 input_shapes.append(input_shape)
638 cache_key_values.append(str(input_shape))
640 for k, val in kwargs.items():
641 if val is None:
642 input_layouts.append(None)
643 continue
644 if not hasattr(val, "_layout"):
645 id_str = "scalar"
646 if not isinstance(val, Tensor):
647 id_str = str(val)
648 cache_key_values.append(id_str)
649 extra_args.append(val)
650 input_layouts.append(None)
651 else:
652 layout = val.layout
653 layout_id = layout.compact_str
654 cache_key_values.append(str(layout_id))
655 input_layouts.append(layout)
656 if isinstance(val, DTensor):
657 input_kwargs[k] = val.to_local()
659 if not hasattr(val, "shape"):
660 input_shapes.append(None)
661 else:
662 input_shape = val.shape
663 input_shapes.append(input_shape)
664 cache_key_values.append(str(input_shape))
666 return input_layouts, input_shapes, extra_args, input_args, input_kwargs, cache_key_values
668 def _with_layout_infer_with_shape(self, func: callable, *args, **kwargs) -> Tensor:
669 """_with_layout_infer_with_shape"""
670 func_name = platform.get_op_name(func)
671 packed_call = None
672 # Packed fallback args for some ops (e.g. Mod: (prim_obj, "Mod", (x, y))).
673 if (func_name in self.unpack_ops and len(args) == 3 and
674 isinstance(args[1], str) and isinstance(args[2], (tuple, list))):
675 packed_call = (args[0], args[1])
676 args = tuple(args[2])
678 (input_layouts, input_shapes, extra_args, input_args,
679 input_kwargs, cache_key_values) = OpDispatcher._process_args_and_kwargs_with_shape(args, kwargs)
680 cache_key = LayoutCacheKey(cache_key_values)
682 cache_manager = LayoutCacheManager.get_instance()
683 layout_cache = cache_manager.get_layout_cache()
684 if func_name not in layout_cache:
685 layout_cache[func_name] = {}
687 op_layout_cache = layout_cache[func_name]
689 distribute_op = cache_manager.distributed_op(func_name)
690 if cache_key in op_layout_cache:
691 output_layout, op_impl = op_layout_cache[cache_key]
692 else:
693 extra_args.append(input_shapes)
694 all_args = (input_layouts, extra_args)
695 output_layout = distribute_op.infer_layout(*all_args)
696 op_impl = distribute_op.get_expand_impl(func, output_layout, input_layouts, extra_args)
697 op_layout_cache[cache_key] = (output_layout, op_impl)
699 if op_impl is None:
700 op_impl = func
702 if packed_call is not None:
703 py_output = op_impl(packed_call[0], packed_call[1], tuple(input_args), **input_kwargs)
704 else:
705 py_output = op_impl(*input_args, **input_kwargs)
707 # set output layout
708 if isinstance(py_output, (tuple, list)):
709 output = ()
710 if isinstance(output_layout, (tuple, list)):
711 if len(py_output) == len(output_layout):
712 for i, output_item in enumerate(py_output):
713 output += (DTensor.from_local(
714 output_item, output_layout[i].mesh,
715 output_layout[i].alias_placements),)
716 else:
717 raise RuntimeError(f"Output tuple size ({len(py_output)}) "
718 f"does not match layout tuple size ({len(output_layout)})")
719 else:
720 raise RuntimeError("Output is a tuple but layout is not")
721 return output
723 return DTensor.from_local(
724 py_output, output_layout.mesh, output_layout.alias_placements)
726 def _with_layout_infer_slice(self, func: callable, *args) -> Tensor:
727 """_with_layout_infer_slice"""
728 input_tensor = args[0]
729 begin = args[1]
730 end = args[2]
732 # input layout
733 input_layouts = []
735 layout = input_tensor.layout
736 global_shape = input_tensor.shape
737 input_layouts.append(layout)
738 layout_id = layout.compact_str
740 extra_args = []
741 extra_args.append(begin)
742 extra_args.append(end)
743 extra_args.append(global_shape)
744 cache_key_values = [str(layout_id), str(begin), str(end), str(global_shape)]
745 cache_key = LayoutCacheKey(cache_key_values)
747 cache_manager = LayoutCacheManager.get_instance()
748 layout_cache = cache_manager.get_layout_cache()
749 func_name = platform.get_op_name(func)
750 if func_name not in layout_cache:
751 layout_cache[func_name] = {}
753 op_layout_cache = layout_cache[func_name]
755 distribute_op = cache_manager.distributed_op(func_name)
756 if cache_key in op_layout_cache:
757 infer_output, op_impl = op_layout_cache[cache_key]
758 else:
759 all_args = (input_layouts, extra_args)
760 infer_output = distribute_op.infer_layout(*all_args)
761 op_impl = distribute_op.get_expand_impl(func, infer_output, input_layouts, extra_args)
762 op_layout_cache[cache_key] = (infer_output, op_impl)
764 infer_output_tuple = infer_output
765 new_begin = infer_output_tuple[1]
766 new_end = infer_output_tuple[2]
768 if op_impl is None:
769 op_impl = func
771 py_output = op_impl(input_tensor.to_local(), new_begin, new_end)
773 return DTensor.from_local(py_output, infer_output_tuple[0].mesh, infer_output_tuple[0].alias_placements)
775 @staticmethod
776 def _merge_default(config: dict):
777 """Apply __default__ values to all ops in this YAML file."""
778 if "__default__" not in config:
779 return config
781 default_cfg = config["__default__"]
782 merged = {}
784 for op_name, op_cfg in config.items():
785 if op_name == "__default__":
786 continue
788 new_cfg = default_cfg.copy()
789 new_cfg.update(op_cfg)
790 merged[op_name] = new_cfg
792 return merged
794 def safe_load_yaml_from_dir(self) -> dict:
795 """
796 Load yaml dictionary from directory.
798 Returns:
799 dict: Merged dictionary of all operator configurations loaded from YAML files.
800 """
801 yaml_dict = {}
802 yaml_path = os.path.join(self.work_dir, self.yaml_dir) if self.work_dir else self.yaml_dir
803 if not os.path.isdir(yaml_path):
804 raise ValueError(f"Invalid yaml directory path: {yaml_path}")
806 for yaml_file_path in glob.glob(os.path.join(yaml_path, '*.yaml')):
807 with open(yaml_file_path, 'r', encoding="utf-8") as f:
808 yaml_data = yaml.safe_load(f)
810 yaml_data = OpDispatcher._merge_default(yaml_data)
811 for name, data in yaml_data.items():
812 if name in yaml_dict:
813 raise ValueError(f"Duplicate yaml object with name '{name}'.")
814 yaml_dict[name] = data
816 return yaml_dict
818 def _dispatch_random_op(self, op_name: str, op_call: callable, args, kwargs):
819 """Handle dispatch for random ops that operate on DTensors."""
820 first_arg = next(
821 (x for x in chain(args, kwargs.values()) if isinstance(x, DTensor)),
822 None,
823 )
824 # Fall back to the default op if no DTensor is found.
825 if first_arg is None:
826 return op_call(*args, **kwargs)
828 local_args = [arg.to_local() if isinstance(arg, DTensor) else arg for arg in args]
829 local_kwargs = {k: v.to_local() if isinstance(v, DTensor) else v for k, v in kwargs.items()}
830 first_local_arg = first_arg.to_local()
832 if self._rng_tracker is None and is_rng_supported_mesh():
833 self._rng_tracker = OffsetBasedRNGTracker()
835 maybe_user_generator = local_kwargs.pop("generator", None)
836 if (
837 self._rng_tracker is not None
838 and not first_local_arg.is_meta
839 and self._rng_tracker.distribute_region_enabled
840 ):
841 # pylint: disable=W0212
842 with self._rng_tracker._distribute_region(
843 device_mesh=first_arg.device_mesh,
844 placements=first_arg.placements,
845 global_shape=first_arg.shape,
846 generator=maybe_user_generator,
847 ):
848 # MindSpore random ops (e.g. mint.randn_like) extract (seed, offset)
849 # from default_generator._step() in the Python wrapper *before* the
850 # C++ dispatch triggers __fallback__. The callback reuses these
851 # pre-fetched tensor args, so set_rng_state inside _distribute_region
852 # has no effect on the kernel. Fix: apply the per-shard offset
853 # increment directly to the offset tensor in the args.
854 if platform.platform_type == PlatformType.MINDSPORE:
855 offset_incr = self._rng_tracker.compute_offset_incr(
856 first_arg.device_mesh, first_arg.placements, first_arg.shape,
857 )
858 local_args = _apply_shard_offset_to_rng_args(local_args, offset_incr)
859 local_results = op_call(*local_args, **local_kwargs)
860 else:
861 local_results = op_call(*local_args, **local_kwargs)
863 # in-place ops
864 if op_name.endswith('_'):
865 return first_arg
866 # non-in-place ops
867 # Some ops return tuple/list, e.g. native_dropout returns (output, mask).
868 if isinstance(local_results, (tuple, list)):
869 return tuple(
870 DTensor.from_local(r, first_arg.device_mesh, first_arg.layout.alias_placements)
871 if isinstance(r, Tensor) else r
872 for r in local_results
873 )
874 if isinstance(local_results, Tensor):
875 return DTensor.from_local(local_results, first_arg.device_mesh, first_arg.layout.alias_placements)
876 # Fallback: return as-is for non-Tensor results (currently unreachable with existing _random_ops).
877 return local_results
879 @staticmethod
880 def _unwrap_args(args: tuple) -> list:
881 """Strip DTensor wrappers from args, preserving tuple/list container structure.
883 Args:
884 args: Op call positional arguments, may contain DTensor instances.
886 Returns:
887 List of args with DTensor replaced by their local tensors.
888 """
889 def unwrap(arg: object) -> object:
890 """Replace DTensor with its local tensor; pass scalars and plain tensors through.
892 Args:
893 arg (object): An element of the operator's argument list.
895 Returns:
896 object: The local tensor if arg is a DTensor, otherwise arg unchanged.
897 """
898 if isinstance(arg, DTensor):
899 return arg.to_local()
900 if isinstance(arg, tuple):
901 return tuple(e.to_local() if isinstance(e, DTensor) else e for e in arg)
902 if isinstance(arg, list):
903 return [e.to_local() if isinstance(e, DTensor) else e for e in arg]
904 return arg
905 return [unwrap(arg) for arg in args]
907 def _should_bypass_dispatch(self, op_name: str) -> bool:
908 """Return True if the op should bypass DTensor dispatch and run locally.
910 Args:
911 op_name: Canonical operator name from platform.get_op_name().
913 Returns:
914 True when the op is whitelisted or DTensor dispatch is globally disabled.
915 """
916 skip_dispatch = get_dtensor_dispatch() is False and op_name not in get_no_skip_ops()
917 return op_name in self.whitelist or skip_dispatch
919 def _dispatch_layout_infer(
920 self, op_name: str, op_call: callable, args: tuple, kwargs: dict
921 ):
922 """Dispatch an op through the layout-inference path.
924 Args:
925 op_name: Canonical operator name.
926 op_call: The raw operator callable.
927 args: Positional arguments for op_call.
928 kwargs: Keyword arguments for op_call.
930 Returns:
931 Result of the layout-infer dispatch.
933 Raises:
934 RuntimeError: If op_name is not registered or has an unknown suffix.
935 """
936 if op_name not in self.layout_infer_ops:
937 raise RuntimeError(f"Operator {op_name} does not contain parallel layout infer func.")
939 cache_manager = LayoutCacheManager.get_instance()
940 distribute_op = cache_manager.distributed_op(op_name)
942 result = distribute_op.preprocess(args, kwargs)
943 if result is not None:
944 return self._dispatch_new(op_call, distribute_op, result)
946 suffix = self.layout_infer_ops[op_name].get('infer_layout_suffix', '')
947 if not suffix:
948 return self._with_layout_infer(op_call, *args, **kwargs)
950 handler_name = self._suffix_dispatch.get(suffix)
951 if handler_name is None:
952 raise RuntimeError(f"Operator {op_name} specified wrong suffix in parallel yaml.")
953 return getattr(self, handler_name)(op_call, *args, **kwargs)
955 def _dispatch_new(self, func, distribute_op, result) -> Tensor:
956 """New dispatch flow using preprocess result.
958 Args:
959 func: Original function.
960 distribute_op: Distributed operation instance.
961 result: Preprocessed result (local_args, local_kwargs, cache_values).
963 Returns:
964 Tensor: Dispatched result as DTensor.
965 """
966 local_args, local_kwargs, cache_values = result
967 cache_key = LayoutCacheKey.from_cache_values(cache_values)
968 func_name = platform.get_op_name(func)
969 cache_manager = LayoutCacheManager.get_instance()
970 layout_cache = cache_manager.get_layout_cache()
971 if func_name not in layout_cache:
972 layout_cache[func_name] = {}
973 op_layout_cache = layout_cache[func_name]
974 if cache_key in op_layout_cache:
975 infer_result, op_impl = op_layout_cache[cache_key]
976 else:
977 infer_result = distribute_op.infer_layout(cache_values)
978 op_impl = distribute_op.get_expand_impl(func, infer_result, cache_values)
979 op_layout_cache[cache_key] = (infer_result, op_impl)
980 output_layouts, _ = infer_result
981 op_impl = func if op_impl is None else op_impl
982 py_output = op_impl(*local_args, **local_kwargs)
983 return distribute_op.wrap_output(py_output, output_layouts)
985 def dispatch(self, op_call: callable, args: tuple, kwargs: dict) -> object:
986 """Route an op call through the appropriate DTensor dispatch path.
988 Args:
989 op_call: The raw operator callable.
990 args: Positional arguments for op_call.
991 kwargs: Keyword arguments for op_call.
993 Returns:
994 Result of the dispatched op call.
995 """
996 op_name = platform.get_op_name(op_call)
998 if self._should_bypass_dispatch(op_name):
999 return op_call(*self._unwrap_args(args), **kwargs)
1001 if op_name in self._random_ops or op_name in self._random_ms_ops:
1002 return self._dispatch_random_op(op_name, op_call, args, kwargs)
1004 # Auto-register ops that were registered programmatically via DistributedOp
1005 # (e.g. through DFunction) without a corresponding YAML entry.
1006 if op_name not in self.layout_infer_ops and get_distributed_op(op_name) is not None:
1007 self.layout_infer_ops[op_name] = {}
1009 return self._dispatch_layout_infer(op_name, op_call, args, kwargs)
1011_OP_DISPATCHER = OpDispatcher()