Coverage for hyper_parallel / core / checkpoint / reshard.py: 95%

159 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""resharding tensor""" 

16import operator 

17from typing import Any, Optional, Union 

18from functools import reduce 

19import numpy as np 

20 

21 

22def check_layout(layout: Optional[Any], name: str) -> None: 

23 """ 

24 Validates that a layout contains required attributes with correct types. 

25 

26 Args: 

27 layout (Optional[Any]): Layout object to validate. 

28 name (str): Name of the layout (for error messages). 

29 

30 Raises: 

31 ValueError: If layout missing required attributes or has size mismatches 

32 TypeError: If layout components are not tuples/lists 

33 """ 

34 if not layout: 

35 return 

36 

37 # Check for required attributes 

38 required_attrs = ['mesh_shape', '_tensor_map', '_rank_list'] 

39 for attr in required_attrs: 

40 if not hasattr(layout, attr): 

41 raise ValueError( 

42 f"Layout {name} must contain attribute {attr}" 

43 ) 

44 

45 # Validate component types 

46 def check_type_is_sequence(obj: Any, obj_name: str) -> None: 

47 if not isinstance(obj, (tuple, list)): 

48 raise TypeError( 

49 f"Layout {name} {obj_name} must be tuple or list, " 

50 f"but got {type(obj).__name__}" 

51 ) 

52 

53 layout_dict = layout.to_dict() 

54 check_type_is_sequence(layout_dict['mesh_shape'], 'mesh_shape') 

55 check_type_is_sequence(layout_dict['tensor_map'], 'tensor_map') 

56 check_type_is_sequence(layout_dict['rank_list'], 'rank_list') 

57 

58 # Validate rank list size matches device count 

59 dev_num = reduce(operator.mul, layout_dict['mesh_shape']) 

60 if len(layout_dict['rank_list']) != dev_num: 

61 raise ValueError( 

62 f"Layout {name} rank_list size ({len(layout_dict['rank_list'])}) " 

63 f"must match device count ({dev_num})" 

64 ) 

65 

66 

67def rank_id_to_dev_id_list(mesh_shape: tuple[int, ...], rank_id: int) -> list[int]: 

68 """ 

69 Converts a rank ID to a list of device IDs based on the mesh shape. 

70 

71 Args: 

72 mesh_shape (tuple[int, ...]): Shape of the mesh shape. 

73 rank_id (int): Global rank ID to convert. 

74 

75 Returns: 

76 list[int]: List of device IDs corresponding to the rank. 

77 """ 

78 dims = len(mesh_shape) 

79 dev_id_list = [0] * dims 

80 

81 for i in range(dims - 1, -1, -1): 

82 dev_id_list[i] = rank_id % mesh_shape[i] 

83 rank_id = rank_id // mesh_shape[i] 

84 

85 return dev_id_list 

86 

87 

88def infer_intersection( 

89 area_a: tuple[tuple[int, int], ...], 

90 area_b: tuple[tuple[int, int], ...] 

91) -> Optional[tuple[tuple[int, int], ...]]: 

92 """ 

93 Calculates the intersection of two tensor slice areas. 

94 

95 Args: 

96 area_a (tuple[tuple[int, int], ...]): First area to intersect. 

97 area_b (tuple[tuple[int, int], ...]): Second area to intersect. 

98 

99 Returns: 

100 Optional[tuple[tuple[int, int], ...]]: Tuple of intersection boundaries or None if no intersection. 

