Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / standard_planner.py: 54%

269 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Standard planner implementations for checkpoint save and load.""" 

16from dataclasses import dataclass 

17import dataclasses 

18import pickle 

19from typing import Any, Optional, Union 

20 

21from hyper_parallel.core.distributed_checkpoint.metadata import ( 

22 Metadata, MetadataIndex, ChunkStorageMetadata, 

23 TensorStorageMetadata, TensorProperties, BytesStorageMetadata 

24) 

25from hyper_parallel.core.distributed_checkpoint.planner import ( 

26 SavePlan, SavePlanner, LoadPlan, LoadPlanner, 

27 WriteItem, WriteItemType, ReadItem, LoadItemType 

28) 

29from hyper_parallel.core.distributed_checkpoint.reshard import infer_slice_area_by_rank, infer_intersection 

30from hyper_parallel.core.distributed_checkpoint.util import ( 

31 narrow_tensor_by_index, 

32 chunk_to_area, 

33 create_chunk_list_for_tensor, 

34 remove_redundant_plans, 

35 flatten_state_dict, 

36 set_element, 

37) 

38from hyper_parallel.core.dtensor.dtensor import DTensor, Layout 

39from hyper_parallel.platform import get_platform 

40 

41platform = get_platform() 

42Tensor = platform.Tensor 

43 

44 

45@dataclass(frozen=True) 

46class CachedSaveResult: 

47 """Cached finalized save result keyed by planner cache namespace.""" 

48 

49 final_plan: SavePlan 

50 metadata: Metadata 

51 

52 

53class StandardSavePlanner(SavePlanner): 

54 """Standard implementation of SavePlanner for distributed checkpoint saving.""" 

55 

56 _cached_save_result: dict[str, CachedSaveResult] = {} 

57 

58 def __init__( 

59 self, 

60 enable_plan_caching: bool = True, 

61 remove_redundancy: bool = True, 

62 save_to_minimum_rank: bool = False, 

63 ): 

64 self.state_dict: Optional[dict[str, Any]] = None 

65 self.is_coordinator: bool = False 

66 self.rank: int = 0 

67 self.remove_redundancy: bool = remove_redundancy 

68 self.save_to_minimum_rank: bool = save_to_minimum_rank 

69 self.flatten_state_dict: bool = True 

70 self._enable_plan_caching: bool = enable_plan_caching 

71 self._cached_plans_key: str = self.__class__.__name__ 

72 

73 def configure_planner(self, state_dict: dict[str, Any], **kwargs) -> None: 

74 """ 

75 Configure planner. 

76 

77 Args: 

78 state_dict (dict[str, Any]): The state_dict to save. 

79 **kwargs: Additional keyword arguments (e.g., is_coordinator, rank, remove_redundancy, 

80 save_to_minimum_rank). 

81 """ 

82 self.is_coordinator = kwargs.get("is_coordinator", False) 

83 self.rank = kwargs.get("rank", 0) 

84 self.remove_redundancy = kwargs.get("remove_redundancy", self.remove_redundancy) 

85 self.save_to_minimum_rank = kwargs.get("save_to_minimum_rank", self.save_to_minimum_rank) 

86 self.flatten_state_dict = kwargs.get("flatten_state_dict", True) 

87 

88 use_collectives = bool(kwargs.get("use_collectives", True)) 

89 if not use_collectives: 

90 self.remove_redundancy = False 

91 self._enable_plan_caching = False 

92 elif "enable_plan_caching" in kwargs: 

93 self._enable_plan_caching = bool(kwargs["enable_plan_caching"]) 

94 

95 if self.flatten_state_dict: 

96 state_dict, self.name_mapping = flatten_state_dict(state_dict) 

97 self.state_dict = state_dict 

98 self._cached_plans_key = self._build_cache_key(state_dict) 

99 

100 def _build_cache_key(self, state_dict: dict[str, Any]) -> str: 

101 """Build a stable cache namespace from sorted state_dict keys.""" 

102 return f"{self.__class__.__name__}:{'||'.join(state_dict.keys())}" 

103 

104 def build_local_plan(self) -> SavePlan: 

105 """ 

106 Create local save plan. 

107 

108 Returns: 

109 SavePlan: Local save plan containing WriteItems for this rank. 

