Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / dtensor.py: 31%
107 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""mindspore dtensor base"""
16from mindspore.common.tensor import Tensor
17from mindspore.common.initializer import initializer
18from mindspore._c_expression import NoFallbackGuard
21class DTensorBase(Tensor):
22 """
23 DTensorBase - Base class for distributed tensors in MindSpore.
25 This class extends Tensor to support distributed tensor operations with
26 device mesh and placement specifications.
27 """
29 def __new__(cls, local_tensor, device_mesh=None, placements=None):
30 """
31 Create a new DTensorBase instance.
33 Args:
34 local_tensor: The local tensor shard or another DTensorBase instance.
35 device_mesh: The device mesh describing the device topology.
36 placements: The placement strategy for each mesh dimension.
37 device: The device type (default: "Ascend").
38 """
39 npu_device = "Ascend"
40 if isinstance(local_tensor, DTensorBase):
41 device_local_tensor = local_tensor.to_local()
42 if device_local_tensor.device != "meta" and not device_local_tensor.has_init:
43 device_local_tensor = device_local_tensor.to(npu_device)
44 t = Tensor._make_subclass(cls, device_local_tensor)
45 copy_placements = local_tensor.layout.alias_placements if local_tensor.layout else local_tensor.placements
46 t.__init_data__(device_local_tensor, local_tensor.device_mesh, copy_placements)
47 return t
49 if local_tensor is None:
50 raise ValueError(
51 "DTensorBase: local_tensor must not be None when constructing from a raw tensor."
52 )
53 if device_mesh is None:
54 raise ValueError(
55 "DTensorBase: device_mesh must be a DeviceMesh instance, got None."
56 )
57 if placements is None:
58 raise ValueError(
59 "DTensorBase: placements must be a sequence of Placement objects, got None."
60 )
61 device_local_tensor = local_tensor
62 if device_local_tensor.device != "meta" and not device_local_tensor.has_init:
63 device_local_tensor = device_local_tensor.to(npu_device)
64 if local_tensor.has_init:
65 local_tensor.init_device = npu_device
66 t = Tensor._make_subclass(cls, device_local_tensor)
67 t.__init_data__(device_local_tensor, device_mesh, placements)
68 return t
70 def asnumpy(self):
71 """
72 Numpy value of local tensor.
73 """
74 return self._local_tensor.asnumpy()
76 def __str__(self):
77 return str(self._local_tensor)
79 def __copy__(self):
80 """
81 Create a shallow copy of the DTensorBase instance.
83 This method ensures that device_mesh and placements are correctly
84 propagated when creating a copy (e.g., for optimizer states).
85 """
86 # Get device_mesh and placements from layout (prefer alias_placements to preserve multi-axis ordering)
87 device_mesh = getattr(self, '_device_mesh', None)
88 placements = None
90 if hasattr(self, '_layout') and self._layout is not None:
91 if device_mesh is None:
92 device_mesh = self._layout.mesh
93 placements = self._layout.alias_placements
95 if placements is None:
96 placements = getattr(self, '_placements', None)
98 if device_mesh is None or placements is None:
99 raise ValueError(
100 "DTensorBase.__copy__: cannot copy without device_mesh and placements; "
101 f"device_mesh={device_mesh!r}, placements={placements!r}. "
102 "Ensure the tensor was constructed with a valid layout."
103 )
105 if self._local_tensor.has_init:
106 obj = DTensorBase.__new__(
107 type(self),
108 initializer(self._local_tensor.init, self._local_tensor.shape, self._local_tensor.dtype),
109 device_mesh,
110 placements
111 )
112 else:
113 obj = DTensorBase.__new__(
114 type(self),
115 self._local_tensor.clone(),
116 device_mesh,
117 placements
118 )
119 filtered_dict = {k: v for k, v in self.__dict__.items() if k != '_local_tensor'}
120 obj.__dict__.update(filtered_dict)
121 return obj
123 # pylint: disable=W0211
124 # pylint: disable=W0102
125 # pylint: disable=C0415
126 def __fallback__(self, func, args={}, kwargs=None):
127 if kwargs is None:
128 kwargs = {}
129 from hyper_parallel.core.shard._op_dispatch import _OP_DISPATCHER
130 with NoFallbackGuard():
131 out = _OP_DISPATCHER.dispatch(func, args, kwargs)
132 return out
134 # pylint: disable=W0212
135 def _need_contiguous(self):
136 """_need_contiguous"""
137 return self._local_tensor._need_contiguous()
139 @property
140 def device(self):
141 """Device info for dtensor"""
142 device_info = self._local_tensor.device
143 return device_info.split(':', 1)[0]
145 @property
146 # pylint: disable=C2801
147 def data(self):
148 return Tensor.data.__get__(self, type(self))
150 @data.setter
151 # pylint: disable=C2801
152 def data(self, value):
153 local_value = value.to_local() if isinstance(value, DTensorBase) else value
154 Tensor.data.__set__(self, local_value)
155 Tensor.data.__set__(self._local_tensor, local_value)
157 # pylint: disable=W0212
158 def set_data(self, data, slice_shape=False):
159 """
160 Set shape/dtype/storage for dtensor and local tensor.
162 Args:
163 data (Tensor): New tensor payload.
164 slice_shape (bool): Kept for MindSpore `Parameter.set_data` API
165 compatibility. Static-graph slicing semantics are not used by
166 hyper_parallel, so this flag is accepted but ignored.
167 """
168 _ = slice_shape
169 if not isinstance(data, Tensor):
170 raise ValueError(f"The data type {type(data)} is not Tensor")
171 if data.has_init:
172 data.init_data()
173 data = data.to(self.device)
174 if isinstance(data, DTensorBase):
175 self._local_tensor._update_data(data.to_local())
176 self._device_mesh = data.device_mesh
177 self._placements = data.placements
178 self._layout = data.layout
179 self._update_data(self._local_tensor)
180 return
182 self._local_tensor._update_data(data)
183 self._update_data(data)
185 @property
186 def has_init(self):
187 """
188 Property to check if the initialization state is set in the local tensor.
190 Returns:
191 bool: True if the local tensor has the 'has_init' attribute, False otherwise.
192 """
193 if not hasattr(self._local_tensor, "has_init"):
194 return False
195 return self._local_tensor.has_init
197 @property
198 def init(self):
199 """
200 Property to get the initialization value from the local tensor.
202 Returns:
203 Any: The initialization value stored in the local tensor if the 'init' attribute exists;
204 None if the 'init' attribute is not present in the local tensor.
205 """
206 if not hasattr(self._local_tensor, "init"):
207 return None
208 return self._local_tensor.init
210 @init.setter
211 def init(self, init_value):
212 """
213 Setter for the initialization value, which assigns the value to the local tensor's 'init' attribute.
215 Args:
216 init_value: The value to be set as the initialization value in the local tensor.
217 """
218 self._local_tensor.init = init_value
220 @property
221 def local_param_info(self):
222 """
223 Property to get the param_info value from the local tensor.
225 Returns:
226 Any: The param_info value stored in the local tensor if the 'param_info' attribute exists;
227 None if the 'param_info' attribute is not present in the local tensor.
228 """
229 if not hasattr(self._local_tensor, "param_info"):
230 return None
231 return self._local_tensor.param_info
233 @local_param_info.setter
234 def local_param_info(self, local_param_info_value):
235 """
236 Setter for local_param_info value, which assigns the value to the local tensor's 'param_info' attribute.
238 Args:
239 local_param_info_value: The value to be set as the param_info value in the local tensor.
240 """
241 self._local_tensor.param_info = local_param_info_value