Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / platform.py: 64%

347 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""framework platform api""" 

16import os 

17from datetime import timedelta 

18from enum import auto, Enum 

19from typing import Optional, Any, Union 

20 

21import numpy as np 

22 

23# Environment variable name used to specify the AI framework platform to use 

24HYPER_PARALLEL_PLATFORM = "HYPER_PARALLEL_PLATFORM" 

25 

26# Identifier for the MindSpore framework 

27HYPER_PARALLEL_PLATFORM_MINDSPORE = "mindspore" 

28 

29# Identifier for the PyTorch framework 

30HYPER_PARALLEL_PLATFORM_TORCH = "torch" 

31 

32 

33class PlatformType(Enum): 

34 """Enumeration class for AI framework platform types. 

35 

36 Used to identify different deep learning framework platform types. 

37 """ 

38 MINDSPORE = auto() 

39 PYTORCH = auto() 

40 

41 

42# Global platform instance, used to cache the created platform object 

43platform = None 

44 

45 

46def get_mindspore_platform(): 

47 """Create and return a MindSpore platform instance. 

48 

49 Returns: 

50 MindSporePlatform: A MindSpore platform instance. 

51 """ 

52 # pylint: disable=C0415 

53 from hyper_parallel.platform.mindspore.platform import MindSporePlatform 

54 global platform 

55 platform = MindSporePlatform() 

56 return platform 

57 

58 

59def get_torch_platform(): 

60 """Create and return a PyTorch platform instance. 

61 

62 Returns: 

63 TorchPlatform: A PyTorch platform instance. 

64 """ 

65 # pylint: disable=C0415 

66 from hyper_parallel.platform.torch.platform import TorchPlatform 

67 global platform 

68 platform = TorchPlatform() 

69 return platform 

70 

71 

72def get_platform(): 

73 """Obtain a framework platform instance. 

74 

75 Returns the appropriate AI framework platform instance based on environment variables or a default priority order. 

76 The lookup priority is as follows: 

77 1. Platform specified by environment variable 

78 2. MindSpore platform (default preferred choice) 

79 3. PyTorch platform (fallback option) 

80 

81 Returns: 

82 Platform: An instance of the framework platform 

83 

84 Raises: 

85 ImportError: Raised when none of the supported frameworks are available 

86 """ 

87 if platform is not None: 

88 return platform 

89 platform_type = os.environ.get(HYPER_PARALLEL_PLATFORM) 

90 if platform_type is not None and isinstance(platform_type, str): 

91 platform_type = platform_type.lower() 

92 if platform_type == HYPER_PARALLEL_PLATFORM_MINDSPORE: 

93 return get_mindspore_platform() 

94 if platform_type == HYPER_PARALLEL_PLATFORM_TORCH: 

95 return get_torch_platform() 

96 try: 

97 return get_mindspore_platform() 

98 except ImportError: 

99 return get_torch_platform() 

100 

101 

102EXISTING_COMM_GROUPS = {} 

103 

104 

105class Platform: 

106 """Platform api""" 

107 current_grad_handle = None 

108 post_grad_handle_process = None 

109 grad_sync_stream = None 

110 

111 @staticmethod 

112 def get_rank(): 

113 """Get the rank of the current process in the default process group. 

114 

115 Returns: 

116 int: The rank of the current process. 

117 """ 

118 raise NotImplementedError("Platform subclasses must implement get_rank") 

119 

120 @staticmethod 

121 def get_global_rank(group, group_rank): 

122 """Convert a group rank to its global rank. 

123 

124 Args: 

125 group: The process group to query. 

126 group_rank (int): The rank within the group. 

127 

128 Returns: 

129 int: The global rank corresponding to the group rank. 

130 """ 

131 raise NotImplementedError("Platform subclasses must implement get_global_rank") 

132 

133 @staticmethod 

134 def get_world_size(): 

135 """Get the total number of processes in the default process group. 

136 

137 Returns: 

138 int: The world size (total number of processes). 

139 """ 

140 raise NotImplementedError("Platform subclasses must implement get_world_size") 

141 

142 @staticmethod 

143 def get_op_name(func): 

144 """Get the canonical name of an operator function. 

145 

146 Args: 

147 func: The operator function to query. 

148 

149 Returns: 

150 str: The canonical name of the operator. 

151 """ 

152 raise NotImplementedError("Platform subclasses must implement get_op_name") 

153 

154 @staticmethod 

155 def differentiable_all_gather_concat(data, group, concat_size, concat_dim): 

156 """Perform differentiable all-gather and concatenate tensors along a dimension. 

157 

158 Args: 

159 data: The input tensor to gather. 

160 group: The process group for collective communication. 

161 concat_size (int): The size to concatenate along concat_dim. 

162 concat_dim (int): The dimension along which to concatenate. 

163 

164 Returns: 

165 The concatenated tensor after all-gather operation. 

166 """ 

167 raise NotImplementedError("Platform subclasses must implement differentiable_all_gather_concat") 

168 

169 @staticmethod 

170 def chunk(data, split_dim, split_size, index): 

171 """Split tensor along a dimension and return the chunk at the given index. 

172 

173 Args: 

174 data: The input tensor to split. 

175 split_dim (int): The dimension along which to split. 

176 split_size (int): The size of each split chunk. 

177 index (int): The index of the chunk to return. 

178 

179 Returns: 

180 The tensor chunk at the specified index. 

181 """ 

182 raise NotImplementedError("Platform subclasses must implement chunk") 

183 

184 @staticmethod 

185 def differentiable_all_to_all(input_data, output_shape, group): 

186 """Perform differentiable all-to-all communication. 

187 

188 Args: 

189 input_data: The input tensor to redistribute. 

190 output_shape: The shape of the output tensor. 

191 group: The process group for collective communication. 

192 

193 Returns: 

194 The output tensor after all-to-all operation. 

195 """ 

196 raise NotImplementedError("Platform subclasses must implement differentiable_all_to_all") 

197 

198 @staticmethod 

199 def tensor_type_cast(input_data, cast_type): 

200 """Cast tensor to a specified dtype. 

201 

202 Args: 

203 input_data: The input tensor to cast. 

204 cast_type: The target dtype to cast to. 

205 

206 Returns: 

207 The tensor cast to the specified dtype. 

208 """ 

209 raise NotImplementedError("Platform subclasses must implement tensor_type_cast") 

210 

211 @staticmethod 

212 def is_tensor(obj: Any) -> bool: 

213 """Return True if ``obj`` is this framework's tensor type.""" 