110 """ 

111 if self.state_dict is None: 

112 raise RuntimeError("Planner not set up") 

113 

114 def compute_global_offsets(global_shape: tuple[int, ...], dtensor_layout: Layout) -> tuple[int, ...]: 

115 """ 

116 Compute the offsets of local tensor in global tensor based on layout. 

117 

118 Args: 

119 global_shape (tuple[int, ...]): Global shape of the tensor. 

120 dtensor_layout (Layout): Layout of the DTensor. 

121 

122 Returns: 

123 tuple[int, ...]: Tuple of offsets for each dimension. 

124 """ 

125 if dtensor_layout is None: 

126 # If layout is None, return all zeros (no sharding) 

127 return tuple(0 for _ in global_shape) 

128 

129 # Validate layout attributes 

130 if not hasattr(dtensor_layout, 'mesh_shape') or dtensor_layout.mesh_shape is None: 

131 raise ValueError("Layout must have mesh_shape attribute") 

132 if not hasattr(dtensor_layout, 'tensor_map') or dtensor_layout.tensor_map is None: 

133 raise ValueError("Layout must have tensor_map attribute") 

134 if not hasattr(dtensor_layout, 'rank_list') or dtensor_layout.rank_list is None: 

135 raise ValueError("Layout must have rank_list attribute") 

136 

137 current_rank = self.rank 

138 if current_rank not in dtensor_layout.rank_list: 

139 raise ValueError( 

140 f"Current rank {current_rank} not found in layout's rank_list {dtensor_layout.rank_list}") 

141 

142 inner_rank_id = dtensor_layout.rank_list.index(current_rank) 

143 # Calculate slice area using infer_slice_area_by_rank 

144 slice_area = infer_slice_area_by_rank( 

145 mesh_shape=dtensor_layout.mesh_shape, 

146 tensor_map=dtensor_layout.tensor_map, 

147 rank_id=inner_rank_id, 

148 full_shape=global_shape 

149 ) 

150 # Extract offsets (start values) from slice_area 

151 return tuple(start for start, _ in slice_area) 

152 

153 items = [] 

154 for fqn, obj in self.state_dict.items(): 

155 # Check if it's a DTensor 

156 if isinstance(obj, DTensor): 

157 # Create write item for DTensor 

158 local_tensor = obj.to_local() 

159 layout = obj.layout 

160 

161 # Get chunk metadata with offsets 

162 if layout: 

163 offsets = compute_global_offsets(obj.shape, layout) 

164 else: 

165 offsets = (0,) * len(local_tensor.shape) 

166 

167 sizes = local_tensor.shape 

168 chunk = ChunkStorageMetadata(offsets=offsets, sizes=sizes) 

169 # Get tensor properties 

170 dtype_str = str(local_tensor.dtype) if hasattr(local_tensor, 'dtype') else 'unknown' 

171 properties = TensorProperties(dtype=dtype_str) 

172 # Create write item for this tensor 

173 index = MetadataIndex(fqn=fqn, offset=offsets, index=None) 

174 write_item = WriteItem( 

175 index=index, 

176 type=WriteItemType.TENSOR, 

177 tensor_data={ 

178 'chunk': chunk, 

179 'properties': properties, 

180 'size': obj.shape, 

181 } 

182 ) 

183 items.append(write_item) 

184 elif isinstance(obj, Tensor): 

185 # Create write item for platform.Tensor: build single chunk with tensor's own size 

186 dtype_str = str(obj.dtype) if hasattr(obj, 'dtype') else 'unknown' 

187 properties = TensorProperties(dtype=dtype_str) 

188 # Single chunk covering the whole tensor (offsets=0, sizes=shape) 

189 chunk = ChunkStorageMetadata( 

190 offsets=(0,) * len(obj.shape), 

191 sizes=obj.shape, 

192 ) 

193 index = MetadataIndex(fqn=fqn, offset=(0,) * len(obj.shape), index=None) 

194 write_item = WriteItem( 

195 index=index, 

196 type=WriteItemType.TENSOR, 

197 tensor_data={ 

198 'chunk': chunk, 

199 'properties': properties, 

200 'size': obj.shape, 

201 } 

202 ) 

203 items.append(write_item) 

204 else: 

205 # Handle non-tensor types (bytes, etc.) 

206 index = MetadataIndex(fqn=fqn) 

207 write_item = WriteItem( 

208 index=index, 

209 type=WriteItemType.BYTE_IO, 

210 bytes_io_data=None 

211 ) 

212 items.append(write_item) 

213 

214 plan = SavePlan(items=items) 

215 if self.flatten_state_dict: 

216 plan.planner_data = self.name_mapping 

217 return plan 

218 

219 def build_global_plan(self, all_plans: list[SavePlan]) -> tuple[list[SavePlan], Metadata]: 

220 """ 

