Coverage for hyper_parallel / core / checkpoint / layout.py: 87%
46 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Layout I/O utilities for saving, loading and gathering layout information."""
16import json
17import os
18from pathlib import Path
19from typing import Union
21from hyper_parallel.platform import get_platform
23platform = get_platform()
26def get_current_layout(cell):
27 """
28 Get current layout from cell
29 Args:
30 cell (Any): Instance of Cell (model/network object).
32 Returns:
33 dict: A dictionary where keys are rank IDs and values are dictionaries
34 mapping parameter names to their layout information, including
35 data type and full shape.
36 """
37 current_rank = str(platform.get_rank())
38 layout_dict = {current_rank: {}}
40 param_dict = platform.parameters_dict(cell)
41 for name, param in param_dict:
42 if name in layout_dict:
43 raise RuntimeError("param in cell can not have same name")
44 if param.layout:
45 layout_dict[current_rank][param.name] = param.layout.to_dict()
46 layout_dict[current_rank][param.name]["type"] = str(param.dtype)
47 layout_dict[current_rank][param.name]["full_shape"] = param.shape
49 return layout_dict
52def save_layout(layout_dict: dict, file_path: Union[Path, str]) -> None:
53 """
54 Save layout to file
55 """
56 # todo: check and create file path
57 with open(file_path, 'w', encoding="utf-8") as f:
58 json.dump(layout_dict, f, ensure_ascii=False)
61def load_layout(file_path: Union[Path, str]) -> dict:
62 """
63 Load layout from file
64 """
65 # todo check path
66 with open(file_path, 'r', encoding='utf-8') as f:
67 param_layout_dict = json.load(f)
68 return param_layout_dict
71def combine_layout(directory: Union[Path, str]) -> dict:
72 """
73 Combines layout files from the specified directory into a single layout dictionary.
75 This function scans the given directory for files with a '.layout' extension,
76 loads each layout file, and merges them into one dictionary.
78 Args:
79 directory (Union[Path, str]): The directory to scan for layout files.
81 Returns:
82 dict: A dictionary containing the combined layout information keyed by rank ID.
84 Raises:
85 RuntimeError: If duplicate rank IDs are found across the layout files.
87 Note:
88 Only processes files with '.layout' extension.
89 """
90 layout_dict = {}
91 for filename in os.listdir(directory):
92 if filename.endswith('.layout'):
93 load_dict = load_layout(os.path.join(directory, filename))
94 for rank_id, param_dict in load_dict.items():
95 if rank_id in layout_dict:
96 raise ValueError("rank_id in files must be unique")
97 layout_dict[rank_id] = param_dict
99 return layout_dict
102def get_global_layout(cell) -> dict:
103 """
104 Get global layout information from all ranks, and gather them into a dict.
106 Args:
107 cell (Any): Instance of Cell (model/network object).
109 Return:
110 dict: A dictionary containing the global layout information keyed by rank ID.
111 """
112 # global layout
113 global_layout_dict = {}
115 # prepare empty global_layout_list
116 global_layout_list = []
117 world_size = platform.get_world_size()
118 for _ in range(world_size):
119 global_layout_list.append(None)
121 # local layout
122 local_layout = get_current_layout(cell)
124 # all gather object
125 platform.all_gather_object(global_layout_list, local_layout)
127 # cast list to dict
128 for layout_dict in global_layout_list:
129 global_layout_dict.update(layout_dict)
131 return global_layout_dict