214 raise NotImplementedError("Platform subclasses must implement is_tensor") 

215 

216 @staticmethod 

217 def get_tensor_storage_size(tensor: Any) -> int: 

218 """Return serialized byte size (numel * element size) for this framework's tensor.""" 

219 raise NotImplementedError("Platform subclasses must implement get_tensor_storage_size") 

220 

221 @staticmethod 

222 def differentiable_all_reduce(data, op, group): 

223 """Perform differentiable all-reduce operation. 

224 

225 Args: 

226 data: The input tensor to reduce. 

227 op: The reduction operation (e.g., sum, max, min). 

228 group: The process group for collective communication. 

229 

230 Returns: 

231 The reduced tensor with gradients supported. 

232 """ 

233 raise NotImplementedError("Platform subclasses must implement differentiable_all_reduce") 

234 

235 @staticmethod 

236 def differentiable_reduce_scatter(data, dev_num, axis, op, group): 

237 """Perform differentiable reduce-scatter operation. 

238 

239 Args: 

240 data: The input tensor to reduce and scatter. 

241 dev_num (int): The number of devices to scatter across. 

242 axis (int): The axis along which to scatter. 

243 op: The reduction operation (e.g., sum, max, min). 

244 group: The process group for collective communication. 

245 

246 Returns: 

247 The scattered tensor chunk with gradients supported. 

248 """ 

249 raise NotImplementedError("Platform subclasses must implement differentiable_reduce_scatter") 

250 

251 @staticmethod 

252 def init_parameters(module, stage_index): 

253 """Initialize parameters for a module at a specific pipeline stage. 

254 

255 This method is primarily needed for MindSpore platform which requires 

256 explicit parameter initialization interface. 

257 

258 Args: 

259 module: The module whose parameters need to be initialized. 

260 stage_index (int): The pipeline stage index for the module. 

261 

262 Raises: 

263 ValueError: If module is None or stage_index is negative. 

264 """ 

265 if module is None: 

266 raise ValueError("input module must not be none.") 

267 if stage_index < 0: 

268 raise ValueError("input stage_index must be positive.") 

269 

270 @staticmethod 

271 def get_cell_construct(cell): 

272 """Get the construct (forward) function of a cell/module. 

273 

274 Args: 

275 cell: The cell or module to get the construct function from. 

276 

277 Returns: 

278 The construct/forward callable of the cell. 

279 """ 

280 raise NotImplementedError("Platform subclasses must implement get_cell_construct") 

281 

282 @staticmethod 

283 def get_cells_and_names(cell): 

284 """Get all nested cells/modules and their names. 

285 

286 Args: 

287 cell: The root cell or module to traverse. 

288 

289 Returns: 

290 list: A list of tuples containing (name, cell) pairs. 

291 """ 

292 raise NotImplementedError("Platform subclasses must implement get_cells_and_names") 

293 

294 @staticmethod 

295 def search_parameter_by_name(cell, param_name: str): 

296 """Search for a parameter by name within a cell/module. 

297 

298 Args: 

299 cell: The cell or module to search in. 

300 param_name (str): The name of the parameter to find. 

301 

302 Returns: 

303 The parameter if found, otherwise None. 

304 """ 

305 raise NotImplementedError("Platform subclasses must implement search_parameter_by_name") 

306 

307 @staticmethod 

308 def update_parameter_by_name(cell, result: tuple, new_param) -> bool: 

309 """Update a parameter by name within a cell/module. 

310 

311 Args: 

312 cell: The cell or module containing the parameter. 

313 result (tuple): A tuple containing (param_name, parameter) to update. 

314 new_param: The new parameter value to set. 

315 

316 Returns: 

317 bool: True if update was successful, False otherwise. 

318 """ 

319 raise NotImplementedError("Platform subclasses must implement update_parameter_by_name") 

320 

321 @staticmethod 

322 def set_layout_into_parameter(param, layout): 

323 """Attach a DTensor layout to a parameter. 

324 

325 Args: 

326 param: The parameter to attach the layout to. 

327 layout: The DTensor layout describing tensor distribution. 

328 """ 

329 raise NotImplementedError("Platform subclasses must implement set_layout_into_parameter") 

330 

331 @staticmethod 

332 def get_param_local_shape(param): 

333 """Get the local shape of a distributed parameter. 

334 

335 Args: 

336 param: The parameter to query. 

337 

338 Returns: 

339 tuple: The local shape of the parameter shard. 

340 """ 

341 raise NotImplementedError("Platform subclasses must implement get_param_local_shape") 

342 

343 @staticmethod 

344 def get_param_local_data(param): 

345 """Get the local data tensor of a distributed parameter. 

346 

347 Args: 

348 param: The parameter to query. 

349 

350 Returns: 

351 The local tensor data of the parameter shard. 

352 """ 

353 raise NotImplementedError("Platform subclasses must implement get_param_local_data") 

354 

355 @staticmethod 

356 def update_param_data(param, data): 

357 """Update the data of a parameter with new tensor data. 

358 

359 Args: 

360 param: The parameter to update. 

361 data: The new tensor data to assign. 

362 """ 

363 raise NotImplementedError("Platform subclasses must implement update_param_data") 

364 

365 @staticmethod 

366 def get_param_type_size(param): 

367 """Get the size in bytes of a parameter's dtype. 

368 

369 Args: 

370 param: The parameter to query. 

371 

372 Returns: 

373 int: The size in bytes of the parameter's data type. 

374 """ 

375 raise NotImplementedError("Platform subclasses must implement get_param_type_size") 

376 

377 @staticmethod 

378 def new_zero_parameter(param_shape, param_type, requires_grad, device): 

379 """Create a new parameter initialized with zeros. 

380 

381 Args: 

382 param_shape (tuple): The shape of the parameter. 

383 param_type: The dtype of the parameter. 

384 requires_grad (bool): Whether the parameter requires gradients. 

385 device: The device on which to create the parameter. 

386 

387 Returns: 

388 A new parameter tensor filled with zeros. 

389 """ 

390 raise NotImplementedError("Platform subclasses must implement new_zero_parameter") 

391 

392 @staticmethod 

393 def new_tensor(tensor_shape, tensor_type, device): 

394 """Create a new tensor with the specified shape, dtype, and device. 

395 

396 Args: 

397 tensor_shape (tuple): The shape of the tensor. 

398 tensor_type: The dtype of the tensor. 

399 device: The device on which to create the tensor. 

400 

401 Returns: 

402 A new tensor with uninitialized values. 

403 """ 

