Coverage for hyper_parallel / platform / torch / dtensor.py: 62%
118 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""torch dtensor base"""
16from typing import Tuple, Dict, Any, Optional
17import torch
18from torch import Tensor
21class DTensorBase(Tensor):
22 """torch dtensor base"""
24 def __new__(cls, local_tensor, device_mesh=None, placements=None):
25 """
26 Create a new DTensorBase instance.
28 Args:
29 local_tensor: The local tensor shard or another DTensorBase instance.
30 device_mesh: The device mesh describing the device topology.
31 placements: The placement strategy for each mesh dimension.
32 """
33 if isinstance(local_tensor, DTensorBase):
34 # Copy from existing DTensorBase
35 t = Tensor._make_subclass(cls, local_tensor._local_tensor, local_tensor._local_tensor.requires_grad)
36 t.__init_data__(local_tensor._local_tensor, local_tensor.device_mesh, local_tensor.placements)
37 return t
39 if device_mesh is None:
40 raise ValueError("device_mesh is None, must provide a DeviceMesh instance")
41 if placements is None:
42 raise ValueError("placements is None, must provide placements")
44 # Create Tensor subclass instance, sharing local_tensor's underlying storage
45 t = Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad)
46 t.__init_data__(local_tensor, device_mesh, placements)
47 return t
49 # pylint: disable=W0613
50 @classmethod
51 def __torch_function__(
52 cls,
53 func: torch._C._FunctionBase,
54 types: Tuple[type, ...],
55 args: Tuple[Any, ...] = (),
56 kwargs: Optional[Dict[str, Any]] = None
57 ) -> Any:
58 kwargs = kwargs or {}
59 # pylint: disable=C0415
60 from hyper_parallel.core.shard._op_dispatch import _OP_DISPATCHER
61 out = _OP_DISPATCHER.dispatch(func, args, kwargs)
62 return out
64 def to(self, *args, **kwargs):
65 """Move the DTensor to a different device or dtype."""
66 src_local = self._local_tensor
67 new_local = src_local.to(*args, **kwargs)
68 return self.__class__(new_local, device_mesh=self._device_mesh, placements=self._placements)
70 @property
71 def grad(self) -> Optional[Tensor]:
72 return self._local_tensor.grad
74 @grad.setter
75 def grad(self, value: Optional[Tensor]) -> None:
76 self._local_tensor.grad = value
78 @property
79 def requires_grad(self) -> bool:
80 return self._local_tensor.requires_grad
82 @requires_grad.setter
83 def requires_grad(self, value: bool) -> None:
84 self._local_tensor.requires_grad_(value)
85 # Sync DTensor wrapper's requires_grad
86 super().requires_grad_(value)
88 def requires_grad_(self, requires_grad: bool = True):
89 self._local_tensor.requires_grad_(requires_grad)
90 super().requires_grad_(requires_grad)
91 return self
93 @property
94 def grad_fn(self) -> Optional[torch.autograd.Function]:
95 return self._local_tensor.grad_fn
97 def grad_zero_(self):
98 if self._local_tensor.grad is not None:
99 self._local_tensor.grad.zero_()
100 return self
102 def detach(self):
103 detached_local = self._local_tensor.detach()
104 return self.__class__(detached_local, device_mesh=self._device_mesh, placements=self._placements)
106 def detach_(self):
107 self._local_tensor.detach_()
108 super().detach_()
109 return self
111 # ====================== Computation graph related overrides ======================
112 @property
113 def is_leaf(self) -> bool:
114 return self._local_tensor.is_leaf
116 @property
117 def retains_grad(self) -> bool:
118 return self._local_tensor.retains_grad
120 @retains_grad.setter
121 def retains_grad(self, value: bool) -> None:
122 self._local_tensor.retains_grad_(value)
124 def backward(self, gradient=None, retain_graph=None, create_graph=False) -> None:
125 self._local_tensor.backward(gradient, retain_graph, create_graph)
127 # ====================== Metadata related overrides (sync with local_tensor) ======================
128 @property
129 def device(self) -> torch.device:
130 return self._local_tensor.device
132 @property
133 def dtype(self) -> torch.dtype:
134 return self._local_tensor.dtype
136 @property
137 def shape(self) -> torch.Size:
138 return self._local_tensor.shape
140 def type(self, dtype=None, non_blocking=False):
141 if dtype is None:
142 return self._local_tensor.type()
143 new_local = self._local_tensor.to(dtype=dtype, non_blocking=non_blocking)
144 return self.__class__(new_local, device_mesh=self._device_mesh, placements=self._placements)
146 def size(self, dim: Optional[int] = None):
147 return self._local_tensor.size(dim)
149 @property
150 def ndim(self) -> int:
151 return self._local_tensor.ndim
153 def data_ptr(self) -> int:
154 # Force return local_tensor's data pointer (ensure address consistency)
155 return self._local_tensor.data_ptr()
157 def numel(self) -> int:
158 return self._local_tensor.numel()
160 # ====================== Data operation overrides (sync storage + fix in-place ops) ======================
161 def zero_(self):
162 """Set tensor zeros"""
163 if self._local_tensor.requires_grad and self._local_tensor.is_leaf:
164 # Create new tensor + rebind DTensor (ensure storage sharing)
165 new_local = torch.zeros_like(self._local_tensor, requires_grad=True)
166 # Key: sync DTensor wrapper's storage to new local_tensor
167 super().copy_(new_local) # sync underlying data
168 self._local_tensor = new_local # replace internal attribute
169 else:
170 self._local_tensor.zero_()
171 super().zero_() # sync wrapper's in-place zero
172 return self
174 def copy_(self, src: Tensor, non_blocking: bool = False):
175 """Copy data from src tensor"""
176 if self._local_tensor.requires_grad and self._local_tensor.is_leaf:
177 new_local = src.to(self._local_tensor.device, non_blocking=non_blocking).detach().clone()
178 new_local.requires_grad = self._local_tensor.requires_grad
179 super().copy_(new_local)
180 self._local_tensor = new_local
181 else:
182 self._local_tensor.copy_(src, non_blocking=non_blocking)
183 super().copy_(src, non_blocking=non_blocking)
184 return self
186 def fill_(self, value):
187 """Fill tensor with value"""
188 if self._local_tensor.requires_grad and self._local_tensor.is_leaf:
189 # Step 1: Create new tensor (non-in-place)
190 new_local = torch.full_like(
191 self._local_tensor,
192 fill_value=value,
193 requires_grad=True,
194 device=self._local_tensor.device
195 )
196 # Step 2: Sync DTensor wrapper's underlying storage to new local_tensor
197 super().copy_(new_local) # Key: make DTensor wrapper point to new address
198 # Step 3: Replace internal local_tensor (ensure attribute consistency)
199 self._local_tensor = new_local
200 else:
201 # Non-leaf tensor: direct in-place fill + sync wrapper
202 self._local_tensor.fill_(value)
203 super().fill_(value) # sync DTensor wrapper's fill
204 return self
206 # ====================== Auxiliary print ======================
207 def __repr__(self) -> str:
208 return (
209 f"DTensor(\n"
210 f" local_tensor={self._local_tensor},\n"
211 f" device_mesh={self._device_mesh},\n"
212 f" placements={self._placements},\n"
213 f" layout={getattr(self, '_layout', None)},\n"
214 f" device={self.device},\n"
215 f" dtype={self.dtype},\n"
216 f" requires_grad={self.requires_grad},\n"
217 f" grad={self.grad},\n"
218 f" is_leaf={self.is_leaf},\n"
219 f" data_ptr={self.data_ptr()}\n"
220 f")"
221 )