Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / platform.py: 55%
602 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Torch platform api"""
16from datetime import timedelta
17from typing import Optional, Any, Union
18import dataclasses
19from collections import OrderedDict
21import numpy as np
22from safetensors.torch import save_file, load_file
23import torch
24from torch import nn
25from torch import Tensor
26from torch._C._distributed_c10d import Store, ProcessGroup
27from torch.distributed import Backend
28from torch.distributed.distributed_c10d import _get_default_group
29from torch.nn import Parameter, Module
30from torch.nn.utils.rnn import PackedSequence
31from torch._ops import OpOverload, OpOverloadPacket
32from torch.utils.checkpoint import noop_context_fn
33from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
34import torch.distributed.nn.functional as dist_func
35import torch.distributed as dist
36from hyper_parallel.platform.torch.dtensor import DTensorBase
37from hyper_parallel.platform.torch.pipeline_parallel.stage import PipelineStageBase
38from hyper_parallel.platform.torch.group_utils import create_sub_groups
39from hyper_parallel.platform.platform import Platform, PlatformType, EXISTING_COMM_GROUPS
40from hyper_parallel.platform.torch.function_override import override_functions
41from hyper_parallel.platform.torch.init_weights import init_on_device as _init_on_device
43override_functions()
46# ---------------------------------------------------------------------------
47# Module-level A2A reshape helpers
48# ---------------------------------------------------------------------------
50def _a2a_reconstruct(out_perm: torch.Tensor, concat_dim: int) -> torch.Tensor:
51 """Reconstruct A2A result from raw out_perm buffer.
53 ``out_perm`` has shape ``[ws, *rest_dims]``, chunk at ``concat_dim + 1``.
54 Returns tensor with merged chunk dimension.
55 """
56 new_ndim = out_perm.dim()
57 chunk_in_perm = concat_dim + 1
58 recon_perm = list(range(1, chunk_in_perm)) + [0] + list(range(chunk_in_perm, new_ndim))
59 x_recon = out_perm.permute(recon_perm).contiguous()
60 shape = list(x_recon.shape)
61 merged = shape[concat_dim] * shape[concat_dim + 1]
62 return x_recon.reshape(shape[:concat_dim] + [merged] + shape[concat_dim + 2:])
65class _TorchAsyncA2AFunction(torch.autograd.Function):
66 """Differentiable wrapper for pre-launched async all-to-all.
68 Forward: wait async handle, reconstruct A2A result.
69 Backward: launch async head→seq A2A and store handle in ``handle_box``
70 for the projection pre-hook to wait, achieving GEMM–A2A overlap.
71 """
73 @staticmethod
74 def forward(ctx, x, work, out_perm, group, world_size, concat_dim, split_dim, # pylint: disable=arguments-differ
75 handle_box):
76 """Wait for pre-launched async A2A and return reconstructed output."""
77 ctx.group = group
78 ctx.world_size = world_size
79 ctx.concat_dim = concat_dim
80 ctx.split_dim = split_dim
81 ctx.handle_box = handle_box
82 ctx.x_shape = x.shape
83 work.wait()
84 return _a2a_reconstruct(out_perm, concat_dim)
86 @staticmethod
87 def backward(ctx, grad_output):
88 """Launch async head→seq A2A for backward overlap, or return zero grad."""
89 if ctx.handle_box is not None:
90 # Launch async head→seq A2A (reverse of forward seq→head)
91 g = grad_output.contiguous()
92 shape = list(g.shape)
93 seq_dim = ctx.concat_dim
94 s_full = shape[seq_dim]
95 ndim = len(shape) + 1
96 x_perm = g.reshape(
97 shape[:seq_dim] + [ctx.world_size, s_full // ctx.world_size] + shape[seq_dim + 1:]
98 ).permute(
99 [seq_dim] + list(range(seq_dim)) + list(range(seq_dim + 1, ndim))
100 ).contiguous()
101 out_perm = torch.empty_like(x_perm)
102 work = dist.all_to_all_single(out_perm, x_perm, group=ctx.group, async_op=True)
103 ctx.handle_box.append((work, out_perm))
104 return grad_output.new_zeros(ctx.x_shape), None, None, None, None, None, None, None
107class _AsyncA2ALazyBwd(torch.autograd.Function):
108 """All-to-all whose forward AND backward return ``AsyncCollectiveTensor``.
110 PyTorch's stock ``all_to_all_single_autograd`` calls ``wait_tensor`` in
111 its backward eagerly, and the autograd engine binds backward stream
112 context to the forward stream — so even if the BWD thread is wrapped
113 in a side-stream context, that wait still lands on the FWD main
114 stream and blocks Attention launches.
116 This Function bypasses the engine's binding by calling the
117 non-autograd functional op in both directions and returning ACT.
118 The wait is deferred to the next consumer's first non-view access
119 (e.g. the indexing backward of ``_unpermute``), giving the FWD
120 thread a small Python window to enqueue its Attention kernels onto
121 the main stream **before** the wait lands there.
122 """
124 @staticmethod
125 def forward(ctx, input_tensor, output_splits, input_splits, group): # pylint: disable=arguments-differ
126 ctx.input_splits = input_splits
127 ctx.output_splits = output_splits
128 ctx.group = group
129 # pylint: disable=C0415
130 from torch.distributed._functional_collectives import all_to_all_single
131 return all_to_all_single(
132 input_tensor, output_splits, input_splits, group,
133 )
135 @staticmethod
136 def backward(ctx, grad_output):
137 # pylint: disable=C0415
138 from torch.distributed._functional_collectives import all_to_all_single
139 grad_input = all_to_all_single(
140 grad_output, ctx.input_splits, ctx.output_splits, ctx.group,
141 )
142 return grad_input, None, None, None
145class _TorchSyncHookFunction(torch.autograd.Function):
146 """Autograd identity that fires HookCoordinator rendezvous on fwd/bwd.
148 Uses a **4-hook** design (``A``, ``B``, ``C``, ``D``) with pure
149 COMM / COMPUTE roles — no NONE role. Every rendezvous is a strict
150 COMM + COMPUTE pair, guaranteeing NCCL-first dispatch ordering at
151 **all** points including layer boundaries.
153 Hook placement per MoE layer::
155 [A] → dispatch → [B] → module → [C] → combine → [D] → (Attention) → [A_next]
157 At layer boundaries (D / A hooks), the Attention that runs between
158 layers is treated as COMPUTE, and the combine / combine.bwd is treated
159 as COMM, so the coordinator enforces comm-first ordering even across
160 layer transitions.
161 """
163 # 4-hook role tables: (prev_role_idx, next_role_idx).
164 # Index encoding: 1 = COMM, 2 = COMPUTE.
165 _FWD_ROLES = {
166 # (prev, next) prev op next op
167 "A": (2, 1), # COMPUTE, COMM Attention | dispatch
168 "B": (1, 2), # COMM, COMPUTE dispatch | module
169 "C": (2, 1), # COMPUTE, COMM module | combine
170 "D": (1, 2), # COMM, COMPUTE combine | Attention
171 }
172 _BWD_ROLES = {
173 "D": (2, 1), # COMPUTE, COMM Attn.bwd | combine.bwd
174 "C": (1, 2), # COMM, COMPUTE combine.bwd | module.bwd
175 "B": (2, 1), # COMPUTE, COMM module.bwd | dispatch.bwd
176 "A": (1, 2), # COMM, COMPUTE dispatch.bwd| Attn.bwd
177 }
179 _ROLE_CACHE = None
181 @staticmethod
182 def _role_enum(idx: int):
183 if _TorchSyncHookFunction._ROLE_CACHE is None:
184 from hyper_parallel.core.pipeline_parallel.hook_coordinator import HookRole # pylint: disable=C0415
185 _TorchSyncHookFunction._ROLE_CACHE = (None, HookRole.COMM, HookRole.COMPUTE)
186 return _TorchSyncHookFunction._ROLE_CACHE[idx]
188 @staticmethod
189 def forward(ctx, x, hook_name, coordinator): # pylint: disable=arguments-differ
190 """Identity forward that fires a HookCoordinator rendezvous.
192 Notifies the previous op's role and rendezvouses for the next op's
193 role per the ``_FWD_ROLES`` table. ``"D_LAST"`` is a sentinel
194 meaning "skip this rendezvous" (last layer's closing D — no
195 Attention follows).
197 Args:
198 ctx: Autograd context, stores ``hook_name`` and
199 ``coordinator`` for the backward pass.
200 x: Input tensor, returned unchanged.
201 hook_name: One of ``"A"``, ``"B"``, ``"C"``, ``"D"``,
202 ``"D_LAST"``.
203 coordinator: The :class:`HookCoordinator` driving the rendezvous.
205 Returns:
206 ``x`` unchanged.
207 """
208 ctx.hook_name = hook_name
209 ctx.coordinator = coordinator
211 if not coordinator.is_enabled():
212 return x
214 # ``D_LAST`` marks the last layer's D hook. The "next op" after
215 # this hook is the chunk's output (no Attention follows), so the
216 # rendezvous is meaningless — skip it. In backward this same
217 # hook is the very first BWD hook to fire, where ``combine.bwd``
218 # has already free-run before any rendezvous is possible — also
219 # skip. Tagging at wrap time replaces the old runtime
220 # ``increment_cycle`` / ``bwd_d_should_skip`` mechanisms.
221 if hook_name == "D_LAST":
222 return x
224 prev_idx, next_idx = _TorchSyncHookFunction._FWD_ROLES[hook_name]
225 role_of = _TorchSyncHookFunction._role_enum
226 coordinator.notify_dispatched(role_of(prev_idx))
227 coordinator.rendezvous(role_of(next_idx))
228 return x
230 @staticmethod
231 def backward(ctx, grad_output):
232 """Identity backward that fires a HookCoordinator rendezvous.
234 Mirror of :meth:`forward` using the ``_BWD_ROLES`` table.
235 ``"D_LAST"`` skips the rendezvous because this is the first BWD
236 hook to fire and ``combine.bwd`` has already dispatched freely
237 before any rendezvous can happen.
239 Args:
240 ctx: Autograd context with ``hook_name`` and
241 ``coordinator`` saved during forward.
242 grad_output: Gradient w.r.t. the forward output, returned
243 unchanged.
245 Returns:
246 ``(grad_output, None, None)`` — gradients only flow back to
247 the tensor input, ``hook_name`` and ``coordinator`` are
248 non-tensor inputs.
249 """
250 hook_name = ctx.hook_name
251 coordinator = ctx.coordinator
253 if not coordinator.is_enabled():
254 return grad_output, None, None
256 # Same ``D_LAST`` semantics as forward: this is the first BWD
257 # hook to fire and combine.bwd has already dispatched freely
258 # before any rendezvous can happen, so skip the rendezvous.
259 if hook_name == "D_LAST":
260 return grad_output, None, None
262 prev_idx, next_idx = _TorchSyncHookFunction._BWD_ROLES[hook_name]
263 role_of = _TorchSyncHookFunction._role_enum
264 coordinator.notify_dispatched(role_of(prev_idx))
265 coordinator.rendezvous(role_of(next_idx))
266 return grad_output, None, None
269class _TorchP2PExchangeFunction(torch.autograd.Function):
270 """Symmetric bidirectional P2P: send local tensor to peer, receive peer's tensor."""
272 @staticmethod
273 def forward(ctx, tensor: torch.Tensor, peer_rank: int, group) -> torch.Tensor: # pylint: disable=arguments-differ
274 """Perform symmetric bidirectional P2P exchange with peer_rank."""
275 ctx.peer_rank = peer_rank
276 ctx.group = group
277 send_buf = tensor.contiguous()
278 recv_buf = torch.empty_like(send_buf)
279 reqs = dist.batch_isend_irecv([
280 dist.P2POp(dist.isend, send_buf, peer_rank, group),
281 dist.P2POp(dist.irecv, recv_buf, peer_rank, group),
282 ])
283 for req in reqs:
284 req.wait()
285 return recv_buf
287 @staticmethod
288 def backward(ctx, grad_output: torch.Tensor):
289 """Perform symmetric P2P exchange for the backward gradient pass."""
290 send_buf = grad_output.contiguous()
291 recv_buf = torch.empty_like(send_buf)
292 reqs = dist.batch_isend_irecv([
293 dist.P2POp(dist.isend, send_buf, ctx.peer_rank, ctx.group),
294 dist.P2POp(dist.irecv, recv_buf, ctx.peer_rank, ctx.group),
295 ])
296 for req in reqs:
297 req.wait()
298 return recv_buf, None, None
301# Mapping from string op names to torch.distributed.ReduceOp
302_OP_MAP = {
303 'sum': dist.ReduceOp.SUM,
304 'prod': dist.ReduceOp.PRODUCT,
305 'max': dist.ReduceOp.MAX,
306 'min': dist.ReduceOp.MIN,
307 # convert tensor elements to int32 and use MIN
308 'all': dist.ReduceOp.MIN,
309 # 'avg' is typically handled by SUM followed by division in current implementation logic
310 'avg': dist.ReduceOp.SUM,
311}
313# Try to add AVG for 'mean' if supported by current torch version
314if hasattr(dist.ReduceOp, "AVG"):
315 _OP_MAP['mean'] = dist.ReduceOp.AVG
316else:
317 # Fallback for older torch versions if necessary, though this might require manual division upstream
318 # Assuming standard behavior where 'mean' implies native AVG support or upstream handling
319 _OP_MAP['mean'] = dist.ReduceOp.SUM
322# pylint: disable=C0103
323class TorchPlatform(Platform):
324 """Torch platform api"""
325 Tensor = Tensor
326 tensor = torch.tensor
327 Parameter = Parameter
328 Module = Module
329 DTensorBase = DTensorBase
330 PipelineStageBase = PipelineStageBase
331 platform_type = PlatformType.PYTORCH
332 tensor_dtype = torch
333 dtype = torch.dtype
334 Function = torch.autograd.Function
336 @staticmethod
337 def is_linear_module(module) -> bool:
338 """Check whether *module* is a ``torch.nn.Linear`` instance."""
339 return isinstance(module, nn.Linear)
341 @staticmethod
342 def is_embedding_module(module) -> bool:
343 """Check whether *module* is a ``torch.nn.Embedding`` instance."""
344 return isinstance(module, nn.Embedding)
346 @staticmethod
347 def device_count(device_handle):
348 """
349 Get the number of available devices.
351 Args:
352 device_handle: The device handle (e.g., torch.cuda, torch.npu).
354 Returns:
355 int: The number of available devices.
356 """
357 return device_handle.device_count()
359 def device_type(self):
360 """
361 Get the current device type.
363 Returns:
364 str: The device type string ("npu" for NPU, "cuda" for GPU).
365 """
366 device_handle = self.get_device_handle()
367 if device_handle == torch.npu:
368 return "npu"
369 return "cuda"
371 def device(self, device_idx=None):
372 """
373 Get a torch.device object for the specified device index.
375 Args:
376 device_idx (Optional[int]): The device index. If None, returns device without index.
378 Returns:
379 torch.device: A torch device object.
380 """
381 device_type = self.device_type()
382 if device_idx is None:
383 return torch.device(device_type)
384 return torch.device(f"{device_type}:{device_idx:d}")
386 @staticmethod
387 def get_rng_state(device=None, device_handle=None):
388 """
389 Get the random number generator state.
391 Args:
392 device (Optional): The device to get RNG state from.
393 device_handle (Optional): The device handle (torch.cuda, torch.npu, etc.).
395 Returns:
396 Tensor: The RNG state as a byte tensor.
397 """
398 if device_handle is None:
399 return torch.get_rng_state()
400 if device is None:
401 return device_handle.get_rng_state()
402 return device_handle.get_rng_state(device)
404 @staticmethod
405 def set_rng_state(state, device=None, device_handle=None):
406 """
407 Set the random number generator state.
409 Args:
410 state (Tensor): The RNG state to set.
411 device (Optional): The device to set RNG state for.
412 device_handle (Optional): The device handle (torch.cuda, torch.npu, etc.).
413 """
414 if device_handle is None:
415 return torch.set_rng_state(state)
416 if device is None:
417 return device_handle.set_rng_state(state)
418 return device_handle.set_rng_state(state, device)
420 @staticmethod
421 def manual_seed(seed):
422 """
423 Set the random seed for reproducibility.
425 Args:
426 seed (int): The random seed value.
428 Returns:
429 torch.Generator: The random number generator.
430 """
431 return torch.manual_seed(seed)
433 @staticmethod
434 def ones(size, dtype=None):
435 """
436 Create a tensor filled with ones.
438 Args:
439 size (tuple): The shape of the output tensor.
440 dtype (Optional[torch.dtype]): The desired data type.
442 Returns:
443 Tensor: A tensor filled with ones.
444 """
445 return torch.ones(size, dtype=dtype)
447 @staticmethod
448 def zeros(size, dtype=None, device=None):
449 """
450 Create a tensor filled with zeros.
452 Args:
453 size (tuple): The shape of the output tensor.
454 dtype (Optional[torch.dtype]): The desired data type.
455 device (Optional[torch.device]): The device to create the tensor on.
457 Returns:
458 Tensor: A tensor filled with zeros.
459 """
460 return torch.zeros(size, dtype=dtype, device=device)
462 @staticmethod
463 def full(size, fill_value, dtype=None):
464 """
465 Create a tensor filled with a scalar value.
467 Args:
468 size (tuple): The shape of the output tensor.
469 fill_value (scalar): The value to fill the tensor with.
470 dtype (Optional[torch.dtype]): The desired data type.
472 Returns:
473 Tensor: A tensor filled with the specified value.
474 """
475 return torch.full(size, fill_value, dtype=dtype)
477 @staticmethod
478 def empty(size, dtype=None):
479 """
480 Create an uninitialized tensor.
482 Args:
483 size (tuple): The shape of the output tensor.
484 dtype (Optional[torch.dtype]): The desired data type.
486 Returns:
487 Tensor: An uninitialized tensor.
488 """
489 return torch.empty(size, dtype=dtype)
491 @staticmethod
492 def get_rank():
493 """
494 Get the rank of the current process in the distributed group.
496 Returns:
497 int: The rank of the current process.
498 """
499 return dist.get_rank()
501 @staticmethod
502 def get_global_rank(group, group_rank):
503 """
504 Get the global rank from a group rank.
506 Args:
507 group (ProcessGroup): The process group.
508 group_rank (int): The rank within the group.
510 Returns:
511 int: The global rank.
512 """
513 return dist.get_global_rank(group, group_rank)
515 @staticmethod
516 def get_world_size():
517 """
518 Get the total number of processes in the distributed group.
520 Returns:
521 int: The world size.
522 """
523 return dist.get_world_size()
525 @staticmethod
526 def get_param_local_shape(param):
527 """
528 Get the local shape of a parameter, handling both regular and distributed tensors.
530 Args:
531 param (Union[Tensor, DTensorBase]): The parameter tensor.
533 Returns:
534 torch.Size: The local shape of the parameter.
535 """
536 if isinstance(param, DTensorBase):
537 return param.local_shape
538 return param.shape
540 @staticmethod
541 def get_param_local_data(param):
542 """
543 Get the local data of a parameter, handling both regular and distributed tensors.
545 Args:
546 param (Union[Tensor, DTensorBase]): The parameter tensor.
548 Returns:
549 Tensor: The local tensor data.
550 """
551 if isinstance(param, DTensorBase):
552 return param.to_local()
553 return param
555 @staticmethod
556 def update_param_data(param, data):
557 """
558 Update the data of a parameter.
560 Args:
561 param (Parameter): The parameter to update.
562 data (Tensor): The new data tensor.
563 """
564 param.data = data
566 @staticmethod
567 def load_into_param(param, data):
568 """Load tensor *data* into *param* (plain tensor or DTensor)."""
569 if isinstance(param, DTensorBase):
570 local = param._local_tensor # pylint: disable=W0212
571 if local.is_meta:
572 # Meta tensor materialisation: replace the placeholder.
573 orig_requires_grad = param.requires_grad
574 param._local_tensor = data # pylint: disable=W0212
575 if data.requires_grad != orig_requires_grad:
576 param.requires_grad_(orig_requires_grad)
577 else:
578 local.copy_(data)
579 else:
580 param.copy_(data)
582 @staticmethod
583 def get_op_name(func):
584 """
585 Extract the operation name from various function types.
587 Args:
588 func: The function or operation to extract the name from.
590 Returns:
591 str: The operation name.
592 """
593 if hasattr(func, "__name__"):
594 return func.__name__
595 if isinstance(func, OpOverload):
596 full_name = func.name
597 core_name = full_name.split("::")[-1].split(".")[0]
598 return core_name
599 if isinstance(func, OpOverloadPacket):
600 return func.name.split("::")[-1]
601 func_str = str(func)
602 if "built-in function" in func_str:
603 return func_str.split()[-1].strip(">")
604 if "function" in func_str:
605 return func_str.split()[1]
606 return "unknown_op"
608 @staticmethod
609 def differentiable_all_gather_concat(data, group, concat_size, concat_dim):
610 output = dist_func.all_gather(data, group=group)
611 return torch.cat(output, dim=concat_dim)
613 @staticmethod
614 def chunk(data, split_dim, split_size, index):
615 return torch.chunk(data, split_size, dim=split_dim)[index]
617 @staticmethod
618 def differentiable_all_to_all(input_data, output_shape, group):
619 output_tensor = torch.empty(output_shape, device=input_data.device, dtype=input_data.dtype)
620 output_tensor = dist_func.all_to_all_single(
621 output_tensor,
622 input_data,
623 group=group
624 )
625 return output_tensor
627 @staticmethod
628 def tensor_type_cast(input_data, cast_type):
629 """Cast tensor to specified data type."""
630 type_mapping = {
631 'float32': torch.float32,
632 'float16': torch.float16,
633 'int64': torch.int64,
634 'int32': torch.int32
635 }
636 if cast_type not in type_mapping:
637 raise ValueError(f"Unknown cast type: {cast_type}. Supported types: {list(type_mapping.keys())}")
638 return input_data.to(type_mapping[cast_type])
640 @staticmethod
641 def differentiable_all_reduce(data, op, group):
642 # Resolve the op from string to ReduceOp enum if necessary
643 reduce_op = _OP_MAP.get(op, dist.ReduceOp.SUM) if isinstance(op, str) else op
644 return dist_func.all_reduce(data, op=reduce_op, group=group)
646 @staticmethod
647 def get_cell_construct(cell):
648 return cell.forward
650 @staticmethod
651 def get_cells_and_names(cell):
652 return cell.named_modules()
654 @staticmethod
655 def search_parameter_by_name(cell, param_name: str):
656 """
657 Find the parent Module of the parameter, the parameter's name in the parent Module, and the parameter.
658 Return value: (parent Module instance, parameter's name in parent Module, parameter object).
659 Returns None if not found.
660 """
661 # Remove the "self." prefix from param_name
662 param_name = param_name.replace("self.", "")
663 # Case 1: The parameter is a direct parameter of the current Module
664 if param_name in cell._parameters: # pylint: disable=protected-access
665 return (cell, param_name, cell._parameters[param_name]) # pylint: disable=protected-access
667 # Case 2: The parameter is in a sub-Module
668 if "." in param_name:
669 cell_path, param_key = param_name.rsplit(".", 1)
670 try:
671 # Locate the sub-Module where the parameter resides (supports multi-level paths)
672 target_cell = cell.get_submodule(cell_path)
673 # Check if the sub-Module directly contains this parameter
674 if param_key in target_cell._parameters: # pylint: disable=protected-access
675 return target_cell, param_key, target_cell._parameters[param_key] # pylint: disable=protected-access
676 except AttributeError:
677 pass
679 # Traverse all sub-Modules (recursively) to search for the parameter
680 for _, child_cell in cell.named_children():
681 if isinstance(child_cell, Module):
682 result = TorchPlatform.search_parameter_by_name(child_cell, param_name)
683 if result is not None:
684 return result
686 return None
688 @staticmethod
689 def update_parameter_by_name(cell, result: tuple, new_param) -> bool:
690 """
691 Modify the original parameter in a Module or sub-Module using the search result
692 """
693 parent_cell, param_key, _ = result
694 # Key operation: directly modify the _parameters dictionary.
695 if param_key in parent_cell._parameters: # pylint: disable=protected-access
696 parent_cell._parameters[param_key] = new_param # pylint: disable=protected-access
697 else:
698 parent_cell.register_parameter(param_key, new_param)
699 return True
701 @staticmethod
702 def set_layout_into_parameter(param, layout):
703 """Set layout into parameter"""
704 from hyper_parallel.core.dtensor.dtensor import DTensor # pylint: disable=import-outside-toplevel
705 from hyper_parallel.core.dtensor.layout import _get_slice_tensor_by_layout # pylint: disable=import-outside-toplevel
706 if isinstance(param, DTensor):
707 raise ValueError(f"Parameter {param} has been configured layout, cannot be set repeatedly.")
708 requires_grad = param.requires_grad
709 param_dtensor = DTensor.from_local(
710 _get_slice_tensor_by_layout(param, layout),
711 layout.mesh, layout.alias_placements)
712 new_param = Parameter(param_dtensor, requires_grad=requires_grad)
713 return new_param
715 @staticmethod
716 def differentiable_reduce_scatter(data, dev_num, axis, op, group):
717 input_tuple = torch.chunk(data, dev_num, dim=axis)
718 output_tensor = torch.empty(input_tuple[0].shape, device=data.device, dtype=data.dtype)
720 # Resolve the op from string to ReduceOp enum
721 reduce_op = _OP_MAP.get(op, dist.ReduceOp.SUM) if isinstance(op, str) else op
723 output_tensor = dist_func.reduce_scatter(output_tensor, input_tuple, op=reduce_op, group=group)
725 # Keep manual handling for 'avg' string as it maps to SUM in _OP_MAP
726 if op == 'avg':
727 output_tensor = output_tensor / dev_num
728 return output_tensor
730 @staticmethod
731 def get_device_handle(device_type: str = "npu"):
732 try:
733 handle = getattr(torch, device_type)
734 except AttributeError as e:
735 raise RuntimeError(f"TorchPlatform expect got device handle: 'torch.{device_type}' failed.") from e
736 return handle
738 @staticmethod
739 def get_param_type_size(param):
740 # pylint: disable=W0212
741 return torch._utils._element_size(param.dtype)
743 @staticmethod
744 def is_tensor(obj: Any) -> bool:
745 """Return True if ``obj`` is a ``torch.Tensor``."""
746 return isinstance(obj, Tensor)
748 @staticmethod
749 def get_tensor_storage_size(tensor: Any) -> int:
750 """Return serialized byte size (numel * element size) for a PyTorch tensor."""
751 if not TorchPlatform.is_tensor(tensor):
752 raise TypeError(
753 f"TorchPlatform.get_tensor_storage_size expects torch.Tensor, got {type(tensor)!r}"
754 )
755 return int(tensor.numel()) * int(tensor.element_size())
757 @staticmethod
758 def parameters_dict(cell: Module):
759 return cell.named_parameters()
761 @staticmethod
762 def get_model_state_dict(model, *, options=None):
763 # pylint: disable=C0415
764 from hyper_parallel.platform.torch.fully_shard.state_dict_utils import (
765 get_model_state_dict as _get_model_state_dict,
766 )
767 return _get_model_state_dict(model, options=options)
769 @staticmethod
770 def save_checkpoint(cell: Module, file_path: str, ckpt_format: str = "safetensors") -> None:
771 if ckpt_format == "safetensors":
772 save_file(tensors=cell, filename=file_path)
773 else:
774 torch.save(obj=cell, f=file_path)
776 @staticmethod
777 def load_checkpoint(file_path: str, ckpt_format: str = "safetensors") -> dict:
778 if ckpt_format == "safetensors":
779 return load_file(filename=file_path)
780 return torch.load(f=file_path)
782 @staticmethod
783 def new_zero_parameter(param_shape, param_type, requires_grad, device):
784 return nn.Parameter(torch.zeros(param_shape, dtype=param_type, device=device), requires_grad=requires_grad)
786 @staticmethod
787 def new_tensor(tensor_shape, tensor_type, device):
788 return torch.empty(size=tensor_shape, dtype=tensor_type, device=device)
790 @staticmethod
791 def full_like(tensor, fill_value, dtype=None):
792 return torch.full_like(tensor, fill_value, dtype=dtype)
794 @staticmethod
795 def set_tensor_requires_grad(input_tensor):
796 """
797 set requires grad flag for input tensor, only effective for leaf node
798 """
799 if input_tensor.is_leaf:
800 input_tensor.requires_grad = True
802 def _create_group(self, rank_list):
803 group_dict = create_sub_groups(rank_list)
804 return group_dict[tuple(rank_list)]
806 @staticmethod
807 def all_gather_into_tensor(data, group_info, async_op=False):
808 output_shape = list(data.shape)
809 output_shape[0] = output_shape[0] * group_info.rank_size
810 output = torch.empty(output_shape, dtype=data.dtype, device=data.device)
811 handle = dist.all_gather_into_tensor(output, data, group=group_info.group, async_op=async_op)
812 return output, handle
814 @staticmethod
815 def all_reduce(data, group_info, async_op=False):
816 if not data.is_contiguous():
817 data = data.contiguous()
818 handle = dist.all_reduce(data, group=group_info.group, async_op=async_op)
819 return data, handle
821 @staticmethod
822 def broadcast(data, src, group=None, async_op=False):
823 handle = dist.broadcast(data, src, group, async_op)
824 if async_op:
825 handle.wait()
827 @staticmethod
828 def isend(tensor, dst=None, group=None, tag=0):
829 return dist.isend(tensor, dst, group, tag)
831 @staticmethod
832 def irecv(tensor, src=None, group=None, tag=0):
833 return dist.irecv(tensor, src, group, tag)
835 @staticmethod
836 def p2p_exchange(tensor, peer_rank: int, group=None):
837 if peer_rank == dist.get_rank(group):
838 return tensor
839 return _TorchP2PExchangeFunction.apply(tensor, peer_rank, group)
841 @staticmethod
842 def send_object_list(obj_list, dst=None, group=None):
843 dist.send_object_list(obj_list, dst, group)
845 @staticmethod
846 def recv_object_list(obj_list, src=None, group=None):
847 dist.recv_object_list(obj_list, src, group)
849 @staticmethod
850 def reduce_scatter_tensor(data, group_info, async_op=False):
851 output_shape = list(data.shape)
852 output_shape[0] = output_shape[0] // group_info.rank_size
853 output = torch.empty(output_shape, dtype=data.dtype, device=data.device)
854 handle = dist.reduce_scatter_tensor(output, data, group=group_info.group, async_op=async_op)
855 return output, handle
857 @staticmethod
858 def all_to_all_single(input_tensor, output_shape, group, async_op=False):
859 output = torch.empty(output_shape, device=input_tensor.device, dtype=input_tensor.dtype)
860 work = dist.all_to_all_single(output, input_tensor, group=group, async_op=async_op)
861 return output, work
863 @staticmethod
864 def differentiable_all_to_all_single(input_tensor, input_splits, output_splits, group):
865 """Variable-split all-to-all with autograd support for EP token dispatch/combine."""
866 out_total = sum(output_splits)
867 output = torch.empty(
868 out_total, *input_tensor.shape[1:],
869 dtype=input_tensor.dtype, device=input_tensor.device,
870 )
871 output = dist_func.all_to_all_single(
872 output, input_tensor,
873 output_split_sizes=output_splits,
874 input_split_sizes=input_splits,
875 group=group,
876 )
877 return output
879 @staticmethod
880 def differentiable_all_to_all_single_async(input_tensor, input_splits, output_splits, group):
881 """Truly-async variant of :meth:`differentiable_all_to_all_single`.
883 Both forward AND backward return :class:`AsyncCollectiveTensor`,
884 so the ``wait_tensor`` op is queued lazily — only when a downstream
885 kernel actually reads the result.
887 Why both directions need lazy wait:
889 * FWD: ACT lazy wait lets host return immediately and the paired
890 BWD thread's compute kernel slip into the queue before the wait.
891 * BWD: PyTorch's stock backward issues ``wait_tensor`` eagerly,
892 and the autograd engine binds backward stream to the forward
893 stream — so even running BWD inside a ``with torch.npu.stream
894 (side_stream)`` context does not move that wait off the main
895 stream. Returning ACT from backward defers the wait to the
896 next backward op's first consumption, opening a small window
897 during which FWD's Attention kernels can be queued onto the
898 main stream **before** the wait lands.
900 Args:
901 input_tensor: Input tensor, split along dim 0 by ``input_splits``.
902 input_splits: ``list[int]`` — rows sent to each rank.
903 output_splits: ``list[int]`` — rows received from each rank.
904 group: Process group.
906 Returns:
907 ``AsyncCollectiveTensor`` of shape
908 ``[sum(output_splits), *input_tensor.shape[1:]]``.
909 """
910 return _AsyncA2ALazyBwd.apply(input_tensor, output_splits, input_splits, group)
912 @staticmethod
913 def arange(start, end=None, step=1, dtype=None, device=None):
914 """Create a 1-D tensor with evenly spaced values."""
915 if end is None:
916 return torch.arange(start, dtype=dtype, device=device)
917 return torch.arange(start, end, step, dtype=dtype, device=device)
919 @staticmethod
920 def differentiable_async_a2a_wait(x, work, out_perm, group, world_size, concat_dim, split_dim,
921 handle_box=None):
922 """Wait async A2A handle and reconstruct result (differentiable).
924 Args:
925 x: Input tensor.
926 work: Async work handle from all_to_all.
927 out_perm: Output buffer from all_to_all.
928 group: Process group.
929 world_size: World size.
930 concat_dim: Dimension for concatenation.
931 split_dim: Dimension for split.
932 handle_box: Optional mutable list; backward appends (work, out_perm) here.
933 """
934 return _TorchAsyncA2AFunction.apply(
935 x, work, out_perm, group, world_size, concat_dim, split_dim, handle_box
936 )
938 @staticmethod
939 def differentiable_sync_hook(x, hook_name: str, coordinator):
940 """Identity op that fires coordinator rendezvous on forward and backward.
942 Always goes through ``_TorchSyncHookFunction.apply`` so that the
943 autograd graph **records a SyncHook node regardless of whether the
944 coordinator is currently enabled**. Skipping ``apply`` when
945 disabled would leave warmup-forwarded graphs without the hook
946 nodes, and a later ``overlap.run`` — whose BWD thread back-props
947 such a graph — would then traverse zero hooks while the paired FWD
948 thread (whose current forward DOES record hooks) waits at a
949 barrier for a partner that never arrives.
951 Args:
952 x: Input tensor.
953 hook_name: One of ``"A"``, ``"B"``, ``"C"``, ``"D"``.
954 coordinator: A :class:`HookCoordinator` instance.
955 """
956 return _TorchSyncHookFunction.apply(x, hook_name, coordinator)
958 @staticmethod
959 def get_tensor_transform():
960 raise NotImplementedError("Unsupported get_tensor_transform for torch platform")
962 @staticmethod
963 def construct_strided_slice(x, begin, end, stride):
964 raise NotImplementedError("Unsupported construct_strided_slice for torch platform")
966 @staticmethod
967 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None):
968 # pylint: disable=C0415
969 from hyper_parallel.platform.torch.pipeline_parallel._utils import _MicroBatch
970 return _MicroBatch(micro_batch_num, args_batch_dim, kwargs_batch_dim)
972 @staticmethod
973 def get_symmetric_memory_handler():
974 # pylint: disable=C0415
975 from hyper_parallel.platform.torch.symmetric_memory import TorchSymmetricMemoryHandler
976 symmetric_memory = TorchSymmetricMemoryHandler()
977 return symmetric_memory
979 @staticmethod
980 def get_multicore_handler():
981 # pylint: disable=C0415
982 from hyper_parallel.platform.torch.multicore import TorchMulticoreHandler
983 return TorchMulticoreHandler()
985 def new_stream(self):
986 device = self.get_device_handle()
987 return device.Stream()
989 def get_stream_context(self):
990 device = self.get_device_handle()
991 return device.stream
993 @staticmethod
994 def all_gather_object(object_list, obj, group=None) -> None:
995 """
996 Gathers objects from the given group into object list.
998 Args:
999 object_list (list[Any]): Define the output list, which size equal to the size of group.
1000 obj (Any): The object on current rank and in given process group.
1001 group (ProcessGroup, optional): The process group to gather obj. Default is ``None``, and ``None`` means
1002 global group.
1004 Returns:
1005 None. Objs are gathered into ``object_list``.
1006 """
1007 dist.all_gather_object(object_list, obj, group)
1009 @staticmethod
1010 def barrier(group=None, async_op: bool = False, device_ids=None) -> Any:
1011 """
1012 Synchronize all processes in the given process group.
1014 Args:
1015 group (ProcessGroup, optional): The process group to work on. Default is ``None``,
1016 meaning the default process group.
1017 async_op (bool, optional): Whether this op should be asynchronous. Default: ``False``.
1018 device_ids (list[int], optional): Device ids for backends that require a device for
1019 barrier (e.g. NCCL). Default: ``None``.
1021 Returns:
1022 Async work handle if ``async_op`` is True; otherwise ``None``.
1023 """
1024 return dist.barrier(group, async_op, device_ids)
1026 @staticmethod
1027 def init_process_group(
1028 backend: Optional[str] = None,
1029 *,
1030 init_method: Optional[str] = None,
1031 timeout: Optional[timedelta] = None,
1032 world_size: int = -1,
1033 rank: int = -1,
1034 store: Optional[Store] = None,
1035 pg_options: Optional[Any] = None,
1036 device_id: Optional[Union[torch.device, int]] = None,
1037 ) -> None:
1038 """
1039 Initialize global process group.
1041 Args:
1042 backend (str or Backend, optional): The backend to use for distributed communication.
1043 init_method (str, optional): URL specifying how to initialize the process group. Default is "env://",
1044 can not be specified at the same time with ``store``.
1045 timeout (timedelta, optional): Timeout for process group. Default 10 minutes for NCCL and for other
1046 backends 30 minutes.
1047 world_size (int, optional): Number of processes. If ``store`` is specified, world_size is required.
1048 rank (int, optional): Rank of the current process, which value must between 0 and ``world_size``-1. If
1049 ``store`` is specified, rank is required.
1050 store (Store, optional): Key/value store accessible to all workers, used to exchange connection/address
1051 information. Can not be specified at the same time with ``init_method``.
1052 pg_options (ProcessGroupOptions, optional): Extra options to pass during constructing process groups.
1053 device_id (torch.device | int, optional): Specific device this process will work on.
1054 """
1055 try:
1056 _get_default_group()
1057 # except multi version error
1058 except (ValueError, RuntimeError):
1059 if backend is None:
1060 backend = "hccl"
1061 dist.init_process_group(backend=backend, init_method=init_method, timeout=timeout, world_size=world_size,
1062 rank=rank, store=store, pg_options=pg_options, device_id=device_id)
1064 @staticmethod
1065 def destroy_process_group(group: Optional[ProcessGroup] = None) -> None:
1066 """
1067 Destroy given process group.
1069 Args:
1070 group (ProcessGroup, optional): Given process group will be destroyed, if not given, all process groups
1071 will be destroyed.
1072 """
1073 group = group or _get_default_group()
1074 if group in EXISTING_COMM_GROUPS.values():
1075 keys_to_destroy = [k for k, v in EXISTING_COMM_GROUPS.items() if v == group]
1076 for k in keys_to_destroy:
1077 del EXISTING_COMM_GROUPS[k]
1078 dist.destroy_process_group(group)
1080 @staticmethod
1081 def get_process_group_ranks(group: Optional[ProcessGroup] = None) -> list[int]:
1082 """
1083 Get all ranks relative to given process group.
1085 Args:
1086 group (Optional[ProcessGroup]): Process group worked on. Default is ``None``, and ``None`` means global
1087 group.
1089 Returns:
1090 Rank list.
1091 """
1092 group = group or _get_default_group()
1093 return dist.get_process_group_ranks(group)
1095 @staticmethod
1096 def get_backend(group: Optional[ProcessGroup] = None) -> Backend:
1097 """
1098 Get the backend of the given process group.
1100 Args:
1101 group (ProcessGroup, optional): Process group worked on. Default is ``None``, and ``None`` means global
1102 group.
1104 Returns:
1105 The backend object of the given process group.
1106 """
1107 group = group or _get_default_group()
1108 return dist.get_backend(group)
1110 @staticmethod
1111 def split_group(parent_pg: Optional[ProcessGroup] = None,
1112 split_ranks: Optional[list] = None,
1113 timeout: Optional[timedelta] = None,
1114 pg_options: Optional[Any] = None,
1115 group_desc: Optional[str] = None,
1116 ) -> Optional[ProcessGroup]:
1117 """
1118 Create split groups for every group rank in split_ranks, and return the split process group which relative to
1119 current rank id.
1121 Args:
1122 parent_pg (Optional[ProcessGroup]): A process group which the goal group split from.
1123 split_ranks (Optional[list]): A list like ``list[list[int]]``.
1124 timeout (Optional[timedelta]): Timeout for process group. Default 10 minutes for NCCL and for other
1125 backend 30 minutes.
1126 pg_options (Optional[Any]): Extra options to pass during constructing process groups.
1127 group_desc (Optional[str]): Description of process group.
1129 Return:
1130 Optional[ProcessGroup]: One of split process group which relative to current rank id
1131 """
1132 if split_ranks is None or len(split_ranks) == 0:
1133 raise ValueError("split_ranks cannot be None or empty")
1135 split_group = None
1136 for split_rank in split_ranks:
1137 dist_group = TorchPlatform.get_created_group(split_rank)
1138 if dist_group is None:
1139 dist_group = dist.new_group(ranks=split_rank)
1140 EXISTING_COMM_GROUPS[str(tuple(sorted(split_rank)))] = dist_group
1141 if TorchPlatform.get_rank() in split_rank:
1142 split_group = dist_group
1144 return split_group
1146 @staticmethod
1147 def get_group_local_rank(group: ProcessGroup = None) -> int:
1148 """get group local rank id."""
1149 group = group or _get_default_group()
1150 return group.rank()
1152 @staticmethod
1153 def no_grad():
1154 return torch.no_grad()
1156 @staticmethod
1157 def cat(tensors, dim=0):
1158 return torch.cat(tensors, dim=dim)
1160 @staticmethod
1161 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False):
1162 return torch.empty_like(tensor, dtype=dtype, device=device, pin_memory=pin_memory)
1164 def get_current_stream(self):
1165 device = self.get_device_handle()
1166 return device.current_stream()
1168 def new_event(self):
1169 device = self.get_device_handle()
1170 return device.Event()
1172 def tree_map(self, fn, tree):
1173 return torch.utils._pytree.tree_map(fn, tree) # pylint: disable=protected-access
1175 @property
1176 def checkpoint(self):
1177 return torch.utils.checkpoint.checkpoint
1179 @staticmethod
1180 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs):
1181 # pylint: disable=C0415
1182 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import FuncModule
1183 if callable(module) and not isinstance(module, torch.nn.Module):
1184 module = FuncModule(module)
1185 return checkpoint_wrapper(module, checkpoint_fn=checkpoint_fn, **checkpoint_fn_kwargs)
1187 @staticmethod
1188 def swap_wrapper(module, policy_fn=None):
1189 # pylint: disable=C0415
1190 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import swap_wrapper
1191 return swap_wrapper(module, policy_fn=policy_fn)
1193 @property
1194 def noop_context_fn(self):
1195 return noop_context_fn
1197 @staticmethod
1198 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
1199 # pylint: disable=C0415
1200 from hyper_parallel.platform.torch.activation_checkpoint.sac import create_selective_checkpoint_contexts
1201 return create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation)
1203 @staticmethod
1204 def async_save_on_cpu(policy_fn=None):
1205 # pylint: disable=C0415
1206 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import AsyncSaveOnCpu
1207 return AsyncSaveOnCpu(policy_fn)
1209 @staticmethod
1210 def get_element_size(tensor):
1211 """Get Tensor Element Size"""
1212 return tensor.element_size()
1214 @staticmethod
1215 def tensor_to_numpy(tensor) -> np.ndarray:
1216 """Convert PyTorch tensor to numpy array."""
1217 return tensor.cpu().numpy()
1219 @staticmethod
1220 def clip_grad_norm_(
1221 parameters, max_norm, norm_type=2.0,
1222 error_if_nonfinite=False, foreach=None,
1223 ):
1224 # pylint: disable=C0415
1225 from hyper_parallel.platform.torch.clip_grad import (
1226 clip_grad_norm_ as _clip_grad_norm,
1227 )
1228 return _clip_grad_norm(
1229 parameters, max_norm, norm_type,
1230 error_if_nonfinite=error_if_nonfinite, foreach=foreach,
1231 )
1233 @staticmethod
1234 def profiler_record(name):
1235 """Profiler context manager for recording operations using torch.profiler."""
1236 return torch.profiler.record_function(name)
1238 def cast_fp_tensor(self, dtype, x):
1239 """
1240 Cast floating-point tensor to target dtype if applicable.
1241 """
1242 if (
1243 not isinstance(x, torch.Tensor)
1244 or not torch.is_floating_point(x)
1245 or x.dtype == dtype
1246 ):
1247 return x
1248 return x.to(dtype)
1250 def apply_to_tensors(self, fn, container):
1251 """Recursively apply to all tensor in different kinds of container types."""
1253 def apply(x):
1255 if isinstance(x, torch.Tensor):
1256 return fn(x)
1257 if hasattr(x, "__dataclass_fields__"):
1258 dc = dataclasses.replace(x)
1259 changes = {
1260 f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc)
1261 }
1262 return dataclasses.replace(dc, **changes)
1263 if isinstance(x, OrderedDict):
1264 od = x.__class__()
1265 for key, value in x.items():
1266 od[key] = apply(value)
1267 return od
1268 if isinstance(x, PackedSequence):
1269 apply(x.data)
1270 return x
1271 if isinstance(x, dict):
1272 return {key: apply(value) for key, value in x.items()}
1273 if isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields"):
1274 res = (apply(el) for el in x)
1275 return type(x)(*res)
1276 if isinstance(x, (list, tuple, set)):
1277 return type(x)(apply(el) for el in x)
1278 return x
1280 return apply(container)
1283 @property
1284 def meta_device(self):
1285 return torch.device("meta")
1287 def init_on_device(self, device, include_buffers=False):
1288 return _init_on_device(device, include_buffers=include_buffers)
1290 def str_to_dtype(self, dtype_str: str) -> torch.dtype:
1291 """Map ``torch.<type>`` strings from checkpoint metadata to ``torch.dtype``."""
1292 parts = dtype_str.split(".", 1)
1293 if len(parts) != 2:
1294 raise ValueError(
1295 f"Expected dtype string like 'torch.float32', got {dtype_str!r}."
1296 )
1297 prefix, name = parts
1298 if prefix != "torch":
1299 raise ValueError(
1300 f"Expected PyTorch dtype string with prefix 'torch', got {dtype_str!r}."
1301 )
1302 dtype = getattr(torch, name)
1303 if isinstance(dtype, torch.dtype):
1304 return dtype
1305 raise ValueError(f"{dtype_str!r} does not resolve to a torch.dtype.")
1307 def list_to_size(self, size_list: list[int]) -> torch.Size:
1308 return torch.Size(size_list)