404 raise NotImplementedError("Platform subclasses must implement new_tensor") 

405 

406 @staticmethod 

407 def full_like(tensor, fill_value, dtype=None): 

408 """Create a tensor filled with a value, with same shape as input. 

409 

410 Args: 

411 tensor: The input tensor to copy shape from. 

412 fill_value: The value to fill the new tensor with. 

413 dtype: Optional dtype for the new tensor. If None, uses input tensor's dtype. 

414 

415 Returns: 

416 A new tensor filled with the specified value. 

417 """ 

418 raise NotImplementedError("Platform subclasses must implement full_like") 

419 

420 @staticmethod 

421 def set_tensor_requires_grad(input_tensor): 

422 """Enable gradient tracking for a tensor in-place. 

423 

424 Args: 

425 input_tensor: The tensor to enable gradients for. 

426 

427 Returns: 

428 The same tensor with requires_grad set to True. 

429 """ 

430 raise NotImplementedError("Platform subclasses must implement set_tensor_requires_grad") 

431 

432 @staticmethod 

433 def all_gather_into_tensor(data, group_info, async_op=False): 

434 """Gather tensors from all ranks into a single output tensor. 

435 

436 Args: 

437 data: The input tensor to gather. 

438 group_info: The process group for collective communication. 

439 async_op (bool): If True, returns a work handle for async operation. 

440 

441 Returns: 

442 The gathered tensor, or a tuple of (tensor, handle) if async_op is True. 

443 """ 

444 raise NotImplementedError("Platform subclasses must implement all_gather_into_tensor") 

445 

446 @staticmethod 

447 def all_reduce(data, group_info, async_op=False): 

448 """Reduce tensors across all ranks using specified operation. 

449 

450 Args: 

451 data: The input tensor to reduce. 

452 group_info: The process group for collective communication. 

453 async_op (bool): If True, returns a work handle for async operation. 

454 

455 Returns: 

456 The reduced tensor, or a tuple of (tensor, handle) if async_op is True. 

457 """ 

458 raise NotImplementedError("Platform subclasses must implement all_reduce") 

459 

460 @staticmethod 

461 def broadcast(data, src, group, async_op=False): 

462 """Broadcast tensor from source rank to all ranks in group. 

463 

464 Args: 

465 data: The tensor to broadcast (only valid on source rank). 

466 src (int): The source rank to broadcast from. 

467 group: The process group for collective communication. 

468 async_op (bool): If True, returns a work handle for async operation. 

469 

470 Returns: 

471 The broadcasted tensor, or a tuple of (tensor, handle) if async_op is True. 

472 """ 

473 raise NotImplementedError("Platform subclasses must implement broadcast") 

474 

475 @staticmethod 

476 def isend(tensor, dst=None, group=None, tag=0): 

477 """Send tensor asynchronously to destination rank. 

478 

479 Args: 

480 tensor: The tensor to send. 

481 dst (int, optional): The destination rank. Defaults to None. 

482 group: The process group for communication. Defaults to None. 

483 tag (int): A tag to identify the send operation. Defaults to 0. 

484 

485 Returns: 

486 A work handle that can be waited on. 

487 """ 

488 raise NotImplementedError("Platform subclasses must implement isend") 

489 

490 @staticmethod 

491 def irecv(tensor, src=None, group=None, tag=0): 

492 """Receive tensor asynchronously from source rank. 

493 

494 Args: 

495 tensor: The tensor buffer to receive data into. 

496 src (int, optional): The source rank. Defaults to None. 

497 group: The process group for communication. Defaults to None. 

498 tag (int): A tag to identify the receive operation. Defaults to 0. 

499 

500 Returns: 

501 A work handle that can be waited on. 

502 """ 

503 raise NotImplementedError("Platform subclasses must implement irecv") 

504 

505 @staticmethod 

506 def p2p_exchange(tensor, peer_rank: int, group=None): 

507 """Differentiable symmetric P2P exchange (send local tensor, receive peer's tensor). 

508 

509 Sends ``tensor`` to ``peer_rank`` and simultaneously receives the peer's 

510 tensor. The operation is differentiable: the backward pass performs the 

511 same symmetric exchange on the upstream gradient. 

512 

513 Args: 

514 tensor: Local tensor to send. 

515 peer_rank (int): Global rank of the communication peer. 

516 group: Process group. ``None`` uses the default group. 

517 

518 Returns: 

519 Tensor received from ``peer_rank``, with the same shape and dtype as 

520 the input ``tensor``. 

521 """ 

522 raise NotImplementedError("Platform subclasses must implement p2p_exchange") 

523 

524 @staticmethod 

525 def send_object_list(obj_list, dst=None, group=None): 

526 """Send a list of Python objects to destination rank. 

527 

528 Args: 

529 obj_list (list): The list of Python objects to send. 

530 dst (int, optional): The destination rank. Defaults to None. 

531 group: The process group for communication. Defaults to None. 

532 """ 

533 raise NotImplementedError("Platform subclasses must implement send_object_list") 

534 

535 @staticmethod 

536 def recv_object_list(obj_list, src=None, group=None): 

537 """Receive a list of Python objects from source rank. 

538 

539 Args: 

540 obj_list (list): The list buffer to receive objects into. 

541 src (int, optional): The source rank. Defaults to None. 

542 group: The process group for communication. Defaults to None. 

543 """ 

544 raise NotImplementedError("Platform subclasses must implement recv_object_list") 

545 

546 @staticmethod 

547 def reduce_scatter_tensor(data, group_info, async_op=False): 

548 """Reduce and scatter tensor across all ranks in group. 

549 

550 Args: 

551 data: The input tensor to reduce and scatter. 

552 group_info: The process group for collective communication. 

553 async_op (bool): If True, returns a work handle for async operation. 

554 

555 Returns: 

556 The scattered tensor chunk, or a tuple of (tensor, handle) if async_op is True. 

557 """ 

558 raise NotImplementedError("Platform subclasses must implement reduce_scatter_tensor") 

559 

560 @staticmethod 

561 def all_to_all_single(input_tensor, output_shape, group, async_op=False): 

