Coverage for hyper_parallel / core / dtensor.py: 93%

107 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""dtensor""" 

16import copy as cp 

17from typing import Sequence, Tuple 

18import numpy as np 

19from hyper_parallel.core.layout import Layout, DeviceMesh, _get_slice_tensor_by_layout 

20from hyper_parallel.core.placement_types import Placement, Replicate 

21from hyper_parallel.platform import get_platform 

22from hyper_parallel.core.utils import compute_local_shape_and_global_offset 

23 

24platform = get_platform() 

25DTensorBase = platform.DTensorBase 

26Tensor = platform.Tensor 

27 

28 

29class SkipDTensorDispatch(): 

30 def __enter__(self): 

31 # pylint: disable=C0415 

32 from hyper_parallel.core.shard._op_dispatch import disable_dtensor_dispatch 

33 disable_dtensor_dispatch() 

34 

35 def __exit__(self, exc_type, exc_val, exc_tb): 

36 # pylint: disable=C0415 

37 from hyper_parallel.core.shard._op_dispatch import enable_dtensor_dispatch 

38 enable_dtensor_dispatch() 

39 

40 

41def _build_layout( 

42 device_mesh: DeviceMesh, 

43 placements: Sequence[Placement], 

44 tensor_dim: int 

45) -> Layout: 

46 """ 

47 Build Layout from device_mesh and placements. 

48 

49 Args: 

50 device_mesh: The device mesh describing the device topology. 

51 placements: Sequence of Placement objects (Shard, Replicate, etc.). 

52 tensor_dim: Number of dimensions in the tensor. 

53 

54 Returns: 

55 Layout: The built layout object. 

56 """ 

57 layout = Layout.from_device_mesh(device_mesh) 

58 result = layout(placements) 

59 result.placement_to_tensor_map(tensor_dim) 

60 return result 

61 

62 

63class DTensor(DTensorBase): 

64 """ 

65 DTensor - Distributed Tensor 

66 

67 A DTensor represents a tensor that is distributed across multiple devices 

68 according to a DeviceMesh and placement specifications. 

69 

70 Args: 

71 local_tensor (Tensor): The local tensor shard on this device. 

72 device_mesh (DeviceMesh): The device mesh describing the device topology. 

73 placements (Sequence[Placement]): The placement strategy for each mesh dimension. 

74 Each element should be a Placement object (Shard, Replicate, Partial, etc.). 

75 

76 Example: 

77 >>> from hyper_parallel.core.placement_types import Shard, Replicate 

78 >>> mesh = init_device_mesh(device_type="npu", mesh_shape=(2, 2), mesh_dim_names=("dp", "tp")) 

79 >>> local_tensor = Tensor(np.ones((4, 4))) 

80 >>> dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0), Replicate()]) 

81 """ 

82 _local_tensor: Tensor 

83 _device_mesh: DeviceMesh 

84 _placements: Sequence[Placement] 

85 

86 def __init_data__( 

87 self, 

88 local_tensor: Tensor, 

89 device_mesh: DeviceMesh, 

90 placements: Sequence[Placement] 

91 ): 

92 self._local_tensor = local_tensor 

93 self._device_mesh = device_mesh 

94 self._placements = tuple(placements) 

95 # Build internal layout for redistribution operations 

96 self._layout = _build_layout( 

97 device_mesh, placements, len(local_tensor.shape) 

98 ) 

99 

100 @property 

101 def device_mesh(self) -> DeviceMesh: 

102 """The device mesh of this DTensor.""" 

103 return self._device_mesh 

104 

105 @property 

106 def placements(self) -> Sequence[Placement]: 

107 """The placements of this DTensor.""" 

108 return self._placements 

109 

110 @property 

111 def layout(self) -> Layout: 

112 """Internal layout for redistribution (for backward compatibility).""" 

113 if not hasattr(self, '_layout'): 

114 return None 

115 return self._layout 

116 

117 @staticmethod 

118 def from_local( 

119 local_tensor: Tensor, 

120 device_mesh: DeviceMesh, 

121 placements: Sequence[Placement] 

122 ) -> 'DTensor': 

123 """ 

