Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / platform.py: 44%

548 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""MindSpore platform api""" 

16from datetime import timedelta 

17from typing import Any, Optional, Union 

18import dataclasses 

19from collections import OrderedDict 

20 

21import contextlib 

22import numpy as np 

23import mindspore as ms 

24import mindspore.common.dtype as mstype 

25from mindspore.mint.distributed import TCPStore 

26 

27from mindspore.nn import Cell 

28from mindspore import mint 

29from mindspore.common.api import _no_grad 

30from mindspore.common._grad_function import _Function 

31from mindspore.common.dtype import type_size_in_bytes 

32from mindspore.common.parameter import Parameter 

33from mindspore.common.tensor import Tensor 

34from mindspore.common.initializer import initializer 

35from mindspore.common.recompute import null_context_fn 

36from mindspore.communication import GlobalComm 

37from mindspore.communication import get_group_size 

38from mindspore.communication import create_group as new_group 

39from mindspore.communication import get_rank as get_rank_id 

40from mindspore.ops import communication as ops_comm 

41from mindspore.ops.function import comm_func 

42from mindspore._c_expression import TensorTransform 

43import mindspore.mint.distributed as dist 

44 

45from hyper_parallel.platform.platform import Platform, PlatformType, EXISTING_COMM_GROUPS 

46from hyper_parallel.platform.mindspore.dtensor import DTensorBase 

47from hyper_parallel.platform.mindspore.pipeline_parallel.stage import PipelineStageBase 

48from hyper_parallel.platform.mindspore.parameter_init import init_parameters as _init_parameters 

49from hyper_parallel.platform.mindspore.init_weights import ( 

50 init_on_device as _init_on_device, 

51 _install_cell_to_empty_patch, 

52) 

53 

54comm_func.set_comm_ops_inplace(False) 

55_tensor_transform = TensorTransform.get_instance() 

56 

57 

58# pylint: disable=C0103 

59 

60 

61def _a2a_reconstruct_ms(out_perm: Tensor, concat_dim: int) -> Tensor: 

62 """Reconstruct A2A result from raw out_perm buffer.""" 

63 new_ndim = out_perm.dim() 

64 chunk_in_perm = concat_dim + 1 

65 recon_perm = list(range(1, chunk_in_perm)) + [0] + list(range(chunk_in_perm, new_ndim)) 

66 x_recon = out_perm.permute(recon_perm).contiguous() 

67 shape = list(x_recon.shape) 

68 merged = shape[concat_dim] * shape[concat_dim + 1] 

69 return x_recon.reshape(shape[:concat_dim] + [merged] + shape[concat_dim + 2:]) 

70 

71 

72def _normalize_all_to_all_single_result(result, output: Tensor) -> tuple[Tensor, object]: 

73 """Normalize MindSpore all_to_all_single return values to ``(output, handle)``.""" 

74 if isinstance(result, tuple): 

75 if len(result) != 2: 

76 raise ValueError( 

77 "mindspore all_to_all_single returned an unexpected tuple " 

78 f"with length {len(result)}" 

79 ) 

80 return result 

81 return output, result 

82 

83 

84def _mindspore_all_to_all_single(input_tensor: Tensor, output_shape, group, async_op=False) -> tuple[Tensor, object]: 

85 """Launch MindSpore all_to_all_single and normalize return values.""" 

86 output = mint.empty(tuple(output_shape), dtype=input_tensor.dtype) 

87 result = ops_comm.all_to_all_single(output, input_tensor, group=group, async_op=async_op) 

88 normalized_output, handle = _normalize_all_to_all_single_result(result, output) 

89 if not async_op: 

90 return normalized_output, None 

91 return normalized_output, handle 

92 

93 

94class _MSAsyncA2AFunction(_Function): 

95 """Differentiable wrapper for pre-launched async all-to-all.""" 

96 

97 @staticmethod 

98 def forward(ctx, x, work, out_perm, group, world_size, concat_dim, split_dim, handle_box): # pylint: disable=arguments-differ 

99 """Wait for pre-launched async A2A and return reconstructed output.""" 

100 ctx.group = group 

101 ctx.world_size = world_size 

102 ctx.concat_dim = concat_dim 

103 ctx.split_dim = split_dim 

104 ctx.handle_box = handle_box 

105 ctx.x_shape = tuple(x.shape) 

106 work.wait() 

107 return _a2a_reconstruct_ms(out_perm, concat_dim) 

108 

109 @staticmethod 

110 def backward(ctx, grad_output): 

111 """Launch async head->seq A2A for backward overlap, or return zero grad.""" 

112 if ctx.handle_box is not None: 

113 g = grad_output.contiguous() 

114 shape = list(g.shape) 

115 seq_dim = ctx.concat_dim 

116 s_full = shape[seq_dim] 

117 ndim = len(shape) + 1 

118 x_perm = g.reshape( 

119 shape[:seq_dim] + [ctx.world_size, s_full // ctx.world_size] + shape[seq_dim + 1:] 

120 ).permute( 

121 [seq_dim] + list(range(seq_dim)) + list(range(seq_dim + 1, ndim)) 

122 ).contiguous() 

123 out_perm, work = _mindspore_all_to_all_single( 

124 x_perm, 

125 list(x_perm.shape), 

126 ctx.group, 

127 async_op=True, 

128 ) 

129 ctx.handle_box.append((work, out_perm)) 

130 return mint.zeros(ctx.x_shape, dtype=grad_output.dtype), None, None, None, None, None, None, None 

131 

132 

133class MindSporePlatform(Platform): 

134 """MindSpore platform api""" 

135 Tensor = Tensor 

136 tensor = Tensor 

137 Parameter = Parameter 

138 Module = Cell 

139 DTensorBase = DTensorBase 

140 PipelineStageBase = PipelineStageBase 

141 platform_type = PlatformType.MINDSPORE 

142 tensor_dtype = mstype 

143 dtype = ms.Type 

144 Function = _Function 

145 

146 def __init__(self): 

147 # Ensure MindSpore ``nn.Cell.to_empty`` is patched as soon as the 

148 # MindSpore platform instance is created. 

149 _install_cell_to_empty_patch() 

150 

151 @staticmethod 

152 def is_linear_module(module) -> bool: 

153 """Check whether *module* is a MindSpore ``Dense`` (linear) or ``mint.nn.Linear`` layer.""" 

154 return isinstance(module, (ms.nn.Dense, mint.nn.Linear)) 

155 

156 @staticmethod 

157 def is_embedding_module(module) -> bool: 

158 """Check whether *module* is a MindSpore ``Embedding`` or ``mint.nn.Embedding`` layer.""" 

159 return isinstance(module, (ms.nn.Embedding, mint.nn.Embedding)) 

160 

161 def device_count(self, device_handle): 

162 """ 

