Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / dfunction.py: 89%

35 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Custom distributed autograd function base class.""" 

16from __future__ import annotations 

17 

18from hyper_parallel.platform import get_platform 

19from hyper_parallel.core.dtensor.dtensor import DTensor 

20from hyper_parallel.core.shard._op_dispatch import _OP_DISPATCHER 

21 

22platform = get_platform() 

23 

24 

25class _LocalCallable: 

26 """Named callable wrapper that exposes the op name to both platform dispatchers. 

27 

28 PyTorch's ``get_op_name`` inspects ``__name__``; MindSpore's inspects ``.name``. 

29 Setting both attributes here lets ``_OP_DISPATCHER`` look up the correct 

30 ``DistributedOp`` without modifying either platform implementation. 

31 

32 Args: 

33 fn: The underlying callable to invoke. 

34 op_name: Canonical op name matching the registered ``DistributedOp``. 

35 """ 

36 

37 def __init__(self, fn: callable, op_name: str) -> None: 

38 self._fn = fn 

39 self.__name__ = op_name 

40 self.name = op_name 

41 

42 def __call__(self, *args, **kwargs): 

43 return self._fn(*args, **kwargs) 

44 

45 

46class DFunction(platform.Function): 

47 """Base class for user-defined distributed autograd functions. 

48 

49 Subclass this class and implement ``forward`` and ``backward`` as 

50 ``@staticmethod`` methods that operate on **local** tensors. To enable 

51 multi-device support, also set ``_op_name`` to a string that matches the 

52 ``op_name`` of a registered ``DistributedOp`` subclass. 

53 

54 Dispatch behaviour: 

55 

56 * **No DTensor inputs** — calls ``super().apply()`` directly, going straight 

57 into the platform autograd mechanism (single-device path). 

58 * **DTensor inputs + ``_op_name`` set** — delegates to 

59 ``_OP_DISPATCHER.dispatch``. The dispatcher extracts local tensors, calls 

60 ``DistributedOp.infer_layout`` to derive the output layout, invokes the 

61 local callable (which re-enters ``apply`` with plain tensors, triggering 

62 the single-device path), and wraps the result as a ``DTensor``. 

63 

64 Note on non-tensor positional arguments: 

65 The legacy ``_with_layout_infer`` path does not forward non-tensor 

66 positional arguments to the local callable. Pass such values as keyword 

67 arguments, or implement ``preprocess()`` in your ``DistributedOp`` to 

68 take the ``_dispatch_new`` path which preserves them in ``cache_values``. 

69 

70 Example:: 

71 

72 class MyAddDistOp(DistributedOp): 

73 def __init__(self): 

74 super().__init__("MyAdd") 

75 

76 def infer_layout(self, layouts, extra_args=None): 

77 return layouts[0] 

78 

79 MyAddDistOp() # instantiation triggers registration 

80 

81 class MyAdd(DFunction): 

82 _op_name = "MyAdd" 

83 

84 @staticmethod 

85 def forward(ctx, x, y): 

86 ctx.save_for_backward(x, y) 

87 return x + y 

88 

89 @staticmethod 

90 def backward(ctx, grad): 

91 return grad, grad 

92 

93 result = MyAdd.apply(x, y) # plain tensors → single-device 

94 result = MyAdd.apply(dtensor_x, dtensor_y) # DTensors → distributed 

95 """ 

96 

97 _op_name: str = None 

98 

99 @staticmethod 

100 def forward(ctx, *args, **kwargs): 

101 """Override in subclass to define the forward computation on local tensors.""" 

102 raise NotImplementedError("Subclasses must implement forward()") 

103 

104 @staticmethod 

105 def backward(ctx, *grad_outputs): 

106 """Override in subclass to define the backward computation on local tensors.""" 

107 raise NotImplementedError("Subclasses must implement backward()") 

108 

109 @classmethod 

110 def apply(cls, *args, **kwargs): 

111 """Execute the function, routing to distributed dispatch when DTensor inputs are present. 

112 

113 Args: 

114 *args: Positional arguments forwarded to ``forward``. May mix 

115 ``DTensor`` and plain ``Tensor``. 

116 **kwargs: Keyword arguments forwarded to ``forward``. 

117 

118 Returns: 

119 Output of the forward computation. Wrapped as ``DTensor`` when the 

120 distributed dispatch path is taken and the ``DistributedOp`` infers a 

121 valid output layout. 

122 

123 Raises: 

124 ValueError: If ``DTensor`` inputs are detected but ``_op_name`` is not 

125 set on the subclass. 

126 """ 

127 has_dtensor = any(isinstance(a, DTensor) for a in args) 

128 if has_dtensor: 

129 if cls._op_name is None: 

130 raise ValueError( 

131 f"{cls.__name__} received DTensor inputs but '_op_name' is not set. " 

132 "Set '_op_name' on the subclass and register a matching DistributedOp." 

133 ) 

134 return _OP_DISPATCHER.dispatch(cls._get_local_callable(), args, kwargs) 

135 return super().apply(*args, **kwargs) 

136 

137 @classmethod 

138 def _get_local_callable(cls) -> _LocalCallable: 

139 """Return (and lazily create) the named local callable for this subclass. 

140 

141 The callable re-enters ``cls.apply`` with plain tensors so the single-device 

142 autograd path is taken without recursion. Cached per subclass in 

143 ``cls.__dict__`` to avoid repeated object creation. 

144 

145 Returns: 

146 A ``_LocalCallable`` whose ``__name__`` and ``.name`` equal 

147 ``cls._op_name``, enabling ``platform.get_op_name`` to resolve the 

148 correct ``DistributedOp``. 

149 """ 

150 if '_local_callable' not in cls.__dict__: 

151 def _local_fn(*a, **kw): 

152 # Called by _OP_DISPATCHER with extracted local (non-DTensor) tensors. 

153 # has_dtensor will be False here, so super().apply() is taken directly. 

154 return cls.apply(*a, **kw) 

155 

156 cls._local_callable = _LocalCallable(_local_fn, cls._op_name) 

157 return cls._local_callable