Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / platform.py: 55%

602 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Torch platform api""" 

16from datetime import timedelta 

17from typing import Optional, Any, Union 

18import dataclasses 

19from collections import OrderedDict 

20 

21import numpy as np 

22from safetensors.torch import save_file, load_file 

23import torch 

24from torch import nn 

25from torch import Tensor 

26from torch._C._distributed_c10d import Store, ProcessGroup 

27from torch.distributed import Backend 

28from torch.distributed.distributed_c10d import _get_default_group 

29from torch.nn import Parameter, Module 

30from torch.nn.utils.rnn import PackedSequence 

31from torch._ops import OpOverload, OpOverloadPacket 

32from torch.utils.checkpoint import noop_context_fn 

33from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper 

34import torch.distributed.nn.functional as dist_func 

35import torch.distributed as dist 

36from hyper_parallel.platform.torch.dtensor import DTensorBase 

37from hyper_parallel.platform.torch.pipeline_parallel.stage import PipelineStageBase 

38from hyper_parallel.platform.torch.group_utils import create_sub_groups 

39from hyper_parallel.platform.platform import Platform, PlatformType, EXISTING_COMM_GROUPS 

40from hyper_parallel.platform.torch.function_override import override_functions 

41from hyper_parallel.platform.torch.init_weights import init_on_device as _init_on_device 

42 

43override_functions() 

44 

45 

46# --------------------------------------------------------------------------- 

47# Module-level A2A reshape helpers 

48# --------------------------------------------------------------------------- 

49 

50def _a2a_reconstruct(out_perm: torch.Tensor, concat_dim: int) -> torch.Tensor: 

51 """Reconstruct A2A result from raw out_perm buffer. 

52 

53 ``out_perm`` has shape ``[ws, *rest_dims]``, chunk at ``concat_dim + 1``. 

54 Returns tensor with merged chunk dimension. 

55 """ 

56 new_ndim = out_perm.dim() 

57 chunk_in_perm = concat_dim + 1 

58 recon_perm = list(range(1, chunk_in_perm)) + [0] + list(range(chunk_in_perm, new_ndim)) 

59 x_recon = out_perm.permute(recon_perm).contiguous() 

60 shape = list(x_recon.shape) 

61 merged = shape[concat_dim] * shape[concat_dim + 1] 

62 return x_recon.reshape(shape[:concat_dim] + [merged] + shape[concat_dim + 2:]) 

63 

64 

65class _TorchAsyncA2AFunction(torch.autograd.Function): 

66 """Differentiable wrapper for pre-launched async all-to-all. 

67 

68 Forward: wait async handle, reconstruct A2A result. 

69 Backward: launch async head→seq A2A and store handle in ``handle_box`` 

70 for the projection pre-hook to wait, achieving GEMM–A2A overlap. 

71 """ 

72 

73 @staticmethod 

74 def forward(ctx, x, work, out_perm, group, world_size, concat_dim, split_dim, # pylint: disable=arguments-differ 

75 handle_box): 

76 """Wait for pre-launched async A2A and return reconstructed output.""" 

77 ctx.group = group 

78 ctx.world_size = world_size 

79 ctx.concat_dim = concat_dim 

80 ctx.split_dim = split_dim 

81 ctx.handle_box = handle_box 

82 ctx.x_shape = x.shape 

83 work.wait() 

84 return _a2a_reconstruct(out_perm, concat_dim) 

85 

86 @staticmethod 

87 def backward(ctx, grad_output): 

88 """Launch async head→seq A2A for backward overlap, or return zero grad.""" 

89 if ctx.handle_box is not None: 

90 # Launch async head→seq A2A (reverse of forward seq→head) 

91 g = grad_output.contiguous() 

92 shape = list(g.shape) 

93 seq_dim = ctx.concat_dim 

94 s_full = shape[seq_dim] 

95 ndim = len(shape) + 1 

96 x_perm = g.reshape( 

97 shape[:seq_dim] + [ctx.world_size, s_full // ctx.world_size] + shape[seq_dim + 1:] 

98 ).permute( 

99 [seq_dim] + list(range(seq_dim)) + list(range(seq_dim + 1, ndim)) 

100 ).contiguous() 

101 out_perm = torch.empty_like(x_perm) 

102 work = dist.all_to_all_single(out_perm, x_perm, group=ctx.group, async_op=True) 

103 ctx.handle_box.append((work, out_perm)) 

104 return grad_output.new_zeros(ctx.x_shape), None, None, None, None, None, None, None 

105 

106 

107class _AsyncA2ALazyBwd(torch.autograd.Function): 

108 """All-to-all whose forward AND backward return ``AsyncCollectiveTensor``. 

109 

110 PyTorch's stock ``all_to_all_single_autograd`` calls ``wait_tensor`` in 

111 its backward eagerly, and the autograd engine binds backward stream 

112 context to the forward stream — so even if the BWD thread is wrapped 

113 in a side-stream context, that wait still lands on the FWD main 

114 stream and blocks Attention launches. 

115 

116 This Function bypasses the engine's binding by calling the 

117 non-autograd functional op in both directions and returning ACT. 

118 The wait is deferred to the next consumer's first non-view access 

119 (e.g. the indexing backward of ``_unpermute``), giving the FWD 

120 thread a small Python window to enqueue its Attention kernels onto 

121 the main stream **before** the wait lands there. 

122 """ 

123 

124 @staticmethod 

125 def forward(ctx, input_tensor, output_splits, input_splits, group): # pylint: disable=arguments-differ 

126 ctx.input_splits = input_splits 

127 ctx.output_splits = output_splits 

128 ctx.group = group 

129 # pylint: disable=C0415 

130 from torch.distributed._functional_collectives import all_to_all_single 

131 return all_to_all_single( 

132 input_tensor, output_splits, input_splits, group, 

133 ) 

134 

135 @staticmethod 

136 def backward(ctx, grad_output): 

137 # pylint: disable=C0415 

138 from torch.distributed._functional_collectives import all_to_all_single 

139 grad_input = all_to_all_single( 

140 grad_output, ctx.input_splits, ctx.output_splits, ctx.group, 

141 ) 

142 return grad_input, None, None, None 

143 

144 

145class _TorchSyncHookFunction(torch.autograd.Function): 

146 """Autograd identity that fires HookCoordinator rendezvous on fwd/bwd. 

147 

148 Uses a **4-hook** design (``A``, ``B``, ``C``, ``D``) with pure 

149 COMM / COMPUTE roles — no NONE role. Every rendezvous is a strict 

150 COMM + COMPUTE pair, guaranteeing NCCL-first dispatch ordering at 

151 **all** points including layer boundaries. 

152 