562 """All-to-all single collective with optional async execution. 

563 

564 Args: 

565 input_tensor: Input tensor to scatter. 

566 output_shape: Shape of the pre-allocated output tensor. 

567 group: Process group (ProcessGroup for torch, group name string for mindspore). 

568 async_op: If True, returns a work handle; the output tensor is 

569 filled only after ``work.wait()`` is called. 

570 

571 Returns: 

572 Tuple ``(output, work)`` where *output* is the result tensor and 

573 *work* is the async handle (``None`` when ``async_op=False``). 

574 

575 Raises: 

576 NotImplementedError: Must be implemented by platform subclasses. 

577 """ 

578 raise NotImplementedError("Platform subclasses must implement all_to_all_single") 

579 

580 @staticmethod 

581 def differentiable_async_a2a_wait(x, work, out_perm, group, world_size, concat_dim, split_dim, 

582 handle_box=None): 

583 """Differentiable wrapper that waits for a pre-launched async A2A. 

584 

585 Wraps the wait-and-reconstruct step in the platform autograd mechanism 

586 so gradients flow correctly through the all-to-all communication. 

587 

588 The A2A direction is seq→head (forward): the output gathers along 

589 ``concat_dim`` (sequence grows from S/cp to S) and scatters along 

590 ``split_dim`` (heads shrink from H to H/ws). 

591 

592 In backward, launches an async head→seq A2A on the incoming gradient 

593 and appends ``(work, out_perm)`` to ``handle_box`` so the caller can 

594 wait just before the projection GEMM, achieving GEMM–A2A overlap. 

595 

596 Args: 

597 x: Original projection output tensor; anchors the op 

598 in the autograd graph. 

599 work: Async work handle from ``all_to_all_single(async_op=True)``. 

600 out_perm: Output buffer filled once ``work.wait()`` completes 

601 (shape ``[ws, ...]``). 

602 group: Process group for the reverse A2A in backward. 

603 world_size: CP/Ulysses degree. 

604 concat_dim: Dimension that is gathered (concatenated) in forward; 

605 typically the sequence dimension. 

606 split_dim: Dimension that is scattered (split) in forward; 

607 typically the head dimension. 

608 handle_box: Optional mutable list ``[]``. In backward, ``(work, out_perm)`` 

609 for the reverse A2A is appended here so the pre-hook can wait. 

610 

611 Returns: 

612 Result tensor with ``concat_dim`` gathered and ``split_dim`` split, 

613 connected to the autograd graph through *x*. 

614 

615 Raises: 

616 NotImplementedError: Must be implemented by platform subclasses. 

617 """ 

618 raise NotImplementedError("Platform subclasses must implement differentiable_async_a2a_wait") 

619 

620 @staticmethod 

621 def differentiable_sync_hook(x, hook_name: str, coordinator): 

622 """Identity operation that intercepts both forward and backward to call 

623 coordinator rendezvous, enabling deterministic comm/compute overlap. 

624 

625 This is the differentiable building block for dual-pipe schedules. 

626 In the forward pass the coordinator is invoked with the forward-side 

627 roles for ``hook_name``; in the backward pass it is invoked with the 

628 backward-side roles. The tensor value and gradient flow through 

629 unchanged. 

630 

631 Args: 

632 x: Input tensor. Returned as-is; gradients flow through. 

633 hook_name: One of ``"A"``, ``"B"``, ``"C"``, ``"D"`` identifying 

634 the position relative to MoE dispatch/combine. 

635 coordinator: A :class:`HookCoordinator` instance shared between the 

636 forward and backward threads. 

637 

638 Returns: 

639 The same tensor *x*, attached to the autograd graph so that the 

640 backward hook will fire. 

641 """ 

642 raise NotImplementedError("Platform subclasses must implement differentiable_sync_hook") 

643 

644 @staticmethod 

645 def differentiable_all_to_all_single(input_tensor, input_splits, output_splits, group): 

646 """Variable-split all-to-all single that supports gradient flow. 

647 

648 Unlike ``all_to_all_single`` (which is not differentiable), this method 

649 wraps the collective in an autograd function so gradients are correctly 

650 routed back through the reverse all-to-all in the backward pass. 

651 Intended for Expert Parallelism token dispatch / combine. 

652 

653 Args: 

654 input_tensor: Input tensor to scatter. Shape ``[sum(input_splits), *feature_dims]``. 

655 input_splits: Per-rank sizes of data sent from this rank (list of ints, 

656 length equal to ep_degree). 

657 output_splits: Per-rank sizes of data received by this rank (list of ints, 

658 length equal to ep_degree). 

659 group: Process group (ProcessGroup for torch, group name str for mindspore). 

660 

661 Returns: 

662 Output tensor of shape ``[sum(output_splits), *feature_dims]``. 

663 

664 Raises: 

665 NotImplementedError: Must be implemented by platform subclasses. 

666 """ 

667 raise NotImplementedError("Platform subclasses must implement differentiable_all_to_all_single") 

668 

669 @staticmethod 

670 def differentiable_all_to_all_single_async(input_tensor, input_splits, output_splits, group): 

671 """Async variant of :meth:`differentiable_all_to_all_single`. 

672 

673 Same semantics but launches the collective with ``async_op=True`` and 

674 only performs a stream-level ``wait`` — the host returns immediately 

675 after dispatching the kernel. Intended for dual-pipe comm/compute 

676 overlap paths where the paired COMPUTE side's rendezvous notify must 

677 fire right after kernel launch (not after the collective actually 

678 completes on device). 

679 

680 Args: 

681 input_tensor: Input tensor to scatter. Shape ``[sum(input_splits), *feature_dims]``. 

682 input_splits: Per-rank sizes of data sent from this rank. 

683 output_splits: Per-rank sizes of data received by this rank. 

684 group: Process group. 

685 

686 Returns: 

687 Output tensor of shape ``[sum(output_splits), *feature_dims]``. 

688 

689 Raises: 

690 NotImplementedError: Must be implemented by platform subclasses. 

691 """ 

692 raise NotImplementedError( 

693 "Platform subclasses must implement differentiable_all_to_all_single_async" 

694 ) 

695 

696 @staticmethod 

697 def arange(start, end=None, step=1, dtype=None, device=None): 

698 """Create a 1-D tensor with evenly spaced values. 

699 

700 Args: 

701 start: Start of interval (inclusive). If *end* is ``None``, 

702 treated as the stop value and *start* defaults to 0. 

703 end: End of interval (exclusive). Defaults to ``None``. 

704 step: Step size. Defaults to ``1``. 

705 dtype: Data type. ``None`` uses the framework default (int64). 

706 device: Target device. 

707 

708 Returns: 

709 1-D tensor ``[start, start+step, ..., end)``. 

710 

711 Raises: 

712 NotImplementedError: Must be implemented by platform subclasses. 

713 """ 

