Coverage for hyper_parallel / platform / torch / group_utils.py: 79%

86 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""_group_manager""" 

16 

17from typing import Dict, List, Tuple, Union 

18 

19import torch.distributed as dist 

20 

21 

22def _validate_intra_step(normalized_template: List[int], template_len: int) -> int: 

23 """Verify consistent intra-group step and return intra_step.""" 

24 intra_step = normalized_template[1] - normalized_template[0] 

25 for i in range(1, template_len - 1): 

26 diff = normalized_template[i + 1] - normalized_template[i] 

27 if diff != intra_step: 

28 msg = ( 

29 f"Template must have consistent intra-group step. " 

30 f"Found {normalized_template[i+1]} - {normalized_template[i]} = {diff}, " 

31 f"expected {intra_step}" 

32 ) 

33 raise ValueError(msg) 

34 return intra_step 

35 

36 

37def _compute_group_starts(world_size: int, block_size: int, inter_step: int) -> List[int]: 

38 """Compute all valid block start positions.""" 

39 return [s for s in range(0, world_size, inter_step) if s + block_size <= world_size] 

40 

41 

42def _build_groups_for_blocks( 

43 group_starts: List[int], 

44 block_size: int, 

45 template_span_int: int, 

46 normalized_template: List[int], 

47 template_len: int, 

48 world_size: int, 

49) -> List[List[int]]: 

50 """Build all groups from block starts.""" 

51 all_groups = [] 

52 for start_block in group_starts: 

53 max_offset = block_size - template_span_int 

54 for offset in range(0, max_offset): 

55 group = [start_block + offset + normalized_template[i] for i in range(template_len)] 

56 if all(0 <= r < world_size for r in group): 

57 all_groups.append(group) 

58 return all_groups 

59 

60 

61def generate_groups_from_template( 

62 template: Union[List[int], Tuple[int, ...]], 

63 world_size: int, 

64 my_rank: int, 

65 verbose: bool = False 

66) -> List[List[int]]: 

67 """ 

68 根据模板组自动生成所有通信组(支持任意合法起始的模板)。 

69 

70 参数: 

71 template: 模板组,例如 [0,1]、[0,2,4,6] 或 [1,3,5,7] 

72 world_size: 总进程数 

73 my_rank: 当前进程rank(用于打印调试信息) 

74 verbose: 是否打印调试信息 

75 

76 返回: 

77 完整的rank列表,例如: 

78 - 模板[0,1] + world_size=8 → [[0,1], [2,3], [4,5], [6,7]] 

79 - 模板[0,2,4,6] + world_size=8 → [[0,2,4,6], [1,3,5,7]] 

80 - 模板[1,3,5,7] + world_size=8 → [[0,2,4,6], [1,3,5,7]] 

81 

82 原理: 

83 1. 模板归一化:将任意起始的模板转换为以0为起点的基准模板 

84 2. 分析基准模板的模式(组内步长、模板跨度) 

85 3. 按块遍历,在每个块内生成所有合法子组 

86 4. 确保每个rank恰好出现在一个组中 

