Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / platform.py: 44%
548 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MindSpore platform api"""
16from datetime import timedelta
17from typing import Any, Optional, Union
18import dataclasses
19from collections import OrderedDict
21import contextlib
22import numpy as np
23import mindspore as ms
24import mindspore.common.dtype as mstype
25from mindspore.mint.distributed import TCPStore
27from mindspore.nn import Cell
28from mindspore import mint
29from mindspore.common.api import _no_grad
30from mindspore.common._grad_function import _Function
31from mindspore.common.dtype import type_size_in_bytes
32from mindspore.common.parameter import Parameter
33from mindspore.common.tensor import Tensor
34from mindspore.common.initializer import initializer
35from mindspore.common.recompute import null_context_fn
36from mindspore.communication import GlobalComm
37from mindspore.communication import get_group_size
38from mindspore.communication import create_group as new_group
39from mindspore.communication import get_rank as get_rank_id
40from mindspore.ops import communication as ops_comm
41from mindspore.ops.function import comm_func
42from mindspore._c_expression import TensorTransform
43import mindspore.mint.distributed as dist
45from hyper_parallel.platform.platform import Platform, PlatformType, EXISTING_COMM_GROUPS
46from hyper_parallel.platform.mindspore.dtensor import DTensorBase
47from hyper_parallel.platform.mindspore.pipeline_parallel.stage import PipelineStageBase
48from hyper_parallel.platform.mindspore.parameter_init import init_parameters as _init_parameters
49from hyper_parallel.platform.mindspore.init_weights import (
50 init_on_device as _init_on_device,
51 _install_cell_to_empty_patch,
52)
54comm_func.set_comm_ops_inplace(False)
55_tensor_transform = TensorTransform.get_instance()
58# pylint: disable=C0103
61def _a2a_reconstruct_ms(out_perm: Tensor, concat_dim: int) -> Tensor:
62 """Reconstruct A2A result from raw out_perm buffer."""
63 new_ndim = out_perm.dim()
64 chunk_in_perm = concat_dim + 1
65 recon_perm = list(range(1, chunk_in_perm)) + [0] + list(range(chunk_in_perm, new_ndim))
66 x_recon = out_perm.permute(recon_perm).contiguous()
67 shape = list(x_recon.shape)
68 merged = shape[concat_dim] * shape[concat_dim + 1]
69 return x_recon.reshape(shape[:concat_dim] + [merged] + shape[concat_dim + 2:])
72def _normalize_all_to_all_single_result(result, output: Tensor) -> tuple[Tensor, object]:
73 """Normalize MindSpore all_to_all_single return values to ``(output, handle)``."""
74 if isinstance(result, tuple):
75 if len(result) != 2:
76 raise ValueError(
77 "mindspore all_to_all_single returned an unexpected tuple "
78 f"with length {len(result)}"
79 )
80 return result
81 return output, result
84def _mindspore_all_to_all_single(input_tensor: Tensor, output_shape, group, async_op=False) -> tuple[Tensor, object]:
85 """Launch MindSpore all_to_all_single and normalize return values."""
86 output = mint.empty(tuple(output_shape), dtype=input_tensor.dtype)
87 result = ops_comm.all_to_all_single(output, input_tensor, group=group, async_op=async_op)
88 normalized_output, handle = _normalize_all_to_all_single_result(result, output)
89 if not async_op:
90 return normalized_output, None
91 return normalized_output, handle
94class _MSAsyncA2AFunction(_Function):
95 """Differentiable wrapper for pre-launched async all-to-all."""
97 @staticmethod
98 def forward(ctx, x, work, out_perm, group, world_size, concat_dim, split_dim, handle_box): # pylint: disable=arguments-differ
99 """Wait for pre-launched async A2A and return reconstructed output."""
100 ctx.group = group
101 ctx.world_size = world_size
102 ctx.concat_dim = concat_dim
103 ctx.split_dim = split_dim
104 ctx.handle_box = handle_box
105 ctx.x_shape = tuple(x.shape)
106 work.wait()
107 return _a2a_reconstruct_ms(out_perm, concat_dim)
109 @staticmethod
110 def backward(ctx, grad_output):
111 """Launch async head->seq A2A for backward overlap, or return zero grad."""
112 if ctx.handle_box is not None:
113 g = grad_output.contiguous()
114 shape = list(g.shape)
115 seq_dim = ctx.concat_dim
116 s_full = shape[seq_dim]
117 ndim = len(shape) + 1
118 x_perm = g.reshape(
119 shape[:seq_dim] + [ctx.world_size, s_full // ctx.world_size] + shape[seq_dim + 1:]
120 ).permute(
121 [seq_dim] + list(range(seq_dim)) + list(range(seq_dim + 1, ndim))
122 ).contiguous()
123 out_perm, work = _mindspore_all_to_all_single(
124 x_perm,
125 list(x_perm.shape),
126 ctx.group,
127 async_op=True,
128 )
129 ctx.handle_box.append((work, out_perm))
130 return mint.zeros(ctx.x_shape, dtype=grad_output.dtype), None, None, None, None, None, None, None
133class MindSporePlatform(Platform):
134 """MindSpore platform api"""
135 Tensor = Tensor
136 tensor = Tensor
137 Parameter = Parameter
138 Module = Cell
139 DTensorBase = DTensorBase
140 PipelineStageBase = PipelineStageBase
141 platform_type = PlatformType.MINDSPORE
142 tensor_dtype = mstype
143 dtype = ms.Type
144 Function = _Function
146 def __init__(self):
147 # Ensure MindSpore ``nn.Cell.to_empty`` is patched as soon as the
148 # MindSpore platform instance is created.
149 _install_cell_to_empty_patch()
151 @staticmethod
152 def is_linear_module(module) -> bool:
153 """Check whether *module* is a MindSpore ``Dense`` (linear) or ``mint.nn.Linear`` layer."""
154 return isinstance(module, (ms.nn.Dense, mint.nn.Linear))
156 @staticmethod
157 def is_embedding_module(module) -> bool:
158 """Check whether *module* is a MindSpore ``Embedding`` or ``mint.nn.Embedding`` layer."""
159 return isinstance(module, (ms.nn.Embedding, mint.nn.Embedding))
161 def device_count(self, device_handle):
162 """
163 Get the number of available devices.
165 Args:
166 device_handle: The device handle (e.g., ms.device_context).
168 Returns:
169 int: The number of available devices.
170 """
171 device_type = self.device_type()
172 if device_type == "cpu":
173 return device_handle.device_context.cpu.device_count()
174 if device_type == "gpu":
175 return device_handle.device_context.gpu.device_count()
176 return device_handle.device_context.ascend.device_count()
178 @staticmethod
179 def get_rng_state(device=None, device_handle=None):
180 """
181 Get the random number generator state.
183 Args:
184 device (Optional): The device to get RNG state from (not used in MindSpore).
185 device_handle (Optional): The device handle (not used in MindSpore).
187 Returns:
188 Tensor: The RNG state as a tensor.
189 """
190 _ = device, device_handle
191 return ms.get_rng_state()
193 @staticmethod
194 def set_rng_state(state, device=None, device_handle=None):
195 """
196 Set the random number generator state.
198 Args:
199 state (Tensor): The RNG state to set.
200 device (Optional): The device to set RNG state for (not used in MindSpore).
201 device_handle (Optional): The device handle (not used in MindSpore).
202 """
203 _ = device, device_handle
204 return ms.set_rng_state(state)
206 def device_type(self):
207 """
208 Get the current device type.
210 Returns:
211 str: The device type string ("npu" for Ascend, "gpu" for GPU, "cpu" for CPU).
212 """
213 device_type = ms.get_context("device_target")
214 if device_type == "Ascend":
215 return "npu"
216 return device_type.lower()
218 def device(self, device_idx=None):
219 """
220 Get the device type string.
222 Args:
223 device_idx (Optional[int]): The device index (not used in MindSpore).
225 Returns:
226 str: The device type string.
227 """
228 _ = device_idx
229 device_type = self.device_type()
230 return device_type
232 @staticmethod
233 def get_device_handle():
234 """
235 Get the MindSpore module as the device handle.
237 Returns:
238 module: The mindspore module.
239 """
240 return ms
242 @staticmethod
243 def manual_seed(seed):
244 """
245 Set the random seed for reproducibility.
247 Args:
248 seed (int): The random seed value.
250 Returns:
251 None
252 """
253 return ms.manual_seed(seed)
255 @staticmethod
256 def ones(size, dtype=None):
257 """
258 Create a tensor filled with ones.
260 Args:
261 size (tuple): The shape of the output tensor.
262 dtype (Optional[ms.Type]): The desired data type.
264 Returns:
265 Tensor: A tensor filled with ones.
266 """
267 return mint.ones(size, dtype=dtype)
269 @staticmethod
270 def zeros(size, dtype=None, device=None):
271 """
272 Create a tensor filled with zeros.
274 Args:
275 size (tuple): The shape of the output tensor.
276 dtype (Optional[ms.Type]): The desired data type.
277 device (Optional[ms.device]): The device to create the tensor on.
279 Returns:
280 Tensor: A tensor filled with zeros.
281 """
282 tensor = mint.zeros(size, dtype=dtype)
283 if device in ("GPU", "Ascend"):
284 return tensor.to(device)
285 return tensor
287 @staticmethod
288 def full(size, fill_value, dtype=None):
289 """
290 Create a tensor filled with a scalar value.
292 Args:
293 size (tuple): The shape of the output tensor.
294 fill_value (scalar): The value to fill the tensor with.
295 dtype (Optional[ms.Type]): The desired data type.
297 Returns:
298 Tensor: A tensor filled with the specified value.
299 """
300 return mint.full(size, fill_value, dtype=dtype)
302 @staticmethod
303 def empty(size, dtype=None):
304 """
305 Create an uninitialized tensor.
307 Args:
308 size (tuple): The shape of the output tensor.
309 dtype (Optional[ms.Type]): The desired data type.
311 Returns:
312 Tensor: An uninitialized tensor.
313 """
314 return mint.empty(size, dtype=dtype)
316 @staticmethod
317 def get_rank():
318 """
319 Get the rank of the current process in the distributed group.
321 Returns:
322 int: The rank of the current process.
323 """
324 return get_rank_id()
326 @staticmethod
327 def get_global_rank(group, group_rank):
328 """
329 Get the global rank from a group rank.
331 Args:
332 group (str): The process group name.
333 group_rank (int): The rank within the group.
335 Returns:
336 int: The global rank.
337 """
338 return dist.get_global_rank(group, group_rank)
340 @staticmethod
341 def get_world_size():
342 """
343 Get the total number of processes in the distributed group.
345 Returns:
346 int: The world size.
347 """
348 return get_group_size()
350 @staticmethod
351 def get_op_name(func):
352 """
353 Extract the operation name from a function.
355 Args:
356 func: The function to extract the name from.
358 Returns:
359 str: The operation name.
360 """
361 return func.name
363 @staticmethod
364 def differentiable_all_gather_concat(data, group, concat_size, concat_dim):
365 output, _ = comm_func.all_gather_into_tensor(None, data, group=group)
366 if concat_dim == 0:
367 return output
368 output_tensors = ms.ops.Split(output_num=concat_size)(output)
369 return ms.mint.concat(output_tensors, concat_dim)
371 @staticmethod
372 def chunk(data, split_dim, split_size, index):
373 return ms.ops.Split(axis=split_dim, output_num=split_size)(data)[index]
375 @staticmethod
376 def differentiable_all_to_all(input_data, output_shape, group):
377 output_tensor, _ = comm_func.all_to_all_single(
378 output_shape,
379 input_data,
380 group=group,
381 async_op=False
382 )
383 return output_tensor
385 @staticmethod
386 def tensor_type_cast(input_data, cast_type):
387 """Cast tensor to specified data type."""
388 type_mapping = {
389 'float32': ms.float32,
390 'float16': ms.float16,
391 'int64': ms.int64,
392 'int32': ms.int32
393 }
394 if cast_type not in type_mapping:
395 raise ValueError(f"Unknown cast type: {cast_type}. Supported types: {list(type_mapping.keys())}")
396 return input_data.to(type_mapping[cast_type])
398 @staticmethod
399 def differentiable_all_reduce(data, op, group):
400 output, _ = comm_func.all_reduce(data, op, group)
401 return output
403 @staticmethod
404 def differentiable_reduce_scatter(data, dev_num, axis, op, group):
405 if axis > 0:
406 data = ms.mint.concat(ms.ops.Split(axis=axis, output_num=dev_num)(data), dim=0)
407 output_tensor, _ = comm_func.reduce_scatter_tensor(None, data, 'sum', group)
408 if op == 'avg':
409 output_tensor = output_tensor / dev_num
410 return output_tensor
412 @staticmethod
413 def init_parameters(module, stage_index):
414 return _init_parameters(module, stage_index)
416 # pylint: disable=W0212
417 @staticmethod
418 def update_param_data(param, data):
419 """update param data"""
420 if isinstance(param, DTensorBase):
421 param.set_data(data)
422 else:
423 param._update_data(data)
425 @staticmethod
426 def load_into_param(param, data):
427 copy_tensor = MindSporePlatform.empty_like(data)
428 copy_tensor.copy_(data)
429 if isinstance(param, DTensorBase):
430 param.set_data(copy_tensor)
431 else:
432 param._update(copy_tensor)
434 @staticmethod
435 def get_cell_construct(cell):
436 return cell.construct
438 @staticmethod
439 def get_cells_and_names(cell):
440 return cell.cells_and_names()
442 @staticmethod
443 def search_parameter_by_name(cell, param_name: str):
444 """
445 Find the parent Module of the parameter, the parameter's name in the parent Module, and the parameter.
446 Return value: (parent Module instance, parameter's name in parent Module, parameter object).
447 Returns None if not found.
448 """
449 # Remove the "self." prefix from param_name (to maintain compatibility with original logic)
450 param_name = param_name.replace("self.", "")
451 # Case 1: The parameter is a direct parameter of the current Module (not in any sub-Module)
452 if param_name in cell._params:
453 return (cell, param_name, cell._params[param_name])
455 # Case 2: The parameter is in a sub-Module (supports multi-level nesting, e.g., "net_b.dense1.weight")
456 if "." in param_name:
457 # Split into: sub-Module path + parameter name (e.g., "net_b.dense1" + "weight")
458 cell_path, param_key = param_name.rsplit(".", 1)
459 try:
460 # Locate the sub-Module where the parameter resides (supports multi-level paths)
461 target_cell = cell.get_sub_cell(cell_path)
462 # Check if the sub-Module directly contains this parameter
463 if param_key in target_cell._params:
464 return target_cell, param_key, target_cell._params[param_key]
465 except AttributeError:
466 # Sub-Module path does not exist or the parameter is not in that sub-Module
467 pass
469 # Traverse all sub-Modules (recursively) to search for the parameter
470 for _, child_cell in cell._cells.items():
471 if isinstance(child_cell, Cell):
472 # Recursively search within the sub-Module
473 result = MindSporePlatform.search_parameter_by_name(child_cell, param_name)
474 if result is not None:
475 return result
477 return None
479 @staticmethod
480 def update_parameter_by_name(cell, result: tuple, new_param) -> bool:
481 """
482 Modify the original parameter in a Module or sub-Module using the search result
483 Args:
484 cell: The cell which parameter is to update
485 result: A tuple contains parent Module, parameter key and old parameter.
486 new_param: New Parameter object (used to replace the original parameter)
487 """
488 parent_cell, param_key, _ = result
489 # Key operation: directly modify the _params dictionary of the parent Module (original storage location)
490 parent_cell._params[param_key] = new_param
492 if param_key in parent_cell.__dict__:
493 parent_cell.__dict__[param_key] = new_param
494 parent_cell._params_list[param_key] = new_param
495 return True
497 @staticmethod
498 def set_layout_into_parameter(param, layout):
499 """Set layout in to parameter"""
500 from hyper_parallel.core.dtensor.dtensor import DTensor # pylint: disable=import-outside-toplevel
501 from hyper_parallel.core.dtensor.layout import _infer_slice_shape_by_layout, \
502 _get_slice_tensor_by_layout # pylint: disable=import-outside-toplevel
503 if isinstance(param, DTensor):
504 raise ValueError(f"Parameter {param.name} has been configured layout, cannot be set repeatedly.")
505 param_info = param.param_info
506 requires_grad = param.requires_grad
507 name = param.name
508 slice_shape = _infer_slice_shape_by_layout(param.shape, layout)
510 if not param.has_init:
511 # has been init, get slice data
512 param_dtensor = DTensor.from_local(
513 _get_slice_tensor_by_layout(param, layout).value(), layout.mesh, layout.alias_placements
514 )
515 param = Parameter(param_dtensor, name=name, requires_grad=requires_grad)
516 param.param_info = param_info
517 else:
518 # has not been init, need to modify init shape
519 param.init_mode.shape = slice_shape
520 param_dtensor = DTensor.from_local(param.init_mode, layout.mesh, layout.alias_placements)
521 param = Parameter(param_dtensor, name=name, requires_grad=requires_grad)
522 param.param_info = param_info
523 return param
525 @staticmethod
526 def get_param_local_shape(param):
527 """get param local shape"""
528 if isinstance(param, DTensorBase):
529 return param.local_shape
530 return param.shape
532 @staticmethod
533 def get_param_local_data(param):
534 """get param local shape"""
535 if isinstance(param, DTensorBase):
536 return param.to_local()
537 return param
539 @staticmethod
540 def get_param_type_size(param):
541 return type_size_in_bytes(param.dtype)
543 @staticmethod
544 def is_tensor(obj: Any) -> bool:
545 """Return True if ``obj`` is a ``mindspore.Tensor``."""
546 return isinstance(obj, Tensor)
548 @staticmethod
549 def get_tensor_storage_size(tensor: Any) -> int:
550 """Return serialized byte size (numel * itemsize) for a MindSpore tensor."""
551 if not MindSporePlatform.is_tensor(tensor):
552 raise TypeError(
553 f"MindSporePlatform.get_tensor_storage_size expects mindspore.Tensor, got {type(tensor)!r}"
554 )
555 return int(tensor.numel()) * int(tensor.itemsize)
557 @staticmethod
558 def new_zero_parameter(param_shape, param_type, requires_grad, device):
559 param = Parameter(initializer("zeros", param_shape, param_type), requires_grad=requires_grad)
560 if device in ("GPU", "Ascend"):
561 return param.to(device)
562 return param
564 @staticmethod
565 def new_tensor(tensor_shape, tensor_type, device):
566 tensor = Tensor(shape=tensor_shape, dtype=tensor_type)
567 if device in ("GPU", "Ascend"):
568 return tensor.to(device)
569 return tensor
571 @staticmethod
572 def full_like(tensor, fill_value, dtype=None):
573 return mint.full_like(tensor, fill_value, dtype=dtype)
575 @staticmethod
576 def isend(tensor, dst=None, group=None, tag=0):
577 return dist.isend(tensor, dst, group, tag)
579 @staticmethod
580 def irecv(tensor, src=None, group=None, tag=0):
581 return dist.irecv(tensor, src, group, tag)
583 @staticmethod
584 def p2p_exchange(tensor, peer_rank: int, group=None): # pylint: disable=unused-argument
585 raise NotImplementedError(
586 "p2p_exchange is not yet supported on the MindSpore platform."
587 )
589 @staticmethod
590 def send_object_list(obj_list, dst=None, group=None):
591 # pylint: disable=C0415
592 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import send_object_list
593 send_object_list(obj_list, dst, group)
595 @staticmethod
596 def recv_object_list(obj_list, src=None, group=None):
597 # pylint: disable=C0415
598 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import recv_object_list
599 recv_object_list(obj_list, src, group)
601 @staticmethod
602 def set_tensor_requires_grad(input_tensor):
603 """
604 set requires grad flag for input tensor
605 """
606 input_tensor.requires_grad_()
608 def _create_group(self, rank_list):
609 world_group = self._maybe_reuse_world_group(rank_list)
610 if world_group is not None:
611 return world_group
613 group_name = str(tuple(sorted(rank_list)))
614 new_group(rank_ids=rank_list, group=group_name)
615 EXISTING_COMM_GROUPS[group_name] = group_name
616 return group_name
618 @staticmethod
619 def all_gather_into_tensor(data, group_info, async_op=False):
620 return comm_func.all_gather_into_tensor(None, data, group=group_info.group_name, async_op=async_op)
622 @staticmethod
623 def all_reduce(data, group_info, async_op=False):
624 if isinstance(group_info, str):
625 handle = dist.all_reduce(data, group=group_info, async_op=async_op)
626 else:
627 handle = dist.all_reduce(data, group=group_info.group_name, async_op=async_op)
628 return data, handle
630 @staticmethod
631 def broadcast(data, src, group=None, async_op=False):
632 handle = dist.broadcast(data, src, group, async_op)
633 if async_op:
634 handle.wait()
635 return data
637 @staticmethod
638 def reduce_scatter_tensor(data, group_info, async_op=False):
639 return comm_func.reduce_scatter_tensor(None, data, group=group_info.group_name, async_op=async_op)
641 @staticmethod
642 def all_to_all_single(input_tensor, output_shape, group, async_op=False):
643 return _mindspore_all_to_all_single(input_tensor, output_shape, group, async_op=async_op)
645 @staticmethod
646 def differentiable_async_a2a_wait(x, work, out_perm, group, world_size, concat_dim, split_dim, # pylint: disable=unused-argument
647 handle_box=None):
648 return _MSAsyncA2AFunction.apply(
649 x, work, out_perm, group, world_size, concat_dim, split_dim, handle_box
650 )
652 @staticmethod
653 def parameters_dict(cell: Cell):
654 return cell.parameters_and_names()
656 @staticmethod
657 def get_tensor_transform():
658 return _tensor_transform
660 @staticmethod
661 def construct_strided_slice(x, begin, end, stride):
662 return ms.ops.strided_slice(x, begin, end, stride)
664 @staticmethod
665 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None):
666 # pylint: disable=C0415
667 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import _MicroBatch
668 return _MicroBatch(micro_batch_num, args_batch_dim, kwargs_batch_dim)
670 @staticmethod
671 def get_model_state_dict(model, *, options=None):
672 raise NotImplementedError(
673 "get_model_state_dict is not yet supported on MindSpore"
674 )
676 @staticmethod
677 def save_checkpoint(cell: Union[Cell, dict], file_path: str, ckpt_format: str = "safetensors") -> None:
678 if isinstance(cell, dict):
679 save_dict = {}
680 for k, v in cell.items():
681 if isinstance(v, Parameter):
682 save_dict[k] = v
683 elif isinstance(v, Tensor):
684 save_dict[k] = Parameter(v, name=k)
685 else:
686 save_dict[k] = v
687 else:
688 save_dict = cell._params
689 ms.save_checkpoint(save_obj=save_dict, ckpt_file_name=file_path, format=ckpt_format)
691 @staticmethod
692 def load_checkpoint(file_path: str, ckpt_format: str = "safetensors") -> dict:
693 return ms.load_checkpoint(ckpt_file_name=file_path, format=ckpt_format)
695 @staticmethod
696 def get_symmetric_memory_handler():
697 # pylint: disable=C0415
698 from hyper_parallel.platform.mindspore.symmetric_memory import MSSymmetricMemoryHandler
699 symmetric_memory = MSSymmetricMemoryHandler()
700 return symmetric_memory
702 @staticmethod
703 def get_multicore_handler():
704 # pylint: disable=C0415
705 from hyper_parallel.platform.mindspore.multicore import MSMulticoreHandler
706 return MSMulticoreHandler()
708 def new_stream(self):
709 return ms.runtime.Stream()
711 def get_stream_context(self):
712 return ms.runtime.StreamCtx
714 @staticmethod
715 def all_gather_object(object_list, obj, group=None) -> None:
716 """
717 Gathers objects from the given group into object list.
719 Args:
720 object_list (list[Any]): Define the output list, which size equal to the size of group.
721 obj (Any): The object on current rank and in given process group.
722 group (ProcessGroup, optional): The process group to gather obj. Default is ``None``, and ``None`` means
723 global group.
725 Returns:
726 None. Objs are gathered into ``object_list``.
727 """
728 dist.all_gather_object(object_list, obj, group)
730 @staticmethod
731 def barrier(group=None, async_op: bool = False, device_ids=None) -> Any:
732 """
733 Synchronize all processes in the given communication group.
735 Args:
736 group (str, optional): The communication group to work on. Default is ``None``,
737 meaning the default world group.
738 async_op (bool, optional): Whether this op should be asynchronous. Default: ``False``.
739 device_ids (list[int], optional): Reserved parameter on Ascend. Default: ``None``.
741 Returns:
742 CommHandle if ``async_op`` is True; otherwise ``None``.
743 """
744 return dist.barrier(group, async_op, device_ids)
746 @staticmethod
747 def init_process_group(
748 backend: str = None,
749 *,
750 init_method: Optional[str] = None,
751 timeout: Optional[timedelta] = None,
752 world_size: int = -1,
753 rank: int = -1,
754 store: TCPStore = None,
755 pg_options=None,
756 device_id=None
757 ) -> None:
758 """
759 Initialize global process group.
761 Args:
762 backend (str): The backend used to init process group. Default is ``"hccl"`` and now only support hccl.
763 init_method (str, optional): URL specifying how to initialize the process group. Default is ``None``.
764 timeout (timedelta, optional): Timeout for API executed. Default is ``None``.
765 world_size (int): Number of processes. Default is ``-1``.
766 rank (int, optional): Rank of the current process. Default is ``-1``.
767 store (Store, optional): An object that stores key/value data, facilitating the exchange of inter-process
768 communication addresses and connection information. Default is ``None``. Currently, only the
769 ``TCPStore`` type is supported.
770 pg_options (ProcessGroupOptions, optional): Reserved parameter. Current not take effect.
771 device_id (int, optional): Reserved parameter. Current not take effect.
772 """
773 if backend is None:
774 backend = "hccl"
775 try:
776 if dist.is_initialized():
777 return
778 except AttributeError:
779 pass
780 dist.init_process_group(backend=backend, init_method=init_method, timeout=timeout, world_size=world_size,
781 rank=rank, store=store, pg_options=pg_options, device_id=device_id)
783 @staticmethod
784 def destroy_process_group(group: Optional[str] = None) -> None:
785 """
786 Destroy given process group.
788 Args:
789 group (str, optional): Specify the group to destroy. Default: ``None`` means ``hccl_world_group``. If group
790 is None or "hccl_world_group", destroy global process group and all process groups relative to global
791 process group.
792 """
793 if group in EXISTING_COMM_GROUPS.values():
794 keys_to_destroy = [k for k, v in EXISTING_COMM_GROUPS.items() if v == group]
795 for k in keys_to_destroy:
796 del EXISTING_COMM_GROUPS[k]
797 dist.destroy_process_group(group)
799 @staticmethod
800 def get_process_group_ranks(group: Optional[str] = None) -> list[int]:
801 """
802 Get all ranks in given process group.
804 Args:
805 group (str, optional): Specify the process group to work on. Default: ``None`` means ``hccl_world_group``.
807 Returns:
808 List[int]: List of ranks in given process group.
809 """
810 return dist.get_process_group_ranks(group)
812 @staticmethod
813 def get_backend(group: Optional[str] = None) -> str:
814 """
815 Get the backend of given process group.
817 Args:
818 group (str, optional): Specify the process group to work on. Default: ``None`` means ``hccl_world_group``.
820 Returns:
821 str: The backend of the group.
822 """
823 return dist.get_backend(group)
825 @staticmethod
826 def split_group(parent_pg: Optional[str] = None,
827 split_ranks: Optional[list] = None,
828 timeout: Optional[timedelta] = None,
829 pg_options: Optional[str] = None,
830 group_desc: Optional[str] = None,
831 ) -> str:
832 """
833 Create split group for a specific group rank in split_ranks, which group contains current rank id.
835 Args:
836 parent_pg (str, Optional): A process group which the goal group split from.
837 split_ranks (Optional[list]): A list like ``list[list[int]]``.
838 timeout (Optional[timedelta]): Timeout for API executed. Default is ``None``.
839 pg_options (Optional[str]): Reserved parameter. Current not take effect.
840 group_desc (Optional[str]): Description of process group.
842 Returns:
843 str: The split group name.
844 """
845 if split_ranks is None or len(split_ranks) == 0:
846 raise ValueError("split_ranks cannot be None or empty")
848 rank_id = MindSporePlatform.get_rank()
849 for split_rank in split_ranks:
850 if rank_id in split_rank:
851 world_group = MindSporePlatform._maybe_reuse_world_group(split_rank)
852 if world_group is not None:
853 return world_group
854 split_group = MindSporePlatform.get_created_group(split_rank)
855 if split_group:
856 return split_group
857 group_name = str(tuple(sorted(split_rank)))
858 new_group(rank_ids=split_rank, group=group_name)
859 EXISTING_COMM_GROUPS[group_name] = group_name
860 return group_name
861 raise ValueError(f"Split group invalid rank, the Split_ranks {split_ranks} does not contain current rank"
862 f" {rank_id}")
864 @staticmethod
865 def get_group_local_rank(group=None) -> int:
866 """get group local rank id."""
867 return dist.get_group_rank(group, MindSporePlatform.get_rank())
869 @staticmethod
870 def no_grad():
871 return _no_grad()
873 @staticmethod
874 def cat(tensors, dim=0):
875 return mint.cat(tensors, dim=dim)
877 @staticmethod
878 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False):
879 return mint.empty_like(tensor, dtype=dtype, device=device, pin_memory=pin_memory)
881 def get_current_stream(self):
882 return ms.runtime.current_stream()
884 def new_event(self):
885 return ms.runtime.Event()
887 def tree_map(self, fn, tree):
888 """
889 Apply fn to each leaf in a nested structure (list / tuple / dict),
890 preserving the original structure.
891 """
892 if isinstance(tree, dict):
893 return type(tree)(
894 (k, self.tree_map(fn, v)) for k, v in tree.items()
895 )
897 if isinstance(tree, tuple):
898 return tuple(self.tree_map(fn, v) for v in tree)
900 if isinstance(tree, list):
901 return [self.tree_map(fn, v) for v in tree]
903 # leaf
904 return fn(tree)
906 @staticmethod
907 def register_forward_pre_hook(module, hook, prepend=False, with_kwargs=False):
908 return module.register_forward_pre_hook(hook, with_kwargs=with_kwargs)
910 @staticmethod
911 def register_full_backward_hook(module, hook, prepend=False):
912 return module.register_backward_hook(hook)
914 @staticmethod
915 def register_full_backward_pre_hook(module, hook, prepend=False):
916 return module.register_backward_pre_hook(hook)
918 @property
919 def checkpoint(self):
920 return ms.recompute
922 @staticmethod
923 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs):
924 # pylint: disable=C0415
925 from hyper_parallel.platform.mindspore.activation_checkpoint.checkpoint_wrapper import checkpoint_wrapper
926 return checkpoint_wrapper(module, checkpoint_fn=checkpoint_fn, **checkpoint_fn_kwargs)
928 @staticmethod
929 def swap_wrapper(module, policy_fn=None):
930 # pylint: disable=C0415
931 from hyper_parallel.platform.mindspore.activation_checkpoint.activation_swap import swap_wrapper
932 return swap_wrapper(module, policy_fn=policy_fn)
934 @property
935 def noop_context_fn(self):
936 return null_context_fn
938 @staticmethod
939 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
940 # pylint: disable=C0415
941 from hyper_parallel.platform.mindspore.activation_checkpoint.sac import create_selective_checkpoint_contexts
942 return create_selective_checkpoint_contexts(policy_fn_or_list,
943 allow_cache_entry_mutation=allow_cache_entry_mutation)
945 @staticmethod
946 def async_save_on_cpu(policy_fn=None):
947 # pylint: disable=C0415
948 from hyper_parallel.platform.mindspore.activation_checkpoint.activation_swap import AsyncSaveOnCpu
949 return AsyncSaveOnCpu(policy_fn=policy_fn)
951 @staticmethod
952 def get_element_size(tensor):
953 """Get Tensor Element Size"""
954 return tensor.itemsize
956 @staticmethod
957 def tensor_to_numpy(tensor) -> np.ndarray:
958 """Convert MindSpore tensor to numpy array."""
959 return tensor.asnumpy()
961 @staticmethod
963 def clip_grad_norm_(
964 parameters, max_norm, norm_type=2.0,
965 error_if_nonfinite=False, foreach=None,
966 ):
967 raise NotImplementedError(
968 "clip_grad_norm_ is not yet supported on MindSpore"
969 )
971 @property
972 def meta_device(self):
973 return "meta"
975 def init_on_device(self, device, include_buffers=False):
976 return _init_on_device(device, include_buffers=include_buffers)
978 def cast_fp_tensor(self, dtype, x):
979 """
980 Cast floating-point tensor to target dtype if applicable.
981 """
982 if (
983 not isinstance(x, ms.Tensor)
984 or not ms.ops.is_floating_point(x)
985 or x.dtype == dtype
986 ):
987 return x
988 return x.to(dtype)
990 def apply_to_tensors(self, fn, container):
991 """Recursively apply to all tensor in different kinds of container types."""
993 def apply(x):
994 if isinstance(x, ms.Tensor):
995 return fn(x)
996 if hasattr(x, "__dataclass_fields__"):
997 dc = dataclasses.replace(x)
998 changes = {
999 f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc)
1000 }
1001 return dataclasses.replace(dc, **changes)
1002 if isinstance(x, OrderedDict):
1003 od = x.__class__()
1004 for key, value in x.items():
1005 od[key] = apply(value)
1006 return od
1007 if isinstance(x, dict):
1008 return {key: apply(value) for key, value in x.items()}
1009 if isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields"):
1010 res = (apply(el) for el in x)
1011 return type(x)(*res)
1012 if isinstance(x, (list, tuple, set)):
1013 return type(x)(apply(el) for el in x)
1014 return x
1016 return apply(container)
1018 @staticmethod
1019 def profiler_record(name):
1020 """Profiler context manager for recording operations using mindspore.profiler."""
1021 return contextlib.nullcontext()
1023 def str_to_dtype(self, dtype_str: str) -> Any:
1024 """Resolve checkpoint dtype strings (``mindspore.*`` or short ``str(Tensor.dtype)`` e.g. ``Float32``)."""
1025 if "." in dtype_str:
1026 prefix, name = dtype_str.split(".", 1)
1027 if prefix == "mindspore":
1028 return getattr(ms, name)
1029 dtype = getattr(ms, dtype_str.lower(), None)
1030 if dtype is not None:
1031 return dtype
1032 raise ValueError(
1033 f"Expected dtype string like 'mindspore.float32' or 'Float32', got {dtype_str!r}."
1034 )
1036 def list_to_size(self, size_list: list[int]) -> tuple[int, ...]:
1037 return tuple(size_list)
1039 @staticmethod
1040 def _maybe_reuse_world_group(rank_list):
1041 """Reuse the default world group for full-world rank lists."""
1042 normalized = tuple(sorted(rank_list))
1043 world_ranks = tuple(range(MindSporePlatform.get_world_size()))
1044 if normalized != world_ranks:
1045 return None
1047 EXISTING_COMM_GROUPS[str(normalized)] = GlobalComm.WORLD_COMM_GROUP
1048 return GlobalComm.WORLD_COMM_GROUP