714 raise NotImplementedError("Platform subclasses must implement arange") 

715 

716 @staticmethod 

717 def zeros(size, dtype=None, device=None): 

718 """Create a zero-filled tensor of the given shape. 

719 

720 Args: 

721 size: Shape of the tensor (a single tuple/list). 

722 dtype: Desired data type. ``None`` uses the framework default (float32). 

723 device: Target device. ``None`` uses the framework default. 

724 

725 Returns: 

726 Zero-filled tensor of the specified shape. 

727 

728 Raises: 

729 NotImplementedError: Must be implemented by platform subclasses. 

730 """ 

731 raise NotImplementedError("Platform subclasses must implement zeros") 

732 

733 @staticmethod 

734 def parameters_dict(cell): 

735 """Get the parameters dictionary of a cell/module. 

736 

737 Args: 

738 cell: The cell or module to get parameters from. 

739 

740 Returns: 

741 dict: A dictionary mapping parameter names to parameters. 

742 """ 

743 raise NotImplementedError("Platform subclasses must implement parameters_dict") 

744 

745 @staticmethod 

746 def get_model_state_dict(model, *, options=None): 

747 """Get the state dictionary of a model. 

748 

749 Args: 

750 model: The model to extract state from. 

751 options: Optional configuration for state dict extraction. 

752 

753 Returns: 

754 dict: The state dictionary containing model parameters and buffers. 

755 """ 

756 raise NotImplementedError( 

757 "Platform subclasses must implement get_model_state_dict" 

758 ) 

759 

760 @staticmethod 

761 def save_checkpoint(cell, file_path: str, ckpt_format: str = "safetensors") -> None: 

762 """Save a cell/module checkpoint to file. 

763 

764 Args: 

765 cell: The cell or module to save. 

766 file_path (str): The path to save the checkpoint to. 

767 ckpt_format (str): The file format. 

768 """ 

769 raise NotImplementedError("Platform subclasses must implement save_checkpoint") 

770 

771 @staticmethod 

772 def load_checkpoint(file_path: str, ckpt_format: str = "safetensors") -> dict: 

773 """Load a checkpoint from file. 

774 

775 Args: 

776 file_path (str): The path to load the checkpoint from. 

777 ckpt_format (str): The file format. 

778 

779 Returns: 

780 dict: The loaded checkpoint state dictionary. 

781 """ 

782 raise NotImplementedError("Platform subclasses must implement load_checkpoint") 

783 

784 def _create_group(self, rank_list): 

785 """Create a new process group with the specified ranks. 

786 

787 Internal method to be implemented by subclasses. 

788 

789 Args: 

790 rank_list (list): List of ranks to include in the group. 

791 

792 Returns: 

793 The newly created process group. 

794 """ 

795 raise NotImplementedError("Platform subclasses must implement _create_group") 

796 

797 def new_stream(self): 

798 """Create a new compute stream for asynchronous operations. 

799 

800 Returns: 

801 A new stream object for the current device. 

802 """ 

803 raise NotImplementedError("Platform subclasses must implement new_stream") 

804 

805 def get_stream_context(self): 

806 """Get a context manager for executing operations on a specific stream. 

807 

808 Returns: 

809 A context manager that can be used with 'with' statement to set stream. 

810 """ 

811 raise NotImplementedError("Platform subclasses must implement get_stream_context") 

812 

813 @staticmethod 

814 def get_tensor_transform(): 

815 """Get the tensor transformation utilities for the current framework. 

816 

817 Returns: 

818 A module or object containing tensor transformation functions. 

819 """ 

820 raise NotImplementedError("Platform subclasses must implement get_tensor_transform") 

821 

822 @staticmethod 

823 def construct_strided_slice(x, begin, end, stride): 

824 """Construct a strided slice operation on a tensor. 

825 

826 Args: 

827 x: The input tensor to slice. 

828 begin: The starting indices for each dimension. 

829 end: The ending indices for each dimension. 

830 stride: The stride for each dimension. 

831 

832 Returns: 

833 The sliced tensor. 

834 """ 

835 raise NotImplementedError("Platform subclasses must implement construct_strided_slice") 

836 

837 @staticmethod 

838 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None): 

839 """Split inputs into micro-batches for pipeline parallelism. 

840 

841 Args: 

842 micro_batch_num (int): The number of micro-batches to create. 

843 args_batch_dim (list, optional): Batch dimension for each positional arg. 

844 kwargs_batch_dim (dict, optional): Batch dimension for each keyword arg. 

845 

846 Returns: 

847 A decorator that splits function inputs into micro-batches. 

848 """ 

849 raise NotImplementedError("Platform subclasses must implement micro_batch") 

850 

851 @staticmethod 

852 def get_symmetric_memory_handler(): 

853 raise NotImplementedError("Platform subclasses must implement get_symmetric_memory_handler") 

854 

855 @staticmethod 

856 def load_into_param(param, data): 

857 raise NotImplementedError("Platform subclasses must implement load_into_param") 

858 

859 def create_group(self, rank_list): 

860 """Create or retrieve a communication group with the specified ranks. 

861 

862 If a group with the same rank list already exists, returns the existing 

863 group instead of creating a new one. 

864 

865 Args: 

866 rank_list (list): List of ranks to include in the group. 

867 

868 Returns: 

869 The process group for the specified ranks. 

870 """ 

871 group_key = str(tuple(sorted(rank_list))) 

872 if group_key in EXISTING_COMM_GROUPS: 

873 return EXISTING_COMM_GROUPS[group_key] 

874 

875 group = self._create_group(rank_list) 

876 EXISTING_COMM_GROUPS[group_key] = group 

877 return group 

878 

879 @staticmethod 

880 def _process_current_handle(): 

881 """Wait for the current gradient handle and execute post-process callback. 

882 

883 Internal method to synchronize pending gradient operations. 

884 """ 

885 if Platform.current_grad_handle is None: 

886 return 

887 

888 Platform.current_grad_handle.wait() 

889 if Platform.post_grad_handle_process is None: 

890 return 

891 # pylint: disable=E1102 

892 Platform.post_grad_handle_process() 

893 

894 def set_grad_reduce_handle(self, handle, post_process=None): 

895 """Set a new gradient reduction handle after waiting for the current one. 

896 

897 Waits for any pending gradient handle on the grad sync stream, then 

898 sets the new handle and optional post-process callback. 

899 

900 Args: 

901 handle: The async work handle for gradient reduction. 

902 post_process (callable, optional): Callback to run after handle completes. 

903 """ 

