Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / planner.py: 75%

83 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Planner interfaces and implementations""" 

16import abc 

17from dataclasses import dataclass, field 

18from enum import Enum 

19from typing import Any, Optional, Union 

20 

21from hyper_parallel.core.distributed_checkpoint.metadata import Metadata, MetadataIndex 

22 

23 

24class WriteItemType(Enum): 

25 """Type of write item.""" 

26 TENSOR = "tensor" 

27 BYTE_IO = "byte_io" 

28 

29 

30class LoadItemType(Enum): 

31 """Type of load item.""" 

32 TENSOR = "tensor" 

33 BYTE_IO = "byte_io" 

34 

35 

36@dataclass(frozen=True) 

37class WriteItem: 

38 """ 

39 Item to be written to storage. 

40 

41 Represents a single logical item (tensor or bytes) to be saved. 

42 

43 Attributes: 

44 index: Metadata index identifying this item. 

45 type: Type of write item (TENSOR or BYTE_IO). 

46 tensor_data: Dictionary containing tensor data (for TENSOR type). Default None. 

47 bytes_io_data: Bytes data (for BYTE_IO type). Default None. 

48 """ 

49 index: MetadataIndex 

50 type: WriteItemType 

51 # Keys: 'chunk' (ChunkStorageMetadata), 'properties' (TensorProperties), 'size' (tuple). 

52 # Actual tensor data is in planner's tensor cache, not here, to avoid all_gather of tensors. 

53 tensor_data: Optional[dict[str, Any]] = None 

54 bytes_io_data: Optional[Union[bytes, Any]] = None # Bytes or pickle-serializable object 

55 

56 def tensor_storage_size(self) -> Optional[int]: 

57 """ 

58 Best-effort storage size estimation in bytes for tensor items. 

59 

60 Returns: 

61 Optional[int]: Estimated storage size in bytes for tensor items, 

62 or None if estimation cannot be performed (e.g., for non-tensor items). 

63 """ 

64 if self.type != WriteItemType.TENSOR or not self.tensor_data: 

65 return None 

66 

67 # Try to estimate from metadata 

68 chunk = self.tensor_data.get("chunk") 

69 properties = self.tensor_data.get("properties") 

70 if chunk is None or properties is None: 

71 return None 

72 

73 # Get size from chunk (local chunk size, not global size) 

74 size = chunk.sizes 

75 num = 1 

76 for dim in size: 

77 num *= int(dim) 

78 # Try to get dtype item size from properties 

79 dtype_str = getattr(properties, "dtype", None) 

80 if dtype_str is None: 

81 return int(num) 

82 # Simple estimation: assume common dtypes 

83 dtype_to_size_map = { 

84 "int32": 4, "int64": 8, "bfloat16": 2, "float16": 2, "float32": 4, "float64": 8 

85 } 

86 dtype_str_lower = str(dtype_str).lower() 

87 elem_size = 4 # Default to 4 bytes 

88 for dtype_name, size in dtype_to_size_map.items(): 

89 if dtype_name in dtype_str_lower: 

90 elem_size = size 

91 break 

92 return int(num) * int(elem_size) 

93 

94 

95 

96@dataclass(frozen=True) 

97class ReadItem: 

98 """ 

99 Item to be read from storage. 

100 

101 Represents a single logical read operation, mapping from checkpoint storage 

102 to destination state_dict location. 

103 

104 Attributes: 

105 type: Type of load item (TENSOR or BYTE_IO). 

106 dest_index: Metadata index identifying the destination in state_dict. 

107 dest_offsets: Offsets into the destination tensor (for TENSOR type). 

108 storage_index: Metadata index identifying the source in checkpoint. 

109 storage_offsets: Offsets into the checkpoint storage data. 

110 lengths: Size of the hypercube to copy (dimensions of the data region). 

111 """ 

112 type: LoadItemType 

113 dest_index: MetadataIndex # Index into the state_dict 

114 dest_offsets: tuple # Offsets into destination tensor 

115 storage_index: MetadataIndex # Index into the checkpoint 

116 storage_offsets: tuple # Offset into the checkpoint data 

117 lengths: tuple # Size of the hypercube to copy 

118 

119 

120@dataclass 

121class SavePlan: 

122 """ 

123 Plan for saving checkpoint. 

124 

125 Contains write items and optional storage/planner-specific data. 

126 

127 Attributes: 

128 items: List of WriteItems to be saved. Default []. 

129 storage_data: Storage-specific data (optional). Default None. 

130 planner_data: Planner-specific data (optional). Default None. 

131 """ 

132 items: list[WriteItem] = field(default_factory=list) 

133 storage_data: Optional[dict[MetadataIndex, Any]] = None # Storage-specific data mapping 

134 planner_data: Any = None # Planner-specific data (can be any type) 

135 

136 

137@dataclass 

138class LoadPlan: 

139 """ 

140 Plan for loading checkpoint. 

141 

142 Contains read items and optional storage/planner-specific data. 

143 

144 Attributes: 

145 items: List of ReadItems to be loaded. Default []. 

146 storage_data: Storage-specific data (optional). Default None. 

147 planner_data: Planner-specific data (optional). Default None. 

148 """ 

149 items: list[ReadItem] = field(default_factory=list) 

150 storage_data: Optional[dict[MetadataIndex, Any]] = None # Storage-specific data mapping 

151 planner_data: Any = None # Planner-specific data (can be any type) 

152 

153class SavePlanner(abc.ABC): 

154 """Abstract base class for save planners.""" 

155 

156 @abc.abstractmethod 

157 def configure_planner(self, state_dict: dict[str, Any], **kwargs) -> None: 

158 """ 

