Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / dtensor.py: 73%

239 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""dtensor""" 

16import copy as cp 

17import inspect 

18import warnings 

19from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union 

20 

21import numpy as np 

22 

23from hyper_parallel.core.dtensor.device_mesh import _mesh_resources 

24from hyper_parallel.core.dtensor.layout import Layout, DeviceMesh, _get_slice_tensor_by_layout 

25from hyper_parallel.core.dtensor.placement_types import Placement, Replicate 

26from hyper_parallel.platform import get_platform 

27from hyper_parallel.platform.platform import PlatformType 

28from hyper_parallel.core.utils import compute_local_shape_and_global_offset 

29 

30platform = get_platform() 

31DTensorBase = platform.DTensorBase 

32Tensor = platform.Tensor 

33 

34 

35class SkipDTensorDispatch(): 

36 """Context manager that disables DTensor op dispatch for the enclosed block. 

37 

38 Args: 

39 no_skip: Optional set of op callables or canonical op name strings that 

40 should still be dispatched through DTensor even within this context. 

41 All other ops bypass DTensor dispatch and operate on local tensors. 

42 

43 Example: 

44 >>> import torch 

45 >>> with SkipDTensorDispatch(no_skip={torch.zeros_like}): 

46 ... # zeros_like still goes through DTensor dispatch; 

47 ... # everything else uses the local tensor path. 

48 ... result = torch.zeros_like(dtensor) 

49 """ 

50 

51 def __init__(self, no_skip: Optional[Set] = None): 

52 self._no_skip_names: Set[str] = set() 

53 if no_skip: 

54 for op in no_skip: 

55 if isinstance(op, str): 

56 self._no_skip_names.add(op) 

57 else: 

58 self._no_skip_names.add(platform.get_op_name(op)) 

59 

60 def __enter__(self): 

61 # pylint: disable=C0415 

62 from hyper_parallel.core.shard._op_dispatch import disable_dtensor_dispatch, add_no_skip_ops 

63 disable_dtensor_dispatch() 

64 if self._no_skip_names: 

65 add_no_skip_ops(self._no_skip_names) 

66 

67 def __exit__(self, exc_type, exc_val, exc_tb): 

68 # pylint: disable=C0415 

69 from hyper_parallel.core.shard._op_dispatch import enable_dtensor_dispatch, remove_no_skip_ops 

70 if self._no_skip_names: 

71 remove_no_skip_ops(self._no_skip_names) 

72 enable_dtensor_dispatch() 

73 

74 

75# Cache for _build_layout to avoid redundant Layout computations 

76# Key: (device_mesh.to_hash(), tuple(placements), tensor_dim) 

77# Value: Layout 

78_LAYOUT_CACHE = {} 

79 

80 

81def _is_alias_placements(placements) -> bool: 

82 """ 

83 Check if placements use alias strings rather than Placement objects. 

84 

85 Alias placements use mesh dimension names (strings) to specify 

86 the sharding strategy, e.g., ("dp", "tp") or (("dp", "tp"), "None"). 

87 All elements must be strings or tuples of strings for the sequence 

88 to be recognized as alias-style. 

89 

90 Args: 

91 placements: A sequence of placement specifications. 

92 

93 Returns: 

94 bool: True if all elements are alias strings or tuples of strings. 

95 """ 

96 if len(placements) == 0: 

97 return False 

98 for p in placements: 

99 if isinstance(p, str): 

100 continue 

101 if isinstance(p, tuple) and len(p) > 0 and all(isinstance(x, str) for x in p): 

102 continue 

103 return False 

104 return True 

105 

106 

107def _build_layout( 

108 device_mesh: DeviceMesh, 

109 placements: Union[Sequence[Placement], Sequence[Union[str, Tuple[str, ...]]]], 

110 tensor_dim: int 

111) -> Layout: 

112 """ 

113 Build Layout from device_mesh and placements. 

114 

115 This function uses a cache to avoid redundant Layout computations 

116 for the same (device_mesh, placements, tensor_dim) combination. 

117 

118 Args: 

119 device_mesh: The device mesh describing the device topology. 

120 placements: Supports two styles: 

121 - Placement objects (Shard, Replicate, etc.) 

122 - Alias strings ("dp", "None", ("dp", "tp"), etc.), length must 

123 equal the number of tensor dimensions (``tensor_dim``). 

124 tensor_dim: Number of dimensions in the tensor. 