163 Get the number of available devices. 

164 

165 Args: 

166 device_handle: The device handle (e.g., ms.device_context). 

167 

168 Returns: 

169 int: The number of available devices. 

170 """ 

171 device_type = self.device_type() 

172 if device_type == "cpu": 

173 return device_handle.device_context.cpu.device_count() 

174 if device_type == "gpu": 

175 return device_handle.device_context.gpu.device_count() 

176 return device_handle.device_context.ascend.device_count() 

177 

178 @staticmethod 

179 def get_rng_state(device=None, device_handle=None): 

180 """ 

181 Get the random number generator state. 

182 

183 Args: 

184 device (Optional): The device to get RNG state from (not used in MindSpore). 

185 device_handle (Optional): The device handle (not used in MindSpore). 

186 

187 Returns: 

188 Tensor: The RNG state as a tensor. 

189 """ 

190 _ = device, device_handle 

191 return ms.get_rng_state() 

192 

193 @staticmethod 

194 def set_rng_state(state, device=None, device_handle=None): 

195 """ 

196 Set the random number generator state. 

197 

198 Args: 

199 state (Tensor): The RNG state to set. 

200 device (Optional): The device to set RNG state for (not used in MindSpore). 

201 device_handle (Optional): The device handle (not used in MindSpore). 

202 """ 

203 _ = device, device_handle 

204 return ms.set_rng_state(state) 

205 

206 def device_type(self): 

207 """ 

208 Get the current device type. 

209 

210 Returns: 

211 str: The device type string ("npu" for Ascend, "gpu" for GPU, "cpu" for CPU). 

212 """ 

213 device_type = ms.get_context("device_target") 

214 if device_type == "Ascend": 

215 return "npu" 

216 return device_type.lower() 

217 

218 def device(self, device_idx=None): 

219 """ 

220 Get the device type string. 

221 

222 Args: 

223 device_idx (Optional[int]): The device index (not used in MindSpore). 

224 

225 Returns: 

226 str: The device type string. 

227 """ 

228 _ = device_idx 

229 device_type = self.device_type() 

230 return device_type 

231 

232 @staticmethod 

233 def get_device_handle(): 

234 """ 

235 Get the MindSpore module as the device handle. 

236 

237 Returns: 

238 module: The mindspore module. 

239 """ 

240 return ms 

241 

242 @staticmethod 

243 def manual_seed(seed): 

244 """ 

245 Set the random seed for reproducibility. 

246 

247 Args: 

248 seed (int): The random seed value. 

249 

250 Returns: 

251 None 

252 """ 

253 return ms.manual_seed(seed) 

254 

255 @staticmethod 

256 def ones(size, dtype=None): 

257 """ 

258 Create a tensor filled with ones. 

259 

260 Args: 

261 size (tuple): The shape of the output tensor. 

262 dtype (Optional[ms.Type]): The desired data type. 

263 

264 Returns: 

265 Tensor: A tensor filled with ones. 

266 """ 

267 return mint.ones(size, dtype=dtype) 

268 

269 @staticmethod 

270 def zeros(size, dtype=None, device=None): 

271 """ 

272 Create a tensor filled with zeros. 

273 

274 Args: 

275 size (tuple): The shape of the output tensor. 

