Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/auto_parallel/sapp_ppb/__init__.py 0.0% 22-23,25,27,42,58,60-61,63-66
hyper_parallel/auto_parallel/sapp_ppb/run_pipeline_balance.py 0.0% 16-18,20-25,28,30,33,39,45-47,50,52,54,56,59,63,66,70,77,79,83,85,89,93,97,99,101,104,106-108,111,118-120,122-123,125-129,131-133,135-136,138-140,142,149,151-154,156-159,161-169,172,174-179,182-183
hyper_parallel/auto_parallel/sapp_ppb/sapp/sapp_pipeline.py 0.0% 16-18,20-21,23-28,31,34,69-74,76-81,84-86,95,97,99,101-104,107,109,112,114,116,118,120,122,124,126,128,130,132-133,135,137-138,140,142,144,148-150,152,155-156,158,164-168,170,175-180,187,189,194-197,204-209,211-212,214,219-227,231-235,237,242-249,251,253-257,259,264-274,277,280,283-290,293-295,297,300-306,308,319,321-326,328-333,335,343-344,352,354,356-363,365,367-369,371-372,374-376,379-382,384-388,390,392-399,401,403-405,407-408,410-412,415-419,421,426-429,431,433-434,447,449-450,452-456,458-459,462,464,466-469,471,473,475-481,483,485-491,494,497,502,504-506,509,512-513,516-517,520,523,526,529,532-535,538-541,544,548,550,552-557,559,564-576,578,583-588,592,595,603,605-608,615,617-619,621,626,628,630-633,635,638,640,651,663-664,678-679,690,699,707-712,714,716,731,733-734,737,745-748,750-751,760-767,769,772,774-778
hyper_parallel/auto_parallel/sapp_ppb/sapp/sapp_solver.py 0.0% 17-20,22,24-26,29-31,33,36-37,39-45,47,50,52-61,63,98-102,104-108,110,112-116,118,120,122,124-126,128,130-135,137,139-141,143-150,152-153,158,163,169,171-174,176,178,180-183,185,187,189-194,196,198-208,210-211,214-222,224-225,227,229-230,233,235-243,245-247,250-253,255,257-258,261,263-270,272-274,277,279-280,283-290,292-296,303-304,307-309,311-318,320-334,336-338,341-344,346-347,349-350,355-359,363-364,367-368,370-371,373-374,376-377,380-384,387-390,392-393,397-400,402,405-406,408-410,414-415,418-421,423-424,428-431,435-436,438,440,442,445-446,449,451,454-460,465,467,469-471,473,477-482,484,486-489,493,495,497-499,503,505-508,511,513,515,522,524,531,533-538,540,542-544,547,549,557-559,562,564,566,568-575,580-585,588,590-591,594,596,599,601-603,608-609,611,614-615,622,628-630,633-635,638,641,645-646,648-649,651,653-654,657,661-662,664-665,668,675-676,678,682,684,687-690,696,701,707-713,716-720,722-727,730-738,740,745-749,751,753,758-763,767,769,771-779,781-783,785-786,790-793,801,806-807,815,821,829,833,835,837,839-842,844,847-853,855-857,860,863-864,868-869,872,874,876-878,885-886,889,900,903-909,916,918,920-924,930-935,937-938,940,942-950,955-958,960,962,964-972,974-977,979,981,983-994,997,1001,1004,1006-1010,1015,1019-1022,1024,1026-1029,1031-1036,1038,1040-1050,1052-1060,1062-1063,1070-1071,1078-1079,1086-1089,1091-1094,1103,1105,1107-1109,1112,1115-1116,1119,1123-1124,1132,1141-1142,1144-1149,1152,1159-1162,1164,1166,1168-1169,1171,1173-1182,1186-1188,1190,1192-1197,1200,1204,1206,1208-1209,1212-1213,1215,1224-1226,1228,1235,1237,1240,1247-1248,1250,1257-1258,1260,1270-1271,1273,1283-1284,1286,1289,1292,1295,1302-1303,1305-1306,1317-1322,1324,1329-1334,1336,1338-1343,1345,1347,1350-1365,1367,1369-1375,1378,1380,1383-1384,1388,1392,1394,1398-1407,1409-1417,1419,1421-1424,1426-1427,1429,1431,1439,1447,1449-1451,1453,1456-1458,1462,1466,1468,1470,1472-1473,1481,1489,1491,1496-1498,1501,1505,1507
hyper_parallel/auto_parallel/sapp_ppb/simulator/causal_error.py 0.0% 16,18,20-22,25,28,36-41,43,45-46,49,52,60-65,67,69-70
hyper_parallel/auto_parallel/sapp_ppb/simulator/pipeline_builder.py 0.0% 16,18,20,22,25,27-28,30-33,35-47,49-56,58-59,62-70,72,74,92-96,98,100-101,103-104,109-116,119,122,125,128-129,131-132,137-144,147,150,153,157-158,160-161,170-176
hyper_parallel/auto_parallel/sapp_ppb/simulator/plot_manager.py 0.0% 16,18-19,21-23,25,28,32,44-45,47-55,57,59,61-65,67-69,71,75-85,87-90,92,94-98,100,104-107,109,111,115-119,121-122,124-125,127,129,133-138,142-145,149,151,156-158,160-168,170,174-185,190-191,193,196-207,213-214,216,218-223,225-227,229,231-232,234-235,238,240-241,243,245-248
hyper_parallel/auto_parallel/sapp_ppb/simulator/pp_simulator.py 0.0% 16,18-19,21,23-28,30,33,104,111,115,120-129,131-132,134-138,142-143,145,161-169,171-177,179,190-193,195-199,202,205-207,209,212-214,216,218-219,222-227,231,233,235-241,243,245-248,250,252-254,256-258,260,262-263,265,267,269,271-273,278,280-281,285-292,294-299,301,303,305,307-309,311-319,321-322,324,326-331,333-335,337,339-357,359,361-370,372,374-387,389-390,392,394-399,401,403-408,410,412-420,422-428,430-431,433,435-440,442,444-447,449,451-465,467,469-470,472,474-481,483,485-491,493,495-506,508-511,514,519
hyper_parallel/auto_parallel/sapp_ppb/simulator/sim_block.py 0.0% 16,18-19,21-23,25,28-29,31-40,42-49,51-52,54-55,57-58,60,62-63,65-68,70-71,73-76,78,80-82,85,88-91,93,95-114,116,118,121-122,124-132,134-135,137,139-140,142-147,150,152,154,156,158,160,162,164,167-168,170-173,175,177,179,181,184,186-193,195-198,200,202,204-207,210-211,213-214,216,219,222,224-227,229-232,235-236,238-239,241,244-251,253,256,259,261-265,268-269,271-274,276,278-280,282,284-303,306-307,309-310,312,315-322,324,327,329-330,332-335
hyper_parallel/auto_parallel/sapp_ppb/simulator/utils.py 0.0% 16-18,20-21,23,25,28,43-51,54,64-67,70,79-83,86,98-99,102,113,116-117,119-132,134,137,147-148,150-154,156
hyper_parallel/auto_parallel/sapp_ppb/utils/check_rules.py 0.0% 17,19-20,22,25,27-28,31,34,45-51,54-55
hyper_parallel/auto_parallel/sapp_ppb/utils/computation_analyzer.py 0.0% 18-23,25,27,29,32,35,37,48-55,57-60,62-67,69,72-79,81,83-85,87-90,92-96,98,101-118,120-122,124-127,129-131,133,135-136,138-145,147,149-176,180-186,188-194,196,199-213,215-220,223,225,227-233,235,237-240,242,247-251,253-258,260-261,263,265,267-271,273-276,279-285,290-292,294,296-297,301,304-309
hyper_parallel/auto_parallel/sapp_ppb/utils/compute_memory.py 0.0% 16,18-21,24,46-54,56,65-67,69-74,76,78-84,86-88,90,92-98,100-104,106-108,110,112-113,115-118,120,125-133,135-137,139,147-153,155-160,162-167,169,174,176-181,183,185-189,191,193-201,205-206,208,214,216,222-231,234-236,240,246-255,260-261,263-266,270-271,274-277,279,287-289,291-293,295-296,298-300,304-305,310,318-319,322-323,327,334,344-346,348,350,352-354,356-359,361-363,365,368-371,373-374,376-378,380-382,384,386-394,396-398,400,402-412,414,416,418,420-422,424,427-429,431,433-435,437,439-441,444,446,448-450,452-461
hyper_parallel/auto_parallel/sapp_ppb/utils/config.py 0.0% 16-21,23-24,26-32,34,37-38,41-46,48,52-56,59-60,63-65,67,70-73,76,79,83,85-88,90,92,94,97-99,101,103-110,112,114-115,118,120-122,124-126,128-132,134-136,138-143,145-154,156-173,175,177,179,182,185-187,189-195,198-202,205,208,213-218,220-222,224,226,228,231,235,247,250,252-257,259-263,266-268,270,272,275-285,289-294,297-305,312,318,326-327,333,339,346,349,351-354,356-360,364-373,375-376,379,381-386,389,391-403,406,408-409,411-413,415-421,423-425,427-428,430-434,436-438,441,443-449,459-461,464,467,472-473,479-484,486-487,490,493-496,499,504-506,508-512,515-521,524,527,530,532,535-537,539,542,544-546,549-551,554,565-567,569,571-573
hyper_parallel/auto_parallel/sapp_ppb/utils/error.py 0.0% 16,18,21,25,35-36,39,51,53
hyper_parallel/auto_parallel/sapp_ppb/utils/interactive.py 0.0% 16-17,19-24,26,28,30,33,35,38,40,43,45,48,50-57,59-62,64-67,69-72,74,77,79-100,102-104,106,111,113-119,121,123-126,128,130-132,134-136,138-140,142-144,146-147,150,152,154-155,157-161,163-166,168-172,174,178-179,181-184
hyper_parallel/auto_parallel/sapp_ppb/utils/layer.py 0.0% 16-19,21-23,26,47-59,61,91-108,110,112-129,131,133,135,137,139,141,143-145,147,149,155-157,159-161,165-166,170,174,180-182,184-187,191,195,199,205-206,208,210,218,220-226,247-248,250,254,257,261-265,268,271,273-277,279,282,284-288,291,294,297,300
hyper_parallel/auto_parallel/sapp_ppb/utils/logger.py 0.0% 16,18-21,24,34-36,38,40,42-45,47,49
hyper_parallel/auto_parallel/sapp_ppb/utils/recompute.py 0.0% 16-17,19,21-22,24,32,40,48,57,65,74,76-79,82,84-86,89,91-93,96,110-113,115,121-136,138-139,142,156-161,163-164,169-171,176-184,187,190,193,204-205,208,210,213,223-227,230,243-248,254,257,259,262,265-273,276,285-294,298-301,304,306-312,315,317-321,324,326-330,333,335-340,343,345-349
hyper_parallel/auto_parallel/sapp_ppb/utils/stage.py 0.0% 17-18,21,36-40,42,53-58,60,62,65-69,71,73-74,76,78,84,88,93,95,100-101,105,107-111
hyper_parallel/auto_parallel/sapp_ppb/__init__.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
recomputation policies across pipeline parallel stages.
"""
# pylint: disable=invalid-name,undefined-all-variable

from importlib import import_module as _import_module
import sys as _sys

_sys.modules.setdefault("sapp_ppb", _sys.modules[__name__])

__all__ = [
    "SappPipeline",
    "choose_interleave",
    "flatten",
    "SappSolver",
38
39
40
41
42
43
44
45
46
    "run",
    "main",
]

_EXPORTS = {
    "SappPipeline": ".sapp.sapp_pipeline",
    "choose_interleave": ".sapp.sapp_pipeline",
    "flatten": ".sapp.sapp_pipeline",
    "SappSolver": ".sapp.sapp_solver",
54
55
56
57
58
59
60
61
62
63
64
65
66
    "main": ".run_pipeline_balance",
}


def __getattr__(name):
    """Lazily import public SAPP-PPB interfaces."""
    if name not in _EXPORTS:
        raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

    module = _import_module(_EXPORTS[name], __name__)
    value = getattr(module, name)
    globals()[name] = value
    return value
hyper_parallel/auto_parallel/sapp_ppb/run_pipeline_balance.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CLI entrypoint for SAPP-PPB pipeline balancing."""
import argparse
import os
import sys

from sapp_ppb.sapp.sapp_pipeline import SappPipeline
from sapp_ppb.utils import interactive
from sapp_ppb.utils.compute_memory import compute_memories
from sapp_ppb.utils.config import initialize_layer_json
from sapp_ppb.utils.layer import generate_layers_list
from sapp_ppb.utils.logger import logger


def _str2bool(value: str) -> bool:
    """Parse a truthy string value coming from ``argparse``."""
    return str(value).lower() in ('true', '1', 'yes')


def build_arg_parser() -> argparse.ArgumentParser:
    """Build the argument parser for the pipeline-balance CLI.

    Returns:
        Configured :class:`argparse.ArgumentParser` instance.
35
36
37
38
39
40
41
42

    Returns:
        Configured :class:`argparse.ArgumentParser` instance.
    """
    parser = argparse.ArgumentParser(
        prog='SAPP AutoBalancing',
        description='Balance layers onto pipeline stages, considering recomputation and interleaving',
        epilog='')
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        description='Balance layers onto pipeline stages, considering recomputation and interleaving',
        epilog='')

    # Pipeline info
    parser.add_argument('-s', '--stage', type=int, default=4, help="Number of stages")
    parser.add_argument('-mb', '--micro_batch', type=int, default=4, help="Number of micro batch")
    parser.add_argument('-i', '--interleave_degree', type=int, default=1, help="Interleave level")

    # Memory size
    parser.add_argument('-mem', '--max_memory', type=int, default=56000,
                        help="Maximum memory available (MB)")
    parser.add_argument('-lm', '--less_memory', type=_str2bool, default=False,
                        help="Compute Memory with 'Less Memory interleave' option")
    parser.add_argument('-dual', '--dualpipe_v', type=_str2bool, default=False,
                        help="Compute Memory with 'DualpipeV' option")
    parser.add_argument('-mc', '--constant_memory', type=int, default=0,
                        help="Constant memory per stages")

    parser.add_argument('-o', '--output_folder', type=str, default="./output",
                        help="output files location")

    # Model info
    parser.add_argument('-m', '--model_name', type=str, default="model_name", help="")

    # Search time
    parser.add_argument('-t', '--time_limit', type=int, default=90,
                        help="Limitation on searching time")

    # Optimization level
    parser.add_argument('-O', '--optimization_level', type=int, default=1,
                        help="Defines optimization level when Stage (S) = Micro Batch number (M). "
                             "0 for same approach as M > S. "
                             "1 (default) generally better. "
                             "2 better for memory constrained cases.")
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
                             "1 (default) generally better. "
                             "2 better for memory constrained cases.")

    # Simulate naive or manual config
    parser.add_argument('-naive', '--simulate_naive', type=_str2bool, default=False,
                        help="Simulate naive configs")
    parser.add_argument('-manual', '--manual_config', type=str, default=None,
                        help="Path of manual config")

    # Layer info
    parser.add_argument('-lf', '--layer_folder', type=str, default="./layers/",
                        help="Path to the layer folder")
    parser.add_argument('-dump', '--dump_layer', type=_str2bool, default=False,
                        help="Dump the layers")

    # For Computation of memory
    parser.add_argument('-mf', '--memory_folder', type=str, default="./memory/",
                        help="Path to the profiler memory folder")

    # For Initialization
    parser.add_argument('-init', '--init', type=str, default=None,
                        help="Path to the init file")

    # Computation argument
    parser.add_argument('-cm', '--compute_memory', type=_str2bool, default=False,
                        help="Parse Mindspore log to generate MEMORY of the layer (unavailable)")
    parser.add_argument('-exec', '--exec', type=_str2bool, default=True,
                        help="Compute solver")
    return parser


def _resolve_path(base_dir: str, path: str) -> str:
    """Return ``path`` resolved relative to ``base_dir`` unless it is already absolute."""
    if os.path.isabs(path):
        return path
    return os.path.join(base_dir, path)


def run(args: argparse.Namespace, base_dir: str) -> None:
    """Execute the pipeline balancing workflow for the given arguments.

    Args:
        args (argparse.Namespace): Parsed CLI arguments.
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    Args:
        args (argparse.Namespace): Parsed CLI arguments.
        base_dir (str): Directory used to resolve relative input / output paths.
    """
    if args.init:
        init_file = _resolve_path(base_dir, args.init)
        initialize_layer_json(args.model_name, init_file)

    output_folder = _resolve_path(base_dir, args.output_folder)
    os.makedirs(output_folder, exist_ok=True)

    manual_config = None
    if args.manual_config:
        candidate = _resolve_path(base_dir, args.manual_config)
        if candidate.endswith(('yaml', 'yml')):
            manual_config = candidate

    layers = generate_layers_list(args.layer_folder, args.model_name)
    if args.compute_memory:
        layers = compute_memories(layers=layers, memory_folder=args.memory_folder,
                                  number_of_stage=args.stage)
    for layer in layers:
        logger.output("%s", layer)

    if args.dump_layer:
        for layer in layers:
            layer.dump()

    pipe = SappPipeline(model_name=args.model_name, num_of_stage=args.stage,
                        num_of_micro_batch=args.micro_batch, max_memory=args.max_memory,
                        layers=layers, num_of_interleave=args.interleave_degree,
                        vpp_less_memory=args.less_memory, dual=args.dualpipe_v,
                        constant_memory=args.constant_memory,
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
                        vpp_less_memory=args.less_memory, dual=args.dualpipe_v,
                        constant_memory=args.constant_memory,
                        optimization_level=args.optimization_level)

    pipe.construct_problem(solver="pulp")

    if args.exec:
        pipe.solve_problem(time_limit=args.time_limit, dump_folder=output_folder)
        pipe.print_yaml_results()
        total_time = pipe.simulate(show=True, file_name=os.path.join(output_folder, "result.svg"))

        logger.output("total_time: %d", total_time)
        logger.output("time: %s", pipe.get_time())
        logger.output("mem_par: %s", pipe.get_memory_parameter())
        logger.output("mem_act: %s", pipe.get_memory_activation())

        if manual_config:
            logger.output("Simulating manual configs")
            pipe.simulate_comparison(manual_config, output_folder)
        if args.simulate_naive:
            logger.output("Simulating naive configs")
            pipe.simulate_naive(layers, output_folder)
    elif manual_config:
        logger.output("Simulating manual configs")
        pipe.simulate_only_manual(manual_config, output_folder)


def main() -> None:
    """Entry point invoked when the module is run as a script."""
    if len(sys.argv) == 1:
        interactive.main()
        return
    parser = build_arg_parser()
    args = parser.parse_args()
    run(args, base_dir=os.path.dirname(os.path.abspath(__file__)))


if __name__ == "__main__":
    main()