124 Create a DTensor from a local tensor with device mesh and placements. 

125 

126 Args: 

127 local_tensor (Tensor): The local tensor shard on this device. 

128 device_mesh (DeviceMesh): The device mesh describing the device topology. 

129 placements (Sequence[Placement]): The placement strategy. Each element 

130 should be a Placement object (Shard, Replicate, Partial, etc.). 

131 

132 Returns: 

133 DTensor: A new DTensor instance. 

134 

135 Example: 

136 >>> from hyper_parallel.core.placement_types import Shard, Replicate 

137 >>> mesh = init_device_mesh(device_type="npu", mesh_shape=(2, 2), mesh_dim_names=("dp", "tp")) 

138 >>> local_tensor = Tensor(np.ones((4, 4))) 

139 >>> dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0), Replicate()]) 

140 """ 

141 return DTensor(local_tensor, device_mesh, placements) 

142 

143 def to_local(self) -> Tensor: 

144 """ 

145 Convert DTensor to local tensor. 

146 

147 Returns: 

148 Tensor: The local tensor shard on this device. 

149 """ 

150 return self._local_tensor 

151 

152 @property 

153 def shape(self) -> Tuple[int, ...]: 

154 """ 

155 The global shape of this DTensor. 

156 

157 Returns: 

158 Tuple[int, ...]: The global tensor shape. 

159 """ 

160 return self._layout.get_global_shape(self._local_tensor.shape) 

161 

162 def size(self, dim=None): 

163 """Return the global shape, consistent with .shape. 

164 

165 Without ``dim`` returns a tuple matching ``self.shape``. 

166 With ``dim`` returns the size of that dimension. 

167 """ 

168 global_shape = self.shape 

169 if dim is not None: 

170 return global_shape[dim] 

171 return global_shape 

172 

173 def numel(self) -> int: 

174 """Return the number of elements in this DTensor.""" 

175 return int(np.prod(self.shape)) 

176 

177 @property 

178 def local_shape(self) -> Tuple[int, ...]: 

179 """ 

180 The local shape of this DTensor on this device. 

181 

182 Returns: 

183 Tuple[int, ...]: The local tensor shape. 

184 """ 

185 return self._local_tensor.shape 

186 

187 def redistribute( 

188 self, 

189 device_mesh: DeviceMesh, 

190 placements: Sequence[Placement] 

191 ) -> 'DTensor': 

192 """ 

193 Redistribute this DTensor to a new device mesh and placements. 

194 

195 Args: 

196 device_mesh (DeviceMesh): The target device mesh. 

197 placements (Sequence[Placement]): The target placements. Each element 

198 should be a Placement object (Shard, Replicate, Partial, etc.). 

199 

200 Returns: 

201 DTensor: A new DTensor with the specified distribution. 

202 

203 Example: 

204 >>> from hyper_parallel.core.placement_types import Shard, Replicate 

205 >>> new_dtensor = dtensor.redistribute(mesh, [Replicate(), Shard(1)]) 

206 """ 

207 # Build dst_layout from device_mesh and placements 

208 dst_layout = _build_layout( 

209 device_mesh, placements, len(self._local_tensor.shape) 

210 ) 

211 

212 # pylint: disable=C0415 

213 from hyper_parallel.core.tensor_redistribution import _tensor_redistribution 

214 out = _tensor_redistribution.redistribution(self, dst_layout) 

215 return out 

216 

217 def reduce_partial(self) -> 'DTensor': 

218 """ 

219 Reduce partial sharding state for this DTensor. 

220 

221 Returns: 

222 DTensor: A new DTensor with partial state reduced. 

223 """ 

224 if not self._layout: 

225 return self 

226 to_layout = cp.deepcopy(self._layout) 

227 to_layout.reset_partial() 

228 # pylint: disable=C0415 

229 from hyper_parallel.core.tensor_redistribution import _tensor_redistribution 

230 out = _tensor_redistribution.reduce_partial(self, to_layout) 

231 return out 

232 

233 def full_tensor(self) -> Tensor: 

234 """ 

