Coverage for hyper_parallel / platform / mindspore / dtensor.py: 79%
89 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""mindspore dtensor base"""
16from mindspore.common.tensor import Tensor
17from mindspore.common.initializer import initializer
18from mindspore._c_expression import NoFallbackGuard
21class DTensorBase(Tensor):
22 """
23 DTensorBase - Base class for distributed tensors in MindSpore.
25 This class extends Tensor to support distributed tensor operations with
26 device mesh and placement specifications.
27 """
29 def __new__(cls, local_tensor, device_mesh=None, placements=None, device="Ascend"):
30 """
31 Create a new DTensorBase instance.
33 Args:
34 local_tensor: The local tensor shard or another DTensorBase instance.
35 device_mesh: The device mesh describing the device topology.
36 placements: The placement strategy for each mesh dimension.
37 device: The device type (default: "Ascend").
38 """
39 if isinstance(local_tensor, DTensorBase):
40 device_local_tensor = local_tensor.to_local() if local_tensor.to_local().has_init else \
41 local_tensor.to_local().to(device)
42 t = Tensor._make_subclass(cls, device_local_tensor)
43 t.__init_data__(device_local_tensor, local_tensor.device_mesh, local_tensor.placements)
44 t._device = device
45 return t
46 if device_mesh is None:
47 raise ValueError("device_mesh is None")
48 if placements is None:
49 raise ValueError("placements is None")
50 device_local_tensor = local_tensor if local_tensor.has_init else local_tensor.to(device)
51 if local_tensor.has_init:
52 local_tensor.init_device = device
53 t = Tensor._make_subclass(cls, device_local_tensor)
54 t.__init_data__(device_local_tensor, device_mesh, placements)
55 t._device = device
56 return t
58 def asnumpy(self):
59 """
60 Numpy value of local tensor.
61 """
62 return self._local_tensor.asnumpy()
64 def __str__(self):
65 return str(self._local_tensor)
67 def __copy__(self):
68 """
69 Create a shallow copy of the DTensorBase instance.
71 This method ensures that device_mesh and placements are correctly
72 propagated when creating a copy (e.g., for optimizer states).
73 """
74 # Get device_mesh and placements from either direct attributes or from layout
75 device_mesh = getattr(self, '_device_mesh', None)
76 placements = getattr(self, '_placements', None)
78 # If not found directly, try to get from layout
79 if device_mesh is None and hasattr(self, '_layout') and self._layout is not None:
80 device_mesh = self._layout.mesh
81 if placements is None and hasattr(self, '_layout') and self._layout is not None:
82 placements = self._layout.placements
84 if device_mesh is None or placements is None:
85 raise ValueError("Cannot copy DTensorBase: device_mesh or placements is None")
87 if self._local_tensor.has_init:
88 obj = DTensorBase.__new__(
89 type(self),
90 initializer(self._local_tensor.init, self._local_tensor.shape, self._local_tensor.dtype),
91 device_mesh,
92 placements
93 )
94 else:
95 obj = DTensorBase.__new__(
96 type(self),
97 self._local_tensor.clone(),
98 device_mesh,
99 placements
100 )
101 filtered_dict = {k: v for k, v in self.__dict__.items() if k != '_local_tensor'}
102 obj.__dict__.update(filtered_dict)
103 return obj
105 # pylint: disable=W0211
106 # pylint: disable=W0102
107 # pylint: disable=C0415
108 def __fallback__(self, func, args={}, kwargs=None):
109 if kwargs is None:
110 kwargs = {}
111 from hyper_parallel.core.shard._op_dispatch import _OP_DISPATCHER
112 with NoFallbackGuard():
113 out = _OP_DISPATCHER.dispatch(func, args, kwargs)
114 return out
116 # pylint: disable=W0212
117 def _need_contiguous(self):
118 """_need_contiguous"""
119 return self._local_tensor._need_contiguous()
121 @property
122 def device(self):
123 """Device info for dtensor"""
124 return self._device
126 # pylint: disable=W0212
127 def set_data(self, data):
128 """
129 Set shape/dtype/storage for dtensor and local tensor.
130 """
131 if not isinstance(data, Tensor):
132 raise ValueError(f"The data type {type(data)} is not Tensor")
133 if data.has_init:
134 data.init_data()
135 data = data.to(self._device)
136 if isinstance(data, DTensorBase):
137 self._local_tensor._update_data(data.to_local())
138 self._device_mesh = data.device_mesh
139 self._placements = data.placements
140 self._layout = data.layout
141 self._update_data(self._local_tensor)
142 return
144 self._local_tensor._update_data(data)
145 self._update_data(data)
147 @property
148 def has_init(self):
149 """
150 Property to check if the initialization state is set in the local tensor.
152 Returns:
153 bool: True if the local tensor has the 'has_init' attribute, False otherwise.
154 """
155 if not hasattr(self._local_tensor, "has_init"):
156 return False
157 return self._local_tensor.has_init
159 @property
160 def init(self):
161 """
162 Property to get the initialization value from the local tensor.
164 Returns:
165 Any: The initialization value stored in the local tensor if the 'init' attribute exists;
166 None if the 'init' attribute is not present in the local tensor.
167 """
168 if not hasattr(self._local_tensor, "init"):
169 return None
170 return self._local_tensor.init
172 @init.setter
173 def init(self, init_value):
174 """
175 Setter for the initialization value, which assigns the value to the local tensor's 'init' attribute.
177 Args:
178 init_value: The value to be set as the initialization value in the local tensor.
179 """
180 self._local_tensor.init = init_value
182 @property
183 def local_param_info(self):
184 """
185 Property to get the param_info value from the local tensor.
187 Returns:
188 Any: The param_info value stored in the local tensor if the 'param_info' attribute exists;
189 None if the 'param_info' attribute is not present in the local tensor.
190 """
191 if not hasattr(self._local_tensor, "param_info"):
192 return None
193 return self._local_tensor.param_info
195 @local_param_info.setter
196 def local_param_info(self, local_param_info_value):
197 """
198 Setter for local_param_info value, which assigns the value to the local tensor's 'param_info' attribute.
200 Args:
201 local_param_info_value: The value to be set as the param_info value in the local tensor.
202 """
203 self._local_tensor.param_info = local_param_info_value