Coverage for hyper_parallel / core / shard / _op_dispatch.py: 78%
441 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""_op_dispatch"""
16import os
17import sys
18import atexit
19import glob
20import importlib
21from typing import Any, List, Dict, Optional
22import yaml
24from hyper_parallel.core.shard.ops.parallel_ops_register import get_distributed_op
25from hyper_parallel.core.dtensor import DTensor
26from hyper_parallel.platform import get_platform
28platform = get_platform()
29Tensor = platform.Tensor
31_dtensor_dispatch = True
33def enable_dtensor_dispatch():
34 global _dtensor_dispatch
35 _dtensor_dispatch = True
37def disable_dtensor_dispatch():
38 global _dtensor_dispatch
39 _dtensor_dispatch = False
41def get_dtensor_dispatch():
42 return _dtensor_dispatch
45class LayoutCacheKey:
46 """
47 Layout cache key
48 """
49 def __init__(self, layout_ids: List[str]):
50 self.layout_ids = layout_ids
52 def __eq__(self, other):
53 if not isinstance(other, LayoutCacheKey):
54 return False
55 return self.layout_ids == other.layout_ids
57 def __hash__(self):
58 seed = 0
59 for id_str in self.layout_ids:
60 h = hash(id_str)
61 seed ^= h + 0x9e3779b9 + (seed << 6) + (seed >> 2)
62 return seed
64class LayoutCacheManager:
65 """
66 Cache layout in infer layout.
67 """
68 _instance = None
70 def __init__(self):
71 self.layout_cache: Dict[str, Dict[LayoutCacheKey, Any]] = {}
72 atexit.register(self.clear_cache)
74 @classmethod
75 def get_instance(cls):
76 if cls._instance is None:
77 cls._instance = LayoutCacheManager()
78 return cls._instance
80 def get_layout_cache(self) -> Dict[str, Dict[LayoutCacheKey, Any]]:
81 return self.layout_cache
83 def distributed_op(self, op_name: str) -> Any:
84 op = get_distributed_op(op_name)
85 return op
87 def clear_cache(self):
88 self.layout_cache.clear()
91class OpDispatcher:
92 """
93 OpDispatcher
94 """
95 def __init__(self):
96 self._env_yaml_dir: Optional[str] = os.environ.get("HYPER_PARALLEL_OPS_YAML_DIR")
97 self._env_python_path: Optional[str] = os.environ.get("HYPER_PARALLEL_OPS_PYTHON_PATH")
99 self._setup_paths_from_env()
101 self.layout_infer_ops = self.safe_load_yaml_from_dir()
102 self.whitelist = ["InplaceAddExt", "InplaceSubExt", "InplaceMul", "InplaceDiv", "typeof", "DistCommIsend",
103 "DistCommIrecv", "DistCommBroadcast", "DistCommAllReduce", "DistCommAllGather",
104 "DistCommReduceScatter", "requires_grad_", "item", "__get__", "__set__", "register_hook",
105 "is_complex", "chunk", "__bool__", "__len__", "__format__"]
107 # Ops requiring args unpacking for layout inference (packed as prim, name, real_args).
108 self.unpack_ops = ["ScatterUpdate", "Mod", "GatherNd"]
110 self._register_distributed_ops()
112 def _setup_paths_from_env(self):
113 self._setup_yaml_dir(self._env_yaml_dir)
114 self._extend_sys_path(self._env_python_path)
116 def _setup_yaml_dir(self, env_yaml_dir: Optional[str]):
117 """
118 Feature: Configure yaml_dir/work_dir for OpDispatcher
119 Description: Resolve the YAML directory used to load distributed op definitions.
120 If env_yaml_dir is an absolute path, use it directly; otherwise treat it
121 as a path relative to the project work_dir. If env_yaml_dir is not set,
122 fall back to the default 'shard/ops/yaml' under work_dir.
123 Expectation: self.yaml_dir and self.work_dir are set to valid values used later by
124 safe_load_yaml_from_dir(); no functional behavior is changed.
125 """
126 if env_yaml_dir:
127 if os.path.isabs(env_yaml_dir):
128 self.yaml_dir = env_yaml_dir
129 self.work_dir = ""
130 else:
131 self.work_dir = os.path.normpath(
132 os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
133 )
134 self.yaml_dir = env_yaml_dir
135 else:
136 self.yaml_dir = "shard/ops/yaml"
137 self.work_dir = os.path.normpath(
138 os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
139 )
141 def _extend_sys_path(self, env_python_path: Optional[str]):
142 if not env_python_path:
143 return
144 python_paths = env_python_path.split(":")
145 for path in python_paths:
146 if path and os.path.isdir(path) and path not in sys.path:
147 sys.path.insert(0, path)
149 def _register_distributed_ops(self):
150 for op_name, config in self.layout_infer_ops.items():
151 self._register_single_distributed_op(op_name, config)
153 def _register_single_distributed_op(self, op_name: str, config: dict):
154 """
155 Feature: Register a single distributed op implementation
156 Description: Import the distributed op class specified by config and instantiate it
157 with op_name to trigger registration in the distributed op registry.
158 Prefer 'distributed_op_module' when provided; otherwise import from
159 built-in module prefix 'hyper_parallel.core.shard.ops.' plus
160 'distributed_op_file'. If import fails and an external python path is
161 provided via env, fall back to importing 'distributed_op_file' directly.
162 Expectation: The distributed op class is imported and instantiated successfully,
163 or the original import error is raised; no functional behavior is changed.
164 """
165 class_name = config["distributed_op_class"]
167 if "distributed_op_module" in config:
168 module_name = config["distributed_op_module"]
169 module = importlib.import_module(module_name)
170 op_class = getattr(module, class_name)
171 _ = op_class(op_name)
172 return
174 module_file = config["distributed_op_file"]
175 try:
176 module_name = "hyper_parallel.core.shard.ops." + module_file
177 module = importlib.import_module(module_name)
178 op_class = getattr(module, class_name)
179 _ = op_class(op_name)
180 except (ModuleNotFoundError, ImportError):
181 if self._env_python_path:
182 module = importlib.import_module(module_file)
183 op_class = getattr(module, class_name)
184 _ = op_class(op_name)
185 else:
186 raise
188 def _process_args_and_kwargs(
189 self, args, kwargs, cache_key: "LayoutCacheKey"
190 ) -> tuple[list, list, list, dict]:
191 """_process_args_and_kwargs"""
192 # input_layouts contain prarmeters which have layout, extra_args contain other parameters
193 input_layouts = []
194 extra_args = []
195 # input_args are position prarmeters, input_kwargs are keyword parameters
196 input_args = []
197 input_kwargs = kwargs.copy()
199 # Normal ops pass real inputs directly (e.g. SumExt: args = (dtensor, axis: list, keep_dims: bool, dtype: None)).
200 for arg in args:
201 if arg is None:
202 input_layouts.append(None)
203 input_args.append(arg)
204 continue
206 if not hasattr(arg, "_layout"):
207 id_str = "scalar"
208 if not isinstance(arg, Tensor):
209 id_str = str(arg)
210 cache_key.layout_ids.append(id_str)
211 extra_args.append(arg)
212 input_layouts.append(None)
213 input_args.append(arg)
214 else:
215 layout = arg.layout
216 layout_id = layout.compact_str
217 cache_key.layout_ids.append(str(layout_id))
218 input_layouts.append(layout)
219 if isinstance(arg, DTensor):
220 input_args.append(arg.to_local())
221 else:
222 input_args.append(arg)
224 for k, val in kwargs.items():
225 if val is None:
226 input_layouts.append(None)
227 continue
228 if not hasattr(val, "_layout"):
229 id_str = "scalar"
230 if not isinstance(val, Tensor):
231 id_str = str(val)
232 cache_key.layout_ids.append(id_str)
233 extra_args.append(val)
234 input_layouts.append(None)
235 else:
236 layout = val.layout
237 layout_id = layout.compact_str
238 cache_key.layout_ids.append(str(layout_id))
239 input_layouts.append(layout)
240 if isinstance(val, DTensor):
241 input_kwargs[k] = val.to_local()
243 return input_layouts, extra_args, input_args, input_kwargs
245 def _with_layout_infer(self, func: callable, *args, **kwargs) -> Tensor:
246 """_with_layout_infer"""
247 func_name = platform.get_op_name(func)
248 packed_call = None
249 # Ops in unpack_ops use packed fallback args (e.g. ScatterUpdate: (prim_obj, op_name: str, (input_x, indices, updates))).
250 if(func_name in self.unpack_ops and len(args) == 3 and
251 isinstance(args[1], str) and isinstance(args[2],(tuple,list))):
252 packed_call = (args[0], args[1])
253 args = tuple(args[2])
255 cache_key = LayoutCacheKey([])
256 input_layouts, extra_args, input_args, input_kwargs = self._process_args_and_kwargs(
257 args, kwargs, cache_key
258 )
259 cache_manager = LayoutCacheManager.get_instance()
260 layout_cache = cache_manager.get_layout_cache()
261 if func_name not in layout_cache:
262 layout_cache[func_name] = {}
264 op_layout_cache = layout_cache[func_name]
266 distribute_op = cache_manager.distributed_op(func_name)
267 if cache_key in op_layout_cache:
268 output_layout, op_impl = op_layout_cache[cache_key]
269 else:
270 all_args = (input_layouts, extra_args)
271 output_layout = distribute_op.infer_layout(*all_args)
272 op_impl = getattr(
273 distribute_op, "get_expand_impl", lambda *args, **kwargs: None
274 )(func, output_layout, input_layouts, extra_args)
275 op_layout_cache[cache_key] = (output_layout, op_impl)
277 if op_impl is None:
278 op_impl = func
280 if packed_call is not None:
281 py_output = op_impl(packed_call[0], packed_call[1], tuple(input_args), **input_kwargs)
282 else:
283 py_output = op_impl(*input_args, **input_kwargs)
285 if isinstance(py_output, (tuple, list)):
286 output = ()
287 if isinstance(output_layout, (tuple, list)):
288 if len(py_output) == len(output_layout):
289 for i, output_item in enumerate(py_output):
290 output += (DTensor.from_local(output_item, output_layout[i].mesh, output_layout[i].placements),)
291 else:
292 raise RuntimeError(f"Output tuple size ({len(py_output)}) "
293 f"does not match layout tuple size ({len(output_layout)})")
294 else:
295 raise RuntimeError("Output is a tuple but layout is not")
296 return output
298 return DTensor.from_local(py_output, output_layout.mesh, output_layout.placements)
300 def _with_layout_infer_with_tuple_expand(self, func: callable, *args, **kwargs) -> Tensor:
301 """_with_layout_infer_with_tuple_expand"""
302 expanded_args = []
303 input_args = []
304 for arg in args:
305 if isinstance(arg, (tuple, list)):
306 expanded_args.extend(arg)
307 # pylint: disable=R1728
308 input_args.append(tuple(item.to_local() if hasattr(item, "_layout") else item for item in arg))
309 else:
310 expanded_args.append(arg)
311 input_args.append(arg.to_local() if isinstance(arg, DTensor) else arg)
313 cache_key = LayoutCacheKey([])
314 input_layouts = []
315 extra_args = []
317 for arg in expanded_args:
318 if arg is None:
319 input_layouts.append(None)
320 continue
322 if not hasattr(arg, "_layout"):
323 id_str = "scalar"
324 if not isinstance(arg, Tensor):
325 id_str = str(arg)
326 cache_key.layout_ids.append(id_str)
327 extra_args.append(arg)
328 input_layouts.append(None)
329 else:
330 layout = arg.layout
331 layout_id = layout.compact_str
332 cache_key.layout_ids.append(str(layout_id))
333 input_layouts.append(layout)
335 cache_manager = LayoutCacheManager.get_instance()
336 layout_cache = cache_manager.get_layout_cache()
337 func_name = platform.get_op_name(func)
338 if func_name not in layout_cache:
339 layout_cache[func_name] = {}
341 op_layout_cache = layout_cache[func_name]
343 distribute_op = cache_manager.distributed_op(func_name)
344 if cache_key in op_layout_cache:
345 output_layout, op_impl = op_layout_cache[cache_key]
346 else:
347 all_args = (input_layouts, extra_args)
348 output_layout = distribute_op.infer_layout(*all_args)
349 op_impl = distribute_op.get_expand_impl(func, output_layout, input_layouts, extra_args)
350 op_layout_cache[cache_key] = (output_layout, op_impl)
352 if op_impl is None:
353 op_impl = func
355 py_output = op_impl(*input_args, **kwargs)
357 if isinstance(py_output, (tuple, list)):
358 output = ()
359 if isinstance(output_layout, (tuple, list)):
360 if len(py_output) == len(output_layout):
361 for i, output_item in enumerate(py_output):
362 output += (DTensor.from_local(output_item, output_layout[i].mesh, output_layout[i].placements),)
363 else:
364 raise RuntimeError(f"Output tuple size ({len(py_output)}) "
365 f"does not match layout tuple size ({len(output_layout)})")
366 else:
367 raise RuntimeError("Output is a tuple but layout is not")
368 return output
370 return DTensor.from_local(py_output, output_layout.mesh, output_layout.placements)
372 def _with_layout_infer_reshape(self, func: callable, *args) -> Tensor:
373 """_with_layout_infer_reshape"""
374 input_tensor = args[0]
375 shape = args[1]
376 cache_key = LayoutCacheKey([])
377 input_layouts = []
379 layout = input_tensor.layout
380 input_layouts.append(layout)
381 layout_id = layout.compact_str
382 cache_key.layout_ids.append(str(layout_id))
384 extra_args = []
385 extra_args.append(shape)
386 cache_key.layout_ids.append(str(shape))
388 input_shape = input_tensor.shape
389 extra_args.append(input_shape)
390 cache_key.layout_ids.append(str(input_shape))
392 cache_manager = LayoutCacheManager.get_instance()
393 layout_cache = cache_manager.get_layout_cache()
394 func_name = platform.get_op_name(func)
395 if func_name not in layout_cache:
396 layout_cache[func_name] = {}
398 op_layout_cache = layout_cache[func_name]
400 distribute_op = cache_manager.distributed_op(func_name)
401 if cache_key in op_layout_cache:
402 infer_output, op_impl = op_layout_cache[cache_key]
403 else:
404 all_args = (input_layouts, extra_args)
405 infer_output = distribute_op.infer_layout(*all_args)
406 op_impl = distribute_op.get_expand_impl(func, infer_output, input_layouts, extra_args)
407 op_layout_cache[cache_key] = (infer_output, op_impl)
409 infer_output_tuple = infer_output
410 local_shape = infer_output_tuple[1]
412 if op_impl is None:
413 op_impl = func
415 py_output = op_impl(input_tensor.to_local(), local_shape)
417 return DTensor.from_local(py_output, infer_output_tuple[0].mesh, infer_output_tuple[0].placements)
419 def _process_args_and_kwargs_with_shape(
420 self, args, kwargs, cache_key: "LayoutCacheKey"
421 ) -> tuple[list, list, list, list, dict]:
422 """_process_args_and_kwargs_with_shape"""
423 input_layouts = []
424 extra_args = []
425 input_shapes = []
426 input_args = []
427 input_kwargs = kwargs.copy()
428 for arg in args:
429 if arg is None:
430 input_layouts.append(None)
431 input_shapes.append(None)
432 input_args.append(arg)
433 continue
435 if not hasattr(arg, "_layout"):
436 id_str = "scalar"
437 if not isinstance(arg, Tensor):
438 id_str = str(arg)
439 cache_key.layout_ids.append(id_str)
440 extra_args.append(arg)
441 input_layouts.append(None)
442 input_args.append(arg)
443 else:
444 layout = arg.layout
445 layout_id = layout.compact_str
446 cache_key.layout_ids.append(str(layout_id))
447 input_layouts.append(layout)
448 if isinstance(arg, DTensor):
449 input_args.append(arg.to_local())
450 else:
451 input_args.append(arg)
453 if not hasattr(arg, "shape"):
454 input_shapes.append(None)
455 else:
456 input_shape = arg.shape
457 input_shapes.append(input_shape)
458 cache_key.layout_ids.append(str(input_shape))
460 for k, val in kwargs.items():
461 if val is None:
462 input_layouts.append(None)
463 continue
464 if not hasattr(val, "_layout"):
465 id_str = "scalar"
466 if not isinstance(val, Tensor):
467 id_str = str(val)
468 cache_key.layout_ids.append(id_str)
469 extra_args.append(val)
470 input_layouts.append(None)
471 else:
472 layout = val.layout
473 layout_id = layout.compact_str
474 cache_key.layout_ids.append(str(layout_id))
475 input_layouts.append(layout)
476 if isinstance(val, DTensor):
477 input_kwargs[k] = val.to_local()
479 if not hasattr(val, "shape"):
480 input_shapes.append(None)
481 else:
482 input_shape = val.shape
483 input_shapes.append(input_shape)
484 cache_key.layout_ids.append(str(input_shape))
486 return input_layouts, input_shapes, extra_args, input_args, input_kwargs
488 def _with_layout_infer_with_shape(self, func: callable, *args, **kwargs) -> Tensor:
489 """_with_layout_infer_with_shape"""
490 func_name = platform.get_op_name(func)
491 packed_call = None
492 # Packed fallback args for some ops (e.g. Mod: (prim_obj, "Mod", (x, y))).
493 if (func_name in self.unpack_ops and len(args) == 3 and
494 isinstance(args[1], str) and isinstance(args[2], (tuple, list))):
495 packed_call = (args[0], args[1])
496 args = tuple(args[2])
498 cache_key = LayoutCacheKey([])
499 input_layouts, input_shapes, extra_args, input_args, input_kwargs = \
500 self._process_args_and_kwargs_with_shape(args, kwargs, cache_key)
502 cache_manager = LayoutCacheManager.get_instance()
503 layout_cache = cache_manager.get_layout_cache()
504 if func_name not in layout_cache:
505 layout_cache[func_name] = {}
507 op_layout_cache = layout_cache[func_name]
509 distribute_op = cache_manager.distributed_op(func_name)
510 if cache_key in op_layout_cache:
511 output_layout, op_impl = op_layout_cache[cache_key]
512 else:
513 extra_args.append(input_shapes)
514 all_args = (input_layouts, extra_args)
515 output_layout = distribute_op.infer_layout(*all_args)
516 op_impl = distribute_op.get_expand_impl(func, output_layout, input_layouts, extra_args)
517 op_layout_cache[cache_key] = (output_layout, op_impl)
519 if op_impl is None:
520 op_impl = func
522 if packed_call is not None:
523 py_output = op_impl(packed_call[0], packed_call[1], tuple(input_args), **input_kwargs)
524 else:
525 py_output = op_impl(*input_args, **input_kwargs)
527 # 设置输出布局
528 if isinstance(py_output, (tuple, list)):
529 output = ()
530 if isinstance(output_layout, (tuple, list)):
531 if len(py_output) == len(output_layout):
532 for i, output_item in enumerate(py_output):
533 output += (DTensor.from_local(output_item, output_layout[i].mesh, output_layout[i].placements),)
534 else:
535 raise RuntimeError(f"Output tuple size ({len(py_output)}) "
536 f"does not match layout tuple size ({len(output_layout)})")
537 else:
538 raise RuntimeError("Output is a tuple but layout is not")
539 return output
541 return DTensor.from_local(py_output, output_layout.mesh, output_layout.placements)
543 def _with_layout_infer_slice(self, func: callable, *args) -> Tensor:
544 """_with_layout_infer_slice"""
545 input_tensor = args[0]
546 begin = args[1]
547 end = args[2]
549 # 输入布局
550 cache_key = LayoutCacheKey([])
551 input_layouts = []
553 layout = input_tensor.layout
554 global_shape = input_tensor.shape
555 input_layouts.append(layout)
556 layout_id = layout.compact_str
557 cache_key.layout_ids.append(str(layout_id))
559 extra_args = []
560 extra_args.append(begin)
561 extra_args.append(end)
562 extra_args.append(global_shape)
563 cache_key.layout_ids.append(str(begin))
564 cache_key.layout_ids.append(str(end))
565 cache_key.layout_ids.append(str(global_shape))
567 cache_manager = LayoutCacheManager.get_instance()
568 layout_cache = cache_manager.get_layout_cache()
569 func_name = platform.get_op_name(func)
570 if func_name not in layout_cache:
571 layout_cache[func_name] = {}
573 op_layout_cache = layout_cache[func_name]
575 distribute_op = cache_manager.distributed_op(func_name)
576 if cache_key in op_layout_cache:
577 infer_output, op_impl = op_layout_cache[cache_key]
578 else:
579 all_args = (input_layouts, extra_args)
580 infer_output = distribute_op.infer_layout(*all_args)
581 op_impl = distribute_op.get_expand_impl(func, infer_output, input_layouts, extra_args)
582 op_layout_cache[cache_key] = (infer_output, op_impl)
584 infer_output_tuple = infer_output
585 new_begin = infer_output_tuple[1]
586 new_end = infer_output_tuple[2]
588 if op_impl is None:
589 op_impl = func
591 py_output = op_impl(input_tensor.to_local(), new_begin, new_end)
593 return DTensor.from_local(py_output, infer_output_tuple[0].mesh, infer_output_tuple[0].placements)
595 def _merge_default(self, config: dict):
596 """Apply __default__ values to all ops in this YAML file."""
597 if "__default__" not in config:
598 return config
600 default_cfg = config["__default__"]
601 merged = {}
603 for op_name, op_cfg in config.items():
604 if op_name == "__default__":
605 continue
607 new_cfg = default_cfg.copy()
608 new_cfg.update(op_cfg)
609 merged[op_name] = new_cfg
611 return merged
613 def safe_load_yaml_from_dir(self):
614 """
615 Load yaml dictionary from directory.
616 """
617 yaml_dict = {}
618 yaml_path = os.path.join(self.work_dir, self.yaml_dir) if self.work_dir else self.yaml_dir
619 if not os.path.isdir(yaml_path):
620 raise ValueError(f"Invalid yaml directory path: {yaml_path}")
622 for yaml_file_path in glob.glob(os.path.join(yaml_path, '*.yaml')):
623 with open(yaml_file_path, 'r', encoding="utf-8") as f:
624 yaml_data = yaml.safe_load(f)
626 yaml_data = self._merge_default(yaml_data)
627 for name, data in yaml_data.items():
628 if name in yaml_dict:
629 raise ValueError(f"Duplicate yaml object with name '{name}'.")
630 yaml_dict[name] = data
632 return yaml_dict
634 def dispatch(self, op_call: callable, args: tuple[object, ...], kwargs: dict[str, object]):
635 """
636 dispatch
637 :param op_call:
638 :param args:
639 :param kwargs:
640 :return:
641 """
642 op_name = platform.get_op_name(op_call)
643 if op_name in self.whitelist or get_dtensor_dispatch() is False:
644 input_args = [arg.to_local() if isinstance(arg, DTensor) else arg for arg in args]
645 return op_call(*input_args, **kwargs)
646 if op_name not in self.layout_infer_ops:
647 raise RuntimeError(f"Operator {op_name} dose not contain parallel layout infer func.")
649 layout_infer_info = self.layout_infer_ops[op_name]
650 suffix = layout_infer_info.get('infer_layout_suffix', '')
651 if not suffix:
652 return self._with_layout_infer(op_call, *args, **kwargs)
654 if suffix == "WithShape":
655 return self._with_layout_infer_with_shape(op_call, *args, **kwargs)
656 if suffix == "Reshape":
657 return self._with_layout_infer_reshape(op_call, *args)
658 if suffix == "WithTupleExpand":
659 return self._with_layout_infer_with_tuple_expand(op_call, *args, **kwargs)
660 if suffix == "Slice":
661 return self._with_layout_infer_slice(op_call, *args)
662 raise RuntimeError(f"Operator {op_name} specified wrong suffix in parallel yaml.")
664_OP_DISPATCHER = OpDispatcher()