Coverage for hyper_parallel / core / checkpoint / planner.py: 98%

83 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Planner interfaces and implementations""" 

16import abc 

17from dataclasses import dataclass, field 

18from enum import Enum 

19from typing import Any, Optional, Union 

20 

21from hyper_parallel.core.checkpoint.metadata import ( 

22 Metadata, MetadataIndex 

23) 

24 

25 

26class WriteItemType(Enum): 

27 """Type of write item.""" 

28 TENSOR = "tensor" 

29 BYTE_IO = "byte_io" 

30 

31 

32class LoadItemType(Enum): 

33 """Type of load item.""" 

34 TENSOR = "tensor" 

35 BYTE_IO = "byte_io" 

36 

37 

38@dataclass(frozen=True) 

39class WriteItem: 

40 """ 

41 Item to be written to storage. 

42 

43 Represents a single logical item (tensor or bytes) to be saved. 

44 

45 Attributes: 

46 index: Metadata index identifying this item. 

47 type: Type of write item (TENSOR or BYTE_IO). 

48 tensor_data: Dictionary containing tensor data (for TENSOR type). Default None. 

49 bytes_io_data: Bytes data (for BYTE_IO type). Default None. 

50 """ 

51 index: MetadataIndex 

52 type: WriteItemType 

53 # Keys: 'chunk' (ChunkStorageMetadata), 'properties' (TensorProperties), 'size' (tuple). 

54 # Actual tensor data is in planner's tensor cache, not here, to avoid all_gather of tensors. 

55 tensor_data: Optional[dict[str, Any]] = None 

56 bytes_io_data: Optional[Union[bytes, Any]] = None # Bytes or pickle-serializable object 

57 

58 def tensor_storage_size(self) -> Optional[int]: 

59 """ 

60 Best-effort storage size estimation in bytes for tensor items. 

61 

62 Returns: 

63 Optional[int]: Estimated storage size in bytes for tensor items, 

64 or None if estimation cannot be performed (e.g., for non-tensor items). 

65 """ 

66 if self.type != WriteItemType.TENSOR or not self.tensor_data: 

67 return None 

68 

69 # Try to estimate from metadata 

70 chunk = self.tensor_data.get("chunk") 

71 properties = self.tensor_data.get("properties") 

72 if chunk is None or properties is None: 

73 return None 

74 

75 # Get size from chunk (local chunk size, not global size) 

76 size = chunk.sizes 

77 num = 1 

78 for dim in size: 

79 num *= int(dim) 

80 # Try to get dtype item size from properties 

81 dtype_str = getattr(properties, "dtype", None) 

82 if dtype_str is None: 

83 return int(num) 

84 # Simple estimation: assume common dtypes 

85 dtype_to_size_map = { 

86 "int32": 4, "int64": 8, "bfloat16": 2, "float16": 2, "float32": 4, "float64": 8 

87 } 

88 dtype_str_lower = str(dtype_str).lower() 

89 elem_size = 4 # Default to 4 bytes 

90 for dtype_name, size in dtype_to_size_map.items(): 

91 if dtype_name in dtype_str_lower: 

92 elem_size = size 

93 break 

94 return int(num) * int(elem_size) 

95 

96 

97 

98@dataclass(frozen=True) 

99class ReadItem: 

100 """ 

101 Item to be read from storage. 

102 

103 Represents a single logical read operation, mapping from checkpoint storage 

104 to destination state_dict location. 

105 

106 Attributes: 

107 type: Type of load item (TENSOR or BYTE_IO). 

108 dest_index: Metadata index identifying the destination in state_dict. 

109 dest_offsets: Offsets into the destination tensor (for TENSOR type). 

110 storage_index: Metadata index identifying the source in checkpoint. 

111 storage_offsets: Offsets into the checkpoint storage data. 

112 lengths: Size of the hypercube to copy (dimensions of the data region). 

113 """ 

114 type: LoadItemType 

115 dest_index: MetadataIndex # Index into the state_dict 

116 dest_offsets: tuple # Offsets into destination tensor 

117 storage_index: MetadataIndex # Index into the checkpoint 

118 storage_offsets: tuple # Offset into the checkpoint data 

119 lengths: tuple # Size of the hypercube to copy 

120 

121 

122@dataclass 

123class SavePlan: 

124 """ 

125 Plan for saving checkpoint. 

126 

127 Contains write items and optional storage/planner-specific data. 

128 

129 Attributes: 

130 items: List of WriteItems to be saved. Default []. 

131 storage_data: Storage-specific data (optional). Default None. 

132 planner_data: Planner-specific data (optional). Default None. 

133 """ 

134 items: list[WriteItem] = field(default_factory=list) 

135 storage_data: Optional[dict[MetadataIndex, Any]] = None # Storage-specific data mapping 

136 planner_data: Any = None # Planner-specific data (can be any type) 

137 

138 

139@dataclass 

140class LoadPlan: 

141 """ 

142 Plan for loading checkpoint. 

143 

144 Contains read items and optional storage/planner-specific data. 

145 

146 Attributes: 

147 items: List of ReadItems to be loaded. Default []. 

148 storage_data: Storage-specific data (optional). Default None. 

149 planner_data: Planner-specific data (optional). Default None. 

150 """ 

151 items: list[ReadItem] = field(default_factory=list) 

152 storage_data: Optional[dict[MetadataIndex, Any]] = None # Storage-specific data mapping 

153 planner_data: Any = None # Planner-specific data (can be any type) 

154 

155 

156class SavePlanner(abc.ABC): 

157 """Abstract base class for save planners.""" 

158 

159 @abc.abstractmethod 

160 def configure_planner(self, state_dict: dict[str, Any], **kwargs) -> None: 

161 """ 

