Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / reshard.py: 8%

160 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""resharding tensor""" 

16import operator 

17from typing import Any, Optional, Union 

18from functools import reduce 

19import numpy as np 

20 

21 

22def check_layout(layout: Optional[Any], name: str) -> None: 

23 """ 

24 Validates that a layout contains required attributes with correct types. 

25 

26 Args: 

27 layout (Optional[Any]): Layout object to validate. 

28 name (str): Name of the layout (for error messages). 

29 

30 Raises: 

31 ValueError: If layout missing required attributes or has size mismatches 

32 TypeError: If layout components are not tuples/lists 

33 """ 

34 if not layout: 

35 return 

36 

37 # Check for required attributes 

38 required_attrs = ['mesh_shape', '_tensor_map', '_rank_list'] 

39 for attr in required_attrs: 

40 if not hasattr(layout, attr): 

41 raise ValueError( 

42 f"Layout {name} must contain attribute {attr}" 

43 ) 

44 

45 # Validate component types 

46 def check_type_is_sequence(obj: Any, obj_name: str) -> None: 

47 if not isinstance(obj, (tuple, list)): 

48 raise TypeError( 

49 f"Layout {name} {obj_name} must be tuple or list, " 

50 f"but got {type(obj).__name__}" 

51 ) 

52 

53 layout_dict = layout.to_dict() 

54 check_type_is_sequence(layout_dict['mesh_shape'], 'mesh_shape') 

55 check_type_is_sequence(layout_dict['tensor_map'], 'tensor_map') 

56 check_type_is_sequence(layout_dict['rank_list'], 'rank_list') 

57 

58 # Validate rank list size matches device count 

59 dev_num = reduce(operator.mul, layout_dict['mesh_shape']) 

60 if len(layout_dict['rank_list']) != dev_num: 

61 raise ValueError( 

62 f"Layout {name} rank_list size ({len(layout_dict['rank_list'])}) " 

63 f"must match device count ({dev_num})" 

64 ) 

65 

66 

67def rank_id_to_dev_id_list(mesh_shape: tuple[int, ...], rank_id: int) -> list[int]: 

68 """ 

69 Converts a rank ID to a list of device IDs based on the mesh shape. 

70 

71 Args: 

72 mesh_shape (tuple[int, ...]): Shape of the mesh shape. 

73 rank_id (int): Global rank ID to convert. 

74 

75 Returns: 

76 list[int]: List of device IDs corresponding to the rank. 

77 """ 

78 dims = len(mesh_shape) 

79 dev_id_list = [0] * dims 

80 

81 for i in range(dims - 1, -1, -1): 

82 dev_id_list[i] = rank_id % mesh_shape[i] 

83 rank_id = rank_id // mesh_shape[i] 

84 

85 return dev_id_list 

86 

87 

88def infer_intersection( 

89 area_a: tuple[tuple[int, int], ...], 

90 area_b: tuple[tuple[int, int], ...] 

91) -> Optional[tuple[tuple[int, int], ...]]: 

92 """ 

93 Calculates the intersection of two tensor slice areas. 

94 

95 Args: 

96 area_a (tuple[tuple[int, int], ...]): First area to intersect. 

97 area_b (tuple[tuple[int, int], ...]): Second area to intersect. 

98 

99 Returns: 

100 Optional[tuple[tuple[int, int], ...]]: Tuple of intersection boundaries or None if no intersection. 

