Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / dtensor.py: 57%

125 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""torch dtensor base""" 

16from typing import Tuple, Dict, Any, Optional 

17import torch 

18from torch import Tensor 

19 

20 

21class DTensorBase(Tensor): 

22 """torch dtensor base""" 

23 

24 def __new__(cls, local_tensor, device_mesh=None, placements=None): 

25 """ 

26 Create a new DTensorBase instance. 

27 

28 Args: 

29 local_tensor: The local tensor shard or another DTensorBase instance. 

30 device_mesh: The device mesh describing the device topology. 

31 placements: The placement strategy for each mesh dimension. 

32 """ 

33 if isinstance(local_tensor, DTensorBase): 

34 # Copy from existing DTensorBase — use alias_placements to preserve multi-axis ordering 

35 t = Tensor._make_subclass(cls, local_tensor._local_tensor, local_tensor._local_tensor.requires_grad) 

36 copy_placements = local_tensor.layout.alias_placements if local_tensor.layout else local_tensor.placements 

37 t.__init_data__(local_tensor._local_tensor, local_tensor.device_mesh, copy_placements) 

38 return t 

39 

40 if device_mesh is None: 

41 raise ValueError("device_mesh is None, must provide a DeviceMesh instance") 

42 if placements is None: 

43 raise ValueError("placements is None, must provide placements") 

44 

45 # Create Tensor subclass instance, sharing local_tensor's underlying storage 

46 t = Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad) 

47 t.__init_data__(local_tensor, device_mesh, placements) 

48 return t 

49 

50 # pylint: disable=W0613 

51 @classmethod 

52 def __torch_function__( 

53 cls, 

54 func: torch._C._FunctionBase, 

55 types: Tuple[type, ...], 

56 args: Tuple[Any, ...] = (), 

57 kwargs: Optional[Dict[str, Any]] = None 

58 ) -> Any: 

59 """ 

60 Override PyTorch's __torch_function__ to intercept tensor operations. 

61 

62 This method dispatches operations through the distributed operator dispatcher 

63 to handle DTensor-specific layout inference and redistribution. 

64 

65 Args: 

66 func (torch._C._FunctionBase): The PyTorch function being called. 

67 types (Tuple[type, ...]): The types of tensors involved in the operation. 

68 args (Tuple[Any, ...]): Positional arguments passed to the function. 

69 kwargs (Optional[Dict[str, Any]]): Keyword arguments passed to the function. 

70 

71 Returns: 

72 Any: The result of the dispatched operation, typically a DTensor or tuple of DTensors. 

73 """ 

74 kwargs = kwargs or {} 

75 # pylint: disable=C0415 

76 from hyper_parallel.core.shard._op_dispatch import _OP_DISPATCHER 

77 out = _OP_DISPATCHER.dispatch(func, args, kwargs) 

78 return out 

79 

80 @property 

81 def grad(self) -> Optional[Tensor]: 

82 """ 

83 Get the gradient tensor of the local tensor. 

84 

85 Returns: 

86 Optional[Tensor]: The gradient tensor, or None if no gradient is set. 

87 """ 

88 return self._local_tensor.grad 

89 

90 @grad.setter 

91 def grad(self, value: Optional[Tensor]) -> None: 

92 """ 

93 Set the gradient tensor for the local tensor. 

94 

95 Args: 

96 value (Optional[Tensor]): The gradient tensor to set, or None to clear. 

97 """ 

98 self._local_tensor.grad = value 

99 

100 @property 

101 def requires_grad(self) -> bool: 

102 """ 

103 Check if gradient computation is enabled for this tensor. 

104 

105 Returns: 

106 bool: True if gradients should be computed for this tensor. 

107 """ 

108 return self._local_tensor.requires_grad 

109 

110 @requires_grad.setter 

111 def requires_grad(self, value: bool) -> None: 

112 """ 

113 Enable or disable gradient computation for this tensor. 

114 

115 Args: 

116 value (bool): True to enable gradient computation, False to disable. 

117 """ 

118 self._local_tensor.requires_grad_(value) 

119 # Sync DTensor wrapper's requires_grad 

120 super().requires_grad_(value) 

121 

122 def requires_grad_(self, requires_grad: bool = True): 

123 """ 

124 Enable or disable gradient computation in-place. 

125 

126 Args: 

127 requires_grad (bool): True to enable gradient computation. Default: True. 

128 

129 Returns: 

130 DTensorBase: Self for method chaining. 

