Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / layout.py: 69%
65 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Layout I/O utilities for saving, loading and gathering layout information."""
16import json
17import logging
18import os
19from pathlib import Path
20from typing import Any, Union
22from hyper_parallel.platform import get_platform
24platform = get_platform()
25logger = logging.getLogger(__name__)
28def get_current_layout(cell: Any) -> dict:
29 """
30 Get current layout from cell
31 Args:
32 cell (Any): Instance of Cell (model/network object).
34 Returns:
35 dict: A dictionary where keys are rank IDs and values are dictionaries
36 mapping parameter names to their layout information, including
37 data type and full shape.
38 """
39 current_rank = str(platform.get_rank())
40 layout_dict = {current_rank: {}}
42 params_without_layout_attr = []
43 param_dict = platform.parameters_dict(cell)
44 for name, param in param_dict:
45 if name in layout_dict:
46 raise RuntimeError("param in cell can not have same name")
47 if not hasattr(param, "layout"):
48 params_without_layout_attr.append(name)
49 layout = getattr(param, "layout", None)
50 if layout:
51 layout_info = dict(layout.to_dict())
52 if "mesh_shape" in layout_info:
53 layout_info["device_matrix"] = layout_info.pop("mesh_shape")
54 layout_dict[current_rank][param.name] = layout_info
55 layout_dict[current_rank][param.name]["type"] = str(param.dtype)
56 layout_dict[current_rank][param.name]["full_shape"] = param.shape
57 else:
58 layout_dict[current_rank][param.name] = {}
59 layout_dict[current_rank][param.name]["type"] = str(param.dtype)
60 layout_dict[current_rank][param.name]["full_shape"] = param.shape
62 if params_without_layout_attr:
63 logger.info(
64 "The following parameters have no layout attribute (layout entry is None): %s",
65 params_without_layout_attr,
66 )
68 return layout_dict
71def save_layout(layout_dict: dict, file_path: Union[Path, str]) -> None:
72 """
73 Save layout to file.
74 """
75 file_path = Path(file_path)
76 file_path.parent.mkdir(parents=True, exist_ok=True)
77 with open(file_path, 'w', encoding="utf-8") as f:
78 json.dump(layout_dict, f, ensure_ascii=False)
81def load_layout(file_path: Union[Path, str]) -> dict:
82 """
83 Load layout from file.
84 """
85 file_path = Path(file_path)
86 if not file_path.exists():
87 raise FileNotFoundError(f"Layout file not found: {file_path}")
88 with open(file_path, 'r', encoding='utf-8') as f:
89 param_layout_dict = json.load(f)
90 return param_layout_dict
93def combine_layout(directory: Union[Path, str]) -> dict:
94 """
95 Combines layout files from the specified directory into a single layout dictionary.
97 This function scans the given directory for files with a '.layout' extension,
98 loads each layout file, and merges them into one dictionary.
100 Args:
101 directory (Union[Path, str]): The directory to scan for layout files.
103 Returns:
104 dict: A dictionary containing the combined layout information keyed by rank ID.
106 Raises:
107 RuntimeError: If duplicate rank IDs are found across the layout files.
109 Note:
110 Only processes files with '.layout' extension.
111 """
112 layout_dict = {}
113 for filename in os.listdir(directory):
114 if filename.endswith('.layout'):
115 load_dict = load_layout(os.path.join(directory, filename))
116 for rank_id, param_dict in load_dict.items():
117 if rank_id in layout_dict:
118 raise ValueError("rank_id in files must be unique")
119 layout_dict[rank_id] = param_dict
121 return layout_dict
124def get_global_layout(cell: Any) -> dict:
125 """
126 Get global layout information from all ranks, and gather them into a dict.
128 Args:
129 cell (Any): Instance of Cell (model/network object).
131 Return:
132 dict: A dictionary containing the global layout information keyed by rank ID.
133 """
134 # global layout
135 global_layout_dict = {}
137 # prepare empty global_layout_list
138 global_layout_list = []
139 world_size = platform.get_world_size()
140 for _ in range(world_size):
141 global_layout_list.append(None)
143 # local layout
144 local_layout = get_current_layout(cell)
146 # all gather object
147 platform.all_gather_object(global_layout_list, local_layout)
149 # cast list to dict
150 for layout_dict in global_layout_list:
151 global_layout_dict.update(layout_dict)
153 return global_layout_dict