125 

126 Returns: 

127 Layout: The built layout object. 

128 

129 Raises: 

130 ValueError: If alias placements length does not match tensor dimensions. 

131 """ 

132 mesh_key = device_mesh.to_hash() 

133 placements_key = tuple(placements) 

134 cache_key = (mesh_key, placements_key, tensor_dim) 

135 

136 if cache_key in _LAYOUT_CACHE: 

137 return _LAYOUT_CACHE[cache_key] 

138 

139 layout = Layout.from_device_mesh(device_mesh) 

140 

141 if _is_alias_placements(placements): 

142 if len(placements) != tensor_dim: 

143 raise ValueError( 

144 f"Alias placements length ({len(placements)}) must equal " 

145 f"tensor dimensions ({tensor_dim})." 

146 ) 

147 result = layout(*placements) 

148 else: 

149 result = layout(placements) 

150 result.placement_to_tensor_map(tensor_dim) 

151 

152 _LAYOUT_CACHE[cache_key] = result 

153 

154 return result 

155 

156 

157class DTensor(DTensorBase): 

158 """ 

159 DTensor - Distributed Tensor 

160 

161 A DTensor represents a tensor that is distributed across multiple devices 

162 according to a DeviceMesh and placement specifications. 

163 

164 Args: 

165 local_tensor (Tensor): The local tensor shard on this device. 

166 device_mesh (DeviceMesh): The device mesh describing the device topology. 

167 placements: The placement strategy. Supports two styles: 

168 - Placement objects (e.g., ``[Shard(0), Replicate()]``). 

169 - Alias strings (e.g., ``("dp", "None")`` or 

170 ``(("dp", "tp"), "None")``), length must equal the number of 

171 tensor dimensions. 

172 

173 Example: 

174 >>> mesh = init_device_mesh(device_type="npu", mesh_shape=(2, 2), mesh_dim_names=("dp", "tp")) 

175 >>> local_tensor = Tensor(np.ones((4, 4))) 

176 >>> # Placement style 

177 >>> dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0), Replicate()]) 

178 >>> # Alias style — length matches tensor dims 

179 >>> dtensor = DTensor.from_local(local_tensor, mesh, ("dp", "None")) 

180 """ 

181 _local_tensor: Tensor 

182 _device_mesh: DeviceMesh 

183 _placements: Sequence[Placement] 

184 

185 def __init_data__( 

186 self, 

187 local_tensor: Tensor, 

188 device_mesh: DeviceMesh, 

189 placements: Union[Sequence[Placement], Sequence[Union[str, Tuple[str, ...]]]] 

190 ): 

191 self._local_tensor = local_tensor 

192 self._device_mesh = device_mesh 

193 self._layout = _build_layout( 

194 device_mesh, placements, len(local_tensor.shape) 

195 ) 

196 self._placements = tuple(self._layout.placements) 

197 

198 @property 

199 def device_mesh(self) -> DeviceMesh: 

200 """The device mesh of this DTensor.""" 

201 return self._device_mesh 

202 

203 @property 

204 def placements(self) -> Sequence[Placement]: 

205 """The placements of this DTensor.""" 

206 return self._placements 

207 

208 @property 

209 def layout(self) -> Layout: 

210 """Internal layout for redistribution (for backward compatibility).""" 

211 if not hasattr(self, '_layout'): 

212 return None 

213 return self._layout 

214 

215 @staticmethod 

216 def from_local( 

217 local_tensor: Tensor, 

218 device_mesh: DeviceMesh, 

219 placements: Union[Sequence[Placement], Sequence[Union[str, Tuple[str, ...]]]] 

220 ) -> 'DTensor': 

221 """ 

222 Create a DTensor from a local tensor with device mesh and placements. 

223 

224 Args: 

225 local_tensor (Tensor): The local tensor shard on this device. 

226 device_mesh (DeviceMesh): The device mesh describing the device topology. 

227 placements: The placement strategy. Supports two styles: 

228 - Placement objects (e.g., ``[Shard(0), Replicate()]``). 

229 - Alias strings (e.g., ``("dp", "None")`` or 

230 ``(("dp", "tp"), "None")``), length must equal the number 

231 of tensor dimensions. 

232 

233 Returns: 

234 DTensor: A new DTensor instance. 

235 

236 Example: 

237 >>> mesh = init_device_mesh(device_type="npu", mesh_shape=(2, 2), mesh_dim_names=("dp", "tp")) 