131 """ 

132 self._local_tensor.requires_grad_(requires_grad) 

133 super().requires_grad_(requires_grad) 

134 return self 

135 

136 @property 

137 def grad_fn(self) -> Optional[torch.autograd.Function]: 

138 """ 

139 Get the gradient function that created this tensor. 

140 

141 Returns: 

142 Optional[torch.autograd.Function]: The gradient function, or None if not applicable. 

143 """ 

144 return self._local_tensor.grad_fn 

145 

146 def grad_zero_(self): 

147 """ 

148 Zero out the gradient tensor in-place. 

149 

150 Returns: 

151 DTensorBase: Self for method chaining. 

152 """ 

153 if self._local_tensor.grad is not None: 

154 self._local_tensor.grad.zero_() 

155 return self 

156 

157 def detach(self): 

158 """ 

159 Create a detached DTensor that does not require gradient. 

160 

161 Returns: 

162 DTensorBase: A new DTensor with the same data but detached from the computation graph. 

163 """ 

164 detached_local = self._local_tensor.detach() 

165 alias_p = self._layout.alias_placements if hasattr(self, '_layout') and self._layout else self._placements 

166 return self.__class__(detached_local, device_mesh=self._device_mesh, placements=alias_p) 

167 

168 def detach_(self): 

169 """ 

170 Detach this tensor from the computation graph in-place. 

171 

172 Returns: 

173 DTensorBase: Self for method chaining. 

174 """ 

175 self._local_tensor.detach_() 

176 super().detach_() 

177 return self 

178 

179 # ====================== Computation graph related overrides ====================== 

180 @property 

181 def is_leaf(self) -> bool: 

182 """ 

183 Check if this tensor is a leaf node in the computation graph. 

184 

185 Returns: 

186 bool: True if this is a leaf tensor (created by user, not by any operation). 

187 """ 

188 return self._local_tensor.is_leaf 

189 

190 @property 

191 def retains_grad(self) -> bool: 

192 """ 

193 Check if this tensor retains its gradient during backward pass. 

194 

195 Returns: 

196 bool: True if gradients are retained for non-leaf tensors. 

197 """ 

198 return self._local_tensor.retains_grad 

199 

200 @retains_grad.setter 

201 def retains_grad(self, value: bool) -> None: 

202 """ 

203 Enable or disable gradient retention for this tensor. 

204 

205 Args: 

206 value (bool): True to enable gradient retention. 

207 """ 

208 self._local_tensor.retains_grad_(value) 

209 

210 def backward(self, gradient=None, retain_graph=None, create_graph=False) -> None: 

211 """ 

212 Compute the gradients for this tensor. 

213 

214 Args: 

215 gradient (Optional[Tensor]): The gradient of the loss w.r.t. this tensor. 

216 retain_graph (Optional[bool]): Whether to retain the computation graph. 

217 create_graph (bool): Whether to create a graph of the gradient computation. 

218 """ 

219 self._local_tensor.backward(gradient, retain_graph, create_graph) 

220 

221 # ====================== Metadata related overrides (sync with local_tensor) ====================== 

222 @property 

223 def device(self) -> torch.device: 

224 """ 

225 Get the device on which this tensor is stored. 

226 

227 Returns: 

228 torch.device: The device object (e.g., 'cuda:0', 'cpu'). 

229 """ 

230 return self._local_tensor.device 

231 

232 @property 

233 # pylint: disable=C2801 

234 def data(self): 

235 return Tensor.data.__get__(self, type(self)) 

236 

237 @data.setter 

238 # pylint: disable=C2801 

239 def data(self, value): 

240 local_value = value.to_local() if isinstance(value, DTensorBase) else value 

241 Tensor.data.__set__(self, local_value) 

242 Tensor.data.__set__(self._local_tensor, local_value) 

243 

244 @property 

245 def dtype(self) -> torch.dtype: 

246 """ 

247 Get the data type of this tensor. 

248 

249 Returns: 

250 torch.dtype: The data type (e.g., torch.float32, torch.int64). 

251 """ 

252 return self._local_tensor.dtype 

253 

254 @property 

255 def shape(self) -> torch.Size: 

256 """ 

257 Get the shape of this tensor. 

258 

259 Returns: 

260 torch.Size: The shape of the tensor. 

261 """ 

262 return self._local_tensor.shape 

263 

264 def type(self, dtype=None, non_blocking=False): 

265 """ 

266 Convert this tensor to the specified dtype. 

267 

268 Args: 

269 dtype (Optional[torch.dtype]): The target dtype. If None, returns the current type string. 

270 non_blocking (bool): Whether to perform the operation asynchronously. Default: False. 

271 