276 dtype (Optional[ms.Type]): The desired data type. 

277 device (Optional[ms.device]): The device to create the tensor on. 

278 

279 Returns: 

280 Tensor: A tensor filled with zeros. 

281 """ 

282 tensor = mint.zeros(size, dtype=dtype) 

283 if device in ("GPU", "Ascend"): 

284 return tensor.to(device) 

285 return tensor 

286 

287 @staticmethod 

288 def full(size, fill_value, dtype=None): 

289 """ 

290 Create a tensor filled with a scalar value. 

291 

292 Args: 

293 size (tuple): The shape of the output tensor. 

294 fill_value (scalar): The value to fill the tensor with. 

295 dtype (Optional[ms.Type]): The desired data type. 

296 

297 Returns: 

298 Tensor: A tensor filled with the specified value. 

299 """ 

300 return mint.full(size, fill_value, dtype=dtype) 

301 

302 @staticmethod 

303 def empty(size, dtype=None): 

304 """ 

305 Create an uninitialized tensor. 

306 

307 Args: 

308 size (tuple): The shape of the output tensor. 

309 dtype (Optional[ms.Type]): The desired data type. 

310 

311 Returns: 

312 Tensor: An uninitialized tensor. 

313 """ 

314 return mint.empty(size, dtype=dtype) 

315 

316 @staticmethod 

317 def get_rank(): 

318 """ 

319 Get the rank of the current process in the distributed group. 

320 

321 Returns: 

322 int: The rank of the current process. 

323 """ 

324 return get_rank_id() 

325 

326 @staticmethod 

327 def get_global_rank(group, group_rank): 

328 """ 

329 Get the global rank from a group rank. 

330 

331 Args: 

332 group (str): The process group name. 

333 group_rank (int): The rank within the group. 

334 

335 Returns: 

336 int: The global rank. 

337 """ 

338 return dist.get_global_rank(group, group_rank) 

339 

340 @staticmethod 

341 def get_world_size(): 

342 """ 

343 Get the total number of processes in the distributed group. 

344 

345 Returns: 

346 int: The world size. 

347 """ 

348 return get_group_size() 

349 

350 @staticmethod 

351 def get_op_name(func): 

352 """ 

353 Extract the operation name from a function. 

354 

355 Args: 

356 func: The function to extract the name from. 

357 

358 Returns: 

359 str: The operation name. 

360 """ 

361 return func.name 

362 

363 @staticmethod 

364 def differentiable_all_gather_concat(data, group, concat_size, concat_dim): 

365 output, _ = comm_func.all_gather_into_tensor(None, data, group=group) 

366 if concat_dim == 0: 

367 return output 

368 output_tensors = ms.ops.Split(output_num=concat_size)(output) 

369 return ms.mint.concat(output_tensors, concat_dim) 

370 

371 @staticmethod 

372 def chunk(data, split_dim, split_size, index): 

373 return ms.ops.Split(axis=split_dim, output_num=split_size)(data)[index] 

374 

375 @staticmethod 

376 def differentiable_all_to_all(input_data, output_shape, group): 

377 output_tensor, _ = comm_func.all_to_all_single( 

378 output_shape, 

379 input_data, 

380 group=group, 

381 async_op=False 

382 ) 

383 return output_tensor 

384 

385 @staticmethod 

386 def tensor_type_cast(input_data, cast_type): 

387 """Cast tensor to specified data type.""" 

388 type_mapping = { 

389 'float32': ms.float32, 

390 'float16': ms.float16, 

391 'int64': ms.int64, 

392 'int32': ms.int32 

393 } 

394 if cast_type not in type_mapping: 

395 raise ValueError(f"Unknown cast type: {cast_type}. Supported types: {list(type_mapping.keys())}") 

396 return input_data.to(type_mapping[cast_type]) 

397 

398 @staticmethod 

399 def differentiable_all_reduce(data, op, group): 

400 output, _ = comm_func.all_reduce(data, op, group) 

401 return output 

402 

403 @staticmethod 

404 def differentiable_reduce_scatter(data, dev_num, axis, op, group): 

405 if axis > 0: 

406 data = ms.mint.concat(ms.ops.Split(axis=axis, output_num=dev_num)(data), dim=0) 

407 output_tensor, _ = comm_func.reduce_scatter_tensor(None, data, 'sum', group) 

408 if op == 'avg': 

409 output_tensor = output_tensor / dev_num 

410 return output_tensor 

411 

412 @staticmethod 

413 def init_parameters(module, stage_index): 

414 return _init_parameters(module, stage_index) 

415 

416 # pylint: disable=W0212 

417 @staticmethod 

418 def update_param_data(param, data): 

419 """update param data""" 

420 if isinstance(param, DTensorBase): 

421 param.set_data(data) 

422 else: 

423 param._update_data(data) 

424 

425 @staticmethod 

426 def load_into_param(param, data): 

427 copy_tensor = MindSporePlatform.empty_like(data) 

428 copy_tensor.copy_(data) 

429 if isinstance(param, DTensorBase): 

430 param.set_data(copy_tensor) 

431 else: 

432 param._update(copy_tensor) 

433 

434 @staticmethod 

435 def get_cell_construct(cell): 

436 return cell.construct 

437 

438 @staticmethod 

439 def get_cells_and_names(cell): 

440 return cell.cells_and_names() 

441 

442 @staticmethod 

443 def search_parameter_by_name(cell, param_name: str): 

444 """ 