101 """ 

102 # Validate input formats 

103 def is_valid_axis_list(axis_list: Any) -> None: 

104 if not isinstance(axis_list, (tuple, list)): 

105 raise TypeError("Area must be a tuple of ranges") 

106 for axis_range in axis_list: 

107 if (not isinstance(axis_range, (tuple, list)) \ 

108 or len(axis_range) != 2): 

109 raise TypeError("Each axis range must be a 2-element tuple") 

110 

111 is_valid_axis_list(area_a) 

112 is_valid_axis_list(area_b) 

113 

114 # Check dimension compatibility 

115 if len(area_a) != len(area_b): 

116 raise ValueError( 

117 f"Area dimension mismatch: {len(area_a)} vs {len(area_b)}" 

118 ) 

119 

120 # Calculate intersection for each dimension 

121 intersection: list[tuple[int, int]] = [] 

122 for axis_range_a, axis_range_b in zip(area_a, area_b): 

123 left = max(axis_range_a[0], axis_range_b[0]) 

124 right = min(axis_range_a[1], axis_range_b[1]) 

125 

126 if left >= right: # No intersection in this dimension 

127 return None 

128 

129 intersection.append((left, right)) 

130 

131 return tuple(intersection) 

132 

133 

134def infer_slice_area_by_rank( 

135 mesh_shape: tuple[int, ...], 

136 tensor_map: Union[list[int], tuple[int, ...]], 

137 rank_id: int, 

138 full_shape: tuple[int, ...] 

139) -> tuple[tuple[int, int], ...]: 

140 """ 

141 Calculates the tensor slice boundaries for a specific rank. 

142 

143 Args: 

144 mesh_shape (tuple[int, ...]): Shape of the mesh shape. 

145 tensor_map (Union[list[int], tuple[int, ...]]): Mapping of tensor dimensions to device dimensions. 

146 rank_id (int): Rank ID to calculate slice for. 

147 full_shape (tuple[int, ...]): Complete shape of the original tensor. 

148 

149 Returns: 

150 tuple[tuple[int, int], ...]: Tuple of (start, end) boundaries for each tensor dimension. 

151 """ 

152 # Helper to get device count along a dimension 

153 def _get_dev_num_along_dim(dim: int) -> int: 

154 return mesh_shape[-dim - 1] if dim != -1 else 1 

155 

156 dims = len(full_shape) 

157 dev_id_list = rank_id_to_dev_id_list(mesh_shape, rank_id) 

158 area: list[tuple[int, int]] = [] 

159 

160 for axis in range(dims): 

161 mapping = tensor_map[axis] 

162 if isinstance(mapping, int): 

163 mapping = (mapping,) # Convert to tuple for consistent handling 

164 

165 # Calculate total number of splits for this axis 

166 split_num = 1 

167 for dim in mapping: 

168 split_num *= _get_dev_num_along_dim(dim) 

169 

170 # Calculate slice ID for this rank 

171 slice_id = 0 

172 coef = 1 

173 for dim in reversed(mapping): 

174 if dim == -1: 

175 continue 

176 slice_id += dev_id_list[-dim - 1] * coef 

177 coef *= _get_dev_num_along_dim(dim) 

178 

179 # Calculate start/end indices for this slice 

180 if full_shape[axis] % split_num != 0: 

181 raise ValueError(f"Shape can not divided along dimension {axis} by {split_num} dev.") 

182 slice_size = full_shape[axis] // split_num 

183 start = slice_id * slice_size 

184 end = start + slice_size 

185 area.append((start, end)) 

186 

187 return tuple(area) 

188 

189 

190class ReshardHandler: 

191 """ 

192 Handles tensor resharding between different distributed layouts. 

193 

194 This class manages the process of reshaping and redistributing tensors between 

195 different parallel layouts. It calculates necessary tensor slices, validates 

196 input layouts, and assembles the final tensor for the target rank. 

197 

198 Args: 

199 param_name (str): Name of the parameter (without pipeline stage prefix). 

200 full_shape (tuple[int, ...]): Complete shape of the tensor before sharding. 

201 from_layout (Optional[Any]): Source layout containing mesh shape, tensor map, and rank list. 

202 to_layout (Optional[Any]): Target layout containing mesh shape, tensor map, and rank list. 

203 to_rank_id (int): Target rank ID to receive the resharded tensor. 

204 

205 Raises: 

206 ValueError: If both layouts are None or layouts contain invalid attributes 

207 TypeError: If layout components are not tuples/lists 