hyper_parallel/auto_parallel/sapp_ppb/sapp/sapp_pipeline.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""High-level orchestrator around :class:`SappSolver`: build, solve, simulate, export YAML."""
import os
import sys
from typing import Any, Dict, List, Optional, Union

import matplotlib.pyplot as plt
import yaml

import sapp_ppb.simulator.pp_simulator as sim
import sapp_ppb.utils.recompute as Recompute
from sapp_ppb.sapp.sapp_solver import SappSolver
from sapp_ppb.utils.check_rules import check_yaml_depth_before_loading
from sapp_ppb.utils.layer import Layer, filter_layer_type
from sapp_ppb.utils.logger import logger


class SappPipeline:
    """pipeline balancer"""

    def __init__(
            self,
            model_name: str,
            num_of_stage: int,
            num_of_micro_batch: int,
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
                seqpp. Default: ``None``.
            seq_split_num (int, optional): Number of sequence splits; ``>1`` enables sequence pipeline.
                Default: ``1``.
        """
        self.model_name_ = model_name
        self.num_of_stage_ = num_of_stage
        self.num_of_micro_batch_ = num_of_micro_batch
        self.num_of_interleave_ = num_of_interleave
        self.max_memory_ = max_memory
        self.vpp_less_memory_ = vpp_less_memory
        # Add arg dual_
        self.dual_ = dual
        self.constant_memory_ = constant_memory
        self.optimization_level = optimization_level
        self.extracted_training_params_ = extracted_training_params
        self.seq_split_num_ = seq_split_num
        self.seqpipe_ = self.seq_split_num_ > 1
        # logger.output("seq chunk: %s",self.seq_split_num_)

        self.problem_ = None
        self.layers_ = layers
        self.layers_sorted_ = {
            Layer.type_enum.HEAD: filter_layer_type(layers,
                                                    Layer.type_enum.HEAD),
            Layer.type_enum.BODY: filter_layer_type(layers,
                                                    Layer.type_enum.BODY),
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
            Layer.type_enum.TAIL: filter_layer_type(layers,
                                                    Layer.type_enum.TAIL),
        }

    def has_some_memory_info(self) -> bool:
        """Check if there is all information for memory constraint."""
        return self.problem_.has_some_memory_info()

    def construct_problem(self, solver: str = "pulp") -> None:
        """Construct the underlying ILP problem using the requested solver backend."""
        if solver == "pulp":
            self.problem_ = self._construct_problem_pulp_()
        elif solver == "other":
            logger.warning(
                "No other solver available..., automatically switch to pulp!!!"
            )
            self.problem_ = self._construct_problem_pulp_()
        else:
            logger.warning(
                "No other solver available..., automatically switch to pulp!!!"
            )
            self.problem_ = self._construct_problem_pulp_()

    def solve_problem(self, time_limit: int = 90, dump_folder: Optional[str] = None) -> None:
        """Solve the ILP, optionally dumping the LP model into ``dump_folder``."""
        self.problem_.solve(time_limit, dump_folder)

    def get_result(self) -> dict[str, list[list[str]]]:
        """Get result distribution of the solution (compact form)."""
        return self.problem_.result()

    def get_memory_activation(self) -> list[float]:
        """Get the activation memory per stage for simulator."""
        return self.problem_.get_simulator_memory_activation()

    def get_memory_parameter(self) -> list[float]:
        """Get the parameter memory per stage for simulator."""
        return self.problem_.get_simulator_memory_parameter()

    def get_fw_time(self) -> list[float]:
        """Get the forward time per stage for simulator."""
        time = self.problem_.get_simulator_forward_time()
        return time

    def get_recompute_time(self) -> list[float]:
        """Get the recompute time per stage for simulator."""
        time = self.problem_.get_simulator_recompute_time()
        return time

    def get_time(self) -> list[float]:
        """Get the time per stage for simulator."""
        return self.problem_.get_simulator_time()

    def naive_layer_per_stage(self,
                              layer_num: int,
                              num_of_interleave: int = 1) -> List[List[int]]:
        """Return the naive layer-to-stage assignment (``layer_num`` evenly split)."""
        logger.output("layer_num = %s", layer_num)
        layer_count = layer_num // (self.num_of_stage_ * num_of_interleave)
        return [[layer_count] * self.num_of_stage_ for _ in range(num_of_interleave)]

    def print_yaml_results(self) -> None:
        """Log the solver output in the MindFormers YAML schema."""

        for layer in self.layers_sorted_[Layer.type_enum.BODY]:
            nass = self.naive_layer_per_stage(layer.nb_layer_,
                                              self.num_of_interleave_)
            yaml_format = Recompute.yaml_from_internal(
                self.num_of_interleave_,
                self.num_of_stage_,
                self.problem_.variables_[layer.name_],
                nass,
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
                self.num_of_stage_,
                self.problem_.variables_[layer.name_],
                nass,
            )
            logger.output("layer-to-stage assignment baseline is \n\t%s", nass)
            yaml_results = "\nTo put in yaml configuration:"
            for y, v in yaml_format.items():
                yaml_results += f"\n\t{y}: {v}"
            logger.output(yaml_results)

    def get_manual_memory_activation(
            self,
            each_layer_per_recompute: Dict[Layer, Dict[Recompute.TYPE, List[List[int]]]],
            interleave_num: int = 1) -> List[List[float]]:
        """Return the per-stage activation memory for a user-supplied layer assignment."""
        memory_active = []
        if self.has_some_memory_info():
            for inter in range(interleave_num):
                memory_active.append([])
                for stage in range(self.num_of_stage_):
                    memory_active[inter].append(sum(
                        each_layer_per_recompute[layer][rec][inter][stage] *
                        layer.memory_activation_rec_[rec]
                        for layer in self.layers_sorted_[Layer.type_enum.BODY]
                        for rec in Recompute.TYPE
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
                        for layer in self.layers_sorted_[Layer.type_enum.BODY]
                        for rec in Recompute.TYPE
                        if rec not in Recompute.get_unused_list(each_layer_per_recompute[layer])
                        and each_layer_per_recompute[layer][rec][inter][stage] > 0))
        return memory_active

    def get_manual_memory_parameter(
            self,
            each_layer_per_recompute: Dict[Layer, Dict[Recompute.TYPE, List[List[int]]]],
            interleave_num: int = 1) -> List[List[float]]:
        """Return the per-stage parameter memory for a user-supplied layer assignment."""
        memory_param_stage = [0] * self.num_of_stage_
        for inter in range(interleave_num):
            for stage in range(self.num_of_stage_):
                memory_param_stage[stage] += sum(
                    each_layer_per_recompute[layer][rec][inter][stage] *
                    layer.memory_parameter_ for rec in Recompute.TYPE
                    for layer in self.layers_sorted_[Layer.type_enum.BODY]
                    if layer.memory_parameter_ is not None
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
                    for layer in self.layers_sorted_[Layer.type_enum.BODY]
                    if layer.memory_parameter_ is not None
                    and rec not in Recompute.get_unused_list(each_layer_per_recompute[layer])
                    and each_layer_per_recompute[layer][rec][inter][stage] > 0)
        for head in self.layers_sorted_[Layer.type_enum.HEAD]:
            if head.memory_parameter_ is not None:
                memory_param_stage[0] += head.memory_parameter_
        for tail in self.layers_sorted_[Layer.type_enum.TAIL]:
            if tail.memory_parameter_ is not None:
                memory_param_stage[self.num_of_stage_ -
                                   1] += tail.memory_parameter_
        memory_param = [memory_param_stage] * interleave_num
        return memory_param

    def get_manual_time(
            self,
            each_layer_per_recompute: Dict[Layer, Dict[Recompute.TYPE, List[List[int]]]],
            interleave_num: int = 1) -> List[List[float]]:
        """Return the per-stage execution time for a user-supplied layer assignment."""
        time = []
        for i in range(interleave_num):
            time.append([])
            for s in range(self.num_of_stage_):
                time[i].append(0)
                for layer in self.layers_sorted_[Layer.type_enum.BODY]:
                    for r in Recompute.TYPE:
                        if each_layer_per_recompute[layer][r][i][s] > 0:
                            time[i][s] += each_layer_per_recompute[layer][r][i][s] * (
                                layer.forward_time_ +
                                layer.backward_time_rec_[r])

        for head in self.layers_sorted_[Layer.type_enum.HEAD]:
            time[0][0] += head.time_
        for tail in self.layers_sorted_[Layer.type_enum.TAIL]:
            time[interleave_num - 1][self.num_of_stage_ - 1] += tail.time_
        return time

    def get_manual_fw_time(
            self,
            each_layer_per_recompute: Dict[Layer, Dict[Recompute.TYPE, List[List[int]]]],
            interleave_num: int = 1) -> List[List[float]]:
        """Return the per-stage forward time for a user-supplied layer assignment."""
        time = []
        for i in range(interleave_num):
            time.append([])
            for s in range(self.num_of_stage_):
                time[i].append(0)
                for layer in self.layers_sorted_[Layer.type_enum.BODY]:
                    for r in Recompute.TYPE:
                        if (r not in Recompute.get_unused_list(each_layer_per_recompute[layer])
                            and each_layer_per_recompute[layer][r][i][s] > 0):
                            time[i][s] += each_layer_per_recompute[layer][r][i][s] * (
                                layer.forward_time_)
        for head in self.layers_sorted_[Layer.type_enum.HEAD]:
            time[0][0] += head.time_
        for tail in self.layers_sorted_[Layer.type_enum.TAIL]:
            time[interleave_num - 1][self.num_of_stage_ - 1] += tail.time_
        return time

    def get_manual_recompute_time(
            self,
            each_layer_per_recompute: Dict[Layer, Dict[Recompute.TYPE, List[List[int]]]],
            interleave_num: int = 1) -> List[List[float]]:
        """Return the per-stage recompute-only time for a user-supplied layer assignment."""
        logger.output("each_layer_per_recompute = %s", each_layer_per_recompute)
        time_all_rec = []
        time_no_rec = []
        for i in range(interleave_num):
            time_all_rec.append([])
            time_no_rec.append([])
            for s in range(self.num_of_stage_):
                time_all_rec[i].append(0)
                time_no_rec[i].append(0)
                for layer in self.layers_sorted_[Layer.type_enum.BODY]:
                    self._add_manual_recompute_time(
                        each_layer_per_recompute, layer, i, s, time_all_rec, time_no_rec)

        return [[r - n for r, n in zip(ar, nr)]
                for ar, nr in zip(time_all_rec, time_no_rec)]

    def _add_manual_recompute_time(self, each_layer_per_recompute, layer, interleave, stage,
                                   time_all_rec, time_no_rec):
        """Accumulate recompute time for a single layer and stage."""
        logger.output("backward_time_rec_(%s) = %s", layer, layer.backward_time_rec_)
        unused_rec = Recompute.get_unused_list(each_layer_per_recompute[layer])
        for rec in Recompute.TYPE:
            layer_num = each_layer_per_recompute[layer][rec][interleave][stage]
            if rec in unused_rec or layer_num <= 0:
                continue
            if layer.backward_time_rec_[rec] is None:
                raise ValueError("No backward tme is specified for this "
                                 "recomputation. Recomputation "
                                 f"'{Recompute.YAML_NAME[rec]}' is likely not considered")
            logger.output("r = %s; i = %s; s = %s", rec, interleave, stage)
            time_all_rec[interleave][stage] += layer_num * layer.backward_time_rec_[rec]
            time_no_rec[interleave][stage] += layer_num * layer.backward_time_rec_[Recompute.TYPE.NONE]

    def simulate(self, show: bool = True, file_name: Optional[str] = None,
                 sub_fig: Optional[plt.Figure] = None) -> float:
        """Run the simulator on the solved schedule and return its estimated total time."""
        forward_time = self.get_fw_time()
        recompute_overhead = self.get_recompute_time()
        stage_mem_par = 0
        stage_mem_act = 0
        if self.has_some_memory_info():
            stage_mem_par = self.get_memory_parameter()
            stage_mem_act = self.get_memory_activation()

        return self.simulation(
            forward_time,
            recompute_overhead,
            stage_mem_par,
            stage_mem_act,
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
            file_name,
            sub_fig
        )

    def simulate_naive(self, layers: List[Layer], output_folder: str) -> None:
        """Simulate the naive (even) layer-to-stage assignments for sanity comparison."""
        num_layers = 0
        rec_considered = {}
        for layer in layers:
            if layer.type_ == Layer.type_enum.BODY:
                num_layers = layer.nb_layer_
                rec_considered = layer.recompute_considered_

        all_recomp = {"offset": 0}
        no_recomp = {"offset": 0}
        for rec in [Recompute.TYPE.FULL, Recompute.TYPE.SLCT, Recompute.TYPE.COMM]:
            if rec_considered.get(rec, False):
                all_recomp[Recompute.YAML_NAME[rec]] = True
                no_recomp[Recompute.YAML_NAME[rec]] = False

        self.simulate_yaml(
            yaml_format=all_recomp,
            show=True,
            interleave_num=self.num_of_interleave_,
            file_name=os.path.join(output_folder,
339
340
341
342
343
344
345
346
347
348
            file_name=os.path.join(output_folder,
                                   "result_naive_all_recomp.svg"),
        )

        if num_layers % self.num_of_stage_ == 0:
            self.simulate_yaml(
                yaml_format=no_recomp,
                show=True,
                interleave_num=self.num_of_interleave_,
                file_name=os.path.join(output_folder,
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
                file_name=os.path.join(output_folder,
                                       "result_naive_no_recomp.svg"),
            )
        else:
            logger.warning("num layer cannot be divided by num stage")

    def simulate_comparison(self, manual_config_file: str, output_folder: str) -> None:
        """Render side-by-side automatic vs manual simulations for every entry in the YAML."""
        with open(manual_config_file, encoding="utf-8") as fp:
            check_yaml_depth_before_loading(fp)
            fp.seek(0)
            data = yaml.safe_load(fp)
        yaml_data = {}
        for manual in data.values():
            yaml_data[Recompute.OFFSET] = manual.get(Recompute.OFFSET)
            if isinstance(yaml_data[Recompute.OFFSET], list) and all(
                    isinstance(item, int) for item in yaml_data[Recompute.OFFSET]):
                yaml_data[Recompute.OFFSET] = [yaml_data[Recompute.OFFSET]]

            for rec in Recompute.YAML_NAME.values():
                yaml_data[rec] = manual.get(rec)
                if isinstance(yaml_data[rec], list) and all(
                        isinstance(item, int) for item in yaml_data[rec]):
                    yaml_data[rec] = [yaml_data[rec]]
            interleave_num = manual.get("interleave_num",
                                        self.num_of_interleave_)
            show = manual.get("show", False)
            file_name = manual.get("file_name")
            full_file_name = os.path.join(output_folder,
                                          file_name) if (file_name) else None

            fig = plt.figure(figsize=(24, 8))
            sub_figs = fig.subfigures(1, 2, wspace=0.07)
            sub_figs[0].suptitle('Automatic', fontsize='x-large')
            self.simulate(show=False, file_name=os.path.join(output_folder, "Auto_" + file_name), sub_fig=sub_figs[0])

            sub_figs[1].suptitle('Manual', fontsize='x-large')
            self.simulate_yaml(yaml_data, False, interleave_num, full_file_name, sub_figs[1])
            plt.savefig(os.path.join(output_folder, "Comparison_" + file_name))
            if show:
                plt.show()

    def simulate_only_manual(self, manual_config_file: str, output_folder: str) -> None:
        """Render only the manual simulation for every entry in ``manual_config_file``."""
        with open(manual_config_file, encoding="utf-8") as fp:
            check_yaml_depth_before_loading(fp)
            fp.seek(0)
            data = yaml.safe_load(fp)
        yaml_data = {}
        for manual in data.values():
            yaml_data[Recompute.OFFSET] = manual.get(Recompute.OFFSET)
            if isinstance(yaml_data[Recompute.OFFSET], list) and all(
                    isinstance(item, int) for item in yaml_data[Recompute.OFFSET]):
                yaml_data[Recompute.OFFSET] = [yaml_data[Recompute.OFFSET]]

            for rec in Recompute.YAML_NAME.values():
                yaml_data[rec] = manual.get(rec)
                if isinstance(yaml_data[rec], list) and all(
                        isinstance(item, int) for item in yaml_data[rec]):
                    yaml_data[rec] = [yaml_data[rec]]
            interleave_num = manual.get("interleave_num",
                                        self.num_of_interleave_)
            show = manual.get("show", False)
            file_name = manual.get("file_name")
            full_file_name = os.path.join(output_folder,
                                          file_name) if (file_name) else None

            fig = plt.figure(figsize=(12, 8))
            self.simulate_yaml(yaml_data, False, interleave_num, full_file_name, fig)
            plt.savefig(os.path.join(output_folder, "manual_file_" + file_name))
            if show:
                plt.show()

    def simulate_yaml(self, yaml_format: Dict[str, Any], show: bool = True,
                      interleave_num: int = 1,
                      file_name: Optional[str] = None,
                      sub_fig: Optional[plt.Figure] = None) -> float:
        """Simulate a manual pipeline configuration encoded as a YAML-compatible dict."""
        layer_num = 0
        for layer in self.layers_sorted_[Layer.type_enum.BODY]:
            layer_num += layer.nb_layer_
        nass = self.naive_layer_per_stage(layer_num,
                                          num_of_interleave=interleave_num)
        layer_per_recompute = Recompute.internal_from_yaml(
            interleave_num, self.num_of_stage_, yaml_format, nass)
        each_layer_per_recompute = self.split_layer_per_recompute(layer_per_recompute)
        return self.simulate_manual(
            each_layer_per_recompute,
            show,
            interleave_num=interleave_num,
            file_name=file_name,
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
    ##                                                                   ##
    ##                      Print Solver Model                           ##
    ##                                                                   ##
    #######################################################################
    def _calculate_activation_memory(self, each_layer_per_recompute, v, s):
        """Calculate activation memory for next and current stage"""
        act_mem_next = 0
        act_mem_curr = 0

        for layer in self.layers_sorted_[Layer.type_enum.BODY]:
            for rec in Recompute.TYPE:
                if self.problem_.recompute_considered_[rec]:
                    if each_layer_per_recompute[layer][rec][v + 1][s] > 0:  # next
                        act_mem_next += (each_layer_per_recompute[layer][rec][v + 1][s] *
                                         layer.memory_activation_rec_[rec])
                    if each_layer_per_recompute[layer][rec][v][s] > 0:    # current
                        act_mem_curr += (each_layer_per_recompute[layer][rec][v][s] *
                                         layer.memory_activation_rec_[rec])

        return act_mem_next, act_mem_curr

    def _compute_parameter_memory_manually_solver(self, each_layer_per_recompute, s, interleave_num=1):
        """Solver memory model: parameter memory"""
        param_mem = 0
        for layer in self.layers_sorted_[Layer.type_enum.BODY]:
            if layer.memory_parameter_ is not None:
                param_mem += self._calculate_layer_parameter_memory(
                    layer, each_layer_per_recompute[layer], s, interleave_num)
        return param_mem

    def _calculate_layer_parameter_memory(self, layer, layer_per_recompute, s, interleave_num):
        """Calculate parameter memory for a single layer"""
        layer_mem = 0
        for inter in range(interleave_num):
            for rec in Recompute.TYPE:
                if self.problem_.recompute_considered_[rec]:
                    if layer_per_recompute[rec][inter][s] > 0:
                        layer_mem += layer_per_recompute[rec][inter][s] * layer.memory_parameter_
        return layer_mem

    def _calculate_activation_memory_solver(self, each_layer_per_recompute, s, interleave_num, activation_nums):
        """Calculate activation memory for a given stage"""
        act_mem = 0
        for layer in self.layers_sorted_[Layer.type_enum.BODY]:
            for inter in range(interleave_num):
                for rec in Recompute.TYPE:
                    if self.problem_.recompute_considered_[rec]:
                        if each_layer_per_recompute[layer][rec][inter][s] > 0:
                            act_mem += (each_layer_per_recompute[layer][rec][inter][s] *
                                        layer.memory_activation_rec_[rec] *
                                        activation_nums[inter][s])
        return act_mem


    def debug_print_manual_theoretical_memory(
            self,
            each_layer_per_recompute: Dict[Layer, Dict[Recompute.TYPE, List[List[int]]]],
            interleave_num: int = 1) -> None:
        """Log the per-stage theoretical memory implied by the solver model (debug aid)."""
        logger.info("%s Manual Theoretical Memory Analysis %s", "=" * 20, "=" * 20)

        if self.vpp_less_memory_:
            if self.seqpipe_:
                activation_nums = self.problem_.compute_activation_seq_nums(
                    self.num_of_stage_, interleave_num, self.seq_split_num_, self.num_of_micro_batch_, True)
            else:
                activation_nums = self.problem_.compute_less_activation_nums(
                    self.num_of_stage_, interleave_num)
        # Add if dual to decide whether dualpipe_v is used
        elif self.dual_:
            activation_nums = self.problem_.compute_activation_nums_dual(
                self.num_of_stage_, interleave_num, self.num_of_micro_batch_)
        else:
            if self.seqpipe_:
                activation_nums = self.problem_.compute_activation_seq_nums(
                    self.num_of_stage_, interleave_num, self.seq_split_num_, self.num_of_micro_batch_, False)
            else:
                activation_nums = self.problem_.compute_activation_nums(
                    self.num_of_stage_, interleave_num, self.num_of_micro_batch_)

        logger.info("Activation nums = %s", activation_nums)

        # compute for each stage
        for s in range(self.num_of_stage_):

            # parameter memory
            param_mem = self._compute_parameter_memory_manually_solver(each_layer_per_recompute, s, interleave_num)

            # head memory
            if s == 0:
                for head in self.layers_sorted_[Layer.type_enum.HEAD]:
                    if head.memory_parameter_ is not None:
                        param_mem += head.memory_parameter_

            # tail memory
            if s == self.num_of_stage_ - 1:
                for tail in self.layers_sorted_[Layer.type_enum.TAIL]:
                    if tail.memory_parameter_ is not None:
                        param_mem += tail.memory_parameter_

            # act memory
            act_mem = self._calculate_activation_memory_solver(each_layer_per_recompute, s,
                                                               interleave_num, activation_nums)

            # overhead
            overhead = 0

            total = param_mem + act_mem + overhead + self.constant_memory_

            logger.info("Stage %d Manual Memory Analysis:", s)
            logger.info("Parameter Memory:     %.2f", param_mem)
            logger.info("Activation Memory:    %.2f", act_mem)
            logger.info("Memory Overhead:      %.2f", overhead)
            logger.info("Constant Memory:      %.2f", self.constant_memory_)
            logger.info("Total Theoretical Memory: %.2f", total)

    def split_layer_per_recompute(
            self,
            layer_per_recompute: Dict[Recompute.TYPE, List[List[int]]]
    ) -> Dict[Layer, Dict[Recompute.TYPE, List[List[int]]]]:
        """Split aggregate per-recompute layer counts into counts per BODY layer."""
        each_layer_per_recompute = {}
        for layer in self.layers_sorted_[Layer.type_enum.BODY]:
            rest = layer.nb_layer_
            each_layer_per_recompute[layer] = {r: [] for r in Recompute.TYPE}
            for rec in Recompute.TYPE:
                for i in range(self.num_of_interleave_):
                    each_layer_per_recompute[layer][rec].append([0]*self.num_of_stage_)
                    for s in range(self.num_of_stage_):
                        subtract = min(layer_per_recompute[rec][i][s], rest)
                        layer_per_recompute[rec][i][s] -= subtract
                        rest -= subtract
                        each_layer_per_recompute[layer][rec][i][s] += subtract
        return each_layer_per_recompute

    def fuse_layer_per_recompute(
            self,
            each_layer_per_recompute: Dict[Layer, Dict[Recompute.TYPE, List[List[int]]]]
    ) -> Dict[Recompute.TYPE, List[List[int]]]:
        """Fuse per-layer recompute counts back into aggregate per-recompute-type totals."""
        all_layers_per_recompute = {r: [] for r in Recompute.TYPE}
        for rec in Recompute.TYPE:
            for i in range(self.num_of_interleave_):
                all_layers_per_recompute[rec].append([])
                for s in range(self.num_of_stage_):
                    all_layers_per_recompute[rec][i].append(sum(
                        each_layer_per_recompute[layer][rec][i][s]
                        for layer in self.layers_sorted_[Layer.type_enum.BODY]
                    ))
        return all_layers_per_recompute


    def simulate_manual(
            self,
            each_layer_per_recompute: Optional[Dict[Layer, Dict[Recompute.TYPE, List[List[int]]]]] = None,
            show: bool = True,
            interleave_num: int = 1,
599
600
601
602
603
604
605
606
607
608
609
610
611
612
            interleave_num: int = 1,
            file_name: Optional[str] = None,
            sub_fig: Optional[plt.Figure] = None) -> float:
        """Run the simulator on a user-supplied per-layer recompute strategy."""
        logger.output("Simulating given strategy: %s", each_layer_per_recompute)

        for layer in self.layers_sorted_[Layer.type_enum.BODY]:
            for rec in Recompute.TYPE:
                if len(each_layer_per_recompute[layer][rec]) != interleave_num:
                    logger.error(
                        "For layer %s with recompute %s, %s does not match interleave number %s",
                        layer,
                        rec,
                        len(each_layer_per_recompute[layer][rec]),
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
                        rec,
                        len(each_layer_per_recompute[layer][rec]),
                        interleave_num,
                    )
                    return sys.maxsize

        for layer in self.layers_sorted_[Layer.type_enum.BODY]:
            for rec in Recompute.TYPE:
                if any(x < 0 for sublist in each_layer_per_recompute[layer][rec]
                    for x in sublist):
                    raise ValueError(
                        f"for {rec}, there is strategy less than 0 in "
                        f"{each_layer_per_recompute[layer][rec]}"
                    )

        forward_time = self.get_manual_fw_time(each_layer_per_recompute,
                                               interleave_num)
        recompute_overhead = self.get_manual_recompute_time(
            each_layer_per_recompute, interleave_num)
        stage_mem_par = 0
        stage_mem_act = 0
        if self.has_some_memory_info():
            stage_mem_par = self.get_manual_memory_parameter(
                each_layer_per_recompute, interleave_num=interleave_num)
            stage_mem_act = self.get_manual_memory_activation(
                each_layer_per_recompute, interleave_num=interleave_num)

        self.debug_print_manual_theoretical_memory(each_layer_per_recompute, interleave_num)

        return self.simulation(
            forward_time,
            recompute_overhead,
            stage_mem_par,
            stage_mem_act,
647
648
649
650
651
652
653
654
655
            file_name,
            sub_fig
        )

    def simulation(
            self,
            forward_time: List[List[float]],
            recompute_overhead: Union[int, List[List[float]]] = 0,
            stage_mem_par: Union[int, List[List[float]]] = 0,
659
660
661
662
663
664
665
666
667
668
            file_name: Optional[str] = None,
            sub_fig: Optional[plt.Figure] = None,
    ) -> float:
        """Run the low-level :class:`PipelineSimulator` and return its reported end time."""
        if self.has_some_memory_info():
            logger.output(
                "PipelineSimulator(\n\t%s, %s,"
                "\n\tblock_mem_act=%s,"
                "\n\tblock_mem_par=%s,"
                "\n\tlayer_recompute=%s,"
674
675
676
677
678
679
680
681
682
683
                recompute_overhead,
                self.vpp_less_memory_,
            )

            sim_method = "vpp2" if self.vpp_less_memory_ else "vpp"
            simulator = sim.PipelineSimulator(
                forward_time,
                self.num_of_micro_batch_,
                block_mem=stage_mem_act,
                block_mem_par=stage_mem_par,
686
687
688
689
690
691
692
693
694
                method=sim_method,
                sub_fig=sub_fig
            )
        else:
            logger.output(
                "PipelineSimulator(\n\t%s, %s,"
                "\n\tlayer_recompute=%s)"
                "\n\tless_memory=%s )",
                forward_time,
695
696
697
698
699
700
701
702
703
                self.num_of_micro_batch_,
                recompute_overhead,
                self.vpp_less_memory_,
            )
            simulator = sim.PipelineSimulator(
                forward_time,
                self.num_of_micro_batch_,
                layer_recompute=recompute_overhead,
                less_memory=self.vpp_less_memory_,
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
                less_memory=self.vpp_less_memory_,
                sub_fig=sub_fig
            )

        simulator.run(comm=False)
        if file_name:
            simulator.save(file_name)
        if show:
            simulator.show()
        return simulator.end_time

    def _construct_problem_pulp_(self) -> SappSolver:
        """construct the problem using pulp"""
        prob = SappSolver(
            num_of_stage=self.num_of_stage_,
            num_of_micro_batch=self.num_of_micro_batch_,
            num_of_interleave=self.num_of_interleave_,
            max_memory=self.max_memory_,
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
            optimization_level=self.optimization_level,
            extracted_training_params=self.extracted_training_params_,
            seq_split_num=self.seq_split_num_
        )
        return prob

    def _recompute_considered(self):
        return self.problem_.recompute_considered_


def choose_interleave(
        model_name: str,
        number_of_stage: int,
        number_of_micro_batch: int,
        max_memory: int,
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
        max_memory: int,
        layers: list[Layer],
) -> tuple[int, int, dict[str, list[list[str]]]]:
    """Simulates different interleaves and returns the best."""
    max_inter = 4
    best_time = int(sys.maxsize)
    best_inter = 1
    best_distribution = {}

    for inter in range(1, max_inter + 1):
        pipe = SappPipeline(
            model_name=model_name,
            num_of_stage=number_of_stage,
            num_of_micro_batch=number_of_micro_batch,
            max_memory=max_memory,
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
            layers=layers,
            num_of_interleave=inter,
        )

        pipe.construct_problem(solver="pulp")
        pipe.solve_problem()
        time = pipe.simulate(show=False)
        logger.output("for interleave %s, time = %s", inter, time)
        if time < best_time:
            best_time = time
            best_inter = inter
            best_distribution = pipe.get_result()

    return (best_inter, best_time, best_distribution)


def flatten(inter_stage_list: List[List[float]]) -> List[float]:
    """Collapse an ``[interleave][stage]`` matrix into a per-stage list via summation."""
    stage_list = [0] * len(inter_stage_list[0])
    for inter, _ in enumerate(inter_stage_list):
        for stage, _ in enumerate(inter_stage_list[inter]):
            stage_list[stage] += inter_stage_list[inter][stage]
    return stage_list
hyper_parallel/auto_parallel/sapp_ppb/sapp/sapp_solver.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# limitations under the License.
# ============================================================================
"""Solver Class"""

import os
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Dict, List, Optional

import pulp as lpSolver

import sapp_ppb.utils.recompute as Recompute
from sapp_ppb.utils.layer import Layer
from sapp_ppb.utils.logger import logger

# seqpipe const
tensor_float_16 = 2
tensor_float_32 = 4
const_from_byte_to_mb = 1024 * 1024
# llama intermideate_size
llama_intermideate_size = 11008


@dataclass
class PipelineMemoryConstraint:
    """constraint struct"""
    prob: Any
    variables: Any
    layers_sorted: dict[Any]
    num_of_stage: int
    num_of_interleave: int
    micro_batch: int
    memory_limit: int

class SappSolver:
    """solver for pipeline balance"""

    BIG_M = 1000000

    MEM_OVERHEAD_NAME = "memory_overhead"
    TOTAL_SUM = "var_sum_FPi_BPi"
    CHUNKS_SUM = "chunks_sum"
    PREV_DIFF = "prev_diff"
    NEXT_DIFF = "next_diff"
    MAX_STAGE_TIME = "max_stage_time"
    MAX_LAST_CHUNK = "max_last_chunk"
    LAYER_FRONTIER = "layer_frontier"
    REC_FRONTIER = "recompute_frontier"
    PROP_PHASE = IntEnum("Propagation", ["FW", "BW"], start=0)

    def __init__(
            self,
            num_of_stage: int,
            num_of_interleave: int,
            num_of_micro_batch: int,
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
            extracted_training_params: Optional training params for sequence-pipeline mode.
            seq_split_num: Number of sequence splits (``>1`` enables sequence pipeline).
        """

        self.num_of_stage_ = num_of_stage
        self.num_of_interleave_ = num_of_interleave
        self.num_of_micro_batch_ = num_of_micro_batch
        self.max_memory_ = max_memory
        self.vpp_less_memory_ = vpp_less_memory
        # Add dualpipe_v
        self.dual_ = dual
        self.constant_memory_ = constant_memory
        self.optimization_level_ = optimization_level
        self.layers_ = layers
        self.layers_sorted_ = layers_sorted

        self.recompute_considered_ = self.find_recompute_considered(
            layers_sorted)
        self.extracted_training_params_ = extracted_training_params
        self.seq_split_num_ = seq_split_num
        self.seq_pipe = self.seq_split_num_ > 1
        if self.seq_pipe:
            self._initialize_seq_pipe_layers()

        self.variables_ = self._create_variables_to_solve_(
            num_of_stage, num_of_interleave, layers_sorted)
        self.problem_ = self._create_problem_(description)

    def _initialize_seq_pipe_layers(self):
        """Update memory and time metadata for sequence pipeline mode."""
        self._update_seq_pipe_memory()
        self.num_of_micro_batch_ *= self.seq_split_num_
        self._update_seq_pipe_time()

    def _update_seq_pipe_memory(self):
        """Update layer memory values for sequence pipeline mode."""
        for layer in self.layers_sorted_[Layer.type_enum.BODY]:
            self._update_body_seq_memory(layer)
        for head in self.layers_sorted_[Layer.type_enum.HEAD]:
            self._update_head_seq_memory(head)
        for tail in self.layers_sorted_[Layer.type_enum.TAIL]:
            self._update_tail_seq_memory(tail)

    def _update_body_seq_memory(self, layer):
        """Update body layer memory values for sequence pipeline mode."""
        if layer.memory_parameter_ is not None:
            logger.info("Body Layer 1f1b Parameter Memory: %s", layer.memory_parameter_)
            layer.memory_parameter_ = self.compute_seq_mem_parameter(
                layer.memory_parameter_, self.extracted_training_params_)
            logger.info("Body Layer Seq Parameter Memory: %s", layer.memory_parameter_)
        for rec in Recompute.TYPE:
            if not self.recompute_considered_[rec]:
                continue
            if rec.name == "FULL":
                self.recompute_considered_[rec] = False
                layer.memory_activation_rec_[rec] = None
                logger.error("Seqpipe doesn't support full recomputation, "
                             "recompute_activation is set as None for seqpp")
                continue
            logger.info(
                "Body Layer 1f1b %s activation Memory: %s",
                rec,
                layer.memory_activation_rec_[rec],
            )
            layer.memory_activation_rec_[rec] = self.compute_seq_mem_activation(
                layer.memory_activation_rec_[rec],
                self.extracted_training_params_,
                self.seq_split_num_
            )
            logger.info(
                "Body Layer seq %s activation Memory: %s",
                rec,
                layer.memory_activation_rec_[rec],
            )
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
                rec,
                layer.memory_activation_rec_[rec],
            )

    def _update_head_seq_memory(self, head):
        """Update head layer memory values for sequence pipeline mode."""
        if head.memory_parameter_ is None:
            return
        logger.info("Head cost 1f1b: %s", head.memory_parameter_)
        head.memory_parameter_ = self.compute_seq_mem_head_cost(
            head.memory_parameter_, self.extracted_training_params_, self.seq_split_num_)
        logger.info("Head cost Seq: %s", head.memory_parameter_)

    def _update_tail_seq_memory(self, tail):
        """Update tail layer memory values for sequence pipeline mode."""
        if tail.memory_parameter_ is None:
            return
        logger.info("Tail cost 1f1b: %s", tail.memory_parameter_)
        tail.memory_parameter_ = self.compute_seq_mem_tail_cost(
            tail.memory_parameter_, self.extracted_training_params_, self.seq_split_num_)
        logger.info("Tail cost seq: %s", tail.memory_parameter_)

    def _update_seq_pipe_time(self):
        """Update layer times for sequence pipeline mode."""
        for layer in self.layers_sorted_[Layer.type_enum.BODY]:
            self._update_layer_seq_time(layer, "Body")
        for head in self.layers_sorted_[Layer.type_enum.HEAD]:
            self._update_layer_seq_time(head, "Head")
        for tail in self.layers_sorted_[Layer.type_enum.TAIL]:
            self._update_layer_seq_time(tail, "Tail")

    def _update_layer_seq_time(self, layer, layer_name):
        """Scale one layer's time by the sequence split number."""
        logger.info("%s Layer 1f1b fp time: %s", layer_name, layer.forward_time_)
        logger.info("%s Layer 1f1b bp time:", layer_name)
        for key, value in layer.backward_time_rec_.items():
            logger.output("%s: %s", key, value)
        layer.time_ = layer.time_ / self.seq_split_num_
        layer.forward_time_ = layer.forward_time_ / self.seq_split_num_
        layer.update_internal_time_for_seqpp()
        logger.info("%s Layer seq fp time: %s", layer_name, layer.forward_time_)
        logger.info("%s Layer seq bp time:", layer_name)
        for key, value in layer.backward_time_rec_.items():
            logger.output("%s: %s", key, value)

    @staticmethod
    def compute_forward_in_backward(num_of_stage: int,
                                    micro_batch: int) -> list[int]:
        """Computes the number of forward propagation happening after a backward"""
        n = num_of_stage - 1
        factors = []
        for _ in range(num_of_stage):
            factors.append(abs(n))
            n -= 2
        if micro_batch < 2 * num_of_stage:
            for i in range(num_of_stage // 2):
                factors[i] = 0
        return factors

    @staticmethod
    def compute_lm_forward_in_backward(num_of_stage: int) -> list[int]:
        """Function compute_forward_in_backward in less_memory schedule"""
        return list(range(num_of_stage))

    @staticmethod
    def compute_activation_nums(num_of_stage: int, num_of_interleave: int,
                                micro_batch: int) -> list[list[int]]:
        """compute the number of activation"""
        activation_nums = []

        if num_of_interleave > 1:
            for i in range(num_of_interleave):
                activation_nums.append([])
                for _ in range(num_of_stage):
                    activation_nums[i].append(num_of_stage)
            for s in range(num_of_stage):
                activation_nums[0][s] += max(0, num_of_stage - 2 * s - 1)
            for s in range(num_of_stage):
                activation_nums[num_of_interleave - 1][s] += min(
                    0, num_of_stage - 2 * s - 1)
            for i in range(num_of_interleave):
                for s in range(num_of_stage):
                    activation_nums[i][s] = min(activation_nums[i][s],
                                                micro_batch)
        else:
            for i in range(num_of_interleave):
                activation_nums.append([])
                for s in range(num_of_stage):
                    activation_nums[i].append(num_of_stage - s)

        return activation_nums

    @staticmethod
    def compute_activation_nums_dual(num_of_stage: int, num_of_interleave: int,
                                     micro_batch: int) -> list[list[int]]:
        """compute the number of activation for dualpipe_v"""
        activation_nums = []

        for i in range(num_of_interleave):
            activation_nums.append([])
            for _ in range(num_of_stage):
                activation_nums[i].append(0)
        for s in range(num_of_stage):
            activation_nums[0][s] += max(0, 2 * num_of_stage - s)
        for s in range(num_of_stage):
            activation_nums[num_of_interleave - 1][s] += max(
                    0, s + 1)
        for i in range(num_of_interleave):
            for s in range(num_of_stage):
                activation_nums[i][s] = min(activation_nums[i][s],
                                            micro_batch)

        return activation_nums

    @staticmethod
    def compute_less_activation_nums(
            num_of_stage: int, num_of_interleave: int) -> list[list[int]]:
        """compute number of less_mem activation"""
        activation_nums = []
        if num_of_interleave > 1:
            for i in range(num_of_interleave):
                activation_nums.append([])
                for _ in range(num_of_stage):
                    activation_nums[i].append(num_of_stage)
            for s in range(num_of_stage):
                activation_nums[num_of_interleave - 1][s] -= s
        else:
            for i in range(num_of_interleave):
                activation_nums.append([])
                for s in range(num_of_stage):
                    activation_nums[i].append(num_of_stage - s)
        return activation_nums

    #######################################################################
    ##                                                                   ##
    ##                            SeqPipe                                ##
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
    ##                                                                   ##
    ##                            SeqPipe                                ##
    ##                                                                   ##
    #######################################################################
    @staticmethod
    def compute_activation_seq_nums(num_of_stage: int, num_of_interleave: int,
                                    seq_split_num: int, micro_batch: int, less_memory: False) -> list[list[int]]:
        """compute the number of activation for seq chunks"""
        activation_nums = []
        if less_memory:
            act_gap = 1
        else:
            act_gap = 2
        if num_of_interleave > 1:
            for i in range(num_of_interleave):
                activation_nums.append([])
                for _ in range(num_of_stage):
                    activation_nums[i].append(num_of_stage)
            for s in range(num_of_stage):
                activation_nums[num_of_interleave - 1][s] = seq_split_num

            loop_index = 1
            for stage_index in range(num_of_stage - 2, -1, -1):
                flag_added = False
                for chunk_index in range(num_of_interleave):
                    condition1 = activation_nums[chunk_index][stage_index + 1] % num_of_stage != 0
                    condition2 = activation_nums[chunk_index][stage_index + 1] // num_of_stage < loop_index
                    if condition1 or condition2:
                        for update in range(stage_index + 1):
                            activation_nums[chunk_index][update] += act_gap
                        flag_added = True
                        break
                if not flag_added:
                    for update in range(stage_index + 1):
                        activation_nums[0][update] += act_gap
                    loop_index += 1
            # microbatch
            for i in range(num_of_interleave):
                for s in range(num_of_stage):
                    activation_nums[i][s] = min(activation_nums[i][s],
                                                micro_batch)
        else:
            for i in range(num_of_interleave):
                activation_nums.append([])
                for s in range(num_of_stage):
                    activation_nums[i].append(num_of_stage - s + seq_split_num - 1)

        logger.output("compute_activation_seq_nums: %s", activation_nums)
        return activation_nums

    @staticmethod
    def compute_seq_mem_activation(original_memory_activation: float,
                                   extracted_training_params: dict[str, int],
                                   seq_split_num: int) -> float:
        """compute activation memory for seqpipe"""
        # context parallel? cp?
        batch_size = extracted_training_params['batch_size']
        heads = extracted_training_params['num_heads']
        seq_length = extracted_training_params['seq_length']
        head_dim = extracted_training_params['head_dim']
        mp = extracted_training_params['model_parallel']
        # cp = extracted_training_params['context_parallel']
        # 2*Kv add
        # cp?
        kv_update_mem_byte = 2 * ((tensor_float_16 * batch_size * heads * seq_length * head_dim) / (mp))
        kv_update_mem = kv_update_mem_byte / const_from_byte_to_mb
        # Attention Key,Value
        # cp?
        key_mem_byte = (tensor_float_16 * batch_size * heads * seq_length * head_dim) / (mp)
        key_mem = key_mem_byte / const_from_byte_to_mb
        # cp?
        value_mem_byte = (tensor_float_16 * batch_size * heads * seq_length * head_dim) / (mp)
        value_mem = value_mem_byte / const_from_byte_to_mb

        seq_memory_activation = (original_memory_activation - key_mem - value_mem) / seq_split_num + kv_update_mem
        return seq_memory_activation

    @staticmethod
    def compute_seq_mem_parameter(original_memory_parameter: float, extracted_training_params: dict[str, int]) -> float:
        """compute layer parameter memory for seqpipe"""
        # context parallel? cp?
        batch_size = extracted_training_params['batch_size']
        heads = extracted_training_params['num_heads']
        seq_length = extracted_training_params['seq_length']
        head_dim = extracted_training_params['head_dim']
        mp = extracted_training_params['model_parallel']
        # cp = extracted_training_params['context_parallel']
        # cp?
        kv_cache_parameter_mem_byte = 4 * (tensor_float_16 * batch_size * heads * seq_length * head_dim / (mp))
        kv_cache_parameter_mem = kv_cache_parameter_mem_byte / const_from_byte_to_mb
        seq_memory_parameter = original_memory_parameter + kv_cache_parameter_mem
        return seq_memory_parameter

    @staticmethod
    def compute_seq_mem_head_cost(original_head_cost: float,
                                  extracted_training_params: dict[str, int],
                                  seq_split_num: int) -> float:
        """compute head stage extra cost for seqpipe"""
        batch_size = extracted_training_params['batch_size']
        seq_length = extracted_training_params['seq_length']
        hidden_size = extracted_training_params['hidden_size']
        mp = extracted_training_params['model_parallel']
        # cp = extracted_training_params['context_parallel']
        if mp > 1:
            # comm operator Mem (recv+reduceScatter)
            # cp?
            comm_operator_mem_byte = 2 * (tensor_float_16 * batch_size * seq_length * hidden_size / (mp))
            comm_operator_mem = comm_operator_mem_byte / const_from_byte_to_mb
            # StridedSliceGrad Operator Mem
            stridslice_operator_mem_byte = tensor_float_16 * batch_size * seq_length * hidden_size
            stridslice_operator_mem = stridslice_operator_mem_byte / const_from_byte_to_mb
            seq_head_cost = original_head_cost - (1 - 1 / seq_split_num) * (comm_operator_mem + stridslice_operator_mem)
        else:
            # comm operator Mem (recv)
            # cp?
            comm_operator_mem_byte = tensor_float_16 * batch_size * seq_length * hidden_size / (mp)
            comm_operator_mem = comm_operator_mem_byte / const_from_byte_to_mb
            # Grad/MatMul // Grad/Mul Operator Mem
            # cp?
            mul_operator_mem_byte = 1 * (tensor_float_16 * batch_size * seq_length * llama_intermideate_size / (mp))
            mul_operator_mem = mul_operator_mem_byte / const_from_byte_to_mb
            seq_head_cost = original_head_cost - (1 - 1 / seq_split_num) * (comm_operator_mem + mul_operator_mem)
        return seq_head_cost

    @staticmethod
    def compute_seq_mem_tail_cost(original_tail_cost: float,
                                  extracted_training_params: dict[str, int],
                                  seq_split_num: int) -> float:
        """compute tail stage extra cost for seqpipe"""
        batch_size = extracted_training_params['batch_size']
        seq_length = extracted_training_params['seq_length']
        vocab_size = extracted_training_params['vocab_size']
        mp = extracted_training_params['model_parallel']
        # cp = extracted_training_params['context_parallel']
        # Memory extra introduced by loss op:
        # cp?
        loss_operator_mem_byte = tensor_float_32 * batch_size * seq_length * vocab_size / (mp)
        loss_operator_mem = loss_operator_mem_byte / const_from_byte_to_mb
        # New tail Cost = Old tail Cost - (3-3/k)M + (k-1)(M/k)
        seq_tail_cost = original_tail_cost - (3 - 3 / seq_split_num) * loss_operator_mem + (
            seq_split_num - 1) * (loss_operator_mem / seq_split_num)
        return seq_tail_cost

    def add_total_nb_layer_constraint(self, prob: Any, variables: Any,
                                      sorted_layers: Dict[Layer.type_enum, list[Layer]]) -> Any:
        """Enforce that the sum of assigned layers equals ``layer.nb_layer_`` per BODY layer."""
        for layer in sorted_layers[Layer.type_enum.BODY]:
            prob += (lpSolver.lpSum(
                variables[layer.name_][rec] for rec in Recompute.TYPE
                if self.recompute_considered_[rec]) == layer.nb_layer_)
        return prob

    def add_stage_nb_layer_constraint(self, prob: Any, variables: Any,
                                      sorted_layers: Dict[Layer.type_enum, List[Layer]]) -> Any:
        """Require each non-reserved ``(interleave, stage)`` cell to host at least one layer."""
        layer_type_num = len(sorted_layers[Layer.type_enum.BODY])
        reserved_positions = self._reserved_stage_positions()
        for i in range(self.num_of_interleave_):
            for s in range(self.num_of_stage_):
                if (i, s) in reserved_positions:
                    continue
                prob += (lpSolver.lpSum(variables[
                    sorted_layers[Layer.type_enum.BODY][ll].name_][rec][i][s]
                    for rec in Recompute.TYPE
                    if self.recompute_considered_[rec]
                    for ll in range(layer_type_num)) >= 1)
        return prob

    def _reserved_stage_positions(self):
        """Return stage positions reserved for head and tail layers."""
        if self.dual_:
            return {(0, 0), (1, 0)}
        return {(0, 0), (self.num_of_interleave_ - 1, self.num_of_stage_ - 1)}

    def add_multimodal_sequence_constraint(
            self, prob: Any, variables: Any,
            sorted_layers: Dict[Layer.type_enum, List[Layer]]) -> Any:
        """Enforce a stage frontier between successive BODY layer types (multimodal models)."""
        for frontier in range(1, len(sorted_layers[Layer.type_enum.BODY])):
            layer = sorted_layers[Layer.type_enum.BODY][frontier].name_
            for v in range(self.num_of_interleave_):
                for s in range(self.num_of_stage_):
                    prob = self._add_frontier_lower_bound(prob, variables, layer, frontier, v, s)
        return self._add_frontier_upper_bounds(prob, variables, sorted_layers)

    def _add_frontier_lower_bound(self, prob, variables, layer, frontier, interleave, stage):
        """Add the lower bound for one multimodal frontier variable."""
        frontier_sum = self._frontier_layer_sum(variables, layer, interleave, stage)
        if frontier_sum is None:
            return prob
        prob += (
            variables[self.LAYER_FRONTIER][frontier - 1][interleave][stage]
            >= frontier_sum / self.BIG_M
        )
        return prob

    def _frontier_layer_sum(self, variables, layer, interleave, stage):
        """Build the layer sum used by multimodal frontier constraints."""
        if self.dual_:
            return self._dual_frontier_layer_sum(variables, layer, interleave, stage)
        return self._current_layer_sum(variables, layer, interleave, range(stage)) + (
            self._previous_layer_sum(variables, layer, interleave)
        )

    def _dual_frontier_layer_sum(self, variables, layer, interleave, stage):
        """Build the layer sum for dualpipe_v multimodal frontier constraints."""
        if interleave == 0:
            return self._current_layer_sum(variables, layer, interleave, range(stage))
        if interleave == 1:
            return self._current_layer_sum(variables, layer, interleave, range(stage, self.num_of_stage_)) + (
                self._previous_layer_sum(variables, layer, interleave)
            )
        return None

    def _current_layer_sum(self, variables, layer, interleave, stage_range):
        """Sum current interleave variables over a stage range."""
        return lpSolver.lpSum(
            variables[layer][rec][interleave][stage]
            for rec in Recompute.TYPE
            if self.recompute_considered_[rec]
            for stage in stage_range
518
519
520
521
522
523
524
525
526
527
528
            if self.recompute_considered_[rec]
            for stage in stage_range
        )

    def _previous_layer_sum(self, variables, layer, interleave):
        """Sum variables from previous interleaves."""
        return lpSolver.lpSum(
            variables[layer][rec][prev_interleave][stage]
            for rec in Recompute.TYPE if self.recompute_considered_[rec]
            for prev_interleave in range(interleave)
            for stage in range(self.num_of_stage_)
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
            for prev_interleave in range(interleave)
            for stage in range(self.num_of_stage_)
        )

    def _add_frontier_upper_bounds(self, prob, variables, sorted_layers):
        """Prevent previous body layer types after each multimodal frontier."""
        for frontier in range(1, len(sorted_layers[Layer.type_enum.BODY])):
            layer = sorted_layers[Layer.type_enum.BODY][frontier - 1].name_
            for stage in range(self.num_of_stage_):
                for interleave in range(self.num_of_interleave_):
                    prob = self._add_frontier_upper_bound(prob, variables, layer, frontier, interleave, stage)
        return prob

    def _add_frontier_upper_bound(self, prob, variables, layer, frontier, interleave, stage):
        """Add one upper bound constraint for a multimodal frontier."""
        for rec in Recompute.TYPE:
            if self.recompute_considered_[rec]:
                prob += variables[layer][rec][interleave][stage] <= (
                    1 - variables[self.LAYER_FRONTIER][frontier - 1][interleave][stage]
                ) * self.BIG_M
        return prob

    def add_multimodal_recompute_constraint(
            self, prob: Any, variables: Any,
            sorted_layers: Dict[Layer.type_enum, List[Layer]]) -> Any:
        """Keep recomputation schemes consistent across BODY layer types (MindFormer constraint)."""
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619

        # if (self.recompute_considered_[Recompute.TYPE.FULL] and
        #     self.recompute_considered_[Recompute.TYPE.FULL]):

        considered = Recompute.get_used_list(self.recompute_considered_)
        if len(considered) > 2:
            logger.error("Careful: MindFormer does not allow a fine recomputation scheme "
                         "for heterogeneous models. Pipeline balancing is currently unable to "
                         "comply with MF constraint for more than 1 recomputation type.")
            return prob

        if len(considered) < 2:
            # this constraint is unnecessary if there is no recomputation
            return prob

        most_rec = max(considered)
        layer_type_num = len(sorted_layers[Layer.type_enum.BODY])
        for v in range(self.num_of_interleave_):
            for s in range(self.num_of_stage_):
                for rec in Recompute.TYPE:
                    if self.recompute_considered_[rec] and rec is not Recompute.TYPE.NONE:
                        for layer_idx in range(0, layer_type_num - 1):
                            prob += variables[self.REC_FRONTIER][v][s][layer_idx] >= (
                                lpSolver.lpSum(
                                    variables[sorted_layers[Layer.type_enum.BODY][next_idx].name_][most_rec][v][s]
                                    for next_idx in range(layer_idx + 1, layer_type_num))) / self.BIG_M

        least_rec = min(considered)
        for layer_idx in range(0, layer_type_num - 1):
            layer = sorted_layers[Layer.type_enum.BODY][layer_idx].name_
            for v in range(0, self.num_of_interleave_):
                for s in range(0, self.num_of_stage_):
                    prob += variables[layer][least_rec][v][s] <= (
                        1 - variables[self.REC_FRONTIER][v][s][layer_idx]
                        ) * self.BIG_M
        return prob

    @staticmethod
    def find_recompute_considered(
            layers_sorted: Dict[Layer.type_enum, List[Layer]]) -> Dict[Recompute.TYPE, bool]:
        """Return the recomputation-considered flags copied from the first BODY layer."""
        return layers_sorted[Layer.type_enum.BODY][0].recompute_considered_

    def max_stage_micro_eq_stage(self, prob: Any,
                                 layers_sorted: Dict[Layer.type_enum, List[Layer]]) -> Any:
        """Apply additional VPP optimisations when ``pp == num_of_micro_batch``."""
        last_chunk = self.num_of_interleave_ - 1

        for i_stage in range(self.num_of_stage_):
            for inter in range(last_chunk):
                prob += self.variables_[self.MAX_STAGE_TIME] >= (
                    self._max_stage_bound_i_bp(layers_sorted, i_stage, inter) +
                    self._max_stage_bound_head_tail(layers_sorted, i_stage,
                                                    -1, inter))

        if self.vpp_less_memory_:
            factors = self.compute_lm_forward_in_backward(self.num_of_stage_)
        else:
            factors = self.compute_forward_in_backward(
                self.num_of_stage_, self.num_of_micro_batch_)

        for i_stage in range(self.num_of_stage_):
            logger.debug(
                "v=%s, s=%s: (BP + HT) + (%s / %s * FP)",
                last_chunk,
                i_stage,
                factors[i_stage],
618
619
620
621
622
623
624
625
626
                i_stage,
                factors[i_stage],
                self.num_of_micro_batch_,
            )
            prob += self.variables_[self.MAX_LAST_CHUNK] >= (
                self._max_stage_bound_i_bp(layers_sorted, i_stage, last_chunk) +
                self._max_stage_bound_head_tail(layers_sorted, i_stage, last_chunk, last_chunk) +
                (factors[i_stage] / self.num_of_micro_batch_) *
                self._max_stage_bound_i_fp(layers_sorted, i_stage, last_chunk))
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
                self._max_stage_bound_head_tail(layers_sorted, i_stage, last_chunk, last_chunk) +
                (factors[i_stage] / self.num_of_micro_batch_) *
                self._max_stage_bound_i_fp(layers_sorted, i_stage, last_chunk))

        if self.optimization_level_ >= 2:
            logger.debug("Approach 2a")
            prob += self.variables_[self.MAX_STAGE_TIME] >= (
                self.variables_[self.MAX_LAST_CHUNK])

            return self.variables_[self.MAX_STAGE_TIME]
        logger.debug("Approach 2b")
        prob += self.variables_[self.MAX_LAST_CHUNK] >= (
            self.variables_[self.MAX_STAGE_TIME])

        return (self.variables_[self.MAX_STAGE_TIME] +
                self.variables_[self.MAX_LAST_CHUNK])

    def add_performance_constraint(self, prob: Any,
                                   layers_sorted: Dict[Layer.type_enum, List[Layer]],
                                   pipeline_total_time: Any) -> Any:
        """Add the ``pipeline_total_time >= …`` performance constraints."""
        max_stage_time = self.variables_[self.MAX_STAGE_TIME]
        max_stage_time = self.add_max_stage_constraint(prob, layers_sorted, max_stage_time)

        total_sum = self.variables_[self.TOTAL_SUM]
        prob += total_sum >= self._total_sum(layers_sorted)

        if self.optimization_level_ >= 2:
            # approach A
            for v in range(self.num_of_interleave_ - 1):
                prob += self.variables_[self.PREV_DIFF][v] >= (
                    self._prev_diff_sum(layers_sorted, prob, v))

                prob += self.variables_[self.CHUNKS_SUM][v] >= (
                    (self.num_of_interleave_ - v) / self.num_of_interleave_ *
                    self._chunks_sum(layers_sorted, v))

            chunks_sum = lpSolver.lpSum(self.variables_[self.CHUNKS_SUM])
            prev_diff = lpSolver.lpSum(self.variables_[self.PREV_DIFF])

            next_diff = self.variables_[self.NEXT_DIFF]
            prob += next_diff >= (
                self._next_diff_sum(layers_sorted, prob))

            prob += pipeline_total_time >= (
                (total_sum + chunks_sum + prev_diff + next_diff)
                / max(1, (self.num_of_interleave_ - 2))
                + max_stage_time * (self.num_of_micro_batch_ - 2)
            )
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
                + max_stage_time * (self.num_of_micro_batch_ - 2)
            )
        else:
            # approach B
            prob += pipeline_total_time >= max_stage_time
        return prob

    def add_max_stage_constraint(self, prob: Any,
                                 layers_sorted: Dict[Layer.type_enum, List[Layer]],
                                 max_stage_time: Any) -> Any:
        """Add the ``max_stage_time`` lower-bound constraints over every ``(interleave, stage)``."""
        if (self.num_of_interleave_ > 1 and self.optimization_level_ >= 1
                and self.num_of_micro_batch_ == self.num_of_stage_):
            max_stage_time = self.max_stage_micro_eq_stage(prob, layers_sorted)
        else:
            # Constraints on sub-main-part of a stage that it may take (for all stage)
            for i_stage in range(self.num_of_stage_):
                for inter_f in range(self.num_of_interleave_):
                    for inter_b in range(self.num_of_interleave_):
                        prob += max_stage_time >= (
                             self._max_stage_bound_i_fp(layers_sorted, i_stage, inter_f) +
                             self._max_stage_bound_i_bp(layers_sorted, i_stage, inter_b) +
                             self._max_stage_bound_head_tail(layers_sorted, i_stage,
                                                             inter_f, inter_b))
692
693
694
695
696
697
698
699
700
701
702
703
704
705
                             self._max_stage_bound_i_bp(layers_sorted, i_stage, inter_b) +
                             self._max_stage_bound_head_tail(layers_sorted, i_stage,
                                                             inter_f, inter_b))

        return max_stage_time

    ############################################
    #            Memory Constraint             #
    ############################################
    def stage_param_memory(self, variables: Any,
                           layers_sorted: Dict[Layer.type_enum, List[Layer]],
                           stage_id: int, num_of_stage: int,
                           num_of_interleave: int) -> Any:
        """Return an LP expression for the parameter memory of ``stage_id``."""
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
                           stage_id: int, num_of_stage: int,
                           num_of_interleave: int) -> Any:
        """Return an LP expression for the parameter memory of ``stage_id``."""
        # Add if dual to decide whether dualpipe_v is used
        if self.dual_:
            bound = lpSolver.LpAffineExpression()
            for inter_id in range(num_of_interleave):
                for layer in layers_sorted[Layer.type_enum.BODY]:
                    for rec in Recompute.TYPE:
                        if self.recompute_considered_[rec]:
                            bound += (
                                variables[layer.name_][rec][inter_id][stage_id] *
                                layer.memory_parameter_)
            if stage_id == 0:
                for head in layers_sorted[Layer.type_enum.HEAD]:
                    bound += head.memory_parameter_
                for tail in layers_sorted[Layer.type_enum.TAIL]:
                    bound += tail.memory_parameter_
        else:
            bound = lpSolver.LpAffineExpression()
            for inter_id in range(num_of_interleave):
                for layer in layers_sorted[Layer.type_enum.BODY]:
                    for rec in Recompute.TYPE:
                        if self.recompute_considered_[rec]:
                            bound += (
                                variables[layer.name_][rec][inter_id][stage_id] *
                                layer.memory_parameter_)
            if stage_id == 0:
                for head in layers_sorted[Layer.type_enum.HEAD]:
                    bound += head.memory_parameter_
                for tail in layers_sorted[Layer.type_enum.TAIL]:
                    bound += tail.memory_parameter_
            if stage_id == num_of_stage - 1:
                for tail in layers_sorted[Layer.type_enum.TAIL]:
                    bound += tail.memory_parameter_
        return bound

    def stage_active_memory_per_micro(
            self, variables: Any,
            layers_sorted: Dict[Layer.type_enum, List[Layer]],
            stage_id: int, inter_id: int) -> Any:
        """Return an LP expression for the activation memory of ``stage_id`` per micro-batch."""
        bound = lpSolver.LpAffineExpression()
        for layer in layers_sorted[Layer.type_enum.BODY]:
            for rec in Recompute.TYPE:
                if self.recompute_considered_[rec]:
                    bound += (variables[layer.name_][rec][inter_id][stage_id] *
                              layer.memory_activation_rec_[rec])
        return bound

    def stage_active_memory(self, variables: Any,
                            layers_sorted: Dict[Layer.type_enum, List[Layer]],
                            stage_id: int, num_of_interleave: int,
                            activation_nums: List[List[int]]) -> Any:
        """Return the total activation-memory LP expression for ``stage_id``."""
        bound = lpSolver.LpAffineExpression()
        for inter_id in range(num_of_interleave):
            for layer in layers_sorted[Layer.type_enum.BODY]:
                for rec in Recompute.TYPE:
                    if self.recompute_considered_[rec]:
                        bound += (
                            variables[layer.name_][rec][inter_id][stage_id] *
                            layer.memory_activation_rec_[rec] *
                            activation_nums[inter_id][stage_id])
        return bound

    def init_overhead_variables(self, variables: Any, s: int) -> Any:
        """Compute the per-stage overhead LP expression used in the VPP memory constraint."""
        bound = lpSolver.LpAffineExpression()
        vf = self.num_of_interleave_ - 1
        vb = self.num_of_interleave_ - 1
        incr_f = True
        if self.vpp_less_memory_:
            for _ in range(self.num_of_interleave_ - 1):
                if incr_f:
                    vf = (vf + 1) % self.num_of_interleave_
                    factor = abs(self.num_of_stage_ - s)
                else:
                    vb = vb - 1
                    factor = s
                incr_f = not incr_f

                logger.debug("%s * (act(%s,%s) - act(%s,%s)", factor, vf, s, vb, s)
                bound += factor * (
                    self.stage_active_memory_per_micro(variables, self.layers_sorted_, s, vf)
                    - self.stage_active_memory_per_micro(variables, self.layers_sorted_, s, vb))
        else:
            for _ in range(self.num_of_interleave_ - 1):
                if incr_f:
                    vf = (vf + 1) % self.num_of_interleave_
                    logger.debug(
                        "%s * (act(%s,%s) - act(%s,%s)",
                        self.num_of_stage_ - abs(self.num_of_stage_ - 2 * s - 1),
                        vf,
                        s,
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
                        s,
                        vb,
                        s,
                    )
                    bound += (self.num_of_stage_ - abs(self.num_of_stage_ - 2 * s - 1)) * (
                        self.stage_active_memory_per_micro(variables, self.layers_sorted_, s, vf)
                        - self.stage_active_memory_per_micro(variables, self.layers_sorted_, s, vb)
                        )
                else:
                    vb = vb - 1
                    logger.debug(
                        "%s * (act(%s,%s) - act(%s,%s)",
                        max(self.num_of_stage_ - 2 * s - 1, 0),
                        vf + 1,
                        s,
811
812
813
814
815
816
817
818
819
                        s,
                        vb + 1,
                        s,
                    )
                    bound += max(self.num_of_stage_ - 2 * s - 1, 0) * (
                        self.stage_active_memory_per_micro(variables,
                                                           self.layers_sorted_, s, vf + 1)
                        - self.stage_active_memory_per_micro(variables,
                                                             self.layers_sorted_, s, vb + 1)
817
818
819
820
821
822
823
824
825
                                                           self.layers_sorted_, s, vf + 1)
                        - self.stage_active_memory_per_micro(variables,
                                                             self.layers_sorted_, s, vb + 1)
                        )
                    logger.debug(
                        "%s * (act(%s,%s) - act(%s,%s)",
                        max(-(self.num_of_stage_ - 2 * s - 1), 0),
                        vf,
                        s,
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
                        s,
                        vb,
                        s,
                    )
                    bound += max(-(self.num_of_stage_ - 2 * s - 1), 0) * (
                        self.stage_active_memory_per_micro(variables, self.layers_sorted_, s, vf)
                        - self.stage_active_memory_per_micro(variables, self.layers_sorted_, s, vb)
                        )
                incr_f = not incr_f

        return bound

    def stage_overhead_memory(self, variables: Any, stage_id: int) -> Any:
        """Return the stage-``stage_id`` memory overhead LP expression."""
        bound = lpSolver.LpAffineExpression()
        for v in range(self.num_of_interleave_ - 1):
            bound += variables[self.MEM_OVERHEAD_NAME][stage_id][v]
        return bound

    def add_pipeline_memory_constraint(self,
                                       constraint: PipelineMemoryConstraint) -> None:
        """Add per-stage memory upper-bound constraints to the solver problem."""
        prob = constraint.prob
        variables = constraint.variables
        layers_sorted = constraint.layers_sorted
        num_of_stage = constraint.num_of_stage
        num_of_interleave = constraint.num_of_interleave
        micro_batch = constraint.micro_batch
        memory_limit = constraint.memory_limit

        if self.vpp_less_memory_:
            if self.seq_pipe:
                activation_nums = self.compute_activation_seq_nums(
                    num_of_stage, num_of_interleave, self.seq_split_num_, micro_batch, True)
            else:
                activation_nums = self.compute_less_activation_nums(
                    num_of_stage, num_of_interleave)
        # Add if dual to decide whether dualpipe_v is used
        elif self.dual_:
            activation_nums = self.compute_activation_nums_dual(
                num_of_stage, num_of_interleave, micro_batch)

        else:
            if self.seq_pipe:
                activation_nums = self.compute_activation_seq_nums(
                    num_of_stage, num_of_interleave, self.seq_split_num_, micro_batch, False)
            else:
                activation_nums = self.compute_activation_nums(
                    num_of_stage, num_of_interleave, micro_batch)
        logger.info("activation nums = %s", activation_nums)

        if self.num_of_stage_ == self.num_of_micro_batch_:
            for s in range(num_of_stage):
                prob += memory_limit >= (
                    self.stage_param_memory(variables, layers_sorted, s,
                                            num_of_stage, num_of_interleave) +
                    self.stage_active_memory(variables, layers_sorted, s,
                                             num_of_interleave, activation_nums) +
881
882
883
884
885
886
887
888
889
890
891
892
893
                    self.stage_active_memory(variables, layers_sorted, s,
                                             num_of_interleave, activation_nums) +
                    self.constant_memory_)
        else:
            for s in range(num_of_stage):
                prob += variables[self.MEM_OVERHEAD_NAME][s] >= (
                    self.init_overhead_variables(variables, s)
                )
                prob += memory_limit >= (
                    self.stage_param_memory(
                        variables, layers_sorted, s, num_of_stage, num_of_interleave
                    )
                    + self.stage_active_memory(
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
                    + variables[self.MEM_OVERHEAD_NAME][s]
                    + self.constant_memory_
                )

    def get_simulator_memory_activation(self) -> list[float]:
        """Give the activation memory per stage for simulator."""

        memory_active = []
        if self.has_some_memory_info():
            for inter in range(self.num_of_interleave_):
                memory_active.append([])
                for stage in range(self.num_of_stage_):
                    memory_active[inter].append(0)
                    memory_active[inter][stage] = sum(
                        self.variables_.get(layer.name_)[rec][inter][stage].varValue
                        * layer.memory_activation_rec_[rec]
                        for rec in Recompute.TYPE
                        if self.recompute_considered_[rec]
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
                        for rec in Recompute.TYPE
                        if self.recompute_considered_[rec]
                        for layer in self.layers_sorted_[Layer.type_enum.BODY]
                    )
        return memory_active

    def get_simulator_memory_parameter(self) -> list[float]:
        """Give the parameter memory per stage for simulator."""
        memory_param_stage = [0] * self.num_of_stage_
        if self.has_some_memory_info():
            for inter in range(self.num_of_interleave_):
                for stage in range(self.num_of_stage_):
                    memory_param_stage[stage] += sum(
                        self.variables_.get(layer.name_)[rec][inter][stage].varValue
                        * layer.memory_parameter_
                        for rec in Recompute.TYPE if self.recompute_considered_[rec]
                        for layer in self.layers_sorted_[Layer.type_enum.BODY])
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
                        * layer.memory_parameter_
                        for rec in Recompute.TYPE if self.recompute_considered_[rec]
                        for layer in self.layers_sorted_[Layer.type_enum.BODY])

        for head in self.layers_sorted_[Layer.type_enum.HEAD]:
            if head.memory_parameter_ is not None:
                memory_param_stage[0] += head.memory_parameter_
        for tail in self.layers_sorted_[Layer.type_enum.TAIL]:
            if tail.memory_parameter_ is not None:
                memory_param_stage[self.num_of_stage_ -
                                   1] += tail.memory_parameter_
        memory_param = [memory_param_stage] * self.num_of_interleave_
        return memory_param

    def get_simulator_time(self) -> list[float]:
        """Give the time per stage for simulator."""
        time = []
        for i in range(self.num_of_interleave_):
            time.append([])
            for s in range(self.num_of_stage_):
                time[i].append(0)
                for layer in self.layers_sorted_[Layer.type_enum.BODY]:
                    for rec in Recompute.TYPE:
                        if self.recompute_considered_[rec]:
                            time[i][s] += self.variables_.get(
                                layer.name_)[rec][i][s].varValue * (
                                    layer.forward_time_ +
                                    layer.backward_time_rec_[rec])

        for head in self.layers_sorted_[Layer.type_enum.HEAD]:
            time[0][0] += head.time_
        for tail in self.layers_sorted_[Layer.type_enum.TAIL]:
            time[self.num_of_interleave_ - 1][self.num_of_stage_ -
                                              1] += tail.time_
        return time

    def get_simulator_forward_time(self) -> list[float]:
        """Give the time per stage for simulator."""
        time = []
        for i in range(self.num_of_interleave_):
            time.append([])
            for s in range(self.num_of_stage_):
                time[i].append(0)
                for layer in self.layers_sorted_[Layer.type_enum.BODY]:
                    for rec in Recompute.TYPE:
                        if self.recompute_considered_[rec]:
                            time[i][s] += self.variables_[layer.name_][rec][i][
                                s].varValue * (layer.forward_time_)
        for head in self.layers_sorted_[Layer.type_enum.HEAD]:
            time[0][0] += head.time_
        for tail in self.layers_sorted_[Layer.type_enum.TAIL]:
            time[self.num_of_interleave_ - 1][self.num_of_stage_ -
                                              1] += tail.time_
        return time

    def get_simulator_recompute_time(self) -> list[float]:
        """Give the time per stage for simulator."""
        time_all_rec = []
        time_no_rec = []
        for i in range(self.num_of_interleave_):
            time_all_rec.append([])
            time_no_rec.append([])
            for s in range(self.num_of_stage_):
                time_all_rec[i].append(0)
                time_no_rec[i].append(0)
                for layer in self.layers_sorted_[Layer.type_enum.BODY]:
                    for rec in Recompute.TYPE:
                        if self.recompute_considered_[rec]:
                            time_all_rec[i][s] += self.variables_[
                                layer.name_][rec][i][s].varValue * (
                                    layer.backward_time_rec_[rec])
                            time_no_rec[i][s] += self.variables_[
                                layer.name_][rec][i][s].varValue * (
                                    layer.backward_time_rec_[
                                        Recompute.TYPE.NONE])
        return [[r - n for r, n in zip(ar, nr)]
                for ar, nr in zip(time_all_rec, time_no_rec)]

    def has_some_memory_info(self) -> bool:
        """Check if there is some information for memory constraint."""
        some_info = False
        for rec in Recompute.TYPE:
            if self.recompute_considered_[rec]:
                some_info = True
        return some_info

    ############################################
    #            General Constraint            #
    ############################################
    def add_optional_recompute_constraint(
            self, prob: Any, variables: Any,
            sorted_layers: Dict[Layer.type_enum, List[Layer]]) -> None:
        """Pin unused recomputation variables to zero in the ILP."""
        for layer in sorted_layers[Layer.type_enum.BODY]:
            for rec in Recompute.TYPE:
                if not self.recompute_considered_[rec]:
                    prob += lpSolver.lpSum(variables[layer.name_][rec]) == 0

    def dump_problem(self, folder: Optional[str] = None) -> None:
        """Serialize the pulp LP model to ``<folder>/<auto-generated-name>.lp``."""
        dump_name = "problem_" + str(self.layers_[0].model_name_)
        dump_name += "_" + str(self.max_memory_)
        dump_name += "_" + str(self.num_of_interleave_)
        dump_name += "_" + str(self.num_of_stage_)

        logger.info("dump_problem:out folder = %s", folder)
        if folder is not None:
            dump_name = os.path.join(folder, dump_name)
        dump_name += ".lp"
        logger.info("dump problem file: %s", dump_name)
        self.problem_.writeLP(dump_name)

    def print_results(self) -> None:
        """Log the detailed per-layer solver assignment for the solved problem."""
        if self.has_some_memory_info():
            logger.output("For max memory %s", self.max_memory_)
            logger.output("==============")
        for body_layer in self.layers_sorted_[Layer.type_enum.BODY]:
            layer_name = body_layer.name_
            logger.output("For layer: %s", layer_name)
            logger.output("=========")
            logger.output("  Forward Prop time: %s", body_layer.forward_time_)
            for rec in Recompute.TYPE:
                if body_layer.recompute_considered_[rec]:
                    logger.output("  Backward Prop %s time: %s",
                                  Recompute.YAML_NAME[rec], body_layer.backward_time_rec_[rec])
            for inter in range(self.num_of_interleave_):
                for stage in range(self.num_of_stage_):
                    parts = []
                    for rec in Recompute.TYPE:
                        if self.recompute_considered_[rec]:
                            value = str(int(self.variables_[layer_name][rec][inter][stage].varValue))
                            parts.append(value if rec is Recompute.TYPE.NONE else f"+ {value} {rec.name}")
                    chunk = f" of chunk {inter}" if self.num_of_interleave_ != 1 else ""
                    logger.output("    Assign %s: %s%s  to stage %d",
                                  layer_name, " ".join(parts), chunk, stage)
        for s in range(self.num_of_stage_):
            logger.debug(
                "%s[%s] =%s",
                self.MEM_OVERHEAD_NAME,
                s,
                self.variables_[self.MEM_OVERHEAD_NAME][s].varValue,
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
                s,
                self.variables_[self.MEM_OVERHEAD_NAME][s].varValue,
            )

        for v in range(self.num_of_interleave_ - 1):
            logger.debug(
                "%s[%s] = %s",
                self.CHUNKS_SUM,
                v,
                self.variables_[self.CHUNKS_SUM][v].varValue,
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
                v,
                self.variables_[self.CHUNKS_SUM][v].varValue,
            )

        for v in range(self.num_of_interleave_ - 1):
            logger.debug(
                "%s[%s] = %s",
                self.PREV_DIFF,
                v,
                self.variables_[self.PREV_DIFF][v].varValue,
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
                v,
                self.variables_[self.PREV_DIFF][v].varValue,
            )

        logger.debug("%s = %s", self.NEXT_DIFF, self.variables_[self.NEXT_DIFF].varValue)
        logger.debug("%s = %s", self.TOTAL_SUM, self.variables_[self.TOTAL_SUM].varValue)
        logger.debug("%s = %s", self.MAX_STAGE_TIME, self.variables_[self.MAX_STAGE_TIME].varValue)
        logger.debug("%s = %s", self.MAX_LAST_CHUNK, self.variables_[self.MAX_LAST_CHUNK].varValue)

        for body_layer in range(len(self.layers_sorted_[Layer.type_enum.BODY]) - 1):
            for v in range(self.num_of_interleave_):
                for s in range(self.num_of_stage_):
                    logger.info(
                        "%s[%s][%s][%s] = %s",
                        self.LAYER_FRONTIER,
                        body_layer,
                        v,
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
                        s,
                        self.variables_[self.LAYER_FRONTIER][body_layer][v][s].varValue,
                    )

    def debug_print_solver_theoretical_memory(self) -> None:
        """Log the solver-implied per-stage theoretical memory (debug aid)."""
        logger.info("%s Solver Theoretical Memory Analysis %s", "=" * 20, "=" * 20)

        if self.vpp_less_memory_:
            if self.seq_pipe:
                activation_nums = self.compute_activation_seq_nums(
                    self.num_of_stage_, self.num_of_interleave_, self.seq_split_num_, self.num_of_micro_batch_, True)
            else:
                activation_nums = self.compute_less_activation_nums(
                    self.num_of_stage_, self.num_of_interleave_)
        else:
            if self.seq_pipe:
                activation_nums = self.compute_activation_seq_nums(
                    self.num_of_stage_, self.num_of_interleave_, self.seq_split_num_, self.num_of_micro_batch_, False)
            else:
                activation_nums = self.compute_activation_nums(
                    self.num_of_stage_, self.num_of_interleave_, self.num_of_micro_batch_)

        # compute theoretical value for each stage
        for s in range(self.num_of_stage_):
            param_mem = self.stage_param_memory(
                self.variables_,
                self.layers_sorted_,
                s,
                self.num_of_stage_,
1128
1129
1130
1131
1132
1133
1134
1135
1136
                self.num_of_stage_,
                self.num_of_interleave_
            ).value()

            act_mem = self.stage_active_memory(
                self.variables_,
                self.layers_sorted_,
                s,
                self.num_of_interleave_,
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
                activation_nums
            ).value()

            # overhead = self.variables_[self.MEM_OVERHEAD_NAME][s].varValue * overhead_factors[s]
            overhead = 0
            total = param_mem + act_mem + overhead + self.constant_memory_

            logger.info("Stage %d Solver Memory Analysis:", s)
            logger.info("Parameter Memory:     %.2f", param_mem)
            logger.info("Activation Memory:    %.2f", act_mem)
            logger.info("Memory Overhead:      %.2f", overhead)
            logger.info("Constant Memory:      %.2f", self.constant_memory_)
            logger.info("Total Theoretical Memory: %.2f", total)


    def solve(self, time_limit: int = 90, dump_folder: Optional[str] = None) -> None:
        """Solve the ILP problem using PuLP's bundled CBC backend.

        Args:
            time_limit: Upper bound on solver wall-clock time in seconds.
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
        Args:
            time_limit: Upper bound on solver wall-clock time in seconds.
            dump_folder: Directory to write the LP model to; ``None`` skips the dump.
        """
        logger.info("solve:out folder = %s", dump_folder)
        self.dump_problem(dump_folder)
        solver = lpSolver.getSolver("PULP_CBC_CMD", timeLimit=time_limit)
        self.problem_.solve(solver)

        self.print_results()

        self.debug_print_solver_theoretical_memory()

        for name, result in self.result().items():
            logger.output("%s %s %s", name, result, "\n")

    def result(self) -> dict[str, list[list[str]]]:
        """return schedule distribution for each layer (in the form of a dict)"""
        r = {}
        for layer in self.layers_sorted_[Layer.type_enum.BODY]:
            layer_name = layer.name_
            inter = []
            for i in range(self.num_of_interleave_):
                stage = []
                for s in range(self.num_of_stage_):
                    for rec in Recompute.TYPE:
                        if self.recompute_considered_[rec]:
                            stage.append(
                                str(
                                    self.variables_.get(layer_name)[rec][i]
                                    [s].varValue) + " + ")
                inter.append(stage)
            r[layer_name] = inter
        return r

    def _create_problem_(self, description: str) -> lpSolver.LpProblem:
        """create the problem"""
        prob = lpSolver.LpProblem(description, lpSolver.LpMinimize)
        layers_sorted = self.layers_sorted_
        num_of_stage = self.num_of_stage_
        num_of_interleave = self.num_of_interleave_
        num_of_micro_batch = self.num_of_micro_batch_
        max_memory = self.max_memory_
        # Local variable declaration
        # max time that a "main" stage have to take (var to minimize)
        pipeline_total_time = lpSolver.LpVariable("pipeline_total_time", 0,
                                                  None, lpSolver.LpContinuous)

        # Var to Minimize
        prob += pipeline_total_time

        self.add_total_nb_layer_constraint(prob, self.variables_, layers_sorted)
        # Add if dual to the original layer order constraint
        self.add_stage_nb_layer_constraint(prob, self.variables_, layers_sorted)
        self.add_multimodal_sequence_constraint(prob, self.variables_, layers_sorted)
        #self.add_stage_nb_layer_constraint_dual(prob, self.variables_, layers_sorted)
        #self.add_multimodal_sequence_constraint_dual(prob, self.variables_, layers_sorted)
        self.add_multimodal_recompute_constraint(prob, self.variables_, layers_sorted)
        self.add_performance_constraint(prob, layers_sorted, pipeline_total_time)

        constraint = PipelineMemoryConstraint(
            prob=prob,
            variables=self.variables_,
            layers_sorted=layers_sorted,
            num_of_stage=num_of_stage,
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
            num_of_interleave=num_of_interleave,
            micro_batch=num_of_micro_batch,
            memory_limit=max_memory,
        )
        if self.has_some_memory_info():
            self.add_pipeline_memory_constraint(constraint)
        return prob

    def _create_variables_to_solve_(
            self,
            num_of_stage: int,
            num_of_interleave: int,
            layers: dict[Layer.type_enum, list[Layer]],
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
            num_of_interleave: int,
            layers: dict[Layer.type_enum, list[Layer]],
    ):
        """create variables to solve"""
        variables = {}

        variables[self.TOTAL_SUM] = lpSolver.LpVariable(
            self.TOTAL_SUM, 0, None, lpSolver.LpContinuous)

        chunks_sum_dict = lpSolver.LpVariable.dicts(
            name=self.CHUNKS_SUM,
            indices=(range(0, self.num_of_interleave_ - 1)),
            lowBound=0,
            upBound=None,
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
            lowBound=0,
            upBound=None,
            cat=lpSolver.LpContinuous
        )
        chunks_sum_list = list(chunks_sum_dict.values())
        variables[self.CHUNKS_SUM] = chunks_sum_list

        prev_diff_dict = lpSolver.LpVariable.dicts(
            name=self.PREV_DIFF,
            indices=(range(0, self.num_of_interleave_ - 1)),
            lowBound=0,
            upBound=None,
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
            lowBound=0,
            upBound=None,
            cat=lpSolver.LpContinuous
        )
        prev_diff_list = list(prev_diff_dict.values())
        variables[self.PREV_DIFF] = prev_diff_list

        layer_frontier_dict = lpSolver.LpVariable.dicts(
            name=self.LAYER_FRONTIER,
            indices=(
                range(1, len(self.layers_sorted_[Layer.type_enum.BODY])),
                range(0, self.num_of_interleave_),
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
            lowBound=0,
            upBound=1,
            cat=lpSolver.LpBinary
        )
        layer_frontier_list = list(layer_frontier_dict.values())
        variables[self.LAYER_FRONTIER] = layer_frontier_list

        rec_frontier_dict = lpSolver.LpVariable.dicts(
            name=self.REC_FRONTIER,
            indices=(
                range(0, self.num_of_interleave_),
                range(0, self.num_of_stage_),
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
            lowBound=0,
            upBound=1,
            cat=lpSolver.LpBinary
        )
        rec_frontier_list = list(rec_frontier_dict.values())
        variables[self.REC_FRONTIER] = rec_frontier_list

        variables[self.NEXT_DIFF] = lpSolver.LpVariable(
            self.NEXT_DIFF, 0, None, lpSolver.LpContinuous)

        variables[self.MAX_STAGE_TIME] = lpSolver.LpVariable(
            self.MAX_STAGE_TIME, 0, None, lpSolver.LpContinuous)

        variables[self.MAX_LAST_CHUNK] = lpSolver.LpVariable(
            self.MAX_LAST_CHUNK, 0, None, lpSolver.LpContinuous)

        lp_variable_dict = lpSolver.LpVariable.dicts(
            name=self.MEM_OVERHEAD_NAME,
            indices=(range(0, self.num_of_stage_)),
            lowBound=0,
            upBound=None,
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
            lowBound=0,
            upBound=None,
            cat=lpSolver.LpInteger,
        )
        variables_list = list(lp_variable_dict.values())
        variables[self.MEM_OVERHEAD_NAME] = variables_list

        for layer in layers[Layer.type_enum.BODY]:
            variable_dict = lpSolver.LpVariable.dicts(
                name=layer.name_,
                indices=(
                    range(0, len(Recompute.TYPE)),
                    range(0, num_of_interleave),
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
                lowBound=0,
                upBound=None,
                cat=lpSolver.LpInteger,
            )
            variable_values = list(variable_dict.values())
            interleave_values = []
            for interleave in variable_values:
                interleave_value = list(interleave.values())
                interleave_values.append(interleave_value)
            variables[layer.name_] = interleave_values

        return variables

    ############################################
    #             Time Constraint              #
    ############################################
    def _max_stage_bound_i_fp(self, layers_sorted, stage_id, inter_f):
        bound = lpSolver.LpAffineExpression()
        for layer in layers_sorted[Layer.type_enum.BODY]:
            for rec in Recompute.TYPE:
                if self.recompute_considered_[rec]:
                    bound += (self.variables_[layer.name_][rec][inter_f][stage_id] *
                              layer.forward_time_)
        return bound

    def _max_stage_bound_i_bp(self, layers_sorted, stage_id, inter_b):
        bound = lpSolver.LpAffineExpression()
        for layer in layers_sorted[Layer.type_enum.BODY]:
            for rec in Recompute.TYPE:
                if self.recompute_considered_[rec]:
                    bound += (self.variables_[layer.name_][rec][inter_b][stage_id] *
                              layer.backward_time_rec_[rec])
        return bound

    def _max_stage_bound_head_tail(self, layers_sorted, stage_id, inter_f,
                                   inter_b):
        """maximize the stage bound of head and tail"""
        bound = lpSolver.LpAffineExpression()
        if stage_id == 0:
            if inter_f == 0:
                for head in layers_sorted[Layer.type_enum.HEAD]:
                    bound += head.time_
            if inter_b == 0:
                for head in layers_sorted[Layer.type_enum.HEAD]:
                    bound += head.time_ * 2
        if stage_id == self.num_of_stage_ - 1:
            if inter_f == self.num_of_interleave_ - 1:
                for tail in layers_sorted[Layer.type_enum.TAIL]:
                    bound += tail.time_
            if inter_b == self.num_of_interleave_ - 1:
                for tail in layers_sorted[Layer.type_enum.TAIL]:
                    bound += tail.time_ * 2
        return bound

    def _total_sum(self, layers_sorted):
        """sum up the layer time"""
        bound = lpSolver.LpAffineExpression()
        for layer in layers_sorted[Layer.type_enum.BODY]:
            for rec in Recompute.TYPE:
                if self.recompute_considered_[rec]:
                    for inter in range(self.num_of_interleave_):
                        for stage in range(self.num_of_stage_):
                            bound += self.variables_[layer.name_][rec][inter][stage] * (
                                layer.forward_time_ +
                                layer.backward_time_rec_[rec])
        return bound

    def body_layer_time(self, prop: "SappSolver.PROP_PHASE", layer: Layer,
                        inter: int, stage: int) -> Any:
        """Return a forward or backward time LP expression for ``layer`` at ``(inter, stage)``."""
        if prop == self.PROP_PHASE.FW:
            bound = lpSolver.lpSum(
                self.variables_[layer.name_][rec][inter][stage] * layer.forward_time_
                for rec in Recompute.TYPE if self.recompute_considered_[rec])
        else:
            bound = lpSolver.lpSum(
                self.variables_[layer.name_][rec][inter][stage] * layer.backward_time_rec_[rec]
                for rec in Recompute.TYPE if self.recompute_considered_[rec])

        return bound

    def micro_batch_time(self, prop: "SappSolver.PROP_PHASE",
                         layers_sorted: Dict[Layer.type_enum, List[Layer]],
                         inter: int, stage: int) -> Any:
        """Return the total micro-batch time LP expression at ``(inter, stage)``."""
        bound = lpSolver.LpAffineExpression()
        if prop == self.PROP_PHASE.FW:
            for layer in layers_sorted[Layer.type_enum.BODY]:
                bound = self.body_layer_time(prop, layer, inter, stage)
            if stage == 0 and inter == 0:
                for head in layers_sorted[Layer.type_enum.HEAD]:
                    bound += head.time_
            if stage == self.num_of_stage_ - 1 and inter == self.num_of_interleave_ - 1:
                for tail in layers_sorted[Layer.type_enum.TAIL]:
                    bound += tail.time_
        else:
            for layer in layers_sorted[Layer.type_enum.BODY]:
                bound = self.body_layer_time(prop, layer, inter, stage)
            if stage == 0 and inter == 0:
                for head in layers_sorted[Layer.type_enum.HEAD]:
                    bound += head.time_ * 2
            if stage == self.num_of_stage_ - 1 and inter == self.num_of_interleave_ - 1:
                for tail in layers_sorted[Layer.type_enum.TAIL]:
                    bound += tail.time_ * 2
        return bound

    def _chunks_sum(self, layers_sorted, v):
        """sum up the warm-up and cool-down time of a given chunk"""
        bound = lpSolver.LpAffineExpression()
        for stage in range(self.num_of_stage_):
            bound += self.micro_batch_time(self.PROP_PHASE.FW, layers_sorted, v, stage)
            bound += self.micro_batch_time(self.PROP_PHASE.BW, layers_sorted, v, stage)
        # normalize
        bound = bound / self.num_of_stage_
        return bound

    def _prev_diff_sum(self, layers_sorted, prob, v):
        """models bubble time for the first diagonal (forward, interleave 0)"""
        max_prev_stages = lpSolver.LpVariable.dicts(
            name="max_prev_stages_" + str(v),
            indices=(range(self.num_of_stage_)),
            lowBound=0,
            upBound=None,
1435
1436
1437
1438
1439
1440
1441
1442
1443
            upBound=None,
            cat=lpSolver.LpContinuous,
        )

        diff_with_prev_stages = lpSolver.LpVariable.dicts(
            name="diff_with_prev_stages_" + str(v),
            indices=(range(self.num_of_stage_)),
            lowBound=0,
            upBound=None,
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
            upBound=None,
            cat=lpSolver.LpContinuous,
        )

        bound = lpSolver.LpAffineExpression()

        head_time = 0
        for head in layers_sorted[Layer.type_enum.HEAD]:
            head_time = head.time_

        prob += max_prev_stages[0] >= (self.micro_batch_time(
            self.PROP_PHASE.FW, layers_sorted, v, 0)) - head_time

        for stage in range(1, self.num_of_stage_):
            prob += max_prev_stages[stage] >= max_prev_stages[stage - 1]
            prob += max_prev_stages[stage] >= (self.micro_batch_time(
                self.PROP_PHASE.FW, layers_sorted, v, stage))


            prob += diff_with_prev_stages[stage] >= (
                max_prev_stages[stage - 1] - self.micro_batch_time(
                    self.PROP_PHASE.FW, layers_sorted, v, stage))

        bound += self.num_of_micro_batch_ * lpSolver.lpSum(
            diff_with_prev_stages[s] for s in range(1, self.num_of_stage_))
        return bound

    def _next_diff_sum(self, layers_sorted, prob):
        """models bubble time for the last diagonal (forward, last chunk)"""
        last_chunk = self.num_of_interleave_ - 1
        max_next_stages = lpSolver.LpVariable.dicts(
            name="max_next_stages",
            indices=(range(self.num_of_stage_)),
            lowBound=0,
            upBound=None,
1477
1478
1479
1480
1481
1482
1483
1484
1485
            upBound=None,
            cat=lpSolver.LpContinuous,
        )

        diff_with_next_stages = lpSolver.LpVariable.dicts(
            name="diff_with_next_stages",
            indices=(range(self.num_of_stage_)),
            lowBound=0,
            upBound=None,
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
            upBound=None,
            cat=lpSolver.LpContinuous,
        )

        bound = lpSolver.LpAffineExpression()

        prob += max_next_stages[self.num_of_stage_ -
                                1] >= (self.micro_batch_time(
                                    self.PROP_PHASE.FW, layers_sorted, last_chunk,
                                    self.num_of_stage_ - 1))

        for stage in reversed(range(0, self.num_of_stage_ - 1)):
            prob += max_next_stages[stage] >= max_next_stages[stage + 1]
            prob += max_next_stages[stage] >= (self.micro_batch_time(
                self.PROP_PHASE.FW, layers_sorted, last_chunk, stage))

            prob += diff_with_next_stages[stage] >= (
                max_next_stages[stage + 1] - self.micro_batch_time(
                    self.PROP_PHASE.FW, layers_sorted, last_chunk, stage))

        bound += self.num_of_micro_batch_ * lpSolver.lpSum(
            diff_with_next_stages[s] for s in range(self.num_of_stage_ - 1))
        return bound
hyper_parallel/auto_parallel/sapp_ppb/simulator/causal_error.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Exception types raised by the pipeline simulator when dependency loops are detected."""
from __future__ import annotations

import matplotlib.pyplot as plt

from sapp_ppb.simulator.plot_manager import PlotMgr
from sapp_ppb.simulator.sim_block import BlockSim, MicroBlockSim
from sapp_ppb.utils.logger import logger


class CausalError(Exception):
    """Raised when the block pipeline (without comm) contains a dependency loop."""

    def __init__(self, msg: str, blocks: list[list[MicroBlockSim]], loop: list[BlockSim]) -> None:
        """Create the error, draw the offending loop and log a diagnostic message.

        Args:
            msg: Human-readable description of the loop.
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
            msg: Human-readable description of the loop.
            blocks: Full 2-D grid of simulator blocks, ``[pp_rank][block_idx]``.
            loop: Sequence of blocks participating in the dependency loop.
        """
        super().__init__()
        self.msg = msg
        self.canvas = PlotMgr(num_plots=1, figsize=(12, 6))
        self.canvas.draw_loop(blocks, loop, 0, False, False, True)
        self.canvas.ax[0].set_title("Block pipeline dependency")
        logger.error("%s", self.canvas.msg)

    def __str__(self) -> str:
        """Show the diagnostic plot and return the error message."""
        plt.show()
        return f"{self.msg}"


class CausalCommError(Exception):
    """Raised when the block pipeline with communication contains a dependency loop."""

    def __init__(self, msg: str, blocks: list[list[MicroBlockSim]], loop: list[BlockSim]) -> None:
        """Create the error, draw the offending loop and log a diagnostic message.

        Args:
            msg: Human-readable description of the loop.
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
            msg: Human-readable description of the loop.
            blocks: Full 2-D grid of simulator blocks, ``[pp_rank][block_idx]``.
            loop: Sequence of blocks (compute + comm) participating in the dependency loop.
        """
        super().__init__()
        self.msg = msg
        self.canvas = PlotMgr(num_plots=1, figsize=(12, 6))
        self.canvas.draw_comm_loop(blocks, loop, 0)
        self.canvas.ax[0].set_title("Block comm pipeline dependency")
        logger.error("%s", self.canvas.msg)

    def __str__(self) -> str:
        """Show the diagnostic plot and return the error message."""
        plt.show()
        return f"{self.msg}"
hyper_parallel/auto_parallel/sapp_ppb/simulator/pipeline_builder.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Build pipeline scheduler chains (1F1B, VPP, VPP-less-memory)."""
from __future__ import annotations

from typing import Callable, List

from sapp_ppb.simulator.sim_block import BlockSim, HeadBlockSim, MicroBlockSim

BuilderFn = Callable[..., List[BlockSim]]


class PipelineBuilder:
    r"""Build pipeline scheduler"""
    @staticmethod
    def _inter_merge(a: list[MicroBlockSim], b: list[MicroBlockSim], delta: int = 0) -> list[MicroBlockSim]:
        r"""merge forward and backward chain for 1f1b"""
        res = []
        if delta >= 0:
            res.extend(a[:delta])
            a = a[delta:]
        else:
            res.extend(b[:-delta])
            b = b[-delta:]
        stable_count = 0
        while a:
            block = a.pop(0)
            block.phase = 'stable'
            res.append(block)
            stable_count += 1
            if b:
                block = b.pop(0)
                block.phase = 'stable'
                res.append(block)
                stable_count += 1
            else:
                break
        if stable_count:
            res[-1].phase = 'cooldown'
        if a:
            res.extend(a)
        elif b:
            res.extend(b)
        return res

    @staticmethod
    def _build_chain(line: list[MicroBlockSim], p: int) -> list[BlockSim]:
        r"""build pipeline chain"""
        # pylint: disable=E1120
        head = HeadBlockSim(p)
        left = head
        for item in line:
            left.right = item
            item.left = left
            left = item
        if p == 0:
            head.right.pre = head
        return line

    @staticmethod
    # pylint: disable=W0613
    def build_1f1b(pp: int, micro_num: int, vp: int, p: int,
                   forward_time: List[float], backward_time: List[float],
                   block_mem: List[float], block_mem_par: List[float]) -> List[BlockSim]:
        """Build a 1F1B schedule chain for one pipeline rank.
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

        Returns:
            The ordered chain of :class:`BlockSim` nodes.
        """
        forward_time = forward_time[0]
        backward_time = backward_time[0]
        block_mem = block_mem[0]
        block_mem_par = block_mem_par[0]
        for_line = [MicroBlockSim(p, 'f', i, 0, forward_time, mem=block_mem, mem_par=block_mem_par, phase='warmup')
                    for i in range(micro_num)]
        back_line = [MicroBlockSim(p, 'b', i, 0, backward_time, mem=block_mem, mem_par=block_mem_par, phase='cooldown')
                     for i in range(micro_num)]
        line = PipelineBuilder._inter_merge(for_line, back_line, pp - p - 1)
        return PipelineBuilder._build_chain(line, p)

    @staticmethod
    def build_virtualpipeline(pp: int, micro_num: int, vp: int, p: int,
                              forward_time: List[float], backward_time: List[float],
                              block_mem: List[float],
                              block_mem_par: List[float]) -> List[BlockSim]:
        """Build a virtual-pipeline (VPP) 1F1B chain for one pipeline rank."""
        for_line = []
        back_line = []
        r = micro_num % pp
        for inter in range(micro_num // pp):
            for i in range(vp):
                bi = vp - 1 - i
                if inter == 0:
                    for_line.extend([MicroBlockSim(p, 'f', m, i, forward_time[i],
                                                   mem=block_mem[i], mem_par=block_mem_par[i],
                                                   phase='warmup') for m in range(r)])
                    back_line.extend([MicroBlockSim(p, 'b', m, bi, backward_time[bi],
                                                    mem=block_mem[bi], mem_par=block_mem_par[bi],
                                                    phase='cooldown') for m in range(r)])
                for_line.extend([MicroBlockSim(p, 'f', r + m + inter * pp, i, forward_time[i],
                                               mem=block_mem[i], mem_par=block_mem_par[i],
                                               phase='warmup') for m in range(pp)])
                back_line.extend([MicroBlockSim(p, 'b', r + m + inter * pp, bi, backward_time[bi],
                                                mem=block_mem[bi], mem_par=block_mem_par[bi],
                                                phase='cooldown') for m in range(pp)])
        line = PipelineBuilder._inter_merge(for_line, back_line, (vp + 1) * pp - 2 * p - 2 + r * (vp - 1))
        return PipelineBuilder._build_chain(line, p)

    @staticmethod
    def build_virtualpipeline2(pp: int, micro_num: int, vp: int, p: int,
                               forward_time: List[float], backward_time: List[float],
                               block_mem: List[float],
                               block_mem_par: List[float]) -> List[BlockSim]:
        """Build a VPP 1F1B chain using the less-memory scheduler variant."""
        for_line = []
        back_line = []
        r = micro_num % pp
        for inter in range(micro_num // pp):
            for i in range(vp):
                bi = vp - 1 - i
                if inter == 0:
                    for_line.extend([MicroBlockSim(p, 'f', m, i, forward_time[i],
                                                   mem=block_mem[i], mem_par=block_mem_par[i],
                                                   phase='warmup') for m in range(r)])
                    back_line.extend([MicroBlockSim(p, 'b', m, bi, backward_time[bi],
                                                    mem=block_mem[bi], mem_par=block_mem_par[bi],
                                                    phase='cooldown') for m in range(r)])
                for_line.extend([MicroBlockSim(p, 'f', r + m + inter * pp, i, forward_time[i],
                                               mem=block_mem[i], mem_par=block_mem_par[i],
                                               phase='warmup') for m in range(pp)])
                back_line.extend([MicroBlockSim(p, 'b', r + m + inter * pp, bi, backward_time[bi],
                                                mem=block_mem[bi], mem_par=block_mem_par[bi],
                                                phase='cooldown') for m in range(pp)])

        line = PipelineBuilder._inter_merge(for_line, back_line, vp * pp - p - 1)
        return PipelineBuilder._build_chain(line, p)

    @staticmethod
    def get_builder(method: str = '1f1b') -> BuilderFn:
        """Return the schedule-builder callable for a given schedule ``method``.

        Args:
            method: One of ``'1f1b'``, ``'vpp'``, ``'vpp2'``.
166
167
168
169
170
171
172
173
174
175
176

        Raises:
            ValueError: If ``method`` is not one of the supported values.
        """
        if method == '1f1b':
            return PipelineBuilder.build_1f1b
        if method == 'vpp':
            return PipelineBuilder.build_virtualpipeline
        if method == 'vpp2':
            return PipelineBuilder.build_virtualpipeline2
        raise ValueError(f"`method` only support ['1f1b', 'vpp', 'vpp2'], but got {method}")
hyper_parallel/auto_parallel/sapp_ppb/simulator/plot_manager.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Matplotlib canvas manager that renders the pipeline simulator timeline & memory plots."""
from __future__ import annotations

from collections.abc import Iterable
from typing import List, Optional

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.transforms import ScaledTranslation

from sapp_ppb.simulator.sim_block import BlockSim, MicroBlockSim


class PlotMgr:
    """Holds a matplotlib figure and its axes, and provides the simulator draw helpers."""

    # pylint: disable=W0613
    def __init__(self, *args: object, num_plots: int = 2, ax_type: object = 'block',
                 subplot_args: Optional[List[int]] = None,
                 sub_fig: Optional[plt.Figure] = None, **kwargs: object) -> None:
        """Create ``num_plots`` sub-axes on a new or reused matplotlib figure.
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
            subplot_args: Optional explicit ``add_subplot`` specifiers; must have length
                ``>= num_plots``.
            sub_fig: Reuse this figure if given, otherwise create one of ``figsize``.
        """
        if sub_fig:
            self.fig = sub_fig
        else:
            self.fig = plt.figure(figsize=kwargs.get('figsize', (12, 8)))
        self.fig.subplots_adjust(wspace=0, hspace=0.4)
        ax_type = ax_type if isinstance(ax_type, (list, tuple)) else [ax_type] * num_plots
        self.ax: List[plt.Axes] = []
        for i in range(num_plots):
            if subplot_args is None:
                self.ax.append(self.fig.add_subplot(num_plots * 100 + 10 + i + 1))
            elif isinstance(subplot_args, Iterable) and len(subplot_args) >= num_plots:
                self.ax.append(self.fig.add_subplot(subplot_args[i]))
            else:
                raise ValueError(f"Unsupported subplot_args format: {subplot_args}")

    def _set_block_ax(self, ax: plt.Axes, pp: int) -> None:
        """Configure one block-timeline axis (title, y ticks, y limits)."""
        ax.set_title("Pipeline Flow Timeline")
        ax.set_yticks(range(pp), [f"stage {p}" for p in range(pp)])
        for tick in ax.get_yticklabels():
            tick.set_verticalalignment('top')
            tick.set_transform(
                tick.get_transform() + ScaledTranslation(0, 0.05 - 1 / pp, self.fig.dpi_scale_trans))
            tick.set_fontsize(12)
        ax.set_ylim(0, pp)
        ax.invert_yaxis()

    def _get_block_indices(self, blocks: List[List[MicroBlockSim]],
                           mode: str = 'compact',
                           equal_wide: bool = False) -> List[np.ndarray]:
        """Return per-stage cumulative x-coordinates suitable for drawing ``blocks``."""
        if mode not in ['compact', 'joint', 'timeline']:
            raise ValueError(f"Get unsupported draw mode: {mode}")
        if mode == 'timeline' and not blocks[-1][-1].finish:
            raise ValueError("Block building should be finished before drawing timeline")
        block_index: List[np.ndarray] = []
        for block_p in blocks:
            inds: List[float] = []
            for block in block_p:
                if mode == 'compact':
                    if block.type == 'c':
                        inds.append(1 if equal_wide else block.time)
                    else:
                        inds.append(0)
                elif mode == 'joint':
                    if block.type == 'c':
                        inds.append(1 if equal_wide else block.time)
                    else:
                        inds.append(block.time)
                else:
                    inds.append(1)
            inds.insert(0, 0)
            inds = np.cumsum(inds)
            block_index.append(inds)
        return block_index

    def draw_block(self, block_index: List[np.ndarray], blocks: List[List[MicroBlockSim]],
                   ax_index: int = 0, equal_wide: bool = False,
                   width: float = 1, phase: bool = False) -> "PlotMgr":
        """Draw all compute blocks onto ``self.ax[ax_index]``."""
        for p, block_p in enumerate(blocks):
            for b, block in enumerate(block_p):
                if block.type == 'c':
                    block.draw(self.ax[ax_index], index=block_index[p][b],
                               equal_wide=equal_wide, width=width, phase=phase)
        return self

    def draw_comm(self, block_index: List[np.ndarray], blocks: List[List[MicroBlockSim]],
                  ax_index: int = 0, equal_wide: bool = False,
                  mode: str = 'compact') -> "PlotMgr":
        """Draw send/receive comm blocks onto ``self.ax[ax_index]``."""
        for p, block_p in enumerate(blocks):
            for b, block in enumerate(block_p):
                if block.type == 'c' and mode == 'compact':
                    if block.send_block:
                        block.send_block.draw(self.ax[ax_index], index=block_index[p][b],
                                              equal_wide=equal_wide)
                    if block.rec_block:
                        block.rec_block.draw(self.ax[ax_index], index=block_index[p][b],
                                             equal_wide=equal_wide)
                elif block.type in ['s', 'r'] and mode in ['joint', 'timeline']:
                    block.draw(self.ax[ax_index], index=block_index[p][b],
                               equal_wide=equal_wide, mode=mode)
        return self

    def draw_connect(self, block_index: List[np.ndarray], blocks: List[List[MicroBlockSim]],
                     ax_index: int = 0, equal_wide: bool = False,
                     mode: str = 'compact') -> "PlotMgr":
        """Draw the arrows that connect each send block to its matching receive block."""
        for p, block_p in enumerate(blocks):
            for b, block in enumerate(block_p):
                if block.type == 'c' and mode == 'compact' and block.send_block:
                    dual_p = block.send_block.dual.stage
                    dual_ind = blocks[dual_p].index(block.send_block.dual.host)
                    block.send_block.draw_comm(
                        self.ax[ax_index], index_from=block_index[p][b],
                        index_to=block_index[dual_p][dual_ind],
                        equal_wide=equal_wide, mode=mode)
                elif block.type == 's' and mode in ['joint', 'timeline']:
                    dual_p = block.dual.stage
                    dual_ind = blocks[dual_p].index(block.dual)
                    block.draw_comm(
                        self.ax[ax_index], index_from=block_index[p][b],
                        index_to=block_index[dual_p][dual_ind],
                        equal_wide=equal_wide, mode=mode)
        return self

    def draw(self, blocks: List[List[MicroBlockSim]], ax_index: int = 0,
             comm: bool = False, connect: bool = False,
             equal_wide: bool = False, mode: str = 'compact',
             phase: bool = False) -> "PlotMgr":
        """Draw the full pipeline timeline: blocks, comm layer and connect arrows."""
        pp = len(blocks)
        block_index = self._get_block_indices(blocks, mode=mode, equal_wide=equal_wide)
        width = max(np.max(block_index[p]) for p in range(pp)) if blocks[0][-1].end is None \
            else max(blocks[p][-1].end for p in range(pp))
        self.draw_block(block_index, blocks, ax_index, equal_wide, width, phase=phase)
        if comm:
            self.draw_comm(block_index, blocks, ax_index, equal_wide, mode)
        if connect:
            self.draw_connect(block_index, blocks, ax_index, equal_wide, mode)
        self._set_block_ax(self.ax[ax_index], pp)
        self.ax[ax_index].set_xlim(0, width)
        self.ax[ax_index].set_xticks(np.linspace(0, width, 8))
        return self

    def draw_loop(self, blocks: List[List[MicroBlockSim]], loop: List[BlockSim],
                  ax_index: int = 0, comm: bool = False, connect: bool = False,
                  equal_wide: bool = False) -> "PlotMgr":
        """Highlight a dependency loop (non-comm) with red arrows and a textual trace."""
        self.draw(blocks, ax_index, comm, connect, equal_wide, phase=True)
        block_index = self._get_block_indices(blocks, equal_wide=equal_wide)
        msg = 'dependency loop: '
        for b in range(len(loop) - 1):
            p = loop[b].stage
            ind = blocks[p].index(loop[b])
            x1, y1, dx1, _ = loop[b].loc_size(block_index[p][ind], equal_wide)
            p = loop[b + 1].stage
            ind = blocks[p].index(loop[b + 1])
            x2, y2, dx2, _ = loop[b + 1].loc_size(block_index[p][ind], equal_wide)
            msg = f'{msg} {loop[b].color_label} -> '
            self.ax[ax_index].annotate(
                None, xy=(x1 + dx1 / 2, y1), xytext=(x2 + dx2 / 2, y2),
                arrowprops={"fc": 'white', "ec": 'r', "arrowstyle": 'simple',
                            "shrinkA": 5, "shrinkB": 5,
                            "connectionstyle": "arc3,rad=-0.1"})
        self.msg = f'{msg} {loop[len(loop) - 1].color_label}'
        return self

    def draw_comm_loop(self, lines: List[List[BlockSim]], loop: List[BlockSim],
                       ax_index: int = 0) -> "PlotMgr":
        """Highlight a dependency loop in the send-receive graph."""
        self.draw(lines, ax_index, True, True, True, 'joint', phase=True)
        block_index = self._get_block_indices(lines, mode='joint', equal_wide=True)
        msg = 'dependency loop: '
        for b in range(len(loop) - 1):
            p = loop[b].stage
            ind = lines[p].index(loop[b])
            x1, y1, dx1, _ = loop[b].loc_size(block_index[p][ind], True, 'joint')
            p = loop[b + 1].stage
            ind = lines[p].index(loop[b + 1])
            x2, y2, dx2, _ = loop[b + 1].loc_size(block_index[p][ind], True, 'joint')
            msg = f'{msg} {loop[b].color_label} -> '
            self.ax[ax_index].annotate(
                None, xy=(x1 + abs(dx1) / 2, y1), xytext=(x2 + abs(dx2) / 2, y2),
                size=10,
                arrowprops={"fc": 'white', "ec": 'r', "arrowstyle": 'simple',
                            "shrinkA": 3, "shrinkB": 3,
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
                size=10,
                arrowprops={"fc": 'white', "ec": 'r', "arrowstyle": 'simple',
                            "shrinkA": 3, "shrinkB": 3,
                            "connectionstyle": "arc3,rad=-0.1", "lw": 0.8})
        self.msg = f'{msg} {loop[len(loop) - 1].color_label}'
        return self

    def draw_mem(self, block_mem_list: List[np.ndarray], ax_index: int = 0) -> "PlotMgr":
        """Plot per-stage block memory curves on the axis at ``ax_index``."""
        for p, block_mem in enumerate(block_mem_list):
            self.ax[ax_index].plot((block_mem.T)[0], (block_mem.T)[1], label=f"stage-{p}")
        self.ax[ax_index].set_title("Block Memory Timeline")
        width = max(np.max((block_mem.T)[0]) for block_mem in block_mem_list)
        height = max(np.max((block_mem.T)[1]) for block_mem in block_mem_list)
        self.ax[ax_index].set_xlim(
            0, max(np.max((block_mem.T)[0]) for block_mem in block_mem_list))
        self.ax[ax_index].set_xticks(np.linspace(0, width, 8))
        self.ax[ax_index].set_yticks(np.linspace(0, height, 4))
        return self

    def draw_info(self, bubble_info: dict, mem_info: List[float]) -> None:
        """Draw the bubble / peak-memory annotation lines at the top & bottom of the figure."""
        info_list = [f'{k} bubble: {v:.4f}' for k, v in bubble_info.items()]
        self.fig.text(0.5, 0.5, ', '.join(info_list), ha='center', va='center',
                      fontdict={'fontsize': 13, 'weight': 'medium'}, color='C3')
        info_list = [f"{v:.0f}" for v in mem_info]
        self.fig.text(0.5, 0.05, f"peak memory: {', '.join(info_list)}", ha='center', va='center',
                      fontdict={'fontsize': 10, 'weight': 'medium'}, color='C0')

    def save(self, file_name: str) -> None:
        """Save the figure to ``file_name``."""
        self.fig.legend(bbox_to_anchor=(0.22, 0.45))
        plt.savefig(file_name)

    def show(self, file_name: Optional[str] = None) -> None:
        """Display the figure interactively, optionally also saving it to ``file_name``."""
        self.fig.legend(bbox_to_anchor=(0.22, 0.45))
        if file_name is not None:
            plt.savefig(file_name)
        plt.show()
hyper_parallel/auto_parallel/sapp_ppb/simulator/pp_simulator.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Pipeline scheduler simulator: builds block dependencies, computes bubbles and peak memory."""
from __future__ import annotations

import copy
import sys

import numpy as np

from sapp_ppb.simulator.causal_error import CausalCommError, CausalError
from sapp_ppb.simulator.pipeline_builder import PipelineBuilder
from sapp_ppb.simulator.plot_manager import PlotMgr
from sapp_ppb.simulator.sim_block import BlockSim, RecBlockSim, SendBlockSim
from sapp_ppb.simulator.utils import apply_color, apply_format, format_2d_inputs
from sapp_ppb.utils.logger import logger

sys.setrecursionlimit(8192)


class PipelineSimulator:
    r"""
    Pipeline Simulator which provide pipeline flow process, bubbles and relative memories for stages.

    Args:
100
101
102
103
104
105
106
107
108
        --------------------  memory  --------------------
        peak memory: 96.00, 96.00, 96.00, 96.00, 96.00, 96.00, 96.00, 93.00, 103.00, 97.00, 91.00,
        85.00, 79.00, 73.00, 70.20
    """
    def __init__(self, block_time: list, micro_num: int, *args: object,
                 comm_time: float = 0.1,
                 layer_recompute: object = False, block_mem: object = 1,
                 block_mem_par: object = 0, constant_mem: float = 0,
                 backward_ratio: object = 2.,
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
                 block_mem_par: object = 0, constant_mem: float = 0,
                 backward_ratio: object = 2.,
                 sub_fig: object = None, **kwargs: object) -> None:
        """Delegate initialisation to :meth:`init` (kept as a named method for subclassing)."""
        self.init(block_time, micro_num, comm_time, layer_recompute, block_mem,
                  block_mem_par, constant_mem, backward_ratio, sub_fig, *args, **kwargs)

    # pylint: disable=W0613
    def init(self, block_time: list, micro_num: int, comm_time: float,
             layer_recompute: object, block_mem: object, block_mem_par: object,
             constant_mem: float, backward_ratio: object,
             sub_fig: object, *args: object, **kwargs: object) -> None:
        """Build the block grid, statistics and communication graph for the simulator."""
        self.micro_num = micro_num
        self.pp, self.vp = self._base_init(block_time)
        self.block_num = 2 * self.vp * self.micro_num
        self.comm_time = comm_time
        self._input_format(block_time, layer_recompute, block_mem, block_mem_par, backward_ratio)
        self.constant_mem = constant_mem
        self._statistic_init()
        self._comm = True
        self.adjust_func_list = [self.swap_send_rec]
        self.sub_fig = sub_fig
        # Construct pipeline blocks
        if self.vp == 1:
            method = '1f1b'
        else:
            method = kwargs.get('method', 'vpp')
            if self.micro_num >= self.pp:
                self.adjust_func_list = [self.vpp_send_delay, self.residue_delay] + self.adjust_func_list
        pp_builder = PipelineBuilder.get_builder(method)
        self.blocks = [pp_builder(self.pp, self.micro_num, self.vp, p, self.block_time[:, p],
                                  self.backward_time[:, p], self.block_mem[:, p], self.block_mem_par[:, p])
                       for p in range(self.pp)]

        self._build_block() # create connection among compute blocks
        self._build_comm_block() # create comm blocks for each compute block

    def run(self, comm: bool = True, print_info: bool = True) -> "PipelineSimulator":
        """Run simulation to schedule the pipeline.

        Args:
            comm: Whether to build the pipeline considering communication dependency and time.
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        Raises:
            CausalError: If the block sequences contain a dependency loop.
            CausalCommError: If the block-with-comm sequences contain a dependency loop.
        """
        self._comm = comm
        self._check_loop()
        if comm:
            self.lines = self._create_lines(*self.adjust_func_list)
            self._check_comm_loop()
            for b in range(self.block_num):
                for p in range(self.pp):
                    self.blocks[p][b].build_with_comm()
            self.lines[0][-1].build_with_comm()
        else:
            for p in range(self.pp):
                for block in self.blocks[p]:
                    block.build_without_comm()
        self._statistic_info()
        if print_info:
            self.print_info()
        return self

    def draw(self, comm: bool = True, connect: bool = None) -> "PipelineSimulator":
        """Show the pipeline and memory timeline.

        Args:
            comm: Whether to show the comm blocks. Default: ``True``.
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523

        Returns:
            The current :class:`PipelineSimulator` instance (for chaining).
        """
        self.canvas = PlotMgr(2, ['block', 'memory'], sub_fig=self.sub_fig)
        if self._comm:
            connect = True if connect is None else connect
            self.canvas.draw(self.lines, 0, comm, connect, False, 'timeline')
        else:
            connect = False if connect is None else connect
            self.canvas.draw(self.blocks, 0, comm, connect, False, 'timeline')
        self.canvas.draw_mem(self.states.get('block_mem_list', []), 1)
        self.canvas.draw_info(self.bubbles, self.peak_memory)
        return self


    def show(self, comm: bool = True, connect: bool = None,
             file_name: str = None) -> "PipelineSimulator":
        """Draw the pipeline and display/save it via the canvas."""
        self.draw(comm, connect)
        self.canvas.show(file_name)
        return self

    def save(self, file_name: str, comm: bool = True,
             connect: bool = None) -> "PipelineSimulator":
        """Draw the pipeline and save it to ``file_name``."""
        self.draw(comm, connect)
        self.canvas.save(file_name)
        return self

    def print_info(self) -> "PipelineSimulator":
        """Log bubble and peak memory information."""
        bubble_colors = ['1;33', '1;32', '1;31', '1;35', '1;36']
        header = '\033[1;37m' + '—' * 13 + \
            f' pp:{self.pp:>2}, vp:{self.vp:>2}, micro:{self.micro_num:>3} ' + \
            '—' * 12 + '\033[0m'
        bubble_header = '-' * 20 + ' bubble ' + '-' * 20
        bubble_keys = apply_format(apply_color(list(self.bubbles.keys()), bubble_colors))
        bubble_values = apply_format(apply_color(list(self.bubbles.values()), bubble_colors))
        memory_header = '-' * 20 + ' memory ' + '-' * 20
        peak_memory = f"peak memory: {', '.join(f'{v:.2f}' for v in self.peak_memory)}"
        logger.output(
            "%s\n%s\n%s\n%s\n%s\n%s",
            header, bubble_header, bubble_keys, bubble_values, memory_header, peak_memory,
        )
        return self

    def _base_init(self, block_time) -> tuple:
        r"""init base setting"""
        if isinstance(block_time, (list, tuple)):
            if all(isinstance(item, (list, tuple)) for item in block_time):
                vp = len(block_time)
                pp = len(block_time[0])
            elif all(isinstance(item, (int, float)) for item in block_time):
                vp = 1
                pp = len(block_time)
            else:
                raise ValueError(f"Unsupported input format block_time: {block_time}")
        else:
            raise ValueError(f"Unsupported input format block_time: {block_time}")
        if self.micro_num < pp:
            raise ValueError(f" `micro_num`({self.micro_num}) should equal or larger than `pp`({pp})")
        return pp, vp

    def _input_format(self, block_time, layer_recompute, block_mem, block_mem_par, backward_ratio) -> None:
        r"""format inputs as 2d array"""
        self.block_time = format_2d_inputs(block_time, self.vp, self.pp)
        if isinstance(layer_recompute, bool):
            self.layer_recompute = self.block_time if layer_recompute else format_2d_inputs(0, self.vp, self.pp)
        else:
            self.layer_recompute = format_2d_inputs(layer_recompute, self.vp, self.pp)
        if isinstance(block_mem, (int, float)):
            self.block_mem = self.block_time * block_mem
        else:
            self.block_mem = format_2d_inputs(block_mem, self.vp, self.pp)

        if isinstance(block_mem_par, (int, float)):
            self.block_mem_par = self.block_time * block_mem_par
        else:
            self.block_mem_par = format_2d_inputs(block_mem_par, self.vp, self.pp)

        self.backward_ratio = format_2d_inputs(backward_ratio, self.vp, self.pp)

    def _statistic_init(self) -> None:
        r"""init statistic info"""
        self.forward_time = self.block_time
        self.backward_time = self.block_time * self.backward_ratio + self.layer_recompute
        self.states = {'last_time': np.zeros(self.pp),
                       'warmup_time': np.zeros(self.pp),
                       'cooldown_time': np.zeros(self.pp),
                       'stable_free_time': (np.zeros((self.vp, self.pp)), np.zeros((self.vp, self.pp))),
                       'block_mem_list': [np.array([[0, 0]]) for _ in range(self.pp)]}
        self.model_compute_time = (np.sum(self.forward_time) + \
                                   np.sum(self.forward_time * self.backward_ratio)) * self.micro_num
        self.hardware_compute_time = (np.sum(self.forward_time) + np.sum(self.backward_time)) * self.micro_num
        self.bubbles = {'real': 0,
                        'ideal': (self.pp - 1) / self.vp / self.micro_num,
                        'imba': 0,
                        'comm': 0}
        if np.sum(self.layer_recompute) > 1e-5:
            self.bubbles['recompute'] = self.hardware_compute_time / self.model_compute_time - 1
        p, v, m = self.pp, self.vp, self.micro_num
        if self.vp == 1:
            if self.pp == 2:
                self.bubbles['comm'] = 4 * m
            elif self.pp % 2 == 0:
                self.bubbles['comm'] = 4 * p * m + 4 * p ** 2 - 14 * p
            else:
                self.bubbles['comm'] = 4 * p * m + 4 * p ** 2 - 12 * p
        elif self.pp <= 5:
            comm_coef_list = [[4, -2, 0], [6, -2, -6], [4, 0, 12], [6, -2, 40]]
            self.bubbles['comm'] = np.dot(np.array([p * v * m, m * p, 1]), comm_coef_list[self.pp - 2])
        elif self.pp % 2 == 0:
            self.bubbles['comm'] = 4 * p * v * m + 4 * p ** 2 - 13 * p
        else:
            self.bubbles['comm'] = 6 * p * v * m - 2 * v * p ** 2 + 4 * v * p - 2 * p * m + 6 * p ** 2 - 16 * p

        self.bubbles['comm'] *= self.comm_time / self.model_compute_time

    def _statistic_info(self) -> None:
        r"""compute statistic info"""
        for p in range(self.pp):
            blocks = self.lines[p] if self._comm else self.blocks[p]
            current_mem = self.constant_mem + blocks[0].mem_par

            for block in blocks:
                if block.type == 'c' and block.state == 'f':
                    current_mem += block.mem
                elif block.type == 'c' and block.state == 'b':
                    if not self._comm or not block.rec_block:
                        current_mem -= block.mem
                elif block.type == 'r' and block.host.state == 'b':
                    current_mem -= block.host.mem
                    block = block.host
                else:
                    continue
                self.states['block_mem_list'][p] = np.append(self.states['block_mem_list'][p],
                                                             np.array([[block.end, current_mem]]), axis=0)
            self.states['block_mem_list'][p] = np.append(self.states['block_mem_list'][p],
                                                         np.array([[blocks[-1].end, current_mem]]), axis=0)
        self.peak_memory = [np.max((self.states['block_mem_list'][p].T)[1]) for p in range(self.pp)]
        self.end_time = max(np.max((self.states['block_mem_list'][p].T)[0]) for p in range(self.pp))
        self.bubbles['real'] = (self.pp * self.end_time - self.model_compute_time) / self.model_compute_time
        self.bubbles['imba'] = self.bubbles['real'] - self.bubbles['ideal'] + 1e-10
        if not self._comm:
            self.bubbles.pop('comm')
        else:
            self.bubbles['imba'] -= self.bubbles['comm']
        if self.bubbles.get('recompute'):
            self.bubbles['imba'] -= self.bubbles['recompute']

    def _get_pre_label(self, label: tuple) -> tuple:
        r"""get pre block label"""
        t, s, m, v, p = label
        if (s, v, p) == ('f', 0, 0):
            return ('h', p)
        if (s, p) == ('f', 0):
            res = (t, s, m, v - 1, self.pp - 1)
            return res
        if (s, p) == ('b', self.pp - 1):
            if v == self.vp - 1:
                res = (t, 'f', m, self.vp - 1, p)
                return res
            res = (t, s, m, v + 1, 0)
            return res
        if s == 'f':
            res = (t, s, m, v, p - 1)
            return res
        if s == 'b':
            res = (t, s, m, v, p + 1)
            return res
        raise ValueError(f"Illegal label: {label}")

    def _build_block(self) -> None:
        r"""Build `pre` relation for computation blocks."""
        books = {self.blocks[0][0].pre.label: self.blocks[0][0].pre}
        for p in range(self.pp):
            for item in self.blocks[p]:
                books[item.label] = item
        for p in range(self.pp):
            block = self.blocks[p][0]
            while block is not None:
                pre_label = self._get_pre_label(block.label)
                block.pre = books.get(pre_label, None)
                block = block.right

    def _build_comm_block(self) -> None:
        r"""Build `send_block` and `rec_block` relation among a computation block and two comm blocks."""
        for p in range(self.pp):
            block = self.blocks[p][0]
            while block is not None:
                pre = block.pre
                if pre.stage != block.stage:
                    block.rec_block = RecBlockSim(p, block.state, block.id, block.chunk, self.comm_time)
                    pre.send_block = SendBlockSim(pre.stage, pre.state, pre.id, pre.chunk, self.comm_time)
                    block.rec_block.host = block
                    block.rec_block.dual = pre.send_block
                    pre.send_block.host = pre
                    pre.send_block.dual = block.rec_block
                    block.depend_pre = block.rec_block
                    block.rec_block.depend_pre = pre.send_block
                    pre.send_block.depend_pre = pre
                else:
                    block.depend_pre = pre
                block = block.right

    def _check_loop(self) -> None:
        r"""check the existence of dependency"""
        loop = self.blocks[0][-1].loop()
        if loop:
            raise CausalError('Block dependency exist loops!', self.blocks, loop)
        for p in range(self.pp):
            for block in self.blocks[p]:
                block.flag = False

    def _check_comm_loop(self) -> None:
        r"""check the existence of comm dependency"""
        loop = self.lines[0][-1].comm_loop()
        if loop:
            raise CausalCommError('Block comm dependency exist loops!', self.lines, loop)
        for p in range(self.pp):
            for block in self.lines[p]:
                block.flag = False

    def _create_lines(self, *adjust_func) -> list[list[BlockSim]]:
        r"""create block line for each stage with comm"""
        lines = [copy.copy(self.blocks[p]) for p in range(self.pp)]
        for p in range(self.pp):
            for b in range(self.block_num):
                block = self.blocks[p][b]
                pre = block.pre
                if block.rec_block:
                    lines[p].insert(lines[p].index(block), block.rec_block)
                    if pre.type == 'h':
                        lines[pre.stage].insert(0, pre.send_block)
                    else:
                        lines[pre.stage].insert(lines[pre.stage].index(pre) + 1, pre.send_block)
        for func in adjust_func:
            lines = func(lines)
        for p in range(self.pp):
            for b, block in enumerate(lines[p]):
                if b == 0:
                    block.depend_left = block.left if block.left else block.host.left
                else:
                    block.depend_left = lines[p][b - 1]
        return lines

    def _get_block_phase(self, p: int, b: int) -> str:
        r"""get block phase"""
        r = self.micro_num % self.pp
        if b < (self.vp + 1) * self.pp - 2 * p - 2 + r:
            return 'warmup'
        if b > self.block_num - (self.vp + 1) * self.pp + 2 * p:
            return 'cooldown'
        return 'stable'

    def _send_block_delay(self, lines, p: int, b: int, distance: int) -> None:
        r"""adjust send block: delay send block"""
        i_send = lines[p].index(self.blocks[p][b].send_block)
        send_block = lines[p].pop(i_send)
        i_new = lines[p].index(self.blocks[p][b + distance]) + 1
        lines[p].insert(i_new, send_block)

    def _process_swap(self, block, lines, p, b, i_b, i_bn) -> bool:
        r"""process swap in condition"""
        if i_bn - i_b == 3:
            if p % 2 == 0 and lines[p][i_b + 1].type == 'r' and lines[p][i_b + 2].type == 's':
                lines[p][i_b + 1], lines[p][i_b + 2] = lines[p][i_b + 2], lines[p][i_b + 1]
            if p % 2 == 1 and lines[p][i_b + 1].type == 's' and lines[p][i_b + 2].type == 'r':
                if block.phase == 'warmup' and self.blocks[p][b + 1].phase == 'cooldown':
                    return False
                lines[p][i_b + 1], lines[p][i_b + 2] = lines[p][i_b + 2], lines[p][i_b + 1]
            if lines[p][i_b + 1].dual.stage == lines[p][i_b + 2].dual.stage:
                pd = lines[p][i_b + 1].dual.stage
                j_b1 = lines[pd].index(lines[p][i_b + 1].dual)
                j_b2 = lines[pd].index(lines[p][i_b + 2].dual)
                if j_b1 > j_b2:
                    lines[p][i_b + 1], lines[p][i_b + 2] = lines[p][i_b + 2], lines[p][i_b + 1]
        if i_bn - i_b == 4:
            if lines[p][i_b + 1].dual.stage == lines[p][i_b + 2].dual.stage and \
                lines[p][i_b + 2].dual.stage == lines[p][i_b + 3].dual.stage:
                if lines[p][i_b + 1].type == 's' and lines[p][i_b + 2].type == 's' \
                    and lines[p][i_b + 3].type == 'r':
                    lines[p][i_b + 1], lines[p][i_b + 2] = lines[p][i_b + 2], lines[p][i_b + 1]
        return True

    def swap_send_rec(self, lines: list[list[BlockSim]]) -> list[list[BlockSim]]:
        """Adjust send blocks: swap adjacent send/receive pairs where ordering is ambiguous."""
        for p in range(self.pp):
            for b, block in enumerate(self.blocks[p]):
                if b >= len(self.blocks[p]) - 1:
                    continue
                i_b = lines[p].index(block)
                i_bn = lines[p].index(self.blocks[p][b + 1])
                self._process_swap(block, lines, p, b, i_b, i_bn)
        return lines

    def vpp_send_delay(self, lines: list[list[BlockSim]]) -> list[list[BlockSim]]:
        """Adjust VPP send blocks by delaying them one slot during the stable phase."""
        if self.micro_num % self.pp != 0:
            return lines
        for p in range(self.pp):
            for b, block in enumerate(self.blocks[p]):
                if block.send_block is not None and block.phase == 'stable':
                    self._send_block_delay(lines, p, b, 1)
        return lines

    def residue_delay(self, lines: list[list[BlockSim]]) -> list[list[BlockSim]]:
        """Adjust send blocks when ``micro_num % pp`` leaves a residue micro-batch."""
        r = self.micro_num % self.pp
        if r == 0:
            return lines
        for p in range(self.pp):
            for b, block in enumerate(self.blocks[p]):
                if block.send_block is None:
                    continue
                if p == self.pp - 1 and block.id < self.pp + r and block.state == 'f':
                    self._send_block_delay(lines, -1, b, r + max(0, block.id - self.pp + 1))
                elif p == 0 and block.id < self.pp + r and block.state == 'b':
                    if self.micro_num // self.pp == 1:
                        self._send_block_delay(lines, 0, b, r)
                    else:
                        self._send_block_delay(lines, 0, b, r + self.pp)
                elif block.phase == 'stable':
                    self._send_block_delay(lines, p, b, 1)
        return lines


if __name__ == '__main__':

    # PipelineSimulator([[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4 + 0.8]], 8, 0.1,
    #                   [[1, 0, 0, 0], [1, 0, 0, 0], [1, 1, 0, 0]],
    #                   [[1.1, 2, 2, 2], [1.1, 2, 2, 2], [1.1, 1.1, 2, 2]], method='vpp').run().show()
    PipelineSimulator(
        [[186.0, 171.0, 132.0, 132.0, 132.0, 132.0, 132.0, 132.0,
          132.0, 132.0, 132.0, 132.0, 132.0, 132.0, 132.0, 133.0]], 32,
        block_mem_act=[[1146, 908, 736, 736, 736, 736, 736, 736,
                        736, 736, 2623, 2623, 4510, 4510, 8284, 8284]],
hyper_parallel/auto_parallel/sapp_ppb/simulator/sim_block.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Simulator block primitives: compute blocks, comm (send/receive) blocks, head sentinels."""
from __future__ import annotations

from dataclasses import dataclass, field
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Polygon, Rectangle

from sapp_ppb.simulator.utils import color_mix, dfs_builder


@dataclass
class BlockSim:
    r"""base block sim class"""
    stage: int # p
    state: str # s
    id: int # m
    chunk: int # v
    time: float
    type: str
    start: float = None
    end: float = None
    pre: BlockSim = field(repr=False, default=None)
    left: BlockSim = field(repr=False, default=None)
    # pylint: disable=E0601
    right: MicroBlockSim = field(repr=False, default=None)
    depend_pre: BlockSim = field(repr=False, default=None)
    depend_left: BlockSim = field(repr=False, default=None)
    finish = False
    in_queue = False
    flag = False
    _color = '0;38'
    father: BlockSim = field(repr=False, default=None)

    @property
    def label(self) -> tuple:
        """Identity tuple ``(type, state, id, chunk, stage)`` used in log messages."""
        res = (self.type, self.state, self.id, self.chunk, self.stage)
        return res

    @property
    def color_label(self) -> str:
        """ANSI-coloured version of :attr:`label` for terminal output."""
        return f"\033[{self._color}m{self.label}\033[0m"

    @dfs_builder(False)
    def build_without_comm(self) -> None:
        r"""Build pipeline timeline without comm blocks and dependency."""
        self.pre.build_without_comm()
        self.left.build_without_comm()
        self.start = max(self.pre.end, self.left.end)
        self.end = self.start + self.time

    @dfs_builder(True)
    def build_with_comm(self) -> None:
        r"""Build pipeline timeline with comm blocks and dependency."""
        self.depend_pre.build_with_comm()
        self.depend_left.build_with_comm()
        self.start = max(self.depend_pre.end, self.depend_left.end)
        self.end = self.start + self.time

    def reset_time(self) -> None:
        r"""reset time"""
        self.start = None
        self.end = None
        self.finish = False

    # pylint: disable=W0613
    def loc_size(self, x: float = 0, equal_wide: bool = False,
                 mode: str = 'compact') -> Tuple[float, float, float, float]:
        """Return ``(x, y, dx, dy)`` for plotting this block at x-coordinate ``x``."""
        x = x if self.start is None else self.start
        dx = 1 if equal_wide else self.time
        res = x, self.stage + 0.5, dx, 1
        return res

    def loop(self, comm: bool = False) -> List["BlockSim"]:
        """Return the first dependency cycle discovered via DFS, or an empty list."""
        if self.flag and not self.in_queue:
            return []
        res = []
        if self.in_queue:
            loop = [self]
            block = self.father
            while block.father and block is not self:
                block = block.father
                loop.append(block)
            return loop
        self.flag = True
        self.in_queue = True
        depends = [self.depend_pre, self.depend_left] if comm else [self.pre, self.left]
        for dep in depends:
            if dep:
                dep.father = self
                res.extend(dep.loop(comm=comm))
                dep.father = None
        self.in_queue = False
        return res

    def comm_loop(self) -> list[BlockSim]:
        r"""recursively check comm loop"""
        return self.loop(True)


@dataclass
class HeadBlockSim(BlockSim):
    r"""sim block of head"""
    stage: int # p
    type: str = 'h'
    id: int = field(repr=False, init=False)
    state: str = field(repr=False, init=False)
    chunk: int = field(repr=False, init=False)
    time: float = 0.
    start: float = 0.
    end: float = 0.
    finish = True

    @property
    def label(self) -> tuple:
        """Identity tuple ``(type, stage)`` for head sentinel blocks."""
        return (self.type, self.stage)

    @property
    def repr(self) -> str:
        """Multi-line representation listing every block chained to the right of this head."""
        s_list = []
        block = self
        while block:
            s_list.append(repr(block))
            block = block.right
        return '\n'.join(s_list)

    # pylint: disable=W0613
    def draw(self, ax: plt.Axes, *args: object, **kwargs: object) -> None:
        """No-op: the head sentinel is not rendered."""
        return

    def build_without_comm(self) -> None:
        """No-op: the head sentinel has no non-comm dependencies."""
        return

    def build_with_comm(self) -> None:
        """No-op: the head sentinel has no comm dependencies."""
        return

    def reset_time_recursive(self) -> None:
        """No-op: the head sentinel has no times to reset."""
        return


@dataclass
class MicroBlockSim(BlockSim):
    r"""compute sim block"""
    type: str = 'c'
    mem: float = 0.
    mem_par: float = 0.
    phase: str = None
    # pylint: disable=E0601
    send_block: SendBlockSim = field(repr=False, default=None)
    # pylint: disable=E0601
    rec_block: RecBlockSim = field(repr=False, default=None)

    def __post_init__(self) -> None:
        """Choose the ANSI colour based on forward / backward state."""
        self._color = '1;34' if self.state == 'f' else '1;33'

    # pylint: disable=W0613
    def draw(self, ax: plt.Axes, *args: object, **kwargs: object) -> None:
        """Render this compute block as a coloured rectangle on ``ax``."""
        x, y, dx, dy = self.loc_size(kwargs.get('index', 0), kwargs.get('equal_wide', False))
        color = (167 / 255, 184 / 255, 231 / 255) if self.state == 'f' else (255 / 255, 213 / 255, 143 / 255)
        mix_color = (240 / 255, 255 / 255, 245 / 255) if self.state == 'f' else (255 / 255, 240 / 255, 255 / 255)
        color = color_mix(mix_color, color, w1=self.chunk / 3)
        if self.phase == 'warmup' and kwargs.get('phase', False):
            edgecolor = 'lightblue'
        elif self.phase == 'cooldown' and kwargs.get('phase', False):
            edgecolor = 'orange'
        else:
            edgecolor = 'black'
        rect = Rectangle((x, y - dy / 2), dx, dy, facecolor=color, edgecolor=edgecolor, linewidth=0.4)
        if dx > 0.008 * kwargs.get('width', 0):
            ax.text(rect.xy[0] + dx / 2, rect.xy[1] + dy / 2, str(self.id),
                    ha='center', va='center', color='black', fontdict={'fontsize': 9})
        ax.add_patch(rect)

    def reset_time_recursive(self) -> None:
        r"""reset block time"""
        if self.finish:
            self.pre.reset_time_recursive()
            self.left.reset_time_recursive()
            self.reset_time()


@dataclass
class CommBlockSim(BlockSim):
    r"""sim comm block"""
    host: MicroBlockSim = field(repr=False, default=None)
    dual: CommBlockSim = field(repr=False, default=None)

    def get_triangle(self, x: float, y: float, dx: float,
                     dy: float) -> List[List[float]]:
        """Return the three triangle vertices used by :meth:`draw`."""
        raise NotImplementedError

    # pylint: disable=W0613
    def draw(self, ax: plt.Axes, *args: object, **kwargs: object) -> None:
        """Render this communication block as a coloured triangle on ``ax``."""
        color = (167 / 255, 184 / 255, 231 / 255) if self.state == 'f' else (255 / 255, 213 / 255, 143 / 255)
        mix_color = (240 / 255, 255 / 255, 255 / 255) if self.state == 'f' else (255 / 255, 240 / 255, 255 / 255)
        color = color_mix(mix_color, color, w1=1.2 * self.chunk / 3)
        index, equal_wide, mode = (kwargs.get('index', 0), kwargs.get('equal_wide', False),
                                   kwargs.get('mode', 'compact'))
        x, y, dx, dy = self.loc_size(index, equal_wide, mode)
        xy = self.get_triangle(x, y, dx, dy)
        tri = Polygon(xy, closed=True, facecolor=color, edgecolor='black', linewidth=0.4)
        ax.add_patch(tri)


@dataclass
class SendBlockSim(CommBlockSim):
    r"""sim send comm block"""
    type: str = 's'
    _color = '35'

    def loc_size(self, x: float = 0, equal_wide: bool = False,
                 mode: str = 'compact') -> Tuple[float, float, float, float]:
        """Return ``(x, y, dx, dy)`` for a send block relative to its compute host."""
        host_x, _, hostdx_, _ = self.host.loc_size(x, equal_wide)
        x, y, _, _ = super().loc_size(x, equal_wide)
        dx_ = self.time
        dy_ = min(np.sqrt(self.time) * 0.6, 0.6)
        if mode == 'compact':
            x = host_x + hostdx_ - dx_
        res = x, y, dx_, dy_
        return res

    def get_triangle(self, x: float, y: float, dx: float,
                     dy: float) -> List[List[float]]:
        """Return the three triangle vertices for a send block pointing to the right."""
        return [[x, y - dy / 2], [x, y + dy / 2], [x + dx, y]]

    # pylint: disable=W0613
    def draw_comm(self, ax: plt.Axes, *args: object, **kwargs: object) -> None:
        """Draw the send→receive connector arrow between paired blocks on ``ax``."""
        index_from, index_to = (kwargs.get('index_from', 0), kwargs.get('index_to', 0))
        equal_wide, mode = (kwargs.get('equal_wide', False), kwargs.get('mode', 'compact'))
        x, y, dx, _ = self.loc_size(index_from, equal_wide, mode)
        x_, y_, dx_, _ = self.dual.loc_size(index_to, equal_wide, mode)
        ax.annotate(None, xy=(x_ - dx_ / 2, y_), xytext=(x + dx / 2, y),
                    arrowprops={"ec": 'grey', "arrowstyle": '->', "shrinkA": 2, "shrinkB": 2})

    @dfs_builder(True)
    def build_with_comm(self) -> None:
        r"""Build pipeline timeline with comm blocks and dependency."""
        self.dual.depend_left.build_with_comm()
        self.depend_left.build_with_comm()
        self.start = max(self.depend_left.end, self.dual.depend_left.end)
        self.end = self.start + self.time

    def loop(self, comm: bool = False) -> List["BlockSim"]:
        """Delegate to :meth:`comm_loop` when ``comm`` is ``True``, else fall back to base."""
        if comm:
            return self.comm_loop()
        return super().loop(comm)

    def comm_loop(self) -> list[BlockSim]:
        r"""recursively check comm loop"""
        if self.flag and not self.in_queue:
            return []
        res = []
        if self.in_queue:
            loop = [self]
            block = self.father
            while block.father and block is not self:
                block = block.father
                loop.append(block)
            return loop
        self.flag = True
        self.in_queue = True
        depends = [self.dual.depend_left, self.depend_left]
        for dep in depends:
            if dep:
                dep.father = self
                res.extend(dep.comm_loop())
                dep.father = None
        self.in_queue = False
        return res


@dataclass
class RecBlockSim(CommBlockSim):
    r"""sim receive comm block"""
    type: str = 'r'
    _color = '32'

    def loc_size(self, x: float = 0, equal_wide: bool = False,
                 mode: str = 'compact') -> Tuple[float, float, float, float]:
        """Return ``(x, y, dx, dy)`` for a receive block relative to its compute host."""
        host_x, _, _, _ = self.host.loc_size(x, equal_wide)
        x, y, _, _ = super().loc_size(x, equal_wide)
        dx_ = self.time
        dy_ = min(np.sqrt(self.time) * 0.6, 0.6)
        if mode == 'compact':
            x = host_x
        res = x, y, -dx_, -dy_
        return res

    def get_triangle(self, x: float, y: float, dx: float,
                     dy: float) -> List[List[float]]:
        """Return the three triangle vertices for a receive block pointing to the left."""
        return [[x, y], [x - dx, y + dy / 2], [x - dx, y - dy / 2]]

    @dfs_builder(True)
    def build_with_comm(self) -> None:
        r"""Build pipeline timeline with comm blocks and dependency."""
        self.dual.build_with_comm()
        self.depend_left.build_with_comm()
        self.start = max(self.depend_left.end, self.dual.start)
        self.end = self.start + self.time
hyper_parallel/auto_parallel/sapp_ppb/simulator/utils.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Helpers used by the pipeline simulator: numeric coercion, colouring, decorators."""
import time
from functools import wraps
from typing import Any, Callable, List, Tuple, Union

import numpy as np
from matplotlib import colors

from sapp_ppb.utils.logger import logger

ScalarOrMatrix = Union[int, float, List[List[Union[int, float]]], Tuple[Tuple[Union[int, float], ...], ...]]


def format_2d_inputs(a: ScalarOrMatrix, raw: int, col: int) -> np.ndarray:
    """Coerce ``a`` into a 2-D :class:`numpy.ndarray` of shape ``(raw, col)``.

    Args:
        a: Scalar broadcast to ``(raw, col)``, a flat sequence treated as one row,
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

    Raises:
        ValueError: If ``a`` does not match any of the supported shapes.
    """
    if isinstance(a, (int, float)):
        return np.broadcast_to(a, (raw, col))
    if isinstance(a, (list, tuple)):
        if all(isinstance(item, (list, tuple)) for item in a):
            return np.array(a)
        if all(isinstance(item, (int, float)) for item in a):
            return np.array([a])
        raise ValueError(f"Unsupported inputs: {a}")
    raise ValueError(f"Unsupported inputs: {a}")


def apply_color(target_list: list, c: List[str]) -> list:
    """Wrap each element of ``target_list`` with an ANSI colour escape from ``c``.

    Args:
        target_list: Values to colour (floats are formatted to four decimals).
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    Returns:
        The same list with each element wrapped in the matching colour escape.
    """
    for i, target in enumerate(target_list):
        target = f'{target:.4f}' if isinstance(target, float) else target
        target_list[i] = f"\033[{c[i]}m{target}\033[0m"
    return target_list


def apply_format(target_list: list) -> str:
    """Join a sequence of pre-coloured values into the single-line bubble report.

    Args:
        target_list: Coloured strings produced by :func:`apply_color`.
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

    Returns:
        The formatted single-line string.
    """
    s = f'{target_list[0]:^22}'
    symbol = ['=', '+', '+', '+', '+', '+']
    for i in range(len(target_list) - 1):
        s = f'{s}{symbol[i]}{target_list[i + 1]:^22}'
    return s


def color_mix(c1: Any, c2: Any, w1: float = 0.5, w2: float = 0.5) -> Tuple[float, float, float, float]:
    """Blend two matplotlib colours with weights ``w1`` and ``w2``.

    Args:
        c1: First colour in any format understood by :func:`matplotlib.colors.to_rgba`.
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106

    Returns:
        A ``(r, g, b, a)`` tuple with values in ``[0, 1]``.
    """
    rgb = (np.array(colors.to_rgba(c1, 1)) * w1 + np.array(colors.to_rgba(c2, 1)) * w2) / (w1 + w2)
    return colors.to_rgba(rgb)


def dfs_builder(comm: bool = False) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
    """Build a decorator that guards a DFS visit against re-entry and unmet dependencies.

    Args:
        comm: When ``True``, use the communication-aware ``depend_pre``/``depend_left``
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    Returns:
        A decorator wrapping a DFS visit method on :class:`BlockSim`-like objects.
    """

    def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
        """Attach the DFS visit guards to ``func``."""

        @wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            """Run ``func`` exactly once per node after asserting dependencies."""
            self = args[0]
            pre, left = (self.depend_pre, self.depend_left) if comm else (self.pre, self.left)
            if self.finish:
                return None
            if pre is None or left is None:
                raise NotImplementedError
            if self.in_queue:
                raise ValueError("Dependency loop detected during DFS traversal")
            self.in_queue = True
            res = func(*args, **kwargs)
            self.finish = True
            self.in_queue = False
            return res
        return wrapper

    return decorator


def timer(func: Callable[..., Any]) -> Callable[..., Any]:
    """Log the wall-clock time a function takes.

    Args:
        func: Callable to time.
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    Returns:
        A wrapper that logs the elapsed time at INFO level after ``func`` returns.
    """

    @wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        """Time one call to ``func`` and log the elapsed wall clock."""
        t0 = time.time()
        res = func(*args, **kwargs)
        t1 = time.time() - t0
        logger.info("function `%s` time used: %.4f s", func.__name__, t1)
        return res

    return wrapper
hyper_parallel/auto_parallel/sapp_ppb/utils/check_rules.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# limitations under the License.
# ============================================================================

"""Defensive YAML loading helpers (caps nesting depth to guard against malicious input)."""
from typing import Union

import yaml
from yaml.nodes import MappingNode, Node

YAML_MAX_NESTING_DEPTH = 10


def _get_yaml_ast_depth(node: Node, depth: int = 0) -> int:
    """Recursively return the maximum nesting depth of a YAML AST."""
    if isinstance(node, MappingNode):
        return max(
            (_get_yaml_ast_depth(v, depth + 1) for _, v in node.value), default=depth
        )
    return depth


def check_yaml_depth_before_loading(yaml_str: Union[str, bytes],
                                    max_depth: int = YAML_MAX_NESTING_DEPTH) -> None:
    """Reject YAML documents whose nesting depth exceeds ``max_depth``.

    Args:
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

    Raises:
        ValueError: If the document exceeds ``max_depth`` or fails to parse.
    """
    try:
        node = yaml.compose(yaml_str)
        if node is None:
            return
        depth = _get_yaml_ast_depth(node)
        if depth > max_depth:
            raise ValueError(
                f"YAML nesting depth {depth} exceeds the maximum allowed value of {max_depth}"
            )
    except yaml.YAMLError as e:
        raise ValueError(f"YAML parse error: {e}") from e
hyper_parallel/auto_parallel/sapp_ppb/utils/computation_analyzer.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# ============================================================================

"""Computation cost analyzer for pipeline balancing."""

import json
import os
import re
import sys
from itertools import chain
from typing import Dict, List, Optional

from tqdm import tqdm

from sapp_ppb.utils.logger import logger

UNSTABLE_STEPS = 2


class ComputationAnalyzer:
    """Parser & Analyzer for profiling timelines"""

    is_msprof_file = False

    def __init__(self, timeline_folder_path: str, model_name: str,
                 num_of_micro_batch: int, layer_list: Optional[dict] = None) -> None:
        """Load profiling timelines and pre-compute the per-layer cost maps.

        Args:
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
            num_of_micro_batch: Micro-batch count for the observed run (used for
                normalizing totals).
            layer_list: Optional pre-parsed layer metadata. When ``None`` it is loaded from disk.
        """
        self.timeline_folder_path = timeline_folder_path
        self.model_name = model_name
        self.num_of_micro_batch = num_of_micro_batch
        self.counted_steps = 0
        self.step_time = 0.0
        self.select_step_number = UNSTABLE_STEPS + 1
        if layer_list:
            self.layer_list = layer_list
        else:
            self.layer_list = self._get_layer_list()
        self.timeline_data = self._get_timeline_data()
        logger.info("parsing layer objs")
        self.auto_partition_layer_objects, self.pre_defined_layer_objects = (
            self._parse_layer_objects())
        logger.info("parsing auto partition layer name")
        self.auto_partition_layer_name_list = self.parse_auto_partition_layer_name_list()
        logger.info("parse layer with computation time list")
        self.layer_with_computation_time_list = self.parse_layer_with_computation_time_list()
        logger.info("transform layer with cost list")
        self.layer_with_cost_list = self.transform_layer_with_cost_list()

    def _get_layer_list(self):
        """Return cfgs from model config file"""

        model_config_file = os.path.join(os.getcwd(), "cfgs", "model_layers.json")
        with open(model_config_file, encoding="utf-8") as json_file:
            model_layers_data = json.load(json_file)
            for layer_list in model_layers_data:
                if self.model_name in layer_list["name"]:
                    return layer_list
        logger.info("ERROR: Not found model in model config file")
        return False

    def _get_timeline_data(self):
        """Return timeline objects from json file."""
        logger.info("loading timeline data")
        timeline_data = []
        for file_name in [file for file in os.listdir(self.timeline_folder_path) if
                          file.endswith(".json")]:
            if file_name.startswith("trace_view"):
                self.is_msprof_file = True
            elif file_name.startswith("msprof"):
                self.is_msprof_file = True
            else:
                self.is_msprof_file = False
                logger.error("ERROR: Not support timeline file type")
            with open(os.path.join(self.timeline_folder_path, file_name), encoding="utf-8") as json_file:
                timeline_data.append(json.load(json_file))
        return timeline_data

    def _parse_step_duration(self, timeline_data):
        """Return timeline objects during a training step."""

        op_name = ""
        step_start = 0.0
        step_end = 0.0
        cpt = 0
        for obj in timeline_data:
            if "MatMul-op" in obj["name"]:
                op_name = obj["name"]
                break
        for obj in timeline_data:
            if obj["name"] == op_name:
                cpt += 1
                if cpt == self.select_step_number:
                    step_start = float(obj["ts"])
                if cpt == (self.select_step_number + 1):
                    step_end = float(obj["ts"])
        step_time = step_end - step_start
        self.step_time = step_time
        return (step_start, step_end)

    def _load_json_data(self, file_path):
        with open(file_path, encoding="utf-8") as json_file:
            return json.load(json_file)

    def _initialize_step_duration(self, timeline_data, step_start, step_end):
        if step_start == 0 or step_end == 0:
            step_start, step_end = self._parse_step_duration(timeline_data)
        return step_start, step_end

    def _add_layer_object(self, objects_list, condition, obj):
        if condition and obj not in objects_list:
            objects_list.append(obj)

    def _is_counted(self, default_table: list, step_start, step_end, cell_object):
        """Check if cell in under forward scope"""
        if float(cell_object["ts"]) < step_start or float(cell_object["ts"]) + float(cell_object["dur"]) > step_end:
            return False

        is_counted = False
        for duration in default_table:
            start = float(cell_object["ts"])
            end = float(cell_object["ts"]) + float(cell_object["dur"])
            if start >= duration[0] and end <= duration[1]:
                is_counted = True
                break
        return is_counted

    def _forward_parser(self, timeline_data):
        """Parse time range of forward operators"""
        logger.info("parsing forward scope")
        scope_pid = 3
        default_durations = []
        cell_durations = {}
        step_range = []
        op_name = ""
        for obj in timeline_data:
            if obj["name"] == "Scope Layer":
                scope_pid = obj["pid"]
                break
        for obj in tqdm(timeline_data):
            if op_name == "" and "MatMul-op" in obj["name"]:
                op_name = obj["name"]
            if obj["name"] == op_name:
                step_range.append(float(obj["ts"]))
            if obj["pid"] != scope_pid:
                continue
            if obj["name"] == "Default" and obj["tid"] == 0:
                start = float(obj["ts"])
                end = float(obj["ts"]) + float(obj["dur"])
                default_durations.append((start, end))
                continue
            for layer_name in chain(self.layer_list["pre_defined_layer"], self.layer_list["auto_partition_layer"]):
                if layer_name in obj["name"]:
                    layer_time = cell_durations.get(layer_name)
                    if layer_time is None:
                        cell_durations[layer_name] = []
                    cell_durations[layer_name].append(obj)

        # step times of first 2 steps are not stable
        # so we don't consider them when enough steps are given
        steps = len(step_range)
        logger.info("There are %s steps in given timeline data", steps)
        if steps == 0:
            raise ValueError("Failed to parse timeline")
        if steps == 1:
            select_step_start = 0.0
            select_step_end = sys.float_info.max
        else:
            select_step_start = step_range[min(len(step_range) - UNSTABLE_STEPS, UNSTABLE_STEPS)]
            select_step_end = step_range[-1]
        logger.info("select_step_start: %f", select_step_start)
        logger.info("select_step_end: %f", select_step_end)
        self.counted_steps = max(len(step_range) - (UNSTABLE_STEPS + 1), 1)
        logger.info("counted_steps: %f", self.counted_steps)
        return default_durations, cell_durations, select_step_start, select_step_end

    def _process_timelines(self, timeline_data, step_start, step_end, pre_defined_layer_objects,
                           auto_partition_layer_objects):
        """_process_file"""
        logger.info("processing timeline. %d objects in it.", len(timeline_data))
        if not self.is_msprof_file:
            step_start, step_end = self._initialize_step_duration(self.timeline_data, step_start, step_end)
        default_durations, cell_durations, step_start, step_end = self._forward_parser(timeline_data)
        for cell_name, cell_objs in cell_durations.items():
            for obj in cell_objs:
                if not self._is_counted(default_durations, step_start, step_end, obj):
                    continue
                for layer_name in self.layer_list["pre_defined_layer"]:
                    if layer_name in cell_name:
                        self._add_layer_object(pre_defined_layer_objects, True, obj)
                for layer_name in self.layer_list["auto_partition_layer"]:
                    if layer_name in cell_name:
                        self._add_layer_object(auto_partition_layer_objects, True, obj)
        return step_start, step_end

    def _parse_layer_objects(self):
        auto_partition_layer_objects = []
        pre_defined_layer_objects = []
        step_start, step_end = 0, 0
        for timeline in self.timeline_data:
            step_start, step_end = self._process_timelines(timeline, step_start, step_end,
                                                           pre_defined_layer_objects,
                                                           auto_partition_layer_objects)
        return auto_partition_layer_objects, pre_defined_layer_objects

    def parse_auto_partition_layer_name_list(self) -> List[str]:
        """example: [42-TransformerEncoderLayer,43-TransformerEncoderLayer]"""
        auto_partition_layer_name_list = []
        for auto_partition_name in self.layer_list["auto_partition_layer"]:
            for obj in [item for timeline in self.timeline_data for item in timeline]:
                object_name = obj["name"]
                if auto_partition_name in object_name:
                    if self.is_msprof_file:
                        find_layer_name = re.findall(r"[0-9]*-" + auto_partition_name,
                                                     object_name)
                        layer_name = find_layer_name[0]
                    else:
                        layer_name = object_name
                    if layer_name not in auto_partition_layer_name_list:
                        auto_partition_layer_name_list.append(layer_name)
        return auto_partition_layer_name_list

    def parse_layer_with_computation_time_list(self) -> Dict[str, float]:
        """
        Map each layer_name with its duration time.
        For example: [46-TransformerEncoderLayer':37.24729124999999, '47-TransformerEncoderLayer': 37.36572429687501]
        """
        layer_with_computation_time_list = {}
        for pre_defined_layer_name in self.layer_list["pre_defined_layer"]:
            layer_with_computation_time_list[pre_defined_layer_name] = 0
        for auto_partition_layer_name in self.auto_partition_layer_name_list:
            layer_with_computation_time_list[auto_partition_layer_name] = 0

        for obj in self.pre_defined_layer_objects:
            for pre_defined_layer_name in self.layer_list["pre_defined_layer"]:
                if pre_defined_layer_name in obj["name"]:
                    layer_with_computation_time_list[pre_defined_layer_name] += (float(obj["dur"]) / 1000)
        for obj in self.auto_partition_layer_objects:
            for auto_partition_layer_name in self.auto_partition_layer_name_list:
                # if auto_partition_layer_name in obj["name"]:
                if re.search(r"\b" + re.escape(auto_partition_layer_name), obj["name"]):
                    layer_with_computation_time_list[auto_partition_layer_name] += (
                        float(obj["dur"]) / 1000)
        return layer_with_computation_time_list

    def transform_layer_with_cost_list(self) -> Dict[str, float]:
        """calculating the average value of layer cost"""
        total_cost_auto_partition_layer = {}
        number_of_auto_partition_layer = {}
        transform_layer_with_cost_list = {}
        for pre_defined_layer_name in self.layer_list["pre_defined_layer"]:
            transform_layer_with_cost_list[pre_defined_layer_name] = 0

        for auto_partition_layer_name in self.layer_list["auto_partition_layer"]:
            total_cost_auto_partition_layer[auto_partition_layer_name] = 0
            number_of_auto_partition_layer[auto_partition_layer_name] = 0
            transform_layer_with_cost_list[auto_partition_layer_name] = 0

        # test
        for layer_name, layer_time in self.layer_with_computation_time_list.items():
            for pre_defined_layer_name in self.layer_list["pre_defined_layer"]:
                if pre_defined_layer_name in layer_name:
                    var_tmp = layer_time / self.counted_steps / self.num_of_micro_batch
                    transform_layer_with_cost_list[pre_defined_layer_name] += var_tmp
            for auto_partition_layer_name in self.layer_list["auto_partition_layer"]:
                if auto_partition_layer_name in layer_name:
                    # assuming that the duration time of a layer can not exceed 10% of step time
                    # in order to
                    # avoid some specific long time layers that caused from the error caused by
                    # time_line.json
                    if not self.is_msprof_file and layer_time > self.step_time / 1000 / 10:
                        continue
                    total_cost_auto_partition_layer[auto_partition_layer_name] += \
                        layer_time
                    number_of_auto_partition_layer[auto_partition_layer_name] += 1

        for auto_partition_layer_name in self.layer_list["auto_partition_layer"]:
            transform_layer_with_cost_list[auto_partition_layer_name] = (
                total_cost_auto_partition_layer[auto_partition_layer_name] /
                number_of_auto_partition_layer[auto_partition_layer_name] /
                self.counted_steps / self.num_of_micro_batch)
        return transform_layer_with_cost_list


if __name__ == "__main__":
    path = "/your/path/here"
    example_model_name = "LLaMA_prof"
    comp1 = ComputationAnalyzer(path, example_model_name, 8)
    logger.info(comp1.layer_with_computation_time_list)
    logger.info(comp1.layer_with_cost_list)
hyper_parallel/auto_parallel/sapp_ppb/utils/compute_memory.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Derive per-layer memory parameters from a set of dry-run stage observations."""
import numpy as np

import sapp_ppb.utils.recompute as Recompute
from sapp_ppb.utils.layer import Layer
from sapp_ppb.utils.logger import logger
from sapp_ppb.utils.stage import Stage, filter_stage_id


class ComputeMemory:
    """
    ComputeMemory class to compute the different memories with stages information running (dry) log

    stage{A|B} means stage with different configuration A and B
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    memory_head_ (float): memory required to run the head layer
    memory_tail_ (float): memory required to run the tail layer
    """

    number_of_stage_: int
    stages_a: list[Stage]
    stages_b: list[Stage]
    memory_parameter_: float
    memory_activation_rec_: dict[Recompute.TYPE, float]
    recompute_considered_: dict[Recompute.TYPE, bool]
    memory_const_: float
    memory_head_: float
    memory_tail_: float

    def __init__(self, number_of_stage: int, stages_a: list[Stage] = None,
                 stages_b: list[Stage] = None) -> None:
        """Build a :class:`ComputeMemory` solver instance.

        Args:
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
            number_of_stage: Total number of pipeline stages in the target LLM.
            stages_a: Dry-run observations with configuration A (at least stages ``0, i, j, n-1``).
            stages_b: Dry-run observations with configuration B (must differ from A).
        """
        self.number_of_stage_ = number_of_stage
        self.set_stages_a(stages_a)
        self.set_stages_b(stages_b)
        # number_of_stage != len(stages) can be true
        self.memory_parameter_ = None
        self.memory_activation_rec_ = {r: None for r in Recompute.TYPE}
        self.find_recompute_considered()
        self.memory_const_ = None
        self.memory_head_ = None
        self.memory_tail_ = None

    def set_stages_a(self, stages: list[Stage]) -> None:
        """Assign dry-run observations to configuration A after a consistency check."""
        if stages is None:
            self.stages_a = []
            return
        for stage1 in stages:
            for stage2 in stages:
                if not stage1.same_global_config(stage2):
                    logger.error(
                        "Cannot set stagesA, all elements don't have the same configuration",)
                    self.stages_a = []
                    return
        self.stages_a = stages

    def set_stages_b(self, stages: list[Stage]) -> None:
        """Assign dry-run observations to configuration B (must differ from A)."""
        if stages is None:
            self.stages_b = []
            return
        for stage1 in stages:
            for stage2 in stages:
                if not stage1.same_global_config(stage2):
                    logger.error(
                        "Cannot set stagesB, all elements don't have the same configuration")
                    self.stages_b = []
                    return
            for stage_a in self.stages_b:
                if stage1.same_global_config(stage_a):
                    logger.error(
                        "Cannot set stagesB, an elements have the same configuration than stagesA")
                    self.stages_b = []
                    return
        self.stages_b = stages

    def find_recompute_considered(self) -> None:
        """Populate :attr:`recompute_considered_` from the observed ``stages_a`` data."""
        self.recompute_considered_ = {r: False for r in Recompute.TYPE}
        self.recompute_considered_[Recompute.TYPE.NONE] = True

        for stage in self.stages_a:
            for rec in Recompute.TYPE:
                if stage.nb_layer_rec_[rec] > 0:
                    self.recompute_considered_[rec] = True

    def _compute_memory_parameter_local_(self, stage1: Stage, stage2: Stage) -> float:
        """
        Given 2 stages information with the same configuration, and different id,
        Compute the memory_parameter
        """
        if stage1.same_config(stage2):
            if stage1.id_ != stage2.id_:
                res = stage1.memory_usage_ * (stage1.nb_stage_ - stage1.id_)
                res -= stage2.memory_usage_ * (stage2.nb_stage_ - stage2.id_)
                res /= stage1.id_ - stage2.id_
                res = abs(res)
                res /= stage1.nb_layer_
                return res
            logger.error(
                "stage with same characteristic, BUT SAME ID too, cannot compute memory_parameter")
            return 0
        logger.error("stage with different characteristic, cannot compute memory_parameter")
        return 0

    def _compute_memory_parameter_(self, multi_run=False) -> float:
        """Compute memory_parameter
            With all available stages compute all combinations of memory parameter
            and return the mean of all the memory_parameter found
        BEWARE: can update memory_parameter_ & memory_activation_rec_
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        BEWARE: can update memory_parameter_ & memory_activation_rec_
                because of _compute_memories_layers_()
        return: memory_parameter
        """
        if multi_run or (len(self.stages_a) < 5 and len(self.stages_b) < 5):
            memory_parameter_list = []
            for stage1 in self.stages_a:
                if stage1.id_ not in [0, (self.number_of_stage_ - 1)]:
                    mem_param = self._compute_memory_parameter_local_(stage1, stage2)
                    for stage2 in self.stages_a:
                        if (stage2.id_ not in [0, (self.number_of_stage_ - 1),
                                               stage1.id_] and mem_param != 0):
                            memory_parameter_list.append(mem_param)
            for stage1 in self.stages_b:
                if stage1.id_ not in [0, (self.number_of_stage_ - 1)]:
                    for stage2 in self.stages_b:
                        mem_param = self._compute_memory_parameter_local_(stage1, stage2)
                        if (stage2.id_ not in [0, (self.number_of_stage_ - 1),
                                               stage1.id_] and mem_param != 0):
                            memory_parameter_list.append(mem_param)
            return np.mean(memory_parameter_list)
        if self._compute_memories_layers_():
            return self.memory_parameter_
        logger.error("Issue with _compute_memory_parameter_!!!")
        return 0

    def _compute_memory_activation_(self, rec, multi_run=False) -> float:
        """
        Compute memory_activation for a given recomputation type
        return: memory_activation
        """
        if multi_run or (len(self.stages_a) < 5 and len(self.stages_b) < 5):
            # look at solution 4 stages
            logger.error("Not implemented yet!!!")
            return 0
        if self._compute_memories_layers_():
            return self.memory_activation_rec_[rec]
        logger.error("Issue with _compute_memory_activation_!!!")
        return 0

    def zero_offset(self) -> bool:
        """Return ``True`` if every stage in ``stages_a`` hosts the same number of layers."""
        nb_layer = self.stages_a[0].nb_layer_
        for s in self.stages_a:
            if s.nb_layer_ != nb_layer:
                return False
        return True

    def _compute_memories_layers_(self) -> bool:
        """check if enough stage number is provided"""
        used_rec = Recompute.get_used_list(self.recompute_considered_)
        used_rec_num = len(used_rec)
        stage_num = len(self.stages_a)
        if stage_num == used_rec_num + 3:
            return self._compute_memories_layer_bodies_(False)
        if stage_num >= used_rec_num + 4:
            logger.info("Enabled const memory component because enough stages were given")
            if self.zero_offset():
                logger.error(
                    "The number of layer per stage cannot be the same for all stages "
                    "when const component is enabled. Some offset must be used"
                )
                return False
            return self._compute_memories_layer_bodies_(True)

        logger.error(
            "%s stages found and (%s) recomputation considered"
            "is not coherent. There should be 3 or 4 more stages than recomputation considered",
            stage_num,
            used_rec_num,
210
211
212
213
214
215
216
217
218
219
220
            "is not coherent. There should be 3 or 4 more stages than recomputation considered",
            stage_num,
            used_rec_num,
        )
        return False

    def _compute_memories_layer_bodies_local_(
            self, unused_rec: list[Recompute.TYPE],
            stages: list[Stage]) -> tuple[float, float, float]:
        """Compute memory_parameter & memory activation for all recomputation types
        Require at least 3 Stages different from first and last stage
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
            stages: list[Stage]) -> tuple[float, float, float]:
        """Compute memory_parameter & memory activation for all recomputation types
        Require at least 3 Stages different from first and last stage
        """
        variable_factor_list = []
        constant_memory_list = []
        unused_rec.sort(reverse=True)
        for stage in stages:
            if stage.id_ not in [0, self.number_of_stage_ - 1]:
                variable_factor_list.append(stage.get_index_memory_var())
                for rec_i in unused_rec:
                    variable_factor_list[-1].pop(1 + rec_i)
                constant_memory_list.append(stage.memory_usage_)
        solution = list(
            np.linalg.solve(np.array(variable_factor_list),
                            np.array(constant_memory_list)))
        memory_param = solution.pop(0)
        memory_act_rec = Recompute.assign_used(solution, unused_rec)
        return (memory_param, memory_act_rec)



    def _compute_memories_layer_bodies_local_with_fix_(
            self, unused_rec: list[Recompute.TYPE],
            stages: list[Stage]) -> tuple[float, float, float]:
        """Compute memory_const, memory_parameter & memory activation for all recomputation types
        Require at least 4 Stages different from first and last stage
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
            stages: list[Stage]) -> tuple[float, float, float]:
        """Compute memory_const, memory_parameter & memory activation for all recomputation types
        Require at least 4 Stages different from first and last stage
        """
        variable_factor_list = []
        constant_memory_list = []
        unused_rec.sort(reverse=True)
        for stage in stages:
            if stage.id_ not in [0, self.number_of_stage_ - 1]:
                variable_factor_list.append([1] + stage.get_index_memory_var())
                for rec_i in unused_rec:
                    variable_factor_list[-1].pop(2 + rec_i)
                constant_memory_list.append(stage.memory_usage_)
        logger.debug(
            "solve(\n %s, \n %s) ",
            np.array(variable_factor_list),
            np.array(constant_memory_list),
        )
        used_rec = Recompute.get_used_list(self.recompute_considered_)
        used_rec_num = len(used_rec)

        if len(stages) < used_rec_num + 4:
            raise ValueError("Stages given are not enough to solve memory constraints")
        if len(stages) == used_rec_num + 4:
            solution = list(
                np.linalg.solve(np.array(variable_factor_list),
                                np.array(constant_memory_list)))
        else:
            logger.warning("Stages given are more than needed, switch to least sqaures method")
            solution = list(np.linalg.lstsq(np.array(variable_factor_list),
                                            np.array(constant_memory_list), rcond=None)[0])

        memory_const = solution.pop(0)
        memory_param = solution.pop(0)
        memory_act_rec = Recompute.assign_used(solution, unused_rec)
        return (memory_const, memory_param, memory_act_rec)

    def _compute_memories_layer_bodies_(self, with_fix: bool) -> bool:
        """
        Compute memory_parameter, memory_recompute, memory_activation
        Require at least 3 Stages different from first and last stage
        BEWARE: can update memory_parameter_, memory_recompute_, memory_activation_
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        BEWARE: can update memory_parameter_, memory_recompute_, memory_activation_
        return True if success to update memory_parameter_, memory_recompute_, memory_activation_
        """

        memory_const_a = None
        memory_parameter_a = None
        memory_recompute_a = {r: None for r in Recompute.TYPE}

        memory_const_b = None
        memory_parameter_b = None
        memory_recompute_b = {r: None for r in Recompute.TYPE}

        unused_rec = Recompute.get_unused_list(self.recompute_considered_)
        logger.info("unused recomputation: %s", unused_rec)

        if with_fix:
            if len(self.stages_a) >= 5:
                (memory_const_a,
                 memory_parameter_a,
                 memory_recompute_a) = (self._compute_memories_layer_bodies_local_with_fix_(
                     unused_rec, self.stages_a))
            if len(self.stages_b) >= 5:
                (memory_const_b,
                 memory_parameter_b,
                 memory_recompute_b) = (self._compute_memories_layer_bodies_local_with_fix_(
                     unused_rec, self.stages_b))

            return self._average_if_needed_fix(
                memory_const_a,
                memory_parameter_a,
                memory_recompute_a,
                memory_const_b,
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
                memory_const_b,
                memory_parameter_b,
                memory_recompute_b,
            )
        if len(self.stages_a) >= 5:
            (memory_parameter_a,
             memory_recompute_a) = (self._compute_memories_layer_bodies_local_(
                 unused_rec, self.stages_a))
        if len(self.stages_b) >= 5:
            (memory_parameter_b,
             memory_recompute_b) = (self._compute_memories_layer_bodies_local_(
                 unused_rec, self.stages_b))

        return self._average_if_needed(
            memory_parameter_a,
            memory_recompute_a,
            memory_parameter_b,
            memory_recompute_b,
330
331
332
333
334
335
336
337
338
            memory_parameter_b,
            memory_recompute_b,
        )

    def _average_if_needed_fix(
            self,
            memory_const_a,
            memory_parameter_a,
            memory_recompute_a,
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
            memory_parameter_b,
            memory_recompute_b,
    ):
        """check if average is needed"""
        if memory_parameter_a is not None and memory_parameter_a != 0:
            if memory_parameter_b is not None and memory_parameter_b != 0:
                self.memory_const_ = (memory_const_a +
                                      memory_const_b) / 2
                self.memory_parameter_ = (memory_parameter_a +
                                          memory_parameter_b) / 2
                Recompute.average([memory_recompute_a, memory_recompute_b])
            else:
                self.memory_const_ = memory_const_a
                self.memory_parameter_ = memory_parameter_a
                self.memory_activation_rec_ = memory_recompute_a

        elif memory_parameter_b is not None and memory_parameter_b != 0:
            self.memory_const_ = memory_const_b
            self.memory_parameter_ = memory_parameter_b
            self.memory_activation_rec_ = memory_recompute_b
        else:
            logger.error("failed to compute memories")
            return False
        return True

    def _average_if_needed(self, memory_parameter_a, memory_recompute_a, memory_parameter_b,
                           memory_recompute_b,):
        """check if average is needed"""
        if memory_parameter_a is not None and memory_parameter_a != 0:
            if memory_parameter_b is not None and memory_parameter_b != 0:
                self.memory_parameter_ = (memory_parameter_a + memory_parameter_b) / 2
                Recompute.average([memory_recompute_a, memory_recompute_b])
            else:
                self.memory_parameter_ = memory_parameter_a
                self.memory_activation_rec_ = memory_recompute_a

        elif memory_parameter_b is not None and memory_parameter_b != 0:
            self.memory_parameter_ = memory_parameter_b
            self.memory_activation_rec_ = memory_recompute_b
        else:
            logger.error("failed to compute memories")
            return False
        return True

    def _compute_memory_head_(self) -> float:
        """compute the memory for the head"""
        head_stages = filter_stage_id(self.stages_a, 0)
        head_stages += filter_stage_id(self.stages_b, 0)
        memory_head_list = []
        mem_parameter = self.get_memory_parameter()
        for head in head_stages:
            head_memory = head.memory_usage_
            for rec in Recompute.TYPE:
                if self.recompute_considered_[rec] is True:
                    head_memory -= (head.nb_layer_rec_[rec] * self.get_memory_activation(
                        rec) * self.number_of_stage_)
            head_memory -= (head.nb_layer_) * mem_parameter
            memory_head_list.append(head_memory)
        return np.mean(memory_head_list)

    def _compute_memory_tail_(self) -> float:
        """compute the memory for the tail"""
        tail_stages = filter_stage_id(self.stages_a, self.number_of_stage_ - 1)
        tail_stages += filter_stage_id(self.stages_b, self.number_of_stage_ - 1)
        memory_tail_list = []
        for tail in tail_stages:
            tail_memory = tail.memory_usage_
            for rec in Recompute.TYPE:
                if self.recompute_considered_[rec] is True:
                    tail_memory -= (tail.nb_layer_rec_[rec] * self.get_memory_activation(rec) * 1)
            tail_memory -= (tail.nb_layer_) * self.get_memory_parameter()
            memory_tail_list.append(tail_memory)
        return np.mean(memory_tail_list)

    def get_memory_const(self) -> float:
        """Return the solver-derived constant memory component per stage."""
        return self.memory_const_

    def get_memory_parameter(self, force_recompute: bool = False) -> float:
        """Return the per-body-layer parameter memory, recomputing on demand."""
        if force_recompute or self.memory_parameter_ is None:
            self.memory_parameter_ = self._compute_memory_parameter_()
        return self.memory_parameter_

    def get_memory_activation(self, rec: Recompute.TYPE,
                              force_recompute: bool = False) -> float:
        """Return the per-layer activation memory for a given recomputation type."""
        if force_recompute or self.memory_activation_rec_[rec] is None:
            self.memory_activation_rec_[rec] = self._compute_memory_activation_(rec)
        return self.memory_activation_rec_[rec]

    def get_memory_head(self, force_recompute: bool = False) -> float:
        """Return the HEAD-layer memory, recomputing on demand."""
        if force_recompute or self.memory_head_ is None:
            self.memory_head_ = self._compute_memory_head_()
        return self.memory_head_

    def get_memory_tail(self, force_recompute: bool = False) -> float:
        """Return the TAIL-layer memory, recomputing on demand."""
        if force_recompute or self.memory_tail_ is None:
            self.memory_tail_ = self._compute_memory_tail_()
        return self.memory_tail_


def compute_memories(layers: list[Layer], memory_folder: str, number_of_stage: int) -> list[Layer]:
    """compute memories"""
    filename = ""
    # Put some meta information in a predefine .json file like layers info?
    with open(memory_folder + filename, encoding="utf-8"):
        pass
    cm = ComputeMemory(number_of_stage=number_of_stage, stages_a=[], stages_b=[])

    for layer in layers:
        if layer.type_ == Layer.type_enum.HEAD:
            layer.memory_parameter_ = cm.get_memory_head()
        elif layer.type_ == Layer.type_enum.TAIL:
            layer.memory_parameter_ = cm.get_memory_tail()
        elif layer.type_ == Layer.type_enum.BODY:
            layer.memory_parameter_ = cm.get_memory_parameter()
            for rec in Recompute.TYPE:
                layer.memory_activation_rec_[rec] = cm.get_memory_activation(rec)
    return layers
hyper_parallel/auto_parallel/sapp_ppb/utils/config.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Layer-description JSON generator and YAML-based dry-run configuration parser."""
import json
import os
import random
from dataclasses import asdict, dataclass
from math import ceil
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import yaml

import sapp_ppb.utils.recompute as Recompute
from sapp_ppb.sapp.sapp_solver import SappSolver
from sapp_ppb.utils.check_rules import check_yaml_depth_before_loading
from sapp_ppb.utils.computation_analyzer import ComputationAnalyzer
from sapp_ppb.utils.compute_memory import ComputeMemory, Stage
from sapp_ppb.utils.layer import Layer
from sapp_ppb.utils.logger import logger

random.seed()


@dataclass
class LayersDescription:
    """layers description"""

    name: str
    type: Layer.type_enum
    model_name: str
    time: int
    nb_layer: int
    memory_parameter: int

    def __init__(
            self, layer_type: Layer.type_enum, time: int, nb_layer: int, model_name: str
    ) -> None:
        """Store layer metadata using the enum ``layer_type.name`` for both ``type`` and ``name``."""
        self.type = layer_type.name
        self.time = time
        self.name = layer_type.name
        self.nb_layer = nb_layer
        self.model_name = model_name


@dataclass
class ModelInfo:
    """basic info of a model"""

    name: str
    stage_const_mem: int
    layers_description: list[LayersDescription]

    def __init__(self, model_name: str, head_time: int, body_time: int,
                 tail_time: int, nb_layer: int) -> None:
        """Record HEAD/BODY/TAIL timing and initialise the stage-constant memory to 0."""
        self.name = model_name
        self.stage_const_mem = 0
        self.layers_description = []
        self.layers_description.append(
            LayersDescription(Layer.type_enum.HEAD, head_time, 1, model_name)
        )
        self.layers_description.append(
            LayersDescription(Layer.type_enum.BODY, body_time, nb_layer, model_name)
        )
        self.layers_description.append(
            LayersDescription(Layer.type_enum.TAIL, tail_time, 1, model_name)
        )

    def get_layer_by_type(self, layer_type: Layer.type_enum) -> Optional[LayersDescription]:
        """Return the layer description matching ``layer_type`` or ``None`` if absent."""
        for layer in self.layers_description:
            if layer.type == layer_type.name:
                return layer
        return None

    def set_stage_const_mem(self, mem_const: int) -> None:
        """Override the per-stage constant memory component."""
        self.stage_const_mem = mem_const

    def layer_memory_update(self, mem_act: Dict[Recompute.TYPE, int], mem_par: int,
                            mem_head: int, mem_tail: int) -> None:
        """Apply the solver-derived memory values onto the internal JSON payload."""
        self.get_layer_by_type(Layer.type_enum.HEAD).memory_parameter = mem_head
        self.get_layer_by_type(Layer.type_enum.TAIL).memory_parameter = mem_tail
        self.get_layer_by_type(Layer.type_enum.BODY).memory_parameter = mem_par

        json_data = asdict(self)

        for rec in Recompute.TYPE:
            rec_mem = mem_act.get(rec)
            if rec_mem is None:
                continue
            for layer in json_data["layers_description"]:
                if layer["type"] == Layer.type_enum.BODY.name:
                    layer[Recompute.JSON_MEMORY_NAME[rec]] = rec_mem
        self.to_json_ = json_data

    def dump_json(self, file_name: str) -> None:
        """Write the current ``to_json_`` payload to ``file_name`` as indented JSON."""
        with open(file_name, "w", encoding="utf-8") as json_file:
            json.dump(self.to_json_, json_file, indent=4)


def time_parser(file_name: str, model_name: str) -> Tuple[float, float, float]:
    """Parse the HEAD/BODY/TAIL timing values from a SAPP-PPB YAML config file."""
    if file_name is None:
        logger.error("input file cannot be none")
        raise ValueError("input file cannot be none")

    if not file_name.endswith("yaml") and not file_name.endswith("yml"):
        logger.error("Only accept yaml as input format")
        raise ValueError(f"Only accept yaml as input format. not {file_name}")

    filepath = os.path.realpath(file_name)
    with open(filepath, encoding="utf-8") as fp:
        check_yaml_depth_before_loading(fp)
        fp.seek(0)
        cfg_dict = yaml.safe_load(fp)

    head_time = 0
    body_time = 0
    tail_time = 0

    if "time_config" in cfg_dict:
        head_time = cfg_dict["time_config"].get("head")
        body_time = cfg_dict["time_config"].get("body")
        tail_time = cfg_dict["time_config"].get("tail")
        if all(key in cfg_dict["time_config"] for key in ["head", "body", "tail"]):
            return head_time, body_time, tail_time

    if cfg_dict.get("profiling_config"):
        head_layers = cfg_dict["profiling_config"].get("head_layers", ["LlamaEmbedding"])
        body_layers = cfg_dict["profiling_config"].get("body_layers", ["LLamaDecodeLayer"])
        tail_layers = cfg_dict["profiling_config"].get("tail_layers", ["lm_head-Linear", "LlamaRMSNorm"])
        if isinstance(head_layers, str):
            head_layers = [head_layers]
        if isinstance(tail_layers, str):
            tail_layers = [tail_layers]
        if isinstance(body_layers, str):
            body_layers = [body_layers]

        num_layer = cfg_dict["pipeline_config"]["num_layer"]
        micro_batch_num = cfg_dict["profiling_config"]["micro_batch_num"]
        timeline_folder_path = cfg_dict["profiling_config"]["folder_path"]
        layer_list = {"pre_defined_layer": {}, "auto_partition_layer": {}}
        for layer in head_layers:
            layer_list["pre_defined_layer"].update({layer: 0})
        for layer in tail_layers:
            layer_list["pre_defined_layer"].update({layer: -1})
        for layer in body_layers:
            layer_list["auto_partition_layer"].update({layer: num_layer})
        analyzer = ComputationAnalyzer(timeline_folder_path, model_name, micro_batch_num, layer_list)
        cost_list = analyzer.layer_with_cost_list
        logger.info(cost_list)
        for layer, time in cost_list.items():
            if layer in layer_list["pre_defined_layer"] and layer_list["pre_defined_layer"][layer] == 0:
                head_time += time
            elif layer in layer_list["pre_defined_layer"] and layer_list["pre_defined_layer"][layer] == -1:
                tail_time += time
            else:
                body_time += time

    logger.info("head_time: %s, body_time: %s, tail_time: %s", head_time, body_time, tail_time)

    return head_time, body_time, tail_time


def process_offset(offset: Union[int, List[int], List[List[int]]],
                   pipeline_num: int) -> Tuple[Union[List[int], List[List[int]]], int]:
    """Normalise the YAML ``offset`` field into a list (or list of lists) and return the rounds."""
    rounds = 1
    if isinstance(offset, int) and offset == 0:
        offset = [0] * pipeline_num
    # if offset is list of lists (usually when pp=4)
    elif isinstance(offset, list) and any(isinstance(item, list) for item in offset):
        tmp_offset = []
        for item in offset:
            if isinstance(item, int) and item == 0:
                tmp_offset.append([0] * pipeline_num)
            elif not (isinstance(item, list) and len(item) == pipeline_num):
                raise ValueError(f"Unsupported input format offset: {item},",
                                 "please check the length of your offset list and the pipeline number")
            else:
                tmp_offset.append(item)
        offset = tmp_offset
        rounds = len(offset)
    elif not (isinstance(offset, list) and len(offset) == pipeline_num):
        raise ValueError(f"Unsupported input format offset: {offset},",
                         "please check the length of your offset list and the pipeline number")

    return offset, rounds


def process_rec_config(
        layer_per_stage: int, pipeline_num: int, offset: List[int],
        rec_config: Optional[Union[bool, List[int]]]
) -> Optional[List[List[int]]]:
    """Normalise a per-recomputation-type YAML entry into a two-dimensional list."""
    if rec_config is None or offset is None:
        return None
    if isinstance(rec_config, bool):
        if rec_config:
            rec_config = [layer_per_stage] * pipeline_num
            rec_config = [recom + bias for recom, bias in (rec_config, offset)]
        else:
            rec_config = [0] * pipeline_num
        rec_config = [rec_config]
    elif isinstance(rec_config, list) and len(rec_config) == pipeline_num:
        # in order to be compatible with internal_from_yaml, change list into double list
        rec_config = [rec_config]
    else:
        raise ValueError(f"Unsupported input format recompute: {rec_config}, please check the length of list")

    return rec_config


def instantiate_stage(stage_id: int, pipeline_num: int, nb_layer: int,
                      layer_per_recompute: Dict[Recompute.TYPE, List[List[int]]],
                      memory: int) -> Stage:
    """Build a :class:`Stage` snapshot from per-recomputation-type layer assignments."""
    stage = Stage(
        sid=stage_id,
        nb_stage=pipeline_num,
        nb_layer=nb_layer,
        nb_layer_rec={
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
            Recompute.TYPE.BOTH: layer_per_recompute[Recompute.TYPE.BOTH][0][stage_id],
        },
        memory_usage=memory,
    )
    return stage


def memory_parser(file_name: str) -> Tuple[int, List[Stage], int]:
    """Parse a SAPP-PPB memory-usage YAML file into solver-ready stage observations."""
    if file_name is None:
        logger.error("input file cannot be none")
        raise ValueError("input file cannot be none")
    if not file_name.endswith("yaml") and not file_name.endswith("yml"):
        logger.error("Only accept yaml as input format")
        raise ValueError(f"Only accept yaml as input format. not {file_name}")

    filepath = os.path.realpath(file_name)
    with open(filepath, encoding="utf-8") as fp:
        check_yaml_depth_before_loading(fp)
        fp.seek(0)
        cfg_dict = yaml.safe_load(fp)

    # get pipeline config
    pipeline_num = cfg_dict["pipeline_config"]["pipeline_num"]
    num_layer = cfg_dict["pipeline_config"]["num_layer"]
    offset = cfg_dict["pipeline_config"]["offset"]

    offset, rounds = process_offset(offset, pipeline_num)

    layer_per_stage = int(num_layer / pipeline_num)

    # get recompute config
    if rounds > 1:
        layer_per_recompute = []
        for i in range(rounds):
            rec_config = {}
            for rec in Recompute.YAML_NAME.values():
                rec_list = cfg_dict["recompute_config"].get(rec)
                if rec_list is None:
                    continue
                rec_config[rec] = process_rec_config(layer_per_stage, pipeline_num, offset[i], rec_list[i])
            rec_config[Recompute.OFFSET] = [offset[i]]
            layer_per_recompute.append(
                Recompute.internal_from_yaml(1, pipeline_num, rec_config, [[layer_per_stage] * pipeline_num])
            )
    else:
        rec_config = {}
        for rec in Recompute.YAML_NAME.values():
            rec_list = cfg_dict["recompute_config"].get(rec)
            rec_config[rec] = process_rec_config(layer_per_stage, pipeline_num, offset, rec_list)
        rec_config[Recompute.OFFSET] = [offset]
        layer_per_recompute = Recompute.internal_from_yaml(1, pipeline_num, rec_config,
                                                           [[layer_per_stage] * pipeline_num])
    # get memory usage
    stage_id = cfg_dict["memory_usage"]["body_memories"]["stage_id"]
    mem_head_stage = cfg_dict["memory_usage"]["head_memory"]
    mem_tail_stage = cfg_dict["memory_usage"]["tail_memory"]
    body_memories = cfg_dict["memory_usage"]["body_memories"]["memories"]
    stages_a = []
    if rounds > 1:
        for i in range(rounds):
            for idx, sg_id in enumerate(stage_id[i]):
                stages_a.append(
                    instantiate_stage(
                        sg_id, pipeline_num,
                        layer_per_stage + offset[i][sg_id],
                        layer_per_recompute[i], body_memories[i][idx],
308
309
310
311
312
313
314
315
316
                        layer_per_stage + offset[i][sg_id],
                        layer_per_recompute[i], body_memories[i][idx],
                    )
                )
        stages_a.append(
            instantiate_stage(
                0, pipeline_num, layer_per_stage + offset[i][0],
                layer_per_recompute[i], mem_head_stage,
            )
314
315
316
317
318
319
320
321
322
                0, pipeline_num, layer_per_stage + offset[i][0],
                layer_per_recompute[i], mem_head_stage,
            )
        )
        stages_a.append(
            instantiate_stage(
                pipeline_num - 1, pipeline_num,
                layer_per_stage + offset[i][pipeline_num - 1],
                layer_per_recompute[i], mem_tail_stage,
322
323
324
325
326
327
328
329
330
331
                layer_per_recompute[i], mem_tail_stage,
            )
        )
    else:
        for idx, sg_id in enumerate(stage_id):
            stages_a.append(
                instantiate_stage(
                    sg_id, pipeline_num, layer_per_stage + offset[sg_id],
                    layer_per_recompute, body_memories[idx],
                )
329
330
331
332
333
334
335
336
337
                    sg_id, pipeline_num, layer_per_stage + offset[sg_id],
                    layer_per_recompute, body_memories[idx],
                )
            )
        stages_a.append(
            instantiate_stage(
                0, pipeline_num, layer_per_stage + offset[0],
                layer_per_recompute, mem_head_stage,
            )
335
336
337
338
339
340
341
342
343
                0, pipeline_num, layer_per_stage + offset[0],
                layer_per_recompute, mem_head_stage,
            )
        )
        stages_a.append(
            instantiate_stage(
                pipeline_num - 1, pipeline_num, layer_per_stage + offset[pipeline_num - 1],
                layer_per_recompute, mem_tail_stage,
            )
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
                layer_per_recompute, mem_tail_stage,
            )
        )

    return pipeline_num, stages_a, num_layer


def initialize_layer_json(model_name: str, file_name: str) -> None:
    """Derive a layer-description JSON from a dry-run YAML and dump it to ``./layers``."""
    num_stage, stages_a, num_layer = memory_parser(file_name)
    head_time, body_time, tail_time = time_parser(file_name, model_name)
    comp_mem = ComputeMemory(number_of_stage=num_stage, stages_a=stages_a)
    mi = ModelInfo(model_name, head_time, body_time, tail_time, num_layer)

    mem_act = {}
    for r in Recompute.TYPE:
        if comp_mem.recompute_considered_[r]:
            mem_act[r] = int(comp_mem.get_memory_activation(r))
            logger.info(
                "[INFO] %s = %f", Recompute.JSON_MEMORY_NAME[r],
                int(comp_mem.get_memory_activation(r))
            )
    if comp_mem.get_memory_const() is not None:
        mem_const = int(comp_mem.get_memory_const())
        logger.info("[INFO] memory_const       = %s", mem_const)
        mi.set_stage_const_mem(mem_const)
    mem_par = int(comp_mem.get_memory_parameter())
    mem_tail = int(comp_mem.get_memory_tail())
    mem_head = int(comp_mem.get_memory_head())
    logger.info("[INFO] memory_parameter  = %s", mem_par)
    logger.info("[INFO] memory_tail       = %s", mem_tail)
    logger.info("[INFO] memory_head       = %s", mem_head)

    mi.layer_memory_update(mem_act, mem_par, mem_head, mem_tail)
    mi.dump_json(os.path.join("./layers", model_name + ".json"))


def get_stage_const_mem(layer_folder: str, model_name: str) -> int:
    """Read ``stage_const_mem`` from ``<layer_folder>/<model_name>.json`` (0 if missing)."""
    json_layer = os.path.join(layer_folder, model_name + '.json')
    with open(json_layer, encoding="utf-8") as json_file:
        json_data = json.load(json_file)
        if "stage_const_mem" in json_data:
            return json_data["stage_const_mem"]
    return 0


def _generate_offset_config(rounds, unknowns, target_sum, array_length):
    """Generate legal random offset arrays"""
    while True:
        offset_config_list = []
        flat = []
        for _ in range(rounds):
            offset_config = _generate_offset_array(target_sum, array_length)
            offset_config_slice = offset_config[1:]
            offset_config_slice = offset_config_slice[:-1]
            flat.append(offset_config_slice)
            offset_config_list.append(offset_config)
        flat = np.array(flat)
        flat = flat.flatten()[0 : unknowns + 1]
        if not np.all(flat == flat[0]):
            return offset_config_list


def _generate_offset_array(target_sum, array_length):
    """Generate a random offset array"""
    if target_sum == array_length:
        return [0] * array_length

    if target_sum < array_length:
        logger.error("number of layers must be larger than stage number")
        return None

    random_array = np.random.randint(1, 10, size=array_length)
    total_sum = random_array.sum()
    scaled_array = (random_array / total_sum) * target_sum
    scaled_array = np.round(scaled_array).astype(int)
    scaled_array = np.maximum(scaled_array, 1)
    current_sum = scaled_array.sum()
    diff = target_sum - current_sum

    if diff >= 0:
        for i in range(abs(diff)):
            scaled_array[i] += 1
    else:
        count = 0
        for i in range(len(scaled_array)):
            # backwards iteration to avoid infinite loop
            if scaled_array[-1 - i] > 1:
                scaled_array[-1 - i] -= 1
                count += 1
                if count == abs(diff):
                    break

    baseline = target_sum // array_length
    offset = scaled_array - baseline
    return offset.tolist()


def _get_coef_matrix(pp, layer_per_stage, offset_config_list, rec_config_list, considered_rec):
    """get coef matrix of equations"""
    activation_nums = SappSolver.compute_activation_nums(pp, 1, 0)[0]
    coef_matrix = []
    rounds = len(offset_config_list)
    for round_ in range(rounds):
        for stage in range(pp):
            if stage not in [0, pp - 1]:
                coef_matrix.append(
                    [1, layer_per_stage + offset_config_list[round_][stage]]
                    + Recompute.to_list(
                        {
                            rec: rec_config_list[round_][rec][stage]
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
                            for rec in considered_rec
                        }
                    )
                )
            if len(coef_matrix) == 2 + len(considered_rec):
                return coef_matrix
    return None


def print_dryrun_config(offset_config_list: List[List[int]],
                        rec_config_list: List[Dict[Recompute.TYPE, List[int]]]) -> None:
    """Log the dry-run YAML config that the user should execute to collect memory data."""
    logger.output(
        "Please dryrun following config, %s round(s) is needed",
        len(offset_config_list),
    )

    for round_, offset_config in enumerate(offset_config_list):
        yaml_config = {
            Recompute.OFFSET: [],
            Recompute.YAML_NAME[Recompute.TYPE.FULL]: [],
            Recompute.YAML_NAME[Recompute.TYPE.SLCT]: [],
            Recompute.YAML_NAME[Recompute.TYPE.COMM]: [],
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
            Recompute.YAML_NAME[Recompute.TYPE.FULL]: [],
            Recompute.YAML_NAME[Recompute.TYPE.SLCT]: [],
            Recompute.YAML_NAME[Recompute.TYPE.COMM]: [],
        }
        yaml_config[Recompute.OFFSET] = offset_config
        pp = len(offset_config)
        slct = rec_config_list[round_].get(Recompute.TYPE.SLCT, [0] * pp)
        comm = rec_config_list[round_].get(Recompute.TYPE.COMM, [0] * pp)
        full = rec_config_list[round_].get(Recompute.TYPE.FULL, [0] * pp)
        both = rec_config_list[round_].get(Recompute.TYPE.BOTH, [0] * pp)

        yaml_config[Recompute.YAML_NAME[Recompute.TYPE.FULL]] = full
        yaml_config[Recompute.YAML_NAME[Recompute.TYPE.SLCT]] = [
            x + y + z for x, y, z in zip(slct, both, full)
        ]
        yaml_config[Recompute.YAML_NAME[Recompute.TYPE.COMM]] = [
            x + y + z for x, y, z in zip(comm, both, full)
        ]
        yaml_results = f"for round {round_ + 1}, please dryrun config:"
        for y, v in yaml_config.items():
            yaml_results += f"\n\t{y}: {v}"
        logger.output(yaml_results)


def generate_solvable_config(
        pp: int, num_layers: int,
        considered_rec: List[Recompute.TYPE]
) -> Optional[Tuple[List[List[int]], List[Dict[Recompute.TYPE, List[int]]]]]:
    """Generate offset / recompute configs whose coefficient matrix is full-rank."""
    if pp == 2:
        logger.error("pp = 2 is not supported yet")
        return None

    considered_rec.append(Recompute.TYPE.NONE)
    rounds = ceil((2 + len(considered_rec)) / (pp - 2))
    layer_per_stage = num_layers // pp
    is_solvable = False
    offset_config_list = _generate_offset_config(
        rounds, 2 + len(considered_rec), num_layers, pp
    )
    while not is_solvable:
        rec_config_list = []
        for round_ in range(rounds):
            offset_config = offset_config_list[round_]
            layer_per_recompute = {r: [0] * pp for r in considered_rec}
            for rec in considered_rec:
                stage_sum = [
                    sum(col) for col in zip(*layer_per_recompute.values())
                ]  # summation of each rec in each stage
                layers_left = [
                    offset_config[i] + layer_per_stage - stage_sum[i] for i in range(pp)
                ]
                layer_per_recompute[rec] = [
                    random.randint(0, layers_left[i]) for i in range(pp)
                ]
            rec_config_list.append(layer_per_recompute)

        coef_matrix = _get_coef_matrix(
            pp, layer_per_stage, offset_config_list, rec_config_list, considered_rec
        )
        coef_rank = np.linalg.matrix_rank(coef_matrix)
        if coef_rank == len(considered_rec) + 2:
            is_solvable = True

    return offset_config_list, rec_config_list


def parse_training_config(yaml_path: str) -> Optional[Dict[str, Any]]:
    """Extract the training-YAML fields needed to compute operator shapes for seqpp."""
    try:
        with open(yaml_path, "r", encoding="utf-8") as file:
            config = yaml.safe_load(file)

        # Extract the requested values
        model_config = config["model"]["model_config"]
        parallel_config = config["parallel_config"]
        runner_config = config["runner_config"]

        # Create a dictionary with the requested parameters
        extracted_params = {
            "num_heads": model_config["num_heads"],
            "hidden_size": model_config["hidden_size"],
            "head_dim": int(model_config["hidden_size"] / model_config["num_heads"]),
            "seq_length": model_config["seq_length"],
561
562
563
564
565
566
567
568
569
570
571
572
573
            "data_parallel": parallel_config.get("data_parallel", 1),
            "model_parallel": parallel_config.get("model_parallel", 1),
            "context_parallel": parallel_config.get("context_parallel", 1),
        }
        logger.output("Extracted training parameters:")
        for key, value in extracted_params.items():
            logger.output("%s: %s", key, value)

        return extracted_params

    except (yaml.YAMLError, FileNotFoundError, PermissionError) as e:
        logger.error("Error parsing Training YAML file: %s", e)
        return None
hyper_parallel/auto_parallel/sapp_ppb/utils/error.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Error types and input validation helpers for SAPP-PPB."""
from typing import Union

Number = Union[int, float]


class SAPPError(ValueError):
    """Raised when SAPP-PPB detects an invalid input or configuration."""


def assert_sapp(test: bool, msg: str) -> None:
    """Raise :class:`SAPPError` with ``msg`` when ``test`` is false.

    Args:
        test: Condition that must hold.
31
32
33
34
35
36
37
38
39
40
41
42
43

    Raises:
        SAPPError: If ``test`` is ``False``.
    """
    if not test:
        raise SAPPError(msg)


def check_in_bounds(n: Number, n_desc: str, lower_bound: Number, higher_bound: Number) -> None:
    """Check that ``n`` lies in the inclusive range ``[lower_bound, higher_bound]``.

    Args:
        n: The value being validated.
47
48
49
50
51
52
53
54

    Raises:
        SAPPError: If ``n`` falls outside ``[lower_bound, higher_bound]``.
    """
    assert_sapp(n >= lower_bound,
                f"{n_desc} {n} should be higher than {lower_bound}")
    assert_sapp(n <= higher_bound,
                f"{n_desc} {n} should be lower than {higher_bound}")
hyper_parallel/auto_parallel/sapp_ppb/utils/interactive.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Interactive CLI flow used when ``run_pipeline_balance.py`` is launched without arguments."""
from collections import namedtuple
from typing import Any, List

import sapp_ppb.utils.recompute as Recompute
from sapp_ppb.sapp.sapp_pipeline import SappPipeline
from sapp_ppb.utils.config import generate_solvable_config, print_dryrun_config
from sapp_ppb.utils.error import check_in_bounds
from sapp_ppb.utils.layer import Layer
from sapp_ppb.utils.logger import logger

YES_OR_NO = "[y/n]? "

OPTIONAL = " if you wish: "

GLOBALARGUMENTS = namedtuple('GLOBALARGUMENTS', ['stage_num', 'micro_batch', 'interleave', 'max_memory'])


def default_v(d: Any) -> str:
    """Format an inline ``(<default> if none)`` hint for ``input()`` prompts."""
    return " (" + str(d) + " if none): "


def is_yes(s: str) -> bool:
    """Return True when ``s`` is a yes-like response (``y``, ``yes``, ``1``)."""
    return s.lower().startswith('y') or s == "1"


def is_empty(s: str) -> bool:
    """Return True when ``s`` is blank or equal to ``*``."""
    return len(s.strip()) == 0 or (s.strip() == '*')


def global_arguments() -> GLOBALARGUMENTS:
    """Prompt the user for the global pipeline arguments and return them as a named tuple."""
    stage_num = 4
    micro_batch = 8
    interleave = 1
    max_memory = 56000
    s = input("Please enter the pipeline stage number" + default_v(stage_num))
    if not is_empty(s):
        stage_num = int(s)
        check_in_bounds(stage_num, "Pipeline stage number", 1, 10000)

    s = input("Please enter the micro batch number" + default_v(micro_batch))
    if not is_empty(s):
        micro_batch = int(s)
        check_in_bounds(micro_batch, "Micro batch number", 1, 10000)

    s = input("Please enter the pipeline interleave number" + default_v(interleave))
    if not is_empty(s):
        interleave = int(s)
        check_in_bounds(interleave, "Interleave", 1, 10)

    s = input("Please enter maximum memory" + default_v(max_memory))
    if not is_empty(s):
        max_memory = int(s)
        check_in_bounds(max_memory, "Maximum memory", 1, 1000000)

    return GLOBALARGUMENTS(stage_num, micro_batch, interleave, max_memory)


def make_layer(t: Layer.type_enum, model_name: str) -> Layer:
    """Prompt the user for one layer's metadata and return the resulting :class:`Layer`."""
    nb_layer = 1
    layer_time = 0
    memory_parameter = 0
    memory_activation_rec = {r: None for r in Recompute.TYPE}
    layer_name = "misc_" + t.name
    s = input("\tEnter the layer name" + OPTIONAL)
    if not is_empty(s):
        layer_name = s
    s = input("\tEnter the layer execution time: ")
    if not is_empty(s):
        layer_time = int(s)
    if t is Layer.type_enum.BODY:
        s = input("\tEnter the number of such layer: ")
        if not is_empty(s):
            nb_layer = int(s)
        s = input("\tEnter the layer parameter memory (MB): ")
        if not is_empty(s):
            memory_parameter = int(s)
        for r in Recompute.TYPE:
            s = input("\tEnter the layer " + Recompute.JSON_MEMORY_NAME[r] + OPTIONAL)
            if not is_empty(s):
                memory_activation_rec[r] = int(s)
    else:
        s = input("\tEnter the layer memory (MB): ")
        if not is_empty(s):
            memory_parameter = int(s)

    return Layer(name=layer_name, ltype=t, nb_layer=nb_layer, time=layer_time,
                 model_name=model_name, memory_activation_rec=memory_activation_rec,
                 memory_parameter=memory_parameter,)


def dryrun_guide() -> None:
    """Prompt the user through a dry-run configuration and print a candidate layout."""
    considered_rec: List[Recompute.TYPE] = []
    stage_num = 0
    num_layers = 0
    s = input("Please enter the pipeline stage number" + default_v(stage_num))
    if not is_empty(s):
        stage_num = int(s)
        check_in_bounds(stage_num, "Pipeline stage number", 1, 10000)
    else:
        return

    s = input("Please enter the number of layers" + default_v(num_layers))
    if not is_empty(s):
        num_layers = int(s)
        check_in_bounds(num_layers, "Micro batch number", 1, 10000)
    else:
        return

    s = input("Do you consider full recomputation?" + YES_OR_NO)
    if is_yes(s):
        considered_rec.append(Recompute.TYPE.FULL)

    s = input("Do you consider select recomputation?" + YES_OR_NO)
    if is_yes(s):
        considered_rec.append(Recompute.TYPE.SLCT)

    s = input("Does your communication recomputation co-work with select recomputation?" + YES_OR_NO)
    if is_yes(s):
        considered_rec.append(Recompute.TYPE.BOTH)

    s = input("Do you consider extra communication recomputation?" + YES_OR_NO)
    if is_yes(s):
        considered_rec.append(Recompute.TYPE.COMM)

    offset_config_list, rec_config_list = generate_solvable_config(stage_num, num_layers, considered_rec)
    print_dryrun_config(offset_config_list, rec_config_list)


def main() -> None:
    """Entry point for the interactive session launched without CLI arguments."""
    s = input(
        "No arguments were given. Would you like to proceed to the interactive mode " + YES_OR_NO)
    if not is_yes(s):
        return

    global_args = global_arguments()
    number_of_stage = global_args.stage_num
    number_of_micro_batch = global_args.micro_batch
    interleave_degree = global_args.interleave
    max_memory = global_args.max_memory

    model_name = "misc"
    s = input("\tEnter the model name" + OPTIONAL)
    if not is_empty(s):
        model_name = s

    layers = []
    for ltype in Layer.type_enum:
        if ltype is not Layer.type_enum.UNKNOWN:
            logger.info("Please enter information of your network %s", ltype.name)
            layers.append(make_layer(ltype, model_name))

    pipe = SappPipeline(model_name=model_name, num_of_stage=number_of_stage,
                        num_of_micro_batch=number_of_micro_batch, max_memory=max_memory,
                        layers=layers, num_of_interleave=interleave_degree,)

    for layer in layers:
        logger.info("%s", layer)

    pipe.construct_problem(solver="pulp")
    pipe.solve_problem(time_limit=40, dump_folder="output")
    pipe.print_yaml_results()
    pipe.simulate(show=True)
hyper_parallel/auto_parallel/sapp_ppb/utils/layer.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Layer descriptor used throughout SAPP-PPB: time, memory and recomputation metadata."""
import json
import os
from enum import Enum
from typing import Any, Dict, Optional

import sapp_ppb.utils.recompute as Recompute
from sapp_ppb.utils.computation_analyzer import ComputationAnalyzer
from sapp_ppb.utils.logger import logger


class Layer:
    """
    Mandatory parameter:
    name_ (str): name of the layer
    type_ (LayerType): type of the layer 'HEAD', 'BODY', 'TAIL'
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    Not manage yet parameter (for multimodal):
    model_name_ (str): name of the model the layer be part of (for multimodal)
    """

    type_enum = Enum("LayerType", ["UNKNOWN", "HEAD", "BODY", "TAIL"])
    backward_default_ratio = 2  # of forward time
    name_: str
    model_name_: str
    type_: type_enum
    nb_layer_: int
    time_: float
    memory_parameter_: float
    memory_activation_rec_: dict[Recompute.TYPE, float]
    forward_time_: float
    backward_time_rec_: dict[Recompute.TYPE, float]
    backward_coef_rec_: dict[Recompute.TYPE, float]
    recompute_considered_: dict[Recompute.TYPE, bool]

    def __init__(
        self,
        model_name: str = "misc",
        name: str = "misc",
        ltype: type_enum = type_enum.UNKNOWN,
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
            memory_parameter (float, optional): Parameter memory in MB. Default: ``0.0``.
            memory_activation_rec (Optional[Dict[Recompute.TYPE, float]], optional): Per-recomputation-type
                activation memory (``None`` -> zeros). Default: ``None``.
        """
        if backward_time_rec is None:
            backward_time_rec = {r: 0 for r in Recompute.TYPE}
        if backward_coef_rec is None:
            backward_coef_rec = {r: 0 for r in Recompute.TYPE}
        if memory_activation_rec is None:
            memory_activation_rec = {r: 0.0 for r in Recompute.TYPE}
        self.name_ = name
        self.model_name_ = model_name
        self.type_ = ltype
        self.nb_layer_ = nb_layer
        self.time_ = time
        self.memory_activation_rec_ = memory_activation_rec
        self.memory_parameter_ = memory_parameter
        self.backward_time_rec_ = backward_time_rec
        self.backward_coef_rec_ = backward_coef_rec
        self.forward_time_ = forward_time
        self.recompute_considered_ = self.find_recompute_considered()
        self.compute_internal_time()

    def __str__(self) -> str:
        """Return a multi-line, human-readable description of the layer."""
        result = "Layer Description:\n"
        result += "  name             = " + self.name_ + "\n"
        result += "  model_name       = " + str(self.model_name_) + "\n"
        result += "  type             = " + self.type_.name + "\n"
        result += "  nb_layer         = " + str(self.nb_layer_) + "\n"
        result += "  time             = " + str(self.time_) + "\n"
        result += "  memory_parameter = " + str(self.memory_parameter_) + "\n"
        for r in Recompute.TYPE:
            if self.recompute_considered_[r]:
                result += "  " + Recompute.JSON_MEMORY_NAME[r] + " = "
                result += str(self.memory_activation_rec_[r]) + "\n"
        result += "  forward_time     = "
        result += str(self.forward_time_) + "\n"
        for r in Recompute.TYPE:
            if self.recompute_considered_[r]:
                result += "  " + Recompute.JSON_TIME_NAME[r] + " = "
                result += str(self.backward_time_rec_[r]) + "\n"
        return result

    def dump(self, dump_file: str) -> None:
        """Dump the layer to ``dump_file`` as JSON (currently a placeholder)."""
        logger.error("dump file (%s) Not implemented yet!!!", dump_file)

    def to_json(self) -> None:
        """Generate the JSON representation of the layer (currently a placeholder)."""
        logger.error("Not implemented yet!!!")

    def find_recompute_considered(self) -> Dict[Recompute.TYPE, bool]:
        """Return which recomputation types have valid activation-memory data."""
        recompute_considered = {rec: False for rec in Recompute.TYPE}

        for rec in Recompute.TYPE:
            if self.memory_activation_rec_[rec] is not None:
                recompute_considered[rec] = True

        return recompute_considered

    def compute_internal_time(
        self,
        back_ratio: float = backward_default_ratio,
        force_fb: bool = False,
    ) -> None:
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        back_ratio: float = backward_default_ratio,
        force_fb: bool = False,
    ) -> None:
        """Derive forward/backward times from ``time_`` if not already set."""
        if force_fb or self.forward_time_ is None:
            self.forward_time_ = self.time_
        self.backward_time_ = back_ratio * self.time_

        for rec in Recompute.TYPE:
            if self.recompute_considered_[rec]:
                if (
                    self.backward_time_rec_[rec] is None
                    or self.backward_time_rec_[rec] == 0
                ):
                    if self.backward_coef_rec_[rec] is None:
                        self.backward_time_rec_[rec] = (
                            1 + Recompute.DEFAULT_COEF[rec]
                        ) * self.backward_time_
                    else:
                        self.backward_time_rec_[rec] = (
                            1 + self.backward_coef_rec_[rec]
                        ) * self.backward_time_

    def update_internal_time_for_seqpp(
        self,
        back_ratio: float = backward_default_ratio,
        force_fb: bool = False,
    ) -> None:
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        back_ratio: float = backward_default_ratio,
        force_fb: bool = False,
    ) -> None:
        """Adjust ``forward_time_``/``backward_time_`` for the sequence-pipeline mode."""
        if force_fb or self.forward_time_ is None:
            self.forward_time_ = (1 - back_ratio) * self.time_
        self.backward_time_ = back_ratio * self.time_

        for rec in Recompute.TYPE:
            if self.recompute_considered_[rec]:
                if self.backward_coef_rec_[rec] is None:
                    self.backward_time_rec_[rec] = (
                        1 + Recompute.DEFAULT_COEF[rec]
                    ) * self.backward_time_
                else:
                    self.backward_time_rec_[rec] = (
                        1 + self.backward_coef_rec_[rec]
                    ) * self.backward_time_

    def compute_timer(
        self, timeline_folder: str = "./timeline", tmp_layer_info: Optional[dict] = None
    ) -> None:
        """Populate ``time_`` from profiling timelines stored in ``timeline_folder``."""
        layer_time = ComputationAnalyzer(
            timeline_folder,
            self.model_name_,
            num_of_micro_batch=0,
            layer_list=tmp_layer_info,
201
202
203
204
205
206
207
208
209
210
211
212
213
            self.model_name_,
            num_of_micro_batch=0,
            layer_list=tmp_layer_info,
        )
        self.time_ = layer_time.layer_with_cost_list.get(self.name_)
        self.compute_internal_time(force_fb=True)

    def compute_memory(self, memory_folder: str = "./memory") -> None:
        """Compute the memory information from ``memory_folder`` dry-run logs (placeholder)."""
        logger.error(
            "compute_memory (%s) Not implemented yet!!!", memory_folder
        )

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230

