Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / dtensor.py: 73%
239 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""dtensor"""
16import copy as cp
17import inspect
18import warnings
19from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union
21import numpy as np
23from hyper_parallel.core.dtensor.device_mesh import _mesh_resources
24from hyper_parallel.core.dtensor.layout import Layout, DeviceMesh, _get_slice_tensor_by_layout
25from hyper_parallel.core.dtensor.placement_types import Placement, Replicate
26from hyper_parallel.platform import get_platform
27from hyper_parallel.platform.platform import PlatformType
28from hyper_parallel.core.utils import compute_local_shape_and_global_offset
30platform = get_platform()
31DTensorBase = platform.DTensorBase
32Tensor = platform.Tensor
35class SkipDTensorDispatch():
36 """Context manager that disables DTensor op dispatch for the enclosed block.
38 Args:
39 no_skip: Optional set of op callables or canonical op name strings that
40 should still be dispatched through DTensor even within this context.
41 All other ops bypass DTensor dispatch and operate on local tensors.
43 Example:
44 >>> import torch
45 >>> with SkipDTensorDispatch(no_skip={torch.zeros_like}):
46 ... # zeros_like still goes through DTensor dispatch;
47 ... # everything else uses the local tensor path.
48 ... result = torch.zeros_like(dtensor)
49 """
51 def __init__(self, no_skip: Optional[Set] = None):
52 self._no_skip_names: Set[str] = set()
53 if no_skip:
54 for op in no_skip:
55 if isinstance(op, str):
56 self._no_skip_names.add(op)
57 else:
58 self._no_skip_names.add(platform.get_op_name(op))
60 def __enter__(self):
61 # pylint: disable=C0415
62 from hyper_parallel.core.shard._op_dispatch import disable_dtensor_dispatch, add_no_skip_ops
63 disable_dtensor_dispatch()
64 if self._no_skip_names:
65 add_no_skip_ops(self._no_skip_names)
67 def __exit__(self, exc_type, exc_val, exc_tb):
68 # pylint: disable=C0415
69 from hyper_parallel.core.shard._op_dispatch import enable_dtensor_dispatch, remove_no_skip_ops
70 if self._no_skip_names:
71 remove_no_skip_ops(self._no_skip_names)
72 enable_dtensor_dispatch()
75# Cache for _build_layout to avoid redundant Layout computations
76# Key: (device_mesh.to_hash(), tuple(placements), tensor_dim)
77# Value: Layout
78_LAYOUT_CACHE = {}
81def _is_alias_placements(placements) -> bool:
82 """
83 Check if placements use alias strings rather than Placement objects.
85 Alias placements use mesh dimension names (strings) to specify
86 the sharding strategy, e.g., ("dp", "tp") or (("dp", "tp"), "None").
87 All elements must be strings or tuples of strings for the sequence
88 to be recognized as alias-style.
90 Args:
91 placements: A sequence of placement specifications.
93 Returns:
94 bool: True if all elements are alias strings or tuples of strings.
95 """
96 if len(placements) == 0:
97 return False
98 for p in placements:
99 if isinstance(p, str):
100 continue
101 if isinstance(p, tuple) and len(p) > 0 and all(isinstance(x, str) for x in p):
102 continue
103 return False
104 return True
107def _build_layout(
108 device_mesh: DeviceMesh,
109 placements: Union[Sequence[Placement], Sequence[Union[str, Tuple[str, ...]]]],
110 tensor_dim: int
111) -> Layout:
112 """
113 Build Layout from device_mesh and placements.
115 This function uses a cache to avoid redundant Layout computations
116 for the same (device_mesh, placements, tensor_dim) combination.
118 Args:
119 device_mesh: The device mesh describing the device topology.
120 placements: Supports two styles:
121 - Placement objects (Shard, Replicate, etc.)
122 - Alias strings ("dp", "None", ("dp", "tp"), etc.), length must
123 equal the number of tensor dimensions (``tensor_dim``).
124 tensor_dim: Number of dimensions in the tensor.
126 Returns:
127 Layout: The built layout object.
129 Raises:
130 ValueError: If alias placements length does not match tensor dimensions.
131 """
132 mesh_key = device_mesh.to_hash()
133 placements_key = tuple(placements)
134 cache_key = (mesh_key, placements_key, tensor_dim)
136 if cache_key in _LAYOUT_CACHE:
137 return _LAYOUT_CACHE[cache_key]
139 layout = Layout.from_device_mesh(device_mesh)
141 if _is_alias_placements(placements):
142 if len(placements) != tensor_dim:
143 raise ValueError(
144 f"Alias placements length ({len(placements)}) must equal "
145 f"tensor dimensions ({tensor_dim})."
146 )
147 result = layout(*placements)
148 else:
149 result = layout(placements)
150 result.placement_to_tensor_map(tensor_dim)
152 _LAYOUT_CACHE[cache_key] = result
154 return result
157class DTensor(DTensorBase):
158 """
159 DTensor - Distributed Tensor
161 A DTensor represents a tensor that is distributed across multiple devices
162 according to a DeviceMesh and placement specifications.
164 Args:
165 local_tensor (Tensor): The local tensor shard on this device.
166 device_mesh (DeviceMesh): The device mesh describing the device topology.
167 placements: The placement strategy. Supports two styles:
168 - Placement objects (e.g., ``[Shard(0), Replicate()]``).
169 - Alias strings (e.g., ``("dp", "None")`` or
170 ``(("dp", "tp"), "None")``), length must equal the number of
171 tensor dimensions.
173 Example:
174 >>> mesh = init_device_mesh(device_type="npu", mesh_shape=(2, 2), mesh_dim_names=("dp", "tp"))
175 >>> local_tensor = Tensor(np.ones((4, 4)))
176 >>> # Placement style
177 >>> dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0), Replicate()])
178 >>> # Alias style — length matches tensor dims
179 >>> dtensor = DTensor.from_local(local_tensor, mesh, ("dp", "None"))
180 """
181 _local_tensor: Tensor
182 _device_mesh: DeviceMesh
183 _placements: Sequence[Placement]
185 def __init_data__(
186 self,
187 local_tensor: Tensor,
188 device_mesh: DeviceMesh,
189 placements: Union[Sequence[Placement], Sequence[Union[str, Tuple[str, ...]]]]
190 ):
191 self._local_tensor = local_tensor
192 self._device_mesh = device_mesh
193 self._layout = _build_layout(
194 device_mesh, placements, len(local_tensor.shape)
195 )
196 self._placements = tuple(self._layout.placements)
198 @property
199 def device_mesh(self) -> DeviceMesh:
200 """The device mesh of this DTensor."""
201 return self._device_mesh
203 @property
204 def placements(self) -> Sequence[Placement]:
205 """The placements of this DTensor."""
206 return self._placements
208 @property
209 def layout(self) -> Layout:
210 """Internal layout for redistribution (for backward compatibility)."""
211 if not hasattr(self, '_layout'):
212 return None
213 return self._layout
215 @staticmethod
216 def from_local(
217 local_tensor: Tensor,
218 device_mesh: DeviceMesh,
219 placements: Union[Sequence[Placement], Sequence[Union[str, Tuple[str, ...]]]]
220 ) -> 'DTensor':
221 """
222 Create a DTensor from a local tensor with device mesh and placements.
224 Args:
225 local_tensor (Tensor): The local tensor shard on this device.
226 device_mesh (DeviceMesh): The device mesh describing the device topology.
227 placements: The placement strategy. Supports two styles:
228 - Placement objects (e.g., ``[Shard(0), Replicate()]``).
229 - Alias strings (e.g., ``("dp", "None")`` or
230 ``(("dp", "tp"), "None")``), length must equal the number
231 of tensor dimensions.
233 Returns:
234 DTensor: A new DTensor instance.
236 Example:
237 >>> mesh = init_device_mesh(device_type="npu", mesh_shape=(2, 2), mesh_dim_names=("dp", "tp"))
238 >>> local_tensor = Tensor(np.ones((4, 4)))
239 >>> dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0), Replicate()])
240 >>> dtensor = DTensor.from_local(local_tensor, mesh, ("dp", "None"))
241 """
242 return DTensor(local_tensor, device_mesh, placements)
244 def _alias_placements(self) -> Sequence[Placement]:
245 """Return alias_placements from layout, falling back to _placements."""
246 if hasattr(self, '_layout') and self._layout:
247 return self._layout.alias_placements
248 return self._placements
250 def to(self, *args, **kwargs):
251 """Move the DTensor to a different device or dtype.
253 Delegates to the underlying local tensor's ``to`` method and
254 reconstructs a DTensor preserving device_mesh and placements.
256 Args:
257 *args (tuple): Arguments passed to the underlying tensor's ``to``
258 method (e.g., device or dtype).
259 **kwargs (dict): Keyword arguments for the tensor conversion
260 (e.g., dtype, device, non_blocking).
262 Returns:
263 DTensor: A new DTensor with the converted local tensor.
264 """
265 new_local = self._local_tensor.to(*args, **kwargs)
266 return self.__class__(new_local, device_mesh=self._device_mesh,
267 placements=self._alias_placements())
269 def float(self):
270 """Convert the DTensor to float dtype.
272 Returns:
273 DTensor: A new DTensor with float32 local tensor.
274 """
275 new_local = self._local_tensor.float()
276 return self.__class__(new_local, device_mesh=self._device_mesh,
277 placements=self._alias_placements())
279 def to_local(self) -> Tensor:
280 """
281 Convert DTensor to local tensor.
283 Returns:
284 Tensor: The local tensor shard on this device.
285 """
286 return self._local_tensor
288 @property
289 def shape(self) -> Tuple[int, ...]:
290 """
291 The global shape of this DTensor.
293 Returns:
294 Tuple[int, ...]: The global tensor shape.
295 """
296 return self._layout.get_global_shape(self._local_tensor.shape)
298 def size(self, dim=None):
299 """Return the global shape, consistent with .shape.
301 Without ``dim`` returns a tuple matching ``self.shape``.
302 With ``dim`` returns the size of that dimension.
303 """
304 global_shape = self.shape
305 if dim is not None:
306 return global_shape[dim]
307 return global_shape
309 def numel(self) -> int:
310 """Return the number of elements in this DTensor."""
311 return int(np.prod(self.shape))
313 @property
314 def local_shape(self) -> Tuple[int, ...]:
315 """
316 The local shape of this DTensor on this device.
318 Returns:
319 Tuple[int, ...]: The local tensor shape.
320 """
321 return self._local_tensor.shape
323 def redistribute(
324 self,
325 device_mesh: DeviceMesh,
326 placements: Union[Sequence[Placement], Sequence[Union[str, Tuple[str, ...]]]]
327 ) -> 'DTensor':
328 """
329 Redistribute this DTensor to a new device mesh and placements.
331 Args:
332 device_mesh (DeviceMesh): The target device mesh.
333 placements: The target placements. Supports Placement objects
334 or alias strings.
336 Returns:
337 DTensor: A new DTensor with the specified distribution.
339 Example:
340 >>> new_dtensor = dtensor.redistribute(mesh, [Replicate(), Shard(1)])
341 >>> new_dtensor = dtensor.redistribute(mesh, ("None", "tp"))
342 """
343 # Build dst_layout from device_mesh and placements
344 dst_layout = _build_layout(
345 device_mesh, placements, len(self._local_tensor.shape)
346 )
348 # pylint: disable=C0415
349 from hyper_parallel.core.dtensor.tensor_redistribution import _tensor_redistribution
350 out = _tensor_redistribution.redistribution(self, dst_layout)
351 return out
353 def reduce_partial(self) -> 'DTensor':
354 """
355 Reduce partial sharding state for this DTensor.
357 Returns:
358 DTensor: A new DTensor with partial state reduced.
359 """
360 if not self._layout:
361 return self
362 to_layout = cp.deepcopy(self._layout)
363 to_layout.reset_partial()
364 # pylint: disable=C0415
365 from hyper_parallel.core.dtensor.tensor_redistribution import _tensor_redistribution
366 out = _tensor_redistribution.reduce_partial(self, to_layout)
367 return out
369 def full_tensor(self) -> Tensor:
370 """
371 Return the full tensor of this DTensor.
373 Returns:
374 Tensor: A Tensor object that represents the full tensor of this DTensor.
375 The returned tensor contains the complete data gathered from
376 all ranks.
378 Note:
379 This operation involves communication across all ranks in the DeviceMesh,
380 which may be expensive for large tensors. Use with caution in
381 performance-critical code paths.
383 Example:
384 >>> # Assume dtensor is sharded across multiple devices
385 >>> local_tensor = dtensor.to_local() # Returns only the local shard
386 >>> full_tensor = dtensor.full_tensor() # Returns the complete tensor
387 """
388 if not self._layout:
389 return self._local_tensor
391 # Create a fully replicated layout
392 replicated_layout = cp.deepcopy(self._layout)
394 # Set all placements to Replicate and convert to tensor_map
395 replicated_placements = [Replicate()] * len(replicated_layout.mesh_shape)
396 replicated_layout.set_placements(replicated_placements)
397 replicated_layout.placement_to_tensor_map(len(self._local_tensor.shape))
399 # Clear partial status from original layout since Replicate has no partial
400 replicated_layout.reset_partial()
402 # Redistribute to the replicated layout and return local tensor
403 # pylint: disable=C0415
404 from hyper_parallel.core.dtensor.tensor_redistribution import _tensor_redistribution
405 out = _tensor_redistribution.redistribution(self, replicated_layout)
406 return out.to_local()
409def distribute_tensor(
410 tensor: Tensor,
411 device_mesh: DeviceMesh,
412 placements: Union[Sequence[Placement], Sequence[Union[str, Tuple[str, ...]]]]
413) -> DTensor:
414 """
415 Distribute a global tensor to the device mesh according to the placements.
417 Args:
418 tensor (Tensor): The global tensor to be distributed. All ranks
419 should have the same tensor data.
420 device_mesh (DeviceMesh): The device mesh describing the device topology.
421 placements: The placement strategy. Supports two styles:
422 - Placement objects (e.g., ``[Shard(0), Replicate()]``).
423 - Alias strings (e.g., ``("dp", "None")`` or
424 ``(("dp", "tp"), "None")``), length must equal the number of
425 tensor dimensions.
427 Returns:
428 DTensor: A new DTensor with the local shard on each rank.
430 Note:
431 This method assumes all ranks have the same global tensor. It slices
432 the tensor locally without communication. If ranks have different
433 data, use `from_local` instead.
435 Example:
436 >>> mesh = init_device_mesh(device_type="npu", mesh_shape=(2, 2), mesh_dim_names=("dp", "tp"))
437 >>> global_tensor = Tensor(np.arange(16).reshape(4, 4))
438 >>> dtensor = distribute_tensor(global_tensor, mesh, [Shard(0), Replicate()])
439 >>> dtensor = distribute_tensor(global_tensor, mesh, ("dp", "None"))
440 """
441 layout = _build_layout(device_mesh, placements, len(tensor.shape))
442 local_tensor = _get_slice_tensor_by_layout(tensor, layout)
443 return DTensor(local_tensor, device_mesh, layout.alias_placements)
446def _distribute_module_param_source(param: Any) -> Tensor:
447 """Tensor data used as the global tensor for :func:`distribute_tensor` (PyTorch uses ``param.data``)."""
448 if hasattr(param, "data"):
449 return param.data
450 return platform.get_param_local_data(param)
453def _distribute_module_new_parameter(key: str, dtensor: DTensor, requires_grad: bool) -> Any:
454 """Build a framework :class:`Parameter` holding *dtensor* (Torch vs MindSpore kwargs differ)."""
455 if platform.platform_type == PlatformType.MINDSPORE:
456 return platform.Parameter(dtensor, name=key, requires_grad=requires_grad)
457 return platform.Parameter(dtensor, requires_grad=requires_grad)
460def _distribute_module_set_param(module: Any, key: str, new_param: Any) -> None:
461 """Register or assign a parameter on *module* (``nn.Module`` or MindSpore ``Cell``)."""
462 if hasattr(module, "register_parameter"):
463 module.register_parameter(key, new_param)
464 return
465 if hasattr(module, "_params"):
466 module._params[key] = new_param
467 if hasattr(module, "_params_list"):
468 module._params_list[key] = new_param
469 if key in module.__dict__:
470 module.__dict__[key] = new_param
471 return
472 raise TypeError(
473 f"distribute_module expects nn.Module-like objects with register_parameter or _params; "
474 f"got {type(module)}."
475 )
478def _distribute_module_iter_params(module: Any) -> list:
479 """Return ``[(name, param), ...]`` for direct parameters (``_parameters`` or ``_params``)."""
480 if hasattr(module, "_parameters"):
481 return list(module._parameters.items())
482 if hasattr(module, "_params"):
483 return list(module._params.items())
484 return []
487def _distribute_module_iter_buffers(module: Any) -> list:
488 """Return ``[(name, buffer), ...]`` if the module has ``_buffers`` (PyTorch ``nn.Module``)."""
489 if hasattr(module, "_buffers"):
490 return list(module._buffers.items())
491 return []
494def _distribute_module_named_modules(module: Any):
495 """``nn.Module.named_modules`` or MindSpore ``Cell.cells_and_names`` (submodule FQNs)."""
496 if hasattr(module, "named_modules"):
497 return module.named_modules()
498 if hasattr(module, "cells_and_names"):
499 return module.cells_and_names()
500 raise TypeError(
501 f"distribute_module expects module-like objects with named_modules or cells_and_names; "
502 f"got {type(module)}."
503 )
506def _replicate_submodule_params_buffers(
507 sub_mod: Any,
508 device_mesh: DeviceMesh,
509 *,
510 module_prefix: str = "",
511) -> None:
512 """Convert plain params/buffers on *sub_mod* to fully replicated :class:`DTensor`."""
513 full_replicate = [Replicate()] * device_mesh.ndim
514 for key, param in _distribute_module_iter_params(sub_mod):
515 if param is None or isinstance(param, DTensorBase):
516 continue
517 src = _distribute_module_param_source(param)
518 requires_grad = bool(getattr(param, "requires_grad", True))
519 dt = distribute_tensor(src, device_mesh, full_replicate)
520 param_name = f"{module_prefix}.{key}" if module_prefix else key
521 new_param = _distribute_module_new_parameter(param_name, dt, requires_grad)
522 _distribute_module_set_param(sub_mod, key, new_param)
523 for key, buffer in _distribute_module_iter_buffers(sub_mod):
524 if buffer is None or isinstance(buffer, DTensorBase):
525 continue
526 sub_mod._buffers[key] = distribute_tensor(buffer, device_mesh, full_replicate)
529def _distribute_module_run_partition_and_replicate(
530 module: Any,
531 device_mesh: DeviceMesh,
532 partition_fn: Optional[Callable[[str, Any, DeviceMesh], None]],
533) -> None:
534 """Call optional ``partition_fn`` per ``named_modules`` and replicate remaining tensors."""
535 if partition_fn is None:
536 for mod_name, submod in _distribute_module_named_modules(module):
537 _replicate_submodule_params_buffers(submod, device_mesh, module_prefix=mod_name)
538 return
539 for mod_name, submod in _distribute_module_named_modules(module):
540 partition_fn(mod_name, submod, device_mesh)
541 _replicate_submodule_params_buffers(submod, device_mesh, module_prefix=mod_name)
544def _distribute_module_register_input_fn(
545 module: Any,
546 device_mesh: DeviceMesh,
547 input_fn: Callable[..., Any],
548) -> None:
549 """Register *input_fn* as a forward pre-hook on *module* (2- or 3-arg, PyTorch-compatible)."""
550 num_args = len(inspect.signature(input_fn).parameters)
551 if num_args == 2:
552 warnings.warn(
553 "Deprecating input_fn that takes two arguments (inputs, device_mesh), "
554 "please use input_fn that takes in (module, inputs, device_mesh) instead!",
555 FutureWarning,
556 stacklevel=3,
557 )
558 module.register_forward_pre_hook(
559 lambda _, inputs: input_fn(inputs, device_mesh)
560 )
561 elif num_args == 3:
562 module.register_forward_pre_hook(
563 lambda mod, inputs: input_fn(mod, inputs, device_mesh)
564 )
565 else:
566 raise ValueError(
567 f"input_fn should take in 2 or 3 arguments, but got {num_args} arguments!"
568 )
571def _distribute_module_register_output_fn(
572 module: Any,
573 device_mesh: DeviceMesh,
574 output_fn: Callable[..., Any],
575) -> None:
576 """Register *output_fn* as a forward hook on *module* (2- or 3-arg, PyTorch-compatible)."""
577 num_args = len(inspect.signature(output_fn).parameters)
578 if num_args == 2:
579 warnings.warn(
580 "Deprecating output_fn that takes two arguments (outputs, device_mesh), "
581 "please use output_fn that takes in (module, outputs, device_mesh) instead!",
582 FutureWarning,
583 stacklevel=3,
584 )
585 module.register_forward_hook(
586 lambda mod, inputs, outputs: output_fn(outputs, device_mesh)
587 )
588 elif num_args == 3:
589 module.register_forward_hook(
590 lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh)
591 )
592 else:
593 raise ValueError(
594 f"output_fn should take in 2 or 3 arguments, but got {num_args} arguments!"
595 )
598def distribute_module(
599 module: Any,
600 device_mesh: Optional[DeviceMesh] = None,
601 partition_fn: Optional[Callable[[str, Any, DeviceMesh], None]] = None,
602 input_fn: Optional[Callable[..., Any]] = None,
603 output_fn: Optional[Callable[..., Any]] = None,
604) -> Any:
605 """PyTorch ``distribute_module`` parity: shard/replicate params and optional I/O hooks.
607 Unsharded parameters and buffers become fully replicated :class:`DTensor` after
608 ``partition_fn``. ``input_fn`` / ``output_fn`` attach only to the root *module*.
610 Args:
611 module: Root ``nn.Module`` or MindSpore ``Cell`` with compatible APIs.
612 device_mesh: Placement mesh; if ``None``, uses ``_mesh_resources.get_current_mesh()``.
613 partition_fn: Per ``named_modules`` callback before replicate pass; ``None`` replicates all.
614 input_fn: ``(module, inputs, mesh)`` or deprecated ``(inputs, mesh)`` pre-hook.
615 output_fn: ``(module, outputs, mesh)`` or deprecated ``(outputs, mesh)`` forward hook.
617 Returns:
618 *module* in place, with distributed tensors where applied.
620 Raises:
621 RuntimeError: If called twice on the same *module*.
622 ValueError: If ``input_fn`` / ``output_fn`` arity is not 2 or 3.
624 Note:
625 XLA / ``torch_xla`` is not supported; strided device :class:`DTensor` only.
626 """
627 if getattr(module, "_distribute_module_applied", False):
628 raise RuntimeError(
629 "distribute_module should only be called once on a module, "
630 "but it has already been called on this module!"
631 )
632 device_mesh = device_mesh or _mesh_resources.get_current_mesh()
633 _distribute_module_run_partition_and_replicate(module, device_mesh, partition_fn)
634 if input_fn is not None:
635 _distribute_module_register_input_fn(module, device_mesh, input_fn)
636 if output_fn is not None:
637 _distribute_module_register_output_fn(module, device_mesh, output_fn)
638 module._distribute_module_applied = True
639 return module
642def _dtensor_init_helper(
643 init_op,
644 size,
645 device_mesh,
646 placements,
647 **kwargs,
648) -> DTensor:
649 """
650 Helper function to create and initialize a distributed tensor.
652 Args:
653 size: Shape of the tensor.
654 dtype: Data type of the tensor.
655 device: Target device for the tensor.
656 requires_grad: Whether the tensor requires gradient.
658 Returns:
659 DTensor: The initialized distributed tensor.
660 """
661 # get local tensor shape
662 local_shape = compute_local_shape_and_global_offset(
663 size, device_mesh, placements
664 )
666 # initialize the local tensor
667 if init_op is platform.full:
668 fill_value = kwargs.pop("fill_value", 0)
669 local_tensor = init_op(local_shape, fill_value, **kwargs)
670 else:
671 local_tensor = init_op(local_shape, **kwargs)
673 return DTensor.from_local(
674 local_tensor,
675 device_mesh,
676 placements,
677 )
680def ones(
681 size,
682 device_mesh,
683 placements,
684) -> DTensor:
685 """
686 Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined
687 by the variable argument ``size``.
689 Args:
690 size (Union[tuple[int], list[int], int, Tensor]): The specified shape of output tensor. Only positive integer or
691 tuple or Tensor containing positive integers are allowed. If it is a Tensor,
692 it must be a 0-D or 1-D Tensor with int32 or int64 dtypes.
694 Keyword args:
695 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
696 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
698 Returns:
699 A :class:`DTensor` object on each rank
700 """
701 ones_ = platform.ones
702 return _dtensor_init_helper(
703 ones_,
704 size,
705 device_mesh=device_mesh,
706 placements=placements,
707 )
710def empty(
711 size,
712 device_mesh,
713 placements,
714) -> DTensor:
715 """
716 Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor`
717 is defined by the variable argument ``size``.
719 Args:
720 size (Union[tuple[int], list[int], int]): The specified shape of output tensor. Can be variable numbers of
721 positive integers or tuple or list containing positive integers.
723 Keyword args:
724 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
725 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
727 Returns:
728 A :class:`DTensor` object on each rank
729 """
730 empty_ = platform.empty
731 return _dtensor_init_helper(
732 empty_,
733 size,
734 device_mesh=device_mesh,
735 placements=placements,
736 )
739def full(
740 size,
741 fill_value,
742 *,
743 device_mesh,
744 placements,
745) -> DTensor:
746 """
747 Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and
748 ``placements``, with the shape defined by the argument ``size``.
750 Args:
751 size (Union[tuple[int], list[int]]): The specified shape of output tensor.
752 fill_value (Union[numbers.Number, Tensor]): Value to fill the returned tensor. It can be a scalar number, a 0-D
753 Tensor, or a 1-D Tensor with only one element.
755 Keyword args:
756 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
757 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
759 Returns:
760 A :class:`DTensor` object on each rank
761 """
762 full_ = platform.full
763 return _dtensor_init_helper(
764 full_,
765 size,
766 fill_value=fill_value,
767 device_mesh=device_mesh,
768 placements=placements,
769 )
772def zeros(
773 size,
774 device_mesh,
775 placements,
776) -> DTensor:
777 """
778 Returns a :class:`DTensor` filled with the scalar value 0.
780 Args:
781 size (Union[tuple[int], list[int], int, Tensor]): The specified shape of output tensor. Only positive integer or
782 tuple or Tensor containing positive integers are allowed. If it is a Tensor,
783 it must be a 0-D or 1-D Tensor with int32 or int64 dtypes.
784 Keyword args:
785 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
786 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
788 Returns:
789 A :class:`DTensor` object on each rank
790 """
791 zeros_ = platform.zeros
792 return _dtensor_init_helper(
793 zeros_,
794 size,
795 device_mesh=device_mesh,
796 placements=placements,
797 )