Coverage for hyper_parallel / platform / torch / platform.py: 81%

372 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Torch platform api""" 

16from datetime import timedelta 

17from typing import Optional, Any, Union 

18import dataclasses 

19from collections import OrderedDict 

20 

21import numpy as np 

22import torch 

23from torch import nn 

24from torch import Tensor 

25from torch._C._distributed_c10d import Store, ProcessGroup 

26from torch.distributed import Backend 

27from torch.distributed.distributed_c10d import _get_default_group 

28from torch.nn import Parameter, Module 

29from torch.nn.utils.rnn import PackedSequence 

30from torch._ops import OpOverload, OpOverloadPacket 

31from torch.utils.checkpoint import noop_context_fn 

32from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper 

33import torch.distributed.nn.functional as dist_func 

34import torch.distributed as dist 

35from hyper_parallel.platform.torch.dtensor import DTensorBase 

36from hyper_parallel.platform.torch.pipeline_parallel.stage import PipelineStageBase 

37from hyper_parallel.platform.torch.group_utils import create_sub_groups 

38from hyper_parallel.platform.platform import Platform, PlatformType 

39from hyper_parallel.platform.torch.function_override import override_functions 

40 

41override_functions() 

42 

43# Mapping from string op names to torch.distributed.ReduceOp 

44_OP_MAP = { 

45 'sum': dist.ReduceOp.SUM, 

46 'prod': dist.ReduceOp.PRODUCT, 

47 'max': dist.ReduceOp.MAX, 

48 'min': dist.ReduceOp.MIN, 

49 # convert tensor elements to int32 and use MIN 

50 'all': dist.ReduceOp.MIN, 

51 # 'avg' is typically handled by SUM followed by division in current implementation logic 

52 'avg': dist.ReduceOp.SUM, 

53} 

54 

55# Try to add AVG for 'mean' if supported by current torch version 

56if hasattr(dist.ReduceOp, "AVG"): 

57 _OP_MAP['mean'] = dist.ReduceOp.AVG 

58else: 

59 # Fallback for older torch versions if necessary, though this might require manual division upstream 

60 # Assuming standard behavior where 'mean' implies native AVG support or upstream handling 

61 _OP_MAP['mean'] = dist.ReduceOp.SUM 

62 

63 

64# pylint: disable=C0103 

65class TorchPlatform(Platform): 

66 """Torch platform api""" 

67 Tensor = Tensor 

68 tensor = torch.tensor 

69 Parameter = Parameter 

70 Module = Module 

71 DTensorBase = DTensorBase 

72 PipelineStageBase = PipelineStageBase 

73 platform_type = PlatformType.PYTORCH 

74 tensor_dtype = torch 

75 

76 @staticmethod 

77 def device_count(device_handle): 

78 return device_handle.device_count() 

79 

80 def device_type(self): 

81 device_handle = self.get_device_handle() 

82 if device_handle == torch.npu: 

83 return "npu" 

84 return "cuda" 

85 

86 def device(self, device_idx=None): 

87 device_type = self.device_type() 

88 if device_idx is None: 

89 return torch.device(device_type) 

90 return torch.device(f"{device_type}:{device_idx:d}") 

91 

92 @staticmethod 

93 def get_rng_state(device=None, device_handle=None): 

94 if device_handle is None: 

95 return torch.get_rng_state() 

96 if device is None: 

97 return device_handle.get_rng_state() 

98 return device_handle.get_rng_state(device) 

99 

100 @staticmethod 

101 def set_rng_state(state, device=None, device_handle=None): 

102 if device_handle is None: 

103 return torch.set_rng_state(state) 

104 if device is None: 

105 return device_handle.set_rng_state(state) 

106 return device_handle.set_rng_state(state, device) 

107 

108 @staticmethod 

109 def manual_seed(seed): 

110 return torch.manual_seed(seed) 

111 

112 @staticmethod 

113 def ones(size, dtype=None): 

114 return torch.ones(size, dtype=dtype) 

115 

116 @staticmethod 

117 def zeros(size, dtype=None): 

118 return torch.zeros(size, dtype=dtype) 

119 

120 @staticmethod 

121 def full(size, fill_value, dtype=None): 

122 return torch.full(size, fill_value, dtype=dtype) 

123 

124 @staticmethod 

125 def empty(size, dtype=None): 

126 return torch.empty(size, dtype=dtype) 

127 

128 @staticmethod 

129 def get_rank(): 

130 return dist.get_rank() 

131 

132 @staticmethod 

133 def get_global_rank(group, group_rank): 

134 return dist.get_global_rank(group, group_rank) 

135 

136 @staticmethod 

137 def get_world_size(): 

138 return dist.get_world_size() 

139 

140 @staticmethod 

141 def get_param_local_shape(param): 

142 """get param local shape""" 

143 if isinstance(param, DTensorBase): 

144 return param.local_shape 

145 return param.shape 

146 

147 @staticmethod 

148 def get_param_local_data(param): 

149 """get param local shape""" 

150 if isinstance(param, DTensorBase): 

151 return param.to_local() 

152 return param 

153 

154 @staticmethod 

155 def update_param_data(param, data): 

156 """update param data""" 

157 param.data = data 

158 

159 @staticmethod 

160 def get_op_name(func): 

161 if hasattr(func, "__name__"): 

162 return func.__name__ 

163 if isinstance(func, OpOverload): 

164 full_name = func.name 

165 core_name = full_name.split("::")[-1].split(".")[0] 

166 return core_name 

167 if isinstance(func, OpOverloadPacket): 

168 return func.name.split("::")[-1] 

169 func_str = str(func) 

170 if "built-in function" in func_str: 

171 return func_str.split()[-1].strip(">") 

172 if "function" in func_str: 

173 return func_str.split()[1] 

174 return "unknown_op" 

175 

176 @staticmethod 

177 def differentiable_all_gather_concat(data, group, concat_size, concat_dim): 

178 output = dist_func.all_gather(data, group=group) 

179 return torch.cat(output, dim=concat_dim) 

180 

181 @staticmethod 

182 def chunk(data, split_dim, split_size, index): 

183 return torch.chunk(data, split_size, dim=split_dim)[index] 

184 

185 @staticmethod 

186 def differentiable_all_to_all(input_data, output_shape, group): 

187 output_tensor = torch.empty(output_shape, device=input_data.device, dtype=input_data.dtype) 

188 output_tensor = dist_func.all_to_all_single( 

189 output_tensor, 

190 input_data, 

191 group=group 

192 ) 

193 return output_tensor 

194 

195 @staticmethod 

196 def tensor_type_cast(input_data, cast_type): 

197 """Cast tensor to specified data type.""" 

198 type_mapping = { 

199 'float32': torch.float32, 

200 'float16': torch.float16, 

201 'int64': torch.int64, 

202 'int32': torch.int32 

203 } 

204 if cast_type not in type_mapping: 

205 raise ValueError(f"Unknown cast type: {cast_type}. Supported types: {list(type_mapping.keys())}") 

206 return input_data.to(type_mapping[cast_type]) 

207 

208 @staticmethod 

209 def differentiable_all_reduce(data, op, group): 

210 # Resolve the op from string to ReduceOp enum if necessary 

211 reduce_op = _OP_MAP.get(op, dist.ReduceOp.SUM) if isinstance(op, str) else op 

212 return dist_func.all_reduce(data, op=reduce_op, group=group) 

213 

214 @staticmethod 

215 def get_cell_construct(cell): 

216 return cell.forward 

217 

218 @staticmethod 

219 def get_cells_and_names(cell): 

220 return cell.named_modules() 

221 

222 @staticmethod 

223 def search_parameter_by_name(cell, param_name: str): 

224 """ 