# Helper functions on layer list


def generate_layers_list(layer_folder: str, model_name: str) -> list[Layer]:
    """ "Parse layer_folder/model_name.json to generate a list of layer"""
    layers = []
    json_layer = os.path.join(layer_folder, model_name + ".json")
    with open(json_layer, encoding="utf-8") as json_file:
        layer_data_json = json.load(json_file)
        if "layers_description" in layer_data_json:
            for layer_data in layer_data_json["layers_description"]:
                new_layer = Layer(
                    name=layer_data["name"],
                    ltype=Layer.type_enum[layer_data["type"]],
                    nb_layer=layer_data["nb_layer"],
                    time=layer_data["time"],
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
                        for r in Recompute.TYPE
                    },
                    memory_parameter=layer_data.get("memory_parameter"),
                )
                new_layer.compute_internal_time()
                layers.append(new_layer)
        else:
            logger.error(
                'ERROR: File "%s" doesn\'t have layers_description to parse.\n',
                json_layer,
            )
    return layers


def filter_layer_type(
    layers: list[Layer], layer_type: Layer.type_enum
) -> list[Layer]:
    """Filters all layers of layer_type in layers."""
    kept_layers = []
    for layer in layers:
        if layer.type_ == layer_type:
            kept_layers.append(layer)
    return kept_layers