159 Configure the planner with state dict. 

160 

161 Args: 

162 state_dict (dict[str, Any]): The state_dict to save. 

163 **kwargs: Additional keyword arguments (e.g., is_coordinator, rank, remove_redundancy, 

164 save_to_minimum_rank). 

165 """ 

166 

167 @abc.abstractmethod 

168 def build_local_plan(self) -> SavePlan: 

169 """ 

170 Build local save plan. 

171 

172 Creates a plan for saving checkpoint data from the current rank's perspective. 

173 This plan contains WriteItems for all tensors and bytes that this rank needs to save. 

174 

175 Returns: 

176 SavePlan: Local save plan containing WriteItems for this rank. 

177 """ 

178 

179 @abc.abstractmethod 

180 def build_global_plan(self, all_plans: list[SavePlan]) -> tuple[list[SavePlan], Metadata]: 

181 """ 

182 Build global plan from all local plans. 

183 

184 Combines local plans from all ranks into a global plan and creates checkpoint metadata. 

185 This method may deduplicate redundant data across ranks and assign storage indices. 

186 

187 Args: 

188 all_plans (list[SavePlan]): List of local save plans from all ranks. 

189 

190 Returns: 

191 tuple[list[SavePlan], Metadata]: Updated global plans (one per rank) and 

192 checkpoint metadata containing information about all saved items. 

193 """ 

194 

195 @abc.abstractmethod 

196 def finalize_plan(self, plan: SavePlan) -> SavePlan: 

197 """ 

198 Finalize the plan. 

199 

200 Performs any final adjustments to the plan before execution, such as updating 

201 tensor cache keys or performing planner-specific optimizations. 

202 

203 Args: 

204 plan (SavePlan): The plan to finalize. 

205 

206 Returns: 

207 SavePlan: The finalized plan ready for execution. 

208 """ 

209 

210 @abc.abstractmethod 

211 def get_data(self, item: WriteItem) -> Any: 

212 """ 

213 Get runtime data for a write item from the current state_dict. 

214 

215 Args: 

216 item (WriteItem): The write item to get data for. 

217 

218 Returns: 

219 Any: Runtime object to be written for this item. 

220 """ 

221 

222 

223class LoadPlanner(abc.ABC): 

224 """Abstract base class for load planners.""" 

225 

226 @abc.abstractmethod 

227 def configure_planner(self, state_dict: dict[str, Any], metadata: Metadata, **kwargs) -> None: 

228 """ 

229 Configure the planner with state dict and metadata. 

230 

231 Args: 

232 state_dict (dict[str, Any]): The state_dict to load into (modified in-place). 

233 metadata (Metadata): Checkpoint metadata. 

234 **kwargs: Additional keyword arguments (e.g., is_coordinator, rank). 

235 """ 

236 

237 @abc.abstractmethod 

238 def build_local_plan(self) -> LoadPlan: 

239 """ 

240 Build local load plan. 

241 

242 Creates a plan for loading checkpoint data from the current rank's perspective. 

243 This plan contains ReadItems for all tensors and bytes that this rank needs to load. 

244 

245 Returns: 

246 LoadPlan: Local load plan containing ReadItems for this rank. 

247 """ 

248 

249 @abc.abstractmethod 

250 def build_global_plan(self, all_plans: list[LoadPlan]) -> list[LoadPlan]: 

251 """ 

252 Build global plan from all local plans. 

253 

254 Combines local plans from all ranks into a global plan. This method may 

255 coordinate across ranks or perform optimizations. 

256 

257 Args: 

258 all_plans (list[LoadPlan]): List of local load plans from all ranks. 

259 

260 Returns: 

261 list[LoadPlan]: Updated global load plans (one per rank). 

262 """ 

263 

264 @abc.abstractmethod 

265 def finalize_plan(self, plan: LoadPlan) -> LoadPlan: 

266 """ 

267 Finalize the plan. 

268 

269 Performs any final adjustments to the plan before execution, such as 

270 performing planner-specific optimizations or validations. 

271 

272 Args: 

273 plan (LoadPlan): The plan to finalize. 

274 

275 Returns: 

276 LoadPlan: The finalized plan ready for execution. 

277 """ 

278 

279 @abc.abstractmethod 

280 def acquire_tensor(self, read_item: ReadItem) -> Any: 

281 """ 

282 Acquire tensor for read item. 

283 

284 Returns a tensor slice/view where data should be written. 

285 

286 Args: 

287 read_item (ReadItem): Read item to acquire tensor for. 

288 

289 Returns: 

290 Any: Acquired tensor slice/view (tensor-like object). 

291 """ 

292 

293 @abc.abstractmethod 

294 def apply_tensor(self, read_item: ReadItem, tensor: Any) -> None: 

295 """ 

296 Apply tensor after reading. 

297 

298 Args: 

299 read_item (ReadItem): Read item. 

300 tensor (Any): Tensor data to apply (tensor-like object). 

301 """ 

302 

303 @abc.abstractmethod 

304 def apply_bytes(self, read_item: ReadItem, value: bytes) -> None: 

305 """ 

306 Apply bytes data. 

307 

308 Args: 

309 read_item (ReadItem): The read item specifying the destination. 

310 value (bytes): The bytes data to deserialize and apply. 

311 """