Coverage for hyper_parallel / core / checkpoint / layout.py: 87%

46 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Layout I/O utilities for saving, loading and gathering layout information.""" 

16import json 

17import os 

18from pathlib import Path 

19from typing import Union 

20 

21from hyper_parallel.platform import get_platform 

22 

23platform = get_platform() 

24 

25 

26def get_current_layout(cell): 

27 """ 

28 Get current layout from cell 

29 Args: 

30 cell (Any): Instance of Cell (model/network object). 

31 

32 Returns: 

33 dict: A dictionary where keys are rank IDs and values are dictionaries 

34 mapping parameter names to their layout information, including 

35 data type and full shape. 

36 """ 

37 current_rank = str(platform.get_rank()) 

38 layout_dict = {current_rank: {}} 

39 

40 param_dict = platform.parameters_dict(cell) 

41 for name, param in param_dict: 

42 if name in layout_dict: 

43 raise RuntimeError("param in cell can not have same name") 

44 if param.layout: 

45 layout_dict[current_rank][param.name] = param.layout.to_dict() 

46 layout_dict[current_rank][param.name]["type"] = str(param.dtype) 

47 layout_dict[current_rank][param.name]["full_shape"] = param.shape 

48 

49 return layout_dict 

50 

51 

52def save_layout(layout_dict: dict, file_path: Union[Path, str]) -> None: 

53 """ 

54 Save layout to file 

55 """ 

56 # todo: check and create file path 

57 with open(file_path, 'w', encoding="utf-8") as f: 

58 json.dump(layout_dict, f, ensure_ascii=False) 

59 

60 

61def load_layout(file_path: Union[Path, str]) -> dict: 

62 """ 

63 Load layout from file 

64 """ 

65 # todo check path 

66 with open(file_path, 'r', encoding='utf-8') as f: 

67 param_layout_dict = json.load(f) 

68 return param_layout_dict 

69 

70 

71def combine_layout(directory: Union[Path, str]) -> dict: 

72 """ 

73 Combines layout files from the specified directory into a single layout dictionary. 

74 

75 This function scans the given directory for files with a '.layout' extension, 

76 loads each layout file, and merges them into one dictionary. 

77 

78 Args: 

79 directory (Union[Path, str]): The directory to scan for layout files. 

80 

81 Returns: 

82 dict: A dictionary containing the combined layout information keyed by rank ID. 

83 

84 Raises: 

85 RuntimeError: If duplicate rank IDs are found across the layout files. 

86 

87 Note: 

88 Only processes files with '.layout' extension. 

89 """ 

90 layout_dict = {} 

91 for filename in os.listdir(directory): 

92 if filename.endswith('.layout'): 

93 load_dict = load_layout(os.path.join(directory, filename)) 

94 for rank_id, param_dict in load_dict.items(): 

95 if rank_id in layout_dict: 

96 raise ValueError("rank_id in files must be unique") 

97 layout_dict[rank_id] = param_dict 

98 

99 return layout_dict 

100 

101 

102def get_global_layout(cell) -> dict: 

103 """ 

104 Get global layout information from all ranks, and gather them into a dict. 

105 

106 Args: 

107 cell (Any): Instance of Cell (model/network object). 

108 

109 Return: 

110 dict: A dictionary containing the global layout information keyed by rank ID. 

111 """ 

112 # global layout 

113 global_layout_dict = {} 

114 

115 # prepare empty global_layout_list 

116 global_layout_list = [] 

117 world_size = platform.get_world_size() 

118 for _ in range(world_size): 

119 global_layout_list.append(None) 

120 

121 # local layout 

122 local_layout = get_current_layout(cell) 

123 

124 # all gather object 

125 platform.all_gather_object(global_layout_list, local_layout) 

126 

127 # cast list to dict 

128 for layout_dict in global_layout_list: 

129 global_layout_dict.update(layout_dict) 

130 

131 return global_layout_dict