153 Hook placement per MoE layer:: 

154 

155 [A] → dispatch → [B] → module → [C] → combine → [D] → (Attention) → [A_next] 

156 

157 At layer boundaries (D / A hooks), the Attention that runs between 

158 layers is treated as COMPUTE, and the combine / combine.bwd is treated 

159 as COMM, so the coordinator enforces comm-first ordering even across 

160 layer transitions. 

161 """ 

162 

163 # 4-hook role tables: (prev_role_idx, next_role_idx). 

164 # Index encoding: 1 = COMM, 2 = COMPUTE. 

165 _FWD_ROLES = { 

166 # (prev, next) prev op next op 

167 "A": (2, 1), # COMPUTE, COMM Attention | dispatch 

168 "B": (1, 2), # COMM, COMPUTE dispatch | module 

169 "C": (2, 1), # COMPUTE, COMM module | combine 

170 "D": (1, 2), # COMM, COMPUTE combine | Attention 

171 } 

172 _BWD_ROLES = { 

173 "D": (2, 1), # COMPUTE, COMM Attn.bwd | combine.bwd 

174 "C": (1, 2), # COMM, COMPUTE combine.bwd | module.bwd 

175 "B": (2, 1), # COMPUTE, COMM module.bwd | dispatch.bwd 

176 "A": (1, 2), # COMM, COMPUTE dispatch.bwd| Attn.bwd 

177 } 

178 

179 _ROLE_CACHE = None 

180 

181 @staticmethod 

182 def _role_enum(idx: int): 

183 if _TorchSyncHookFunction._ROLE_CACHE is None: 

184 from hyper_parallel.core.pipeline_parallel.hook_coordinator import HookRole # pylint: disable=C0415 

185 _TorchSyncHookFunction._ROLE_CACHE = (None, HookRole.COMM, HookRole.COMPUTE) 

186 return _TorchSyncHookFunction._ROLE_CACHE[idx] 

187 

188 @staticmethod 

189 def forward(ctx, x, hook_name, coordinator): # pylint: disable=arguments-differ 

190 """Identity forward that fires a HookCoordinator rendezvous. 

191 

192 Notifies the previous op's role and rendezvouses for the next op's 

193 role per the ``_FWD_ROLES`` table. ``"D_LAST"`` is a sentinel 

194 meaning "skip this rendezvous" (last layer's closing D — no 

195 Attention follows). 

196 

197 Args: 

198 ctx: Autograd context, stores ``hook_name`` and 

199 ``coordinator`` for the backward pass. 

200 x: Input tensor, returned unchanged. 

201 hook_name: One of ``"A"``, ``"B"``, ``"C"``, ``"D"``, 

202 ``"D_LAST"``. 

203 coordinator: The :class:`HookCoordinator` driving the rendezvous. 

204 

205 Returns: 

206 ``x`` unchanged. 

207 """ 

208 ctx.hook_name = hook_name 

209 ctx.coordinator = coordinator 

210 

211 if not coordinator.is_enabled(): 

212 return x 

213 

214 # ``D_LAST`` marks the last layer's D hook. The "next op" after 

215 # this hook is the chunk's output (no Attention follows), so the 

216 # rendezvous is meaningless — skip it. In backward this same 

217 # hook is the very first BWD hook to fire, where ``combine.bwd`` 

218 # has already free-run before any rendezvous is possible — also 

219 # skip. Tagging at wrap time replaces the old runtime 

220 # ``increment_cycle`` / ``bwd_d_should_skip`` mechanisms. 

221 if hook_name == "D_LAST": 

222 return x 

223 

224 prev_idx, next_idx = _TorchSyncHookFunction._FWD_ROLES[hook_name] 

225 role_of = _TorchSyncHookFunction._role_enum 

226 coordinator.notify_dispatched(role_of(prev_idx)) 

227 coordinator.rendezvous(role_of(next_idx)) 

228 return x 

229 

230 @staticmethod 

231 def backward(ctx, grad_output): 

232 """Identity backward that fires a HookCoordinator rendezvous. 

233 

234 Mirror of :meth:`forward` using the ``_BWD_ROLES`` table. 

235 ``"D_LAST"`` skips the rendezvous because this is the first BWD 

236 hook to fire and ``combine.bwd`` has already dispatched freely 

237 before any rendezvous can happen. 

238 

239 Args: 

240 ctx: Autograd context with ``hook_name`` and 

241 ``coordinator`` saved during forward. 

242 grad_output: Gradient w.r.t. the forward output, returned 

243 unchanged. 

244 

245 Returns: 

246 ``(grad_output, None, None)`` — gradients only flow back to 

247 the tensor input, ``hook_name`` and ``coordinator`` are 

248 non-tensor inputs. 

249 """ 

250 hook_name = ctx.hook_name 

251 coordinator = ctx.coordinator 

252 

253 if not coordinator.is_enabled(): 

254 return grad_output, None, None 

255 

256 # Same ``D_LAST`` semantics as forward: this is the first BWD 

257 # hook to fire and combine.bwd has already dispatched freely 

258 # before any rendezvous can happen, so skip the rendezvous. 

259 if hook_name == "D_LAST": 

260 return grad_output, None, None 

261 

262 prev_idx, next_idx = _TorchSyncHookFunction._BWD_ROLES[hook_name] 

263 role_of = _TorchSyncHookFunction._role_enum 

264 coordinator.notify_dispatched(role_of(prev_idx)) 

265 coordinator.rendezvous(role_of(next_idx)) 

266 return grad_output, None, None 

267 

268 

269class _TorchP2PExchangeFunction(torch.autograd.Function): 

270 """Symmetric bidirectional P2P: send local tensor to peer, receive peer's tensor.""" 

271 

272 @staticmethod 

273 def forward(ctx, tensor: torch.Tensor, peer_rank: int, group) -> torch.Tensor: # pylint: disable=arguments-differ 

274 """Perform symmetric bidirectional P2P exchange with peer_rank.""" 

275 ctx.peer_rank = peer_rank 

276 ctx.group = group 

277 send_buf = tensor.contiguous() 

278 recv_buf = torch.empty_like(send_buf) 

279 reqs = dist.batch_isend_irecv([ 

280 dist.P2POp(dist.isend, send_buf, peer_rank, group), 

281 dist.P2POp(dist.irecv, recv_buf, peer_rank, group), 

282 ]) 

283 for req in reqs: 

284 req.wait() 

285 return recv_buf 

286 

287 @staticmethod 

288 def backward(ctx, grad_output: torch.Tensor): 

289 """Perform symmetric P2P exchange for the backward gradient pass.""" 

290 send_buf = grad_output.contiguous() 

