Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / util.py: 39%

139 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Common utility functions.""" 

16import dataclasses 

17from collections import defaultdict 

18from collections.abc import Collection, Mapping 

19from pathlib import Path 

20from typing import Any, Union 

21 

22from hyper_parallel.core.distributed_checkpoint.metadata import ChunkStorageMetadata, MetadataIndex 

23from hyper_parallel.core.distributed_checkpoint.planner import SavePlan, WriteItem 

24from hyper_parallel.core.distributed_checkpoint.reshard import infer_slice_area_by_rank 

25from hyper_parallel.core.dtensor.dtensor import DTensor 

26from hyper_parallel.platform import get_platform 

27 

28 

29platform = get_platform() 

30Tensor = platform.Tensor 

31 

32 

33def check_path(path: Union[Path, str]) -> None: 

34 """ 

35 Check whether path is existing or not. 

36 

37 Args: 

38 path (Union[Path, str]): path to check. Can only a file name in current directory, a pure directory, or a file 

39 name with directory. When path contains a directory, the function will check whether the directory exists, if 

40 not, the directory will be created. 

41 """ 

42 path_obj = Path(path) if isinstance(path, str) else path 

43 

44 if path_obj.exists(): 

45 return 

46 

47 if path_obj.suffix: 

48 path_obj.parent.mkdir(parents=True, exist_ok=True) 

49 else: 

50 path_obj.mkdir(parents=True, exist_ok=True) 

51 

52 

53def has_valid_filename(path: Path) -> bool: 

54 """ 

55 Check whether path has valid filename. A filename should contain name and suffix, name and suffix must contain 

56 letters, and then can have numbers and underscores. 

57 

58 Args: 

59 path (Path): path to check. 

60 

61 Return: 

62 bool: whether path has a valid filename. 

63 """ 

64 conditions = ( 

65 path.name, 

66 path.suffix, 

67 len(path.suffix) > 1, 

68 path.stem, 

69 any(c.isalpha() for c in path.stem), 

70 any(c.isalpha() for c in path.suffix[1:]) 

71 ) 

72 return all(conditions) 

73 

74 

75def narrow_tensor_by_index(tensor: Any, offsets: tuple, lengths: tuple) -> Any: 

76 """ 

77 Narrow the tensor by (offsets, lengths) per dimension. 

78 

79 Used for resharding operations to extract a slice from a tensor. 

80 Compatible with both torch and mindspore (uses slice indexing). 

81 

82 Args: 

83 tensor (Any): The tensor to narrow (tensor-like object supporting indexing). 

84 offsets (tuple): Tuple of offsets per dimension. 

85 lengths (tuple): Tuple of lengths per dimension. 

86 

87 Returns: 

88 Any: The narrowed tensor slice (tensor-like object). 

89 """ 

90 if not offsets or not lengths: 

91 return tensor 

92 slices = tuple( 

93 slice(int(off), int(off) + int(ln)) 

94 for off, ln in zip(offsets, lengths) 

95 ) 

96 return tensor[slices] 

97 

98 

99def chunk_to_area(chunk: ChunkStorageMetadata) -> tuple[tuple[int, int], ...]: 

100 """ 

101 Convert ChunkStorageMetadata to (start, end) area per dimension. 

102 

103 Args: 

104 chunk (ChunkStorageMetadata): ChunkStorageMetadata instance with offsets and sizes. 

105 

106 Returns: 

107 tuple[tuple[int, int], ...]: Tuple of (start, end) tuples for each dimension. 

108 """ 

109 return tuple( 

110 (chunk.offsets[i], chunk.offsets[i] + chunk.sizes[i]) 

111 for i in range(len(chunk.offsets)) 

112 ) 

113 

114 

115def create_chunk_list_for_tensor(obj: Union[Tensor, DTensor]) -> list[ChunkStorageMetadata]: 

116 """ 

117 Create list of local chunks for the given object (DTensor or plain tensor). 

118 

119 Used to determine what this rank needs to load (resharding). 

120 

121 Args: 

122 obj (Union[Tensor, DTensor]): hyper DTensor or platform Tensor. 

123 

124 Returns: 

125 list[ChunkStorageMetadata]: List of ChunkStorageMetadata representing 

126 local chunks needed by this rank. 

