Coverage for hyper_parallel / core / dtensor.py: 93%
107 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""dtensor"""
16import copy as cp
17from typing import Sequence, Tuple
18import numpy as np
19from hyper_parallel.core.layout import Layout, DeviceMesh, _get_slice_tensor_by_layout
20from hyper_parallel.core.placement_types import Placement, Replicate
21from hyper_parallel.platform import get_platform
22from hyper_parallel.core.utils import compute_local_shape_and_global_offset
24platform = get_platform()
25DTensorBase = platform.DTensorBase
26Tensor = platform.Tensor
29class SkipDTensorDispatch():
30 def __enter__(self):
31 # pylint: disable=C0415
32 from hyper_parallel.core.shard._op_dispatch import disable_dtensor_dispatch
33 disable_dtensor_dispatch()
35 def __exit__(self, exc_type, exc_val, exc_tb):
36 # pylint: disable=C0415
37 from hyper_parallel.core.shard._op_dispatch import enable_dtensor_dispatch
38 enable_dtensor_dispatch()
41def _build_layout(
42 device_mesh: DeviceMesh,
43 placements: Sequence[Placement],
44 tensor_dim: int
45) -> Layout:
46 """
47 Build Layout from device_mesh and placements.
49 Args:
50 device_mesh: The device mesh describing the device topology.
51 placements: Sequence of Placement objects (Shard, Replicate, etc.).
52 tensor_dim: Number of dimensions in the tensor.
54 Returns:
55 Layout: The built layout object.
56 """
57 layout = Layout.from_device_mesh(device_mesh)
58 result = layout(placements)
59 result.placement_to_tensor_map(tensor_dim)
60 return result
63class DTensor(DTensorBase):
64 """
65 DTensor - Distributed Tensor
67 A DTensor represents a tensor that is distributed across multiple devices
68 according to a DeviceMesh and placement specifications.
70 Args:
71 local_tensor (Tensor): The local tensor shard on this device.
72 device_mesh (DeviceMesh): The device mesh describing the device topology.
73 placements (Sequence[Placement]): The placement strategy for each mesh dimension.
74 Each element should be a Placement object (Shard, Replicate, Partial, etc.).
76 Example:
77 >>> from hyper_parallel.core.placement_types import Shard, Replicate
78 >>> mesh = init_device_mesh(device_type="npu", mesh_shape=(2, 2), mesh_dim_names=("dp", "tp"))
79 >>> local_tensor = Tensor(np.ones((4, 4)))
80 >>> dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0), Replicate()])
81 """
82 _local_tensor: Tensor
83 _device_mesh: DeviceMesh
84 _placements: Sequence[Placement]
86 def __init_data__(
87 self,
88 local_tensor: Tensor,
89 device_mesh: DeviceMesh,
90 placements: Sequence[Placement]
91 ):
92 self._local_tensor = local_tensor
93 self._device_mesh = device_mesh
94 self._placements = tuple(placements)
95 # Build internal layout for redistribution operations
96 self._layout = _build_layout(
97 device_mesh, placements, len(local_tensor.shape)
98 )
100 @property
101 def device_mesh(self) -> DeviceMesh:
102 """The device mesh of this DTensor."""
103 return self._device_mesh
105 @property
106 def placements(self) -> Sequence[Placement]:
107 """The placements of this DTensor."""
108 return self._placements
110 @property
111 def layout(self) -> Layout:
112 """Internal layout for redistribution (for backward compatibility)."""
113 if not hasattr(self, '_layout'):
114 return None
115 return self._layout
117 @staticmethod
118 def from_local(
119 local_tensor: Tensor,
120 device_mesh: DeviceMesh,
121 placements: Sequence[Placement]
122 ) -> 'DTensor':
123 """
124 Create a DTensor from a local tensor with device mesh and placements.
126 Args:
127 local_tensor (Tensor): The local tensor shard on this device.
128 device_mesh (DeviceMesh): The device mesh describing the device topology.
129 placements (Sequence[Placement]): The placement strategy. Each element
130 should be a Placement object (Shard, Replicate, Partial, etc.).
132 Returns:
133 DTensor: A new DTensor instance.
135 Example:
136 >>> from hyper_parallel.core.placement_types import Shard, Replicate
137 >>> mesh = init_device_mesh(device_type="npu", mesh_shape=(2, 2), mesh_dim_names=("dp", "tp"))
138 >>> local_tensor = Tensor(np.ones((4, 4)))
139 >>> dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0), Replicate()])
140 """
141 return DTensor(local_tensor, device_mesh, placements)
143 def to_local(self) -> Tensor:
144 """
145 Convert DTensor to local tensor.
147 Returns:
148 Tensor: The local tensor shard on this device.
149 """
150 return self._local_tensor
152 @property
153 def shape(self) -> Tuple[int, ...]:
154 """
155 The global shape of this DTensor.
157 Returns:
158 Tuple[int, ...]: The global tensor shape.
159 """
160 return self._layout.get_global_shape(self._local_tensor.shape)
162 def size(self, dim=None):
163 """Return the global shape, consistent with .shape.
165 Without ``dim`` returns a tuple matching ``self.shape``.
166 With ``dim`` returns the size of that dimension.
167 """
168 global_shape = self.shape
169 if dim is not None:
170 return global_shape[dim]
171 return global_shape
173 def numel(self) -> int:
174 """Return the number of elements in this DTensor."""
175 return int(np.prod(self.shape))
177 @property
178 def local_shape(self) -> Tuple[int, ...]:
179 """
180 The local shape of this DTensor on this device.
182 Returns:
183 Tuple[int, ...]: The local tensor shape.
184 """
185 return self._local_tensor.shape
187 def redistribute(
188 self,
189 device_mesh: DeviceMesh,
190 placements: Sequence[Placement]
191 ) -> 'DTensor':
192 """
193 Redistribute this DTensor to a new device mesh and placements.
195 Args:
196 device_mesh (DeviceMesh): The target device mesh.
197 placements (Sequence[Placement]): The target placements. Each element
198 should be a Placement object (Shard, Replicate, Partial, etc.).
200 Returns:
201 DTensor: A new DTensor with the specified distribution.
203 Example:
204 >>> from hyper_parallel.core.placement_types import Shard, Replicate
205 >>> new_dtensor = dtensor.redistribute(mesh, [Replicate(), Shard(1)])
206 """
207 # Build dst_layout from device_mesh and placements
208 dst_layout = _build_layout(
209 device_mesh, placements, len(self._local_tensor.shape)
210 )
212 # pylint: disable=C0415
213 from hyper_parallel.core.tensor_redistribution import _tensor_redistribution
214 out = _tensor_redistribution.redistribution(self, dst_layout)
215 return out
217 def reduce_partial(self) -> 'DTensor':
218 """
219 Reduce partial sharding state for this DTensor.
221 Returns:
222 DTensor: A new DTensor with partial state reduced.
223 """
224 if not self._layout:
225 return self
226 to_layout = cp.deepcopy(self._layout)
227 to_layout.reset_partial()
228 # pylint: disable=C0415
229 from hyper_parallel.core.tensor_redistribution import _tensor_redistribution
230 out = _tensor_redistribution.reduce_partial(self, to_layout)
231 return out
233 def full_tensor(self) -> Tensor:
234 """
235 Return the full tensor of this DTensor.
237 Returns:
238 Tensor: A Tensor object that represents the full tensor of this DTensor.
239 The returned tensor contains the complete data gathered from
240 all ranks.
242 Note:
243 This operation involves communication across all ranks in the DeviceMesh,
244 which may be expensive for large tensors. Use with caution in
245 performance-critical code paths.
247 Example:
248 >>> # Assume dtensor is sharded across multiple devices
249 >>> local_tensor = dtensor.to_local() # Returns only the local shard
250 >>> full_tensor = dtensor.full_tensor() # Returns the complete tensor
251 """
252 if not self._layout:
253 return self._local_tensor
255 # Create a fully replicated layout
256 replicated_layout = cp.deepcopy(self._layout)
258 # Set all placements to Replicate and convert to tensor_map
259 replicated_placements = [Replicate()] * len(replicated_layout.mesh_shape)
260 replicated_layout.set_placements(replicated_placements)
261 replicated_layout.placement_to_tensor_map(len(self._local_tensor.shape))
263 # Clear partial status from original layout since Replicate has no partial
264 replicated_layout.reset_partial()
266 # Redistribute to the replicated layout and return local tensor
267 # pylint: disable=C0415
268 from hyper_parallel.core.tensor_redistribution import _tensor_redistribution
269 out = _tensor_redistribution.redistribution(self, replicated_layout)
270 return out.to_local()
273def distribute_tensor(
274 tensor: Tensor,
275 device_mesh: DeviceMesh,
276 placements: Sequence[Placement]
277) -> DTensor:
278 """
279 Distribute a global tensor to the device mesh according to the placements.
281 Args:
282 tensor (Tensor): The global tensor to be distributed. All ranks
283 should have the same tensor data.
284 device_mesh (DeviceMesh): The device mesh describing the device topology.
285 placements (Sequence[Placement]): The placement strategy for distribution.
286 Each element should be a Placement object (Shard, Replicate, etc.).
288 Returns:
289 DTensor: A new DTensor with the local shard on each rank.
291 Note:
292 This method assumes all ranks have the same global tensor. It slices
293 the tensor locally without communication. If ranks have different
294 data, use `from_local` instead.
296 Example:
297 >>> from hyper_parallel.core.placement_types import Shard, Replicate
298 >>> mesh = init_device_mesh(mesh_shape=(2, 2), alias_name=("dp", "tp"))
299 >>> global_tensor = Tensor(np.arange(16).reshape(4, 4))
300 >>> dtensor = distribute_tensor(global_tensor, mesh, [Shard(0), Replicate()])
301 >>> # rank 0 and rank1 gets: [[0,1,2,3], [4,5,6,7]]
302 >>> # rank 2 and rank3 gets: [[8,9,10,11], [12,13,14,15]]
303 """
304 # Build layout from device_mesh and placements
305 layout = _build_layout(device_mesh, placements, len(tensor.shape))
307 # Slice the global tensor to get local shard based on layout
308 local_tensor = _get_slice_tensor_by_layout(tensor, layout)
310 return DTensor(local_tensor, device_mesh, placements)
312def _dtensor_init_helper(
313 init_op,
314 size,
315 device_mesh,
316 placements,
317 **kwargs,
318) -> DTensor:
319 """
320 Helper function to create and initialize a distributed tensor.
322 Args:
323 size: Shape of the tensor.
324 dtype: Data type of the tensor.
325 device: Target device for the tensor.
326 requires_grad: Whether the tensor requires gradient.
328 Returns:
329 DTensor: The initialized distributed tensor.
330 """
331 # get local tensor shape
332 local_shape = compute_local_shape_and_global_offset(
333 size, device_mesh, placements
334 )
336 # initialize the local tensor
337 if init_op is platform.full:
338 fill_value = kwargs.pop("fill_value", 0)
339 local_tensor = init_op(local_shape, fill_value, **kwargs)
340 else:
341 local_tensor = init_op(local_shape, **kwargs)
343 return DTensor.from_local(
344 local_tensor,
345 device_mesh,
346 placements,
347 )
349def ones(
350 size,
351 device_mesh,
352 placements,
353) -> DTensor:
354 """
355 Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined
356 by the variable argument ``size``.
358 Args:
359 size (Union[tuple[int], list[int], int, Tensor]): The specified shape of output tensor. Only positive integer or
360 tuple or Tensor containing positive integers are allowed. If it is a Tensor,
361 it must be a 0-D or 1-D Tensor with int32 or int64 dtypes.
363 Keyword args:
364 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
365 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
367 Returns:
368 A :class:`DTensor` object on each rank
369 """
370 ones_ = platform.ones
371 return _dtensor_init_helper(
372 ones_,
373 size,
374 device_mesh=device_mesh,
375 placements=placements,
376 )
378def empty(
379 size,
380 device_mesh,
381 placements,
382) -> DTensor:
383 """
384 Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor`
385 is defined by the variable argument ``size``.
387 Args:
388 size (Union[tuple[int], list[int], int]): The specified shape of output tensor. Can be variable numbers of
389 positive integers or tuple or list containing positive integers.
391 Keyword args:
392 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
393 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
395 Returns:
396 A :class:`DTensor` object on each rank
397 """
398 empty_ = platform.empty
399 return _dtensor_init_helper(
400 empty_,
401 size,
402 device_mesh=device_mesh,
403 placements=placements,
404 )
407def full(
408 size,
409 fill_value,
410 *,
411 device_mesh,
412 placements,
413) -> DTensor:
414 """
415 Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and
416 ``placements``, with the shape defined by the argument ``size``.
418 Args:
419 size (Union[tuple[int], list[int]]): The specified shape of output tensor.
420 fill_value (Union[numbers.Number, Tensor]): Value to fill the returned tensor. It can be a scalar number, a 0-D
421 Tensor, or a 1-D Tensor with only one element.
423 Keyword args:
424 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
425 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
427 Returns:
428 A :class:`DTensor` object on each rank
429 """
430 full_ = platform.full
431 return _dtensor_init_helper(
432 full_,
433 size,
434 fill_value=fill_value,
435 device_mesh=device_mesh,
436 placements=placements,
437 )
439def zeros(
440 size,
441 device_mesh,
442 placements,
443) -> DTensor:
444 """
445 Returns a :class:`DTensor` filled with the scalar value 0.
447 Args:
448 size (Union[tuple[int], list[int], int, Tensor]): The specified shape of output tensor. Only positive integer or
449 tuple or Tensor containing positive integers are allowed. If it is a Tensor,
450 it must be a 0-D or 1-D Tensor with int32 or int64 dtypes.
451 Keyword args:
452 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
453 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
455 Returns:
456 A :class:`DTensor` object on each rank
457 """
458 zeros_ = platform.zeros
459 return _dtensor_init_helper(
460 zeros_,
461 size,
462 device_mesh=device_mesh,
463 placements=placements,
464 )