Coverage for hyper_parallel / platform / mindspore / platform.py: 81%
347 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MindSpore platform api"""
16from datetime import timedelta
17from typing import Optional
19import numpy as np
20import mindspore as ms
21import mindspore.common.dtype as mstype
22from mindspore.mint.distributed import TCPStore
24from mindspore.nn import Cell
25from mindspore import mint
26from mindspore.common.api import _no_grad
27from mindspore.common.dtype import type_size_in_bytes
28from mindspore.common.parameter import Parameter
29from mindspore.common.tensor import Tensor
30from mindspore.common.initializer import initializer
31from mindspore.communication import get_group_size
32from mindspore.communication import create_group as new_group
33from mindspore.communication import get_rank as get_rank_id
34from mindspore.communication import comm_func
35from mindspore._c_expression import TensorTransform
36import mindspore.mint.distributed as dist
38from hyper_parallel.platform.platform import Platform, PlatformType
39from hyper_parallel.platform.mindspore.dtensor import DTensorBase
40from hyper_parallel.platform.mindspore.pipeline_parallel.stage import PipelineStageBase
41from hyper_parallel.platform.mindspore.parameter_init import init_parameters as _init_parameters
43_tensor_transform = TensorTransform.get_instance()
46# pylint: disable=C0103
49class MindSporePlatform(Platform):
50 """MindSpore platform api"""
51 Tensor = Tensor
52 tensor = Tensor
53 Parameter = Parameter
54 Module = Cell
55 DTensorBase = DTensorBase
56 PipelineStageBase = PipelineStageBase
57 platform_type = PlatformType.MINDSPORE
58 tensor_dtype = mstype
60 def device_count(self, device_handle):
61 device_type = self.device_type()
62 if device_type == "cpu":
63 return device_handle.device_context.cpu.device_count()
64 if device_type == "gpu":
65 return device_handle.device_context.gpu.device_count()
66 return device_handle.device_context.ascend.device_count()
68 @staticmethod
69 def get_rng_state(device=None, device_handle=None):
70 """Get RNG state """
71 _ = device, device_handle
72 return ms.get_rng_state()
74 @staticmethod
75 def set_rng_state(state, device=None, device_handle=None):
76 _ = device, device_handle
77 return ms.set_rng_state(state)
79 def device_type(self):
80 device_type = ms.get_context("device_target")
81 if device_type == "Ascend":
82 return "npu"
83 return device_type.lower()
85 def device(self, device_idx=None):
86 _ = device_idx
87 device_type = self.device_type()
88 return device_type
90 @staticmethod
91 def get_device_handle():
92 return ms
94 @staticmethod
95 def manual_seed(seed):
96 return ms.manual_seed(seed)
98 @staticmethod
99 def ones(size, dtype=None):
100 return mint.ones(size, dtype=dtype)
102 @staticmethod
103 def zeros(size, dtype=None):
104 return mint.zeros(size, dtype=dtype)
106 @staticmethod
107 def full(size, fill_value, dtype=None):
108 return mint.full(size, fill_value, dtype=dtype)
110 @staticmethod
111 def empty(size, dtype=None):
112 return mint.empty(size, dtype=dtype)
114 @staticmethod
115 def get_rank():
116 return get_rank_id()
118 @staticmethod
119 def get_global_rank(group, group_rank):
120 return dist.get_global_rank(group, group_rank)
122 @staticmethod
123 def get_world_size():
124 return get_group_size()
126 @staticmethod
127 def get_op_name(func):
128 return func.name
130 @staticmethod
131 def differentiable_all_gather_concat(data, group, concat_size, concat_dim):
132 output, _ = comm_func.all_gather_into_tensor(data, group=group)
133 if concat_dim == 0:
134 return output
135 output_tensors = ms.ops.Split(output_num=concat_size)(output)
136 return ms.mint.concat(output_tensors, concat_dim)
138 @staticmethod
139 def chunk(data, split_dim, split_size, index):
140 return ms.ops.Split(axis=split_dim, output_num=split_size)(data)[index]
142 @staticmethod
143 def differentiable_all_to_all(input_data, output_shape, group):
144 output_tensor, _ = comm_func.all_to_all_single_with_output_shape(
145 output_shape=output_shape,
146 tensor=input_data,
147 group=group,
148 async_op=False
149 )
150 return output_tensor
152 @staticmethod
153 def tensor_type_cast(input_data, cast_type):
154 """Cast tensor to specified data type."""
155 type_mapping = {
156 'float32': ms.float32,
157 'float16': ms.float16,
158 'int64': ms.int64,
159 'int32': ms.int32
160 }
161 if cast_type not in type_mapping:
162 raise ValueError(f"Unknown cast type: {cast_type}. Supported types: {list(type_mapping.keys())}")
163 return input_data.to(type_mapping[cast_type])
165 @staticmethod
166 def differentiable_all_reduce(data, op, group):
167 output, _ = comm_func.all_reduce(data, op, group)
168 return output
170 @staticmethod
171 def differentiable_reduce_scatter(data, dev_num, axis, op, group):
172 if axis > 0:
173 data = ms.mint.concat(ms.ops.Split(axis=axis, output_num=dev_num)(data), dim=0)
174 output_tensor, _ = comm_func.reduce_scatter_tensor(data, 'sum', group)
175 if op == 'avg':
176 output_tensor = output_tensor / dev_num
177 return output_tensor
179 @staticmethod
180 def init_parameters(module, stage_index):
181 return _init_parameters(module, stage_index)
183 # pylint: disable=W0212
184 @staticmethod
185 def update_param_data(param, data):
186 """update param data"""
187 if isinstance(param, DTensorBase):
188 param.set_data(data)
189 else:
190 param._update_data(data)
192 @staticmethod
193 def get_cell_construct(cell):
194 return cell.construct
196 @staticmethod
197 def get_cells_and_names(cell):
198 return cell.cells_and_names()
200 @staticmethod
201 def search_parameter_by_name(cell, param_name: str):
202 """
203 Find the parent Module of the parameter, the parameter's name in the parent Module, and the parameter.
204 Return value: (parent Module instance, parameter's name in parent Module, parameter object).
205 Returns None if not found.
206 """
207 # Remove the "self." prefix from param_name (to maintain compatibility with original logic)
208 param_name = param_name.replace("self.", "")
209 # Case 1: The parameter is a direct parameter of the current Module (not in any sub-Module)
210 if param_name in cell._params:
211 return (cell, param_name, cell._params[param_name])
213 # Case 2: The parameter is in a sub-Module (supports multi-level nesting, e.g., "net_b.dense1.weight")
214 if "." in param_name:
215 # Split into: sub-Module path + parameter name (e.g., "net_b.dense1" + "weight")
216 cell_path, param_key = param_name.rsplit(".", 1)
217 try:
218 # Locate the sub-Module where the parameter resides (supports multi-level paths)
219 target_cell = cell.get_sub_cell(cell_path)
220 # Check if the sub-Module directly contains this parameter
221 if param_key in target_cell._params:
222 return target_cell, param_key, target_cell._params[param_key]
223 except AttributeError:
224 # Sub-Module path does not exist or the parameter is not in that sub-Module
225 pass
227 # Traverse all sub-Modules (recursively) to search for the parameter
228 for _, child_cell in cell._cells.items():
229 if isinstance(child_cell, Cell):
230 # Recursively search within the sub-Module
231 result = MindSporePlatform.search_parameter_by_name(child_cell, param_name)
232 if result is not None:
233 return result
235 return None
237 @staticmethod
238 def update_parameter_by_name(cell, result: tuple, new_param) -> bool:
239 """
240 Modify the original parameter in a Module or sub-Module using the search result
241 Args:
242 cell: The cell which parameter is to update
243 result: A tuple contains parent Module, parameter key and old parameter.
244 new_param: New Parameter object (used to replace the original parameter)
245 """
246 parent_cell, param_key, _ = result
247 # Key operation: directly modify the _params dictionary of the parent Module (original storage location)
248 parent_cell._params[param_key] = new_param
250 if param_key in parent_cell.__dict__:
251 parent_cell.__dict__[param_key] = new_param
252 parent_cell._params_list[param_key] = new_param
253 return True
255 @staticmethod
256 def set_layout_into_parameter(param, layout):
257 """Set layout in to parameter"""
258 from hyper_parallel.core.dtensor import DTensor # pylint: disable=import-outside-toplevel
259 from hyper_parallel.core.layout import _infer_slice_shape_by_layout, \
260 _get_slice_tensor_by_layout # pylint: disable=import-outside-toplevel
261 if isinstance(param, DTensor):
262 raise ValueError(f"Parameter {param.name} has been configured layout, cannot be set repeatedly.")
263 param_info = param.param_info
264 requires_grad = param.requires_grad
265 name = param.name
266 slice_shape = _infer_slice_shape_by_layout(param.shape, layout)
268 if not param.has_init:
269 # has been init, get slice data
270 param_dtensor = DTensor.from_local(
271 _get_slice_tensor_by_layout(param, layout).value(), layout.mesh, layout.placements
272 )
273 param = Parameter(param_dtensor, name=name, requires_grad=requires_grad)
274 param.param_info = param_info
275 else:
276 # has not been init, need to modify init shape
277 param.init_mode.shape = slice_shape
278 param_dtensor = DTensor.from_local(param.init_mode, layout.mesh, layout.placements)
279 param = Parameter(param_dtensor, name=name, requires_grad=requires_grad)
280 param.param_info = param_info
281 return param
283 @staticmethod
284 def get_param_local_shape(param):
285 """get param local shape"""
286 if isinstance(param, DTensorBase):
287 return param.local_shape
288 return param.shape
290 @staticmethod
291 def get_param_local_data(param):
292 """get param local shape"""
293 if isinstance(param, DTensorBase):
294 return param.to_local()
295 return param
297 @staticmethod
298 def get_param_type_size(param):
299 return type_size_in_bytes(param.dtype)
301 @staticmethod
302 def new_zero_parameter(param_shape, param_type, requires_grad, device):
303 param = Parameter(initializer("zeros", param_shape, param_type), requires_grad=requires_grad)
304 if device in ("GPU", "Ascend"):
305 return param.to(device)
306 return param
308 @staticmethod
309 def new_tensor(tensor_shape, tensor_type, device):
310 tensor = Tensor(shape=tensor_shape, dtype=tensor_type)
311 if device in ("GPU", "Ascend"):
312 return tensor.to(device)
313 return tensor
315 @staticmethod
316 def full_like(tensor, fill_value, dtype=None):
317 return mint.full_like(tensor, fill_value, dtype=dtype)
319 @staticmethod
320 def isend(tensor, dst=None, group=None, tag=0):
321 return dist.isend(tensor, dst, group, tag)
323 @staticmethod
324 def irecv(tensor, src=None, group=None, tag=0):
325 return dist.irecv(tensor, src, group, tag)
327 @staticmethod
328 def send_object_list(obj_list, dst=None, group=None):
329 # pylint: disable=C0415
330 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import send_object_list
331 send_object_list(obj_list, dst, group)
333 @staticmethod
334 def recv_object_list(obj_list, src=None, group=None):
335 # pylint: disable=C0415
336 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import recv_object_list
337 recv_object_list(obj_list, src, group)
339 @staticmethod
340 def set_tensor_requires_grad(input_tensor):
341 """
342 set requires grad flag for input tensor
343 """
344 input_tensor.requires_grad_()
346 def _create_group(self, rank_list, group_name=None):
347 if group_name is None:
348 hash_str_rank_list = '-'.join([str(rank) for rank in rank_list])
349 group_name = f"{len(rank_list)}-{hash_str_rank_list}"
350 new_group(rank_ids=rank_list, group=group_name)
351 return group_name
353 @staticmethod
354 def all_gather_into_tensor(data, group_info, async_op=False):
355 return comm_func.all_gather_into_tensor(data, group=group_info.group_name, async_op=async_op)
357 @staticmethod
358 def all_reduce(data, group_info, async_op=False):
359 if isinstance(group_info, str):
360 handle = dist.all_reduce(data, group=group_info, async_op=async_op)
361 else:
362 handle = dist.all_reduce(data, group=group_info.group_name, async_op=async_op)
363 return data, handle
365 @staticmethod
366 def broadcast(data, src, group=None, async_op=False):
367 handle = dist.broadcast(data, src, group, async_op)
368 if async_op:
369 handle.wait()
370 return data
372 @staticmethod
373 def reduce_scatter_tensor(data, group_info, async_op=False):
374 return comm_func.reduce_scatter_tensor(data, group=group_info.group_name, async_op=async_op)
376 @staticmethod
377 def parameters_dict(cell: Cell):
378 return cell.parameters_and_names()
380 @staticmethod
381 def get_tensor_transform():
382 return _tensor_transform
384 @staticmethod
385 def construct_strided_slice(x, begin, end, stride):
386 return ms.ops.strided_slice(x, begin, end, stride)
388 @staticmethod
389 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None):
390 # pylint: disable=C0415
391 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import _MicroBatch
392 return _MicroBatch(micro_batch_num, args_batch_dim, kwargs_batch_dim)
394 @staticmethod
395 def save_checkpoint(cell: Cell, file_path: str) -> None:
396 save_dict = cell._params
397 ms.save_checkpoint(save_obj=save_dict, ckpt_file_name=file_path, format="safetensors")
399 @staticmethod
400 def load_checkpoint(file_path: str) -> dict:
401 return ms.load_checkpoint(ckpt_file_name=file_path, format="safetensors")
403 def new_stream(self):
404 return ms.runtime.Stream()
406 def get_stream_context(self):
407 return ms.runtime.StreamCtx
409 @staticmethod
410 def all_gather_object(object_list, obj, group=None) -> None:
411 """
412 Gathers objects from the given group into object list.
414 Args:
415 object_list (list[Any]): Define the output list, which size equal to the size of group.
416 obj (Any): The object on current rank and in given process group.
417 group (ProcessGroup, optional): The process group to gather obj. Default is ``None``, and ``None`` means
418 global group.
420 Returns:
421 None. Objs are gathered into ``object_list``.
422 """
423 dist.all_gather_object(object_list, obj, group)
425 @staticmethod
426 def init_process_group(
427 backend: str = None,
428 *,
429 init_method: Optional[str] = None,
430 timeout: Optional[timedelta] = None,
431 world_size: int = -1,
432 rank: int = -1,
433 store: TCPStore = None,
434 pg_options=None,
435 device_id=None
436 ) -> None:
437 """
438 Initialize global process group.
440 Args:
441 backend (str): The backend used to init process group. Default is ``"hccl"`` and now only support hccl.
442 init_method (str, optional): URL specifying how to initialize the process group. Default is ``None``.
443 timeout (timedelta, optional): Timeout for API executed. Default is ``None``.
444 world_size (int): Number of processes. Default is ``-1``.
445 rank (int, optional): Rank of the current process. Default is ``-1``.
446 store (Store, optional): An object that stores key/value data, facilitating the exchange of inter-process
447 communication addresses and connection information. Default is ``None``. Currently, only the
448 ``TCPStore`` type is supported.
449 pg_options (ProcessGroupOptions, optional): Reserved parameter. Current not take effect.
450 device_id (int, optional): Reserved parameter. Current not take effect.
451 """
452 if backend is None:
453 backend = "hccl"
454 dist.init_process_group(backend=backend, init_method=init_method, timeout=timeout, world_size=world_size,
455 rank=rank, store=store, pg_options=pg_options, device_id=device_id)
457 @staticmethod
458 def destroy_process_group(group: Optional[str] = None) -> None:
459 """
460 Destroy given process group.
462 Args:
463 group (str, optional): Specify the group to destroy. Default: ``None`` means ``hccl_world_group``. If group
464 is None or "hccl_world_group", destroy global process group and all process groups relative to global
465 process group.
466 """
467 dist.destroy_process_group(group)
469 @staticmethod
470 def get_process_group_ranks(group: Optional[str] = None) -> list[int]:
471 """
472 Get all ranks in given process group.
474 Args:
475 group (str, optional): Specify the process group to work on. Default: ``None`` means ``hccl_world_group``.
477 Returns:
478 List[int]: List of ranks in given process group.
479 """
480 return dist.get_process_group_ranks(group)
482 @staticmethod
483 def get_backend(group: Optional[str] = None) -> str:
484 """
485 Get the backend of given process group.
487 Args:
488 group (str, optional): Specify the process group to work on. Default: ``None`` means ``hccl_world_group``.
490 Returns:
491 str: The backend of the group.
492 """
493 return dist.get_backend(group)
495 @staticmethod
496 def split_group(parent_pg: Optional[str] = None,
497 split_ranks: Optional[list] = None,
498 timeout: Optional[timedelta] = None,
499 pg_options: Optional[str] = None,
500 group_desc: Optional[str] = None,
501 ) -> str:
502 """
503 Create split group for a specific group rank in split_ranks, which group contains current rank id.
505 Args:
506 parent_pg (str, Optional): A process group which the goal group split from.
507 split_ranks (Optional[list]): A list like ``list[list[int]]``.
508 timeout (Optional[timedelta]): Timeout for API executed. Default is ``None``.
509 pg_options (Optional[str]): Reserved parameter. Current not take effect.
510 group_desc (Optional[str]): Description of process group.
512 Returns:
513 str: The split group name.
514 """
515 if split_ranks is None or len(split_ranks) == 0:
516 raise ValueError("split_ranks cannot be None or empty")
518 rank_id = MindSporePlatform.get_rank()
519 for split_rank in split_ranks:
520 if rank_id in split_rank:
521 if pg_options is None:
522 hash_str_rank_list = '-'.join([str(rank) for rank in split_rank])
523 pg_options = f"{len(split_rank)}-{hash_str_rank_list}"
524 new_group(rank_ids=split_rank, group=pg_options)
525 return pg_options
526 raise ValueError(f"Split group invalid rank, the Split_ranks {split_ranks} does not contain current rank"
527 f" {rank_id}")
529 @staticmethod
530 def no_grad():
531 return _no_grad()
533 @staticmethod
534 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False):
535 return mint.empty_like(tensor, dtype=dtype, device=device, pin_memory=pin_memory)
537 def get_current_stream(self):
538 return ms.runtime.current_stream()
540 def new_event(self):
541 return ms.runtime.Event()
543 def tree_map(self, fn, tree):
544 """
545 Apply fn to each leaf in a nested structure (list / tuple / dict),
546 preserving the original structure.
547 """
548 if isinstance(tree, dict):
549 return type(tree)(
550 (k, self.tree_map(fn, v)) for k, v in tree.items()
551 )
553 if isinstance(tree, tuple):
554 return tuple(self.tree_map(fn, v) for v in tree)
556 if isinstance(tree, list):
557 return [self.tree_map(fn, v) for v in tree]
559 # leaf
560 return fn(tree)
562 @staticmethod
563 def register_forward_pre_hook(module, hook, prepend=False, with_kwargs=False):
564 return module.register_forward_pre_hook(hook, with_kwargs)
566 @staticmethod
567 def register_full_backward_hook(module, hook, prepend=False):
568 return module.register_backward_hook(hook)
570 @staticmethod
571 def register_full_backward_pre_hook(module, hook, prepend=False):
572 return module.register_backward_pre_hook(hook)
574 @property
575 def checkpoint(self):
576 return ms.recompute
578 @staticmethod
579 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs):
580 raise NotImplementedError("ckpt_wrapper is not supported on MindSpore platform")
582 @property
583 def noop_context_fn(self):
584 raise NotImplementedError("noop_context_fn is not supported on MindSpore platform")
586 @staticmethod
587 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
588 raise NotImplementedError("create_selective_checkpoint_contexts is not supported on MindSpore platform")
590 @staticmethod
591 def async_save_on_cpu(policy_fn=None):
592 raise NotImplementedError("async_save_on_cpu is not supported on MindSpore platform")
594 @staticmethod
595 def tensor_to_numpy(tensor) -> np.ndarray:
596 """Convert MindSpore tensor to numpy array."""
597 return tensor.asnumpy()