Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / dtensor.py: 31%

107 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""mindspore dtensor base""" 

16from mindspore.common.tensor import Tensor 

17from mindspore.common.initializer import initializer 

18from mindspore._c_expression import NoFallbackGuard 

19 

20 

21class DTensorBase(Tensor): 

22 """ 

23 DTensorBase - Base class for distributed tensors in MindSpore. 

24 

25 This class extends Tensor to support distributed tensor operations with 

26 device mesh and placement specifications. 

27 """ 

28 

29 def __new__(cls, local_tensor, device_mesh=None, placements=None): 

30 """ 

31 Create a new DTensorBase instance. 

32 

33 Args: 

34 local_tensor: The local tensor shard or another DTensorBase instance. 

35 device_mesh: The device mesh describing the device topology. 

36 placements: The placement strategy for each mesh dimension. 

37 device: The device type (default: "Ascend"). 

38 """ 

39 npu_device = "Ascend" 

40 if isinstance(local_tensor, DTensorBase): 

41 device_local_tensor = local_tensor.to_local() 

42 if device_local_tensor.device != "meta" and not device_local_tensor.has_init: 

43 device_local_tensor = device_local_tensor.to(npu_device) 

44 t = Tensor._make_subclass(cls, device_local_tensor) 

45 copy_placements = local_tensor.layout.alias_placements if local_tensor.layout else local_tensor.placements 

46 t.__init_data__(device_local_tensor, local_tensor.device_mesh, copy_placements) 

47 return t 

48 

49 if local_tensor is None: 

50 raise ValueError( 

51 "DTensorBase: local_tensor must not be None when constructing from a raw tensor." 

52 ) 

53 if device_mesh is None: 

54 raise ValueError( 

55 "DTensorBase: device_mesh must be a DeviceMesh instance, got None." 

56 ) 

57 if placements is None: 

58 raise ValueError( 

59 "DTensorBase: placements must be a sequence of Placement objects, got None." 

60 ) 

61 device_local_tensor = local_tensor 

62 if device_local_tensor.device != "meta" and not device_local_tensor.has_init: 

63 device_local_tensor = device_local_tensor.to(npu_device) 

64 if local_tensor.has_init: 

65 local_tensor.init_device = npu_device 

66 t = Tensor._make_subclass(cls, device_local_tensor) 

67 t.__init_data__(device_local_tensor, device_mesh, placements) 

68 return t 

69 

70 def asnumpy(self): 

71 """ 

72 Numpy value of local tensor. 

73 """ 

74 return self._local_tensor.asnumpy() 

75 

76 def __str__(self): 

77 return str(self._local_tensor) 

78 

79 def __copy__(self): 

80 """ 

81 Create a shallow copy of the DTensorBase instance. 

82 

83 This method ensures that device_mesh and placements are correctly 

84 propagated when creating a copy (e.g., for optimizer states). 

