Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / dfunction.py: 89%
35 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Custom distributed autograd function base class."""
16from __future__ import annotations
18from hyper_parallel.platform import get_platform
19from hyper_parallel.core.dtensor.dtensor import DTensor
20from hyper_parallel.core.shard._op_dispatch import _OP_DISPATCHER
22platform = get_platform()
25class _LocalCallable:
26 """Named callable wrapper that exposes the op name to both platform dispatchers.
28 PyTorch's ``get_op_name`` inspects ``__name__``; MindSpore's inspects ``.name``.
29 Setting both attributes here lets ``_OP_DISPATCHER`` look up the correct
30 ``DistributedOp`` without modifying either platform implementation.
32 Args:
33 fn: The underlying callable to invoke.
34 op_name: Canonical op name matching the registered ``DistributedOp``.
35 """
37 def __init__(self, fn: callable, op_name: str) -> None:
38 self._fn = fn
39 self.__name__ = op_name
40 self.name = op_name
42 def __call__(self, *args, **kwargs):
43 return self._fn(*args, **kwargs)
46class DFunction(platform.Function):
47 """Base class for user-defined distributed autograd functions.
49 Subclass this class and implement ``forward`` and ``backward`` as
50 ``@staticmethod`` methods that operate on **local** tensors. To enable
51 multi-device support, also set ``_op_name`` to a string that matches the
52 ``op_name`` of a registered ``DistributedOp`` subclass.
54 Dispatch behaviour:
56 * **No DTensor inputs** — calls ``super().apply()`` directly, going straight
57 into the platform autograd mechanism (single-device path).
58 * **DTensor inputs + ``_op_name`` set** — delegates to
59 ``_OP_DISPATCHER.dispatch``. The dispatcher extracts local tensors, calls
60 ``DistributedOp.infer_layout`` to derive the output layout, invokes the
61 local callable (which re-enters ``apply`` with plain tensors, triggering
62 the single-device path), and wraps the result as a ``DTensor``.
64 Note on non-tensor positional arguments:
65 The legacy ``_with_layout_infer`` path does not forward non-tensor
66 positional arguments to the local callable. Pass such values as keyword
67 arguments, or implement ``preprocess()`` in your ``DistributedOp`` to
68 take the ``_dispatch_new`` path which preserves them in ``cache_values``.
70 Example::
72 class MyAddDistOp(DistributedOp):
73 def __init__(self):
74 super().__init__("MyAdd")
76 def infer_layout(self, layouts, extra_args=None):
77 return layouts[0]
79 MyAddDistOp() # instantiation triggers registration
81 class MyAdd(DFunction):
82 _op_name = "MyAdd"
84 @staticmethod
85 def forward(ctx, x, y):
86 ctx.save_for_backward(x, y)
87 return x + y
89 @staticmethod
90 def backward(ctx, grad):
91 return grad, grad
93 result = MyAdd.apply(x, y) # plain tensors → single-device
94 result = MyAdd.apply(dtensor_x, dtensor_y) # DTensors → distributed
95 """
97 _op_name: str = None
99 @staticmethod
100 def forward(ctx, *args, **kwargs):
101 """Override in subclass to define the forward computation on local tensors."""
102 raise NotImplementedError("Subclasses must implement forward()")
104 @staticmethod
105 def backward(ctx, *grad_outputs):
106 """Override in subclass to define the backward computation on local tensors."""
107 raise NotImplementedError("Subclasses must implement backward()")
109 @classmethod
110 def apply(cls, *args, **kwargs):
111 """Execute the function, routing to distributed dispatch when DTensor inputs are present.
113 Args:
114 *args: Positional arguments forwarded to ``forward``. May mix
115 ``DTensor`` and plain ``Tensor``.
116 **kwargs: Keyword arguments forwarded to ``forward``.
118 Returns:
119 Output of the forward computation. Wrapped as ``DTensor`` when the
120 distributed dispatch path is taken and the ``DistributedOp`` infers a
121 valid output layout.
123 Raises:
124 ValueError: If ``DTensor`` inputs are detected but ``_op_name`` is not
125 set on the subclass.
126 """
127 has_dtensor = any(isinstance(a, DTensor) for a in args)
128 if has_dtensor:
129 if cls._op_name is None:
130 raise ValueError(
131 f"{cls.__name__} received DTensor inputs but '_op_name' is not set. "
132 "Set '_op_name' on the subclass and register a matching DistributedOp."
133 )
134 return _OP_DISPATCHER.dispatch(cls._get_local_callable(), args, kwargs)
135 return super().apply(*args, **kwargs)
137 @classmethod
138 def _get_local_callable(cls) -> _LocalCallable:
139 """Return (and lazily create) the named local callable for this subclass.
141 The callable re-enters ``cls.apply`` with plain tensors so the single-device
142 autograd path is taken without recursion. Cached per subclass in
143 ``cls.__dict__`` to avoid repeated object creation.
145 Returns:
146 A ``_LocalCallable`` whose ``__name__`` and ``.name`` equal
147 ``cls._op_name``, enabling ``platform.get_op_name`` to resolve the
148 correct ``DistributedOp``.
149 """
150 if '_local_callable' not in cls.__dict__:
151 def _local_fn(*a, **kw):
152 # Called by _OP_DISPATCHER with extracted local (non-DTensor) tensors.
153 # has_dtensor will be False here, so super().apply() is taken directly.
154 return cls.apply(*a, **kw)
156 cls._local_callable = _LocalCallable(_local_fn, cls._op_name)
157 return cls._local_callable