101 """ 

102 # Validate input formats 

103 def is_valid_axis_list(axis_list: Any) -> None: 

104 if not isinstance(axis_list, (tuple, list)): 

105 raise TypeError("Area must be a tuple of ranges") 

106 for axis_range in axis_list: 

107 if (not isinstance(axis_range, (tuple, list)) \ 

108 or len(axis_range) != 2): 

109 raise TypeError("Each axis range must be a 2-element tuple") 

110 

111 is_valid_axis_list(area_a) 

112 is_valid_axis_list(area_b) 

113 

114 # Check dimension compatibility 

115 if len(area_a) != len(area_b): 

116 raise ValueError( 

117 f"Area dimension mismatch: {len(area_a)} vs {len(area_b)}" 

118 ) 

119 

120 # Calculate intersection for each dimension 

121 intersection: list[tuple[int, int]] = [] 

122 for axis_range_a, axis_range_b in zip(area_a, area_b): 

123 left = max(axis_range_a[0], axis_range_b[0]) 

124 right = min(axis_range_a[1], axis_range_b[1]) 

125 

126 if left >= right: # No intersection in this dimension 

127 return None 

128 

129 intersection.append((left, right)) 

130 

131 return tuple(intersection) 

132 

133 

134def infer_slice_area_by_rank( 

135 mesh_shape: tuple[int, ...], 

136 tensor_map: Union[list[int], tuple[int, ...]], 

137 rank_id: int, 

138 full_shape: tuple[int, ...] 

139) -> tuple[tuple[int, int], ...]: 

140 """ 

141 Calculates the tensor slice boundaries for a specific rank. 

142 

143 Args: 

144 mesh_shape (tuple[int, ...]): Shape of the mesh shape. 

145 tensor_map (Union[list[int], tuple[int, ...]]): Mapping of tensor dimensions to device dimensions. 

146 rank_id (int): Rank ID to calculate slice for. 

147 full_shape (tuple[int, ...]): Complete shape of the original tensor. 

148 

149 Returns: 

150 tuple[tuple[int, int], ...]: Tuple of (start, end) boundaries for each tensor dimension. 

151 """ 

152 # Helper to get device count along a dimension 

153 def _get_dev_num_along_dim(dim: int) -> int: 

154 return mesh_shape[-dim - 1] if dim != -1 else 1 

155 

156 dims = len(full_shape) 

157 dev_id_list = rank_id_to_dev_id_list(mesh_shape, rank_id) 

158 area: list[tuple[int, int]] = [] 

159 

160 for axis in range(dims): 

161 mapping = tensor_map[axis] 

162 if isinstance(mapping, int): 

163 mapping = (mapping,) # Convert to tuple for consistent handling 

164 

165 # Calculate total number of splits for this axis 

166 split_num = 1 

167 for dim in mapping: 

168 split_num *= _get_dev_num_along_dim(dim) 

169 

170 # Calculate slice ID for this rank 

171 slice_id = 0 

172 coef = 1 

173 for dim in reversed(mapping): 

174 if dim == -1: 

175 continue 

176 slice_id += dev_id_list[-dim - 1] * coef 

177 coef *= _get_dev_num_along_dim(dim) 

178 

179 # Calculate start/end indices for this slice 

180 if full_shape[axis] % split_num != 0: 

181 raise ValueError(f"Shape can not divided along dimension {axis} by {split_num} dev.") 

182 slice_size = full_shape[axis] // split_num 

183 start = slice_id * slice_size 

184 end = start + slice_size 

185 area.append((start, end)) 

186 

187 return tuple(area) 

188 

189 

190class ReshardHandler: 

191 """ 

192 Handles tensor resharding between different distributed layouts. 

193 

194 This class manages the process of reshaping and redistributing tensors between 

195 different parallel layouts. It calculates necessary tensor slices, validates 

196 input layouts, and assembles the final tensor for the target rank. 

197 

198 Args: 

199 param_name (str): Name of the parameter (without pipeline stage prefix). 

200 full_shape (tuple[int, ...]): Complete shape of the tensor before sharding. 

201 from_layout (Optional[Any]): Source layout containing mesh shape, tensor map, and rank list. 

202 to_layout (Optional[Any]): Target layout containing mesh shape, tensor map, and rank list. 

203 to_rank_id (int): Target rank ID to receive the resharded tensor. 

204 

205 Raises: 

206 ValueError: If both layouts are None or layouts contain invalid attributes 

207 TypeError: If layout components are not tuples/lists 