445 Find the parent Module of the parameter, the parameter's name in the parent Module, and the parameter. 

446 Return value: (parent Module instance, parameter's name in parent Module, parameter object). 

447 Returns None if not found. 

448 """ 

449 # Remove the "self." prefix from param_name (to maintain compatibility with original logic) 

450 param_name = param_name.replace("self.", "") 

451 # Case 1: The parameter is a direct parameter of the current Module (not in any sub-Module) 

452 if param_name in cell._params: 

453 return (cell, param_name, cell._params[param_name]) 

454 

455 # Case 2: The parameter is in a sub-Module (supports multi-level nesting, e.g., "net_b.dense1.weight") 

456 if "." in param_name: 

457 # Split into: sub-Module path + parameter name (e.g., "net_b.dense1" + "weight") 

458 cell_path, param_key = param_name.rsplit(".", 1) 

459 try: 

460 # Locate the sub-Module where the parameter resides (supports multi-level paths) 

461 target_cell = cell.get_sub_cell(cell_path) 

462 # Check if the sub-Module directly contains this parameter 

463 if param_key in target_cell._params: 

464 return target_cell, param_key, target_cell._params[param_key] 

465 except AttributeError: 

466 # Sub-Module path does not exist or the parameter is not in that sub-Module 

467 pass 

468 

469 # Traverse all sub-Modules (recursively) to search for the parameter 

470 for _, child_cell in cell._cells.items(): 

471 if isinstance(child_cell, Cell): 

472 # Recursively search within the sub-Module 

473 result = MindSporePlatform.search_parameter_by_name(child_cell, param_name) 

474 if result is not None: 

475 return result 

476 

477 return None 

478 

479 @staticmethod 

480 def update_parameter_by_name(cell, result: tuple, new_param) -> bool: 

481 """ 

482 Modify the original parameter in a Module or sub-Module using the search result 

483 Args: 

484 cell: The cell which parameter is to update 

485 result: A tuple contains parent Module, parameter key and old parameter. 

486 new_param: New Parameter object (used to replace the original parameter) 

487 """ 

488 parent_cell, param_key, _ = result 

489 # Key operation: directly modify the _params dictionary of the parent Module (original storage location) 

490 parent_cell._params[param_key] = new_param 

491 

492 if param_key in parent_cell.__dict__: 

493 parent_cell.__dict__[param_key] = new_param 

494 parent_cell._params_list[param_key] = new_param 

495 return True 

496 

497 @staticmethod 

498 def set_layout_into_parameter(param, layout): 

499 """Set layout in to parameter""" 

500 from hyper_parallel.core.dtensor.dtensor import DTensor # pylint: disable=import-outside-toplevel 

501 from hyper_parallel.core.dtensor.layout import _infer_slice_shape_by_layout, \ 

502 _get_slice_tensor_by_layout # pylint: disable=import-outside-toplevel 

503 if isinstance(param, DTensor): 

504 raise ValueError(f"Parameter {param.name} has been configured layout, cannot be set repeatedly.") 

505 param_info = param.param_info 

506 requires_grad = param.requires_grad 

507 name = param.name 

508 slice_shape = _infer_slice_shape_by_layout(param.shape, layout) 

509 

510 if not param.has_init: 

511 # has been init, get slice data 

512 param_dtensor = DTensor.from_local( 

513 _get_slice_tensor_by_layout(param, layout).value(), layout.mesh, layout.alias_placements 

514 ) 

515 param = Parameter(param_dtensor, name=name, requires_grad=requires_grad) 

516 param.param_info = param_info 

517 else: 

518 # has not been init, need to modify init shape 

519 param.init_mode.shape = slice_shape 

520 param_dtensor = DTensor.from_local(param.init_mode, layout.mesh, layout.alias_placements) 

521 param = Parameter(param_dtensor, name=name, requires_grad=requires_grad) 

522 param.param_info = param_info 

523 return param 

524 

525 @staticmethod 

526 def get_param_local_shape(param): 

527 """get param local shape""" 

528 if isinstance(param, DTensorBase): 

529 return param.local_shape 

530 return param.shape 

531 

532 @staticmethod 

533 def get_param_local_data(param): 

534 """get param local shape""" 

535 if isinstance(param, DTensorBase): 

536 return param.to_local() 

537 return param 

538 

539 @staticmethod 

540 def get_param_type_size(param): 

541 return type_size_in_bytes(param.dtype) 

542 

543 @staticmethod 

544 def is_tensor(obj: Any) -> bool: 

545 """Return True if ``obj`` is a ``mindspore.Tensor``.""" 

546 return isinstance(obj, Tensor) 

547 

548 @staticmethod 

549 def get_tensor_storage_size(tensor: Any) -> int: 

550 """Return serialized byte size (numel * itemsize) for a MindSpore tensor.""" 

551 if not MindSporePlatform.is_tensor(tensor): 

552 raise TypeError( 

553 f"MindSporePlatform.get_tensor_storage_size expects mindspore.Tensor, got {type(tensor)!r}" 

554 ) 

555 return int(tensor.numel()) * int(tensor.itemsize) 

556 

557 @staticmethod 

558 def new_zero_parameter(param_shape, param_type, requires_grad, device): 

559 param = Parameter(initializer("zeros", param_shape, param_type), requires_grad=requires_grad) 

560 if device in ("GPU", "Ascend"): 

561 return param.to(device) 

562 return param 

563 

564 @staticmethod 

565 def new_tensor(tensor_shape, tensor_type, device): 

566 tensor = Tensor(shape=tensor_shape, dtype=tensor_type) 

567 if device in ("GPU", "Ascend"): 

568 return tensor.to(device) 

569 return tensor 

570 

571 @staticmethod 

572 def full_like(tensor, fill_value, dtype=None): 

573 return mint.full_like(tensor, fill_value, dtype=dtype) 

574 

575 @staticmethod 

576 def isend(tensor, dst=None, group=None, tag=0): 

577 return dist.isend(tensor, dst, group, tag) 

578 

579 @staticmethod 

580 def irecv(tensor, src=None, group=None, tag=0): 

581 return dist.irecv(tensor, src, group, tag) 

582 

583 @staticmethod 

584 def p2p_exchange(tensor, peer_rank: int, group=None): # pylint: disable=unused-argument 

585 raise NotImplementedError( 

586 "p2p_exchange is not yet supported on the MindSpore platform." 

587 ) 

588 

589 @staticmethod 

590 def send_object_list(obj_list, dst=None, group=None): 

591 # pylint: disable=C0415 

592 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import send_object_list 

593 send_object_list(obj_list, dst, group) 

594 

595 @staticmethod 

596 def recv_object_list(obj_list, src=None, group=None): 

597 # pylint: disable=C0415 

598 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import recv_object_list 

599 recv_object_list(obj_list, src, group) 

600 

601 @staticmethod 

602 def set_tensor_requires_grad(input_tensor): 

603 """ 

604 set requires grad flag for input tensor 

605 """ 

606 input_tensor.requires_grad_() 

607 

608 def _create_group(self, rank_list): 

609 world_group = self._maybe_reuse_world_group(rank_list) 

610 if world_group is not None: 

611 return world_group 

612 

613 group_name = str(tuple(sorted(rank_list))) 

614 new_group(rank_ids=rank_list, group=group_name) 

615 EXISTING_COMM_GROUPS[group_name] = group_name 

616 return group_name 

617 

618 @staticmethod 

619 def all_gather_into_tensor(data, group_info, async_op=False): 

620 return comm_func.all_gather_into_tensor(None, data, group=group_info.group_name, async_op=async_op) 

621 

622 @staticmethod 

623 def all_reduce(data, group_info, async_op=False): 

624 if isinstance(group_info, str): 

625 handle = dist.all_reduce(data, group=group_info, async_op=async_op) 

626 else: 

627 handle = dist.all_reduce(data, group=group_info.group_name, async_op=async_op) 

628 return data, handle 

629 

630 @staticmethod 

631 def broadcast(data, src, group=None, async_op=False): 

632 handle = dist.broadcast(data, src, group, async_op) 

633 if async_op: 

634 handle.wait() 

635 return data 

636 

637 @staticmethod 

638 def reduce_scatter_tensor(data, group_info, async_op=False): 

639 return comm_func.reduce_scatter_tensor(None, data, group=group_info.group_name, async_op=async_op) 

640 

641 @staticmethod 

642 def all_to_all_single(input_tensor, output_shape, group, async_op=False): 

643 return _mindspore_all_to_all_single(input_tensor, output_shape, group, async_op=async_op) 

644 

645 @staticmethod 

646 def differentiable_async_a2a_wait(x, work, out_perm, group, world_size, concat_dim, split_dim, # pylint: disable=unused-argument 

647 handle_box=None): 

648 return _MSAsyncA2AFunction.apply( 

649 x, work, out_perm, group, world_size, concat_dim, split_dim, handle_box 

650 ) 

651 

652 @staticmethod 

653 def parameters_dict(cell: Cell): 

654 return cell.parameters_and_names() 

655 

656 @staticmethod 

657 def get_tensor_transform(): 

658 return _tensor_transform 

659 

660 @staticmethod 

661 def construct_strided_slice(x, begin, end, stride): 

662 return ms.ops.strided_slice(x, begin, end, stride) 

663 

664 @staticmethod 

665 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None): 

666 # pylint: disable=C0415 

667 from hyper_parallel.platform.mindspore.pipeline_parallel._utils import _MicroBatch 

668 return _MicroBatch(micro_batch_num, args_batch_dim, kwargs_batch_dim) 

669 

670 @staticmethod 

671 def get_model_state_dict(model, *, options=None): 

672 raise NotImplementedError( 

673 "get_model_state_dict is not yet supported on MindSpore" 

674 ) 

675 

676 @staticmethod 

677 def save_checkpoint(cell: Union[Cell, dict], file_path: str, ckpt_format: str = "safetensors") -> None: 

678 if isinstance(cell, dict): 

679 save_dict = {} 

680 for k, v in cell.items(): 

681 if isinstance(v, Parameter): 

682 save_dict[k] = v 

683 elif isinstance(v, Tensor): 

684 save_dict[k] = Parameter(v, name=k) 

685 else: 

686 save_dict[k] = v 

687 else: 

688 save_dict = cell._params 

689 ms.save_checkpoint(save_obj=save_dict, ckpt_file_name=file_path, format=ckpt_format) 

690 

691 @staticmethod 

692 def load_checkpoint(file_path: str, ckpt_format: str = "safetensors") -> dict: 

693 return ms.load_checkpoint(ckpt_file_name=file_path, format=ckpt_format) 

694 

695 @staticmethod 

696 def get_symmetric_memory_handler(): 

697 # pylint: disable=C0415 

698 from hyper_parallel.platform.mindspore.symmetric_memory import MSSymmetricMemoryHandler 

699 symmetric_memory = MSSymmetricMemoryHandler() 

700 return symmetric_memory 

701 

702 @staticmethod 

703 def get_multicore_handler(): 

704 # pylint: disable=C0415 

705 from hyper_parallel.platform.mindspore.multicore import MSMulticoreHandler 

706 return MSMulticoreHandler() 

707 

708 def new_stream(self): 

709 return ms.runtime.Stream() 

710 

711 def get_stream_context(self): 

712 return ms.runtime.StreamCtx 

713 

714 @staticmethod 

715 def all_gather_object(object_list, obj, group=None) -> None: 

716 """ 

717 Gathers objects from the given group into object list. 

718 

719 Args: 

720 object_list (list[Any]): Define the output list, which size equal to the size of group. 

721 obj (Any): The object on current rank and in given process group. 

722 group (ProcessGroup, optional): The process group to gather obj. Default is ``None``, and ``None`` means 

723 global group. 

724 

725 Returns: 

726 None. Objs are gathered into ``object_list``. 

727 """ 

728 dist.all_gather_object(object_list, obj, group) 

729 

730 @staticmethod 

731 def barrier(group=None, async_op: bool = False, device_ids=None) -> Any: 

732 """ 

733 Synchronize all processes in the given communication group. 

734 

735 Args: 

736 group (str, optional): The communication group to work on. Default is ``None``, 

737 meaning the default world group. 

738 async_op (bool, optional): Whether this op should be asynchronous. Default: ``False``. 

739 device_ids (list[int], optional): Reserved parameter on Ascend. Default: ``None``. 

740 

741 Returns: 

742 CommHandle if ``async_op`` is True; otherwise ``None``. 

743 """ 

744 return dist.barrier(group, async_op, device_ids) 

745 

746 @staticmethod 

747 def init_process_group( 

748 backend: str = None, 

749 *, 

750 init_method: Optional[str] = None, 

751 timeout: Optional[timedelta] = None, 

752 world_size: int = -1, 

753 rank: int = -1, 

754 store: TCPStore = None, 

755 pg_options=None, 

756 device_id=None 

757 ) -> None: 

758 """ 

759 Initialize global process group. 

760 

761 Args: 

762 backend (str): The backend used to init process group. Default is ``"hccl"`` and now only support hccl. 

763 init_method (str, optional): URL specifying how to initialize the process group. Default is ``None``. 

764 timeout (timedelta, optional): Timeout for API executed. Default is ``None``. 

765 world_size (int): Number of processes. Default is ``-1``. 

766 rank (int, optional): Rank of the current process. Default is ``-1``. 

767 store (Store, optional): An object that stores key/value data, facilitating the exchange of inter-process 

768 communication addresses and connection information. Default is ``None``. Currently, only the 

769 ``TCPStore`` type is supported. 

770 pg_options (ProcessGroupOptions, optional): Reserved parameter. Current not take effect. 

771 device_id (int, optional): Reserved parameter. Current not take effect. 

772 """ 

773 if backend is None: 

774 backend = "hccl" 

775 try: 

776 if dist.is_initialized(): 

777 return 

778 except AttributeError: 

779 pass 

780 dist.init_process_group(backend=backend, init_method=init_method, timeout=timeout, world_size=world_size, 

781 rank=rank, store=store, pg_options=pg_options, device_id=device_id) 

782 

783 @staticmethod 

784 def destroy_process_group(group: Optional[str] = None) -> None: 

785 """ 

786 Destroy given process group. 

787 

788 Args: 

789 group (str, optional): Specify the group to destroy. Default: ``None`` means ``hccl_world_group``. If group 

790 is None or "hccl_world_group", destroy global process group and all process groups relative to global 

791 process group. 

792 """ 

793 if group in EXISTING_COMM_GROUPS.values(): 

794 keys_to_destroy = [k for k, v in EXISTING_COMM_GROUPS.items() if v == group] 

795 for k in keys_to_destroy: 

796 del EXISTING_COMM_GROUPS[k] 

797 dist.destroy_process_group(group) 

798 

799 @staticmethod 

800 def get_process_group_ranks(group: Optional[str] = None) -> list[int]: 

801 """ 

802 Get all ranks in given process group. 

803 

804 Args: 

805 group (str, optional): Specify the process group to work on. Default: ``None`` means ``hccl_world_group``. 

806 

807 Returns: 

808 List[int]: List of ranks in given process group. 

809 """ 

810 return dist.get_process_group_ranks(group) 

811 

812 @staticmethod 

813 def get_backend(group: Optional[str] = None) -> str: 

814 """ 

815 Get the backend of given process group. 

816 

817 Args: 

818 group (str, optional): Specify the process group to work on. Default: ``None`` means ``hccl_world_group``. 

819 

820 Returns: 

821 str: The backend of the group. 

822 """ 

823 return dist.get_backend(group) 

824 

825 @staticmethod 

826 def split_group(parent_pg: Optional[str] = None, 

827 split_ranks: Optional[list] = None, 

828 timeout: Optional[timedelta] = None, 

829 pg_options: Optional[str] = None, 

830 group_desc: Optional[str] = None, 

831 ) -> str: 

832 """ 

833 Create split group for a specific group rank in split_ranks, which group contains current rank id. 

834 

835 Args: 

836 parent_pg (str, Optional): A process group which the goal group split from. 

837 split_ranks (Optional[list]): A list like ``list[list[int]]``. 

838 timeout (Optional[timedelta]): Timeout for API executed. Default is ``None``. 

839 pg_options (Optional[str]): Reserved parameter. Current not take effect. 

840 group_desc (Optional[str]): Description of process group. 

841 

842 Returns: 

843 str: The split group name. 

844 """ 

845 if split_ranks is None or len(split_ranks) == 0: 

846 raise ValueError("split_ranks cannot be None or empty") 

847 

848 rank_id = MindSporePlatform.get_rank() 

849 for split_rank in split_ranks: 

850 if rank_id in split_rank: 

851 world_group = MindSporePlatform._maybe_reuse_world_group(split_rank) 

852 if world_group is not None: 

853 return world_group 

854 split_group = MindSporePlatform.get_created_group(split_rank) 

855 if split_group: 

856 return split_group 

857 group_name = str(tuple(sorted(split_rank))) 

858 new_group(rank_ids=split_rank, group=group_name) 

859 EXISTING_COMM_GROUPS[group_name] = group_name 

860 return group_name 

861 raise ValueError(f"Split group invalid rank, the Split_ranks {split_ranks} does not contain current rank" 

862 f" {rank_id}") 

863 

864 @staticmethod 

865 def get_group_local_rank(group=None) -> int: 

866 """get group local rank id.""" 

867 return dist.get_group_rank(group, MindSporePlatform.get_rank()) 

868 

869 @staticmethod 

870 def no_grad(): 

871 return _no_grad() 

872 

873 @staticmethod 

874 def cat(tensors, dim=0): 

875 return mint.cat(tensors, dim=dim) 

876 

877 @staticmethod 

878 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False): 

879 return mint.empty_like(tensor, dtype=dtype, device=device, pin_memory=pin_memory) 

880 

881 def get_current_stream(self): 

882 return ms.runtime.current_stream() 

883 

884 def new_event(self): 

885 return ms.runtime.Event() 

886 

887 def tree_map(self, fn, tree): 

888 """ 

889 Apply fn to each leaf in a nested structure (list / tuple / dict), 

890 preserving the original structure. 

891 """ 

892 if isinstance(tree, dict): 

893 return type(tree)( 

894 (k, self.tree_map(fn, v)) for k, v in tree.items() 

895 ) 

896 

897 if isinstance(tree, tuple): 

898 return tuple(self.tree_map(fn, v) for v in tree) 

899 

900 if isinstance(tree, list): 

901 return [self.tree_map(fn, v) for v in tree] 

902 

903 # leaf 

904 return fn(tree) 

905 

906 @staticmethod 

907 def register_forward_pre_hook(module, hook, prepend=False, with_kwargs=False): 

908 return module.register_forward_pre_hook(hook, with_kwargs=with_kwargs) 

909 

910 @staticmethod 

911 def register_full_backward_hook(module, hook, prepend=False): 

912 return module.register_backward_hook(hook) 

913 

914 @staticmethod 

915 def register_full_backward_pre_hook(module, hook, prepend=False): 

916 return module.register_backward_pre_hook(hook) 

917 

918 @property 

919 def checkpoint(self): 

920 return ms.recompute 

921 

922 @staticmethod 

923 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs): 

924 # pylint: disable=C0415 

925 from hyper_parallel.platform.mindspore.activation_checkpoint.checkpoint_wrapper import checkpoint_wrapper 

926 return checkpoint_wrapper(module, checkpoint_fn=checkpoint_fn, **checkpoint_fn_kwargs) 

927 

928 @staticmethod 

929 def swap_wrapper(module, policy_fn=None): 

930 # pylint: disable=C0415 

931 from hyper_parallel.platform.mindspore.activation_checkpoint.activation_swap import swap_wrapper 

932 return swap_wrapper(module, policy_fn=policy_fn) 

933 

934 @property 

935 def noop_context_fn(self): 

936 return null_context_fn 

937 

938 @staticmethod 

939 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): 

940 # pylint: disable=C0415 

941 from hyper_parallel.platform.mindspore.activation_checkpoint.sac import create_selective_checkpoint_contexts 

942 return create_selective_checkpoint_contexts(policy_fn_or_list, 

943 allow_cache_entry_mutation=allow_cache_entry_mutation) 

944 

945 @staticmethod 

946 def async_save_on_cpu(policy_fn=None): 

947 # pylint: disable=C0415 

948 from hyper_parallel.platform.mindspore.activation_checkpoint.activation_swap import AsyncSaveOnCpu 

949 return AsyncSaveOnCpu(policy_fn=policy_fn) 

950 

951 @staticmethod 

952 def get_element_size(tensor): 

953 """Get Tensor Element Size""" 

954 return tensor.itemsize 

955 

956 @staticmethod 

957 def tensor_to_numpy(tensor) -> np.ndarray: 

958 """Convert MindSpore tensor to numpy array.""" 

959 return tensor.asnumpy() 

960 

961 @staticmethod 

962 

963 def clip_grad_norm_( 

964 parameters, max_norm, norm_type=2.0, 

965 error_if_nonfinite=False, foreach=None, 

966 ): 

967 raise NotImplementedError( 

968 "clip_grad_norm_ is not yet supported on MindSpore" 

969 ) 

970 

971 @property 

972 def meta_device(self): 

973 return "meta" 

974 

975 def init_on_device(self, device, include_buffers=False): 

976 return _init_on_device(device, include_buffers=include_buffers) 

977 

978 def cast_fp_tensor(self, dtype, x): 

979 """ 

980 Cast floating-point tensor to target dtype if applicable. 

981 """ 

982 if ( 

983 not isinstance(x, ms.Tensor) 

984 or not ms.ops.is_floating_point(x) 

985 or x.dtype == dtype 

986 ): 

987 return x 

988 return x.to(dtype) 

989 

990 def apply_to_tensors(self, fn, container): 

991 """Recursively apply to all tensor in different kinds of container types.""" 

992 

993 def apply(x): 

994 if isinstance(x, ms.Tensor): 

995 return fn(x) 

996 if hasattr(x, "__dataclass_fields__"): 

997 dc = dataclasses.replace(x) 

998 changes = { 

999 f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc) 

1000 } 

1001 return dataclasses.replace(dc, **changes) 

1002 if isinstance(x, OrderedDict): 

1003 od = x.__class__() 

1004 for key, value in x.items(): 

1005 od[key] = apply(value) 

1006 return od 

1007 if isinstance(x, dict): 

1008 return {key: apply(value) for key, value in x.items()} 

1009 if isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields"): 

1010 res = (apply(el) for el in x) 

1011 return type(x)(*res) 

1012 if isinstance(x, (list, tuple, set)): 

1013 return type(x)(apply(el) for el in x) 

1014 return x 

1015 

1016 return apply(container) 

1017 

1018 @staticmethod 

1019 def profiler_record(name): 

1020 """Profiler context manager for recording operations using mindspore.profiler.""" 

1021 return contextlib.nullcontext() 

1022 

1023 def str_to_dtype(self, dtype_str: str) -> Any: 

1024 """Resolve checkpoint dtype strings (``mindspore.*`` or short ``str(Tensor.dtype)`` e.g. ``Float32``).""" 

1025 if "." in dtype_str: 

1026 prefix, name = dtype_str.split(".", 1) 

1027 if prefix == "mindspore": 

1028 return getattr(ms, name) 

1029 dtype = getattr(ms, dtype_str.lower(), None) 

1030 if dtype is not None: 

1031 return dtype 

1032 raise ValueError( 

1033 f"Expected dtype string like 'mindspore.float32' or 'Float32', got {dtype_str!r}." 

1034 ) 

1035 

1036 def list_to_size(self, size_list: list[int]) -> tuple[int, ...]: 

1037 return tuple(size_list) 

1038 

1039 @staticmethod 

1040 def _maybe_reuse_world_group(rank_list): 

1041 """Reuse the default world group for full-world rank lists.""" 

1042 normalized = tuple(sorted(rank_list)) 

1043 world_ranks = tuple(range(MindSporePlatform.get_world_size())) 

1044 if normalized != world_ranks: 

1045 return None 

1046 

1047 EXISTING_COMM_GROUPS[str(normalized)] = GlobalComm.WORLD_COMM_GROUP 

1048 return GlobalComm.WORLD_COMM_GROUP