127 """ 

128 if isinstance(obj, DTensor): 

129 layout = obj.layout 

130 if layout is None: 

131 shape = obj.shape if hasattr(obj, "shape") else obj.to_local().shape 

132 return [ChunkStorageMetadata(offsets=(0,) * len(shape), sizes=tuple(shape))] 

133 

134 mesh_shape = getattr(layout, "mesh_shape", None) or getattr(layout, "_mesh", None) 

135 tensor_map = getattr(layout, "tensor_map", None) or getattr(layout, "_tensor_map", None) 

136 rank_list = getattr(layout, "rank_list", None) or getattr(layout, "_rank_list", None) 

137 

138 if mesh_shape is None or tensor_map is None or rank_list is None: 

139 shape = obj.shape if hasattr(obj, "shape") else obj.to_local().shape 

140 return [ChunkStorageMetadata(offsets=(0,) * len(shape), sizes=tuple(shape))] 

141 

142 current_rank = platform.get_rank() 

143 if current_rank not in rank_list: 

144 return [] 

145 

146 inner_rank_id = rank_list.index(current_rank) 

147 full_shape = obj.shape 

148 slice_area = infer_slice_area_by_rank( 

149 mesh_shape=mesh_shape, 

150 tensor_map=tensor_map, 

151 rank_id=inner_rank_id, 

152 full_shape=full_shape, 

153 ) 

154 offsets = tuple(s for s, _ in slice_area) 

155 sizes = tuple(e - s for s, e in slice_area) 

156 return [ChunkStorageMetadata(offsets=offsets, sizes=sizes)] 

157 

158 if isinstance(obj, Tensor): 

159 # platform.Tensor has exactly one chunk in metadata (full tensor) 

160 shape = tuple(obj.shape) 

161 return [ChunkStorageMetadata(offsets=(0,) * len(shape), sizes=shape)] 

162 

163 raise ValueError(f"Not support type {type(obj)} for creating chunk list ") 

164 

165 

166def remove_redundant_plans( 

167 all_plans: list[SavePlan], 

168 save_to_minimum_rank: bool = False, 

169) -> list[SavePlan]: 

170 """ 

171 Remove duplicate entries across SavePlans. For each duplicate, only one plan 

172 keeps the entry. The selection prefers the smallest planned storage size 

173 (or the minimum rank when save_to_minimum_rank is True). 

174 

175 Args: 

176 all_plans (list[SavePlan]): List of save plans to deduplicate. 

177 save_to_minimum_rank (bool): If True, assign duplicates to the minimum rank; else to plan with minimal storage. 

178 Default False. 

179 """ 

180 # Build mapping from item index to set of plan indices containing it 

181 duplicate_map: dict[MetadataIndex, set[int]] = defaultdict(set) 

182 # Registry to retrieve WriteItem by its index 

183 item_registry: dict[MetadataIndex, WriteItem] = {} 

184 # Track which items remain in each plan after deduplication 

185 remaining_items: list[set[MetadataIndex]] = [ 

186 {entry.index for entry in plan.items} for plan in all_plans 

187 ] 

188 

189 # Collect all items and their plan associations 

190 for idx, plan in enumerate(all_plans): 

191 for entry in plan.items: 

192 duplicate_map[entry.index].add(idx) 

193 item_registry[entry.index] = entry 

194 

195 storage_sizes = [0] * len(all_plans) 

196 

197 # Separate unique items (appear in only one plan) from duplicates 

198 # Process unique items first to prevent them from affecting load balancing 

199 single_plan_items: list[tuple[MetadataIndex, int]] = [] 

200 multi_plan_items: list[tuple[MetadataIndex, set[int]]] = [] 

201 

202 for item_key, containing_plans in duplicate_map.items(): 

203 if len(containing_plans) == 1: 

204 single_plan_items.append((item_key, next(iter(containing_plans)))) 

205 else: 

206 multi_plan_items.append((item_key, containing_plans)) 

207 

208 # First pass: handle items that appear in only one plan 

209 for item_key, target_idx in single_plan_items: 

210 entry = item_registry[item_key] 

211 storage_sizes[target_idx] += entry.tensor_storage_size() or 1 

212 

213 # Second pass: assign duplicate items to the plan with minimal storage size 

214 for item_key, containing_plans in multi_plan_items: 

215 if save_to_minimum_rank: 

216 target_plan = min(containing_plans) 

217 else: 

218 target_plan = min( 

219 containing_plans, key=lambda p_idx: storage_sizes[p_idx] 

220 ) 

221 

222 entry = item_registry[item_key] 

223 storage_sizes[target_plan] += entry.tensor_storage_size() or 1 

224 # Remove this item from all other plans 

225 for p_idx in containing_plans - {target_plan}: 

226 remaining_items[p_idx].discard(item_key) 

227 

228 if len(all_plans) != len(remaining_items): 

229 raise AssertionError("len(all_plans) != len(remaining_items)") 

230 

231 # Generate deduplicated plans with only remaining items 

232 return [ 

233 dataclasses.replace( 

234 plan, items=[entry for entry in plan.items if entry.index in item_set] 

235 ) 

236 for plan, item_set in zip(all_plans, remaining_items) 

237 ] 

238 

239 

240def traverse_state_dict( 

241 state_dict: Any, 

242 visitor: Any, 

243) -> None: 

244 """ 

