Coverage for hyper_parallel / platform / torch / platform.py: 81%
372 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Torch platform api"""
16from datetime import timedelta
17from typing import Optional, Any, Union
18import dataclasses
19from collections import OrderedDict
21import numpy as np
22import torch
23from torch import nn
24from torch import Tensor
25from torch._C._distributed_c10d import Store, ProcessGroup
26from torch.distributed import Backend
27from torch.distributed.distributed_c10d import _get_default_group
28from torch.nn import Parameter, Module
29from torch.nn.utils.rnn import PackedSequence
30from torch._ops import OpOverload, OpOverloadPacket
31from torch.utils.checkpoint import noop_context_fn
32from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
33import torch.distributed.nn.functional as dist_func
34import torch.distributed as dist
35from hyper_parallel.platform.torch.dtensor import DTensorBase
36from hyper_parallel.platform.torch.pipeline_parallel.stage import PipelineStageBase
37from hyper_parallel.platform.torch.group_utils import create_sub_groups
38from hyper_parallel.platform.platform import Platform, PlatformType
39from hyper_parallel.platform.torch.function_override import override_functions
41override_functions()
43# Mapping from string op names to torch.distributed.ReduceOp
44_OP_MAP = {
45 'sum': dist.ReduceOp.SUM,
46 'prod': dist.ReduceOp.PRODUCT,
47 'max': dist.ReduceOp.MAX,
48 'min': dist.ReduceOp.MIN,
49 # convert tensor elements to int32 and use MIN
50 'all': dist.ReduceOp.MIN,
51 # 'avg' is typically handled by SUM followed by division in current implementation logic
52 'avg': dist.ReduceOp.SUM,
53}
55# Try to add AVG for 'mean' if supported by current torch version
56if hasattr(dist.ReduceOp, "AVG"):
57 _OP_MAP['mean'] = dist.ReduceOp.AVG
58else:
59 # Fallback for older torch versions if necessary, though this might require manual division upstream
60 # Assuming standard behavior where 'mean' implies native AVG support or upstream handling
61 _OP_MAP['mean'] = dist.ReduceOp.SUM
64# pylint: disable=C0103
65class TorchPlatform(Platform):
66 """Torch platform api"""
67 Tensor = Tensor
68 tensor = torch.tensor
69 Parameter = Parameter
70 Module = Module
71 DTensorBase = DTensorBase
72 PipelineStageBase = PipelineStageBase
73 platform_type = PlatformType.PYTORCH
74 tensor_dtype = torch
76 @staticmethod
77 def device_count(device_handle):
78 return device_handle.device_count()
80 def device_type(self):
81 device_handle = self.get_device_handle()
82 if device_handle == torch.npu:
83 return "npu"
84 return "cuda"
86 def device(self, device_idx=None):
87 device_type = self.device_type()
88 if device_idx is None:
89 return torch.device(device_type)
90 return torch.device(f"{device_type}:{device_idx:d}")
92 @staticmethod
93 def get_rng_state(device=None, device_handle=None):
94 if device_handle is None:
95 return torch.get_rng_state()
96 if device is None:
97 return device_handle.get_rng_state()
98 return device_handle.get_rng_state(device)
100 @staticmethod
101 def set_rng_state(state, device=None, device_handle=None):
102 if device_handle is None:
103 return torch.set_rng_state(state)
104 if device is None:
105 return device_handle.set_rng_state(state)
106 return device_handle.set_rng_state(state, device)
108 @staticmethod
109 def manual_seed(seed):
110 return torch.manual_seed(seed)
112 @staticmethod
113 def ones(size, dtype=None):
114 return torch.ones(size, dtype=dtype)
116 @staticmethod
117 def zeros(size, dtype=None):
118 return torch.zeros(size, dtype=dtype)
120 @staticmethod
121 def full(size, fill_value, dtype=None):
122 return torch.full(size, fill_value, dtype=dtype)
124 @staticmethod
125 def empty(size, dtype=None):
126 return torch.empty(size, dtype=dtype)
128 @staticmethod
129 def get_rank():
130 return dist.get_rank()
132 @staticmethod
133 def get_global_rank(group, group_rank):
134 return dist.get_global_rank(group, group_rank)
136 @staticmethod
137 def get_world_size():
138 return dist.get_world_size()
140 @staticmethod
141 def get_param_local_shape(param):
142 """get param local shape"""
143 if isinstance(param, DTensorBase):
144 return param.local_shape
145 return param.shape
147 @staticmethod
148 def get_param_local_data(param):
149 """get param local shape"""
150 if isinstance(param, DTensorBase):
151 return param.to_local()
152 return param
154 @staticmethod
155 def update_param_data(param, data):
156 """update param data"""
157 param.data = data
159 @staticmethod
160 def get_op_name(func):
161 if hasattr(func, "__name__"):
162 return func.__name__
163 if isinstance(func, OpOverload):
164 full_name = func.name
165 core_name = full_name.split("::")[-1].split(".")[0]
166 return core_name
167 if isinstance(func, OpOverloadPacket):
168 return func.name.split("::")[-1]
169 func_str = str(func)
170 if "built-in function" in func_str:
171 return func_str.split()[-1].strip(">")
172 if "function" in func_str:
173 return func_str.split()[1]
174 return "unknown_op"
176 @staticmethod
177 def differentiable_all_gather_concat(data, group, concat_size, concat_dim):
178 output = dist_func.all_gather(data, group=group)
179 return torch.cat(output, dim=concat_dim)
181 @staticmethod
182 def chunk(data, split_dim, split_size, index):
183 return torch.chunk(data, split_size, dim=split_dim)[index]
185 @staticmethod
186 def differentiable_all_to_all(input_data, output_shape, group):
187 output_tensor = torch.empty(output_shape, device=input_data.device, dtype=input_data.dtype)
188 output_tensor = dist_func.all_to_all_single(
189 output_tensor,
190 input_data,
191 group=group
192 )
193 return output_tensor
195 @staticmethod
196 def tensor_type_cast(input_data, cast_type):
197 """Cast tensor to specified data type."""
198 type_mapping = {
199 'float32': torch.float32,
200 'float16': torch.float16,
201 'int64': torch.int64,
202 'int32': torch.int32
203 }
204 if cast_type not in type_mapping:
205 raise ValueError(f"Unknown cast type: {cast_type}. Supported types: {list(type_mapping.keys())}")
206 return input_data.to(type_mapping[cast_type])
208 @staticmethod
209 def differentiable_all_reduce(data, op, group):
210 # Resolve the op from string to ReduceOp enum if necessary
211 reduce_op = _OP_MAP.get(op, dist.ReduceOp.SUM) if isinstance(op, str) else op
212 return dist_func.all_reduce(data, op=reduce_op, group=group)
214 @staticmethod
215 def get_cell_construct(cell):
216 return cell.forward
218 @staticmethod
219 def get_cells_and_names(cell):
220 return cell.named_modules()
222 @staticmethod
223 def search_parameter_by_name(cell, param_name: str):
224 """
225 Find the parent Module of the parameter, the parameter's name in the parent Module, and the parameter.
226 Return value: (parent Module instance, parameter's name in parent Module, parameter object).
227 Returns None if not found.
228 """
229 # Remove the "self." prefix from param_name
230 param_name = param_name.replace("self.", "")
231 # Case 1: The parameter is a direct parameter of the current Module
232 if param_name in cell._parameters: # pylint:disable=protected-access
233 return (cell, param_name, cell._parameters[param_name]) # pylint:disable=protected-access
235 # Case 2: The parameter is in a sub-Module
236 if "." in param_name:
237 cell_path, param_key = param_name.rsplit(".", 1)
238 try:
239 # Locate the sub-Module where the parameter resides (supports multi-level paths)
240 target_cell = cell.get_submodule(cell_path)
241 # Check if the sub-Module directly contains this parameter
242 if param_key in target_cell._parameters: # pylint:disable=protected-access
243 return target_cell, param_key, target_cell._parameters[param_key] # pylint:disable=protected-access
244 except AttributeError:
245 pass
247 # Traverse all sub-Modules (recursively) to search for the parameter
248 for _, child_cell in cell.named_children():
249 if isinstance(child_cell, Module):
250 result = TorchPlatform.search_parameter_by_name(child_cell, param_name)
251 if result is not None:
252 return result
254 return None
256 @staticmethod
257 def update_parameter_by_name(cell, result: tuple, new_param) -> bool:
258 """
259 Modify the original parameter in a Module or sub-Module using the search result
260 """
261 parent_cell, param_key, _ = result
262 # Key operation: directly modify the _parameters dictionary.
263 if param_key in parent_cell._parameters: # pylint:disable=protected-access
264 parent_cell._parameters[param_key] = new_param # pylint:disable=protected-access
265 else:
266 parent_cell.register_parameter(param_key, new_param)
267 return True
269 @staticmethod
270 def set_layout_into_parameter(param, layout):
271 """Set layout in to parameter"""
272 from hyper_parallel.core.dtensor import DTensor # pylint: disable=import-outside-toplevel
273 from hyper_parallel.core.layout import _get_slice_tensor_by_layout # pylint: disable=import-outside-toplevel
274 if isinstance(param, DTensor):
275 raise ValueError(f"Parameter {param} has been configured layout, cannot be set repeatedly.")
276 requires_grad = param.requires_grad
277 param_dtensor = DTensor.from_local(_get_slice_tensor_by_layout(param, layout), layout.mesh, layout.placements)
278 new_param = Parameter(param_dtensor, requires_grad=requires_grad)
279 return new_param
281 @staticmethod
282 def differentiable_reduce_scatter(data, dev_num, axis, op, group):
283 input_tuple = torch.chunk(data, dev_num, dim=axis)
284 output_tensor = torch.empty(input_tuple[0].shape, device=data.device, dtype=data.dtype)
286 # Resolve the op from string to ReduceOp enum
287 reduce_op = _OP_MAP.get(op, dist.ReduceOp.SUM) if isinstance(op, str) else op
289 output_tensor = dist_func.reduce_scatter(output_tensor, input_tuple, op=reduce_op, group=group)
291 # Keep manual handling for 'avg' string as it maps to SUM in _OP_MAP
292 if op == 'avg':
293 output_tensor = output_tensor / dev_num
294 return output_tensor
296 @staticmethod
297 def get_device_handle():
298 if hasattr(torch, "npu"):
299 return torch.npu
300 return torch.cuda
302 @staticmethod
303 def get_param_type_size(param):
304 # pylint: disable=W0212
305 return torch._utils._element_size(param.dtype)
307 @staticmethod
308 def parameters_dict(cell: Module):
309 return cell.named_parameters()
311 @staticmethod
312 def save_checkpoint(cell: Module, file_path: str) -> None:
313 torch.save(obj=cell, f=file_path)
315 @staticmethod
316 def load_checkpoint(file_path: str) -> dict:
317 return torch.load(f=file_path)
319 @staticmethod
320 def new_zero_parameter(param_shape, param_type, requires_grad, device):
321 return nn.Parameter(torch.zeros(param_shape, dtype=param_type, device=device), requires_grad=requires_grad)
323 @staticmethod
324 def new_tensor(tensor_shape, tensor_type, device):
325 return torch.empty(size=tensor_shape, dtype=tensor_type, device=device)
327 @staticmethod
328 def full_like(tensor, fill_value, dtype=None):
329 return torch.full_like(tensor, fill_value, dtype=dtype)
331 @staticmethod
332 def set_tensor_requires_grad(input_tensor):
333 """
334 set requires grad flag for input tensor, only effective for leaf node
335 """
336 if input_tensor.is_leaf:
337 input_tensor.requires_grad = True
339 def _create_group(self, rank_list, group_name=None):
340 group_dict = create_sub_groups(rank_list)
341 return group_dict[tuple(rank_list)]
343 @staticmethod
344 def all_gather_into_tensor(data, group_info, async_op=False):
345 output_shape = list(data.shape)
346 output_shape[0] = output_shape[0] * group_info.rank_size
347 output = torch.empty(output_shape, dtype=data.dtype, device=data.device)
348 handle = dist.all_gather_into_tensor(output, data, group=group_info.group, async_op=async_op)
349 return output, handle
351 @staticmethod
352 def all_reduce(data, group_info, async_op=False):
353 if not data.is_contiguous():
354 data = data.contiguous()
355 handle = dist.all_reduce(data, group=group_info.group, async_op=async_op)
356 return data, handle
358 @staticmethod
359 def broadcast(data, src, group=None, async_op=False):
360 handle = dist.broadcast(data, src, group, async_op)
361 if async_op:
362 handle.wait()
364 @staticmethod
365 def isend(tensor, dst=None, group=None, tag=0):
366 return dist.isend(tensor, dst, group, tag)
368 @staticmethod
369 def irecv(tensor, src=None, group=None, tag=0):
370 return dist.irecv(tensor, src, group, tag)
372 @staticmethod
373 def send_object_list(obj_list, dst=None, group=None):
374 dist.send_object_list(obj_list, dst, group)
376 @staticmethod
377 def recv_object_list(obj_list, src=None, group=None):
378 dist.recv_object_list(obj_list, src, group)
380 @staticmethod
381 def reduce_scatter_tensor(data, group_info, async_op=False):
382 output_shape = list(data.shape)
383 output_shape[0] = output_shape[0] // group_info.rank_size
384 output = torch.empty(output_shape, dtype=data.dtype, device=data.device)
385 handle = dist.reduce_scatter_tensor(output, data, group=group_info.group, async_op=async_op)
386 return output, handle
388 @staticmethod
389 def get_tensor_transform():
390 raise NotImplementedError("Unsupported get_tensor_transform for torch platform")
392 @staticmethod
393 def construct_strided_slice(x, begin, end, stride):
394 raise NotImplementedError("Unsupported construct_strided_slice for torch platform")
396 @staticmethod
397 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None):
398 # pylint: disable=C0415
399 from hyper_parallel.platform.torch.pipeline_parallel._utils import _MicroBatch
400 return _MicroBatch(micro_batch_num, args_batch_dim, kwargs_batch_dim)
402 def new_stream(self):
403 device = self.get_device_handle()
404 return device.Stream()
406 def get_stream_context(self):
407 device = self.get_device_handle()
408 return device.stream
410 @staticmethod
411 def all_gather_object(object_list, obj, group=None) -> None:
412 """
413 Gathers objects from the given group into object list.
415 Args:
416 object_list (list[Any]): Define the output list, which size equal to the size of group.
417 obj (Any): The object on current rank and in given process group.
418 group (ProcessGroup, optional): The process group to gather obj. Default is ``None``, and ``None`` means
419 global group.
421 Returns:
422 None. Objs are gathered into ``object_list``.
423 """
424 dist.all_gather_object(object_list, obj, group)
426 @staticmethod
427 def init_process_group(
428 backend: Optional[str] = None,
429 *,
430 init_method: Optional[str] = None,
431 timeout: Optional[timedelta] = None,
432 world_size: int = -1,
433 rank: int = -1,
434 store: Optional[Store] = None,
435 pg_options: Optional[Any] = None,
436 device_id: Optional[Union[torch.device, int]] = None,
437 ) -> None:
438 """
439 Initialize global process group.
441 Args:
442 backend (str or Backend, optional): The backend to use for distributed communication.
443 init_method (str, optional): URL specifying how to initialize the process group. Default is "env://",
444 can not be specified at the same time with ``store``.
445 timeout (timedelta, optional): Timeout for process group. Default 10 minutes for NCCL and for other
446 backends 30 minutes.
447 world_size (int, optional): Number of processes. If ``store`` is specified, world_size is required.
448 rank (int, optional): Rank of the current process, which value must between 0 and ``world_size``-1. If
449 ``store`` is specified, rank is required.
450 store (Store, optional): Key/value store accessible to all workers, used to exchange connection/address
451 information. Can not be specified at the same time with ``init_method``.
452 pg_options (ProcessGroupOptions, optional): Extra options to pass during constructing process groups.
453 device_id (torch.device | int, optional): Specific device this process will work on.
454 """
455 try:
456 _get_default_group()
457 # except multi version error
458 except (ValueError, RuntimeError):
459 if backend is None:
460 backend = "hccl"
461 dist.init_process_group(backend=backend, init_method=init_method, timeout=timeout, world_size=world_size,
462 rank=rank, store=store, pg_options=pg_options, device_id=device_id)
464 @staticmethod
465 def destroy_process_group(group: Optional[ProcessGroup] = None) -> None:
466 """
467 Destroy given process group.
469 Args:
470 group (ProcessGroup, optional): Given process group will be destroyed, if not given, all process groups
471 will be destroyed.
472 """
473 group = group or _get_default_group()
474 dist.destroy_process_group(group)
476 @staticmethod
477 def get_process_group_ranks(group: Optional[ProcessGroup] = None) -> list[int]:
478 """
479 Get all ranks relative to given process group.
481 Args:
482 group (Optional[ProcessGroup]): Process group worked on. Default is ``None``, and ``None`` means global
483 group.
485 Returns:
486 Rank list.
487 """
488 group = group or _get_default_group()
489 return dist.get_process_group_ranks(group)
491 @staticmethod
492 def get_backend(group: Optional[ProcessGroup] = None) -> Backend:
493 """
494 Get the backend of the given process group.
496 Args:
497 group (ProcessGroup, optional): Process group worked on. Default is ``None``, and ``None`` means global
498 group.
500 Returns:
501 The backend object of the given process group.
502 """
503 group = group or _get_default_group()
504 return dist.get_backend(group)
506 @staticmethod
507 def split_group(parent_pg: Optional[ProcessGroup] = None,
508 split_ranks: Optional[list] = None,
509 timeout: Optional[timedelta] = None,
510 pg_options: Optional[Any] = None,
511 group_desc: Optional[str] = None,
512 ) -> Optional[ProcessGroup]:
513 """
514 Create split groups for every group rank in split_ranks, and return the split process group which relative to
515 current rank id.
517 Args:
518 parent_pg (Optional[ProcessGroup]): A process group which the goal group split from.
519 split_ranks (Optional[list]): A list like ``list[list[int]]``.
520 timeout (Optional[timedelta]): Timeout for process group. Default 10 minutes for NCCL and for other
521 backend 30 minutes.
522 pg_options (Optional[Any]): Extra options to pass during constructing process groups.
523 group_desc (Optional[str]): Description of process group.
525 Return:
526 Optional[ProcessGroup]: One of split process group which relative to current rank id
527 """
528 if split_ranks is None or len(split_ranks) == 0:
529 raise ValueError("split_ranks cannot be None or empty")
531 split_group = None
532 for split_rank in split_ranks:
533 dist_group = dist.new_group(ranks=split_rank)
534 if TorchPlatform.get_rank() in split_rank:
535 split_group = dist_group
537 return split_group
539 @staticmethod
540 def no_grad():
541 return torch.no_grad()
543 @staticmethod
544 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False):
545 return torch.empty_like(tensor, dtype=dtype, device=device, pin_memory=pin_memory)
547 def get_current_stream(self):
548 device = self.get_device_handle()
549 return device.current_stream()
551 def new_event(self):
552 device = self.get_device_handle()
553 return device.Event()
555 def tree_map(self, fn, tree):
556 return torch.utils._pytree.tree_map(fn, tree) # pylint:disable=protected-access
558 @property
559 def checkpoint(self):
560 return torch.utils.checkpoint.checkpoint
562 @staticmethod
563 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs):
564 return checkpoint_wrapper(module, checkpoint_fn=checkpoint_fn, **checkpoint_fn_kwargs)
566 @property
567 def noop_context_fn(self):
568 return noop_context_fn
570 @staticmethod
571 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
572 # pylint: disable=C0415
573 from hyper_parallel.platform.torch.activation_checkpoint.sac import create_selective_checkpoint_contexts
574 return create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation)
576 @staticmethod
577 def async_save_on_cpu(policy_fn=None):
578 # pylint: disable=C0415
579 from hyper_parallel.platform.torch.activation_checkpoint.activation_swap import AsyncSaveOnCpu
580 return AsyncSaveOnCpu(policy_fn)
582 @staticmethod
583 def tensor_to_numpy(tensor) -> np.ndarray:
584 """Convert PyTorch tensor to numpy array."""
585 return tensor.cpu().numpy()
587 def cast_fp_tensor(self,dtype, x):
588 """
589 Cast floating-point tensor to target dtype if applicable.
590 """
591 if (
592 not isinstance(x, torch.Tensor)
593 or not torch.is_floating_point(x)
594 or x.dtype == dtype
595 ):
596 return x
597 return x.to(dtype)
599 def apply_to_tensors(self, fn, container):
600 """Recursively apply to all tensor in different kinds of container types."""
602 def apply(x):
604 if isinstance(x, torch.Tensor):
605 return fn(x)
606 if hasattr(x, "__dataclass_fields__"):
607 dc = dataclasses.replace(x)
608 changes = {
609 f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc)
610 }
611 return dataclasses.replace(dc, **changes)
612 if isinstance(x, OrderedDict):
613 od = x.__class__()
614 for key, value in x.items():
615 od[key] = apply(value)
616 return od
617 if isinstance(x, PackedSequence):
618 apply(x.data)
619 return x
620 if isinstance(x, dict):
621 return {key: apply(value) for key, value in x.items()}
622 if isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields"):
623 res = (apply(el) for el in x)
624 return type(x)(*res)
625 if isinstance(x, (list, tuple, set)):
626 return type(x)(apply(el) for el in x)
627 return x
629 return apply(container)