def aggregate(layers: list[Layer]) -> Layer:
    """Aggregate all layers into one."""

    def add_none(a: Optional[Any], b: Optional[Any]) -> Any:
        """Add ``a`` and ``b``, returning whichever is not ``None`` when one is missing."""
        if a is None:
            return b
        if b is None:
            return a
        return a + b

    def add_rec_dict(a: Dict[Recompute.TYPE, Any],
                     b: Dict[Recompute.TYPE, Any]) -> Dict[Recompute.TYPE, Any]:
        """Element-wise add two per-recomputation-type dictionaries."""
        return {i: a[i] + b[i] for i in Recompute.TYPE}

    aggregation = layers[0]
    layers.pop(0)
    for layer in layers:
        aggregation.time_ += layer.time_
        aggregation.backward_time_rec_ = add_rec_dict(
            aggregation.backward_time_rec_, layer.backward_time_rec_
        )
        aggregation.memory_activation_rec_ = add_rec_dict(
            aggregation.memory_activation_rec_, layer.memory_activation_rec_
        )
        aggregation.memory_parameter_ = add_none(
            aggregation.memory_parameter_, layer.memory_parameter_
        )
        aggregation.nb_layer_ = add_none(
            aggregation.nb_layer_, layer.nb_layer_
        )
    return aggregation
