Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / group_utils.py: 9%

88 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""_group_manager""" 

16 

17from typing import Dict, List, Tuple, Union 

18 

19import torch.distributed as dist 

20 

21from hyper_parallel.platform.platform import EXISTING_COMM_GROUPS 

22 

23 

24def _validate_intra_step(normalized_template: List[int], template_len: int) -> int: 

25 """Verify consistent intra-group step and return intra_step.""" 

26 intra_step = normalized_template[1] - normalized_template[0] 

27 for i in range(1, template_len - 1): 

28 diff = normalized_template[i + 1] - normalized_template[i] 

29 if diff != intra_step: 

30 msg = ( 

31 f"Template must have consistent intra-group step. " 

32 f"Found {normalized_template[i+1]} - {normalized_template[i]} = {diff}, " 

33 f"expected {intra_step}" 

34 ) 

35 raise ValueError(msg) 

36 return intra_step 

37 

38 

39def _compute_group_starts(world_size: int, block_size: int, inter_step: int) -> List[int]: 

40 """Compute all valid block start positions.""" 

41 return [s for s in range(0, world_size, inter_step) if s + block_size <= world_size] 

42 

43 

44def _build_groups_for_blocks( 

45 group_starts: List[int], 

46 block_size: int, 

47 template_span_int: int, 

48 normalized_template: List[int], 

49 template_len: int, 

50 world_size: int, 

51) -> List[List[int]]: 

52 """Build all groups from block starts.""" 

53 all_groups = [] 

54 for start_block in group_starts: 

55 max_offset = block_size - template_span_int 

56 for offset in range(0, max_offset): 

57 group = [start_block + offset + normalized_template[i] for i in range(template_len)] 

58 if all(0 <= r < world_size for r in group): 

59 all_groups.append(group) 

60 return all_groups 

61 

62 

63def generate_groups_from_template( 

64 template: Union[List[int], Tuple[int, ...]], 

65 world_size: int, 

66 my_rank: int, 

67 verbose: bool = False 

68) -> List[List[int]]: 

69 """ 

70 Auto-generate all communication groups from a template (supports any valid starting template). 

71 

72 Args: 

73 template: Template group, e.g. [0,1], [0,2,4,6] or [1,3,5,7] 

74 world_size: Total number of processes 

75 my_rank: Current process rank (for debug output) 

76 verbose: Whether to print debug info 

77 

78 Returns: 

79 Full rank list, e.g.: 

80 - template [0,1] + world_size=8 -> [[0,1], [2,3], [4,5], [6,7]] 

81 - template [0,2,4,6] + world_size=8 -> [[0,2,4,6], [1,3,5,7]] 

82 - template [1,3,5,7] + world_size=8 -> [[0,2,4,6], [1,3,5,7]] 

83 

84 Algorithm: 

85 1. Template normalization: convert any starting template to 0-based 

86 2. Analyze pattern (intra-step, template span) 

87 3. Iterate by blocks, generate valid sub-groups per block 

88 4. Ensure each rank appears in exactly one group 