225 Find the parent Module of the parameter, the parameter's name in the parent Module, and the parameter. 

226 Return value: (parent Module instance, parameter's name in parent Module, parameter object). 

227 Returns None if not found. 

228 """ 

229 # Remove the "self." prefix from param_name 

230 param_name = param_name.replace("self.", "") 

231 # Case 1: The parameter is a direct parameter of the current Module 

232 if param_name in cell._parameters: # pylint:disable=protected-access 

233 return (cell, param_name, cell._parameters[param_name]) # pylint:disable=protected-access 

234 

235 # Case 2: The parameter is in a sub-Module 

236 if "." in param_name: 

237 cell_path, param_key = param_name.rsplit(".", 1) 

238 try: 

239 # Locate the sub-Module where the parameter resides (supports multi-level paths) 

240 target_cell = cell.get_submodule(cell_path) 

241 # Check if the sub-Module directly contains this parameter 

242 if param_key in target_cell._parameters: # pylint:disable=protected-access 

243 return target_cell, param_key, target_cell._parameters[param_key] # pylint:disable=protected-access 

244 except AttributeError: 

245 pass 

246 

247 # Traverse all sub-Modules (recursively) to search for the parameter 

248 for _, child_cell in cell.named_children(): 

249 if isinstance(child_cell, Module): 

250 result = TorchPlatform.search_parameter_by_name(child_cell, param_name) 

251 if result is not None: 

252 return result 

253 

254 return None 

255 

256 @staticmethod 

257 def update_parameter_by_name(cell, result: tuple, new_param) -> bool: 

258 """ 

259 Modify the original parameter in a Module or sub-Module using the search result 

260 """ 

261 parent_cell, param_key, _ = result 

262 # Key operation: directly modify the _parameters dictionary. 

263 if param_key in parent_cell._parameters: # pylint:disable=protected-access 

264 parent_cell._parameters[param_key] = new_param # pylint:disable=protected-access 

265 else: 

266 parent_cell.register_parameter(param_key, new_param) 

267 return True 

268 

269 @staticmethod 

270 def set_layout_into_parameter(param, layout): 

271 """Set layout in to parameter""" 

272 from hyper_parallel.core.dtensor import DTensor # pylint: disable=import-outside-toplevel 

273 from hyper_parallel.core.layout import _get_slice_tensor_by_layout # pylint: disable=import-outside-toplevel 

274 if isinstance(param, DTensor): 

275 raise ValueError(f"Parameter {param} has been configured layout, cannot be set repeatedly.") 

276 requires_grad = param.requires_grad 

277 param_dtensor = DTensor.from_local(_get_slice_tensor_by_layout(param, layout), layout.mesh, layout.placements) 

278 new_param = Parameter(param_dtensor, requires_grad=requires_grad) 

279 return new_param 

280 

281 @staticmethod 

282 def differentiable_reduce_scatter(data, dev_num, axis, op, group): 

283 input_tuple = torch.chunk(data, dev_num, dim=axis) 

284 output_tensor = torch.empty(input_tuple[0].shape, device=data.device, dtype=data.dtype) 

285 

286 # Resolve the op from string to ReduceOp enum 

287 reduce_op = _OP_MAP.get(op, dist.ReduceOp.SUM) if isinstance(op, str) else op 

288 

289 output_tensor = dist_func.reduce_scatter(output_tensor, input_tuple, op=reduce_op, group=group) 

290 

291 # Keep manual handling for 'avg' string as it maps to SUM in _OP_MAP 

292 if op == 'avg': 

293 output_tensor = output_tensor / dev_num 

294 return output_tensor 

295 

296 @staticmethod 

297 def get_device_handle(): 

298 if hasattr(torch, "npu"): 

299 return torch.npu 

300 return torch.cuda 

301 

302 @staticmethod 

303 def get_param_type_size(param): 

304 # pylint: disable=W0212 

305 return torch._utils._element_size(param.dtype) 

306 

307 @staticmethod 

308 def parameters_dict(cell: Module): 

309 return cell.named_parameters() 

310 

311 @staticmethod 

312 def save_checkpoint(cell: Module, file_path: str) -> None: 

313 torch.save(obj=cell, f=file_path) 

314 

315 @staticmethod 

316 def load_checkpoint(file_path: str) -> dict: 

317 return torch.load(f=file_path) 

318 

319 @staticmethod 

320 def new_zero_parameter(param_shape, param_type, requires_grad, device): 

321 return nn.Parameter(torch.zeros(param_shape, dtype=param_type, device=device), requires_grad=requires_grad) 

322 

323 @staticmethod 

324 def new_tensor(tensor_shape, tensor_type, device): 

325 return torch.empty(size=tensor_shape, dtype=tensor_type, device=device) 

326 

327 @staticmethod 

328 def full_like(tensor, fill_value, dtype=None): 

329 return torch.full_like(tensor, fill_value, dtype=dtype) 

330 

331 @staticmethod 

332 def set_tensor_requires_grad(input_tensor): 

333 """ 

334 set requires grad flag for input tensor, only effective for leaf node 

335 """ 

336 if input_tensor.is_leaf: 

337 input_tensor.requires_grad = True 

338 

339 def _create_group(self, rank_list, group_name=None): 

340 group_dict = create_sub_groups(rank_list) 

341 return group_dict[tuple(rank_list)] 

342 

343 @staticmethod 

344 def all_gather_into_tensor(data, group_info, async_op=False): 

345 output_shape = list(data.shape) 

346 output_shape[0] = output_shape[0] * group_info.rank_size 

347 output = torch.empty(output_shape, dtype=data.dtype, device=data.device) 

348 handle = dist.all_gather_into_tensor(output, data, group=group_info.group, async_op=async_op) 

349 return output, handle 

350 

351 @staticmethod 

352 def all_reduce(data, group_info, async_op=False): 

353 if not data.is_contiguous(): 

354 data = data.contiguous() 

355 handle = dist.all_reduce(data, group=group_info.group, async_op=async_op) 

356 return data, handle 

357 

358 @staticmethod 

359 def broadcast(data, src, group=None, async_op=False): 

360 handle = dist.broadcast(data, src, group, async_op) 

361 if async_op: 

362 handle.wait() 

363 

364 @staticmethod 

365 def isend(tensor, dst=None, group=None, tag=0): 

366 return dist.isend(tensor, dst, group, tag) 

367 

368 @staticmethod 

369 def irecv(tensor, src=None, group=None, tag=0): 

370 return dist.irecv(tensor, src, group, tag) 

371 

372 @staticmethod 

373 def send_object_list(obj_list, dst=None, group=None): 

374 dist.send_object_list(obj_list, dst, group) 

375 

376 @staticmethod 

377 def recv_object_list(obj_list, src=None, group=None): 

378 dist.recv_object_list(obj_list, src, group) 

379 

380 @staticmethod 

381 def reduce_scatter_tensor(data, group_info, async_op=False): 

382 output_shape = list(data.shape) 

383 output_shape[0] = output_shape[0] // group_info.rank_size 

384 output = torch.empty(output_shape, dtype=data.dtype, device=data.device) 

385 handle = dist.reduce_scatter_tensor(output, data, group=group_info.group, async_op=async_op) 

386 return output, handle 

387 

388 @staticmethod 

389 def get_tensor_transform(): 

390 raise NotImplementedError("Unsupported get_tensor_transform for torch platform") 

391 

392 @staticmethod 

393 def construct_strided_slice(x, begin, end, stride): 

394 raise NotImplementedError("Unsupported construct_strided_slice for torch platform") 

395 

396 @staticmethod 

397 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None): 

398 # pylint: disable=C0415 

399 from hyper_parallel.platform.torch.pipeline_parallel._utils import _MicroBatch 

400 return _MicroBatch(micro_batch_num, args_batch_dim, kwargs_batch_dim) 

401 

402 def new_stream(self): 

403 device = self.get_device_handle() 

404 return device.Stream() 

405 

406 def get_stream_context(self): 

407 device = self.get_device_handle() 

408 return device.stream 

409 

410 @staticmethod 

411 def all_gather_object(object_list, obj, group=None) -> None: 

412 """ 

413 Gathers objects from the given group into object list. 

414 

415 Args: 

416 object_list (list[Any]): Define the output list, which size equal to the size of group. 

417 obj (Any): The object on current rank and in given process group. 

418 group (ProcessGroup, optional): The process group to gather obj. Default is ``None``, and ``None`` means 

419 global group. 

420 

421 Returns: 

422 None. Objs are gathered into ``object_list``. 

423 """ 

424 dist.all_gather_object(object_list, obj, group) 

425 

426 @staticmethod 

427 def init_process_group( 

428 backend: Optional[str] = None, 

429 *, 

430 init_method: Optional[str] = None, 

431 timeout: Optional[timedelta] = None, 

432 world_size: int = -1, 

433 rank: int = -1, 

434 store: Optional[Store] = None, 

435 pg_options: Optional[Any] = None, 

436 device_id: Optional[Union[torch.device, int]] = None, 

437 ) -> None: 

438 """ 

439 Initialize global process group. 

440 

441 Args: 

442 backend (str or Backend, optional): The backend to use for distributed communication. 

443 init_method (str, optional): URL specifying how to initialize the process group. Default is "env://", 

444 can not be specified at the same time with ``store``. 

445 timeout (timedelta, optional): Timeout for process group. Default 10 minutes for NCCL and for other 

446 backends 30 minutes. 

447 world_size (int, optional): Number of processes. If ``store`` is specified, world_size is required. 

448 rank (int, optional): Rank of the current process, which value must between 0 and ``world_size``-1. If 

449 ``store`` is specified, rank is required. 

450 store (Store, optional): Key/value store accessible to all workers, used to exchange connection/address 

451 information. Can not be specified at the same time with ``init_method``. 

452 pg_options (ProcessGroupOptions, optional): Extra options to pass during constructing process groups. 

453 device_id (torch.device | int, optional): Specific device this process will work on. 

454 """ 

455 try: 

456 _get_default_group() 

457 # except multi version error 

458 except (ValueError, RuntimeError): 

459 if backend is None: 

460 backend = "hccl" 

461 dist.init_process_group(backend=backend, init_method=init_method, timeout=timeout, world_size=world_size, 

462 rank=rank, store=store, pg_options=pg_options, device_id=device_id) 

463 

464 @staticmethod 

465 def destroy_process_group(group: Optional[ProcessGroup] = None) -> None: 

466 """ 

467 Destroy given process group. 

468 

469 Args: 

470 group (ProcessGroup, optional): Given process group will be destroyed, if not given, all process groups 

471 will be destroyed. 

472 """ 

473 group = group or _get_default_group() 

474 dist.destroy_process_group(group) 

475 

476 @staticmethod 

477 def get_process_group_ranks(group: Optional[ProcessGroup] = None) -> list[int]: 

478 """ 

479 Get all ranks relative to given process group. 

480 

481 Args: 

482 group (Optional[ProcessGroup]): Process group worked on. Default is ``None``, and ``None`` means global 

483 group. 

484 

485 Returns: 

486 Rank list. 

487 """ 

488 group = group or _get_default_group() 

489 return dist.get_process_group_ranks(group) 

490 

491 @staticmethod 

492 def get_backend(group: Optional[ProcessGroup] = None) -> Backend: 

493 """ 

494 Get the backend of the given process group. 

495 

496 Args: 

497 group (ProcessGroup, optional): Process group worked on. Default is ``None``, and ``None`` means global 

498 group. 

499 

500 Returns: 

501 The backend object of the given process group. 

502 """ 

503 group = group or _get_default_group() 

504 return dist.get_backend(group) 

505 

506 @staticmethod 

507 def split_group(parent_pg: Optional[ProcessGroup] = None, 

508 split_ranks: Optional[list] = None, 

509 timeout: Optional[timedelta] = None, 

510 pg_options: Optional[Any] = None, 

511 group_desc: Optional[str] = None, 

512 ) -> Optional[ProcessGroup]: 

513 """ 

514 Create split groups for every group rank in split_ranks, and return the split process group which relative to 

515 current rank id. 

516 

517 Args: 

518 parent_pg (Optional[ProcessGroup]): A process group which the goal group split from. 

519 split_ranks (Optional[list]): A list like ``list[list[int]]``. 

520 timeout (Optional[timedelta]): Timeout for process group. Default 10 minutes for NCCL and for other 

521 backend 30 minutes. 

522 pg_options (Optional[Any]): Extra options to pass during constructing process groups. 

523 group_desc (Optional[str]): Description of process group. 

524 

525 Return: 

526 Optional[ProcessGroup]: One of split process group which relative to current rank id 

527 """ 

528 if split_ranks is None or len(split_ranks) == 0: 

529 raise ValueError("split_ranks cannot be None or empty") 

530 

531 split_group = None 

532 for split_rank in split_ranks: 

533 dist_group = dist.new_group(ranks=split_rank) 

534 if TorchPlatform.get_rank() in split_rank: 

535 split_group = dist_group 

536 

537 return split_group 

538 

539 @staticmethod 

540 def no_grad(): 

541 return torch.no_grad() 

542 

543 @staticmethod 

544 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False): 

545 return torch.empty_like(tensor, dtype=dtype, device=device, pin_memory=pin_memory) 

546 

547 def get_current_stream(self): 

548 device = self.get_device_handle() 

549 return device.current_stream() 

550 

551 def new_event(self): 

552 device = self.get_device_handle() 

553 return device.Event() 

554 

555 def tree_map(self, fn, tree): 

556 return torch.utils._pytree.tree_map(fn, tree) # pylint:disable=protected-access 

557 

558 @property 

559 def checkpoint(self): 

560 return torch.utils.checkpoint.checkpoint 

561 

562 @staticmethod 

563 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs): 

564 return checkpoint_wrapper(module, checkpoint_fn=checkpoint_fn, **checkpoint_fn_kwargs) 

565 

566 @property 

567 def noop_context_fn(self): 

568 return noop_context_fn 

569 

570 @staticmethod 

571 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): 

572 # pylint: disable=C0415 

573 from hyper_parallel.platform.torch.activation_checkpoint.sac import create_selective_checkpoint_contexts 

574 return create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation) 

575 

576 @staticmethod 

577 def async_save_on_cpu(policy_fn=None): 

578 # pylint: disable=C0415 

579 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import AsyncSaveOnCpu 

580 return AsyncSaveOnCpu(policy_fn) 

581 

582 @staticmethod 

583 def tensor_to_numpy(tensor) -> np.ndarray: 

584 """Convert PyTorch tensor to numpy array.""" 

585 return tensor.cpu().numpy() 

586 

587 def cast_fp_tensor(self,dtype, x): 

588 """ 

589 Cast floating-point tensor to target dtype if applicable. 

590 """ 

591 if ( 

592 not isinstance(x, torch.Tensor) 

593 or not torch.is_floating_point(x) 

594 or x.dtype == dtype 

595 ): 

596 return x 

597 return x.to(dtype) 

598 

599 def apply_to_tensors(self, fn, container): 

600 """Recursively apply to all tensor in different kinds of container types.""" 

601 

602 def apply(x): 

603 

604 if isinstance(x, torch.Tensor): 

605 return fn(x) 

606 if hasattr(x, "__dataclass_fields__"): 

607 dc = dataclasses.replace(x) 

608 changes = { 

609 f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc) 

610 } 

611 return dataclasses.replace(dc, **changes) 

612 if isinstance(x, OrderedDict): 

613 od = x.__class__() 

614 for key, value in x.items(): 

615 od[key] = apply(value) 

616 return od 

617 if isinstance(x, PackedSequence): 

618 apply(x.data) 

619 return x 

620 if isinstance(x, dict): 

621 return {key: apply(value) for key, value in x.items()} 

622 if isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields"): 

623 res = (apply(el) for el in x) 

624 return type(x)(*res) 

625 if isinstance(x, (list, tuple, set)): 

626 return type(x)(apply(el) for el in x) 

627 return x 

628 

629 return apply(container)