hyper_parallel/auto_parallel/sapp_ppb/utils/logger.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""logger for pipeline balance"""
import logging

DEFAULT_STDOUT_FORMAT = '%(levelname)s %(asctime)s %(filename)s:%(lineno)d - %(message)s'
FORMATTER = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
OUTPUT_LEVEL_NUM = logging.WARNING
logging.addLevelName(OUTPUT_LEVEL_NUM, "OUTPUT")


def setup_logger(name: str, level: int = logging.DEBUG) -> logging.Logger:
    """Create a namespaced logger and register the ``output`` convenience level.

    Args:
        name: Logger name (typically a package or module name).
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

    Returns:
        The configured :class:`logging.Logger` instance.
    """
    ch = logging.StreamHandler()
    ch.setLevel(level)
    ch.setFormatter(FORMATTER)

    def output(self: logging.Logger, message: str, *args: object) -> None:
        """Emit ``message`` at the WARNING level (aliased as ``OUTPUT``)."""
        self.warning(message, *args)

    logging.Logger.output = output
    ppb_logger = logging.getLogger(name)
    ppb_logger.setLevel(level)
    ppb_logger.addHandler(ch)

    return ppb_logger

logger = setup_logger('sapp_ppb', level=logging.INFO)
hyper_parallel/auto_parallel/sapp_ppb/utils/recompute.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Recomputation taxonomy and conversion helpers between internal dicts and the YAML schema."""
from enum import IntEnum
from typing import Any, Dict, List, Optional

