Coverage for hyper_parallel / platform / mindspore / platform.py: 81%

347 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""MindSpore platform api""" 

16from datetime import timedelta 

17from typing import Optional 

18 

19import numpy as np 

20import mindspore as ms 

21import mindspore.common.dtype as mstype 

22from mindspore.mint.distributed import TCPStore 

23 

24from mindspore.nn import Cell 

25from mindspore import mint 

26from mindspore.common.api import _no_grad 

27from mindspore.common.dtype import type_size_in_bytes 

28from mindspore.common.parameter import Parameter 

29from mindspore.common.tensor import Tensor 

30from mindspore.common.initializer import initializer 

31from mindspore.communication import get_group_size 

32from mindspore.communication import create_group as new_group 

33from mindspore.communication import get_rank as get_rank_id 

34from mindspore.communication import comm_func 

35from mindspore._c_expression import TensorTransform 

36import mindspore.mint.distributed as dist 

37 

38from hyper_parallel.platform.platform import Platform, PlatformType 

39from hyper_parallel.platform.mindspore.dtensor import DTensorBase 

40from hyper_parallel.platform.mindspore.pipeline_parallel.stage import PipelineStageBase 

41from hyper_parallel.platform.mindspore.parameter_init import init_parameters as _init_parameters 

42 

43_tensor_transform = TensorTransform.get_instance() 

44 

45 

46# pylint: disable=C0103 

47 

48 

49class MindSporePlatform(Platform): 

50 """MindSpore platform api""" 

51 Tensor = Tensor 

52 tensor = Tensor 

53 Parameter = Parameter 

54 Module = Cell 

55 DTensorBase = DTensorBase 

56 PipelineStageBase = PipelineStageBase 

57 platform_type = PlatformType.MINDSPORE 

58 tensor_dtype = mstype 

59 

60 def device_count(self, device_handle): 

61 device_type = self.device_type() 

62 if device_type == "cpu": 

63 return device_handle.device_context.cpu.device_count() 

64 if device_type == "gpu": 

65 return device_handle.device_context.gpu.device_count() 

66 return device_handle.device_context.ascend.device_count() 

67 

68 @staticmethod 

69 def get_rng_state(device=None, device_handle=None): 

70 """Get RNG state """ 

71 _ = device, device_handle 

72 return ms.get_rng_state() 

73 

74 @staticmethod 

75 def set_rng_state(state, device=None, device_handle=None): 

76 _ = device, device_handle 

77 return ms.set_rng_state(state) 

78 

79 def device_type(self): 

80 device_type = ms.get_context("device_target") 

81 if device_type == "Ascend": 

82 return "npu" 

83 return device_type.lower() 

84 

85 def device(self, device_idx=None): 

86 _ = device_idx 

87 device_type = self.device_type() 

88 return device_type 

89 

90 @staticmethod 

91 def get_device_handle(): 

92 return ms 

93 

94 @staticmethod 

95 def manual_seed(seed): 

96 return ms.manual_seed(seed) 

97 

98 @staticmethod 

99 def ones(size, dtype=None): 

100 return mint.ones(size, dtype=dtype) 

101 

102 @staticmethod 

103 def zeros(size, dtype=None): 

104 return mint.zeros(size, dtype=dtype) 

105 

106 @staticmethod 

107 def full(size, fill_value, dtype=None): 

108 return mint.full(size, fill_value, dtype=dtype) 

109 

110 @staticmethod 

111 def empty(size, dtype=None): 

112 return mint.empty(size, dtype=dtype) 

113 

114 @staticmethod 

115 def get_rank(): 

116 return get_rank_id() 

117 

118 @staticmethod 

119 def get_global_rank(group, group_rank): 

120 return dist.get_global_rank(group, group_rank) 

121 

122 @staticmethod 

123 def get_world_size(): 

124 return get_group_size() 

125 

126 @staticmethod 

127 def get_op_name(func): 

128 return func.name 

129 

130 @staticmethod 

131 def differentiable_all_gather_concat(data, group, concat_size, concat_dim): 

132 output, _ = comm_func.all_gather_into_tensor(data, group=group) 

133 if concat_dim == 0: 

134 return output 

135 output_tensors = ms.ops.Split(output_num=concat_size)(output) 

136 return ms.mint.concat(output_tensors, concat_dim) 

137 

138 @staticmethod 

139 def chunk(data, split_dim, split_size, index): 

140 return ms.ops.Split(axis=split_dim, output_num=split_size)(data)[index] 

141 

142 @staticmethod 

143 def differentiable_all_to_all(input_data, output_shape, group): 

144 output_tensor, _ = comm_func.all_to_all_single_with_output_shape( 

145 output_shape=output_shape, 

146 tensor=input_data, 

147 group=group, 

148 async_op=False 

149 ) 

150 return output_tensor 

151 

152 @staticmethod 

153 def tensor_type_cast(input_data, cast_type): 

154 """Cast tensor to specified data type.""" 

155 type_mapping = { 

156 'float32': ms.float32, 

157 'float16': ms.float16, 

158 'int64': ms.int64, 

159 'int32': ms.int32 

160 } 

161 if cast_type not in type_mapping: 

162 raise ValueError(f"Unknown cast type: {cast_type}. Supported types: {list(type_mapping.keys())}") 

163 return input_data.to(type_mapping[cast_type]) 

164 

165 @staticmethod 

166 def differentiable_all_reduce(data, op, group): 

167 output, _ = comm_func.all_reduce(data, op, group) 

168 return output 

169 

170 @staticmethod 

171 def differentiable_reduce_scatter(data, dev_num, axis, op, group): 

172 if axis > 0: 

173 data = ms.mint.concat(ms.ops.Split(axis=axis, output_num=dev_num)(data), dim=0) 

174 output_tensor, _ = comm_func.reduce_scatter_tensor(data, 'sum', group) 

175 if op == 'avg': 

176 output_tensor = output_tensor / dev_num 

177 return output_tensor 

178 

179 @staticmethod 

180 def init_parameters(module, stage_index): 

181 return _init_parameters(module, stage_index) 

182 

183 # pylint: disable=W0212 

184 @staticmethod 

185 def update_param_data(param, data): 

186 """update param data""" 

187 if isinstance(param, DTensorBase): 

188 param.set_data(data) 

189 else: 

190 param._update_data(data) 

191 

192 @staticmethod 

193 def get_cell_construct(cell): 

194 return cell.construct 

195 

196 @staticmethod 

197 def get_cells_and_names(cell): 

198 return cell.cells_and_names() 

199 

200 @staticmethod 

201 def search_parameter_by_name(cell, param_name: str): 

202 """ 