904 if Platform.grad_sync_stream is None: 

905 Platform.grad_sync_stream = self.new_stream() 

906 stream_context = self.get_stream_context() 

907 with stream_context(Platform.grad_sync_stream): 

908 Platform._process_current_handle() 

909 Platform.current_grad_handle = handle 

910 Platform.post_grad_handle_process = post_process 

911 

912 def wait_grad_handle(self): 

913 """Wait for the current gradient handle to complete. 

914 

915 Blocks until the current gradient reduction handle completes and 

916 clears the handle state. 

917 """ 

918 if Platform.current_grad_handle is None: 

919 return 

920 if Platform.grad_sync_stream is None: 

921 Platform.grad_sync_stream = self.new_stream() 

922 stream_context = self.get_stream_context() 

923 with stream_context(Platform.grad_sync_stream): 

924 Platform._process_current_handle() 

925 sync_event = Platform.grad_sync_stream.record_event() 

926 sync_event.wait() 

927 Platform.current_grad_handle = None 

928 Platform.post_grad_handle_process = None 

929 

930 @staticmethod 

931 def all_gather_object(object_list, obj, group=None) -> None: 

932 """Gather Python objects from all ranks into a list. 

933 

934 Each rank contributes its object, and all ranks receive the complete list. 

935 

936 Args: 

937 object_list (list): List to store gathered objects (output parameter). 

938 obj: The Python object from this rank to contribute. 

939 group: The process group for communication. Defaults to None (default group). 

940 """ 

941 raise NotImplementedError("Platform subclasses must implement all_gather_object") 

942 

943 @staticmethod 

944 def barrier(group=None, async_op: bool = False, device_ids=None) -> Any: 

945 """Synchronize all processes in the given process group. 

946 

947 Each rank blocks until every rank in the group enters this collective (when ``async_op`` 

948 is False), or returns an async handle that must be completed before proceeding. 

949 

950 Args: 

951 group: The process group or communication group. ``None`` uses the default group. 

952 async_op (bool): If True, returns a backend-specific async work handle. Default: False. 

953 device_ids: Optional device id list; semantics depend on the backend. 

954 

955 Returns: 

956 Async work handle when ``async_op`` is True; otherwise ``None`` (unless the rank 

957 is not in the group, in which case the backend may return ``None``). 

958 """ 

959 raise NotImplementedError("Platform subclasses must implement barrier") 

960 

961 @staticmethod 

962 def init_process_group( 

963 backend: Optional[str] = None, 

964 *, 

965 init_method: Optional[str] = None, 

966 timeout: Optional[timedelta] = None, 

967 world_size: int = -1, 

968 rank: int = -1, 

969 store: Any = None, 

970 pg_options: Any = None, 

971 device_id: Any = None 

972 ) -> None: 

973 """ 

974 Initialize the default distributed process group. 

975 

976 Args: 

977 backend: The backend to use for distributed communication 

978 init_method: URL specifying how to initialize the process group 

979 timeout: Timeout for operations executed against the process group 

980 world_size: Number of processes participating in the job 

981 rank: Rank of the current process 

982 store: Key/value store for exchanging connection information 

983 pg_options: Process group options for backend-specific configurations 

984 device_id: Specific device this process will work on 

985 

986 Raises: 

987 NotImplementedError: This method must be implemented by subclasses 

988 """ 

989 raise NotImplementedError("Platform subclasses must implement init_process_group") 

990 

991 @staticmethod 

992 def destroy_process_group(group=None) -> None: 

993 """ 

994 Destroy a given process group. 

995 

996 Args: 

997 group: The process group to be destroyed. If None, destroys the default group. 

998 

999 Raises: 

1000 NotImplementedError: This method must be implemented by subclasses 

1001 """ 

1002 raise NotImplementedError("Platform subclasses must implement destroy_process_group") 

1003 

1004 @staticmethod 

1005 def get_process_group_ranks(group=None) -> list[int]: 

1006 """ 

1007 Get rank list of the given process group. 

1008 

1009 Args: 

1010 group: The process group to get ranks from. If None, uses the default group. 

1011 

1012 Returns: 

1013 List of ranks in the specified process group. 

1014 

1015 Raises: 

1016 NotImplementedError: This method must be implemented by subclasses 

1017 """ 

1018 raise NotImplementedError("Platform subclasses must implement get_process_group_ranks") 

1019 

1020 @staticmethod 

1021 def get_backend(group=None): 

1022 """ 

1023 Get the backend of the given process group. 

1024 Args: 

1025 group: The process group to get backend from. If None, uses the default group. 

1026 

1027 Returns: 

1028 The backend name of the specified process group. 

1029 

1030 Raises: 

1031 NotImplementedError: This method must be implemented by subclasses 

1032 """ 

1033 raise NotImplementedError("Platform subclasses must implement get_backend") 

1034 

1035 @staticmethod 

1036 def split_group(parent_pg: Any = None, 

1037 split_ranks: Optional[list] = None, 

1038 timeout: Optional[timedelta] = None, 

1039 pg_options: Optional[Any] = None, 

1040 group_desc: Optional[str] = None, 

1041 ) -> Any: 

1042 """Create a split group relative to the parent process group. 

1043 

1044 Args: 

1045 parent_pg: The parent process group to split from. 

1046 split_ranks (list, optional): Ranks to include in the split group. 

1047 timeout (timedelta, optional): Timeout for operations. 

1048 pg_options: Process group options for backend-specific configurations. 

1049 group_desc (str, optional): Description of the group. 

1050 

1051 Returns: 

1052 The new split process group. 

1053 """ 

1054 raise NotImplementedError("Platform subclasses must implement split_group") 

1055 

1056 @staticmethod 

1057 def get_group_local_rank(group=None) -> int: 

1058 """Get the local rank within the given process group. 

1059 

1060 Args: 

1061 group: The process group to query. If None, uses the default group. 

1062 

1063 Returns: 

1064 int: The local rank within the group. 

1065 """ 

1066 raise NotImplementedError("Platform subclasses must implement get_group_local_rank") 

1067 

1068 @staticmethod 

1069 def no_grad(): 

1070 """Get a context manager to disable gradient computation. 

1071 

1072 Returns: 

1073 A context manager that disables gradient tracking. 

1074 """ 

1075 raise NotImplementedError("Platform subclasses must implement no_grad") 

