Coverage for hyper_parallel / platform / torch / dtensor.py: 62%

118 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""torch dtensor base""" 

16from typing import Tuple, Dict, Any, Optional 

17import torch 

18from torch import Tensor 

19 

20 

21class DTensorBase(Tensor): 

22 """torch dtensor base""" 

23 

24 def __new__(cls, local_tensor, device_mesh=None, placements=None): 

25 """ 

26 Create a new DTensorBase instance. 

27 

28 Args: 

29 local_tensor: The local tensor shard or another DTensorBase instance. 

30 device_mesh: The device mesh describing the device topology. 

31 placements: The placement strategy for each mesh dimension. 

32 """ 

33 if isinstance(local_tensor, DTensorBase): 

34 # Copy from existing DTensorBase 

35 t = Tensor._make_subclass(cls, local_tensor._local_tensor, local_tensor._local_tensor.requires_grad) 

36 t.__init_data__(local_tensor._local_tensor, local_tensor.device_mesh, local_tensor.placements) 

37 return t 

38 

39 if device_mesh is None: 

40 raise ValueError("device_mesh is None, must provide a DeviceMesh instance") 

41 if placements is None: 

42 raise ValueError("placements is None, must provide placements") 

43 

44 # Create Tensor subclass instance, sharing local_tensor's underlying storage 

45 t = Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad) 

46 t.__init_data__(local_tensor, device_mesh, placements) 

47 return t 

48 

49 # pylint: disable=W0613 

50 @classmethod 

51 def __torch_function__( 

52 cls, 

53 func: torch._C._FunctionBase, 

54 types: Tuple[type, ...], 

55 args: Tuple[Any, ...] = (), 

56 kwargs: Optional[Dict[str, Any]] = None 

57 ) -> Any: 

58 kwargs = kwargs or {} 

59 # pylint: disable=C0415 

60 from hyper_parallel.core.shard._op_dispatch import _OP_DISPATCHER 

61 out = _OP_DISPATCHER.dispatch(func, args, kwargs) 

62 return out 

63 

64 def to(self, *args, **kwargs): 

65 """Move the DTensor to a different device or dtype.""" 

66 src_local = self._local_tensor 

67 new_local = src_local.to(*args, **kwargs) 

68 return self.__class__(new_local, device_mesh=self._device_mesh, placements=self._placements) 

69 

70 @property 

71 def grad(self) -> Optional[Tensor]: 

72 return self._local_tensor.grad 

73 

74 @grad.setter 

75 def grad(self, value: Optional[Tensor]) -> None: 

76 self._local_tensor.grad = value 

77 

78 @property 

79 def requires_grad(self) -> bool: 

80 return self._local_tensor.requires_grad 

81 

82 @requires_grad.setter 

83 def requires_grad(self, value: bool) -> None: 

84 self._local_tensor.requires_grad_(value) 

85 # Sync DTensor wrapper's requires_grad 

86 super().requires_grad_(value) 

87 

88 def requires_grad_(self, requires_grad: bool = True): 

89 self._local_tensor.requires_grad_(requires_grad) 

90 super().requires_grad_(requires_grad) 

91 return self 

92 

93 @property 

94 def grad_fn(self) -> Optional[torch.autograd.Function]: 

95 return self._local_tensor.grad_fn 

96 

97 def grad_zero_(self): 

98 if self._local_tensor.grad is not None: 

99 self._local_tensor.grad.zero_() 

100 return self 

101 

102 def detach(self): 

103 detached_local = self._local_tensor.detach() 

104 return self.__class__(detached_local, device_mesh=self._device_mesh, placements=self._placements) 

105 

106 def detach_(self): 

107 self._local_tensor.detach_() 

108 super().detach_() 

109 return self 

110 

111 # ====================== Computation graph related overrides ====================== 

112 @property 

113 def is_leaf(self) -> bool: 

114 return self._local_tensor.is_leaf 

115 

116 @property 

117 def retains_grad(self) -> bool: 

118 return self._local_tensor.retains_grad 

119 

120 @retains_grad.setter 

121 def retains_grad(self, value: bool) -> None: 

122 self._local_tensor.retains_grad_(value) 

123 

124 def backward(self, gradient=None, retain_graph=None, create_graph=False) -> None: 

125 self._local_tensor.backward(gradient, retain_graph, create_graph) 

126 

127 # ====================== Metadata related overrides (sync with local_tensor) ====================== 

128 @property 

129 def device(self) -> torch.device: 

130 return self._local_tensor.device 

131 

132 @property 

133 def dtype(self) -> torch.dtype: 

134 return self._local_tensor.dtype 

135 

136 @property 

137 def shape(self) -> torch.Size: 

138 return self._local_tensor.shape 

139 

140 def type(self, dtype=None, non_blocking=False): 

141 if dtype is None: 

142 return self._local_tensor.type() 

143 new_local = self._local_tensor.to(dtype=dtype, non_blocking=non_blocking) 

144 return self.__class__(new_local, device_mesh=self._device_mesh, placements=self._placements) 

145 

146 def size(self, dim: Optional[int] = None): 

147 return self._local_tensor.size(dim) 

148 

149 @property 

150 def ndim(self) -> int: 

151 return self._local_tensor.ndim 

152 

153 def data_ptr(self) -> int: 

154 # Force return local_tensor's data pointer (ensure address consistency) 

155 return self._local_tensor.data_ptr() 

156 

157 def numel(self) -> int: 

158 return self._local_tensor.numel() 

159 

160 # ====================== Data operation overrides (sync storage + fix in-place ops) ====================== 

161 def zero_(self): 

162 """Set tensor zeros""" 

163 if self._local_tensor.requires_grad and self._local_tensor.is_leaf: 

164 # Create new tensor + rebind DTensor (ensure storage sharing) 

165 new_local = torch.zeros_like(self._local_tensor, requires_grad=True) 

166 # Key: sync DTensor wrapper's storage to new local_tensor 

167 super().copy_(new_local) # sync underlying data 

168 self._local_tensor = new_local # replace internal attribute 

169 else: 

170 self._local_tensor.zero_() 

171 super().zero_() # sync wrapper's in-place zero 

172 return self 

173 

174 def copy_(self, src: Tensor, non_blocking: bool = False): 

175 """Copy data from src tensor""" 

176 if self._local_tensor.requires_grad and self._local_tensor.is_leaf: 

177 new_local = src.to(self._local_tensor.device, non_blocking=non_blocking).detach().clone() 

178 new_local.requires_grad = self._local_tensor.requires_grad 

179 super().copy_(new_local) 

180 self._local_tensor = new_local 

181 else: 

182 self._local_tensor.copy_(src, non_blocking=non_blocking) 

183 super().copy_(src, non_blocking=non_blocking) 

184 return self 

185 

186 def fill_(self, value): 

187 """Fill tensor with value""" 

188 if self._local_tensor.requires_grad and self._local_tensor.is_leaf: 

189 # Step 1: Create new tensor (non-in-place) 

190 new_local = torch.full_like( 

191 self._local_tensor, 

192 fill_value=value, 

193 requires_grad=True, 

194 device=self._local_tensor.device 

195 ) 

196 # Step 2: Sync DTensor wrapper's underlying storage to new local_tensor 

197 super().copy_(new_local) # Key: make DTensor wrapper point to new address 

198 # Step 3: Replace internal local_tensor (ensure attribute consistency) 

199 self._local_tensor = new_local 

200 else: 

201 # Non-leaf tensor: direct in-place fill + sync wrapper 

202 self._local_tensor.fill_(value) 

203 super().fill_(value) # sync DTensor wrapper's fill 

204 return self 

205 

206 # ====================== Auxiliary print ====================== 

207 def __repr__(self) -> str: 

208 return ( 

209 f"DTensor(\n" 

210 f" local_tensor={self._local_tensor},\n" 

211 f" device_mesh={self._device_mesh},\n" 

212 f" placements={self._placements},\n" 

213 f" layout={getattr(self, '_layout', None)},\n" 

214 f" device={self.device},\n" 

215 f" dtype={self.dtype},\n" 

216 f" requires_grad={self.requires_grad},\n" 

217 f" grad={self.grad},\n" 

218 f" is_leaf={self.is_leaf},\n" 

219 f" data_ptr={self.data_ptr()}\n" 

220 f")" 

221 )