291 recv_buf = torch.empty_like(send_buf) 

292 reqs = dist.batch_isend_irecv([ 

293 dist.P2POp(dist.isend, send_buf, ctx.peer_rank, ctx.group), 

294 dist.P2POp(dist.irecv, recv_buf, ctx.peer_rank, ctx.group), 

295 ]) 

296 for req in reqs: 

297 req.wait() 

298 return recv_buf, None, None 

299 

300 

301# Mapping from string op names to torch.distributed.ReduceOp 

302_OP_MAP = { 

303 'sum': dist.ReduceOp.SUM, 

304 'prod': dist.ReduceOp.PRODUCT, 

305 'max': dist.ReduceOp.MAX, 

306 'min': dist.ReduceOp.MIN, 

307 # convert tensor elements to int32 and use MIN 

308 'all': dist.ReduceOp.MIN, 

309 # 'avg' is typically handled by SUM followed by division in current implementation logic 

310 'avg': dist.ReduceOp.SUM, 

311} 

312 

313# Try to add AVG for 'mean' if supported by current torch version 

314if hasattr(dist.ReduceOp, "AVG"): 

315 _OP_MAP['mean'] = dist.ReduceOp.AVG 

316else: 

317 # Fallback for older torch versions if necessary, though this might require manual division upstream 

318 # Assuming standard behavior where 'mean' implies native AVG support or upstream handling 

319 _OP_MAP['mean'] = dist.ReduceOp.SUM 

320 

321 

322# pylint: disable=C0103 

323class TorchPlatform(Platform): 

324 """Torch platform api""" 

325 Tensor = Tensor 

326 tensor = torch.tensor 

327 Parameter = Parameter 

328 Module = Module 

329 DTensorBase = DTensorBase 

330 PipelineStageBase = PipelineStageBase 

331 platform_type = PlatformType.PYTORCH 

332 tensor_dtype = torch 

333 dtype = torch.dtype 

334 Function = torch.autograd.Function 

335 

336 @staticmethod 

337 def is_linear_module(module) -> bool: 

338 """Check whether *module* is a ``torch.nn.Linear`` instance.""" 

339 return isinstance(module, nn.Linear) 

340 

341 @staticmethod 

342 def is_embedding_module(module) -> bool: 

343 """Check whether *module* is a ``torch.nn.Embedding`` instance.""" 

344 return isinstance(module, nn.Embedding) 

345 

346 @staticmethod 

347 def device_count(device_handle): 

348 """ 

349 Get the number of available devices. 

350 

351 Args: 

352 device_handle: The device handle (e.g., torch.cuda, torch.npu). 

353 

354 Returns: 

355 int: The number of available devices. 

356 """ 

357 return device_handle.device_count() 

358 

359 def device_type(self): 

360 """ 

361 Get the current device type. 

362 

363 Returns: 

364 str: The device type string ("npu" for NPU, "cuda" for GPU). 

365 """ 

366 device_handle = self.get_device_handle() 

367 if device_handle == torch.npu: 

368 return "npu" 

369 return "cuda" 

370 

371 def device(self, device_idx=None): 

372 """ 

373 Get a torch.device object for the specified device index. 

374 

375 Args: 

376 device_idx (Optional[int]): The device index. If None, returns device without index. 

377 

378 Returns: 

379 torch.device: A torch device object. 

380 """ 

381 device_type = self.device_type() 

382 if device_idx is None: 

383 return torch.device(device_type) 

384 return torch.device(f"{device_type}:{device_idx:d}") 

385 

386 @staticmethod 

387 def get_rng_state(device=None, device_handle=None): 

388 """ 

389 Get the random number generator state. 

390 

391 Args: 

392 device (Optional): The device to get RNG state from. 

393 device_handle (Optional): The device handle (torch.cuda, torch.npu, etc.). 

394 

395 Returns: 

396 Tensor: The RNG state as a byte tensor. 

397 """ 

398 if device_handle is None: 

399 return torch.get_rng_state() 

400 if device is None: 

401 return device_handle.get_rng_state() 

402 return device_handle.get_rng_state(device) 

403 

404 @staticmethod 

405 def set_rng_state(state, device=None, device_handle=None): 

406 """ 

407 Set the random number generator state. 

408 

409 Args: 

410 state (Tensor): The RNG state to set. 

411 device (Optional): The device to set RNG state for. 

412 device_handle (Optional): The device handle (torch.cuda, torch.npu, etc.). 

413 """ 

414 if device_handle is None: 

415 return torch.set_rng_state(state) 

416 if device is None: 

417 return device_handle.set_rng_state(state) 

418 return device_handle.set_rng_state(state, device) 

419 

420 @staticmethod 

421 def manual_seed(seed): 

422 """ 

423 Set the random seed for reproducibility. 

424 

425 Args: 

426 seed (int): The random seed value. 

427 

428 Returns: 

429 torch.Generator: The random number generator. 

430 """ 

431 return torch.manual_seed(seed) 

432 

433 @staticmethod 

434 def ones(size, dtype=None): 

435 """ 

436 Create a tensor filled with ones. 

437 

438 Args: 

439 size (tuple): The shape of the output tensor. 

440 dtype (Optional[torch.dtype]): The desired data type. 

441 

442 Returns: 

443 Tensor: A tensor filled with ones. 

444 """ 

445 return torch.ones(size, dtype=dtype) 

446 

447 @staticmethod 

448 def zeros(size, dtype=None, device=None): 

449 """ 

450 Create a tensor filled with zeros. 

451 

452 Args: 

453 size (tuple): The shape of the output tensor. 

454 dtype (Optional[torch.dtype]): The desired data type. 

455 device (Optional[torch.device]): The device to create the tensor on. 

456 

457 Returns: 

458 Tensor: A tensor filled with zeros. 

459 """ 

460 return torch.zeros(size, dtype=dtype, device=device) 

461 

462 @staticmethod 

463 def full(size, fill_value, dtype=None): 

464 """ 

465 Create a tensor filled with a scalar value. 

466 

467 Args: 

468 size (tuple): The shape of the output tensor. 

469 fill_value (scalar): The value to fill the tensor with. 

470 dtype (Optional[torch.dtype]): The desired data type. 

471 

472 Returns: 

473 Tensor: A tensor filled with the specified value. 

474 """ 

475 return torch.full(size, fill_value, dtype=dtype) 

476 

477 @staticmethod 

478 def empty(size, dtype=None): 

479 """ 

480 Create an uninitialized tensor. 

481 

482 Args: 

483 size (tuple): The shape of the output tensor. 

484 dtype (Optional[torch.dtype]): The desired data type. 

485 

486 Returns: 

487 Tensor: An uninitialized tensor. 

488 """ 