235 Return the full tensor of this DTensor. 

236 

237 Returns: 

238 Tensor: A Tensor object that represents the full tensor of this DTensor. 

239 The returned tensor contains the complete data gathered from 

240 all ranks. 

241 

242 Note: 

243 This operation involves communication across all ranks in the DeviceMesh, 

244 which may be expensive for large tensors. Use with caution in 

245 performance-critical code paths. 

246 

247 Example: 

248 >>> # Assume dtensor is sharded across multiple devices 

249 >>> local_tensor = dtensor.to_local() # Returns only the local shard 

250 >>> full_tensor = dtensor.full_tensor() # Returns the complete tensor 

251 """ 

252 if not self._layout: 

253 return self._local_tensor 

254 

255 # Create a fully replicated layout 

256 replicated_layout = cp.deepcopy(self._layout) 

257 

258 # Set all placements to Replicate and convert to tensor_map 

259 replicated_placements = [Replicate()] * len(replicated_layout.mesh_shape) 

260 replicated_layout.set_placements(replicated_placements) 

261 replicated_layout.placement_to_tensor_map(len(self._local_tensor.shape)) 

262 

263 # Clear partial status from original layout since Replicate has no partial 

264 replicated_layout.reset_partial() 

265 

266 # Redistribute to the replicated layout and return local tensor 

267 # pylint: disable=C0415 

268 from hyper_parallel.core.tensor_redistribution import _tensor_redistribution 

269 out = _tensor_redistribution.redistribution(self, replicated_layout) 

270 return out.to_local() 

271 

272 

273def distribute_tensor( 

274 tensor: Tensor, 

275 device_mesh: DeviceMesh, 

276 placements: Sequence[Placement] 

277) -> DTensor: 

278 """ 

279 Distribute a global tensor to the device mesh according to the placements. 

280 

281 Args: 

282 tensor (Tensor): The global tensor to be distributed. All ranks 

283 should have the same tensor data. 

284 device_mesh (DeviceMesh): The device mesh describing the device topology. 

285 placements (Sequence[Placement]): The placement strategy for distribution. 

286 Each element should be a Placement object (Shard, Replicate, etc.). 

287 

288 Returns: 

289 DTensor: A new DTensor with the local shard on each rank. 

290 

291 Note: 

292 This method assumes all ranks have the same global tensor. It slices 

293 the tensor locally without communication. If ranks have different 

294 data, use `from_local` instead. 

295 

296 Example: 

297 >>> from hyper_parallel.core.placement_types import Shard, Replicate 

298 >>> mesh = init_device_mesh(mesh_shape=(2, 2), alias_name=("dp", "tp")) 

299 >>> global_tensor = Tensor(np.arange(16).reshape(4, 4)) 

300 >>> dtensor = distribute_tensor(global_tensor, mesh, [Shard(0), Replicate()]) 

301 >>> # rank 0 and rank1 gets: [[0,1,2,3], [4,5,6,7]] 

302 >>> # rank 2 and rank3 gets: [[8,9,10,11], [12,13,14,15]] 

303 """ 

304 # Build layout from device_mesh and placements 

305 layout = _build_layout(device_mesh, placements, len(tensor.shape)) 

306 

307 # Slice the global tensor to get local shard based on layout 

308 local_tensor = _get_slice_tensor_by_layout(tensor, layout) 

309 

310 return DTensor(local_tensor, device_mesh, placements) 

311 

312def _dtensor_init_helper( 

313 init_op, 

314 size, 

315 device_mesh, 

316 placements, 

317 **kwargs, 

318) -> DTensor: 

319 """ 

320 Helper function to create and initialize a distributed tensor. 

321 

322 Args: 

323 size: Shape of the tensor. 

324 dtype: Data type of the tensor. 

325 device: Target device for the tensor. 

