Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / tensor_parallel / api.py: 98%
53 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""High-level API to apply parallel styles to modules (aligned with PyTorch ``parallelize_module``)."""
16from __future__ import annotations
18import warnings
19from contextlib import contextmanager
20from fnmatch import fnmatch
21from typing import Iterator, Optional, Union
23from hyper_parallel.core.dtensor.device_mesh import DeviceMesh, _mesh_resources
24from hyper_parallel.core.tensor_parallel.style import ParallelStyle
25from hyper_parallel.platform import get_platform
27platform = get_platform()
28Module = platform.Module
30__all__ = ["parallelize_module"]
33def _named_children(module: Module):
34 """Immediate child modules: PyTorch ``nn.Module.named_children`` or MindSpore ``Cell.name_cells``."""
35 if hasattr(module, "named_children"):
36 return module.named_children()
37 return module.name_cells().items()
40@contextmanager
41def _tensor_parallel_mesh_context(device_mesh: DeviceMesh) -> Iterator[DeviceMesh]:
42 """Internal: same thread-local stack as ``with device_mesh:`` for ``parallelize_module(..., None)``.
44 Prefer user code using ``with mesh:``; this exists for tests and library helpers.
45 """
46 with device_mesh:
47 yield device_mesh
50def _validate_tp_mesh_dim(device_mesh: DeviceMesh) -> None:
51 """Require a 1-D mesh, matching PyTorch tensor-parallel constraints."""
52 if device_mesh.ndim > 1:
53 raise ValueError(
54 f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D! "
55 f'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"] '
56 f'or another 1-D sub-mesh slice (e.g. mesh["cp"]).'
57 )
60def parallelize_module( # type: ignore[return]
61 module: Module,
62 device_mesh: Optional[DeviceMesh] = None,
63 parallelize_plan: Optional[Union[ParallelStyle, dict[str, ParallelStyle]]] = None,
64 *,
65 src_data_rank: Optional[int] = 0,
66) -> Module:
67 """Apply parallel styles to *module* or submodules (PyTorch-compatible interface).
69 Behaviour follows ``torch.distributed.tensor.parallel.parallelize_module``:
71 - *device_mesh* should normally be passed explicitly. Omitting it (``None``) requires an
72 active mesh context: ``with mesh:`` (see :meth:`hyper_parallel.core.dtensor.device_mesh.DeviceMesh.__enter__`)
73 or :func:`_tensor_parallel_mesh_context` for tests/libraries.
74 - *parallelize_plan* may be a single :class:`ParallelStyle` (applied to *module*)
75 or a dict mapping submodule paths to styles. Path segments support ``fnmatch``
76 patterns (e.g. ``\"layers.*\"``) like PyTorch FQN rules.
77 - Only **1-D** :class:`DeviceMesh` is accepted; slice a sub-mesh from a multi-dim mesh first.
78 - *src_data_rank* is stored on the style (``style.src_data_rank``) before ``apply``; styles
79 that shard parameters from a logical global tensor may use it (see PyTorch TP).
81 Note:
82 When ``parallelize_plan`` is a single :class:`ParallelStyle` (not a dict), this
83 function modifies it in-place by setting ``parallelize_plan.src_data_rank``.
84 The caller should be aware that the passed object will be mutated.
86 Args:
87 module: Root module to parallelize.
88 device_mesh: Mesh for this TP/CP slice. Use ``None`` only inside ``with mesh:`` (or
89 :func:`_tensor_parallel_mesh_context`) so ``_mesh_resources.get_current_mesh()`` resolves
90 (see PyTorch ``distribute_module``).
91 parallelize_plan: A :class:`ParallelStyle` or dict ``{path: ParallelStyle}``.
92 src_data_rank: Source rank for global tensor semantics; ``None`` means use local data only
93 (PyTorch parity). Default ``0``.
95 Returns:
96 *module* after in-place parallelization.
97 """
98 if device_mesh is None:
99 device_mesh = _mesh_resources.get_current_mesh()
100 _validate_tp_mesh_dim(device_mesh)
102 if parallelize_plan is None:
103 warnings.warn(
104 "No parallelize_plan is provided and auto-parallel is not supported "
105 "at the moment, so this parallelize_module call will do nothing.",
106 stacklevel=2,
107 )
108 return module
110 if isinstance(parallelize_plan, ParallelStyle):
111 parallelize_plan.src_data_rank = src_data_rank
112 return parallelize_plan.apply(module, device_mesh)
113 if isinstance(parallelize_plan, dict):
115 def _apply_path(
116 current_module: Module,
117 atoms: list[str],
118 style: ParallelStyle,
119 src_rank: Optional[int],
120 ) -> bool:
121 atom = atoms[0]
122 matched_children = list(
123 filter(
124 lambda t, pattern=atom: fnmatch(t[0], pattern),
125 _named_children(current_module),
126 )
127 )
128 applied = False
129 for _, submodule in matched_children:
130 if len(atoms) == 1:
131 parallelize_module(
132 submodule,
133 device_mesh,
134 style,
135 src_data_rank=src_rank,
136 )
137 applied = True
138 else:
139 applied = _apply_path(submodule, atoms[1:], style, src_rank) or applied
140 return applied
142 for module_path, parallelize_style in parallelize_plan.items():
143 if not isinstance(parallelize_style, ParallelStyle):
144 raise TypeError(
145 "Expect ParallelStyle values in parallelize_plan dict, but got "
146 f"{type(parallelize_style)} for path '{module_path}'."
147 )
148 path_splits = module_path.split(".")
149 if module_path == "" or any(path == "" for path in path_splits):
150 raise ValueError(
151 f"Expect module path to be non-empty dot-separated atoms, but got '{module_path}'."
152 )
153 if not _apply_path(module, path_splits, parallelize_style, src_data_rank):
154 warnings.warn(
155 f"parallelize_plan path '{module_path}' has no matches, so this path is skipped.",
156 stacklevel=2,
157 )
158 return module
159 raise TypeError(
160 "Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
161 f" parallelize_plan, {type(parallelize_plan)} found!"
162 )