238 >>> local_tensor = Tensor(np.ones((4, 4))) 

239 >>> dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0), Replicate()]) 

240 >>> dtensor = DTensor.from_local(local_tensor, mesh, ("dp", "None")) 

241 """ 

242 return DTensor(local_tensor, device_mesh, placements) 

243 

244 def _alias_placements(self) -> Sequence[Placement]: 

245 """Return alias_placements from layout, falling back to _placements.""" 

246 if hasattr(self, '_layout') and self._layout: 

247 return self._layout.alias_placements 

248 return self._placements 

249 

250 def to(self, *args, **kwargs): 

251 """Move the DTensor to a different device or dtype. 

252 

253 Delegates to the underlying local tensor's ``to`` method and 

254 reconstructs a DTensor preserving device_mesh and placements. 

255 

256 Args: 

257 *args (tuple): Arguments passed to the underlying tensor's ``to`` 

258 method (e.g., device or dtype). 

259 **kwargs (dict): Keyword arguments for the tensor conversion 

260 (e.g., dtype, device, non_blocking). 

261 

262 Returns: 

263 DTensor: A new DTensor with the converted local tensor. 

264 """ 

265 new_local = self._local_tensor.to(*args, **kwargs) 

266 return self.__class__(new_local, device_mesh=self._device_mesh, 

267 placements=self._alias_placements()) 

268 

269 def float(self): 

270 """Convert the DTensor to float dtype. 

271 

272 Returns: 

273 DTensor: A new DTensor with float32 local tensor. 

274 """ 

275 new_local = self._local_tensor.float() 

276 return self.__class__(new_local, device_mesh=self._device_mesh, 

277 placements=self._alias_placements()) 

278 

279 def to_local(self) -> Tensor: 

280 """ 

281 Convert DTensor to local tensor. 

282 

283 Returns: 

284 Tensor: The local tensor shard on this device. 

285 """ 

286 return self._local_tensor 

287 

288 @property 

289 def shape(self) -> Tuple[int, ...]: 

290 """ 

291 The global shape of this DTensor. 

292 

293 Returns: 

294 Tuple[int, ...]: The global tensor shape. 

295 """ 

296 return self._layout.get_global_shape(self._local_tensor.shape) 

297 

298 def size(self, dim=None): 

299 """Return the global shape, consistent with .shape. 

300 

301 Without ``dim`` returns a tuple matching ``self.shape``. 

302 With ``dim`` returns the size of that dimension. 

303 """ 

304 global_shape = self.shape 

305 if dim is not None: 

306 return global_shape[dim] 

307 return global_shape 

308 

309 def numel(self) -> int: 

310 """Return the number of elements in this DTensor.""" 

311 return int(np.prod(self.shape)) 

312 

313 @property 

314 def local_shape(self) -> Tuple[int, ...]: 

315 """ 

316 The local shape of this DTensor on this device. 

317 

318 Returns: 

319 Tuple[int, ...]: The local tensor shape. 

320 """ 

321 return self._local_tensor.shape 

322 

323 def redistribute( 

324 self, 

325 device_mesh: DeviceMesh, 

326 placements: Union[Sequence[Placement], Sequence[Union[str, Tuple[str, ...]]]] 

327 ) -> 'DTensor': 

328 """ 

329 Redistribute this DTensor to a new device mesh and placements. 

330 

331 Args: 

332 device_mesh (DeviceMesh): The target device mesh. 

333 placements: The target placements. Supports Placement objects 

334 or alias strings. 

335 

336 Returns: 

337 DTensor: A new DTensor with the specified distribution. 

338 

339 Example: 

340 >>> new_dtensor = dtensor.redistribute(mesh, [Replicate(), Shard(1)]) 

341 >>> new_dtensor = dtensor.redistribute(mesh, ("None", "tp")) 

342 """ 

343 # Build dst_layout from device_mesh and placements 

344 dst_layout = _build_layout( 

345 device_mesh, placements, len(self._local_tensor.shape) 

346 ) 

347 

348 # pylint: disable=C0415 

349 from hyper_parallel.core.dtensor.tensor_redistribution import _tensor_redistribution 

350 out = _tensor_redistribution.redistribution(self, dst_layout) 

351 return out 

352 

353 def reduce_partial(self) -> 'DTensor': 

354 """ 

355 Reduce partial sharding state for this DTensor. 