208 """ 

209 def __init__( 

210 self, 

211 param_name: str, 

212 full_shape: tuple[int, ...], 

213 from_layout: Optional[Any], 

214 to_layout: Optional[Any], 

215 to_rank_id: int 

216 ): 

217 # Validate input layouts 

218 check_layout(from_layout, 'from_layout') 

219 check_layout(to_layout, 'to_layout') 

220 

221 if from_layout is None and to_layout is None: 

222 raise ValueError("`from_layout` and `to_layout` cannot both be None.") 

223 

224 # Initialize basic attributes 

225 self.param_name = param_name 

226 self.full_shape = full_shape 

227 

228 # Process source layout configuration 

229 if from_layout is None: 

230 self.from_mesh_shape = (1,) 

231 self.from_tensor_map = tuple(0 for _ in full_shape) 

232 self.from_rank_list = [0] 

233 else: 

234 from_layout_dict = from_layout.to_dict() 

235 self.from_mesh_shape = from_layout_dict["mesh_shape"] 

236 self.from_tensor_map = from_layout_dict["tensor_map"] 

237 self.from_rank_list = from_layout_dict["rank_list"] 

238 

239 # Process target layout configuration 

240 if to_layout is None: 

241 self.to_mesh_shape = (1,) 

242 self.to_tensor_map = tuple(0 for _ in full_shape) 

243 self.to_rank_list = [0] 

244 self.to_rank_id = 0 

245 else: 

246 to_layout_dict = to_layout.to_dict() 

247 self.to_mesh_shape = to_layout_dict["mesh_shape"] 

248 self.to_tensor_map = to_layout_dict["tensor_map"] 

249 self.to_rank_list = to_layout_dict["rank_list"] 

250 self.to_rank_id = to_rank_id 

251 if self.to_rank_id not in self.to_rank_list: 

252 raise ValueError("Input to_rank_id is not in to_rank_list.") 

253 

254 # Calculate device counts and internal rank mappings 

255 self.from_dev_num = len(self.from_rank_list) 

256 self.inner_from_rank_list = range(self.from_dev_num) 

257 self.inner_to_rank_id = self.to_rank_list.index(self.to_rank_id) 

258 

259 # Compute redundancy information 

260 self.inner_deredundancy_from_rank_list = ( 

261 self._infer_inner_deredundancy_rank_list_by_from_layout() 

262 if from_layout else [0] 

263 ) 

264 self.global_union_area_map: dict[int, tuple[tuple[int, int], ...]] = {} 

265 self.to_area = () # Initialized in infer_all_tensor_offset() 

266 

267 def _infer_inner_deredundancy_rank_list_by_from_layout(self) -> list[int]: 

268 """ 

269 Infers ranks containing non-redundant data from the source layout. 

270 

271 Returns: 

272 List of ranks with unique data slices 

273 """ 

274 inner_deredundancy_rank_list: list[int] = [] 

275 dev_dim = len(self.from_mesh_shape) 

276 

277 # Collect relevant device dimensions from tensor map 

278 from_dev_map = set() 

279 for map_dev in self.from_tensor_map: 

280 if isinstance(map_dev, (list, tuple)): 

281 for map_dev_inner in map_dev: 

282 from_dev_map.add(dev_dim - map_dev_inner - 1) 

283 else: 

284 from_dev_map.add(dev_dim - map_dev - 1) 

285 

286 # Filter ranks with non-redundant data 

287 unused_dims = [dim for dim in range(dev_dim) if dim not in from_dev_map] 

288 if not unused_dims: 

289 return list(self.inner_from_rank_list) 

290 for rank_id in self.inner_from_rank_list: 

291 dev_id_list = rank_id_to_dev_id_list(self.from_mesh_shape, rank_id) 

292 # check redundant 

293 found_redundant = False 

294 for dim in unused_dims: 

295 if dev_id_list[dim] > 0: 

296 found_redundant = True 

297 break 

298 

299 # save not redundant rank 

300 if not found_redundant: 

301 inner_deredundancy_rank_list.append(rank_id) 

302 

303 return inner_deredundancy_rank_list 

304 

305 def infer_all_tensor_offset(self) -> dict[int, tuple[tuple[int, int], ...]]: 

306 """ 