208 """ 

209 def __init__( 

210 self, 

211 param_name: str, 

212 full_shape: tuple[int, ...], 

213 from_layout: Optional[Any], 

214 to_layout: Optional[Any], 

215 to_rank_id: int 

216 ): 

217 # Validate input layouts 

218 check_layout(from_layout, 'from_layout') 

219 check_layout(to_layout, 'to_layout') 

220 

221 if from_layout is None and to_layout is None: 

222 raise ValueError("`from_layout` and `to_layout` cannot both be None.") 

223 

224 # Initialize basic attributes 

225 self.param_name = param_name 

226 self.full_shape = full_shape 

227 

228 # Process source layout configuration 

229 if from_layout is None: 

230 self.from_mesh_shape = (1,) 

231 self.from_tensor_map = tuple(0 for _ in full_shape) 

232 self.from_rank_list = [0] 

233 else: 

234 from_layout_dict = from_layout.to_dict() 

235 self.from_mesh_shape = from_layout_dict["mesh_shape"] 

236 self.from_tensor_map = from_layout_dict["tensor_map"] 

237 self.from_rank_list = from_layout_dict["rank_list"] 

238 

239 # Process target layout configuration 

240 if to_layout is None: 

241 self.to_mesh_shape = (1,) 

242 self.to_tensor_map = tuple(0 for _ in full_shape) 

243 self.to_rank_list = [0] 

244 self.to_rank_id = 0 

245 else: 

246 to_layout_dict = to_layout.to_dict() 

247 self.to_mesh_shape = to_layout_dict["mesh_shape"] 

248 self.to_tensor_map = to_layout_dict["tensor_map"] 

249 self.to_rank_list = to_layout_dict["rank_list"] 

250 self.to_rank_id = to_rank_id 

251 if self.to_rank_id not in self.to_rank_list: 

252 raise ValueError("Input to_rank_id is not in to_rank_list.") 

253 

254 # Calculate device counts and internal rank mappings 

255 self.from_dev_num = len(self.from_rank_list) 

256 self.inner_from_rank_list = range(self.from_dev_num) 

257 self.inner_to_rank_id = self.to_rank_list.index(self.to_rank_id) 

258 

259 # Compute redundancy information 

260 self.inner_deredundancy_from_rank_list = ( 

261 self._infer_inner_deredundancy_rank_list_by_from_layout() 

262 if from_layout else [0] 

263 ) 

264 self.global_union_area_map: dict[int, tuple[tuple[int, int], ...]] = {} 

265 

266 def _infer_inner_deredundancy_rank_list_by_from_layout(self) -> list[int]: 

267 """ 

268 Infers ranks containing non-redundant data from the source layout. 

269 

270 Returns: 

271 List of ranks with unique data slices 

272 """ 

273 inner_deredundancy_rank_list: list[int] = [] 

274 dev_dim = len(self.from_mesh_shape) 

275 

276 # Collect relevant device dimensions from tensor map 

277 from_dev_map = set() 

278 for map_dev in self.from_tensor_map: 

279 if isinstance(map_dev, (list, tuple)): 

280 for map_dev_inner in map_dev: 

281 from_dev_map.add(dev_dim - map_dev_inner - 1) 

282 else: 

283 from_dev_map.add(dev_dim - map_dev - 1) 

284 

285 # Filter ranks with non-redundant data 

286 unused_dims = [dim for dim in range(dev_dim) if dim not in from_dev_map] 

287 if not unused_dims: 

288 return list(self.inner_from_rank_list) 

289 for rank_id in self.inner_from_rank_list: 

290 dev_id_list = rank_id_to_dev_id_list(self.from_mesh_shape, rank_id) 

291 # check redundant 

292 found_redundant = False 

293 for dim in unused_dims: 

294 if dev_id_list[dim] > 0: 

295 found_redundant = True 

296 break 

297 

298 # save not redundant rank 

299 if not found_redundant: 

300 inner_deredundancy_rank_list.append(rank_id) 

301 

302 return inner_deredundancy_rank_list 

303 

304 def infer_all_tensor_offset(self) -> dict[int, tuple[tuple[int, int], ...]]: 

305 """ 

