Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / dtensor.py: 57%
125 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""torch dtensor base"""
16from typing import Tuple, Dict, Any, Optional
17import torch
18from torch import Tensor
21class DTensorBase(Tensor):
22 """torch dtensor base"""
24 def __new__(cls, local_tensor, device_mesh=None, placements=None):
25 """
26 Create a new DTensorBase instance.
28 Args:
29 local_tensor: The local tensor shard or another DTensorBase instance.
30 device_mesh: The device mesh describing the device topology.
31 placements: The placement strategy for each mesh dimension.
32 """
33 if isinstance(local_tensor, DTensorBase):
34 # Copy from existing DTensorBase — use alias_placements to preserve multi-axis ordering
35 t = Tensor._make_subclass(cls, local_tensor._local_tensor, local_tensor._local_tensor.requires_grad)
36 copy_placements = local_tensor.layout.alias_placements if local_tensor.layout else local_tensor.placements
37 t.__init_data__(local_tensor._local_tensor, local_tensor.device_mesh, copy_placements)
38 return t
40 if device_mesh is None:
41 raise ValueError("device_mesh is None, must provide a DeviceMesh instance")
42 if placements is None:
43 raise ValueError("placements is None, must provide placements")
45 # Create Tensor subclass instance, sharing local_tensor's underlying storage
46 t = Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad)
47 t.__init_data__(local_tensor, device_mesh, placements)
48 return t
50 # pylint: disable=W0613
51 @classmethod
52 def __torch_function__(
53 cls,
54 func: torch._C._FunctionBase,
55 types: Tuple[type, ...],
56 args: Tuple[Any, ...] = (),
57 kwargs: Optional[Dict[str, Any]] = None
58 ) -> Any:
59 """
60 Override PyTorch's __torch_function__ to intercept tensor operations.
62 This method dispatches operations through the distributed operator dispatcher
63 to handle DTensor-specific layout inference and redistribution.
65 Args:
66 func (torch._C._FunctionBase): The PyTorch function being called.
67 types (Tuple[type, ...]): The types of tensors involved in the operation.
68 args (Tuple[Any, ...]): Positional arguments passed to the function.
69 kwargs (Optional[Dict[str, Any]]): Keyword arguments passed to the function.
71 Returns:
72 Any: The result of the dispatched operation, typically a DTensor or tuple of DTensors.
73 """
74 kwargs = kwargs or {}
75 # pylint: disable=C0415
76 from hyper_parallel.core.shard._op_dispatch import _OP_DISPATCHER
77 out = _OP_DISPATCHER.dispatch(func, args, kwargs)
78 return out
80 @property
81 def grad(self) -> Optional[Tensor]:
82 """
83 Get the gradient tensor of the local tensor.
85 Returns:
86 Optional[Tensor]: The gradient tensor, or None if no gradient is set.
87 """
88 return self._local_tensor.grad
90 @grad.setter
91 def grad(self, value: Optional[Tensor]) -> None:
92 """
93 Set the gradient tensor for the local tensor.
95 Args:
96 value (Optional[Tensor]): The gradient tensor to set, or None to clear.
97 """
98 self._local_tensor.grad = value
100 @property
101 def requires_grad(self) -> bool:
102 """
103 Check if gradient computation is enabled for this tensor.
105 Returns:
106 bool: True if gradients should be computed for this tensor.
107 """
108 return self._local_tensor.requires_grad
110 @requires_grad.setter
111 def requires_grad(self, value: bool) -> None:
112 """
113 Enable or disable gradient computation for this tensor.
115 Args:
116 value (bool): True to enable gradient computation, False to disable.
117 """
118 self._local_tensor.requires_grad_(value)
119 # Sync DTensor wrapper's requires_grad
120 super().requires_grad_(value)
122 def requires_grad_(self, requires_grad: bool = True):
123 """
124 Enable or disable gradient computation in-place.
126 Args:
127 requires_grad (bool): True to enable gradient computation. Default: True.
129 Returns:
130 DTensorBase: Self for method chaining.
131 """
132 self._local_tensor.requires_grad_(requires_grad)
133 super().requires_grad_(requires_grad)
134 return self
136 @property
137 def grad_fn(self) -> Optional[torch.autograd.Function]:
138 """
139 Get the gradient function that created this tensor.
141 Returns:
142 Optional[torch.autograd.Function]: The gradient function, or None if not applicable.
143 """
144 return self._local_tensor.grad_fn
146 def grad_zero_(self):
147 """
148 Zero out the gradient tensor in-place.
150 Returns:
151 DTensorBase: Self for method chaining.
152 """
153 if self._local_tensor.grad is not None:
154 self._local_tensor.grad.zero_()
155 return self
157 def detach(self):
158 """
159 Create a detached DTensor that does not require gradient.
161 Returns:
162 DTensorBase: A new DTensor with the same data but detached from the computation graph.
163 """
164 detached_local = self._local_tensor.detach()
165 alias_p = self._layout.alias_placements if hasattr(self, '_layout') and self._layout else self._placements
166 return self.__class__(detached_local, device_mesh=self._device_mesh, placements=alias_p)
168 def detach_(self):
169 """
170 Detach this tensor from the computation graph in-place.
172 Returns:
173 DTensorBase: Self for method chaining.
174 """
175 self._local_tensor.detach_()
176 super().detach_()
177 return self
179 # ====================== Computation graph related overrides ======================
180 @property
181 def is_leaf(self) -> bool:
182 """
183 Check if this tensor is a leaf node in the computation graph.
185 Returns:
186 bool: True if this is a leaf tensor (created by user, not by any operation).
187 """
188 return self._local_tensor.is_leaf
190 @property
191 def retains_grad(self) -> bool:
192 """
193 Check if this tensor retains its gradient during backward pass.
195 Returns:
196 bool: True if gradients are retained for non-leaf tensors.
197 """
198 return self._local_tensor.retains_grad
200 @retains_grad.setter
201 def retains_grad(self, value: bool) -> None:
202 """
203 Enable or disable gradient retention for this tensor.
205 Args:
206 value (bool): True to enable gradient retention.
207 """
208 self._local_tensor.retains_grad_(value)
210 def backward(self, gradient=None, retain_graph=None, create_graph=False) -> None:
211 """
212 Compute the gradients for this tensor.
214 Args:
215 gradient (Optional[Tensor]): The gradient of the loss w.r.t. this tensor.
216 retain_graph (Optional[bool]): Whether to retain the computation graph.
217 create_graph (bool): Whether to create a graph of the gradient computation.
218 """
219 self._local_tensor.backward(gradient, retain_graph, create_graph)
221 # ====================== Metadata related overrides (sync with local_tensor) ======================
222 @property
223 def device(self) -> torch.device:
224 """
225 Get the device on which this tensor is stored.
227 Returns:
228 torch.device: The device object (e.g., 'cuda:0', 'cpu').
229 """
230 return self._local_tensor.device
232 @property
233 # pylint: disable=C2801
234 def data(self):
235 return Tensor.data.__get__(self, type(self))
237 @data.setter
238 # pylint: disable=C2801
239 def data(self, value):
240 local_value = value.to_local() if isinstance(value, DTensorBase) else value
241 Tensor.data.__set__(self, local_value)
242 Tensor.data.__set__(self._local_tensor, local_value)
244 @property
245 def dtype(self) -> torch.dtype:
246 """
247 Get the data type of this tensor.
249 Returns:
250 torch.dtype: The data type (e.g., torch.float32, torch.int64).
251 """
252 return self._local_tensor.dtype
254 @property
255 def shape(self) -> torch.Size:
256 """
257 Get the shape of this tensor.
259 Returns:
260 torch.Size: The shape of the tensor.
261 """
262 return self._local_tensor.shape
264 def type(self, dtype=None, non_blocking=False):
265 """
266 Convert this tensor to the specified dtype.
268 Args:
269 dtype (Optional[torch.dtype]): The target dtype. If None, returns the current type string.
270 non_blocking (bool): Whether to perform the operation asynchronously. Default: False.
272 Returns:
273 Union[str, DTensorBase]: The type string if dtype is None, otherwise a new DTensor.
274 """
275 if dtype is None:
276 return self._local_tensor.type()
277 new_local = self._local_tensor.to(dtype=dtype, non_blocking=non_blocking)
278 alias_p = self._layout.alias_placements if hasattr(self, '_layout') and self._layout else self._placements
279 return self.__class__(new_local, device_mesh=self._device_mesh, placements=alias_p)
281 def size(self, dim: Optional[int] = None):
282 """
283 Get the size of this tensor.
285 Args:
286 dim (Optional[int]): The dimension to query. If None, returns the full shape.
288 Returns:
289 Union[torch.Size, int]: The shape or size along a specific dimension.
290 """
291 return self._local_tensor.size(dim)
293 @property
294 def ndim(self) -> int:
295 """
296 Get the number of dimensions of this tensor.
298 Returns:
299 int: The number of dimensions.
300 """
301 return self._local_tensor.ndim
303 def data_ptr(self) -> int:
304 """
305 Get the pointer to the data storage of the local tensor.
307 Returns:
308 int: The memory address of the tensor's data.
309 """
310 # Force return local_tensor's data pointer (ensure address consistency)
311 return self._local_tensor.data_ptr()
313 def numel(self) -> int:
314 """
315 Get the total number of elements in this tensor.
317 Returns:
318 int: The total number of elements.
319 """
320 return self._local_tensor.numel()
322 # ====================== Data operation overrides (sync storage + fix in-place ops) ======================
323 def zero_(self):
324 """Set tensor zeros"""
325 if self._local_tensor.requires_grad and self._local_tensor.is_leaf:
326 # Create new tensor + rebind DTensor (ensure storage sharing)
327 new_local = torch.zeros_like(self._local_tensor, requires_grad=True)
328 # Key: sync DTensor wrapper's storage to new local_tensor
329 super().copy_(new_local) # sync underlying data
330 self._local_tensor = new_local # replace internal attribute
331 else:
332 self._local_tensor.zero_()
333 super().zero_() # sync wrapper's in-place zero
334 return self
336 def copy_(self, src: Tensor, non_blocking: bool = False):
337 """Copy data from src tensor"""
338 if self._local_tensor.requires_grad and self._local_tensor.is_leaf:
339 new_local = src.to(self._local_tensor.device, non_blocking=non_blocking).detach().clone()
340 new_local.requires_grad = self._local_tensor.requires_grad
341 super().copy_(new_local)
342 self._local_tensor = new_local
343 else:
344 self._local_tensor.copy_(src, non_blocking=non_blocking)
345 super().copy_(src, non_blocking=non_blocking)
346 return self
348 def fill_(self, value):
349 """Fill tensor with value"""
350 if self._local_tensor.requires_grad and self._local_tensor.is_leaf:
351 # Step 1: Create new tensor (non-in-place)
352 new_local = torch.full_like(
353 self._local_tensor,
354 fill_value=value,
355 requires_grad=True,
356 device=self._local_tensor.device
357 )
358 # Step 2: Sync DTensor wrapper's underlying storage to new local_tensor
359 super().copy_(new_local) # Key: make DTensor wrapper point to new address
360 # Step 3: Replace internal local_tensor (ensure attribute consistency)
361 self._local_tensor = new_local
362 else:
363 # Non-leaf tensor: direct in-place fill + sync wrapper
364 self._local_tensor.fill_(value)
365 super().fill_(value) # sync DTensor wrapper's fill
366 return self
368 # ====================== Auxiliary print ======================
369 def __repr__(self) -> str:
370 return (
371 f"DTensor(\n"
372 f" local_tensor={self._local_tensor},\n"
373 f" device_mesh={self._device_mesh},\n"
374 f" placements={self._placements},\n"
375 f" layout={getattr(self, '_layout', None)},\n"
376 f" device={self.device},\n"
377 f" dtype={self.dtype},\n"
378 f" requires_grad={self.requires_grad},\n"
379 f" grad={self.grad},\n"
380 f" is_leaf={self.is_leaf},\n"
381 f" data_ptr={self.data_ptr()}\n"
382 f")"
383 )