221 Build global plan from all local plans. 

222 

223 Collects chunks from all ranks, validates consistency, and creates metadata for the checkpoint. 

224 

225 Args: 

226 all_plans (list[SavePlan]): List of local plans from all ranks. 

227 

228 Returns: 

229 tuple[list[SavePlan], Metadata]: Updated plans and checkpoint metadata. 

230 """ 

231 # Deduplicate plans if redundancy removal is enabled 

232 if self.remove_redundancy and len(all_plans) > 1: 

233 all_plans = remove_redundant_plans(all_plans, save_to_minimum_rank=self.save_to_minimum_rank) 

234 

235 # Collect all write items by FQN 

236 fqn_to_chunks: dict[str, list[ChunkStorageMetadata]] = {} 

237 fqn_to_properties: dict[str, TensorProperties] = {} 

238 fqn_to_size: dict[str, tuple] = {} 

239 state_dict_metadata: dict[str, Union[TensorStorageMetadata, BytesStorageMetadata]] = {} 

240 

241 final_global_plans: list[SavePlan] = [] 

242 for plan in all_plans: 

243 with_index_items = [] 

244 for item in plan.items: 

245 if item.type == WriteItemType.TENSOR and item.tensor_data: 

246 fqn = item.index.fqn 

247 chunk = item.tensor_data['chunk'] 

248 properties = item.tensor_data['properties'] 

249 size = item.tensor_data['size'] 

250 

251 # Validate consistency across ranks 

252 if fqn in fqn_to_chunks and (fqn_to_properties[fqn] != properties or fqn_to_size[fqn] != size): 

253 raise ValueError(f"The {fqn} in different rank has different properties and size.") 

254 

255 # Initialize FQN entry if not exists 

256 if fqn not in fqn_to_chunks: 

257 fqn_to_properties[fqn] = properties 

258 fqn_to_size[fqn] = size 

259 fqn_to_chunks[fqn] = [] 

260 

261 # Append chunk and set index (platform.Tensor has exactly one chunk) 

262 new_index = dataclasses.replace(item.index, index=len(fqn_to_chunks[fqn])) 

263 with_index_item = dataclasses.replace(item, index=new_index) 

264 with_index_items.append(with_index_item) 

265 fqn_to_chunks[fqn].append(chunk) 

266 

267 elif item.type == WriteItemType.BYTE_IO: 

268 with_index_items.append(item) 

269 state_dict_metadata[item.index.fqn] = BytesStorageMetadata() 

270 else: 

271 raise ValueError(f"Unsupported write item type: {item.type}") 

272 

273 final_global_plans.append(dataclasses.replace(plan, items=with_index_items)) 

274 

275 # Create metadata for all tensors 

276 for fqn, chunks in fqn_to_chunks.items(): 

277 state_dict_metadata[fqn] = TensorStorageMetadata( 

278 properties=fqn_to_properties[fqn], 

279 size=fqn_to_size[fqn], 

280 chunks=chunks 

281 ) 

282 

283 metadata = Metadata(state_dict_metadata=state_dict_metadata) 

284 if self.flatten_state_dict: 

285 merged_mapping = {} 

286 for p in all_plans: 

287 merged_mapping.update(p.planner_data) 

288 metadata.planner_data = merged_mapping 

289 return final_global_plans, metadata 

290 

291 def finalize_plan(self, plan: SavePlan) -> SavePlan: 

292 """ 

293 Finalize the plan. 

294 

295 Args: 

296 plan (SavePlan): Plan to finalize. 

297 

298 Returns: 

299 SavePlan: Finalized plan. 