307 Calculates required tensor slices from each source rank. 

308 

309 Determines which parts of the tensor need to be collected from each source 

310 rank to assemble the target tensor slice. 

311 

312 Returns: 

313 Dictionary mapping source ranks to their required slice offsets 

314 """ 

315 # Calculate target area for current rank 

316 self.to_area = infer_slice_area_by_rank( 

317 self.to_mesh_shape, 

318 self.to_tensor_map, 

319 self.inner_to_rank_id, 

320 self.full_shape 

321 ) 

322 

323 # Calculate required slices from each source rank 

324 local_union_areas_map: dict[int, tuple[tuple[int, int], ...]] = {} 

325 self.global_union_area_map.clear() 

326 

327 for inner_rank_id in self.inner_deredundancy_from_rank_list: 

328 # Get source area for this rank 

329 from_area = infer_slice_area_by_rank( 

330 self.from_mesh_shape, 

331 self.from_tensor_map, 

332 inner_rank_id, 

333 self.full_shape 

334 ) 

335 

336 # Find overlapping area between source and target 

337 union_area = infer_intersection(from_area, self.to_area) 

338 if union_area is not None: 

339 source_rank = self.from_rank_list[inner_rank_id] 

340 self.global_union_area_map[source_rank] = union_area 

341 

342 # Calculate relative offsets within source slice 

343 local_union_areas_map[source_rank] = tuple( 

344 (union_range[0] - from_range[0], union_range[1] - from_range[0]) 

345 for union_range, from_range in zip(union_area, from_area) 

346 ) 

347 

348 return local_union_areas_map 

349 

350 def get_real_tensor(self, from_tensor_map: dict[int, np.ndarray]) -> np.ndarray: 

351 """ 

352 Assembles the final tensor for the target rank from collected slices. 

353 

354 Args: 

355 from_tensor_map (dict[int, np.ndarray]): Dictionary mapping source ranks to their tensor slices. 

356 

357 Returns: 

358 np.ndarray: Assembled tensor for the target rank. 

359 

360 Raises: 

361 ValueError: If input slices are missing or have incorrect shapes 

362 """ 

363 if not from_tensor_map: 

364 raise ValueError("Input from_tensor_map cannot be empty") 

365 

366 # Validate input slices 

367 for from_rank_id, from_area in self.global_union_area_map.items(): 

368 if from_rank_id not in from_tensor_map: 

369 raise ValueError( 

370 f"Missing slice data from rank {from_rank_id}. " 

371 "Please provide all required slices from infer_all_tensor_offset." 

372 ) 

373 

374 # Validate slice shape matches expected size 

375 expected_shape = tuple(end - start for start, end in from_area) 

376 actual_shape = from_tensor_map[from_rank_id].shape 

377 if expected_shape != actual_shape: 

378 raise ValueError( 

379 f"Slice from rank {from_rank_id} has incorrect shape. " 

380 f"Expected {expected_shape}, got {actual_shape}." 

381 ) 

382 

383 # Create target tensor and assign slices 

384 to_slice_shape = [end - start for start, end in self.to_area] 

385 dtype = next(iter(from_tensor_map.values())).dtype 

386 real_tensor = np.zeros(to_slice_shape, dtype=dtype) 

387 

388 for from_rank_id, from_slice in from_tensor_map.items(): 

389 from_area = self.global_union_area_map[from_rank_id] 

390 

391 # Calculate assignment indices in target tensor 

392 assign_slices = tuple( 

393 slice(from_axis[0] - to_axis[0], from_axis[1] - to_axis[0]) 

394 for from_axis, to_axis in zip(from_area, self.to_area) 

395 ) 

396 

397 real_tensor[assign_slices] = from_slice 

398 

399 return real_tensor