356 

357 Returns: 

358 DTensor: A new DTensor with partial state reduced. 

359 """ 

360 if not self._layout: 

361 return self 

362 to_layout = cp.deepcopy(self._layout) 

363 to_layout.reset_partial() 

364 # pylint: disable=C0415 

365 from hyper_parallel.core.dtensor.tensor_redistribution import _tensor_redistribution 

366 out = _tensor_redistribution.reduce_partial(self, to_layout) 

367 return out 

368 

369 def full_tensor(self) -> Tensor: 

370 """ 

371 Return the full tensor of this DTensor. 

372 

373 Returns: 

374 Tensor: A Tensor object that represents the full tensor of this DTensor. 

375 The returned tensor contains the complete data gathered from 

376 all ranks. 

377 

378 Note: 

379 This operation involves communication across all ranks in the DeviceMesh, 

380 which may be expensive for large tensors. Use with caution in 

381 performance-critical code paths. 

382 

383 Example: 

384 >>> # Assume dtensor is sharded across multiple devices 

385 >>> local_tensor = dtensor.to_local() # Returns only the local shard 

386 >>> full_tensor = dtensor.full_tensor() # Returns the complete tensor 

387 """ 

388 if not self._layout: 

389 return self._local_tensor 

390 

391 # Create a fully replicated layout 

392 replicated_layout = cp.deepcopy(self._layout) 

393 

394 # Set all placements to Replicate and convert to tensor_map 

395 replicated_placements = [Replicate()] * len(replicated_layout.mesh_shape) 

396 replicated_layout.set_placements(replicated_placements) 

397 replicated_layout.placement_to_tensor_map(len(self._local_tensor.shape)) 

398 

399 # Clear partial status from original layout since Replicate has no partial 

400 replicated_layout.reset_partial() 

401 

402 # Redistribute to the replicated layout and return local tensor 

403 # pylint: disable=C0415 

404 from hyper_parallel.core.dtensor.tensor_redistribution import _tensor_redistribution 

405 out = _tensor_redistribution.redistribution(self, replicated_layout) 

406 return out.to_local() 

407 

408 

409def distribute_tensor( 

410 tensor: Tensor, 

411 device_mesh: DeviceMesh, 

412 placements: Union[Sequence[Placement], Sequence[Union[str, Tuple[str, ...]]]] 

413) -> DTensor: 

414 """ 

415 Distribute a global tensor to the device mesh according to the placements. 

416 

417 Args: 

418 tensor (Tensor): The global tensor to be distributed. All ranks 

419 should have the same tensor data. 

420 device_mesh (DeviceMesh): The device mesh describing the device topology. 

421 placements: The placement strategy. Supports two styles: 

422 - Placement objects (e.g., ``[Shard(0), Replicate()]``). 

423 - Alias strings (e.g., ``("dp", "None")`` or 

424 ``(("dp", "tp"), "None")``), length must equal the number of 

425 tensor dimensions. 

426 

427 Returns: 

428 DTensor: A new DTensor with the local shard on each rank. 

429 

430 Note: 

431 This method assumes all ranks have the same global tensor. It slices 

432 the tensor locally without communication. If ranks have different 

433 data, use `from_local` instead. 

434 

435 Example: 

436 >>> mesh = init_device_mesh(device_type="npu", mesh_shape=(2, 2), mesh_dim_names=("dp", "tp")) 

437 >>> global_tensor = Tensor(np.arange(16).reshape(4, 4)) 

438 >>> dtensor = distribute_tensor(global_tensor, mesh, [Shard(0), Replicate()]) 

439 >>> dtensor = distribute_tensor(global_tensor, mesh, ("dp", "None")) 