203 Find the parent Module of the parameter, the parameter's name in the parent Module, and the parameter. 

204 Return value: (parent Module instance, parameter's name in parent Module, parameter object). 

205 Returns None if not found. 

206 """ 

207 # Remove the "self." prefix from param_name (to maintain compatibility with original logic) 

208 param_name = param_name.replace("self.", "") 

209 # Case 1: The parameter is a direct parameter of the current Module (not in any sub-Module) 

210 if param_name in cell._params: 

211 return (cell, param_name, cell._params[param_name]) 

212 

213 # Case 2: The parameter is in a sub-Module (supports multi-level nesting, e.g., "net_b.dense1.weight") 

214 if "." in param_name: 

215 # Split into: sub-Module path + parameter name (e.g., "net_b.dense1" + "weight") 

216 cell_path, param_key = param_name.rsplit(".", 1) 

217 try: 

218 # Locate the sub-Module where the parameter resides (supports multi-level paths) 

219 target_cell = cell.get_sub_cell(cell_path) 

220 # Check if the sub-Module directly contains this parameter 

221 if param_key in target_cell._params: 

222 return target_cell, param_key, target_cell._params[param_key] 

223 except AttributeError: 

224 # Sub-Module path does not exist or the parameter is not in that sub-Module 

225 pass 

226 

227 # Traverse all sub-Modules (recursively) to search for the parameter 

228 for _, child_cell in cell._cells.items(): 

229 if isinstance(child_cell, Cell): 

230 # Recursively search within the sub-Module 

231 result = MindSporePlatform.search_parameter_by_name(child_cell, param_name) 

232 if result is not None: 

233 return result 

234 

235 return None 

236 

237 @staticmethod 

238 def update_parameter_by_name(cell, result: tuple, new_param) -> bool: 

239 """ 

240 Modify the original parameter in a Module or sub-Module using the search result 

241 Args: 

242 cell: The cell which parameter is to update 

243 result: A tuple contains parent Module, parameter key and old parameter. 

244 new_param: New Parameter object (used to replace the original parameter) 

245 """ 

246 parent_cell, param_key, _ = result 

247 # Key operation: directly modify the _params dictionary of the parent Module (original storage location) 

248 parent_cell._params[param_key] = new_param 

249 

250 if param_key in parent_cell.__dict__: 

251 parent_cell.__dict__[param_key] = new_param 

252 parent_cell._params_list[param_key] = new_param 

253 return True 

254 

255 @staticmethod 

256 def set_layout_into_parameter(param, layout): 

257 """Set layout in to parameter""" 

258 from hyper_parallel.core.dtensor import DTensor # pylint: disable=import-outside-toplevel 

259 from hyper_parallel.core.layout import _infer_slice_shape_by_layout, \ 

260 _get_slice_tensor_by_layout # pylint: disable=import-outside-toplevel 

261 if isinstance(param, DTensor): 

262 raise ValueError(f"Parameter {param.name} has been configured layout, cannot be set repeatedly.") 

263 param_info = param.param_info 

264 requires_grad = param.requires_grad 

265 name = param.name 

266 slice_shape = _infer_slice_shape_by_layout(param.shape, layout) 

267 

268 if not param.has_init: 

269 # has been init, get slice data 

270 param_dtensor = DTensor.from_local( 

271 _get_slice_tensor_by_layout(param, layout).value(), layout.mesh, layout.placements 

272 ) 

273 param = Parameter(param_dtensor, name=name, requires_grad=requires_grad) 

274 param.param_info = param_info 

275 else: 

276 # has not been init, need to modify init shape 

277 param.init_mode.shape = slice_shape 

278 param_dtensor = DTensor.from_local(param.init_mode, layout.mesh, layout.placements) 

279 param = Parameter(param_dtensor, name=name, requires_grad=requires_grad) 

280 param.param_info = param_info 

281 return param 

282 

283 @staticmethod 

284 def get_param_local_shape(param): 

285 """get param local shape""" 

286 if isinstance(param, DTensorBase): 

287 return param.local_shape 

288 return param.shape 

289 

290 @staticmethod 

291 def get_param_local_data(param): 

292 """get param local shape""" 

293 if isinstance(param, DTensorBase): 

294 return param.to_local() 

295 return param 

296 

297 @staticmethod 

298 def get_param_type_size(param): 

299 return type_size_in_bytes(param.dtype) 

300 

301 @staticmethod 

302 def new_zero_parameter(param_shape, param_type, requires_grad, device): 

303 param = Parameter(initializer("zeros", param_shape, param_type), requires_grad=requires_grad) 

304 if device in ("GPU", "Ascend"): 

305 return param.to(device) 

306 return param 

307 

308 @staticmethod 

309 def new_tensor(tensor_shape, tensor_type, device): 

310 tensor = Tensor(shape=tensor_shape, dtype=tensor_type) 

311 if device in ("GPU", "Ascend"): 

312 return tensor.to(device) 

313 return tensor 

314 

315 @staticmethod 

316 def full_like(tensor, fill_value, dtype=None): 

317 return mint.full_like(tensor, fill_value, dtype=dtype) 

318 

319 @staticmethod 

320 def isend(tensor, dst=None, group=None, tag=0): 

321 return dist.isend(tensor, dst, group, tag) 

322 

323 @staticmethod 

324 def irecv(tensor, src=None, group=None, tag=0): 

325 return dist.irecv(tensor, src, group, tag) 

326 

327 @staticmethod 

328 def send_object_list(obj_list, dst=None, group=None): 

329 # pylint: disable=C0415 

330 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import send_object_list 

331 send_object_list(obj_list, dst, group) 

332 

333 @staticmethod 

334 def recv_object_list(obj_list, src=None, group=None): 

335 # pylint: disable=C0415 

336 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import recv_object_list 

337 recv_object_list(obj_list, src, group) 

338 

339 @staticmethod 

340 def set_tensor_requires_grad(input_tensor): 

341 """ 

342 set requires grad flag for input tensor 

343 """ 

344 input_tensor.requires_grad_() 

345 

346 def _create_group(self, rank_list, group_name=None): 

347 if group_name is None: 

348 hash_str_rank_list = '-'.join([str(rank) for rank in rank_list]) 

349 group_name = f"{len(rank_list)}-{hash_str_rank_list}" 

350 new_group(rank_ids=rank_list, group=group_name) 

351 return group_name 

352 

353 @staticmethod 

354 def all_gather_into_tensor(data, group_info, async_op=False): 

355 return comm_func.all_gather_into_tensor(data, group=group_info.group_name, async_op=async_op) 

356 

357 @staticmethod 

358 def all_reduce(data, group_info, async_op=False): 

359 if isinstance(group_info, str): 

360 handle = dist.all_reduce(data, group=group_info, async_op=async_op) 

361 else: 

362 handle = dist.all_reduce(data, group=group_info.group_name, async_op=async_op) 

363 return data, handle 

364 

365 @staticmethod 

366 def broadcast(data, src, group=None, async_op=False): 

367 handle = dist.broadcast(data, src, group, async_op) 

368 if async_op: 

369 handle.wait() 

370 return data 

371 

372 @staticmethod 

373 def reduce_scatter_tensor(data, group_info, async_op=False): 

374 return comm_func.reduce_scatter_tensor(data, group=group_info.group_name, async_op=async_op) 

375 

376 @staticmethod 

377 def parameters_dict(cell: Cell): 

378 return cell.parameters_and_names() 

379 

380 @staticmethod 

381 def get_tensor_transform(): 

382 return _tensor_transform 

383 

384 @staticmethod 

385 def construct_strided_slice(x, begin, end, stride): 

386 return ms.ops.strided_slice(x, begin, end, stride) 

387 

388 @staticmethod 

389 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None): 

390 # pylint: disable=C0415 

391 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import _MicroBatch 

392 return _MicroBatch(micro_batch_num, args_batch_dim, kwargs_batch_dim) 

393 

394 @staticmethod 

395 def save_checkpoint(cell: Cell, file_path: str) -> None: 

396 save_dict = cell._params 

397 ms.save_checkpoint(save_obj=save_dict, ckpt_file_name=file_path, format="safetensors") 

398 

399 @staticmethod 

400 def load_checkpoint(file_path: str) -> dict: 

401 return ms.load_checkpoint(ckpt_file_name=file_path, format="safetensors") 

402 

403 def new_stream(self): 

404 return ms.runtime.Stream() 

405 

406 def get_stream_context(self): 

407 return ms.runtime.StreamCtx 

408 

409 @staticmethod 

410 def all_gather_object(object_list, obj, group=None) -> None: 

411 """ 

412 Gathers objects from the given group into object list. 

413 

414 Args: 

415 object_list (list[Any]): Define the output list, which size equal to the size of group. 

416 obj (Any): The object on current rank and in given process group. 

417 group (ProcessGroup, optional): The process group to gather obj. Default is ``None``, and ``None`` means 

418 global group. 

419 

420 Returns: 

421 None. Objs are gathered into ``object_list``. 

422 """ 

423 dist.all_gather_object(object_list, obj, group) 

424 

425 @staticmethod 

426 def init_process_group( 

427 backend: str = None, 

428 *, 

429 init_method: Optional[str] = None, 

430 timeout: Optional[timedelta] = None, 

431 world_size: int = -1, 

432 rank: int = -1, 

433 store: TCPStore = None, 

434 pg_options=None, 

435 device_id=None 

436 ) -> None: 

437 """ 

438 Initialize global process group. 

439 

440 Args: 

441 backend (str): The backend used to init process group. Default is ``"hccl"`` and now only support hccl. 

442 init_method (str, optional): URL specifying how to initialize the process group. Default is ``None``. 

443 timeout (timedelta, optional): Timeout for API executed. Default is ``None``. 

444 world_size (int): Number of processes. Default is ``-1``. 

445 rank (int, optional): Rank of the current process. Default is ``-1``. 

446 store (Store, optional): An object that stores key/value data, facilitating the exchange of inter-process 

447 communication addresses and connection information. Default is ``None``. Currently, only the 

448 ``TCPStore`` type is supported. 

449 pg_options (ProcessGroupOptions, optional): Reserved parameter. Current not take effect. 

450 device_id (int, optional): Reserved parameter. Current not take effect. 

451 """ 

452 if backend is None: 

453 backend = "hccl" 

454 dist.init_process_group(backend=backend, init_method=init_method, timeout=timeout, world_size=world_size, 

455 rank=rank, store=store, pg_options=pg_options, device_id=device_id) 

456 

457 @staticmethod 

458 def destroy_process_group(group: Optional[str] = None) -> None: 

459 """ 

460 Destroy given process group. 

461 

462 Args: 

463 group (str, optional): Specify the group to destroy. Default: ``None`` means ``hccl_world_group``. If group 

464 is None or "hccl_world_group", destroy global process group and all process groups relative to global 

465 process group. 

466 """ 

467 dist.destroy_process_group(group) 

468 

469 @staticmethod 

470 def get_process_group_ranks(group: Optional[str] = None) -> list[int]: 

471 """ 

472 Get all ranks in given process group. 

473 

474 Args: 

475 group (str, optional): Specify the process group to work on. Default: ``None`` means ``hccl_world_group``. 

476 

477 Returns: 

478 List[int]: List of ranks in given process group. 

479 """ 

480 return dist.get_process_group_ranks(group) 

481 

482 @staticmethod 

483 def get_backend(group: Optional[str] = None) -> str: 

484 """ 

485 Get the backend of given process group. 

486 

487 Args: 

488 group (str, optional): Specify the process group to work on. Default: ``None`` means ``hccl_world_group``. 

489 

490 Returns: 

491 str: The backend of the group. 

492 """ 

493 return dist.get_backend(group) 

494 

495 @staticmethod 

496 def split_group(parent_pg: Optional[str] = None, 

497 split_ranks: Optional[list] = None, 

498 timeout: Optional[timedelta] = None, 

499 pg_options: Optional[str] = None, 

500 group_desc: Optional[str] = None, 

501 ) -> str: 

502 """ 

503 Create split group for a specific group rank in split_ranks, which group contains current rank id. 

504 

505 Args: 

506 parent_pg (str, Optional): A process group which the goal group split from. 

507 split_ranks (Optional[list]): A list like ``list[list[int]]``. 

508 timeout (Optional[timedelta]): Timeout for API executed. Default is ``None``. 

509 pg_options (Optional[str]): Reserved parameter. Current not take effect. 

510 group_desc (Optional[str]): Description of process group. 

511 

512 Returns: 

513 str: The split group name. 

514 """ 

515 if split_ranks is None or len(split_ranks) == 0: 

516 raise ValueError("split_ranks cannot be None or empty") 

517 

518 rank_id = MindSporePlatform.get_rank() 

519 for split_rank in split_ranks: 

520 if rank_id in split_rank: 

521 if pg_options is None: 

522 hash_str_rank_list = '-'.join([str(rank) for rank in split_rank]) 

523 pg_options = f"{len(split_rank)}-{hash_str_rank_list}" 

524 new_group(rank_ids=split_rank, group=pg_options) 

525 return pg_options 

526 raise ValueError(f"Split group invalid rank, the Split_ranks {split_ranks} does not contain current rank" 

527 f" {rank_id}") 

528 

529 @staticmethod 

530 def no_grad(): 

531 return _no_grad() 

532 

533 @staticmethod 

534 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False): 

535 return mint.empty_like(tensor, dtype=dtype, device=device, pin_memory=pin_memory) 

536 

537 def get_current_stream(self): 

538 return ms.runtime.current_stream() 

539 

540 def new_event(self): 

541 return ms.runtime.Event() 

542 

543 def tree_map(self, fn, tree): 

544 """ 

545 Apply fn to each leaf in a nested structure (list / tuple / dict), 

546 preserving the original structure. 

547 """ 

548 if isinstance(tree, dict): 

549 return type(tree)( 

550 (k, self.tree_map(fn, v)) for k, v in tree.items() 

551 ) 

552 

553 if isinstance(tree, tuple): 

554 return tuple(self.tree_map(fn, v) for v in tree) 

555 

556 if isinstance(tree, list): 

557 return [self.tree_map(fn, v) for v in tree] 

558 

559 # leaf 

560 return fn(tree) 

561 

562 @staticmethod 

563 def register_forward_pre_hook(module, hook, prepend=False, with_kwargs=False): 

564 return module.register_forward_pre_hook(hook, with_kwargs) 

565 

566 @staticmethod 

567 def register_full_backward_hook(module, hook, prepend=False): 

568 return module.register_backward_hook(hook) 

569 

570 @staticmethod 

571 def register_full_backward_pre_hook(module, hook, prepend=False): 

572 return module.register_backward_pre_hook(hook) 

573 

574 @property 

575 def checkpoint(self): 

576 return ms.recompute 

577 

578 @staticmethod 

579 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs): 

580 raise NotImplementedError("ckpt_wrapper is not supported on MindSpore platform") 

581 

582 @property 

583 def noop_context_fn(self): 

584 raise NotImplementedError("noop_context_fn is not supported on MindSpore platform") 

585 

586 @staticmethod 

587 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): 

588 raise NotImplementedError("create_selective_checkpoint_contexts is not supported on MindSpore platform") 

589 

590 @staticmethod 

591 def async_save_on_cpu(policy_fn=None): 

592 raise NotImplementedError("async_save_on_cpu is not supported on MindSpore platform") 

593 

594 @staticmethod 

595 def tensor_to_numpy(tensor) -> np.ndarray: 

596 """Convert MindSpore tensor to numpy array.""" 

597 return tensor.asnumpy()