Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / platform.py: 64%
347 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""framework platform api"""
16import os
17from datetime import timedelta
18from enum import auto, Enum
19from typing import Optional, Any, Union
21import numpy as np
23# Environment variable name used to specify the AI framework platform to use
24HYPER_PARALLEL_PLATFORM = "HYPER_PARALLEL_PLATFORM"
26# Identifier for the MindSpore framework
27HYPER_PARALLEL_PLATFORM_MINDSPORE = "mindspore"
29# Identifier for the PyTorch framework
30HYPER_PARALLEL_PLATFORM_TORCH = "torch"
33class PlatformType(Enum):
34 """Enumeration class for AI framework platform types.
36 Used to identify different deep learning framework platform types.
37 """
38 MINDSPORE = auto()
39 PYTORCH = auto()
42# Global platform instance, used to cache the created platform object
43platform = None
46def get_mindspore_platform():
47 """Create and return a MindSpore platform instance.
49 Returns:
50 MindSporePlatform: A MindSpore platform instance.
51 """
52 # pylint: disable=C0415
53 from hyper_parallel.platform.mindspore.platform import MindSporePlatform
54 global platform
55 platform = MindSporePlatform()
56 return platform
59def get_torch_platform():
60 """Create and return a PyTorch platform instance.
62 Returns:
63 TorchPlatform: A PyTorch platform instance.
64 """
65 # pylint: disable=C0415
66 from hyper_parallel.platform.torch.platform import TorchPlatform
67 global platform
68 platform = TorchPlatform()
69 return platform
72def get_platform():
73 """Obtain a framework platform instance.
75 Returns the appropriate AI framework platform instance based on environment variables or a default priority order.
76 The lookup priority is as follows:
77 1. Platform specified by environment variable
78 2. MindSpore platform (default preferred choice)
79 3. PyTorch platform (fallback option)
81 Returns:
82 Platform: An instance of the framework platform
84 Raises:
85 ImportError: Raised when none of the supported frameworks are available
86 """
87 if platform is not None:
88 return platform
89 platform_type = os.environ.get(HYPER_PARALLEL_PLATFORM)
90 if platform_type is not None and isinstance(platform_type, str):
91 platform_type = platform_type.lower()
92 if platform_type == HYPER_PARALLEL_PLATFORM_MINDSPORE:
93 return get_mindspore_platform()
94 if platform_type == HYPER_PARALLEL_PLATFORM_TORCH:
95 return get_torch_platform()
96 try:
97 return get_mindspore_platform()
98 except ImportError:
99 return get_torch_platform()
102EXISTING_COMM_GROUPS = {}
105class Platform:
106 """Platform api"""
107 current_grad_handle = None
108 post_grad_handle_process = None
109 grad_sync_stream = None
111 @staticmethod
112 def get_rank():
113 """Get the rank of the current process in the default process group.
115 Returns:
116 int: The rank of the current process.
117 """
118 raise NotImplementedError("Platform subclasses must implement get_rank")
120 @staticmethod
121 def get_global_rank(group, group_rank):
122 """Convert a group rank to its global rank.
124 Args:
125 group: The process group to query.
126 group_rank (int): The rank within the group.
128 Returns:
129 int: The global rank corresponding to the group rank.
130 """
131 raise NotImplementedError("Platform subclasses must implement get_global_rank")
133 @staticmethod
134 def get_world_size():
135 """Get the total number of processes in the default process group.
137 Returns:
138 int: The world size (total number of processes).
139 """
140 raise NotImplementedError("Platform subclasses must implement get_world_size")
142 @staticmethod
143 def get_op_name(func):
144 """Get the canonical name of an operator function.
146 Args:
147 func: The operator function to query.
149 Returns:
150 str: The canonical name of the operator.
151 """
152 raise NotImplementedError("Platform subclasses must implement get_op_name")
154 @staticmethod
155 def differentiable_all_gather_concat(data, group, concat_size, concat_dim):
156 """Perform differentiable all-gather and concatenate tensors along a dimension.
158 Args:
159 data: The input tensor to gather.
160 group: The process group for collective communication.
161 concat_size (int): The size to concatenate along concat_dim.
162 concat_dim (int): The dimension along which to concatenate.
164 Returns:
165 The concatenated tensor after all-gather operation.
166 """
167 raise NotImplementedError("Platform subclasses must implement differentiable_all_gather_concat")
169 @staticmethod
170 def chunk(data, split_dim, split_size, index):
171 """Split tensor along a dimension and return the chunk at the given index.
173 Args:
174 data: The input tensor to split.
175 split_dim (int): The dimension along which to split.
176 split_size (int): The size of each split chunk.
177 index (int): The index of the chunk to return.
179 Returns:
180 The tensor chunk at the specified index.
181 """
182 raise NotImplementedError("Platform subclasses must implement chunk")
184 @staticmethod
185 def differentiable_all_to_all(input_data, output_shape, group):
186 """Perform differentiable all-to-all communication.
188 Args:
189 input_data: The input tensor to redistribute.
190 output_shape: The shape of the output tensor.
191 group: The process group for collective communication.
193 Returns:
194 The output tensor after all-to-all operation.
195 """
196 raise NotImplementedError("Platform subclasses must implement differentiable_all_to_all")
198 @staticmethod
199 def tensor_type_cast(input_data, cast_type):
200 """Cast tensor to a specified dtype.
202 Args:
203 input_data: The input tensor to cast.
204 cast_type: The target dtype to cast to.
206 Returns:
207 The tensor cast to the specified dtype.
208 """
209 raise NotImplementedError("Platform subclasses must implement tensor_type_cast")
211 @staticmethod
212 def is_tensor(obj: Any) -> bool:
213 """Return True if ``obj`` is this framework's tensor type."""
214 raise NotImplementedError("Platform subclasses must implement is_tensor")
216 @staticmethod
217 def get_tensor_storage_size(tensor: Any) -> int:
218 """Return serialized byte size (numel * element size) for this framework's tensor."""
219 raise NotImplementedError("Platform subclasses must implement get_tensor_storage_size")
221 @staticmethod
222 def differentiable_all_reduce(data, op, group):
223 """Perform differentiable all-reduce operation.
225 Args:
226 data: The input tensor to reduce.
227 op: The reduction operation (e.g., sum, max, min).
228 group: The process group for collective communication.
230 Returns:
231 The reduced tensor with gradients supported.
232 """
233 raise NotImplementedError("Platform subclasses must implement differentiable_all_reduce")
235 @staticmethod
236 def differentiable_reduce_scatter(data, dev_num, axis, op, group):
237 """Perform differentiable reduce-scatter operation.
239 Args:
240 data: The input tensor to reduce and scatter.
241 dev_num (int): The number of devices to scatter across.
242 axis (int): The axis along which to scatter.
243 op: The reduction operation (e.g., sum, max, min).
244 group: The process group for collective communication.
246 Returns:
247 The scattered tensor chunk with gradients supported.
248 """
249 raise NotImplementedError("Platform subclasses must implement differentiable_reduce_scatter")
251 @staticmethod
252 def init_parameters(module, stage_index):
253 """Initialize parameters for a module at a specific pipeline stage.
255 This method is primarily needed for MindSpore platform which requires
256 explicit parameter initialization interface.
258 Args:
259 module: The module whose parameters need to be initialized.
260 stage_index (int): The pipeline stage index for the module.
262 Raises:
263 ValueError: If module is None or stage_index is negative.
264 """
265 if module is None:
266 raise ValueError("input module must not be none.")
267 if stage_index < 0:
268 raise ValueError("input stage_index must be positive.")
270 @staticmethod
271 def get_cell_construct(cell):
272 """Get the construct (forward) function of a cell/module.
274 Args:
275 cell: The cell or module to get the construct function from.
277 Returns:
278 The construct/forward callable of the cell.
279 """
280 raise NotImplementedError("Platform subclasses must implement get_cell_construct")
282 @staticmethod
283 def get_cells_and_names(cell):
284 """Get all nested cells/modules and their names.
286 Args:
287 cell: The root cell or module to traverse.
289 Returns:
290 list: A list of tuples containing (name, cell) pairs.
291 """
292 raise NotImplementedError("Platform subclasses must implement get_cells_and_names")
294 @staticmethod
295 def search_parameter_by_name(cell, param_name: str):
296 """Search for a parameter by name within a cell/module.
298 Args:
299 cell: The cell or module to search in.
300 param_name (str): The name of the parameter to find.
302 Returns:
303 The parameter if found, otherwise None.
304 """
305 raise NotImplementedError("Platform subclasses must implement search_parameter_by_name")
307 @staticmethod
308 def update_parameter_by_name(cell, result: tuple, new_param) -> bool:
309 """Update a parameter by name within a cell/module.
311 Args:
312 cell: The cell or module containing the parameter.
313 result (tuple): A tuple containing (param_name, parameter) to update.
314 new_param: The new parameter value to set.
316 Returns:
317 bool: True if update was successful, False otherwise.
318 """
319 raise NotImplementedError("Platform subclasses must implement update_parameter_by_name")
321 @staticmethod
322 def set_layout_into_parameter(param, layout):
323 """Attach a DTensor layout to a parameter.
325 Args:
326 param: The parameter to attach the layout to.
327 layout: The DTensor layout describing tensor distribution.
328 """
329 raise NotImplementedError("Platform subclasses must implement set_layout_into_parameter")
331 @staticmethod
332 def get_param_local_shape(param):
333 """Get the local shape of a distributed parameter.
335 Args:
336 param: The parameter to query.
338 Returns:
339 tuple: The local shape of the parameter shard.
340 """
341 raise NotImplementedError("Platform subclasses must implement get_param_local_shape")
343 @staticmethod
344 def get_param_local_data(param):
345 """Get the local data tensor of a distributed parameter.
347 Args:
348 param: The parameter to query.
350 Returns:
351 The local tensor data of the parameter shard.
352 """
353 raise NotImplementedError("Platform subclasses must implement get_param_local_data")
355 @staticmethod
356 def update_param_data(param, data):
357 """Update the data of a parameter with new tensor data.
359 Args:
360 param: The parameter to update.
361 data: The new tensor data to assign.
362 """
363 raise NotImplementedError("Platform subclasses must implement update_param_data")
365 @staticmethod
366 def get_param_type_size(param):
367 """Get the size in bytes of a parameter's dtype.
369 Args:
370 param: The parameter to query.
372 Returns:
373 int: The size in bytes of the parameter's data type.
374 """
375 raise NotImplementedError("Platform subclasses must implement get_param_type_size")
377 @staticmethod
378 def new_zero_parameter(param_shape, param_type, requires_grad, device):
379 """Create a new parameter initialized with zeros.
381 Args:
382 param_shape (tuple): The shape of the parameter.
383 param_type: The dtype of the parameter.
384 requires_grad (bool): Whether the parameter requires gradients.
385 device: The device on which to create the parameter.
387 Returns:
388 A new parameter tensor filled with zeros.
389 """
390 raise NotImplementedError("Platform subclasses must implement new_zero_parameter")
392 @staticmethod
393 def new_tensor(tensor_shape, tensor_type, device):
394 """Create a new tensor with the specified shape, dtype, and device.
396 Args:
397 tensor_shape (tuple): The shape of the tensor.
398 tensor_type: The dtype of the tensor.
399 device: The device on which to create the tensor.
401 Returns:
402 A new tensor with uninitialized values.
403 """
404 raise NotImplementedError("Platform subclasses must implement new_tensor")
406 @staticmethod
407 def full_like(tensor, fill_value, dtype=None):
408 """Create a tensor filled with a value, with same shape as input.
410 Args:
411 tensor: The input tensor to copy shape from.
412 fill_value: The value to fill the new tensor with.
413 dtype: Optional dtype for the new tensor. If None, uses input tensor's dtype.
415 Returns:
416 A new tensor filled with the specified value.
417 """
418 raise NotImplementedError("Platform subclasses must implement full_like")
420 @staticmethod
421 def set_tensor_requires_grad(input_tensor):
422 """Enable gradient tracking for a tensor in-place.
424 Args:
425 input_tensor: The tensor to enable gradients for.
427 Returns:
428 The same tensor with requires_grad set to True.
429 """
430 raise NotImplementedError("Platform subclasses must implement set_tensor_requires_grad")
432 @staticmethod
433 def all_gather_into_tensor(data, group_info, async_op=False):
434 """Gather tensors from all ranks into a single output tensor.
436 Args:
437 data: The input tensor to gather.
438 group_info: The process group for collective communication.
439 async_op (bool): If True, returns a work handle for async operation.
441 Returns:
442 The gathered tensor, or a tuple of (tensor, handle) if async_op is True.
443 """
444 raise NotImplementedError("Platform subclasses must implement all_gather_into_tensor")
446 @staticmethod
447 def all_reduce(data, group_info, async_op=False):
448 """Reduce tensors across all ranks using specified operation.
450 Args:
451 data: The input tensor to reduce.
452 group_info: The process group for collective communication.
453 async_op (bool): If True, returns a work handle for async operation.
455 Returns:
456 The reduced tensor, or a tuple of (tensor, handle) if async_op is True.
457 """
458 raise NotImplementedError("Platform subclasses must implement all_reduce")
460 @staticmethod
461 def broadcast(data, src, group, async_op=False):
462 """Broadcast tensor from source rank to all ranks in group.
464 Args:
465 data: The tensor to broadcast (only valid on source rank).
466 src (int): The source rank to broadcast from.
467 group: The process group for collective communication.
468 async_op (bool): If True, returns a work handle for async operation.
470 Returns:
471 The broadcasted tensor, or a tuple of (tensor, handle) if async_op is True.
472 """
473 raise NotImplementedError("Platform subclasses must implement broadcast")
475 @staticmethod
476 def isend(tensor, dst=None, group=None, tag=0):
477 """Send tensor asynchronously to destination rank.
479 Args:
480 tensor: The tensor to send.
481 dst (int, optional): The destination rank. Defaults to None.
482 group: The process group for communication. Defaults to None.
483 tag (int): A tag to identify the send operation. Defaults to 0.
485 Returns:
486 A work handle that can be waited on.
487 """
488 raise NotImplementedError("Platform subclasses must implement isend")
490 @staticmethod
491 def irecv(tensor, src=None, group=None, tag=0):
492 """Receive tensor asynchronously from source rank.
494 Args:
495 tensor: The tensor buffer to receive data into.
496 src (int, optional): The source rank. Defaults to None.
497 group: The process group for communication. Defaults to None.
498 tag (int): A tag to identify the receive operation. Defaults to 0.
500 Returns:
501 A work handle that can be waited on.
502 """
503 raise NotImplementedError("Platform subclasses must implement irecv")
505 @staticmethod
506 def p2p_exchange(tensor, peer_rank: int, group=None):
507 """Differentiable symmetric P2P exchange (send local tensor, receive peer's tensor).
509 Sends ``tensor`` to ``peer_rank`` and simultaneously receives the peer's
510 tensor. The operation is differentiable: the backward pass performs the
511 same symmetric exchange on the upstream gradient.
513 Args:
514 tensor: Local tensor to send.
515 peer_rank (int): Global rank of the communication peer.
516 group: Process group. ``None`` uses the default group.
518 Returns:
519 Tensor received from ``peer_rank``, with the same shape and dtype as
520 the input ``tensor``.
521 """
522 raise NotImplementedError("Platform subclasses must implement p2p_exchange")
524 @staticmethod
525 def send_object_list(obj_list, dst=None, group=None):
526 """Send a list of Python objects to destination rank.
528 Args:
529 obj_list (list): The list of Python objects to send.
530 dst (int, optional): The destination rank. Defaults to None.
531 group: The process group for communication. Defaults to None.
532 """
533 raise NotImplementedError("Platform subclasses must implement send_object_list")
535 @staticmethod
536 def recv_object_list(obj_list, src=None, group=None):
537 """Receive a list of Python objects from source rank.
539 Args:
540 obj_list (list): The list buffer to receive objects into.
541 src (int, optional): The source rank. Defaults to None.
542 group: The process group for communication. Defaults to None.
543 """
544 raise NotImplementedError("Platform subclasses must implement recv_object_list")
546 @staticmethod
547 def reduce_scatter_tensor(data, group_info, async_op=False):
548 """Reduce and scatter tensor across all ranks in group.
550 Args:
551 data: The input tensor to reduce and scatter.
552 group_info: The process group for collective communication.
553 async_op (bool): If True, returns a work handle for async operation.
555 Returns:
556 The scattered tensor chunk, or a tuple of (tensor, handle) if async_op is True.
557 """
558 raise NotImplementedError("Platform subclasses must implement reduce_scatter_tensor")
560 @staticmethod
561 def all_to_all_single(input_tensor, output_shape, group, async_op=False):
562 """All-to-all single collective with optional async execution.
564 Args:
565 input_tensor: Input tensor to scatter.
566 output_shape: Shape of the pre-allocated output tensor.
567 group: Process group (ProcessGroup for torch, group name string for mindspore).
568 async_op: If True, returns a work handle; the output tensor is
569 filled only after ``work.wait()`` is called.
571 Returns:
572 Tuple ``(output, work)`` where *output* is the result tensor and
573 *work* is the async handle (``None`` when ``async_op=False``).
575 Raises:
576 NotImplementedError: Must be implemented by platform subclasses.
577 """
578 raise NotImplementedError("Platform subclasses must implement all_to_all_single")
580 @staticmethod
581 def differentiable_async_a2a_wait(x, work, out_perm, group, world_size, concat_dim, split_dim,
582 handle_box=None):
583 """Differentiable wrapper that waits for a pre-launched async A2A.
585 Wraps the wait-and-reconstruct step in the platform autograd mechanism
586 so gradients flow correctly through the all-to-all communication.
588 The A2A direction is seq→head (forward): the output gathers along
589 ``concat_dim`` (sequence grows from S/cp to S) and scatters along
590 ``split_dim`` (heads shrink from H to H/ws).
592 In backward, launches an async head→seq A2A on the incoming gradient
593 and appends ``(work, out_perm)`` to ``handle_box`` so the caller can
594 wait just before the projection GEMM, achieving GEMM–A2A overlap.
596 Args:
597 x: Original projection output tensor; anchors the op
598 in the autograd graph.
599 work: Async work handle from ``all_to_all_single(async_op=True)``.
600 out_perm: Output buffer filled once ``work.wait()`` completes
601 (shape ``[ws, ...]``).
602 group: Process group for the reverse A2A in backward.
603 world_size: CP/Ulysses degree.
604 concat_dim: Dimension that is gathered (concatenated) in forward;
605 typically the sequence dimension.
606 split_dim: Dimension that is scattered (split) in forward;
607 typically the head dimension.
608 handle_box: Optional mutable list ``[]``. In backward, ``(work, out_perm)``
609 for the reverse A2A is appended here so the pre-hook can wait.
611 Returns:
612 Result tensor with ``concat_dim`` gathered and ``split_dim`` split,
613 connected to the autograd graph through *x*.
615 Raises:
616 NotImplementedError: Must be implemented by platform subclasses.
617 """
618 raise NotImplementedError("Platform subclasses must implement differentiable_async_a2a_wait")
620 @staticmethod
621 def differentiable_sync_hook(x, hook_name: str, coordinator):
622 """Identity operation that intercepts both forward and backward to call
623 coordinator rendezvous, enabling deterministic comm/compute overlap.
625 This is the differentiable building block for dual-pipe schedules.
626 In the forward pass the coordinator is invoked with the forward-side
627 roles for ``hook_name``; in the backward pass it is invoked with the
628 backward-side roles. The tensor value and gradient flow through
629 unchanged.
631 Args:
632 x: Input tensor. Returned as-is; gradients flow through.
633 hook_name: One of ``"A"``, ``"B"``, ``"C"``, ``"D"`` identifying
634 the position relative to MoE dispatch/combine.
635 coordinator: A :class:`HookCoordinator` instance shared between the
636 forward and backward threads.
638 Returns:
639 The same tensor *x*, attached to the autograd graph so that the
640 backward hook will fire.
641 """
642 raise NotImplementedError("Platform subclasses must implement differentiable_sync_hook")
644 @staticmethod
645 def differentiable_all_to_all_single(input_tensor, input_splits, output_splits, group):
646 """Variable-split all-to-all single that supports gradient flow.
648 Unlike ``all_to_all_single`` (which is not differentiable), this method
649 wraps the collective in an autograd function so gradients are correctly
650 routed back through the reverse all-to-all in the backward pass.
651 Intended for Expert Parallelism token dispatch / combine.
653 Args:
654 input_tensor: Input tensor to scatter. Shape ``[sum(input_splits), *feature_dims]``.
655 input_splits: Per-rank sizes of data sent from this rank (list of ints,
656 length equal to ep_degree).
657 output_splits: Per-rank sizes of data received by this rank (list of ints,
658 length equal to ep_degree).
659 group: Process group (ProcessGroup for torch, group name str for mindspore).
661 Returns:
662 Output tensor of shape ``[sum(output_splits), *feature_dims]``.
664 Raises:
665 NotImplementedError: Must be implemented by platform subclasses.
666 """
667 raise NotImplementedError("Platform subclasses must implement differentiable_all_to_all_single")
669 @staticmethod
670 def differentiable_all_to_all_single_async(input_tensor, input_splits, output_splits, group):
671 """Async variant of :meth:`differentiable_all_to_all_single`.
673 Same semantics but launches the collective with ``async_op=True`` and
674 only performs a stream-level ``wait`` — the host returns immediately
675 after dispatching the kernel. Intended for dual-pipe comm/compute
676 overlap paths where the paired COMPUTE side's rendezvous notify must
677 fire right after kernel launch (not after the collective actually
678 completes on device).
680 Args:
681 input_tensor: Input tensor to scatter. Shape ``[sum(input_splits), *feature_dims]``.
682 input_splits: Per-rank sizes of data sent from this rank.
683 output_splits: Per-rank sizes of data received by this rank.
684 group: Process group.
686 Returns:
687 Output tensor of shape ``[sum(output_splits), *feature_dims]``.
689 Raises:
690 NotImplementedError: Must be implemented by platform subclasses.
691 """
692 raise NotImplementedError(
693 "Platform subclasses must implement differentiable_all_to_all_single_async"
694 )
696 @staticmethod
697 def arange(start, end=None, step=1, dtype=None, device=None):
698 """Create a 1-D tensor with evenly spaced values.
700 Args:
701 start: Start of interval (inclusive). If *end* is ``None``,
702 treated as the stop value and *start* defaults to 0.
703 end: End of interval (exclusive). Defaults to ``None``.
704 step: Step size. Defaults to ``1``.
705 dtype: Data type. ``None`` uses the framework default (int64).
706 device: Target device.
708 Returns:
709 1-D tensor ``[start, start+step, ..., end)``.
711 Raises:
712 NotImplementedError: Must be implemented by platform subclasses.
713 """
714 raise NotImplementedError("Platform subclasses must implement arange")
716 @staticmethod
717 def zeros(size, dtype=None, device=None):
718 """Create a zero-filled tensor of the given shape.
720 Args:
721 size: Shape of the tensor (a single tuple/list).
722 dtype: Desired data type. ``None`` uses the framework default (float32).
723 device: Target device. ``None`` uses the framework default.
725 Returns:
726 Zero-filled tensor of the specified shape.
728 Raises:
729 NotImplementedError: Must be implemented by platform subclasses.
730 """
731 raise NotImplementedError("Platform subclasses must implement zeros")
733 @staticmethod
734 def parameters_dict(cell):
735 """Get the parameters dictionary of a cell/module.
737 Args:
738 cell: The cell or module to get parameters from.
740 Returns:
741 dict: A dictionary mapping parameter names to parameters.
742 """
743 raise NotImplementedError("Platform subclasses must implement parameters_dict")
745 @staticmethod
746 def get_model_state_dict(model, *, options=None):
747 """Get the state dictionary of a model.
749 Args:
750 model: The model to extract state from.
751 options: Optional configuration for state dict extraction.
753 Returns:
754 dict: The state dictionary containing model parameters and buffers.
755 """
756 raise NotImplementedError(
757 "Platform subclasses must implement get_model_state_dict"
758 )
760 @staticmethod
761 def save_checkpoint(cell, file_path: str, ckpt_format: str = "safetensors") -> None:
762 """Save a cell/module checkpoint to file.
764 Args:
765 cell: The cell or module to save.
766 file_path (str): The path to save the checkpoint to.
767 ckpt_format (str): The file format.
768 """
769 raise NotImplementedError("Platform subclasses must implement save_checkpoint")
771 @staticmethod
772 def load_checkpoint(file_path: str, ckpt_format: str = "safetensors") -> dict:
773 """Load a checkpoint from file.
775 Args:
776 file_path (str): The path to load the checkpoint from.
777 ckpt_format (str): The file format.
779 Returns:
780 dict: The loaded checkpoint state dictionary.
781 """
782 raise NotImplementedError("Platform subclasses must implement load_checkpoint")
784 def _create_group(self, rank_list):
785 """Create a new process group with the specified ranks.
787 Internal method to be implemented by subclasses.
789 Args:
790 rank_list (list): List of ranks to include in the group.
792 Returns:
793 The newly created process group.
794 """
795 raise NotImplementedError("Platform subclasses must implement _create_group")
797 def new_stream(self):
798 """Create a new compute stream for asynchronous operations.
800 Returns:
801 A new stream object for the current device.
802 """
803 raise NotImplementedError("Platform subclasses must implement new_stream")
805 def get_stream_context(self):
806 """Get a context manager for executing operations on a specific stream.
808 Returns:
809 A context manager that can be used with 'with' statement to set stream.
810 """
811 raise NotImplementedError("Platform subclasses must implement get_stream_context")
813 @staticmethod
814 def get_tensor_transform():
815 """Get the tensor transformation utilities for the current framework.
817 Returns:
818 A module or object containing tensor transformation functions.
819 """
820 raise NotImplementedError("Platform subclasses must implement get_tensor_transform")
822 @staticmethod
823 def construct_strided_slice(x, begin, end, stride):
824 """Construct a strided slice operation on a tensor.
826 Args:
827 x: The input tensor to slice.
828 begin: The starting indices for each dimension.
829 end: The ending indices for each dimension.
830 stride: The stride for each dimension.
832 Returns:
833 The sliced tensor.
834 """
835 raise NotImplementedError("Platform subclasses must implement construct_strided_slice")
837 @staticmethod
838 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None):
839 """Split inputs into micro-batches for pipeline parallelism.
841 Args:
842 micro_batch_num (int): The number of micro-batches to create.
843 args_batch_dim (list, optional): Batch dimension for each positional arg.
844 kwargs_batch_dim (dict, optional): Batch dimension for each keyword arg.
846 Returns:
847 A decorator that splits function inputs into micro-batches.
848 """
849 raise NotImplementedError("Platform subclasses must implement micro_batch")
851 @staticmethod
852 def get_symmetric_memory_handler():
853 raise NotImplementedError("Platform subclasses must implement get_symmetric_memory_handler")
855 @staticmethod
856 def load_into_param(param, data):
857 raise NotImplementedError("Platform subclasses must implement load_into_param")
859 def create_group(self, rank_list):
860 """Create or retrieve a communication group with the specified ranks.
862 If a group with the same rank list already exists, returns the existing
863 group instead of creating a new one.
865 Args:
866 rank_list (list): List of ranks to include in the group.
868 Returns:
869 The process group for the specified ranks.
870 """
871 group_key = str(tuple(sorted(rank_list)))
872 if group_key in EXISTING_COMM_GROUPS:
873 return EXISTING_COMM_GROUPS[group_key]
875 group = self._create_group(rank_list)
876 EXISTING_COMM_GROUPS[group_key] = group
877 return group
879 @staticmethod
880 def _process_current_handle():
881 """Wait for the current gradient handle and execute post-process callback.
883 Internal method to synchronize pending gradient operations.
884 """
885 if Platform.current_grad_handle is None:
886 return
888 Platform.current_grad_handle.wait()
889 if Platform.post_grad_handle_process is None:
890 return
891 # pylint: disable=E1102
892 Platform.post_grad_handle_process()
894 def set_grad_reduce_handle(self, handle, post_process=None):
895 """Set a new gradient reduction handle after waiting for the current one.
897 Waits for any pending gradient handle on the grad sync stream, then
898 sets the new handle and optional post-process callback.
900 Args:
901 handle: The async work handle for gradient reduction.
902 post_process (callable, optional): Callback to run after handle completes.
903 """
904 if Platform.grad_sync_stream is None:
905 Platform.grad_sync_stream = self.new_stream()
906 stream_context = self.get_stream_context()
907 with stream_context(Platform.grad_sync_stream):
908 Platform._process_current_handle()
909 Platform.current_grad_handle = handle
910 Platform.post_grad_handle_process = post_process
912 def wait_grad_handle(self):
913 """Wait for the current gradient handle to complete.
915 Blocks until the current gradient reduction handle completes and
916 clears the handle state.
917 """
918 if Platform.current_grad_handle is None:
919 return
920 if Platform.grad_sync_stream is None:
921 Platform.grad_sync_stream = self.new_stream()
922 stream_context = self.get_stream_context()
923 with stream_context(Platform.grad_sync_stream):
924 Platform._process_current_handle()
925 sync_event = Platform.grad_sync_stream.record_event()
926 sync_event.wait()
927 Platform.current_grad_handle = None
928 Platform.post_grad_handle_process = None
930 @staticmethod
931 def all_gather_object(object_list, obj, group=None) -> None:
932 """Gather Python objects from all ranks into a list.
934 Each rank contributes its object, and all ranks receive the complete list.
936 Args:
937 object_list (list): List to store gathered objects (output parameter).
938 obj: The Python object from this rank to contribute.
939 group: The process group for communication. Defaults to None (default group).
940 """
941 raise NotImplementedError("Platform subclasses must implement all_gather_object")
943 @staticmethod
944 def barrier(group=None, async_op: bool = False, device_ids=None) -> Any:
945 """Synchronize all processes in the given process group.
947 Each rank blocks until every rank in the group enters this collective (when ``async_op``
948 is False), or returns an async handle that must be completed before proceeding.
950 Args:
951 group: The process group or communication group. ``None`` uses the default group.
952 async_op (bool): If True, returns a backend-specific async work handle. Default: False.
953 device_ids: Optional device id list; semantics depend on the backend.
955 Returns:
956 Async work handle when ``async_op`` is True; otherwise ``None`` (unless the rank
957 is not in the group, in which case the backend may return ``None``).
958 """
959 raise NotImplementedError("Platform subclasses must implement barrier")
961 @staticmethod
962 def init_process_group(
963 backend: Optional[str] = None,
964 *,
965 init_method: Optional[str] = None,
966 timeout: Optional[timedelta] = None,
967 world_size: int = -1,
968 rank: int = -1,
969 store: Any = None,
970 pg_options: Any = None,
971 device_id: Any = None
972 ) -> None:
973 """
974 Initialize the default distributed process group.
976 Args:
977 backend: The backend to use for distributed communication
978 init_method: URL specifying how to initialize the process group
979 timeout: Timeout for operations executed against the process group
980 world_size: Number of processes participating in the job
981 rank: Rank of the current process
982 store: Key/value store for exchanging connection information
983 pg_options: Process group options for backend-specific configurations
984 device_id: Specific device this process will work on
986 Raises:
987 NotImplementedError: This method must be implemented by subclasses
988 """
989 raise NotImplementedError("Platform subclasses must implement init_process_group")
991 @staticmethod
992 def destroy_process_group(group=None) -> None:
993 """
994 Destroy a given process group.
996 Args:
997 group: The process group to be destroyed. If None, destroys the default group.
999 Raises:
1000 NotImplementedError: This method must be implemented by subclasses
1001 """
1002 raise NotImplementedError("Platform subclasses must implement destroy_process_group")
1004 @staticmethod
1005 def get_process_group_ranks(group=None) -> list[int]:
1006 """
1007 Get rank list of the given process group.
1009 Args:
1010 group: The process group to get ranks from. If None, uses the default group.
1012 Returns:
1013 List of ranks in the specified process group.
1015 Raises:
1016 NotImplementedError: This method must be implemented by subclasses
1017 """
1018 raise NotImplementedError("Platform subclasses must implement get_process_group_ranks")
1020 @staticmethod
1021 def get_backend(group=None):
1022 """
1023 Get the backend of the given process group.
1024 Args:
1025 group: The process group to get backend from. If None, uses the default group.
1027 Returns:
1028 The backend name of the specified process group.
1030 Raises:
1031 NotImplementedError: This method must be implemented by subclasses
1032 """
1033 raise NotImplementedError("Platform subclasses must implement get_backend")
1035 @staticmethod
1036 def split_group(parent_pg: Any = None,
1037 split_ranks: Optional[list] = None,
1038 timeout: Optional[timedelta] = None,
1039 pg_options: Optional[Any] = None,
1040 group_desc: Optional[str] = None,
1041 ) -> Any:
1042 """Create a split group relative to the parent process group.
1044 Args:
1045 parent_pg: The parent process group to split from.
1046 split_ranks (list, optional): Ranks to include in the split group.
1047 timeout (timedelta, optional): Timeout for operations.
1048 pg_options: Process group options for backend-specific configurations.
1049 group_desc (str, optional): Description of the group.
1051 Returns:
1052 The new split process group.
1053 """
1054 raise NotImplementedError("Platform subclasses must implement split_group")
1056 @staticmethod
1057 def get_group_local_rank(group=None) -> int:
1058 """Get the local rank within the given process group.
1060 Args:
1061 group: The process group to query. If None, uses the default group.
1063 Returns:
1064 int: The local rank within the group.
1065 """
1066 raise NotImplementedError("Platform subclasses must implement get_group_local_rank")
1068 @staticmethod
1069 def no_grad():
1070 """Get a context manager to disable gradient computation.
1072 Returns:
1073 A context manager that disables gradient tracking.
1074 """
1075 raise NotImplementedError("Platform subclasses must implement no_grad")
1077 @staticmethod
1078 def cat(tensors, dim=0):
1079 """Concatenate tensors along a dimension."""
1080 raise NotImplementedError("Platform subclasses must implement cat")
1082 @staticmethod
1083 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False):
1084 """Create an uninitialized tensor with the same shape as input.
1086 Args:
1087 tensor: The input tensor to copy shape from.
1088 dtype: Optional dtype for the new tensor. If None, uses input tensor's dtype.
1089 device: Optional device for the new tensor. If None, uses input tensor's device.
1090 pin_memory (bool): If True, allocate pinned memory for faster CPU-GPU transfer.
1092 Returns:
1093 An uninitialized tensor with the same shape as input.
1094 """
1095 raise NotImplementedError("Platform subclasses must implement empty_like")
1097 def get_current_stream(self):
1098 """Get the current compute stream for the device.
1100 Returns:
1101 The current stream object.
1102 """
1103 raise NotImplementedError("Platform subclasses must implement get_current_stream")
1105 def new_event(self):
1106 """Create a new event for stream synchronization.
1108 Returns:
1109 A new event object.
1110 """
1111 raise NotImplementedError("Platform subclasses must implement new_event")
1113 def tree_map(self, fn, tree):
1114 """Apply a function to all tensors in a nested structure.
1116 Args:
1117 fn (callable): Function to apply to each tensor.
1118 tree: Nested structure (list, tuple, dict) containing tensors.
1120 Returns:
1121 The same nested structure with fn applied to all tensors.
1122 """
1123 raise NotImplementedError("Platform subclasses must implement tree_map")
1125 @staticmethod
1126 def is_linear_module(module) -> bool:
1127 """Check whether *module* is a linear/dense layer for the current framework.
1129 Args:
1130 module: The module instance to check.
1132 Returns:
1133 True if *module* is the framework's linear layer type.
1134 """
1135 raise NotImplementedError("Platform subclasses must implement is_linear_module")
1137 @staticmethod
1138 def is_embedding_module(module) -> bool:
1139 """Check whether *module* is an embedding layer for the current framework.
1141 Args:
1142 module: The module instance to check.
1144 Returns:
1145 True if *module* is the framework's embedding layer type.
1146 """
1147 raise NotImplementedError("Platform subclasses must implement is_embedding_module")
1149 @staticmethod
1150 def register_forward_pre_hook(module, hook, prepend=False, with_kwargs=False):
1151 """Register a forward pre-hook on a module.
1153 Args:
1154 module: The module to register the hook on.
1155 hook (callable): The hook function to register.
1156 prepend (bool): If True, prepend the hook to existing hooks.
1157 with_kwargs (bool): If True, hook receives both args and kwargs.
1159 Returns:
1160 A handle that can be used to remove the hook.
1161 """
1162 return module.register_forward_pre_hook(hook, prepend=prepend, with_kwargs=with_kwargs)
1164 @staticmethod
1165 def register_full_backward_hook(module, hook, prepend=False):
1166 """Register a full backward hook on a module.
1168 Args:
1169 module: The module to register the hook on.
1170 hook (callable): The hook function to register.
1171 prepend (bool): If True, prepend the hook to existing hooks.
1173 Returns:
1174 A handle that can be used to remove the hook.
1175 """
1176 return module.register_full_backward_hook(hook, prepend)
1178 @staticmethod
1179 def register_full_backward_pre_hook(module, hook, prepend=False):
1180 """Register a full backward pre-hook on a module.
1182 Args:
1183 module: The module to register the hook on.
1184 hook (callable): The hook function to register.
1185 prepend (bool): If True, prepend the hook to existing hooks.
1187 Returns:
1188 A handle that can be used to remove the hook.
1189 """
1190 return module.register_full_backward_pre_hook(hook, prepend)
1192 @property
1193 def checkpoint(self):
1194 """Get the checkpoint function for activation checkpointing.
1196 Returns:
1197 The checkpoint function for the current framework.
1198 """
1199 raise NotImplementedError("Platform subclasses must implement checkpoint")
1201 @staticmethod
1202 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs):
1203 """Wrap a module with checkpoint functionality.
1205 Args:
1206 module: The module to wrap with checkpointing.
1207 checkpoint_fn: Optional custom checkpoint function.
1208 **checkpoint_fn_kwargs: Additional kwargs for checkpoint function.
1210 Returns:
1211 The wrapped module with checkpointing enabled.
1212 """
1213 raise NotImplementedError("Platform subclasses must implement ckpt_wrapper")
1215 @staticmethod
1216 def swap_wrapper(module, policy_fn=None):
1217 """Wrap a module with activation swap functionality.
1219 Args:
1220 module: The module to wrap with activation swap.
1221 policy_fn: Optional per-tensor swap policy function.
1223 Returns:
1224 The wrapped module with activation swap enabled.
1225 """
1226 raise NotImplementedError("Platform subclasses must implement swap_wrapper")
1228 @property
1229 def noop_context_fn(self):
1230 """Get a no-op context function for checkpointing.
1232 Returns:
1233 A context function that performs no operation.
1234 """
1235 raise NotImplementedError("Platform subclasses must implement noop_context_fn")
1237 @staticmethod
1238 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
1239 """Create contexts for selective activation checkpointing.
1241 Args:
1242 policy_fn_or_list: A policy function or list of layer names to checkpoint.
1243 allow_cache_entry_mutation (bool): Whether to allow cache entry mutation.
1245 Returns:
1246 Context functions for selective checkpointing.
1247 """
1248 raise NotImplementedError("Platform subclasses must implement create_selective_checkpoint_contexts")
1250 @staticmethod
1251 def async_save_on_cpu(policy_fn=None):
1252 """Create an async CPU offload context for activation checkpointing.
1254 Args:
1255 policy_fn: Optional policy function to determine which activations to offload.
1257 Returns:
1258 Context manager for async CPU offloading during checkpointing.
1259 """
1260 raise NotImplementedError("Platform subclasses must implement async_save_on_cpu")
1262 @staticmethod
1263 def get_element_size(tensor):
1264 """Get Tensor Element Size"""
1265 raise NotImplementedError("Platform subclasses must implement get_element_size")
1267 @staticmethod
1268 def tensor_to_numpy(tensor) -> np.ndarray:
1269 """Convert a framework tensor to a NumPy array.
1271 Args:
1272 tensor: The tensor to convert.
1274 Returns:
1275 np.ndarray: The tensor data as a NumPy array.
1276 """
1277 raise NotImplementedError("Platform subclasses must implement tensor_to_numpy")
1279 @staticmethod
1280 def profiler_record(name):
1281 """Record a profiler event with the given name.
1283 Args:
1284 name (str): The name of the profiler event.
1286 Returns:
1287 A context manager or decorator for profiling a code region.
1288 """
1289 raise NotImplementedError("Platform subclasses must implement profiler_record")
1291 def cast_fp_tensor(self, dtype, x):
1292 """Cast floating-point tensor to target dtype if applicable.
1294 Args:
1295 dtype: The target dtype to cast to.
1296 x: The input tensor.
1298 Returns:
1299 The tensor cast to target dtype, or unchanged if not floating-point.
1300 """
1301 raise NotImplementedError("Platform subclasses must implement cast_fp_tensor")
1303 def apply_to_tensors(self, fn, container):
1304 """Recursively apply a function to all tensors in a container.
1306 Supports nested structures including lists, tuples, and dicts.
1308 Args:
1309 fn (callable): Function to apply to each tensor.
1310 container: Nested structure containing tensors.
1312 Returns:
1313 The same structure with fn applied to all tensors.
1314 """
1315 raise NotImplementedError("Platform subclasses must implement apply_to_tensors")
1317 @staticmethod
1318 def clip_grad_norm_(
1319 parameters, max_norm: float, norm_type: float = 2.0,
1320 error_if_nonfinite: bool = False, foreach=None,
1321 ):
1322 """Compute and clip gradient norms for distributed models.
1324 Communication is derived from each parameter's DTensor spec.
1325 Subclasses must implement this method.
1327 Args:
1328 parameters: An ``nn.Module``, a single ``Tensor``, or an
1329 iterable of ``Tensor`` s whose gradients to clip.
1330 max_norm: Maximum allowed gradient norm.
1331 norm_type: Type of the norm (default ``2.0``).
1332 error_if_nonfinite: If ``True``, raise when total norm is
1333 non-finite. Default ``False``.
1334 foreach: Unused, accepted for API compatibility.
1336 Returns:
1337 The total (unclipped) gradient norm.
1338 """
1339 raise NotImplementedError(
1340 "Platform subclasses must implement clip_grad_norm_"
1341 )
1343 @staticmethod
1344 def get_created_group(rank_list: Union[list[int], tuple[int]]):
1345 """Get an existing process group by rank list.
1347 Args:
1348 rank_list (Union[list[int], tuple[int]]): Tuple or list of ranks.
1350 Returns:
1351 The process group corresponding to the rank list if it exists, else None.
1352 """
1353 group_key = str(tuple(sorted(rank_list)))
1354 if group_key in EXISTING_COMM_GROUPS:
1355 return EXISTING_COMM_GROUPS[group_key]
1356 return None
1358 @classmethod
1359 def mark_created_groups(cls, process_group: Union[Any, list[Any]]) -> None:
1360 """Register process groups in the global cache for reuse.
1362 Args:
1363 process_group (Union[Any, list[Any]]): A process group or a list of process groups.
1364 """
1365 if not isinstance(process_group, list):
1366 process_group = [process_group]
1367 for group in process_group:
1368 rank_list = cls.get_process_group_ranks(group)
1369 group_key = str(tuple(sorted(rank_list)))
1370 EXISTING_COMM_GROUPS[group_key] = group
1372 @property
1373 def meta_device(self):
1374 """Get the framework-specific meta device for tensor shape inference.
1376 The meta device allows creating tensors without allocating actual storage,
1377 useful for shape inference and model initialization.
1379 Returns:
1380 The meta device object for the current framework.
1381 """
1382 raise NotImplementedError("Platform subclasses must implement meta_device")
1384 def init_on_device(self, device, include_buffers=False):
1385 """Get a context manager for initializing module parameters on a device.
1387 Args:
1388 device: The target device for parameter initialization.
1389 include_buffers (bool): If True, also initialize buffers on the device.
1391 Returns:
1392 A context manager for device-specific initialization.
1393 """
1394 raise NotImplementedError("Platform subclasses must implement init_on_device")
1396 def str_to_dtype(self, dtype_str: str) -> Any:
1397 """
1398 Map a framework-style dtype string (e.g. ``torch.float32``) to the backend dtype object.
1400 Args:
1401 dtype_str (str): Serialized dtype identifier produced by checkpoint metadata.
1403 Returns:
1404 Framework dtype object (e.g. ``torch.dtype`` or MindSpore dtype).
1405 """
1406 raise NotImplementedError("Platform subclasses must implement str_to_dtype")
1408 def list_to_size(self, size_list: list[int]) -> Any:
1409 """
1410 Convert a shape list from checkpoint metadata to the framework's size type (e.g. ``torch.Size``).
1412 Args:
1413 size_list (list[int]): Tensor global shape as a list of ints.
1415 Returns:
1416 Framework-specific size object.
1417 """
1418 raise NotImplementedError("Platform subclasses must implement list_to_size")