489 return torch.empty(size, dtype=dtype) 

490 

491 @staticmethod 

492 def get_rank(): 

493 """ 

494 Get the rank of the current process in the distributed group. 

495 

496 Returns: 

497 int: The rank of the current process. 

498 """ 

499 return dist.get_rank() 

500 

501 @staticmethod 

502 def get_global_rank(group, group_rank): 

503 """ 

504 Get the global rank from a group rank. 

505 

506 Args: 

507 group (ProcessGroup): The process group. 

508 group_rank (int): The rank within the group. 

509 

510 Returns: 

511 int: The global rank. 

512 """ 

513 return dist.get_global_rank(group, group_rank) 

514 

515 @staticmethod 

516 def get_world_size(): 

517 """ 

518 Get the total number of processes in the distributed group. 

519 

520 Returns: 

521 int: The world size. 

522 """ 

523 return dist.get_world_size() 

524 

525 @staticmethod 

526 def get_param_local_shape(param): 

527 """ 

528 Get the local shape of a parameter, handling both regular and distributed tensors. 

529 

530 Args: 

531 param (Union[Tensor, DTensorBase]): The parameter tensor. 

532 

533 Returns: 

534 torch.Size: The local shape of the parameter. 

535 """ 

536 if isinstance(param, DTensorBase): 

537 return param.local_shape 

538 return param.shape 

539 

540 @staticmethod 

541 def get_param_local_data(param): 

542 """ 

543 Get the local data of a parameter, handling both regular and distributed tensors. 

544 

545 Args: 

546 param (Union[Tensor, DTensorBase]): The parameter tensor. 

547 

548 Returns: 

549 Tensor: The local tensor data. 

550 """ 

551 if isinstance(param, DTensorBase): 

552 return param.to_local() 

553 return param 

554 

555 @staticmethod 

556 def update_param_data(param, data): 

557 """ 

558 Update the data of a parameter. 

559 

560 Args: 

561 param (Parameter): The parameter to update. 

562 data (Tensor): The new data tensor. 

563 """ 

564 param.data = data 

565 

566 @staticmethod 

567 def load_into_param(param, data): 

568 """Load tensor *data* into *param* (plain tensor or DTensor).""" 

569 if isinstance(param, DTensorBase): 

570 local = param._local_tensor # pylint: disable=W0212 

571 if local.is_meta: 

572 # Meta tensor materialisation: replace the placeholder. 

573 orig_requires_grad = param.requires_grad 

574 param._local_tensor = data # pylint: disable=W0212 

575 if data.requires_grad != orig_requires_grad: 

576 param.requires_grad_(orig_requires_grad) 

577 else: 

578 local.copy_(data) 

579 else: 

580 param.copy_(data) 

581 

582 @staticmethod 

583 def get_op_name(func): 

584 """ 

585 Extract the operation name from various function types. 

586 

587 Args: 

588 func: The function or operation to extract the name from. 

589 

590 Returns: 

591 str: The operation name. 

592 """ 

593 if hasattr(func, "__name__"): 

594 return func.__name__ 

595 if isinstance(func, OpOverload): 

596 full_name = func.name 

597 core_name = full_name.split("::")[-1].split(".")[0] 

598 return core_name 

599 if isinstance(func, OpOverloadPacket): 

600 return func.name.split("::")[-1] 

601 func_str = str(func) 

602 if "built-in function" in func_str: 

603 return func_str.split()[-1].strip(">") 

604 if "function" in func_str: 

605 return func_str.split()[1] 

606 return "unknown_op" 

607 

608 @staticmethod 

609 def differentiable_all_gather_concat(data, group, concat_size, concat_dim): 

610 output = dist_func.all_gather(data, group=group) 

611 return torch.cat(output, dim=concat_dim) 

612 

613 @staticmethod 

614 def chunk(data, split_dim, split_size, index): 

615 return torch.chunk(data, split_size, dim=split_dim)[index] 

616 

617 @staticmethod 

618 def differentiable_all_to_all(input_data, output_shape, group): 

619 output_tensor = torch.empty(output_shape, device=input_data.device, dtype=input_data.dtype) 

620 output_tensor = dist_func.all_to_all_single( 

621 output_tensor, 

622 input_data, 

623 group=group 

624 ) 

625 return output_tensor 

626 

627 @staticmethod 

628 def tensor_type_cast(input_data, cast_type): 

629 """Cast tensor to specified data type.""" 

630 type_mapping = { 

631 'float32': torch.float32, 

632 'float16': torch.float16, 

633 'int64': torch.int64, 

634 'int32': torch.int32 

635 } 

636 if cast_type not in type_mapping: 

637 raise ValueError(f"Unknown cast type: {cast_type}. Supported types: {list(type_mapping.keys())}") 

638 return input_data.to(type_mapping[cast_type]) 

639 

640 @staticmethod 

641 def differentiable_all_reduce(data, op, group): 

642 # Resolve the op from string to ReduceOp enum if necessary 

643 reduce_op = _OP_MAP.get(op, dist.ReduceOp.SUM) if isinstance(op, str) else op 

644 return dist_func.all_reduce(data, op=reduce_op, group=group) 

645 

646 @staticmethod 

647 def get_cell_construct(cell): 

648 return cell.forward 

649 

650 @staticmethod 

651 def get_cells_and_names(cell): 

652 return cell.named_modules() 

653 

654 @staticmethod 

655 def search_parameter_by_name(cell, param_name: str): 

656 """ 

657 Find the parent Module of the parameter, the parameter's name in the parent Module, and the parameter. 

658 Return value: (parent Module instance, parameter's name in parent Module, parameter object). 

659 Returns None if not found. 

660 """ 

661 # Remove the "self." prefix from param_name 

662 param_name = param_name.replace("self.", "") 

663 # Case 1: The parameter is a direct parameter of the current Module 

664 if param_name in cell._parameters: # pylint: disable=protected-access 

665 return (cell, param_name, cell._parameters[param_name]) # pylint: disable=protected-access 

666 

667 # Case 2: The parameter is in a sub-Module 

668 if "." in param_name: 

669 cell_path, param_key = param_name.rsplit(".", 1) 

670 try: 

671 # Locate the sub-Module where the parameter resides (supports multi-level paths) 

672 target_cell = cell.get_submodule(cell_path) 

673 # Check if the sub-Module directly contains this parameter 

674 if param_key in target_cell._parameters: # pylint: disable=protected-access 

675 return target_cell, param_key, target_cell._parameters[param_key] # pylint: disable=protected-access 

676 except AttributeError: 

677 pass 

678 

679 # Traverse all sub-Modules (recursively) to search for the parameter 

