Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / group_utils.py: 9%
88 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""_group_manager"""
17from typing import Dict, List, Tuple, Union
19import torch.distributed as dist
21from hyper_parallel.platform.platform import EXISTING_COMM_GROUPS
24def _validate_intra_step(normalized_template: List[int], template_len: int) -> int:
25 """Verify consistent intra-group step and return intra_step."""
26 intra_step = normalized_template[1] - normalized_template[0]
27 for i in range(1, template_len - 1):
28 diff = normalized_template[i + 1] - normalized_template[i]
29 if diff != intra_step:
30 msg = (
31 f"Template must have consistent intra-group step. "
32 f"Found {normalized_template[i+1]} - {normalized_template[i]} = {diff}, "
33 f"expected {intra_step}"
34 )
35 raise ValueError(msg)
36 return intra_step
39def _compute_group_starts(world_size: int, block_size: int, inter_step: int) -> List[int]:
40 """Compute all valid block start positions."""
41 return [s for s in range(0, world_size, inter_step) if s + block_size <= world_size]
44def _build_groups_for_blocks(
45 group_starts: List[int],
46 block_size: int,
47 template_span_int: int,
48 normalized_template: List[int],
49 template_len: int,
50 world_size: int,
51) -> List[List[int]]:
52 """Build all groups from block starts."""
53 all_groups = []
54 for start_block in group_starts:
55 max_offset = block_size - template_span_int
56 for offset in range(0, max_offset):
57 group = [start_block + offset + normalized_template[i] for i in range(template_len)]
58 if all(0 <= r < world_size for r in group):
59 all_groups.append(group)
60 return all_groups
63def generate_groups_from_template(
64 template: Union[List[int], Tuple[int, ...]],
65 world_size: int,
66 my_rank: int,
67 verbose: bool = False
68) -> List[List[int]]:
69 """
70 Auto-generate all communication groups from a template (supports any valid starting template).
72 Args:
73 template: Template group, e.g. [0,1], [0,2,4,6] or [1,3,5,7]
74 world_size: Total number of processes
75 my_rank: Current process rank (for debug output)
76 verbose: Whether to print debug info
78 Returns:
79 Full rank list, e.g.:
80 - template [0,1] + world_size=8 -> [[0,1], [2,3], [4,5], [6,7]]
81 - template [0,2,4,6] + world_size=8 -> [[0,2,4,6], [1,3,5,7]]
82 - template [1,3,5,7] + world_size=8 -> [[0,2,4,6], [1,3,5,7]]
84 Algorithm:
85 1. Template normalization: convert any starting template to 0-based
86 2. Analyze pattern (intra-step, template span)
87 3. Iterate by blocks, generate valid sub-groups per block
88 4. Ensure each rank appears in exactly one group
89 """
90 # convert template to int list and sort (rank_list may come from numpy/tensor as float)
91 template = sorted([int(x) for x in list(template)])
92 world_size = int(world_size)
93 my_rank = int(my_rank)
94 template_len = len(template)
96 if verbose:
97 print(f"Rank {my_rank}: Original Template = {template}, World size = {world_size}")
99 if template_len == 1:
100 return [[i] for i in range(world_size)]
102 if template_len < 2:
103 raise ValueError(f"Template must have at least 2 ranks, got {template}")
105 # 1. Template normalization: convert to 0-based template
106 template_base = template[0] # original template start value
107 normalized_template = [x - template_base for x in template] # normalize to 0-based
108 if verbose:
109 print(f"Rank {my_rank}: Normalized Template = {normalized_template}")
111 # 2. Analyze normalized template core params
112 # intra-step: spacing between elements in template
113 intra_step = _validate_intra_step(normalized_template, template_len)
114 # template span: last - first element of normalized template
115 template_span = normalized_template[-1] - normalized_template[0]
116 # block size: ranks per block (determines inter-step)
117 block_size = int(intra_step * template_len)
118 # inter-step: spacing between adjacent blocks (equals block_size)
119 inter_step = block_size
121 if verbose:
122 print(
123 f"Rank {my_rank}: Template analysis - "
124 f"intra_step={intra_step}, template_span={template_span}, "
125 f"block_size={block_size}, inter_step={inter_step}"
126 )
128 # 3. Compute all valid block start positions
129 group_starts = _compute_group_starts(world_size, block_size, inter_step)
130 if verbose:
131 print(f"Rank {my_rank}: Possible block starts: {group_starts}")
133 # 4. Generate all valid sub-groups for each block
134 template_span_int = int(template_span)
135 all_groups = _build_groups_for_blocks(
136 group_starts, block_size, template_span_int,
137 normalized_template, template_len, world_size
138 )
140 # 5. Validate: ensure each rank appears exactly once
141 all_ranks = [rank for group in all_groups for rank in group]
142 unique_ranks = set(all_ranks)
143 if len(all_ranks) != len(unique_ranks):
144 raise ValueError("Duplicate ranks found! Some ranks appear in multiple groups.")
146 # 6. Sort: ensure all processes generate groups in same order
147 all_groups.sort(key=lambda x: (x[0], x[1] if len(x) > 1 else 0))
149 if verbose:
150 print(
151 f"Rank {my_rank}: Generated {len(all_groups)} groups, "
152 f"covering {len(unique_ranks)} unique ranks\n"
153 f"Final group list: {all_groups}"
154 )
156 return all_groups
159def create_sub_groups(
160 rank_list: Union[List[int], Tuple[int, ...]],
161 verbose: bool = False
162) -> Dict[tuple, dist.ProcessGroup]:
163 """
164 Create sub-communication groups, supports template auto-expansion.
166 Args:
167 rank_list: One of:
168 1. Full group list, e.g. [[0,1], [2,3], [4,5], [6,7]]
169 2. Template group, e.g. [0,1] or [0,2], will auto-expand
170 verbose: Whether to print debug info
172 Returns:
173 Dict, key is tuple of group ranks, value is ProcessGroup
174 """
175 my_rank = dist.get_rank()
176 world_size = dist.get_world_size()
177 template = list(rank_list)
178 full_rank_list = generate_groups_from_template(template, world_size, my_rank, verbose=verbose)
180 if verbose:
181 print(f"Rank {my_rank}: Full rank list to create: {full_rank_list}")
183 # validate full group list format
184 for i, group in enumerate(full_rank_list):
185 if not isinstance(group, (list, tuple)):
186 raise ValueError(f"Group {i} must be a list or tuple, got {type(group)}")
187 if len(group) == 0:
188 raise ValueError(f"Group {i} is empty")
189 if len(group) != len(set(group)):
190 raise ValueError(f"Group {i} contains duplicate ranks")
191 for rank in group:
192 if not isinstance(rank, int):
193 raise ValueError(f"Rank must be integer, got {type(rank)} in group {i}")
195 # sort by first element to ensure all processes create groups in same order
196 sorted_groups = sorted(full_rank_list, key=lambda x: x[0])
198 if verbose:
199 print(f"Rank {my_rank}: Sorted groups for creation: {sorted_groups}")
201 # create all groups and collect groups current process belongs to
202 group_dict = {}
203 for group_ranks in sorted_groups:
204 # ensure ranks are ordered so each process passes same order
205 sorted_ranks = sorted(group_ranks)
207 if verbose:
208 print(f"Rank {my_rank}: Creating group with ranks {sorted_ranks}")
210 # key: all processes participate in each group creation
211 group = dist.new_group(ranks=sorted_ranks)
212 EXISTING_COMM_GROUPS[str(tuple(sorted_ranks))] = group
214 # only save when current process is in the group
215 if my_rank in sorted_ranks:
216 group_dict[tuple(sorted_ranks)] = group
218 if verbose:
219 print(f"Rank {my_rank}: Created {len(group_dict)} groups I belong to")
221 return group_dict