440 """ 

441 layout = _build_layout(device_mesh, placements, len(tensor.shape)) 

442 local_tensor = _get_slice_tensor_by_layout(tensor, layout) 

443 return DTensor(local_tensor, device_mesh, layout.alias_placements) 

444 

445 

446def _distribute_module_param_source(param: Any) -> Tensor: 

447 """Tensor data used as the global tensor for :func:`distribute_tensor` (PyTorch uses ``param.data``).""" 

448 if hasattr(param, "data"): 

449 return param.data 

450 return platform.get_param_local_data(param) 

451 

452 

453def _distribute_module_new_parameter(key: str, dtensor: DTensor, requires_grad: bool) -> Any: 

454 """Build a framework :class:`Parameter` holding *dtensor* (Torch vs MindSpore kwargs differ).""" 

455 if platform.platform_type == PlatformType.MINDSPORE: 

456 return platform.Parameter(dtensor, name=key, requires_grad=requires_grad) 

457 return platform.Parameter(dtensor, requires_grad=requires_grad) 

458 

459 

460def _distribute_module_set_param(module: Any, key: str, new_param: Any) -> None: 

461 """Register or assign a parameter on *module* (``nn.Module`` or MindSpore ``Cell``).""" 

462 if hasattr(module, "register_parameter"): 

463 module.register_parameter(key, new_param) 

464 return 

465 if hasattr(module, "_params"): 

466 module._params[key] = new_param 

467 if hasattr(module, "_params_list"): 

468 module._params_list[key] = new_param 

469 if key in module.__dict__: 

470 module.__dict__[key] = new_param 

471 return 

472 raise TypeError( 

473 f"distribute_module expects nn.Module-like objects with register_parameter or _params; " 

474 f"got {type(module)}." 

475 ) 

476 

477 

478def _distribute_module_iter_params(module: Any) -> list: 

479 """Return ``[(name, param), ...]`` for direct parameters (``_parameters`` or ``_params``).""" 

480 if hasattr(module, "_parameters"): 

481 return list(module._parameters.items()) 

482 if hasattr(module, "_params"): 

483 return list(module._params.items()) 

484 return [] 

485 

486 

487def _distribute_module_iter_buffers(module: Any) -> list: 

488 """Return ``[(name, buffer), ...]`` if the module has ``_buffers`` (PyTorch ``nn.Module``).""" 

489 if hasattr(module, "_buffers"): 

490 return list(module._buffers.items()) 

491 return [] 

492 

493 

494def _distribute_module_named_modules(module: Any): 

495 """``nn.Module.named_modules`` or MindSpore ``Cell.cells_and_names`` (submodule FQNs).""" 

496 if hasattr(module, "named_modules"): 

497 return module.named_modules() 

498 if hasattr(module, "cells_and_names"): 

499 return module.cells_and_names() 

500 raise TypeError( 

501 f"distribute_module expects module-like objects with named_modules or cells_and_names; " 

502 f"got {type(module)}." 

503 ) 

504 

505 

506def _replicate_submodule_params_buffers( 

507 sub_mod: Any, 

508 device_mesh: DeviceMesh, 

509 *, 

510 module_prefix: str = "", 

511) -> None: 

512 """Convert plain params/buffers on *sub_mod* to fully replicated :class:`DTensor`.""" 

513 full_replicate = [Replicate()] * device_mesh.ndim 

514 for key, param in _distribute_module_iter_params(sub_mod): 

515 if param is None or isinstance(param, DTensorBase): 

516 continue 

517 src = _distribute_module_param_source(param) 

518 requires_grad = bool(getattr(param, "requires_grad", True)) 

519 dt = distribute_tensor(src, device_mesh, full_replicate) 

520 param_name = f"{module_prefix}.{key}" if module_prefix else key 

521 new_param = _distribute_module_new_parameter(param_name, dt, requires_grad) 

522 _distribute_module_set_param(sub_mod, key, new_param) 

523 for key, buffer in _distribute_module_iter_buffers(sub_mod): 

524 if buffer is None or isinstance(buffer, DTensorBase): 

525 continue 

526 sub_mod._buffers[key] = distribute_tensor(buffer, device_mesh, full_replicate) 

527 

528 

529def _distribute_module_run_partition_and_replicate( 

530 module: Any, 

531 device_mesh: DeviceMesh, 

532 partition_fn: Optional[Callable[[str, Any, DeviceMesh], None]], 

533) -> None: 

534 """Call optional ``partition_fn`` per ``named_modules`` and replicate remaining tensors.""" 

535 if partition_fn is None: 

536 for mod_name, submod in _distribute_module_named_modules(module): 

537 _replicate_submodule_params_buffers(submod, device_mesh, module_prefix=mod_name) 

538 return 

539 for mod_name, submod in _distribute_module_named_modules(module): 

540 partition_fn(mod_name, submod, device_mesh) 

541 _replicate_submodule_params_buffers(submod, device_mesh, module_prefix=mod_name) 

542 

543 

544def _distribute_module_register_input_fn( 

545 module: Any, 

546 device_mesh: DeviceMesh, 

547 input_fn: Callable[..., Any], 

548) -> None: 

549 """Register *input_fn* as a forward pre-hook on *module* (2- or 3-arg, PyTorch-compatible).""" 

550 num_args = len(inspect.signature(input_fn).parameters) 

551 if num_args == 2: 

552 warnings.warn( 

553 "Deprecating input_fn that takes two arguments (inputs, device_mesh), " 

554 "please use input_fn that takes in (module, inputs, device_mesh) instead!", 

555 FutureWarning, 

556 stacklevel=3, 

557 ) 

558 module.register_forward_pre_hook( 

559 lambda _, inputs: input_fn(inputs, device_mesh) 

560 ) 

561 elif num_args == 3: 

562 module.register_forward_pre_hook( 

563 lambda mod, inputs: input_fn(mod, inputs, device_mesh) 

564 ) 

565 else: 

566 raise ValueError( 

567 f"input_fn should take in 2 or 3 arguments, but got {num_args} arguments!" 

568 ) 

569 

570 

571def _distribute_module_register_output_fn( 

572 module: Any, 

573 device_mesh: DeviceMesh, 

574 output_fn: Callable[..., Any], 

575) -> None: 

576 """Register *output_fn* as a forward hook on *module* (2- or 3-arg, PyTorch-compatible).""" 

577 num_args = len(inspect.signature(output_fn).parameters) 

578 if num_args == 2: 

579 warnings.warn( 

580 "Deprecating output_fn that takes two arguments (outputs, device_mesh), " 

581 "please use output_fn that takes in (module, outputs, device_mesh) instead!", 

582 FutureWarning, 

583 stacklevel=3, 

584 ) 

585 module.register_forward_hook( 

586 lambda mod, inputs, outputs: output_fn(outputs, device_mesh) 

587 ) 

588 elif num_args == 3: 

589 module.register_forward_hook( 

590 lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh) 

591 ) 

592 else: 

593 raise ValueError( 

594 f"output_fn should take in 2 or 3 arguments, but got {num_args} arguments!" 

595 ) 

596 

597 

598def distribute_module( 

599 module: Any, 

600 device_mesh: Optional[DeviceMesh] = None, 

601 partition_fn: Optional[Callable[[str, Any, DeviceMesh], None]] = None, 

602 input_fn: Optional[Callable[..., Any]] = None, 

603 output_fn: Optional[Callable[..., Any]] = None, 

604) -> Any: 

605 """PyTorch ``distribute_module`` parity: shard/replicate params and optional I/O hooks. 

