Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / autograd_compat.py: 70%
118 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
16"""MindSpore backward-style autograd compatibility helpers."""
17# pylint: disable=protected-access,import-outside-toplevel
19from __future__ import annotations
21import warnings
23from mindspore import ops
24from mindspore._c_expression import TensorPy, pyboost_detach, run_backward
25from mindspore._c_expression import typing
26from mindspore.graph.api import _pynative_executor
28_BACKWARD_COMPAT_ENABLED = False
31@property
32def requires_grad(self):
33 """Return whether the tensor requires gradient."""
34 return self._requires_grad
37@requires_grad.setter
38def requires_grad(self, value=True):
39 if not isinstance(value, bool):
40 raise TypeError("The argument `requires_grad` must be bool type")
41 self._requires_grad = value
44@property
45def grad(self):
46 """Return the current accumulated gradient."""
47 if not self.is_leaf and self.requires_grad:
48 warnings.warn(
49 "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. "
50 "Its .grad attribute won't be populated during autograd.backward(). "
51 "If you indeed want the .grad field to be populated for a non-leaf Tensor, "
52 "use .retain_grad() on the non-leaf Tensor.",
53 stacklevel=2,
54 )
55 dtensor_grad = getattr(self, "_dtensor_grad", None)
56 if dtensor_grad is not None:
57 return dtensor_grad
58 return self._grad
61@grad.setter
62def grad(self, value):
63 try:
64 from hyper_parallel.core.dtensor.dtensor import DTensor
65 except ImportError:
66 DTensor = ()
68 if value is None:
69 self._dtensor_grad = None
70 self._grad = None
71 return
73 if DTensor and isinstance(value, DTensor):
74 self._dtensor_grad = value
75 self._grad = value._local_tensor
76 return
78 self._dtensor_grad = None
79 self._grad = value
82@property
83def is_leaf(self):
84 """Return whether the tensor is a leaf."""
85 return self._is_leaf
88@property
89def retains_grad(self):
90 """Return whether the tensor retains gradient."""
91 return self._retains_grad
94@property
95def grad_fn(self):
96 if self._grad_node and self._grad_node.is_leaf():
97 return None
98 return self._grad_node
101@property
102def output_nr(self):
103 return self._output_index
106def retain_grad(self):
107 """Set the tensor retains gradient."""
108 return self._retain_grad()
111def detach(self):
112 """Detach the tensor."""
113 detached = pyboost_detach(self)
114 detached._dtensor_grad = None
115 return detached
118def _is_same_size(output, grad_tensor):
119 return tuple(output.shape) == tuple(grad_tensor.shape)
122def _calculate_shape(output, grad_tensor):
123 return output.shape, grad_tensor.shape
126def _tensor_or_tensors_to_tuple(tensors, length):
127 if tensors is None:
128 return (None,) * length
129 if isinstance(tensors, TensorPy):
130 return (tensors,)
131 return tuple(tensors)
134def _make_grads(outputs, grads):
135 """Validate backward gradients and materialize implicit scalar grads."""
136 new_grads = []
137 for index, (out, grad_tensor) in enumerate(zip(outputs, grads)):
138 if isinstance(grad_tensor, TensorPy):
139 if not _is_same_size(out, grad_tensor):
140 out_shape, grad_shape = _calculate_shape(out, grad_tensor)
141 raise RuntimeError(
142 "Mismatch in shape: grad_output["
143 + str(index)
144 + "] has a shape of "
145 + str(grad_shape)
146 + " and output["
147 + str(index)
148 + "] has a shape of "
149 + str(out_shape)
150 + "."
151 )
152 if out.dtype.is_complex != grad_tensor.dtype.is_complex:
153 raise RuntimeError(
154 "For complex Tensors, both grad_output and output"
155 " are required to have the same dtype."
156 " Mismatch in dtype: grad_output["
157 + str(index)
158 + "] has a dtype of "
159 + str(grad_tensor.dtype)
160 + " and output["
161 + str(index)
162 + "] has a dtype of "
163 + str(out.dtype)
164 + "."
165 )
166 new_grads.append(grad_tensor)
167 elif grad_tensor is None:
168 if out.numel() != 1:
169 raise RuntimeError("grad can be implicitly created only for scalar outputs")
170 if not isinstance(out.dtype, (typing.Float, typing.BFloat)):
171 raise RuntimeError(
172 f"grad can be implicitly created only for real scalar outputs but got {out.dtype}"
173 )
174 new_grads.append(ops.ones_like(out))
175 else:
176 raise TypeError(
177 "gradients can be either Tensors or None, but got " + type(grad_tensor).__name__
178 )
179 return tuple(new_grads)
182def backward(self, gradient=None, retain_graph=None, create_graph=False, inputs=None):
183 """Run torch-style backward on a MindSpore tensor."""
184 outputs = (self,)
185 has_explicit_inputs = inputs is not None
186 if isinstance(inputs, list):
187 inputs = tuple(inputs)
188 elif isinstance(inputs, TensorPy):
189 inputs = (inputs,)
190 elif inputs is None:
191 inputs = ()
192 else:
193 inputs = tuple(inputs)
194 if has_explicit_inputs and len(inputs) == 0:
195 raise RuntimeError("'inputs' argument to backward() cannot be empty.")
197 grad_tensors = _tensor_or_tensors_to_tuple(gradient, len(outputs))
198 grad_tensors = _make_grads(outputs, grad_tensors)
199 if retain_graph is None:
200 retain_graph = create_graph
202 return run_backward(
203 outputs,
204 grad_tensors,
205 retain_graph,
206 create_graph,
207 inputs,
208 allow_unreachable=True,
209 accumulate_grad=True,
210 )
213def enable_mindspore_backward_compat() -> None:
214 """Enable torch-like ``Tensor.backward()`` semantics for MindSpore PyNative."""
215 global _BACKWARD_COMPAT_ENABLED
216 if _BACKWARD_COMPAT_ENABLED:
217 return
219 _pynative_executor.set_grad_flag(True)
220 TensorPy.requires_grad = requires_grad
221 TensorPy.grad = grad
222 TensorPy.backward = backward
223 TensorPy.is_leaf = is_leaf
224 TensorPy.retains_grad = retains_grad
225 TensorPy.retain_grad = retain_grad
226 TensorPy.grad_fn = grad_fn
227 TensorPy.output_nr = output_nr
228 TensorPy.detach = detach
229 _BACKWARD_COMPAT_ENABLED = True