from sapp_ppb.utils.logger import logger

TYPE = IntEnum("RecomputeType", ["NONE", "SLCT", "COMM", "BOTH", "FULL"], start=0)
OFFSET = "offset"

DEFAULT_COEF = {
    TYPE.NONE: 0,
    TYPE.SLCT: 0.04,
    TYPE.COMM: 0.125,
    TYPE.BOTH: 0.165,
28
29
30
31
32
33
34
35
36
    TYPE.BOTH: 0.165,
    TYPE.FULL: 0.5,
}

YAML_NAME = {
    TYPE.NONE: "",
    TYPE.COMM: "select_comm_recompute",
    TYPE.SLCT: "select_recompute",
    TYPE.BOTH: "both_comm_select",
36
37
38
39
40
41
42
43
44
    TYPE.BOTH: "both_comm_select",
    TYPE.FULL: "recompute",
}

JSON_MEMORY_NAME = {
    TYPE.NONE: "memory_activation",
    TYPE.COMM: "memory_select_comm",
    TYPE.BOTH: "memory_both_comm_select",
    TYPE.SLCT: "memory_select_rec",
44
45
46
47
48
49
50
51
52
    TYPE.SLCT: "memory_select_rec",
    TYPE.FULL: "memory_recompute",
}

JSON_MEMORY_NAME_ALIGNED = {
    TYPE.NONE: "memory_activation ",
    TYPE.COMM: "memory_select_comm",
    TYPE.BOTH: "memory_both_comm_select",
    TYPE.SLCT: "memory_select_rec ",
53
54
55
56
57
58
59
60
61
    TYPE.FULL: "memory_recompute  ",
}


