Coverage for hyper_parallel / platform / torch / group_utils.py: 79%
86 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""_group_manager"""
17from typing import Dict, List, Tuple, Union
19import torch.distributed as dist
22def _validate_intra_step(normalized_template: List[int], template_len: int) -> int:
23 """Verify consistent intra-group step and return intra_step."""
24 intra_step = normalized_template[1] - normalized_template[0]
25 for i in range(1, template_len - 1):
26 diff = normalized_template[i + 1] - normalized_template[i]
27 if diff != intra_step:
28 msg = (
29 f"Template must have consistent intra-group step. "
30 f"Found {normalized_template[i+1]} - {normalized_template[i]} = {diff}, "
31 f"expected {intra_step}"
32 )
33 raise ValueError(msg)
34 return intra_step
37def _compute_group_starts(world_size: int, block_size: int, inter_step: int) -> List[int]:
38 """Compute all valid block start positions."""
39 return [s for s in range(0, world_size, inter_step) if s + block_size <= world_size]
42def _build_groups_for_blocks(
43 group_starts: List[int],
44 block_size: int,
45 template_span_int: int,
46 normalized_template: List[int],
47 template_len: int,
48 world_size: int,
49) -> List[List[int]]:
50 """Build all groups from block starts."""
51 all_groups = []
52 for start_block in group_starts:
53 max_offset = block_size - template_span_int
54 for offset in range(0, max_offset):
55 group = [start_block + offset + normalized_template[i] for i in range(template_len)]
56 if all(0 <= r < world_size for r in group):
57 all_groups.append(group)
58 return all_groups
61def generate_groups_from_template(
62 template: Union[List[int], Tuple[int, ...]],
63 world_size: int,
64 my_rank: int,
65 verbose: bool = False
66) -> List[List[int]]:
67 """
68 根据模板组自动生成所有通信组(支持任意合法起始的模板)。
70 参数:
71 template: 模板组,例如 [0,1]、[0,2,4,6] 或 [1,3,5,7]
72 world_size: 总进程数
73 my_rank: 当前进程rank(用于打印调试信息)
74 verbose: 是否打印调试信息
76 返回:
77 完整的rank列表,例如:
78 - 模板[0,1] + world_size=8 → [[0,1], [2,3], [4,5], [6,7]]
79 - 模板[0,2,4,6] + world_size=8 → [[0,2,4,6], [1,3,5,7]]
80 - 模板[1,3,5,7] + world_size=8 → [[0,2,4,6], [1,3,5,7]]
82 原理:
83 1. 模板归一化:将任意起始的模板转换为以0为起点的基准模板
84 2. 分析基准模板的模式(组内步长、模板跨度)
85 3. 按块遍历,在每个块内生成所有合法子组
86 4. 确保每个rank恰好出现在一个组中
87 """
88 # 将模板转换为整数列表并排序(rank_list 可能来自 numpy/tensor 等为 float)
89 template = sorted([int(x) for x in list(template)])
90 world_size = int(world_size)
91 my_rank = int(my_rank)
92 template_len = len(template)
94 if verbose:
95 print(f"Rank {my_rank}: Original Template = {template}, World size = {world_size}")
97 if template_len == 1:
98 return [[i] for i in range(world_size)]
100 if template_len < 2:
101 raise ValueError(f"Template must have at least 2 ranks, got {template}")
103 # 1. 模板归一化:转换为以0为起点的基准模板(消除起始值影响)
104 template_base = template[0] # 原始模板的起始值
105 normalized_template = [x - template_base for x in template] # 归一化到0起点
106 if verbose:
107 print(f"Rank {my_rank}: Normalized Template = {normalized_template}")
109 # 2. 分析归一化后的模板核心参数
110 # 组内步长(模板内元素的间隔)
111 intra_step = _validate_intra_step(normalized_template, template_len)
112 # 模板跨度(归一化模板最后一个元素 - 第一个元素)
113 template_span = normalized_template[-1] - normalized_template[0]
114 # 块大小:每个块可容纳的rank数(决定组间步长)
115 block_size = int(intra_step * template_len)
116 # 组间步长:相邻块的起始间隔(等于块大小)
117 inter_step = block_size
119 if verbose:
120 print(
121 f"Rank {my_rank}: Template analysis - "
122 f"intra_step={intra_step}, template_span={template_span}, "
123 f"block_size={block_size}, inter_step={inter_step}"
124 )
126 # 3. 计算所有合法的块起始位置
127 group_starts = _compute_group_starts(world_size, block_size, inter_step)
128 if verbose:
129 print(f"Rank {my_rank}: Possible block starts: {group_starts}")
131 # 4. 为每个块生成所有合法子组
132 template_span_int = int(template_span)
133 all_groups = _build_groups_for_blocks(
134 group_starts, block_size, template_span_int,
135 normalized_template, template_len, world_size
136 )
138 # 5. 验证:确保每个rank只出现一次
139 all_ranks = [rank for group in all_groups for rank in group]
140 unique_ranks = set(all_ranks)
141 if len(all_ranks) != len(unique_ranks):
142 raise ValueError("Duplicate ranks found! Some ranks appear in multiple groups.")
144 # 6. 排序:确保所有进程生成的组顺序一致
145 all_groups.sort(key=lambda x: (x[0], x[1] if len(x) > 1 else 0))
147 if verbose:
148 print(
149 f"Rank {my_rank}: Generated {len(all_groups)} groups, "
150 f"covering {len(unique_ranks)} unique ranks\n"
151 f"Final group list: {all_groups}"
152 )
154 return all_groups
157def create_sub_groups(
158 rank_list: Union[List[int], Tuple[int, ...]],
159 verbose: bool = False
160) -> Dict[tuple, dist.ProcessGroup]:
161 """
162 创建子通信组,支持模板组自动扩展。
164 参数:
165 rank_list: 可以是以下两种格式之一:
166 1. 完整的组列表,例如 [[0,1], [2,3], [4,5], [6,7]]
167 2. 模板组,例如 [0,1] 或 [0,2],会自动扩展
168 verbose: 是否打印调试信息
170 返回:
171 字典,键为组ranks的元组,值为该组的ProcessGroup对象
172 """
173 my_rank = dist.get_rank()
174 world_size = dist.get_world_size()
175 template = list(rank_list)
176 full_rank_list = generate_groups_from_template(template, world_size, my_rank, verbose=verbose)
178 if verbose:
179 print(f"Rank {my_rank}: Full rank list to create: {full_rank_list}")
181 # 验证完整组列表格式
182 for i, group in enumerate(full_rank_list):
183 if not isinstance(group, (list, tuple)):
184 raise ValueError(f"Group {i} must be a list or tuple, got {type(group)}")
185 if len(group) == 0:
186 raise ValueError(f"Group {i} is empty")
187 if len(group) != len(set(group)):
188 raise ValueError(f"Group {i} contains duplicate ranks")
189 for rank in group:
190 if not isinstance(rank, int):
191 raise ValueError(f"Rank must be integer, got {type(rank)} in group {i}")
193 # 按照第一个元素的顺序排序,确保所有进程以相同顺序创建组
194 sorted_groups = sorted(full_rank_list, key=lambda x: x[0])
196 if verbose:
197 print(f"Rank {my_rank}: Sorted groups for creation: {sorted_groups}")
199 # 创建所有组并收集当前进程所在的组
200 group_dict = {}
201 for group_ranks in sorted_groups:
202 # 确保ranks有序,这样每个进程传入相同的顺序
203 sorted_ranks = sorted(group_ranks)
205 if verbose:
206 print(f"Rank {my_rank}: Creating group with ranks {sorted_ranks}")
208 # 关键:所有进程都参与每个组的创建
209 group = dist.new_group(ranks=sorted_ranks)
211 # 只在当前进程在组内时保存
212 if my_rank in sorted_ranks:
213 group_dict[tuple(sorted_ranks)] = group
215 if verbose:
216 print(f"Rank {my_rank}: Created {len(group_dict)} groups I belong to")
218 return group_dict