680 for _, child_cell in cell.named_children(): 

681 if isinstance(child_cell, Module): 

682 result = TorchPlatform.search_parameter_by_name(child_cell, param_name) 

683 if result is not None: 

684 return result 

685 

686 return None 

687 

688 @staticmethod 

689 def update_parameter_by_name(cell, result: tuple, new_param) -> bool: 

690 """ 

691 Modify the original parameter in a Module or sub-Module using the search result 

692 """ 

693 parent_cell, param_key, _ = result 

694 # Key operation: directly modify the _parameters dictionary. 

695 if param_key in parent_cell._parameters: # pylint: disable=protected-access 

696 parent_cell._parameters[param_key] = new_param # pylint: disable=protected-access 

697 else: 

698 parent_cell.register_parameter(param_key, new_param) 

699 return True 

700 

701 @staticmethod 

702 def set_layout_into_parameter(param, layout): 

703 """Set layout into parameter""" 

704 from hyper_parallel.core.dtensor.dtensor import DTensor # pylint: disable=import-outside-toplevel 

705 from hyper_parallel.core.dtensor.layout import _get_slice_tensor_by_layout # pylint: disable=import-outside-toplevel 

706 if isinstance(param, DTensor): 

707 raise ValueError(f"Parameter {param} has been configured layout, cannot be set repeatedly.") 

708 requires_grad = param.requires_grad 

709 param_dtensor = DTensor.from_local( 

710 _get_slice_tensor_by_layout(param, layout), 

711 layout.mesh, layout.alias_placements) 

712 new_param = Parameter(param_dtensor, requires_grad=requires_grad) 

713 return new_param 

714 

715 @staticmethod 

716 def differentiable_reduce_scatter(data, dev_num, axis, op, group): 

717 input_tuple = torch.chunk(data, dev_num, dim=axis) 

718 output_tensor = torch.empty(input_tuple[0].shape, device=data.device, dtype=data.dtype) 

719 

720 # Resolve the op from string to ReduceOp enum 

721 reduce_op = _OP_MAP.get(op, dist.ReduceOp.SUM) if isinstance(op, str) else op 

722 

723 output_tensor = dist_func.reduce_scatter(output_tensor, input_tuple, op=reduce_op, group=group) 

724 

725 # Keep manual handling for 'avg' string as it maps to SUM in _OP_MAP 

726 if op == 'avg': 

727 output_tensor = output_tensor / dev_num 

728 return output_tensor 

729 

730 @staticmethod 

731 def get_device_handle(device_type: str = "npu"): 

732 try: 

733 handle = getattr(torch, device_type) 

734 except AttributeError as e: 

735 raise RuntimeError(f"TorchPlatform expect got device handle: 'torch.{device_type}' failed.") from e 

736 return handle 

737 

738 @staticmethod 

739 def get_param_type_size(param): 

740 # pylint: disable=W0212 

741 return torch._utils._element_size(param.dtype) 

742 

743 @staticmethod 

744 def is_tensor(obj: Any) -> bool: 

745 """Return True if ``obj`` is a ``torch.Tensor``.""" 

746 return isinstance(obj, Tensor) 

747 

748 @staticmethod 

749 def get_tensor_storage_size(tensor: Any) -> int: 

750 """Return serialized byte size (numel * element size) for a PyTorch tensor.""" 

751 if not TorchPlatform.is_tensor(tensor): 

752 raise TypeError( 

753 f"TorchPlatform.get_tensor_storage_size expects torch.Tensor, got {type(tensor)!r}" 

754 ) 

755 return int(tensor.numel()) * int(tensor.element_size()) 

756 

757 @staticmethod 

758 def parameters_dict(cell: Module): 

759 return cell.named_parameters() 

760 

761 @staticmethod 

762 def get_model_state_dict(model, *, options=None): 

763 # pylint: disable=C0415 

764 from hyper_parallel.platform.torch.fully_shard.state_dict_utils import ( 

765 get_model_state_dict as _get_model_state_dict, 

766 ) 

767 return _get_model_state_dict(model, options=options) 

768 

769 @staticmethod 

770 def save_checkpoint(cell: Module, file_path: str, ckpt_format: str = "safetensors") -> None: 

771 if ckpt_format == "safetensors": 

772 save_file(tensors=cell, filename=file_path) 

773 else: 

774 torch.save(obj=cell, f=file_path) 

775 

776 @staticmethod 

777 def load_checkpoint(file_path: str, ckpt_format: str = "safetensors") -> dict: 

778 if ckpt_format == "safetensors": 

779 return load_file(filename=file_path) 

780 return torch.load(f=file_path) 

781 

782 @staticmethod 

783 def new_zero_parameter(param_shape, param_type, requires_grad, device): 

784 return nn.Parameter(torch.zeros(param_shape, dtype=param_type, device=device), requires_grad=requires_grad) 

785 

786 @staticmethod 

787 def new_tensor(tensor_shape, tensor_type, device): 

788 return torch.empty(size=tensor_shape, dtype=tensor_type, device=device) 

789 

790 @staticmethod 

791 def full_like(tensor, fill_value, dtype=None): 

792 return torch.full_like(tensor, fill_value, dtype=dtype) 

793 

794 @staticmethod 

795 def set_tensor_requires_grad(input_tensor): 

796 """ 

797 set requires grad flag for input tensor, only effective for leaf node 

798 """ 

799 if input_tensor.is_leaf: 

800 input_tensor.requires_grad = True 

801 

802 def _create_group(self, rank_list): 

803 group_dict = create_sub_groups(rank_list) 

804 return group_dict[tuple(rank_list)] 

805 

806 @staticmethod 

807 def all_gather_into_tensor(data, group_info, async_op=False): 

808 output_shape = list(data.shape) 

809 output_shape[0] = output_shape[0] * group_info.rank_size 

810 output = torch.empty(output_shape, dtype=data.dtype, device=data.device) 

811 handle = dist.all_gather_into_tensor(output, data, group=group_info.group, async_op=async_op) 

812 return output, handle 

813 

814 @staticmethod 

815 def all_reduce(data, group_info, async_op=False): 

816 if not data.is_contiguous(): 

817 data = data.contiguous() 

818 handle = dist.all_reduce(data, group=group_info.group, async_op=async_op) 

819 return data, handle 

820 

821 @staticmethod 

822 def broadcast(data, src, group=None, async_op=False): 

823 handle = dist.broadcast(data, src, group, async_op) 

824 if async_op: 

825 handle.wait() 

826 

827 @staticmethod 

828 def isend(tensor, dst=None, group=None, tag=0): 

829 return dist.isend(tensor, dst, group, tag) 

830 

831 @staticmethod 