87 """ 

88 # 将模板转换为整数列表并排序(rank_list 可能来自 numpy/tensor 等为 float) 

89 template = sorted([int(x) for x in list(template)]) 

90 world_size = int(world_size) 

91 my_rank = int(my_rank) 

92 template_len = len(template) 

93 

94 if verbose: 

95 print(f"Rank {my_rank}: Original Template = {template}, World size = {world_size}") 

96 

97 if template_len == 1: 

98 return [[i] for i in range(world_size)] 

99 

100 if template_len < 2: 

101 raise ValueError(f"Template must have at least 2 ranks, got {template}") 

102 

103 # 1. 模板归一化:转换为以0为起点的基准模板(消除起始值影响) 

104 template_base = template[0] # 原始模板的起始值 

105 normalized_template = [x - template_base for x in template] # 归一化到0起点 

106 if verbose: 

107 print(f"Rank {my_rank}: Normalized Template = {normalized_template}") 

108 

109 # 2. 分析归一化后的模板核心参数 

110 # 组内步长(模板内元素的间隔) 

111 intra_step = _validate_intra_step(normalized_template, template_len) 

112 # 模板跨度(归一化模板最后一个元素 - 第一个元素) 

113 template_span = normalized_template[-1] - normalized_template[0] 

114 # 块大小:每个块可容纳的rank数(决定组间步长) 

115 block_size = int(intra_step * template_len) 

116 # 组间步长:相邻块的起始间隔(等于块大小) 

117 inter_step = block_size 

118 

119 if verbose: 

120 print( 

121 f"Rank {my_rank}: Template analysis - " 

122 f"intra_step={intra_step}, template_span={template_span}, " 

123 f"block_size={block_size}, inter_step={inter_step}" 

124 ) 

125 

126 # 3. 计算所有合法的块起始位置 

127 group_starts = _compute_group_starts(world_size, block_size, inter_step) 

128 if verbose: 

129 print(f"Rank {my_rank}: Possible block starts: {group_starts}") 

130 

131 # 4. 为每个块生成所有合法子组 

132 template_span_int = int(template_span) 

133 all_groups = _build_groups_for_blocks( 

134 group_starts, block_size, template_span_int, 

135 normalized_template, template_len, world_size 

136 ) 

137 

138 # 5. 验证:确保每个rank只出现一次 

139 all_ranks = [rank for group in all_groups for rank in group] 

140 unique_ranks = set(all_ranks) 

141 if len(all_ranks) != len(unique_ranks): 

142 raise ValueError("Duplicate ranks found! Some ranks appear in multiple groups.") 

143 

144 # 6. 排序:确保所有进程生成的组顺序一致 

145 all_groups.sort(key=lambda x: (x[0], x[1] if len(x) > 1 else 0)) 

146 

147 if verbose: 

148 print( 

149 f"Rank {my_rank}: Generated {len(all_groups)} groups, " 

150 f"covering {len(unique_ranks)} unique ranks\n" 

151 f"Final group list: {all_groups}" 

152 ) 

153 

154 return all_groups 

155 

156 

157def create_sub_groups( 

158 rank_list: Union[List[int], Tuple[int, ...]], 

159 verbose: bool = False 

160) -> Dict[tuple, dist.ProcessGroup]: 

161 """ 

162 创建子通信组,支持模板组自动扩展。 

163 

164 参数: 

165 rank_list: 可以是以下两种格式之一: 

166 1. 完整的组列表,例如 [[0,1], [2,3], [4,5], [6,7]] 

167 2. 模板组,例如 [0,1] 或 [0,2],会自动扩展 

168 verbose: 是否打印调试信息 

169 

170 返回: 

171 字典,键为组ranks的元组,值为该组的ProcessGroup对象 

172 """ 

173 my_rank = dist.get_rank() 

174 world_size = dist.get_world_size() 

175 template = list(rank_list) 

176 full_rank_list = generate_groups_from_template(template, world_size, my_rank, verbose=verbose) 

177 

178 if verbose: 

179 print(f"Rank {my_rank}: Full rank list to create: {full_rank_list}") 

180 

181 # 验证完整组列表格式 

182 for i, group in enumerate(full_rank_list): 

183 if not isinstance(group, (list, tuple)): 

184 raise ValueError(f"Group {i} must be a list or tuple, got {type(group)}") 

185 if len(group) == 0: 

186 raise ValueError(f"Group {i} is empty") 

187 if len(group) != len(set(group)): 

188 raise ValueError(f"Group {i} contains duplicate ranks") 

189 for rank in group: 

190 if not isinstance(rank, int): 

191 raise ValueError(f"Rank must be integer, got {type(rank)} in group {i}") 

192 

193 # 按照第一个元素的顺序排序,确保所有进程以相同顺序创建组 

194 sorted_groups = sorted(full_rank_list, key=lambda x: x[0]) 

195 

196 if verbose: 

197 print(f"Rank {my_rank}: Sorted groups for creation: {sorted_groups}") 

198 

199 # 创建所有组并收集当前进程所在的组 

200 group_dict = {} 

201 for group_ranks in sorted_groups: 

202 # 确保ranks有序,这样每个进程传入相同的顺序 

203 sorted_ranks = sorted(group_ranks) 

204 

205 if verbose: 

206 print(f"Rank {my_rank}: Creating group with ranks {sorted_ranks}") 

207 

208 # 关键:所有进程都参与每个组的创建 

209 group = dist.new_group(ranks=sorted_ranks) 

210 

211 # 只在当前进程在组内时保存 

212 if my_rank in sorted_ranks: 

213 group_dict[tuple(sorted_ranks)] = group 

214 

215 if verbose: 

216 print(f"Rank {my_rank}: Created {len(group_dict)} groups I belong to") 

217 

218 return group_dict