Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / tensor_parallel / api.py: 98%

53 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""High-level API to apply parallel styles to modules (aligned with PyTorch ``parallelize_module``).""" 

16from __future__ import annotations 

17 

18import warnings 

19from contextlib import contextmanager 

20from fnmatch import fnmatch 

21from typing import Iterator, Optional, Union 

22 

23from hyper_parallel.core.dtensor.device_mesh import DeviceMesh, _mesh_resources 

24from hyper_parallel.core.tensor_parallel.style import ParallelStyle 

25from hyper_parallel.platform import get_platform 

26 

27platform = get_platform() 

28Module = platform.Module 

29 

30__all__ = ["parallelize_module"] 

31 

32 

33def _named_children(module: Module): 

34 """Immediate child modules: PyTorch ``nn.Module.named_children`` or MindSpore ``Cell.name_cells``.""" 

35 if hasattr(module, "named_children"): 

36 return module.named_children() 

37 return module.name_cells().items() 

38 

39 

40@contextmanager 

41def _tensor_parallel_mesh_context(device_mesh: DeviceMesh) -> Iterator[DeviceMesh]: 

42 """Internal: same thread-local stack as ``with device_mesh:`` for ``parallelize_module(..., None)``. 

43 

44 Prefer user code using ``with mesh:``; this exists for tests and library helpers. 

45 """ 

46 with device_mesh: 

47 yield device_mesh 

48 

49 

50def _validate_tp_mesh_dim(device_mesh: DeviceMesh) -> None: 

51 """Require a 1-D mesh, matching PyTorch tensor-parallel constraints.""" 

52 if device_mesh.ndim > 1: 

53 raise ValueError( 

54 f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D! " 

55 f'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"] ' 

56 f'or another 1-D sub-mesh slice (e.g. mesh["cp"]).' 

57 ) 

58 

59 

60def parallelize_module( # type: ignore[return] 

61 module: Module, 

62 device_mesh: Optional[DeviceMesh] = None, 

63 parallelize_plan: Optional[Union[ParallelStyle, dict[str, ParallelStyle]]] = None, 

64 *, 

65 src_data_rank: Optional[int] = 0, 

66) -> Module: 

67 """Apply parallel styles to *module* or submodules (PyTorch-compatible interface). 

68 

69 Behaviour follows ``torch.distributed.tensor.parallel.parallelize_module``: 

70 

71 - *device_mesh* should normally be passed explicitly. Omitting it (``None``) requires an 

72 active mesh context: ``with mesh:`` (see :meth:`hyper_parallel.core.dtensor.device_mesh.DeviceMesh.__enter__`) 

73 or :func:`_tensor_parallel_mesh_context` for tests/libraries. 

74 - *parallelize_plan* may be a single :class:`ParallelStyle` (applied to *module*) 

75 or a dict mapping submodule paths to styles. Path segments support ``fnmatch`` 

76 patterns (e.g. ``\"layers.*\"``) like PyTorch FQN rules. 

77 - Only **1-D** :class:`DeviceMesh` is accepted; slice a sub-mesh from a multi-dim mesh first. 

78 - *src_data_rank* is stored on the style (``style.src_data_rank``) before ``apply``; styles 

79 that shard parameters from a logical global tensor may use it (see PyTorch TP). 

80 

81 Note: 

82 When ``parallelize_plan`` is a single :class:`ParallelStyle` (not a dict), this 

83 function modifies it in-place by setting ``parallelize_plan.src_data_rank``. 

84 The caller should be aware that the passed object will be mutated. 

85 

86 Args: 

87 module: Root module to parallelize. 

88 device_mesh: Mesh for this TP/CP slice. Use ``None`` only inside ``with mesh:`` (or 

89 :func:`_tensor_parallel_mesh_context`) so ``_mesh_resources.get_current_mesh()`` resolves 

90 (see PyTorch ``distribute_module``). 

91 parallelize_plan: A :class:`ParallelStyle` or dict ``{path: ParallelStyle}``. 

92 src_data_rank: Source rank for global tensor semantics; ``None`` means use local data only 

93 (PyTorch parity). Default ``0``. 

94 

95 Returns: 

96 *module* after in-place parallelization. 

97 """ 

98 if device_mesh is None: 

99 device_mesh = _mesh_resources.get_current_mesh() 

100 _validate_tp_mesh_dim(device_mesh) 

101 

102 if parallelize_plan is None: 

103 warnings.warn( 

104 "No parallelize_plan is provided and auto-parallel is not supported " 

105 "at the moment, so this parallelize_module call will do nothing.", 

106 stacklevel=2, 

107 ) 

108 return module 

109 

110 if isinstance(parallelize_plan, ParallelStyle): 

111 parallelize_plan.src_data_rank = src_data_rank 

112 return parallelize_plan.apply(module, device_mesh) 

113 if isinstance(parallelize_plan, dict): 

114 

115 def _apply_path( 

116 current_module: Module, 

117 atoms: list[str], 

118 style: ParallelStyle, 

119 src_rank: Optional[int], 

120 ) -> bool: 

121 atom = atoms[0] 

122 matched_children = list( 

123 filter( 

124 lambda t, pattern=atom: fnmatch(t[0], pattern), 

125 _named_children(current_module), 

126 ) 

127 ) 

128 applied = False 

129 for _, submodule in matched_children: 

130 if len(atoms) == 1: 

131 parallelize_module( 

132 submodule, 

133 device_mesh, 

134 style, 

135 src_data_rank=src_rank, 

136 ) 

137 applied = True 

138 else: 

139 applied = _apply_path(submodule, atoms[1:], style, src_rank) or applied 

140 return applied 

141 

142 for module_path, parallelize_style in parallelize_plan.items(): 

143 if not isinstance(parallelize_style, ParallelStyle): 

144 raise TypeError( 

145 "Expect ParallelStyle values in parallelize_plan dict, but got " 

146 f"{type(parallelize_style)} for path '{module_path}'." 

147 ) 

148 path_splits = module_path.split(".") 

149 if module_path == "" or any(path == "" for path in path_splits): 

150 raise ValueError( 

151 f"Expect module path to be non-empty dot-separated atoms, but got '{module_path}'." 

152 ) 

153 if not _apply_path(module, path_splits, parallelize_style, src_data_rank): 

154 warnings.warn( 

155 f"parallelize_plan path '{module_path}' has no matches, so this path is skipped.", 

156 stacklevel=2, 

157 ) 

158 return module 

159 raise TypeError( 

160 "Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for" 

161 f" parallelize_plan, {type(parallelize_plan)} found!" 

162 )