Coverage for hyper_parallel / platform / mindspore / dtensor.py: 79%

89 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""mindspore dtensor base""" 

16from mindspore.common.tensor import Tensor 

17from mindspore.common.initializer import initializer 

18from mindspore._c_expression import NoFallbackGuard 

19 

20 

21class DTensorBase(Tensor): 

22 """ 

23 DTensorBase - Base class for distributed tensors in MindSpore. 

24 

25 This class extends Tensor to support distributed tensor operations with 

26 device mesh and placement specifications. 

27 """ 

28 

29 def __new__(cls, local_tensor, device_mesh=None, placements=None, device="Ascend"): 

30 """ 

31 Create a new DTensorBase instance. 

32 

33 Args: 

34 local_tensor: The local tensor shard or another DTensorBase instance. 

35 device_mesh: The device mesh describing the device topology. 

36 placements: The placement strategy for each mesh dimension. 

37 device: The device type (default: "Ascend"). 

38 """ 

39 if isinstance(local_tensor, DTensorBase): 

40 device_local_tensor = local_tensor.to_local() if local_tensor.to_local().has_init else \ 

41 local_tensor.to_local().to(device) 

42 t = Tensor._make_subclass(cls, device_local_tensor) 

43 t.__init_data__(device_local_tensor, local_tensor.device_mesh, local_tensor.placements) 

44 t._device = device 

45 return t 

46 if device_mesh is None: 

47 raise ValueError("device_mesh is None") 

48 if placements is None: 

49 raise ValueError("placements is None") 

50 device_local_tensor = local_tensor if local_tensor.has_init else local_tensor.to(device) 

51 if local_tensor.has_init: 

52 local_tensor.init_device = device 

53 t = Tensor._make_subclass(cls, device_local_tensor) 

54 t.__init_data__(device_local_tensor, device_mesh, placements) 

55 t._device = device 

56 return t 

57 

58 def asnumpy(self): 

59 """ 

60 Numpy value of local tensor. 

61 """ 

62 return self._local_tensor.asnumpy() 

63 

64 def __str__(self): 

65 return str(self._local_tensor) 

66 

67 def __copy__(self): 

68 """ 

69 Create a shallow copy of the DTensorBase instance. 

70 

71 This method ensures that device_mesh and placements are correctly 

72 propagated when creating a copy (e.g., for optimizer states). 

73 """ 

74 # Get device_mesh and placements from either direct attributes or from layout 

75 device_mesh = getattr(self, '_device_mesh', None) 

76 placements = getattr(self, '_placements', None) 

77 

78 # If not found directly, try to get from layout 

79 if device_mesh is None and hasattr(self, '_layout') and self._layout is not None: 

80 device_mesh = self._layout.mesh 

81 if placements is None and hasattr(self, '_layout') and self._layout is not None: 

82 placements = self._layout.placements 

83 

84 if device_mesh is None or placements is None: 

85 raise ValueError("Cannot copy DTensorBase: device_mesh or placements is None") 

86 

87 if self._local_tensor.has_init: 

88 obj = DTensorBase.__new__( 

89 type(self), 

90 initializer(self._local_tensor.init, self._local_tensor.shape, self._local_tensor.dtype), 

91 device_mesh, 

92 placements 

93 ) 

94 else: 

95 obj = DTensorBase.__new__( 

96 type(self), 

97 self._local_tensor.clone(), 

98 device_mesh, 

99 placements 

100 ) 

101 filtered_dict = {k: v for k, v in self.__dict__.items() if k != '_local_tensor'} 

102 obj.__dict__.update(filtered_dict) 

103 return obj 

104 

105 # pylint: disable=W0211 

106 # pylint: disable=W0102 

107 # pylint: disable=C0415 

108 def __fallback__(self, func, args={}, kwargs=None): 

109 if kwargs is None: 

110 kwargs = {} 

111 from hyper_parallel.core.shard._op_dispatch import _OP_DISPATCHER 

112 with NoFallbackGuard(): 

113 out = _OP_DISPATCHER.dispatch(func, args, kwargs) 

114 return out 

115 

116 # pylint: disable=W0212 

117 def _need_contiguous(self): 

118 """_need_contiguous""" 

119 return self._local_tensor._need_contiguous() 

120 

121 @property 

122 def device(self): 

123 """Device info for dtensor""" 

124 return self._device 

125 

126 # pylint: disable=W0212 

127 def set_data(self, data): 

128 """ 

129 Set shape/dtype/storage for dtensor and local tensor. 

130 """ 

131 if not isinstance(data, Tensor): 

132 raise ValueError(f"The data type {type(data)} is not Tensor") 

133 if data.has_init: 

134 data.init_data() 

135 data = data.to(self._device) 

136 if isinstance(data, DTensorBase): 

137 self._local_tensor._update_data(data.to_local()) 

138 self._device_mesh = data.device_mesh 

139 self._placements = data.placements 

140 self._layout = data.layout 

141 self._update_data(self._local_tensor) 

142 return 

143 

144 self._local_tensor._update_data(data) 

145 self._update_data(data) 

146 

147 @property 

148 def has_init(self): 

149 """ 

150 Property to check if the initialization state is set in the local tensor. 

151 

152 Returns: 

153 bool: True if the local tensor has the 'has_init' attribute, False otherwise. 

154 """ 

155 if not hasattr(self._local_tensor, "has_init"): 

156 return False 

157 return self._local_tensor.has_init 

158 

159 @property 

160 def init(self): 

161 """ 

162 Property to get the initialization value from the local tensor. 

163 

164 Returns: 

165 Any: The initialization value stored in the local tensor if the 'init' attribute exists; 

166 None if the 'init' attribute is not present in the local tensor. 

167 """ 

168 if not hasattr(self._local_tensor, "init"): 

169 return None 

170 return self._local_tensor.init 

171 

172 @init.setter 

173 def init(self, init_value): 

174 """ 

175 Setter for the initialization value, which assigns the value to the local tensor's 'init' attribute. 

176 

177 Args: 

178 init_value: The value to be set as the initialization value in the local tensor. 

179 """ 

180 self._local_tensor.init = init_value 

181 

182 @property 

183 def local_param_info(self): 

184 """ 

185 Property to get the param_info value from the local tensor. 

186 

187 Returns: 

188 Any: The param_info value stored in the local tensor if the 'param_info' attribute exists; 

189 None if the 'param_info' attribute is not present in the local tensor. 

190 """ 

191 if not hasattr(self._local_tensor, "param_info"): 

192 return None 

193 return self._local_tensor.param_info 

194 

195 @local_param_info.setter 

196 def local_param_info(self, local_param_info_value): 

197 """ 

198 Setter for local_param_info value, which assigns the value to the local tensor's 'param_info' attribute. 

199 

200 Args: 

201 local_param_info_value: The value to be set as the param_info value in the local tensor. 

202 """ 

203 self._local_tensor.param_info = local_param_info_value