832 def irecv(tensor, src=None, group=None, tag=0): 

833 return dist.irecv(tensor, src, group, tag) 

834 

835 @staticmethod 

836 def p2p_exchange(tensor, peer_rank: int, group=None): 

837 if peer_rank == dist.get_rank(group): 

838 return tensor 

839 return _TorchP2PExchangeFunction.apply(tensor, peer_rank, group) 

840 

841 @staticmethod 

842 def send_object_list(obj_list, dst=None, group=None): 

843 dist.send_object_list(obj_list, dst, group) 

844 

845 @staticmethod 

846 def recv_object_list(obj_list, src=None, group=None): 

847 dist.recv_object_list(obj_list, src, group) 

848 

849 @staticmethod 

850 def reduce_scatter_tensor(data, group_info, async_op=False): 

851 output_shape = list(data.shape) 

852 output_shape[0] = output_shape[0] // group_info.rank_size 

853 output = torch.empty(output_shape, dtype=data.dtype, device=data.device) 

854 handle = dist.reduce_scatter_tensor(output, data, group=group_info.group, async_op=async_op) 

855 return output, handle 

856 

857 @staticmethod 

858 def all_to_all_single(input_tensor, output_shape, group, async_op=False): 

859 output = torch.empty(output_shape, device=input_tensor.device, dtype=input_tensor.dtype) 

860 work = dist.all_to_all_single(output, input_tensor, group=group, async_op=async_op) 

861 return output, work 

862 

863 @staticmethod 

864 def differentiable_all_to_all_single(input_tensor, input_splits, output_splits, group): 

865 """Variable-split all-to-all with autograd support for EP token dispatch/combine.""" 

866 out_total = sum(output_splits) 

867 output = torch.empty( 

868 out_total, *input_tensor.shape[1:], 

869 dtype=input_tensor.dtype, device=input_tensor.device, 

870 ) 

871 output = dist_func.all_to_all_single( 

872 output, input_tensor, 

873 output_split_sizes=output_splits, 

874 input_split_sizes=input_splits, 

875 group=group, 

876 ) 

877 return output 

878 

879 @staticmethod 

880 def differentiable_all_to_all_single_async(input_tensor, input_splits, output_splits, group): 

881 """Truly-async variant of :meth:`differentiable_all_to_all_single`. 

882 

883 Both forward AND backward return :class:`AsyncCollectiveTensor`, 

884 so the ``wait_tensor`` op is queued lazily — only when a downstream 

885 kernel actually reads the result. 

886 

887 Why both directions need lazy wait: 

888 

889 * FWD: ACT lazy wait lets host return immediately and the paired 

890 BWD thread's compute kernel slip into the queue before the wait. 

891 * BWD: PyTorch's stock backward issues ``wait_tensor`` eagerly, 

892 and the autograd engine binds backward stream to the forward 

893 stream — so even running BWD inside a ``with torch.npu.stream 

894 (side_stream)`` context does not move that wait off the main 

895 stream. Returning ACT from backward defers the wait to the 

896 next backward op's first consumption, opening a small window 

897 during which FWD's Attention kernels can be queued onto the 

898 main stream **before** the wait lands. 

899 

900 Args: 

901 input_tensor: Input tensor, split along dim 0 by ``input_splits``. 

902 input_splits: ``list[int]`` — rows sent to each rank. 

903 output_splits: ``list[int]`` — rows received from each rank. 

904 group: Process group. 

905 

906 Returns: 

907 ``AsyncCollectiveTensor`` of shape 

908 ``[sum(output_splits), *input_tensor.shape[1:]]``. 

909 """ 

910 return _AsyncA2ALazyBwd.apply(input_tensor, output_splits, input_splits, group) 

911 

912 @staticmethod 

913 def arange(start, end=None, step=1, dtype=None, device=None): 

914 """Create a 1-D tensor with evenly spaced values.""" 

915 if end is None: 

916 return torch.arange(start, dtype=dtype, device=device) 

917 return torch.arange(start, end, step, dtype=dtype, device=device) 

918 

919 @staticmethod 

920 def differentiable_async_a2a_wait(x, work, out_perm, group, world_size, concat_dim, split_dim, 

921 handle_box=None): 

922 """Wait async A2A handle and reconstruct result (differentiable). 

923 

924 Args: 

925 x: Input tensor. 

926 work: Async work handle from all_to_all. 

927 out_perm: Output buffer from all_to_all. 

928 group: Process group. 

929 world_size: World size. 

930 concat_dim: Dimension for concatenation. 

931 split_dim: Dimension for split. 

932 handle_box: Optional mutable list; backward appends (work, out_perm) here. 

933 """ 

934 return _TorchAsyncA2AFunction.apply( 

935 x, work, out_perm, group, world_size, concat_dim, split_dim, handle_box 

936 ) 

937 

938 @staticmethod 

939 def differentiable_sync_hook(x, hook_name: str, coordinator): 

940 """Identity op that fires coordinator rendezvous on forward and backward. 

941 

942 Always goes through ``_TorchSyncHookFunction.apply`` so that the 

943 autograd graph **records a SyncHook node regardless of whether the 

944 coordinator is currently enabled**. Skipping ``apply`` when 

945 disabled would leave warmup-forwarded graphs without the hook 

946 nodes, and a later ``overlap.run`` — whose BWD thread back-props 

947 such a graph — would then traverse zero hooks while the paired FWD 

948 thread (whose current forward DOES record hooks) waits at a 

949 barrier for a partner that never arrives. 

950 

951 Args: 

952 x: Input tensor. 

953 hook_name: One of ``"A"``, ``"B"``, ``"C"``, ``"D"``. 

954 coordinator: A :class:`HookCoordinator` instance. 

955 """ 

956 return _TorchSyncHookFunction.apply(x, hook_name, coordinator) 

957 

958 @staticmethod 

959 def get_tensor_transform(): 

960 raise NotImplementedError("Unsupported get_tensor_transform for torch platform") 

961 

962 @staticmethod 

963 def construct_strided_slice(x, begin, end, stride): 

964 raise NotImplementedError("Unsupported construct_strided_slice for torch platform") 

965 

966 @staticmethod 

967 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None): 

968 # pylint: disable=C0415 

969 from hyper_parallel.platform.torch.pipeline_parallel._utils import _MicroBatch 

970 return _MicroBatch(micro_batch_num, args_batch_dim, kwargs_batch_dim) 

971 

972 @staticmethod 

973 def get_symmetric_memory_handler(): 

974 # pylint: disable=C0415 

975 from hyper_parallel.platform.torch.symmetric_memory import TorchSymmetricMemoryHandler 

976 symmetric_memory = TorchSymmetricMemoryHandler() 