326 requires_grad: Whether the tensor requires gradient. 

327 

328 Returns: 

329 DTensor: The initialized distributed tensor. 

330 """ 

331 # get local tensor shape 

332 local_shape = compute_local_shape_and_global_offset( 

333 size, device_mesh, placements 

334 ) 

335 

336 # initialize the local tensor 

337 if init_op is platform.full: 

338 fill_value = kwargs.pop("fill_value", 0) 

339 local_tensor = init_op(local_shape, fill_value, **kwargs) 

340 else: 

341 local_tensor = init_op(local_shape, **kwargs) 

342 

343 return DTensor.from_local( 

344 local_tensor, 

345 device_mesh, 

346 placements, 

347 ) 

348 

349def ones( 

350 size, 

351 device_mesh, 

352 placements, 

353) -> DTensor: 

354 """ 

355 Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined 

356 by the variable argument ``size``. 

357 

358 Args: 

359 size (Union[tuple[int], list[int], int, Tensor]): The specified shape of output tensor. Only positive integer or 

360 tuple or Tensor containing positive integers are allowed. If it is a Tensor, 

361 it must be a 0-D or 1-D Tensor with int32 or int64 dtypes. 

362 

363 Keyword args: 

364 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks 

365 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` 

366 

367 Returns: 

368 A :class:`DTensor` object on each rank 

369 """ 

370 ones_ = platform.ones 

371 return _dtensor_init_helper( 

372 ones_, 

373 size, 

374 device_mesh=device_mesh, 

375 placements=placements, 

376 ) 

377 

378def empty( 

379 size, 

380 device_mesh, 

381 placements, 

382) -> DTensor: 

383 """ 

384 Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor` 

385 is defined by the variable argument ``size``. 

386 

387 Args: 

388 size (Union[tuple[int], list[int], int]): The specified shape of output tensor. Can be variable numbers of 

389 positive integers or tuple or list containing positive integers. 

390 

391 Keyword args: 

392 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks 

393 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` 

394 

395 Returns: 

396 A :class:`DTensor` object on each rank 

397 """ 

398 empty_ = platform.empty 

399 return _dtensor_init_helper( 

400 empty_, 

401 size, 

402 device_mesh=device_mesh, 

403 placements=placements, 

404 ) 

405 

406 

407def full( 

408 size, 

409 fill_value, 

410 *, 

411 device_mesh, 

412 placements, 

413) -> DTensor: 

414 """ 

415 Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and 

416 ``placements``, with the shape defined by the argument ``size``. 

417 

418 Args: 

419 size (Union[tuple[int], list[int]]): The specified shape of output tensor. 

420 fill_value (Union[numbers.Number, Tensor]): Value to fill the returned tensor. It can be a scalar number, a 0-D 

421 Tensor, or a 1-D Tensor with only one element. 

422 

423 Keyword args: 

424 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. 

425 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` 

426 

427 Returns: 

428 A :class:`DTensor` object on each rank 

429 """ 

430 full_ = platform.full 

431 return _dtensor_init_helper( 

432 full_, 

433 size, 

434 fill_value=fill_value, 

435 device_mesh=device_mesh, 

436 placements=placements, 

437 ) 

438 

439def zeros( 

440 size, 

441 device_mesh, 

442 placements, 

443) -> DTensor: 

444 """ 

445 Returns a :class:`DTensor` filled with the scalar value 0. 

446 

447 Args: 

448 size (Union[tuple[int], list[int], int, Tensor]): The specified shape of output tensor. Only positive integer or 

449 tuple or Tensor containing positive integers are allowed. If it is a Tensor, 

450 it must be a 0-D or 1-D Tensor with int32 or int64 dtypes. 

451 Keyword args: 

452 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks 

453 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` 

454 

455 Returns: 

456 A :class:`DTensor` object on each rank 

457 """ 

458 zeros_ = platform.zeros 

459 return _dtensor_init_helper( 

460 zeros_, 

461 size, 

462 device_mesh=device_mesh, 

463 placements=placements, 

464 )