300 """ 

301 return plan 

302 

303 def get_cached_result(self) -> Optional[tuple[SavePlan, Metadata]]: 

304 """Return cached finalized plan and metadata when plan caching is enabled.""" 

305 if not self._enable_plan_caching: 

306 return None 

307 cached_result = StandardSavePlanner._cached_save_result.get(self._cached_plans_key) 

308 if cached_result is None: 

309 return None 

310 return cached_result.final_plan, cached_result.metadata 

311 

312 def cache_result(self, final_plan: SavePlan, metadata: Metadata) -> None: 

313 """Store finalized plan and metadata in the class-level planner cache.""" 

314 if not self._enable_plan_caching: 

315 return 

316 StandardSavePlanner._cached_save_result[self._cached_plans_key] = CachedSaveResult( 

317 final_plan=final_plan, 

318 metadata=metadata, 

319 ) 

320 

321 def get_data(self, item: WriteItem) -> Any: 

322 """ 

323 Get current runtime data from state_dict for a write item. 

324 

325 Args: 

326 item (WriteItem): Write item describing what to write. 

327 

328 Returns: 

329 Any: Runtime object to be written. 

330 """ 

331 if self.state_dict is None: 

332 raise RuntimeError("Planner not set up") 

333 fqn = item.index.fqn 

334 if fqn not in self.state_dict: 

335 raise KeyError(f"Key {fqn} not found in state_dict") 

336 obj = self.state_dict[fqn] 

337 if item.type == WriteItemType.TENSOR: 

338 if isinstance(obj, DTensor): 

339 return obj.to_local().detach().cpu() 

340 if isinstance(obj, Tensor): 

341 return obj.detach().cpu() 

342 raise TypeError(f"Write item {fqn} expected tensor-like object, got {type(obj)}") 

343 if item.type == WriteItemType.BYTE_IO: 

344 return obj 

345 raise TypeError(f"Unsupported write item type: {item.type}") 

346 

347def create_read_items_for_chunk_list( 

348 fqn: str, 

349 checkpoint_md: TensorStorageMetadata, 

350 local_chunks: list[ChunkStorageMetadata], 

351) -> list[ReadItem]: 

352 """ 

353 Create ReadItems by matching local chunks (what this rank needs) with 

354 saved chunks (checkpoint_md.chunks), including resharding overlaps. 

355 

356 Mirrors torch create_read_items_for_chunk_list behavior. 

357 

358 Args: 

359 fqn (str): Fully qualified name of the tensor. 

360 checkpoint_md (TensorStorageMetadata): Tensor storage metadata from checkpoint. 

361 local_chunks (list[ChunkStorageMetadata]): List of local chunks needed by this rank. 

362 

363 Returns: 

364 list[ReadItem]: List of ReadItems for loading the required data. 

365 """ 

366 read_items: list[ReadItem] = [] 

367 saved_chunks = checkpoint_md.chunks 

368 if not local_chunks or not saved_chunks: 

369 return read_items 

370 

371 for local_idx, local_chunk in enumerate(local_chunks): 

372 local_area = chunk_to_area(local_chunk) 

373 for storage_idx, storage_chunk in enumerate(saved_chunks): 

374 saved_area = chunk_to_area(storage_chunk) 

375 overlap = infer_intersection(local_area, saved_area) 

376 if overlap is None: 

377 continue 

378 

379 dest_offsets = tuple(overlap[i][0] - local_chunk.offsets[i] for i in range(len(overlap))) 

380 storage_offsets = tuple(overlap[i][0] - storage_chunk.offsets[i] for i in range(len(overlap))) 

381 lengths = tuple(overlap[i][1] - overlap[i][0] for i in range(len(overlap))) 

382 

383 read_items.append( 

384 ReadItem( 

385 type=LoadItemType.TENSOR, 

386 dest_index=MetadataIndex(fqn=fqn, offset=local_chunk.offsets, index=local_idx), 

387 dest_offsets=dest_offsets, 

388 storage_index=MetadataIndex(fqn=fqn, offset=storage_chunk.offsets, index=storage_idx), 

389 storage_offsets=storage_offsets, 

390 lengths=lengths, 

391 ) 

392 ) 

393 return read_items 

394 

395 

396class StandardLoadPlanner(LoadPlanner): 

397 """ 

398 Standard implementation of LoadPlanner. 

399 