977 return symmetric_memory 

978 

979 @staticmethod 

980 def get_multicore_handler(): 

981 # pylint: disable=C0415 

982 from hyper_parallel.platform.torch.multicore import TorchMulticoreHandler 

983 return TorchMulticoreHandler() 

984 

985 def new_stream(self): 

986 device = self.get_device_handle() 

987 return device.Stream() 

988 

989 def get_stream_context(self): 

990 device = self.get_device_handle() 

991 return device.stream 

992 

993 @staticmethod 

994 def all_gather_object(object_list, obj, group=None) -> None: 

995 """ 

996 Gathers objects from the given group into object list. 

997 

998 Args: 

999 object_list (list[Any]): Define the output list, which size equal to the size of group. 

1000 obj (Any): The object on current rank and in given process group. 

1001 group (ProcessGroup, optional): The process group to gather obj. Default is ``None``, and ``None`` means 

1002 global group. 

1003 

1004 Returns: 

1005 None. Objs are gathered into ``object_list``. 

1006 """ 

1007 dist.all_gather_object(object_list, obj, group) 

1008 

1009 @staticmethod 

1010 def barrier(group=None, async_op: bool = False, device_ids=None) -> Any: 

1011 """ 

1012 Synchronize all processes in the given process group. 

1013 

1014 Args: 

1015 group (ProcessGroup, optional): The process group to work on. Default is ``None``, 

1016 meaning the default process group. 

1017 async_op (bool, optional): Whether this op should be asynchronous. Default: ``False``. 

1018 device_ids (list[int], optional): Device ids for backends that require a device for 

1019 barrier (e.g. NCCL). Default: ``None``. 

1020 

1021 Returns: 

1022 Async work handle if ``async_op`` is True; otherwise ``None``. 

1023 """ 

1024 return dist.barrier(group, async_op, device_ids) 

1025 

1026 @staticmethod 

1027 def init_process_group( 

1028 backend: Optional[str] = None, 

1029 *, 

1030 init_method: Optional[str] = None, 

1031 timeout: Optional[timedelta] = None, 

1032 world_size: int = -1, 

1033 rank: int = -1, 

1034 store: Optional[Store] = None, 

1035 pg_options: Optional[Any] = None, 

1036 device_id: Optional[Union[torch.device, int]] = None, 

1037 ) -> None: 

1038 """ 

1039 Initialize global process group. 

1040 

1041 Args: 

1042 backend (str or Backend, optional): The backend to use for distributed communication. 

1043 init_method (str, optional): URL specifying how to initialize the process group. Default is "env://", 

1044 can not be specified at the same time with ``store``. 

1045 timeout (timedelta, optional): Timeout for process group. Default 10 minutes for NCCL and for other 

1046 backends 30 minutes. 

1047 world_size (int, optional): Number of processes. If ``store`` is specified, world_size is required. 

1048 rank (int, optional): Rank of the current process, which value must between 0 and ``world_size``-1. If 

1049 ``store`` is specified, rank is required. 

1050 store (Store, optional): Key/value store accessible to all workers, used to exchange connection/address 

1051 information. Can not be specified at the same time with ``init_method``. 

1052 pg_options (ProcessGroupOptions, optional): Extra options to pass during constructing process groups. 

1053 device_id (torch.device | int, optional): Specific device this process will work on. 

1054 """ 

1055 try: 

1056 _get_default_group() 

1057 # except multi version error 

1058 except (ValueError, RuntimeError): 

1059 if backend is None: 

1060 backend = "hccl" 

1061 dist.init_process_group(backend=backend, init_method=init_method, timeout=timeout, world_size=world_size, 

1062 rank=rank, store=store, pg_options=pg_options, device_id=device_id) 

1063 

1064 @staticmethod 

1065 def destroy_process_group(group: Optional[ProcessGroup] = None) -> None: 

1066 """ 

1067 Destroy given process group. 

1068 

1069 Args: 

1070 group (ProcessGroup, optional): Given process group will be destroyed, if not given, all process groups 

1071 will be destroyed. 

1072 """ 

1073 group = group or _get_default_group() 

1074 if group in EXISTING_COMM_GROUPS.values(): 

1075 keys_to_destroy = [k for k, v in EXISTING_COMM_GROUPS.items() if v == group] 

1076 for k in keys_to_destroy: 

1077 del EXISTING_COMM_GROUPS[k] 

1078 dist.destroy_process_group(group) 

1079 

1080 @staticmethod 

1081 def get_process_group_ranks(group: Optional[ProcessGroup] = None) -> list[int]: 

1082 """ 

1083 Get all ranks relative to given process group. 

1084 

1085 Args: 

1086 group (Optional[ProcessGroup]): Process group worked on. Default is ``None``, and ``None`` means global 

1087 group. 

1088 

1089 Returns: 

1090 Rank list. 

1091 """ 

1092 group = group or _get_default_group() 

1093 return dist.get_process_group_ranks(group) 

1094 

1095 @staticmethod 

1096 def get_backend(group: Optional[ProcessGroup] = None) -> Backend: 

1097 """ 

1098 Get the backend of the given process group. 

1099 

1100 Args: 

1101 group (ProcessGroup, optional): Process group worked on. Default is ``None``, and ``None`` means global 

1102 group. 

1103 

1104 Returns: 

1105 The backend object of the given process group. 

1106 """ 

1107 group = group or _get_default_group() 

1108 return dist.get_backend(group) 

1109 

1110 @staticmethod 

1111 def split_group(parent_pg: Optional[ProcessGroup] = None, 

1112 split_ranks: Optional[list] = None, 

1113 timeout: Optional[timedelta] = None, 

1114 pg_options: Optional[Any] = None, 

1115 group_desc: Optional[str] = None, 

1116 ) -> Optional[ProcessGroup]: 

1117 """ 

1118 Create split groups for every group rank in split_ranks, and return the split process group which relative to 

1119 current rank id. 

1120 

1121 Args: 

1122 parent_pg (Optional[ProcessGroup]): A process group which the goal group split from. 

1123 split_ranks (Optional[list]): A list like ``list[list[int]]``. 

1124 timeout (Optional[timedelta]): Timeout for process group. Default 10 minutes for NCCL and for other 

1125 backend 30 minutes. 

1126 pg_options (Optional[Any]): Extra options to pass during constructing process groups. 

1127 group_desc (Optional[str]): Description of process group. 

1128 

1129 Return: 

1130 Optional[ProcessGroup]: One of split process group which relative to current rank id 

1131 """ 

1132 if split_ranks is None or len(split_ranks) == 0: 

1133 raise ValueError("split_ranks cannot be None or empty") 

1134 

1135 split_group = None 

