Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / layout.py: 69%

65 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Layout I/O utilities for saving, loading and gathering layout information.""" 

16import json 

17import logging 

18import os 

19from pathlib import Path 

20from typing import Any, Union 

21 

22from hyper_parallel.platform import get_platform 

23 

24platform = get_platform() 

25logger = logging.getLogger(__name__) 

26 

27 

28def get_current_layout(cell: Any) -> dict: 

29 """ 

30 Get current layout from cell 

31 Args: 

32 cell (Any): Instance of Cell (model/network object). 

33 

34 Returns: 

35 dict: A dictionary where keys are rank IDs and values are dictionaries 

36 mapping parameter names to their layout information, including 

37 data type and full shape. 

38 """ 

39 current_rank = str(platform.get_rank()) 

40 layout_dict = {current_rank: {}} 

41 

42 params_without_layout_attr = [] 

43 param_dict = platform.parameters_dict(cell) 

44 for name, param in param_dict: 

45 if name in layout_dict: 

46 raise RuntimeError("param in cell can not have same name") 

47 if not hasattr(param, "layout"): 

48 params_without_layout_attr.append(name) 

49 layout = getattr(param, "layout", None) 

50 if layout: 

51 layout_info = dict(layout.to_dict()) 

52 if "mesh_shape" in layout_info: 

53 layout_info["device_matrix"] = layout_info.pop("mesh_shape") 

54 layout_dict[current_rank][param.name] = layout_info 

55 layout_dict[current_rank][param.name]["type"] = str(param.dtype) 

56 layout_dict[current_rank][param.name]["full_shape"] = param.shape 

57 else: 

58 layout_dict[current_rank][param.name] = {} 

59 layout_dict[current_rank][param.name]["type"] = str(param.dtype) 

60 layout_dict[current_rank][param.name]["full_shape"] = param.shape 

61 

62 if params_without_layout_attr: 

63 logger.info( 

64 "The following parameters have no layout attribute (layout entry is None): %s", 

65 params_without_layout_attr, 

66 ) 

67 

68 return layout_dict 

69 

70 

71def save_layout(layout_dict: dict, file_path: Union[Path, str]) -> None: 

72 """ 

73 Save layout to file. 

74 """ 

75 file_path = Path(file_path) 

76 file_path.parent.mkdir(parents=True, exist_ok=True) 

77 with open(file_path, 'w', encoding="utf-8") as f: 

78 json.dump(layout_dict, f, ensure_ascii=False) 

79 

80 

81def load_layout(file_path: Union[Path, str]) -> dict: 

82 """ 

83 Load layout from file. 

84 """ 

85 file_path = Path(file_path) 

86 if not file_path.exists(): 

87 raise FileNotFoundError(f"Layout file not found: {file_path}") 

88 with open(file_path, 'r', encoding='utf-8') as f: 

89 param_layout_dict = json.load(f) 

90 return param_layout_dict 

91 

92 

93def combine_layout(directory: Union[Path, str]) -> dict: 

94 """ 

95 Combines layout files from the specified directory into a single layout dictionary. 

96 

97 This function scans the given directory for files with a '.layout' extension, 

98 loads each layout file, and merges them into one dictionary. 

99 

100 Args: 

101 directory (Union[Path, str]): The directory to scan for layout files. 

102 

103 Returns: 

104 dict: A dictionary containing the combined layout information keyed by rank ID. 

105 

106 Raises: 

107 RuntimeError: If duplicate rank IDs are found across the layout files. 

108 

109 Note: 

110 Only processes files with '.layout' extension. 

111 """ 

112 layout_dict = {} 

113 for filename in os.listdir(directory): 

114 if filename.endswith('.layout'): 

115 load_dict = load_layout(os.path.join(directory, filename)) 

116 for rank_id, param_dict in load_dict.items(): 

117 if rank_id in layout_dict: 

118 raise ValueError("rank_id in files must be unique") 

119 layout_dict[rank_id] = param_dict 

120 

121 return layout_dict 

122 

123 

124def get_global_layout(cell: Any) -> dict: 

125 """ 

126 Get global layout information from all ranks, and gather them into a dict. 

127 

128 Args: 

129 cell (Any): Instance of Cell (model/network object). 

130 

131 Return: 

132 dict: A dictionary containing the global layout information keyed by rank ID. 

133 """ 

134 # global layout 

135 global_layout_dict = {} 

136 

137 # prepare empty global_layout_list 

138 global_layout_list = [] 

139 world_size = platform.get_world_size() 

140 for _ in range(world_size): 

141 global_layout_list.append(None) 

142 

143 # local layout 

144 local_layout = get_current_layout(cell) 

145 

146 # all gather object 

147 platform.all_gather_object(global_layout_list, local_layout) 

148 

149 # cast list to dict 

150 for layout_dict in global_layout_list: 

151 global_layout_dict.update(layout_dict) 

152 

153 return global_layout_dict