400 Iterate state_dict and creates load plans via chunk list for resharding support. 

401 """ 

402 

403 def __init__(self, allow_partial_load: bool = False): 

404 """ 

405 Args: 

406 allow_partial_load (bool): If True, allow loading when checkpoint has fewer keys than state_dict. 

407 Default False. 

408 """ 

409 self.state_dict: Optional[dict[str, Any]] = None 

410 self.metadata: Optional[Metadata] = None 

411 self.is_coordinator: bool = False 

412 self.rank: int = 0 

413 self.allow_partial_load = allow_partial_load 

414 self.flatten_state_dict: bool = True 

415 

416 def configure_planner(self, state_dict: dict[str, Any], metadata: Metadata, **kwargs) -> None: 

417 """ 

418 Configure planner with state dict and metadata. 

419 

420 Args: 

421 state_dict (dict[str, Any]): The state_dict to load into (modified in-place). 

422 metadata (Metadata): Checkpoint metadata. 

423 **kwargs: Additional keyword arguments (e.g., is_coordinator, rank). 

424 """ 

425 self.state_dict = state_dict 

426 self.metadata = metadata 

427 self.is_coordinator = kwargs.get("is_coordinator", False) 

428 self.rank = kwargs.get("rank", 0) 

429 self.flatten_state_dict = kwargs.get("flatten_state_dict", True) 

430 self.original_state_dict = state_dict 

431 if self.flatten_state_dict: 

432 state_dict, self.name_mapping = flatten_state_dict(state_dict) 

433 self.state_dict = state_dict 

434 

435 def build_local_plan(self) -> LoadPlan: 

436 """ 

437 Build local load plan. 

438 

439 Iterate state_dict and creates load plans via chunk list for resharding support. 

440 

441 Returns: 

442 LoadPlan: Local load plan containing ReadItems for this rank. 

443 """ 

444 if self.state_dict is None or self.metadata is None: 

445 raise RuntimeError("Planner not configured") 

446 

447 requests: list[ReadItem] = [] 

448 strict = not self.allow_partial_load 

449 for fqn, obj in self.state_dict.items(): 

450 if fqn not in self.metadata.state_dict_metadata: 

451 if strict: 

452 raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.") 

453 continue 

454 md = self.metadata.state_dict_metadata[fqn] 

455 if isinstance(md, TensorStorageMetadata): 

456 obj_size = getattr(obj, "shape", None) 

457 if obj_size is None or md.size != tuple(obj_size): 

458 raise ValueError( 

459 f"Size mismatch between saved {md.size} and current: {obj_size} for {fqn}", 

460 ) 

461 if isinstance(obj, DTensor): 

462 layout = getattr(obj, "layout", None) 

463 rank_list = getattr(layout, "rank_list", None) if layout else None 

464 if rank_list is None and layout is not None: 

465 rank_list = getattr(layout, "_rank_list", None) 

466 if layout is not None and rank_list is not None: 

467 if get_platform().get_rank() not in rank_list: 

468 continue 

469 # Both DTensor and platform.Tensor: create local chunks and read items 

470 local_chunks = create_chunk_list_for_tensor(obj) 

471 requests += create_read_items_for_chunk_list(fqn, md, local_chunks) 

472 else: 

473 requests.append( 

474 ReadItem( 

475 type=LoadItemType.BYTE_IO, 

476 dest_index=MetadataIndex(fqn=fqn), 

477 dest_offsets=(0,), 

478 storage_index=MetadataIndex(fqn=fqn), 

479 storage_offsets=(0,), 

480 lengths=(0,), 

481 ) 

482 ) 

483 return LoadPlan(items=requests) 

484 

485 def build_global_plan(self, all_plans: list[LoadPlan]) -> list[LoadPlan]: 

486 """ 

487 Build global plan from all local plans. 

488 

489 For now, returns plans as-is. In a more sophisticated implementation, you might need to coordinate across ranks. 

490 

491 Args: 

492 all_plans (list[LoadPlan]): List of local plans from all ranks. 

493 

494 Returns: 

495 list[LoadPlan]: Global plans (currently returns plans as-is). 

496 """ 

497 return all_plans 

498 

499 def finalize_plan(self, plan: LoadPlan) -> LoadPlan: 

500 """ 

501 Finalize the plan (no-op for default implementation). 