85 """ 

86 # Get device_mesh and placements from layout (prefer alias_placements to preserve multi-axis ordering) 

87 device_mesh = getattr(self, '_device_mesh', None) 

88 placements = None 

89 

90 if hasattr(self, '_layout') and self._layout is not None: 

91 if device_mesh is None: 

92 device_mesh = self._layout.mesh 

93 placements = self._layout.alias_placements 

94 

95 if placements is None: 

96 placements = getattr(self, '_placements', None) 

97 

98 if device_mesh is None or placements is None: 

99 raise ValueError( 

100 "DTensorBase.__copy__: cannot copy without device_mesh and placements; " 

101 f"device_mesh={device_mesh!r}, placements={placements!r}. " 

102 "Ensure the tensor was constructed with a valid layout." 

103 ) 

104 

105 if self._local_tensor.has_init: 

106 obj = DTensorBase.__new__( 

107 type(self), 

108 initializer(self._local_tensor.init, self._local_tensor.shape, self._local_tensor.dtype), 

109 device_mesh, 

110 placements 

111 ) 

112 else: 

113 obj = DTensorBase.__new__( 

114 type(self), 

115 self._local_tensor.clone(), 

116 device_mesh, 

117 placements 

118 ) 

119 filtered_dict = {k: v for k, v in self.__dict__.items() if k != '_local_tensor'} 

120 obj.__dict__.update(filtered_dict) 

121 return obj 

122 

123 # pylint: disable=W0211 

124 # pylint: disable=W0102 

125 # pylint: disable=C0415 

126 def __fallback__(self, func, args={}, kwargs=None): 

127 if kwargs is None: 

128 kwargs = {} 

129 from hyper_parallel.core.shard._op_dispatch import _OP_DISPATCHER 

130 with NoFallbackGuard(): 

131 out = _OP_DISPATCHER.dispatch(func, args, kwargs) 

132 return out 

133 

134 # pylint: disable=W0212 

135 def _need_contiguous(self): 

136 """_need_contiguous""" 

137 return self._local_tensor._need_contiguous() 

138 

139 @property 

140 def device(self): 

141 """Device info for dtensor""" 

142 device_info = self._local_tensor.device 

143 return device_info.split(':', 1)[0] 

144 

145 @property 

146 # pylint: disable=C2801 

147 def data(self): 

148 return Tensor.data.__get__(self, type(self)) 

149 

150 @data.setter 

151 # pylint: disable=C2801 

152 def data(self, value): 

153 local_value = value.to_local() if isinstance(value, DTensorBase) else value 

154 Tensor.data.__set__(self, local_value) 

155 Tensor.data.__set__(self._local_tensor, local_value) 

156 

157 # pylint: disable=W0212 

158 def set_data(self, data, slice_shape=False): 

159 """ 

160 Set shape/dtype/storage for dtensor and local tensor. 

161 

162 Args: 

163 data (Tensor): New tensor payload. 

164 slice_shape (bool): Kept for MindSpore `Parameter.set_data` API 

165 compatibility. Static-graph slicing semantics are not used by 

166 hyper_parallel, so this flag is accepted but ignored. 

167 """ 

168 _ = slice_shape 

169 if not isinstance(data, Tensor): 

170 raise ValueError(f"The data type {type(data)} is not Tensor") 

171 if data.has_init: 

172 data.init_data() 

173 data = data.to(self.device) 

174 if isinstance(data, DTensorBase): 

175 self._local_tensor._update_data(data.to_local()) 

176 self._device_mesh = data.device_mesh 

177 self._placements = data.placements 

178 self._layout = data.layout 

179 self._update_data(self._local_tensor) 

180 return 

181 

182 self._local_tensor._update_data(data) 

183 self._update_data(data) 

184 

185 @property 

186 def has_init(self): 

187 """ 

188 Property to check if the initialization state is set in the local tensor. 

189 

190 Returns: 

191 bool: True if the local tensor has the 'has_init' attribute, False otherwise. 

192 """ 

193 if not hasattr(self._local_tensor, "has_init"): 

194 return False 

195 return self._local_tensor.has_init 

196 

197 @property 

198 def init(self): 

199 """ 

200 Property to get the initialization value from the local tensor. 

201 

202 Returns: 

203 Any: The initialization value stored in the local tensor if the 'init' attribute exists; 

204 None if the 'init' attribute is not present in the local tensor. 

205 """ 

206 if not hasattr(self._local_tensor, "init"): 

207 return None 

208 return self._local_tensor.init 

209 

210 @init.setter 

211 def init(self, init_value): 

212 """ 

213 Setter for the initialization value, which assigns the value to the local tensor's 'init' attribute. 

214 

215 Args: 

216 init_value: The value to be set as the initialization value in the local tensor. 

217 """ 

218 self._local_tensor.init = init_value 

219 

220 @property 

221 def local_param_info(self): 

222 """ 

223 Property to get the param_info value from the local tensor. 

224 

225 Returns: 

226 Any: The param_info value stored in the local tensor if the 'param_info' attribute exists; 

227 None if the 'param_info' attribute is not present in the local tensor. 

228 """ 

229 if not hasattr(self._local_tensor, "param_info"): 

230 return None 

231 return self._local_tensor.param_info 

232 

233 @local_param_info.setter 

234 def local_param_info(self, local_param_info_value): 

235 """ 

236 Setter for local_param_info value, which assigns the value to the local tensor's 'param_info' attribute. 

237 

238 Args: 

239 local_param_info_value: The value to be set as the param_info value in the local tensor. 

240 """ 

241 self._local_tensor.param_info = local_param_info_value