606 

607 Unsharded parameters and buffers become fully replicated :class:`DTensor` after 

608 ``partition_fn``. ``input_fn`` / ``output_fn`` attach only to the root *module*. 

609 

610 Args: 

611 module: Root ``nn.Module`` or MindSpore ``Cell`` with compatible APIs. 

612 device_mesh: Placement mesh; if ``None``, uses ``_mesh_resources.get_current_mesh()``. 

613 partition_fn: Per ``named_modules`` callback before replicate pass; ``None`` replicates all. 

614 input_fn: ``(module, inputs, mesh)`` or deprecated ``(inputs, mesh)`` pre-hook. 

615 output_fn: ``(module, outputs, mesh)`` or deprecated ``(outputs, mesh)`` forward hook. 

616 

617 Returns: 

618 *module* in place, with distributed tensors where applied. 

619 

620 Raises: 

621 RuntimeError: If called twice on the same *module*. 

622 ValueError: If ``input_fn`` / ``output_fn`` arity is not 2 or 3. 

623 

624 Note: 

625 XLA / ``torch_xla`` is not supported; strided device :class:`DTensor` only. 

626 """ 

627 if getattr(module, "_distribute_module_applied", False): 

628 raise RuntimeError( 

629 "distribute_module should only be called once on a module, " 

630 "but it has already been called on this module!" 

631 ) 

632 device_mesh = device_mesh or _mesh_resources.get_current_mesh() 

633 _distribute_module_run_partition_and_replicate(module, device_mesh, partition_fn) 

634 if input_fn is not None: 

635 _distribute_module_register_input_fn(module, device_mesh, input_fn) 

636 if output_fn is not None: 

637 _distribute_module_register_output_fn(module, device_mesh, output_fn) 

638 module._distribute_module_applied = True 

639 return module 

640 

641 

642def _dtensor_init_helper( 

643 init_op, 

644 size, 

645 device_mesh, 

646 placements, 

647 **kwargs, 

648) -> DTensor: 

649 """ 