502 

503 Args: 

504 plan (LoadPlan): Plan to finalize. 

505 

506 Returns: 

507 LoadPlan: Finalized plan. 

508 """ 

509 return plan 

510 

511 def acquire_tensor(self, read_item: ReadItem) -> Any: 

512 """ 

513 Acquire the destination slice (narrow view) for this read_item. 

514 

515 StorageReader uses this to copy loaded data into the correct region. 

516 Torch-aligned behavior. 

517 

518 Args: 

519 read_item (ReadItem): The read item specifying what to load. 

520 

521 Returns: 

522 Any: The destination tensor slice where data should be written 

523 (tensor-like object). 

524 """ 

525 if self.state_dict is None: 

526 raise RuntimeError("Planner not configured") 

527 

528 fqn = read_item.dest_index.fqn 

529 if fqn not in self.state_dict: 

530 raise KeyError(f"Key {fqn} not found in state_dict") 

531 

532 target = self.state_dict[fqn] 

533 local_tensor = target.to_local().detach() if isinstance(target, DTensor) else target.detach() 

534 return narrow_tensor_by_index( 

535 local_tensor, 

536 read_item.dest_offsets, 

537 read_item.lengths, 

538 ) 

539 

540 def apply_tensor(self, read_item: ReadItem, tensor: Any) -> None: 

541 """ 

542 Apply tensor after reading. 

543 

544 After read_data copies into the slice, this is no-op when tensor is the 

545 same slice. When the backend has no copy_ (e.g. mindspore), read_data 

546 passes the loaded slice here; we copy it into the destination slice. 

547 

548 Args: 

549 read_item (ReadItem): The read item that was processed. 

550 tensor (Any): The tensor data to apply (tensor-like object). 

551 """ 

552 if tensor is None: 

553 return 

554 dest_slice = self.acquire_tensor(read_item) 

555 if dest_slice is tensor: 

556 return 

557 if hasattr(dest_slice, "copy_"): 

558 dest_slice.copy_(tensor) 

559 else: 

560 # Fallback: assign into state_dict if supported 

561 dest_slice[...] = tensor 

562 

563 def apply_bytes(self, read_item: ReadItem, value: bytes) -> None: 

564 """ 

565 Load bytes data into state_dict. 

566 

567 Args: 

568 read_item (ReadItem): The read item specifying the destination. 

569 value (bytes): The bytes data to deserialize and load. 

570 """ 

571 if self.state_dict is None: 

572 raise RuntimeError("Planner not set up") 

573 

574 fqn = read_item.dest_index.fqn 

575 # Deserialize bytes 

576 obj = pickle.loads(value) 

577 self.state_dict[fqn] = obj 

578 if self.flatten_state_dict: 

579 set_element(self.original_state_dict, self.name_mapping[fqn], obj) 

580 

581 

582 

583class _DcpMergeLoadPlanner(StandardLoadPlanner): 

584 """Load planner that builds distributed checkpoint from dcp into fully ``state_dict`` (in-place).""" 

585 

586 def __init__(self) -> None: 

587 super().__init__() 

588 

589 def configure_planner(self, state_dict: dict[str, Any], metadata: Metadata, **kwargs) -> None: 

590 if len(state_dict) > 0: 

591 raise ValueError( 

592 "state_dict must be empty for _DcpMergeLoadPlanner; " 

593 "it is populated in-place from checkpoint metadata." 

594 ) 

595 

596 if metadata is None: 

597 raise ValueError("metadata must not be None for _DcpMergeLoadPlanner.") 

598 

599 self.is_coordinator = kwargs.get("is_coordinator", False) 

600 for k, v in metadata.state_dict_metadata.items(): 

601 if isinstance(v, TensorStorageMetadata): 

602 v = platform.empty( 

603 platform.list_to_size(v.size), 

604 dtype=platform.str_to_dtype(v.properties.dtype), 

605 ) 

606 

607 state_dict[k] = v 

608 if metadata.planner_data is not None and k in metadata.planner_data: 

609 set_element(state_dict, metadata.planner_data[k], v) 

610 

611 super().configure_planner( 

612 state_dict, 

613 metadata, 

614 is_coordinator=self.is_coordinator, 

615 flatten_state_dict=True, 

616 )