JSON_TIME_NAME = {
    TYPE.NONE: "backward_time",
    TYPE.COMM: "select_comm_time",
    TYPE.BOTH: "both_comm_select_time",
    TYPE.SLCT: "select_rec_time",
61
62
63
64
65
66
67
68
69
    TYPE.SLCT: "select_rec_time",
    TYPE.FULL: "recompute_time ",
}

JSON_COEF_NAME = {
    TYPE.NONE: "backward_coef",
    TYPE.SLCT: "select_rec_coef",
    TYPE.BOTH: "both_comm_select_coef",
    TYPE.COMM: "select_comm_coef",
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    TYPE.FULL: "recompute_coef",
}


def sums(rec_dict: Dict[TYPE, int]) -> int:
    """Return the sum of layer counts across every :class:`TYPE` key in ``rec_dict``."""
    x = 0
    for r in TYPE:
        x += rec_dict[r]
    return x


def zero_if_none_var(v: Any, i: int, s: int) -> int:
    """Read ``int(v[i][s].varValue)`` guarding against ``None`` at any step."""
    if v is not None and v[i][s].varValue is not None:
        return int(v[i][s].varValue)
    return 0


def zero_if_none(v: Any, i: int, s: int) -> int:
    """Read ``int(v[i][s])`` guarding against ``None`` at any step."""
    if v is not None and v[i][s] is not None:
        return int(v[i][s])
    return 0