89 """ 

90 # convert template to int list and sort (rank_list may come from numpy/tensor as float) 

91 template = sorted([int(x) for x in list(template)]) 

92 world_size = int(world_size) 

93 my_rank = int(my_rank) 

94 template_len = len(template) 

95 

96 if verbose: 

97 print(f"Rank {my_rank}: Original Template = {template}, World size = {world_size}") 

98 

99 if template_len == 1: 

100 return [[i] for i in range(world_size)] 

101 

102 if template_len < 2: 

103 raise ValueError(f"Template must have at least 2 ranks, got {template}") 

104 

105 # 1. Template normalization: convert to 0-based template 

106 template_base = template[0] # original template start value 

107 normalized_template = [x - template_base for x in template] # normalize to 0-based 

108 if verbose: 

109 print(f"Rank {my_rank}: Normalized Template = {normalized_template}") 

110 

111 # 2. Analyze normalized template core params 

112 # intra-step: spacing between elements in template 

113 intra_step = _validate_intra_step(normalized_template, template_len) 

114 # template span: last - first element of normalized template 

115 template_span = normalized_template[-1] - normalized_template[0] 

116 # block size: ranks per block (determines inter-step) 

117 block_size = int(intra_step * template_len) 

118 # inter-step: spacing between adjacent blocks (equals block_size) 

119 inter_step = block_size 

120 

121 if verbose: 

122 print( 

123 f"Rank {my_rank}: Template analysis - " 

124 f"intra_step={intra_step}, template_span={template_span}, " 

125 f"block_size={block_size}, inter_step={inter_step}" 

126 ) 

127 

128 # 3. Compute all valid block start positions 

129 group_starts = _compute_group_starts(world_size, block_size, inter_step) 

130 if verbose: 

131 print(f"Rank {my_rank}: Possible block starts: {group_starts}") 

132 

133 # 4. Generate all valid sub-groups for each block 

134 template_span_int = int(template_span) 

135 all_groups = _build_groups_for_blocks( 

136 group_starts, block_size, template_span_int, 

137 normalized_template, template_len, world_size 

138 ) 

139 

140 # 5. Validate: ensure each rank appears exactly once 

141 all_ranks = [rank for group in all_groups for rank in group] 

142 unique_ranks = set(all_ranks) 

143 if len(all_ranks) != len(unique_ranks): 

144 raise ValueError("Duplicate ranks found! Some ranks appear in multiple groups.") 

145 

146 # 6. Sort: ensure all processes generate groups in same order 

147 all_groups.sort(key=lambda x: (x[0], x[1] if len(x) > 1 else 0)) 

148 

149 if verbose: 

150 print( 

151 f"Rank {my_rank}: Generated {len(all_groups)} groups, " 

152 f"covering {len(unique_ranks)} unique ranks\n" 

153 f"Final group list: {all_groups}" 

154 ) 

155 

156 return all_groups 

157 

158 

159def create_sub_groups( 

160 rank_list: Union[List[int], Tuple[int, ...]], 

161 verbose: bool = False 

162) -> Dict[tuple, dist.ProcessGroup]: 

163 """ 

164 Create sub-communication groups, supports template auto-expansion. 

165 

166 Args: 

167 rank_list: One of: 

168 1. Full group list, e.g. [[0,1], [2,3], [4,5], [6,7]] 

169 2. Template group, e.g. [0,1] or [0,2], will auto-expand 

170 verbose: Whether to print debug info 

171 

172 Returns: 

173 Dict, key is tuple of group ranks, value is ProcessGroup 

174 """ 

175 my_rank = dist.get_rank() 

176 world_size = dist.get_world_size() 

177 template = list(rank_list) 

178 full_rank_list = generate_groups_from_template(template, world_size, my_rank, verbose=verbose) 

179 

180 if verbose: 

181 print(f"Rank {my_rank}: Full rank list to create: {full_rank_list}") 

182 

183 # validate full group list format 

184 for i, group in enumerate(full_rank_list): 

185 if not isinstance(group, (list, tuple)): 

186 raise ValueError(f"Group {i} must be a list or tuple, got {type(group)}") 

187 if len(group) == 0: 

188 raise ValueError(f"Group {i} is empty") 

189 if len(group) != len(set(group)): 

190 raise ValueError(f"Group {i} contains duplicate ranks") 

191 for rank in group: 

192 if not isinstance(rank, int): 

193 raise ValueError(f"Rank must be integer, got {type(rank)} in group {i}") 

194 

195 # sort by first element to ensure all processes create groups in same order 

196 sorted_groups = sorted(full_rank_list, key=lambda x: x[0]) 

197 

198 if verbose: 

199 print(f"Rank {my_rank}: Sorted groups for creation: {sorted_groups}") 

200 

201 # create all groups and collect groups current process belongs to 

202 group_dict = {} 

203 for group_ranks in sorted_groups: 

204 # ensure ranks are ordered so each process passes same order 

205 sorted_ranks = sorted(group_ranks) 

206 

207 if verbose: 

208 print(f"Rank {my_rank}: Creating group with ranks {sorted_ranks}") 

209 

210 # key: all processes participate in each group creation 

211 group = dist.new_group(ranks=sorted_ranks) 

212 EXISTING_COMM_GROUPS[str(tuple(sorted_ranks))] = group 

213 

214 # only save when current process is in the group 

215 if my_rank in sorted_ranks: 

216 group_dict[tuple(sorted_ranks)] = group 

217 

218 if verbose: 

219 print(f"Rank {my_rank}: Created {len(group_dict)} groups I belong to") 

220 

221 return group_dict