1076 

1077 @staticmethod 

1078 def cat(tensors, dim=0): 

1079 """Concatenate tensors along a dimension.""" 

1080 raise NotImplementedError("Platform subclasses must implement cat") 

1081 

1082 @staticmethod 

1083 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False): 

1084 """Create an uninitialized tensor with the same shape as input. 

1085 

1086 Args: 

1087 tensor: The input tensor to copy shape from. 

1088 dtype: Optional dtype for the new tensor. If None, uses input tensor's dtype. 

1089 device: Optional device for the new tensor. If None, uses input tensor's device. 

1090 pin_memory (bool): If True, allocate pinned memory for faster CPU-GPU transfer. 

1091 

1092 Returns: 

1093 An uninitialized tensor with the same shape as input. 

1094 """ 

1095 raise NotImplementedError("Platform subclasses must implement empty_like") 

1096 

1097 def get_current_stream(self): 

1098 """Get the current compute stream for the device. 

1099 

1100 Returns: 

1101 The current stream object. 

1102 """ 

1103 raise NotImplementedError("Platform subclasses must implement get_current_stream") 

1104 

1105 def new_event(self): 

1106 """Create a new event for stream synchronization. 

1107 

1108 Returns: 

1109 A new event object. 

1110 """ 

1111 raise NotImplementedError("Platform subclasses must implement new_event") 

1112 

1113 def tree_map(self, fn, tree): 

1114 """Apply a function to all tensors in a nested structure. 

1115 

1116 Args: 

1117 fn (callable): Function to apply to each tensor. 

1118 tree: Nested structure (list, tuple, dict) containing tensors. 

1119 

1120 Returns: 

1121 The same nested structure with fn applied to all tensors. 

1122 """ 

1123 raise NotImplementedError("Platform subclasses must implement tree_map") 

1124 

1125 @staticmethod 

1126 def is_linear_module(module) -> bool: 

1127 """Check whether *module* is a linear/dense layer for the current framework. 

1128 

1129 Args: 

1130 module: The module instance to check. 

1131 

1132 Returns: 

1133 True if *module* is the framework's linear layer type. 

1134 """ 

1135 raise NotImplementedError("Platform subclasses must implement is_linear_module") 

1136 

1137 @staticmethod 

1138 def is_embedding_module(module) -> bool: 

1139 """Check whether *module* is an embedding layer for the current framework. 

1140 

1141 Args: 

1142 module: The module instance to check. 

1143 

1144 Returns: 

1145 True if *module* is the framework's embedding layer type. 

1146 """ 

1147 raise NotImplementedError("Platform subclasses must implement is_embedding_module") 

1148 

1149 @staticmethod 

1150 def register_forward_pre_hook(module, hook, prepend=False, with_kwargs=False): 

1151 """Register a forward pre-hook on a module. 

1152 

1153 Args: 

1154 module: The module to register the hook on. 

1155 hook (callable): The hook function to register. 

1156 prepend (bool): If True, prepend the hook to existing hooks. 

1157 with_kwargs (bool): If True, hook receives both args and kwargs. 

1158 

1159 Returns: 

1160 A handle that can be used to remove the hook. 

1161 """ 

1162 return module.register_forward_pre_hook(hook, prepend=prepend, with_kwargs=with_kwargs) 

1163 

1164 @staticmethod 

1165 def register_full_backward_hook(module, hook, prepend=False): 

1166 """Register a full backward hook on a module. 

1167 

1168 Args: 

1169 module: The module to register the hook on. 

1170 hook (callable): The hook function to register. 

1171 prepend (bool): If True, prepend the hook to existing hooks. 

1172 

1173 Returns: 

1174 A handle that can be used to remove the hook. 

1175 """ 

1176 return module.register_full_backward_hook(hook, prepend) 

1177 

1178 @staticmethod 

1179 def register_full_backward_pre_hook(module, hook, prepend=False): 

1180 """Register a full backward pre-hook on a module. 

1181 

1182 Args: 

1183 module: The module to register the hook on. 

1184 hook (callable): The hook function to register. 

1185 prepend (bool): If True, prepend the hook to existing hooks. 

1186 

1187 Returns: 

1188 A handle that can be used to remove the hook. 

1189 """ 

1190 return module.register_full_backward_pre_hook(hook, prepend) 

1191 

1192 @property 

1193 def checkpoint(self): 

1194 """Get the checkpoint function for activation checkpointing. 

1195 

1196 Returns: 

1197 The checkpoint function for the current framework. 

1198 """ 

1199 raise NotImplementedError("Platform subclasses must implement checkpoint") 

1200 

1201 @staticmethod 

1202 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs): 

1203 """Wrap a module with checkpoint functionality. 

1204 

1205 Args: 

1206 module: The module to wrap with checkpointing. 

1207 checkpoint_fn: Optional custom checkpoint function. 

1208 **checkpoint_fn_kwargs: Additional kwargs for checkpoint function. 

1209 

1210 Returns: 

1211 The wrapped module with checkpointing enabled. 

1212 """ 

1213 raise NotImplementedError("Platform subclasses must implement ckpt_wrapper") 

1214 

1215 @staticmethod 

1216 def swap_wrapper(module, policy_fn=None): 

1217 """Wrap a module with activation swap functionality. 

1218 

1219 Args: 

1220 module: The module to wrap with activation swap. 

1221 policy_fn: Optional per-tensor swap policy function. 

1222 

1223 Returns: 

1224 The wrapped module with activation swap enabled. 

1225 """ 

1226 raise NotImplementedError("Platform subclasses must implement swap_wrapper") 

1227 

1228 @property 

1229 def noop_context_fn(self): 

1230 """Get a no-op context function for checkpointing. 

1231 

1232 Returns: 

1233 A context function that performs no operation. 

1234 """ 

1235 raise NotImplementedError("Platform subclasses must implement noop_context_fn") 

1236 

1237 @staticmethod 

1238 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): 

1239 """Create contexts for selective activation checkpointing. 

1240 

1241 Args: 

1242 policy_fn_or_list: A policy function or list of layer names to checkpoint. 

1243 allow_cache_entry_mutation (bool): Whether to allow cache entry mutation. 

1244 

1245 Returns: 

1246 Context functions for selective checkpointing. 

1247 """ 

1248 raise NotImplementedError("Platform subclasses must implement create_selective_checkpoint_contexts") 

1249 

1250 @staticmethod 

1251 def async_save_on_cpu(policy_fn=None): 