306 Calculates required tensor slices from each source rank. 

307 

308 Determines which parts of the tensor need to be collected from each source 

309 rank to assemble the target tensor slice. 

310 

311 Returns: 

312 Dictionary mapping source ranks to their required slice offsets 

313 """ 

314 # Calculate target area for current rank 

315 self.to_area = infer_slice_area_by_rank( 

316 self.to_mesh_shape, 

317 self.to_tensor_map, 

318 self.inner_to_rank_id, 

319 self.full_shape 

320 ) 

321 

322 # Calculate required slices from each source rank 

323 local_union_areas_map: dict[int, tuple[tuple[int, int], ...]] = {} 

324 self.global_union_area_map.clear() 

325 

326 for inner_rank_id in self.inner_deredundancy_from_rank_list: 

327 # Get source area for this rank 

328 from_area = infer_slice_area_by_rank( 

329 self.from_mesh_shape, 

330 self.from_tensor_map, 

331 inner_rank_id, 

332 self.full_shape 

333 ) 

334 

335 # Find overlapping area between source and target 

336 union_area = infer_intersection(from_area, self.to_area) 

337 if union_area is not None: 

338 source_rank = self.from_rank_list[inner_rank_id] 

339 self.global_union_area_map[source_rank] = union_area 

340 

341 # Calculate relative offsets within source slice 

342 local_union_areas_map[source_rank] = tuple( 

343 (union_range[0] - from_range[0], union_range[1] - from_range[0]) 

344 for union_range, from_range in zip(union_area, from_area) 

345 ) 

346 

347 return local_union_areas_map 

348 

349 def get_real_tensor(self, from_tensor_map: dict[int, np.ndarray]) -> np.ndarray: 

350 """ 

351 Assembles the final tensor for the target rank from collected slices. 

352 

353 Args: 

354 from_tensor_map (dict[int, np.ndarray]): Dictionary mapping source ranks to their tensor slices. 

355 

356 Returns: 

357 np.ndarray: Assembled tensor for the target rank. 

358 

359 Raises: 

360 ValueError: If input slices are missing or have incorrect shapes 

361 """ 

362 if not from_tensor_map: 

363 raise ValueError("Input from_tensor_map cannot be empty") 

364 

365 # Validate input slices 

366 for from_rank_id, from_area in self.global_union_area_map.items(): 

367 if from_rank_id not in from_tensor_map: 

368 raise ValueError( 

369 f"Missing slice data from rank {from_rank_id}. " 

370 "Please provide all required slices from infer_all_tensor_offset." 

371 ) 

372 

373 # Validate slice shape matches expected size 

374 expected_shape = tuple(end - start for start, end in from_area) 

375 actual_shape = from_tensor_map[from_rank_id].shape 

376 if expected_shape != actual_shape: 

377 raise ValueError( 

378 f"Slice from rank {from_rank_id} has incorrect shape. " 

379 f"Expected {expected_shape}, got {actual_shape}." 

380 ) 

381 

382 # Create target tensor and assign slices 

383 to_slice_shape = [end - start for start, end in self.to_area] 

384 dtype = next(iter(from_tensor_map.values())).dtype 

385 real_tensor = np.zeros(to_slice_shape, dtype=dtype) 

386 

387 for from_rank_id, from_slice in from_tensor_map.items(): 

388 from_area = self.global_union_area_map[from_rank_id] 

389 

390 # Calculate assignment indices in target tensor 

391 assign_slices = tuple( 

392 slice(from_axis[0] - to_axis[0], from_axis[1] - to_axis[0]) 

393 for from_axis, to_axis in zip(from_area, self.to_area) 

394 ) 

395 

396 real_tensor[assign_slices] = from_slice 

397 

398 return real_tensor