245 Invoke ``visitor`` for each value recursively in ``state_dict``. 

246 Mapping will be traversed and ``visitor`` will be applied to the leaf elements. 

247 ``visitor`` will only be applied to elements in a list or a tuple, if the 

248 container contains tensors or mappings. 

249 """ 

250 

251 def _is_terminal(value: Any) -> bool: 

252 """Leaf-like container: no nested mappings/lists/tuples/tensors to recurse into.""" 

253 values: Collection 

254 if isinstance(value, Mapping): 

255 return False 

256 if isinstance(value, (list, tuple)): 

257 values = value 

258 else: 

259 return True 

260 

261 for entry in values: 

262 if isinstance(entry, (Mapping, list, tuple)) and not _is_terminal(entry): 

263 return False 

264 if isinstance(entry, Tensor): 

265 return False 

266 return True 

267 

268 def _traverse_obj(path: tuple[Any, ...], value: Any) -> None: 

269 if isinstance(value, Mapping): 

270 for k, v in value.items(): 

271 _traverse_obj(path + (str(k),), v) 

272 elif _is_terminal(value): 

273 visitor(path, value) 

274 elif isinstance(value, (list, tuple)): 

275 for i, v in enumerate(value): 

276 _traverse_obj(path + (i,), v) 

277 

278 for key, value in state_dict.items(): 

279 _traverse_obj((str(key),), value) 

280 

281 

282def flatten_state_dict(state_dict: Any) -> tuple[dict[str, Any], dict[str, tuple[Any, ...]]]: 

283 """Flatten a nested state dict to dotted FQN keys; returns ``(flat_dict, fqn -> path)``.""" 

284 fqn_names: dict[str, Any] = {} 

285 mappings: dict[str, tuple[Any, ...]] = {} 

286 

287 def flat_copy(path: tuple[Any, ...], value: Any) -> None: 

288 new_fqn = ".".join(map(str, path)) 

289 if new_fqn in fqn_names: 

290 raise ValueError( 

291 f"Duplicate flattened FQN {new_fqn!r} when converting nested state_dict; " 

292 "two different values map to the same dotted name." 

293 ) 

294 fqn_names[new_fqn] = value 

295 mappings[new_fqn] = path 

296 

297 traverse_state_dict(state_dict, flat_copy) 

298 return fqn_names, mappings 

299 

300 

301def set_element(root_dict: Any, path: tuple[Any, ...], value: Any) -> None: 

302 """Set ``value`` in ``root_dict`` along the ``path`` object path.""" 

303 if not path: 

304 raise ValueError("path must be non-empty") 

305 cur_container: Any = root_dict 

306 

307 def extend_list(lst: list[Any], idx: int) -> None: 

308 while len(lst) <= idx: 

309 lst.append(None) 

310 

311 for i in range(1, len(path)): 

312 prev_key = path[i - 1] 

313 next_key = path[i] 

314 def_val: Any = {} if isinstance(next_key, str) else [] 

315 

316 if isinstance(cur_container, Mapping): 

317 cur_container = cur_container.setdefault(prev_key, def_val) 

318 else: 

319 extend_list(cur_container, prev_key) 

320 if cur_container[prev_key] is None: 

321 cur_container[prev_key] = def_val 

322 cur_container = cur_container[prev_key] 

323 

324 last_key = path[-1] 

325 if isinstance(last_key, int): 

326 extend_list(cur_container, last_key) 

327 

328 cur_container[last_key] = value