162 Configure the planner with state dict. 

163 

164 Args: 

165 state_dict (dict[str, Any]): The state_dict to save. 

166 **kwargs: Additional keyword arguments (e.g., is_coordinator, rank, remove_redundancy, 

167 save_to_minimum_rank). 

168 """ 

169 

170 @abc.abstractmethod 

171 def build_local_plan(self) -> SavePlan: 

172 """ 

173 Build local save plan. 

174 

175 Creates a plan for saving checkpoint data from the current rank's perspective. 

176 This plan contains WriteItems for all tensors and bytes that this rank needs to save. 

177 

178 Returns: 

179 SavePlan: Local save plan containing WriteItems for this rank. 

180 """ 

181 

182 @abc.abstractmethod 

183 def build_global_plan(self, all_plans: list[SavePlan]) -> tuple[list[SavePlan], Metadata]: 

184 """ 

185 Build global plan from all local plans. 

186 

187 Combines local plans from all ranks into a global plan and creates checkpoint metadata. 

188 This method may deduplicate redundant data across ranks and assign storage indices. 

189 

190 Args: 

191 all_plans (list[SavePlan]): List of local save plans from all ranks. 

192 

193 Returns: 

194 tuple[list[SavePlan], Metadata]: Updated global plans (one per rank) and 

195 checkpoint metadata containing information about all saved items. 

196 """ 

197 

198 @abc.abstractmethod 

199 def finalize_plan(self, plan: SavePlan) -> SavePlan: 

200 """ 

201 Finalize the plan. 

202 

203 Performs any final adjustments to the plan before execution, such as updating 

204 tensor cache keys or performing planner-specific optimizations. 

205 

206 Args: 

207 plan (SavePlan): The plan to finalize. 

208 

209 Returns: 

210 SavePlan: The finalized plan ready for execution. 

211 """ 

212 

213 @abc.abstractmethod 

214 def get_tensor(self, index: MetadataIndex) -> Any: 

215 """ 

216 Get tensor data for a given MetadataIndex. 

217 

218 This method allows storage writers to retrieve tensor data when needed, 

219 avoiding the need to store tensors in WriteItem.tensor_data (which would 

220 be transmitted during all_gather operations). 

221 

222 Args: 

223 index (MetadataIndex): Metadata index identifying the tensor. 

224 

225 Returns: 

226 Any: Tensor data (tensor-like object) or None if not found. 

227 """ 

228 

229 

230class LoadPlanner(abc.ABC): 

231 """Abstract base class for load planners.""" 

232 

233 @abc.abstractmethod 

234 def configure_planner(self, state_dict: dict[str, Any], metadata: Metadata, **kwargs) -> None: 

235 """ 

236 Configure the planner with state dict and metadata. 

237 

238 Args: 

239 state_dict (dict[str, Any]): The state_dict to load into (modified in-place). 

240 metadata (Metadata): Checkpoint metadata. 

241 **kwargs: Additional keyword arguments (e.g., is_coordinator, rank). 

242 """ 

243 

244 @abc.abstractmethod 

245 def build_local_plan(self) -> LoadPlan: 

246 """ 

247 Build local load plan. 

248 

249 Creates a plan for loading checkpoint data from the current rank's perspective. 

250 This plan contains ReadItems for all tensors and bytes that this rank needs to load. 

251 

252 Returns: 

253 LoadPlan: Local load plan containing ReadItems for this rank. 

254 """ 

255 

256 @abc.abstractmethod 

257 def build_global_plan(self, all_plans: list[LoadPlan]) -> list[LoadPlan]: 

258 """ 

259 Build global plan from all local plans. 

260 

261 Combines local plans from all ranks into a global plan. This method may 

262 coordinate across ranks or perform optimizations. 

263 

264 Args: 

265 all_plans (list[LoadPlan]): List of local load plans from all ranks. 

266 

267 Returns: 

268 list[LoadPlan]: Updated global load plans (one per rank). 

269 """ 

270 

271 @abc.abstractmethod 

272 def finalize_plan(self, plan: LoadPlan) -> LoadPlan: 

273 """ 

274 Finalize the plan. 

275 

276 Performs any final adjustments to the plan before execution, such as 

277 performing planner-specific optimizations or validations. 

278 

279 Args: 

280 plan (LoadPlan): The plan to finalize. 

281 

282 Returns: 

283 LoadPlan: The finalized plan ready for execution. 

284 """ 

285 

286 @abc.abstractmethod 

287 def acquire_tensor(self, read_item: ReadItem) -> Any: 

288 """ 

289 Acquire tensor for read item. 

290 

291 Returns a tensor slice/view where data should be written. 

292 

293 Args: 

294 read_item (ReadItem): Read item to acquire tensor for. 

295 

296 Returns: 

297 Any: Acquired tensor slice/view (tensor-like object). 

298 """ 

299 

300 @abc.abstractmethod 

301 def apply_tensor(self, read_item: ReadItem, tensor: Any) -> None: 

302 """ 

303 Apply tensor after reading. 

304 

305 Args: 

306 read_item (ReadItem): Read item. 

307 tensor (Any): Tensor data to apply (tensor-like object). 

308 """ 

309 

310 @abc.abstractmethod 

311 def apply_bytes(self, read_item: ReadItem, value: bytes) -> None: 

312 """ 

313 Apply bytes data. 

314 

315 Args: 

316 read_item (ReadItem): The read item specifying the destination. 

317 value (bytes): The bytes data to deserialize and apply. 

318 """