650 Helper function to create and initialize a distributed tensor. 

651 

652 Args: 

653 size: Shape of the tensor. 

654 dtype: Data type of the tensor. 

655 device: Target device for the tensor. 

656 requires_grad: Whether the tensor requires gradient. 

657 

658 Returns: 

659 DTensor: The initialized distributed tensor. 

660 """ 

661 # get local tensor shape 

662 local_shape = compute_local_shape_and_global_offset( 

663 size, device_mesh, placements 

664 ) 

665 

666 # initialize the local tensor 

667 if init_op is platform.full: 

668 fill_value = kwargs.pop("fill_value", 0) 

669 local_tensor = init_op(local_shape, fill_value, **kwargs) 

670 else: 

671 local_tensor = init_op(local_shape, **kwargs) 

672 

673 return DTensor.from_local( 

674 local_tensor, 

675 device_mesh, 

676 placements, 

677 ) 

678 

679 

680def ones( 

681 size, 

682 device_mesh, 

683 placements, 

684) -> DTensor: 

685 """ 

686 Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined 

687 by the variable argument ``size``. 

688 

689 Args: 

690 size (Union[tuple[int], list[int], int, Tensor]): The specified shape of output tensor. Only positive integer or 

691 tuple or Tensor containing positive integers are allowed. If it is a Tensor, 

692 it must be a 0-D or 1-D Tensor with int32 or int64 dtypes. 

693 

694 Keyword args: 

695 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks 

696 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` 

697 

698 Returns: 

699 A :class:`DTensor` object on each rank 

700 """ 

701 ones_ = platform.ones 

702 return _dtensor_init_helper( 

703 ones_, 

704 size, 

705 device_mesh=device_mesh, 

706 placements=placements, 

707 ) 

708 

709 

710def empty( 

711 size, 

712 device_mesh, 

713 placements, 

714) -> DTensor: 

715 """ 

716 Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor` 

717 is defined by the variable argument ``size``. 

718 

719 Args: 

720 size (Union[tuple[int], list[int], int]): The specified shape of output tensor. Can be variable numbers of 

721 positive integers or tuple or list containing positive integers. 

722 

723 Keyword args: 

724 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks 

725 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` 

726 

727 Returns: 

728 A :class:`DTensor` object on each rank 

729 """ 

730 empty_ = platform.empty 

731 return _dtensor_init_helper( 

732 empty_, 

733 size, 

734 device_mesh=device_mesh, 

735 placements=placements, 

736 ) 

737 

738 

739def full( 

740 size, 

741 fill_value, 

742 *, 

743 device_mesh, 

744 placements, 

745) -> DTensor: 

746 """ 

747 Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and 

748 ``placements``, with the shape defined by the argument ``size``. 

749 

750 Args: 

751 size (Union[tuple[int], list[int]]): The specified shape of output tensor. 

752 fill_value (Union[numbers.Number, Tensor]): Value to fill the returned tensor. It can be a scalar number, a 0-D 

753 Tensor, or a 1-D Tensor with only one element. 

754 

755 Keyword args: 

756 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. 

757 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` 

758 

759 Returns: 

760 A :class:`DTensor` object on each rank 

761 """ 

762 full_ = platform.full 

763 return _dtensor_init_helper( 

764 full_, 

765 size, 

766 fill_value=fill_value, 

767 device_mesh=device_mesh, 

768 placements=placements, 

769 ) 

770 

771 

772def zeros( 

773 size, 

774 device_mesh, 

775 placements, 

776) -> DTensor: 

777 """ 

778 Returns a :class:`DTensor` filled with the scalar value 0. 

779 

780 Args: 

781 size (Union[tuple[int], list[int], int, Tensor]): The specified shape of output tensor. Only positive integer or 

782 tuple or Tensor containing positive integers are allowed. If it is a Tensor, 

783 it must be a 0-D or 1-D Tensor with int32 or int64 dtypes. 

784 Keyword args: 

785 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks 

786 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` 

787 

788 Returns: 

789 A :class:`DTensor` object on each rank 

790 """ 

791 zeros_ = platform.zeros 

792 return _dtensor_init_helper( 

793 zeros_, 

794 size, 

795 device_mesh=device_mesh, 

796 placements=placements, 

797 )