1136 for split_rank in split_ranks: 

1137 dist_group = TorchPlatform.get_created_group(split_rank) 

1138 if dist_group is None: 

1139 dist_group = dist.new_group(ranks=split_rank) 

1140 EXISTING_COMM_GROUPS[str(tuple(sorted(split_rank)))] = dist_group 

1141 if TorchPlatform.get_rank() in split_rank: 

1142 split_group = dist_group 

1143 

1144 return split_group 

1145 

1146 @staticmethod 

1147 def get_group_local_rank(group: ProcessGroup = None) -> int: 

1148 """get group local rank id.""" 

1149 group = group or _get_default_group() 

1150 return group.rank() 

1151 

1152 @staticmethod 

1153 def no_grad(): 

1154 return torch.no_grad() 

1155 

1156 @staticmethod 

1157 def cat(tensors, dim=0): 

1158 return torch.cat(tensors, dim=dim) 

1159 

1160 @staticmethod 

1161 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False): 

1162 return torch.empty_like(tensor, dtype=dtype, device=device, pin_memory=pin_memory) 

1163 

1164 def get_current_stream(self): 

1165 device = self.get_device_handle() 

1166 return device.current_stream() 

1167 

1168 def new_event(self): 

1169 device = self.get_device_handle() 

1170 return device.Event() 

1171 

1172 def tree_map(self, fn, tree): 

1173 return torch.utils._pytree.tree_map(fn, tree) # pylint: disable=protected-access 

1174 

1175 @property 

1176 def checkpoint(self): 

1177 return torch.utils.checkpoint.checkpoint 

1178 

1179 @staticmethod 

1180 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs): 

1181 # pylint: disable=C0415 

1182 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import FuncModule 

1183 if callable(module) and not isinstance(module, torch.nn.Module): 

1184 module = FuncModule(module) 

1185 return checkpoint_wrapper(module, checkpoint_fn=checkpoint_fn, **checkpoint_fn_kwargs) 

1186 

1187 @staticmethod 

1188 def swap_wrapper(module, policy_fn=None): 

1189 # pylint: disable=C0415 

1190 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import swap_wrapper 

1191 return swap_wrapper(module, policy_fn=policy_fn) 

1192 

1193 @property 

1194 def noop_context_fn(self): 

1195 return noop_context_fn 

1196 

1197 @staticmethod 

1198 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): 

1199 # pylint: disable=C0415 

1200 from hyper_parallel.platform.torch.activation_checkpoint.sac import create_selective_checkpoint_contexts 

1201 return create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation) 

1202 

1203 @staticmethod 

1204 def async_save_on_cpu(policy_fn=None): 

1205 # pylint: disable=C0415 

1206 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import AsyncSaveOnCpu 

1207 return AsyncSaveOnCpu(policy_fn) 

1208 

1209 @staticmethod 

1210 def get_element_size(tensor): 

1211 """Get Tensor Element Size""" 

1212 return tensor.element_size() 

1213 

1214 @staticmethod 

1215 def tensor_to_numpy(tensor) -> np.ndarray: 

1216 """Convert PyTorch tensor to numpy array.""" 

1217 return tensor.cpu().numpy() 

1218 

1219 @staticmethod 

1220 def clip_grad_norm_( 

1221 parameters, max_norm, norm_type=2.0, 

1222 error_if_nonfinite=False, foreach=None, 

1223 ): 

1224 # pylint: disable=C0415 

1225 from hyper_parallel.platform.torch.clip_grad import ( 

1226 clip_grad_norm_ as _clip_grad_norm, 

1227 ) 

1228 return _clip_grad_norm( 

1229 parameters, max_norm, norm_type, 

1230 error_if_nonfinite=error_if_nonfinite, foreach=foreach, 

1231 ) 

1232 

1233 @staticmethod 

1234 def profiler_record(name): 

1235 """Profiler context manager for recording operations using torch.profiler.""" 

1236 return torch.profiler.record_function(name) 

1237 

1238 def cast_fp_tensor(self, dtype, x): 

1239 """ 

1240 Cast floating-point tensor to target dtype if applicable. 

1241 """ 

1242 if ( 

1243 not isinstance(x, torch.Tensor) 

1244 or not torch.is_floating_point(x) 

1245 or x.dtype == dtype 

1246 ): 

1247 return x 

1248 return x.to(dtype) 

1249 

1250 def apply_to_tensors(self, fn, container): 

1251 """Recursively apply to all tensor in different kinds of container types.""" 

1252 

1253 def apply(x): 

1254 

1255 if isinstance(x, torch.Tensor): 

1256 return fn(x) 

1257 if hasattr(x, "__dataclass_fields__"): 

1258 dc = dataclasses.replace(x) 

1259 changes = { 

1260 f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc) 

1261 } 

1262 return dataclasses.replace(dc, **changes) 

1263 if isinstance(x, OrderedDict): 

1264 od = x.__class__() 

1265 for key, value in x.items(): 

1266 od[key] = apply(value) 

1267 return od 

1268 if isinstance(x, PackedSequence): 

1269 apply(x.data) 

1270 return x 

1271 if isinstance(x, dict): 

1272 return {key: apply(value) for key, value in x.items()} 

1273 if isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields"): 

1274 res = (apply(el) for el in x) 

1275 return type(x)(*res) 

1276 if isinstance(x, (list, tuple, set)): 

1277 return type(x)(apply(el) for el in x) 

1278 return x 

1279 

1280 return apply(container) 

1281 

1282 

1283 @property 

1284 def meta_device(self): 

1285 return torch.device("meta") 

1286 

1287 def init_on_device(self, device, include_buffers=False): 

1288 return _init_on_device(device, include_buffers=include_buffers) 

1289 

1290 def str_to_dtype(self, dtype_str: str) -> torch.dtype: 

1291 """Map ``torch.<type>`` strings from checkpoint metadata to ``torch.dtype``.""" 

1292 parts = dtype_str.split(".", 1) 

1293 if len(parts) != 2: 

1294 raise ValueError( 

1295 f"Expected dtype string like 'torch.float32', got {dtype_str!r}." 

1296 ) 

1297 prefix, name = parts 

1298 if prefix != "torch": 

1299 raise ValueError( 

1300 f"Expected PyTorch dtype string with prefix 'torch', got {dtype_str!r}." 

1301 ) 

1302 dtype = getattr(torch, name) 

1303 if isinstance(dtype, torch.dtype): 

1304 return dtype 

1305 raise ValueError(f"{dtype_str!r} does not resolve to a torch.dtype.") 

1306 

1307 def list_to_size(self, size_list: list[int]) -> torch.Size: 

1308 return torch.Size(size_list)