1252 """Create an async CPU offload context for activation checkpointing. 

1253 

1254 Args: 

1255 policy_fn: Optional policy function to determine which activations to offload. 

1256 

1257 Returns: 

1258 Context manager for async CPU offloading during checkpointing. 

1259 """ 

1260 raise NotImplementedError("Platform subclasses must implement async_save_on_cpu") 

1261 

1262 @staticmethod 

1263 def get_element_size(tensor): 

1264 """Get Tensor Element Size""" 

1265 raise NotImplementedError("Platform subclasses must implement get_element_size") 

1266 

1267 @staticmethod 

1268 def tensor_to_numpy(tensor) -> np.ndarray: 

1269 """Convert a framework tensor to a NumPy array. 

1270 

1271 Args: 

1272 tensor: The tensor to convert. 

1273 

1274 Returns: 

1275 np.ndarray: The tensor data as a NumPy array. 

1276 """ 

1277 raise NotImplementedError("Platform subclasses must implement tensor_to_numpy") 

1278 

1279 @staticmethod 

1280 def profiler_record(name): 

1281 """Record a profiler event with the given name. 

1282 

1283 Args: 

1284 name (str): The name of the profiler event. 

1285 

1286 Returns: 

1287 A context manager or decorator for profiling a code region. 

1288 """ 

1289 raise NotImplementedError("Platform subclasses must implement profiler_record") 

1290 

1291 def cast_fp_tensor(self, dtype, x): 

1292 """Cast floating-point tensor to target dtype if applicable. 

1293 

1294 Args: 

1295 dtype: The target dtype to cast to. 

1296 x: The input tensor. 

1297 

1298 Returns: 

1299 The tensor cast to target dtype, or unchanged if not floating-point. 

1300 """ 

1301 raise NotImplementedError("Platform subclasses must implement cast_fp_tensor") 

1302 

1303 def apply_to_tensors(self, fn, container): 

1304 """Recursively apply a function to all tensors in a container. 

1305 

1306 Supports nested structures including lists, tuples, and dicts. 

1307 

1308 Args: 

1309 fn (callable): Function to apply to each tensor. 

1310 container: Nested structure containing tensors. 

1311 

1312 Returns: 

1313 The same structure with fn applied to all tensors. 

1314 """ 

1315 raise NotImplementedError("Platform subclasses must implement apply_to_tensors") 

1316 

1317 @staticmethod 

1318 def clip_grad_norm_( 

1319 parameters, max_norm: float, norm_type: float = 2.0, 

1320 error_if_nonfinite: bool = False, foreach=None, 

1321 ): 

1322 """Compute and clip gradient norms for distributed models. 

1323 

1324 Communication is derived from each parameter's DTensor spec. 

1325 Subclasses must implement this method. 

1326 

1327 Args: 

1328 parameters: An ``nn.Module``, a single ``Tensor``, or an 

1329 iterable of ``Tensor`` s whose gradients to clip. 

1330 max_norm: Maximum allowed gradient norm. 

1331 norm_type: Type of the norm (default ``2.0``). 

1332 error_if_nonfinite: If ``True``, raise when total norm is 

1333 non-finite. Default ``False``. 

1334 foreach: Unused, accepted for API compatibility. 

1335 

1336 Returns: 

1337 The total (unclipped) gradient norm. 

1338 """ 

1339 raise NotImplementedError( 

1340 "Platform subclasses must implement clip_grad_norm_" 

1341 ) 

1342 

1343 @staticmethod 

1344 def get_created_group(rank_list: Union[list[int], tuple[int]]): 

1345 """Get an existing process group by rank list. 

1346 

1347 Args: 

1348 rank_list (Union[list[int], tuple[int]]): Tuple or list of ranks. 

1349 

1350 Returns: 

1351 The process group corresponding to the rank list if it exists, else None. 

1352 """ 

1353 group_key = str(tuple(sorted(rank_list))) 

1354 if group_key in EXISTING_COMM_GROUPS: 

1355 return EXISTING_COMM_GROUPS[group_key] 

1356 return None 

1357 

1358 @classmethod 

1359 def mark_created_groups(cls, process_group: Union[Any, list[Any]]) -> None: 

1360 """Register process groups in the global cache for reuse. 

1361 

1362 Args: 

1363 process_group (Union[Any, list[Any]]): A process group or a list of process groups. 

1364 """ 

1365 if not isinstance(process_group, list): 

1366 process_group = [process_group] 

1367 for group in process_group: 

1368 rank_list = cls.get_process_group_ranks(group) 

1369 group_key = str(tuple(sorted(rank_list))) 

1370 EXISTING_COMM_GROUPS[group_key] = group 

1371 

1372 @property 

1373 def meta_device(self): 

1374 """Get the framework-specific meta device for tensor shape inference. 

1375 

1376 The meta device allows creating tensors without allocating actual storage, 

1377 useful for shape inference and model initialization. 

1378 

1379 Returns: 

1380 The meta device object for the current framework. 

1381 """ 

1382 raise NotImplementedError("Platform subclasses must implement meta_device") 

1383 

1384 def init_on_device(self, device, include_buffers=False): 

1385 """Get a context manager for initializing module parameters on a device. 

1386 

1387 Args: 

1388 device: The target device for parameter initialization. 

1389 include_buffers (bool): If True, also initialize buffers on the device. 

1390 

1391 Returns: 

1392 A context manager for device-specific initialization. 

1393 """ 

1394 raise NotImplementedError("Platform subclasses must implement init_on_device") 

1395 

1396 def str_to_dtype(self, dtype_str: str) -> Any: 

1397 """ 

1398 Map a framework-style dtype string (e.g. ``torch.float32``) to the backend dtype object. 

1399 

1400 Args: 

1401 dtype_str (str): Serialized dtype identifier produced by checkpoint metadata. 

1402 

1403 Returns: 

1404 Framework dtype object (e.g. ``torch.dtype`` or MindSpore dtype). 

1405 """ 

1406 raise NotImplementedError("Platform subclasses must implement str_to_dtype") 

1407 

1408 def list_to_size(self, size_list: list[int]) -> Any: 

1409 """ 

1410 Convert a shape list from checkpoint metadata to the framework's size type (e.g. ``torch.Size``). 

1411 

1412 Args: 

1413 size_list (list[int]): Tensor global shape as a list of ints. 

1414 

1415 Returns: 

1416 Framework-specific size object. 

1417 """ 

1418 raise NotImplementedError("Platform subclasses must implement list_to_size")