def yaml_from_internal(vpp: int, pp: int,
                       lp_variables: Dict[TYPE, Any],
                       nass: List[List[int]]) -> Dict[str, List[List[int]]]:
    """Convert solver variables into the MindFormers YAML schema.
106
107
108
109
110
111
112
113
114
115
116
117
118
119

    Returns:
        A mapping from YAML field name to a 2-D list of ``(vpp, pp)`` integers.
    """
    slct_is = 0
    comm_is = 0
    both_is = 0
    full_is = 0

    yaml_out: Dict[str, List[List[int]]] = {
        OFFSET: [],
        YAML_NAME[TYPE.FULL]: [],
        YAML_NAME[TYPE.SLCT]: [],
        YAML_NAME[TYPE.COMM]: [],
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        YAML_NAME[TYPE.FULL]: [],
        YAML_NAME[TYPE.SLCT]: [],
        YAML_NAME[TYPE.COMM]: [],
    }
    logger.debug("pp = %s, vpp = %s", pp, vpp)
    for i in range(vpp):
        for _, v in yaml_out.items():
            v.append([])
        for s in range(pp):
            gass_i_s = 0
            for r in TYPE:
                gass_i_s += zero_if_none_var(lp_variables[r], i, s)
            slct_is = zero_if_none_var(lp_variables[TYPE.SLCT], i, s)
            comm_is = zero_if_none_var(lp_variables[TYPE.COMM], i, s)
            both_is = zero_if_none_var(lp_variables[TYPE.BOTH], i, s)
            full_is = zero_if_none_var(lp_variables[TYPE.FULL], i, s)
            yaml_out[OFFSET][i].append(gass_i_s - nass[i][s])
            yaml_out[YAML_NAME[TYPE.FULL]][i].append(full_is)
            yaml_out[YAML_NAME[TYPE.SLCT]][i].append(slct_is + both_is + full_is)
            yaml_out[YAML_NAME[TYPE.COMM]][i].append(comm_is + both_is + full_is)

    logger.debug("yaml = %s", yaml_out)
    return yaml_out


def internal_from_yaml(vpp: int, pp: int,
                       yaml_in: Dict[str, Any],
                       nass: List[List[int]]) -> Dict[TYPE, List[List[int]]]:
    """Convert a MindFormers YAML schema back into per-type layer counts.
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

    Returns:
        A mapping from :class:`TYPE` to a 2-D list of ``(vpp, pp)`` integers.
    """
    slct_is = 0
    comm_is = 0
    full_is = 0
    layer_per_recompute: Dict[TYPE, List[List[int]]] = {r: [] for r in TYPE}
    if yaml_in[OFFSET] == 0:
        yaml_in[OFFSET] = [[0] * pp for _ in range(vpp)]

    for rec in [TYPE.SLCT, TYPE.COMM, TYPE.FULL]:
        if (
                YAML_NAME[rec] not in yaml_in
                or yaml_in[YAML_NAME[rec]] is False
                or yaml_in[YAML_NAME[rec]] == 0
        ):
            yaml_in[YAML_NAME[rec]] = [[0] * pp for _ in range(vpp)]
        if yaml_in[YAML_NAME[rec]] is True:
            yaml_in[YAML_NAME[rec]] = [
                [a + b for a, b in zip(list1, list2)]
                for list1, list2 in zip(nass, yaml_in[OFFSET])
            ]

    for i in range(vpp):
        for _, v in layer_per_recompute.items():
            v.append([])
        for s in range(pp):
            slct_is = zero_if_none(yaml_in[YAML_NAME[TYPE.SLCT]], i, s)
            comm_is = zero_if_none(yaml_in[YAML_NAME[TYPE.COMM]], i, s)
            full_is = zero_if_none(yaml_in[YAML_NAME[TYPE.FULL]], i, s)
            layer_per_recompute[TYPE.FULL][i].append(full_is)
            layer_per_recompute[TYPE.BOTH][i].append(
                max(min(slct_is - full_is, comm_is - full_is), 0)
            )
            layer_per_recompute[TYPE.SLCT][i].append(
                max(slct_is - full_is - layer_per_recompute[TYPE.BOTH][i][s], 0)
            )
            layer_per_recompute[TYPE.COMM][i].append(
                max(comm_is - full_is - layer_per_recompute[TYPE.BOTH][i][s], 0)
            )
            layer_per_recompute[TYPE.NONE][i].append(
                (
                    yaml_in[OFFSET][i][s]
                    + nass[i][s]
                    - layer_per_recompute[TYPE.FULL][i][s]
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
                    - layer_per_recompute[TYPE.COMM][i][s]
                )
            )

    logger.debug("layer_per_recompute = %s", layer_per_recompute)
    return layer_per_recompute


def to_list(rec_dict: Dict[TYPE, Any]) -> List[Any]:
    """Return the values of ``rec_dict`` in :class:`TYPE` enum order."""
    return list(rec_dict.values())


def right_extend(ll: List[List[int]], n: int) -> List[List[int]]:
    """Return ``ll`` extended by appending each of ``range(n)`` to every sub-list.

    Args:
        ll: List of partially built index vectors.
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

    Returns:
        A new list where each input sub-list appears ``n`` times, each with one of the new values.
    """
    all_l: List[List[int]] = []
    for i in range(n):
        for sublist in ll:
            all_l += [sublist + [i]]
    return all_l


def make_all_indexes_local(used_rec: Dict[TYPE, bool], num_of_interleave: int,
                           all_indexes: List[List[int]], r: TYPE) -> List[List[int]]:
    """Recursive helper behind :func:`make_all_indexes`.

    Args:
239
240
241
242
243
244
245
246
247
248
249
250
251
252

    Returns:
        The completed list of index vectors once the last :class:`TYPE` is reached.
    """
    if r >= len(TYPE) - 1:
        if used_rec[r]:
            all_indexes = right_extend(all_indexes, num_of_interleave)
        return all_indexes
    if used_rec[r]:
        return make_all_indexes_local(
            used_rec,
            num_of_interleave,
            right_extend(all_indexes, num_of_interleave),
            TYPE(r + 1),
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
            num_of_interleave,
            right_extend(all_indexes, num_of_interleave),
            TYPE(r + 1),
        )
    return make_all_indexes_local(used_rec, num_of_interleave, all_indexes, TYPE(r + 1))


def make_all_indexes(used_rec: Dict[TYPE, bool], num_of_interleave: int) -> List[List[int]]:
    """Enumerate all per-recomputation-type assignments across ``num_of_interleave`` chunks."""
    return make_all_indexes_local(used_rec, num_of_interleave, [[]], TYPE.NONE)


def recomputes_from_indexes(used_rec: Dict[TYPE, bool],
                            indexes: List[List[int]]) -> List[Dict[TYPE, Optional[int]]]:
    """Decode index vectors produced by :func:`make_all_indexes` into per-type dictionaries."""
    recomputes: List[Dict[TYPE, Optional[int]]] = []
    for idx in indexes:
        recompute: Dict[TYPE, Optional[int]] = {r: None for r in TYPE}
        for r in TYPE:
            if used_rec[r]:
                recompute[r] = idx[0]
                idx.pop(0)
        recomputes.append(recompute)
    return recomputes


def average(rec_list: List[Dict[TYPE, Optional[float]]]) -> Dict[TYPE, Optional[float]]:
    """Return the per-type mean of a list of per-type recomputation dicts.

    Args:
        rec_list: Mapping from :class:`TYPE` to a numeric value or ``None``.
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349

    Returns:
        A new dict holding the arithmetic mean for each :class:`TYPE` (``None`` propagates).
    """
    num = len(rec_list)
    if num == 0:
        return rec_list
    rec_1 = rec_list.pop(0)
    for rec_i in rec_list:
        for r in TYPE:
            if rec_1[r] is not None and rec_i[r] is not None:
                rec_1[r] = rec_1[r] + rec_i[r]
            elif not (rec_1[r] is None and rec_i[r] is None):
                logger.warning(
                    "WARNING: Recomputation %s is not taken into consideration by all body layers",
                    r.name,
                )
    for r in TYPE:
        if rec_1[r] is not None:
            rec_1[r] = rec_1[r] / num
    return rec_1


def assign_used(values: List[int], unused_rec: List[TYPE]) -> Dict[TYPE, Optional[int]]:
    """Associate each value with its recomputation type, skipping ``unused_rec`` entries."""
    assignment: Dict[TYPE, Optional[int]] = {r: None for r in TYPE}
    value_idx = 0
    for r in TYPE:
        if r not in unused_rec:
            assignment[r] = values[value_idx]
            value_idx += 1
    return assignment


def get_used_list(recompute_considered: Dict[TYPE, bool]) -> List[TYPE]:
    """Return recomputation types flagged as enabled in ``recompute_considered``."""
    used_rec: List[TYPE] = []
    for rec in TYPE:
        if recompute_considered[rec]:
            used_rec.append(rec)
    return used_rec


def get_unused_list(recompute_considered: Dict[TYPE, bool]) -> List[TYPE]:
    """Return recomputation types flagged as disabled (or missing) in ``recompute_considered``."""
    unused_rec: List[TYPE] = []
    for rec in TYPE:
        if rec not in recompute_considered or not recompute_considered[rec]:
            unused_rec.append(rec)
    return unused_rec


def least_recomputed(recompute_considered: Dict[TYPE, bool]) -> TYPE:
    """Return the lowest-index enabled recomputation :class:`TYPE`."""
    rec = TYPE.NONE
    for r in TYPE:
        if recompute_considered[r]:
            rec = r
            break
    return rec


def most_recomputed(recompute_considered: Dict[TYPE, bool]) -> TYPE:
    """Return the highest-index enabled recomputation :class:`TYPE`."""
    rec = TYPE.FULL
    for r in TYPE:
        if recompute_considered[r]:
            rec = r
    return rec
hyper_parallel/auto_parallel/sapp_ppb/utils/stage.py
13
14
15
16
17
18
19
20
21
22
23
24
25
# limitations under the License.
# ============================================================================
"""Stage descriptor used by the ILP solver to represent one pipeline stage's workload."""

import sapp_ppb.utils.recompute as Recompute
from sapp_ppb.utils.error import assert_sapp


class Stage:
    """Stage Class to describe a run from a log

    id_ (int): stage id of the run
    nb_stage_ (int): total number of stage present
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    nb_layer_ == (nb_recompute+nb_select_rec+nb_norecompute)
    id_ < nb_stage_
    """

    id_: int
    nb_stage_: int
    nb_layer_: int
    nb_layer_rec_: dict[Recompute.TYPE, int]
    memory_usage_: int

    def __init__(self, sid: int, nb_stage: int, nb_layer: int,
                 nb_layer_rec: dict[Recompute.TYPE, int], memory_usage: int) -> None:
        """Build a :class:`Stage` record.

        Args:
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
            nb_layer: Total number of layers assigned to this stage.
            nb_layer_rec: Per-recomputation-type layer counts; missing keys are completed.
            memory_usage: Observed memory usage in MB.
        """
        self.id_ = sid
        self.nb_stage_ = nb_stage
        self.nb_layer_ = nb_layer
        self.nb_layer_rec_ = self.complete_nb_layer_rec_(nb_layer_rec)
        self.memory_usage_ = memory_usage
        assert_sapp(nb_layer == Recompute.sums(nb_layer_rec),
                    "init stage, nb_layer == (nb_recompute+nb_norecompute)")
        assert_sapp(sid < nb_stage, "init stage, id < nb_stage")

    def complete_nb_layer_rec_(
            self, nb_layer_rec: dict[Recompute.TYPE, int]) -> dict[Recompute.TYPE, int]:
        """Fill in missing recomputation-type entries so every :class:`Recompute.TYPE` is present."""
        sum_layers = 0
        for r in Recompute.TYPE:
            if r is not Recompute.TYPE.NONE:
                if r not in nb_layer_rec:
                    nb_layer_rec[r] = 0
                else:
                    sum_layers += nb_layer_rec[r]

        if Recompute.TYPE.NONE not in nb_layer_rec:
            nb_layer_rec[Recompute.TYPE.NONE] = self.nb_layer_ - sum_layers

        return nb_layer_rec

    def same_config(self, other: 'Stage') -> bool:
        """
        Check if self and other have same configuration
        same number of total layers, number of total stages, recompute layers and no recompute
        layers
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
        Check if self and other have same configuration
        same number of total layers, number of total stages, recompute layers and no recompute
        layers
        """
        return (
            self.nb_layer_ == other.nb_layer_ and self.nb_stage_ == other.nb_stage_ and
            self.nb_layer_rec_ == other.nb_layer_rec_)

    def same_global_config(self, other: 'Stage') -> bool:
        """
        Check if self and other have same configuration
        same number of total layers and number of total stages
        """
        return self.nb_stage_ == other.nb_stage_

    def get_index_memory_var(self) -> list[int]:
        """
        Returns memory factors for parameter and
        activation for all recomputation types
        """
        diff = self.nb_stage_ - self.id_
        return [self.nb_layer_] + Recompute.to_list(
            {r: self.nb_layer_rec_[r] * diff for r in Recompute.TYPE})


def filter_stage_id(stages: list[Stage], stage_id: int) -> list[Stage]:
    """Filters all stages of stage_id in stages."""
    kept_stages = []
    for s in stages:
        if s.id_ == stage_id:
            kept_stages.append(s)
    return kept_stages