272 Returns: 

273 Union[str, DTensorBase]: The type string if dtype is None, otherwise a new DTensor. 

274 """ 

275 if dtype is None: 

276 return self._local_tensor.type() 

277 new_local = self._local_tensor.to(dtype=dtype, non_blocking=non_blocking) 

278 alias_p = self._layout.alias_placements if hasattr(self, '_layout') and self._layout else self._placements 

279 return self.__class__(new_local, device_mesh=self._device_mesh, placements=alias_p) 

280 

281 def size(self, dim: Optional[int] = None): 

282 """ 

283 Get the size of this tensor. 

284 

285 Args: 

286 dim (Optional[int]): The dimension to query. If None, returns the full shape. 

287 

288 Returns: 

289 Union[torch.Size, int]: The shape or size along a specific dimension. 

290 """ 

291 return self._local_tensor.size(dim) 

292 

293 @property 

294 def ndim(self) -> int: 

295 """ 

296 Get the number of dimensions of this tensor. 

297 

298 Returns: 

299 int: The number of dimensions. 

300 """ 

301 return self._local_tensor.ndim 

302 

303 def data_ptr(self) -> int: 

304 """ 

305 Get the pointer to the data storage of the local tensor. 

306 

307 Returns: 

308 int: The memory address of the tensor's data. 

309 """ 

310 # Force return local_tensor's data pointer (ensure address consistency) 

311 return self._local_tensor.data_ptr() 

312 

313 def numel(self) -> int: 

314 """ 

315 Get the total number of elements in this tensor. 

316 

317 Returns: 

318 int: The total number of elements. 

319 """ 

320 return self._local_tensor.numel() 

321 

322 # ====================== Data operation overrides (sync storage + fix in-place ops) ====================== 

323 def zero_(self): 

324 """Set tensor zeros""" 

325 if self._local_tensor.requires_grad and self._local_tensor.is_leaf: 

326 # Create new tensor + rebind DTensor (ensure storage sharing) 

327 new_local = torch.zeros_like(self._local_tensor, requires_grad=True) 

328 # Key: sync DTensor wrapper's storage to new local_tensor 

329 super().copy_(new_local) # sync underlying data 

330 self._local_tensor = new_local # replace internal attribute 

331 else: 

332 self._local_tensor.zero_() 

333 super().zero_() # sync wrapper's in-place zero 

334 return self 

335 

336 def copy_(self, src: Tensor, non_blocking: bool = False): 

337 """Copy data from src tensor""" 

338 if self._local_tensor.requires_grad and self._local_tensor.is_leaf: 

339 new_local = src.to(self._local_tensor.device, non_blocking=non_blocking).detach().clone() 

340 new_local.requires_grad = self._local_tensor.requires_grad 

341 super().copy_(new_local) 

342 self._local_tensor = new_local 

343 else: 

344 self._local_tensor.copy_(src, non_blocking=non_blocking) 

345 super().copy_(src, non_blocking=non_blocking) 

346 return self 

347 

348 def fill_(self, value): 

349 """Fill tensor with value""" 

350 if self._local_tensor.requires_grad and self._local_tensor.is_leaf: 

351 # Step 1: Create new tensor (non-in-place) 

352 new_local = torch.full_like( 

353 self._local_tensor, 

354 fill_value=value, 

355 requires_grad=True, 

356 device=self._local_tensor.device 

357 ) 

358 # Step 2: Sync DTensor wrapper's underlying storage to new local_tensor 

359 super().copy_(new_local) # Key: make DTensor wrapper point to new address 

360 # Step 3: Replace internal local_tensor (ensure attribute consistency) 

361 self._local_tensor = new_local 

362 else: 

363 # Non-leaf tensor: direct in-place fill + sync wrapper 

364 self._local_tensor.fill_(value) 

365 super().fill_(value) # sync DTensor wrapper's fill 

366 return self 

367 

368 # ====================== Auxiliary print ====================== 

369 def __repr__(self) -> str: 

370 return ( 

371 f"DTensor(\n" 

372 f" local_tensor={self._local_tensor},\n" 

373 f" device_mesh={self._device_mesh},\n" 

374 f" placements={self._placements},\n" 

375 f" layout={getattr(self, '_layout', None)},\n" 

376 f" device={self.device},\n" 

377 f" dtype={self.dtype},\n" 

378 f" requires_grad={self.requires_grad},\n" 

379 f" grad={self.grad},\n" 

380 f" is_leaf={self.is_leaf},\n" 

381 f" data_ptr={self.data_ptr()}\n" 

382 f")" 

383 )