Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/auto_parallel/sapp_nd/__init__.py 100%  
hyper_parallel/auto_parallel/sapp_nd/memory_estimation/_backbone.py 100%  
hyper_parallel/auto_parallel/sapp_nd/memory_estimation/logger.py 100%  
hyper_parallel/auto_parallel/sapp_nd/nd/__init__.py 100%  
hyper_parallel/auto_parallel/sapp_nd/nd/balancing_adapter.py 46.6% 34,44,47-49,74-75,82-95,97-107,111-121,123,128,133,136-139,143-144,146-155,159-168,177-179,188-192,201-210,213,216-217,220,225,238,244,251,262-267,321-323,344,346,349,354-360
hyper_parallel/auto_parallel/sapp_nd/nd/common/cost_model_config.py 100%  
hyper_parallel/auto_parallel/sapp_nd/nd/common/hardware.py 71.4% 130,137
hyper_parallel/auto_parallel/sapp_nd/nd/debug.py 17.8% 51,55-76,92,101,114-118,126,130-132,136-138,142-149,154-166,168-170,175-182,187,197-203,208-210,223,228-230,240-252,267-273,277-279,286-299,306-307,309-310,312,316-324,332-340,343-346,353-362,369,372,374,377,380-381,383-384,391-392,394,397,401-403,414-416,418-419,425,428,435-437,439,450-452,460-467,470-479,481-482,487-504,509-519,524-532,534-537,540-546,548-551,554-556,561,564,572-573,578,581,584,587,590,594,600-604,606,609,611-613,615-617,619,625-627,629-633,635-636,638,641,645,647,649,653,655-658,664-666,668-671,676-678,683-689,691-692,697-703,705,708-710,715-718,723-724,730-746,751-754,756-757,759-761,763-764,766,771-774,779-780,786-787,790,797,799,802
hyper_parallel/auto_parallel/sapp_nd/nd/global_config.py 80.8% 39-40,55,73,93,102-104,122,138-139,183-185,211,216,222,228,237-238,240,242,248,250-251
hyper_parallel/auto_parallel/sapp_nd/nd/logger.py 100%  
hyper_parallel/auto_parallel/sapp_nd/nd/parallelize.py 43.4% 54,68,75,85,96,139,159-160,163,168,183,192-198,201,206-207,210,214-218,221-227,229,248,250,304,311-312,321-322,328,335-347,356-357,362,369-373,375,380,382,384-389,391-392,396-399,402-404,410-414,418-420,424-427,430-434,439-443,445-447,451,462-463,466-467,470,475-476,479,485-489,496,500-508,514-515,522,533,538-539,542-543,550-551,557-558,562-563,572,587,614,619,624,627-628,651-654,656-679,684,686-687,698
hyper_parallel/auto_parallel/sapp_nd/nd/run_nd.py 0.0% 17-18,20-24,26-27,33,40,49,56,63,84,99,115,121,135,142,149,157,166,174,180,187,189,195-197,202,204-207,209-212,218-220,226,228,243-244,246
hyper_parallel/auto_parallel/sapp_nd/perf_estimation/__init__.py 100%  
hyper_parallel/auto_parallel/sapp_nd/perf_estimation/comm_time.py 14.0% 42-45,47-48,51,54-55,57-58,66-67,72-74,76-77,79-80,83-84,89-92,112-114,119,137-139,144,155-156,158,160,162-163,166,168,170-171,173-177,179-183,192-194,202-204,206-208,210-212,214-220,222-224,226,231-233,235-239,245,247,251-252,254-255,257-258,264,270,277,283,290-291,296-299,301,304,307,310-311,317,323,326,329,332-338,343-349,351,354,357,366-368,391-394,396-397,399-404,406-409,411-419,421-423,425,430-431,434,438,450-454,459-463,468-472,477-479,484-486,493-495,498,507
hyper_parallel/auto_parallel/sapp_nd/perf_estimation/estimate.py 11.1% 47,49,52-58,60-61,64-65,76-77,79,85-86,88-89,92,96-98,100-106,108-109,120,122-123,127-128,130,137-138,140,142,147,157,160,164-165,170,173,181-183,187-206,210,212,221-223,226,229,232,234,240-244,250-259,263,265,269-270,274,281-283,285-286,292-296,305,308,313,315-318,325-326,332-333,338-340,350,365-367,369,374,379-380,382-385,390-392,395-396,399-400,402-403,405-411,413-414,416,418-419,421-423,425-426,429-430,438,441-442,449,458,460-461,463,471-474,477,479-481,483-487,489-490,492,500-501,507-508,511-513,516,519-520,522,525-526,528-529,532-534,542-544,546,555-557,559,562,569,572-573,581,590,592-593,596-597,599-605,607-613
hyper_parallel/auto_parallel/sapp_nd/perf_estimation/getters.py 22.2% 27,30,32-36,38,43-50,55-57,59,64
hyper_parallel/auto_parallel/sapp_nd/perf_estimation/utils_classes.py 80.0% 69-72,75,83
hyper_parallel/auto_parallel/sapp_nd/nd/balancing_adapter.py
30
31
32
33
34
35
36
37
38
        self.pp = pp
        self.vpp = vpp

    def __str__(self):
        return "PP: " + str(self.pp) + ", VPP: " + str(self.vpp)

    def chunk_stage(self):
        """Product of chunks and stages"""
        return self.pp * self.vpp
40
41
42
43
44
45
46
47
48
49
50
51
52
53

def infer_pp_and_vpp(offset):
    """Return a pipeline configuration inferred from an offset"""
    if is_zero_d(offset):
        return Pipeline(1, 1)
    if is_one_d(offset):
        return Pipeline(len(offset), 1)
    if is_two_d(offset):
        return Pipeline(len(offset[0]), len(offset))
    raise TypeError(f"Offset {offset} has a wrong type!")


class BalancingAdapter:
    """Adapt pipeline balancing to a given 'new' pipeline configuration"""
70
71
72
73
74
75
76
77
78
79
        self.prev_pip = infer_pp_and_vpp(offset)

    def treat_pp_list(self, new_pip, stages):
        """Treat 1D offset or recompute config"""
        current_pp_len = len(stages)
        logger.debug(
            "new pp (%d) = prev pp (%d)? l = %s, len(l) = %d",
            new_pip.pp,
            self.prev_pip.pp,
            str(stages),
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
            self.prev_pip.pp,
            str(stages),
            len(stages),
        )
        if new_pip.pp == current_pp_len:
            return stages
        new_l = []
        if new_pip.pp > current_pp_len:
            while new_pip.pp % len(stages) != 0:
                stages.insert(len(stages) // 2, 0)
            logger.debug("stages: %s", str(stages))
            factor = new_pip.pp // current_pp_len
            for s in stages:
                rest = s % factor
                for _ in range(factor):
                    if rest > 0:
                        new_l.append(s // factor + 1)
                        rest -= 1
                    else:
                        new_l.append(s // factor)
        if new_pip.pp < current_pp_len:
            for _ in range(current_pp_len % new_pip.pp):
                stages.insert(len(stages) // 2, 0)
            factor = current_pp_len // new_pip.pp
            for i in range(new_pip.pp):
                total_rec_layers = 0
                for j in range(factor):
                    total_rec_layers += stages[i * factor + j]
                new_l.append(total_rec_layers)
        return new_l

    def treat_vpp_list(self, new_pip, ll):
        """Treat 2D offset or recompute config"""
        logger.debug("treat_vpp_list start: ll = %s", str(ll))
        prev_vpp = len(ll)
        if prev_vpp < new_pip.vpp:
            if prev_vpp == 1:
                ll.append([0] * self.prev_pip.pp)
                ll[1][-1] = ll[0][-1]
                ll[0][-1] = 0
                prev_vpp = 2
            for _ in range(new_pip.vpp - prev_vpp):
                ll.insert(prev_vpp // 2, [0] * self.prev_pip.pp)
            logger.debug("B: ll = %s", str(ll))

            prev_layer_per_stage = [
                prev_vpp * (self.layers // self.prev_pip.chunk_stage())
                + sum(stages[p] for stages in make_two_d(self.prev_offset))
                for p in range(self.prev_pip.pp)
            ]
            new_layer_per_stage = [
                new_pip.vpp * (self.layers // self.prev_pip.chunk_stage())
                + sum(stages[p] for stages in ll)
                for p in range(self.prev_pip.pp)
            ]
            logger.debug(
                "prev_layer_per_stage = %s", str(prev_layer_per_stage)
            )
            logger.debug("new_layer_per_stage = %s", str(new_layer_per_stage))
            for s in range(self.prev_pip.pp):
                for v in range(new_pip.vpp):
                    if (
                        prev_layer_per_stage[s] - new_layer_per_stage[s] > 0
                        and ll[v][s] < 1
                    ):
                        ll[v][s] += 1
                        new_layer_per_stage[s] += 1

        elif prev_vpp > new_pip.vpp:
            for _ in range(prev_vpp - new_pip.vpp):
                ll[-2] = [sum(x) for x in zip(ll[-1], ll[-2])]
                del ll[-1]
        logger.debug("C: ll = %s", str(ll))
        new_vpp_list = []
        for stages in ll:
            new_vpp_list.append(self.treat_pp_list(new_pip, stages))
        logger.debug("D: new_vpp_list = %s", str(new_vpp_list))
        return new_vpp_list

    def treat_recompute_list(self, new_pip, recompute):
        """Treat recompute config recursively"""
        if all(isinstance(x, int) for x in recompute):
            if all(isinstance(x, int) for x in recompute):
                if new_pip.vpp == 1:
                    return self.treat_pp_list(new_pip, recompute)
                return self.treat_vpp_list(new_pip, recompute)
        elif all(isinstance(x, int) for stages in recompute for x in stages):
            if new_pip.vpp == 1:
                return self.treat_vpp_list(new_pip, recompute)[0]
            return self.treat_vpp_list(new_pip, recompute)
        return recompute

    def treat_recompute(self, new_pp, new_vpp):
        """Treat recompute config for the new given Pipeline config"""
        new_pip = Pipeline(new_pp, new_vpp)
173
174
175
176
177
178
179
180
181
182
183
        if not self.from_config:
            return self.default_recompute(
                new_pip, copy.deepcopy(self.prev_recompute)
            )
        if new_pip == self.prev_pip or isinstance(self.prev_recompute, bool):
            return copy.deepcopy(self.prev_recompute)
        return self.treat_recompute_list(
            new_pip, copy.deepcopy(self.prev_recompute)
        )

    def treat_offset(self, new_pp, new_vpp):
184
185
186
187
188
189
190
191
192
193
194
195
196
        """Treat offset for the new given Pipeline config"""
        new_pip = Pipeline(new_pp, new_vpp)
        if not self.from_config:
            return self.make_valid(new_pip, self.default_offset(new_pip))
        offset = copy_offset(self.prev_offset)
        if new_pip == self.prev_pip:
            logger.debug("Same pipeline, no offset change")
            return offset
        logger.debug(
            "change offset %s from PP = %d, VPP = %d, to PP = %d, VPP = %d",
            str(offset),
            self.prev_pip.pp,
            self.prev_pip.vpp,
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
            new_pip.pp,
            new_pip.vpp,
        )

        if is_zero_d(offset):
            offset = []
            for _ in range(new_pip.vpp):
                offset.append([])
                for _ in range(new_pip.pp):
                    offset[-1].append(0)
            return self.make_valid(new_pip, offset)
        if is_one_d(offset):
            if new_pip.vpp == 1:
                return self.make_valid(
                    new_pip, self.treat_pp_list(new_pip, offset)
                )
            return self.make_valid(
                new_pip, self.treat_vpp_list(new_pip, [offset])
            )
        if is_two_d(offset):
            return self.make_valid(
                new_pip, self.treat_vpp_list(new_pip, offset)
            )
        raise TypeError(f"Offset {offset} has a wrong type!")

    def check_offset(self, new_pip, offset):
        """Check offset validity"""
        if is_zero_d(offset):
            return (new_pip.pp == 1) and (new_pip.vpp == 1)
        if is_one_d(offset):
            return (
                len(offset) == new_pip.pp
                and sum(offset) == self.layers % new_pip.pp
234
235
236
237
238
239
240
241
242
                and all(len(stages) == new_pip.pp for stages in offset)
                and sum(sum(stages) for stages in offset)
                == self.layers % (new_pip.chunk_stage())
            )
        return False

    def offset_checker(self, new_pp, new_vpp, offset):
        """Log an error message if offset is invalid"""
        new_pip = Pipeline(new_pp, new_vpp)
240
241
242
243
244
245
246
247
248
    def offset_checker(self, new_pp, new_vpp, offset):
        """Log an error message if offset is invalid"""
        new_pip = Pipeline(new_pp, new_vpp)
        if not self.check_offset(new_pip, offset):
            logger.error(
                "offset %s is wrong!! pp = %d, vpp = %d, L = %d",
                str(offset),
                new_pip.pp,
                new_pip.vpp,
247
248
249
250
251
252
253
254
255
                new_pip.pp,
                new_pip.vpp,
                self.layers,
            )
            return False
        return True

    def make_valid(self, new_pip, offset):
        """Transform an invalid offset into a valid one"""
258
259
260
261
262
263
264
265
266
267
268
269
270
271

        delta = self.layers % (new_pip.chunk_stage()) - sum(flat)
        logger.debug("delta = %s", str(delta))
        if delta < 0:
            for _ in range(-delta):
                top = max(flat)
                for i, _ in enumerate(flat):
                    if flat[i] == top:
                        flat[i] = flat[i] - 1
                        break
        elif delta > 0:
            for _ in range(delta):
                bot = min(flat)
                for i, _ in enumerate(flat):
317
318
319
320
321
322
323
324
325
326
327


def copy_offset(offset):
    """Copy an offset"""
    if is_zero_d(offset):
        return offset
    return copy.deepcopy(offset)


def is_zero_d(offset):
    """Check if offset is an int"""
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360

def make_one_d(offset):
    """Transform offset is an int list"""
    if is_zero_d(offset):
        return [0]
    if is_one_d(offset):
        return offset
    if is_two_d(offset):
        return [x for stages in offset for x in stages]
    raise TypeError(f"Offset {offset} has a wrong type!")


def make_two_d(offset):
    """Transform offset is an int list list"""
    if is_zero_d(offset):
        return [[0]]
    if is_one_d(offset):
        return [offset]
    if is_two_d(offset):
        return offset
    raise TypeError(f"Offset {offset} has a wrong type!")
hyper_parallel/auto_parallel/sapp_nd/nd/common/hardware.py
126
127
128
129
130
131
132
133
134
    def __init__(self, number, device):
        self.number = number
        if isinstance(device, int):
            if device == 2:
                self.device = Device_A2
            elif device == 3:
                self.device = Device_A3
            else:
                raise ValueError(f"Ascend A{device} unknown")
133
134
135
136
137
138
139
140
141
            else:
                raise ValueError(f"Ascend A{device} unknown")
        elif isinstance(device, str):
            if device not in device_map:
                raise ValueError(
                    f"Device {device} is not supported. "
                    f"Supported devices: {list(device_map.keys())}"
                )
            self.device = device_map[device]
hyper_parallel/auto_parallel/sapp_nd/nd/debug.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    TOTAL = auto()
    MEMORY = auto()

    def __str__(self):
        return self.name

    def short_name(self):
        """Returns short component name"""
        name = "Perf"
        if self == self.FW_COMPUTE:
            name = "FW"
        elif self == self.BW_COMPUTE:
            name = "BW"
        elif self == self.RECOMPUTE:
            name = "Rec"
        elif self == self.DP_COMM:
            name = "DP"
        elif self == self.MP_COMM:
            name = "MP"
        elif self == self.EP_COMM:
            name = "EP"
        elif self == self.CP_COMM:
            name = "CP"
        elif self == self.PP_COMM:
            name = "P2P"
        elif self == self.BUBBLE:
            name = "BBL"
        elif self == self.MEMORY:
            name = "MEM"
        return name


class RealParts(Enum):
    """decomposition of performance"""
88
89
90
91
92
93
94
95
96
    IDLE = auto()
    TOTAL = auto()

    def __str__(self):
        return self.name.lower()


class MemParts(Enum):
    """decomposition of memory"""
 97
 98
 99
100
101
102
103
104
105

    TOTAL = auto()

    def __str__(self):
        return self.name


class Debug:
    """Debugging tools"""
110
111
112
113
114
115
116
117
118
119
120
121
122
        info_type,
        enable=True,
        output_file="debug.csv",
    ):
        self.enable = enable
        if self.enable:
            self.parallel_dimensions = parallel_dimensions
            self.info = {p: 0 for p in info_type}
            self.output_file = (
                os.path.dirname(os.path.abspath(__file__))
                + "/output/"
                + output_file
            )
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
            )

    def is_enabled(self):
        """Check whether debugging is enabled"""
        return self.enable

    def column_titles(self):
        """Parameters to debug"""
        titles = self.parallel_dimensions.keys() + list(self.info.keys())
        titles = [str(t) for t in titles]
        return ",".join(titles) + "\n"

    def values(self):
        """values debugged"""
        str_dims = [str(v) for v in self.parallel_dimensions.values()]
        str_score = [str(int(v)) for v in self.info.values()]
        return ",".join(str_dims + str_score) + "\n"

    def write(self):
        """Parameters to debug"""
        if self.enable:
            os.makedirs(os.path.dirname(self.output_file), exist_ok=True)
            is_new = not os.path.exists(self.output_file)
            logger.info("debug written")
            with open(self.output_file, "a", encoding="utf-8") as outfile:
                if is_new:
                    outfile.write(self.column_titles())
                outfile.write(self.values())


def pastel(color, l_delta=0.0, lbl=None, sat=None):
    "Pastel (lighter) color of the input"
    if color == "white":
        return (1.0, 1.0, 1.0)
    if color == "black":
        return (0.5, 0.5, 0.5)
    try:
        color = mc.cnames[color]
    except KeyError:
        pass
    color_hls = colorsys.rgb_to_hls(*mc.to_rgb(color))
    lgt = 0.7
    if lbl is not None:
        lgt = lbl
    lgt = lgt + l_delta

    if sat is None:
        sat = 0.6
    return colorsys.hls_to_rgb(color_hls[0], lgt, sat)


def near_white(color, ratio):
    "Very light color of the input for background"
    rgb = mc.to_rgb(color)
    if rgb is None:
        return "white"
    (red, green, blue) = rgb
    red += (1 - red) * ratio
    green += (1 - green) * ratio
    blue += (1 - blue) * ratio
    return (red, green, blue)


def dim_color(dim, default="black"):
    """Color of parallel dimensions for plot"""
    color = {
        Dim.DP: "orange",
        Dim.OP: "orange",
        Dim.TP: "red",
        Dim.EP: "blue",
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        Dim.PP: "green",
        Dim.VPP: "green",
        Dim.MBN: "green",
    }
    try:
        dim = Dim.get_dim(dim)
        if dim in color:
            return color[dim]
        return default
    except ValueError:
        return default


def gen_colors(categories):
    """Color of each time component"""
    compute_color = "purple"
    idle_color = "grey"
    col_d = {
        str(PerfParts.FW_COMPUTE): pastel(compute_color, -0.2),
        str(PerfParts.BW_COMPUTE): pastel(compute_color, -0.1),
        str(PerfParts.RECOMPUTE): pastel(compute_color),
        str(PerfParts.DP_COMM): pastel(dim_color(Dim.DP)),
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        str(PerfParts.BUBBLE): pastel(dim_color(Dim.PP), -0.15),
        "IDLE": idle_color,
        "COMPUTATION": pastel(compute_color, -0.2),
    }
    return [col_d[cat] for cat in categories]


def set_twin_handles(ax1, data_frame, dbg_cols):
    """Set legend for estimation and real"""
    handle1, label1 = ax1.get_legend_handles_labels()
    ax2 = plt.twinx()
    data_frame[dbg_cols].plot.bar(
        stacked=True,
        sharex=True,
        ax=ax2,
        position=0,
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        width=0.4,
        rot=0,
    )

    handle2, label2 = ax2.get_legend_handles_labels()  # type: ignore
    for handle in handle2:
        if handle not in handle1:
            handle1.append(handle)
    for lbl in label2:
        if lbl not in label1:
            label1.append(lbl)
    handles = handle1
    labels = label1
    plt.legend(handles, labels, loc="upper left", bbox_to_anchor=(1, 1))
    leg = ax2.get_legend()
    pp_color = gen_colors(["PP_COMM"])[0]
    leg.legend_handles[-1].set_facecolor(pp_color)  # type: ignore


class Plot:
    """plot ND top configs"""
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    dbg_cols: list[str]
    top: int

    def __init__(self, title, rows, debug_parts, top=None):
        self.title = title
        self.top = top if top is not None else 20
        self.row_title = rows + ["MEM"]
        self.dbg_cols = list(map(str, debug_parts))
        self.col_title = []
        self.cell_text = []
        self.data = []

    def make_table(self):
        """Make table below plot with each parallelism degree"""
        self.cell_text = list(map(list, zip(*self.cell_text)))  # transpose
        max_rows = list(map(max, map(partial(map, float), self.cell_text)))
        the_table = plt.table(
            cellText=self.cell_text,
            rowLabels=self.row_title,
            colLabels=self.col_title,
            cellLoc="center",
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
            colLabels=self.col_title,
            cellLoc="center",
            loc="bottom",
        )
        row_colors = list(map(dim_color, self.row_title))
        for row in range(len(self.row_title)):
            cell = the_table[row + 1, -1]
            cell.set_edgecolor("none")
            cell.get_text().set_color(row_colors[row])
            cell.set_text_props(fontproperties=FontProperties(weight="bold"))
            for col in range(len(self.cell_text[0])):
                cell = the_table[row + 1, col]
                value = float(str(cell.get_text().get_text()))
                try:
                    ratio = 1 - (value / max_rows[row])
                except ZeroDivisionError:
                    ratio = 0
                logger.debug(
                    "tmax = %s, ratio = %f, col=%s, newcolor=%s",
                    str(max_rows[row]),
                    ratio,
                    str(mc.to_rgb(row_colors[row])),
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
                    ratio,
                    str(mc.to_rgb(row_colors[row])),
                    str(near_white(row_colors[row], ratio)),
                )
                cell.set_facecolor(near_white(pastel(row_colors[row]), ratio))
                cell.set_edgecolor("none")

        for col in range(len(self.cell_text[0])):
            the_table[0, col].set_edgecolor("none")

        the_table.scale(xscale=1, yscale=1.2)  # +len(rows)/5)

    def close(self, output_path, filename):
        """Plot closing statements"""
        plt.gca().set_xticklabels([])
        plt.gca().set_yticklabels([])
        plt.xlim([-0.5, len(self.data) - 0.5])
        if self.title is not None:
            plt.title(self.title)
        plt.subplots_adjust(left=0.1, bottom=0.047 * (2 + len(self.row_title)))
        plotfile = os.path.join(output_path, filename + ".pdf")
        plt.savefig(plotfile, bbox_inches="tight")
        plt.clf()

    def parse_data(
        self,
        configs_estimated,
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        configs_estimated,
        **kwargs,
    ):
        """Parse test data for plot"""
        real_data = kwargs.get("real_data", None)
        plot_idle = kwargs.get("plot_idle", False)
        min_e = configs_estimated[0][2]
        i = 0
        for cfg_e in configs_estimated:
            self.cell_text.append(cfg_e[0].values() + [cfg_e[1]])
            self.col_title.append("")
            try:
                self.data.append(
                    tuple([cfg_e[0], cfg_e[2], cfg_e[3]] + cfg_e[4])
                )
                if real_data is not None:
                    waits = cfg_e[5]
                    logger.info(waits)
                    wait_list = [
                        waits["comp"],
                        waits["dp_wait"],
                        waits["mp_wait"],
                        waits["ep_wait"],
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
                        waits["mp_wait"],
                        waits["ep_wait"],
                        waits["BUBBLE"],
                    ]
                    if plot_idle:
                        wait_list.append(waits["IDLE"])
                    real_data.append(tuple(wait_list))
            except IndexError:
                score = cfg_e[2]
                if i >= self.top or (min_e is not None and score > min_e * 20):
                    self.cell_text.pop()
                    break
                self.data.append(tuple([cfg_e[0], score] + cfg_e[3]))
                i += 1


def plot_nd(
    configs_estimated, output_path, debug_parts, title=None, max_num=None
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def plot_nd(
    configs_estimated, output_path, debug_parts, title=None, max_num=None
):
    """Plot estimation"""
    plot = Plot(
        title, configs_estimated[0][0].keys(), debug_parts, top=max_num
    )
    plot.parse_data(configs_estimated)

    data_frame = pd.DataFrame(
        plot.data, columns=(["config", "estim"] + plot.dbg_cols)
    )
    axis = data_frame[plot.dbg_cols].plot.bar(
        stacked=True, color=gen_colors(plot.dbg_cols), width=0.4, rot=0
    )
    axis.set_ylim(ymin=1)
    axis.legend(loc="upper left", bbox_to_anchor=(1, 1))

    plot.make_table()
    plot.close(output_path, "results")


def plot_vs_real(
    configs_estimated, csv_f, output_path, debug_parts, title=None
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def plot_vs_real(
    configs_estimated, csv_f, output_path, debug_parts, title=None
):
    """Plot estimation vs real global time"""
    plot = Plot(title, configs_estimated[0][0].keys(), debug_parts)
    plot.parse_data(configs_estimated)

    data_frame = pd.DataFrame(
        plot.data, columns=(["config", "Real", "estim"] + plot.dbg_cols)
    )
    ax1 = data_frame["Real"].plot.bar(
        position=1.1, width=0.4, secondary_y="real", color="grey", rot=0
    )

    set_twin_handles(ax1, data_frame, plot.dbg_cols)
    plot.make_table()
    plot.close(output_path, Path(os.path.basename(csv_f)).stem)


def plot_vs_real_comm_classified(
    configs_estimated,
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    debug_parts,
    **kwargs,
):
    """Plot estimation vs real detailed time"""
    plot_idle = kwargs.get("plot_idle", False)
    title = kwargs.get("title", None)
    real_data = []

    plot = Plot(title, configs_estimated[0][0].keys(), debug_parts)
    plot.parse_data(
        configs_estimated,
        real_data=real_data,
        plot_idle=plot_idle,
    )
421
422
423
424
425
426
427
428
429
430
431
432
        real_data=real_data,
        plot_idle=plot_idle,
    )

    data_frame = pd.DataFrame(
        plot.data, columns=(["config", "real", "estim"] + plot.dbg_cols)
    )
    real_cols = [
        "COMPUTATION",
        "DP_COMM",
        "MP_COMM",
        "EP_COMM",
431
432
433
434
435
436
437
438
439
440
441
442
443
        "MP_COMM",
        "EP_COMM",
        "BUBBLE",
    ]
    if plot_idle:
        real_cols.append("IDLE")
    real_df = pd.DataFrame(real_data, columns=real_cols)

    ax1 = real_df[real_cols].plot.bar(
        stacked=True,
        sharex=True,
        position=1,
        secondary_y="real",
446
447
448
449
450
451
452
453
454
455
        rot=0,
        legend=False,
    )

    set_twin_handles(ax1, data_frame, plot.dbg_cols)
    plot.make_table()
    plot.close(
        output_path,
        Path(os.path.basename(csv_f)).stem,
    )
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568


def correlation_topk(configs_estimated, csv_f):
    """Computes correlation & top-k between real & estimation"""
    times = []
    estims = []
    for _, _, time, score, _ in configs_estimated:
        times.append(time)
        estims.append(score)
    correl = pearsonr(times, estims).statistic  # type: ignore
    if isnan(correl):
        logger.critical(
            "An input array is constant: %s or %s", str(times), str(estims)
        )
    topk = 0
    for i, score in enumerate(estims):
        if not score == min(estims[i:]):
            break
        topk += 1
    if topk == 0:
        for i, score in enumerate(estims):
            if score == min(estims[i:]):
                break
            topk -= 1

    logger.info("Correlation for file %s is: %.3f", csv_f, correl * 100)
    return correl, topk


def get_real_data(csv_f):
    """Read execution time of different configurations on a given csv file"""
    configs = []
    row_num = 0
    with open(csv_f, newline="", encoding="utf-8") as csv_file:
        rows = csv.DictReader(csv_file)
        for row in rows:
            row_num += 1
            logger.info(row)
            real_time = float(row.pop("time"))
            config = []
            for dim_str, value in row.items():
                try:
                    dim = Dim.get_dim(dim_str)
                    logger.debug("%s : %s", str(dim), str(dim.from_str(value)))
                    config.append((dim, dim.from_str(value)))
                except ValueError:
                    pass
            configs.append((Dim.Dimensions(config), real_time))
    return configs, row_num


def get_diff_dims(csv_f):
    """Read execution time of different configurations on a given csv file"""
    dims = []
    data_frame = pd.read_csv(csv_f)
    for dim_str, degrees in data_frame.items():
        try:
            dim = Dim.get_dim(dim_str)
            diff_values = len(set(degrees))
            if diff_values > 1:
                dims.append(dim)
        except ValueError:
            pass
    return dims


def get_comm_classified_data(csv_f, plot_idle=False):
    """Read time components of different configurations on a given csv file"""
    configs = []
    with open(csv_f, newline="", encoding="utf-8") as csv_file:
        rows = csv.DictReader(csv_file)
        for row in rows:
            logger.info(row)
            time = float(row.pop("time"))
            config = []
            comm_wait_time_classified = {}
            total_wait = 0

            for component, value_str in row.items():
                if "wait" in component:
                    value_float = float(value_str)
                    logger.info(
                        "Comm_wait = %s, v = %f", component, value_float
                    )
                    comm_wait_time_classified[component] = value_float
                    total_wait += value_float
                elif "comp" in component:
                    value_float = float(value_str)
                    logger.info("Computation = %f", value_float)
                    comm_wait_time_classified["comp"] = value_float
                    total_wait += value_float
                else:
                    logger.info("d = %s, v = %s", component, value_str)
                    dim = Dim.get_dim(component)
                    config.append((dim, dim.from_str(value_str)))
            comm_wait_time_classified["BUBBLE"] = comm_wait_time_classified[
                str(RealParts.PP_WAIT)
            ]
            if plot_idle:
                comm_wait_time_classified["IDLE"] = time - total_wait
                logger.info(
                    "idle = total time - total waits = %.3f - %.3f",
                    time,
                    total_wait,
                )
            configs.append(
                (Dim.Dimensions(config), time, comm_wait_time_classified)
            )
    return configs


def estimation_in_real_parts(
    estimations_in_real_components, estimations, score
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
    estimations_in_real_components, estimations, score
):
    """Transform the estimation components into
    the RealParts components for comparison with real time"""
    estimations_in_real_components[RealParts.TOTAL].append(score)
    estimations_in_real_components[RealParts.COMP].append(
        estimations[PerfParts.FW_COMPUTE.value - 1]
        + estimations[PerfParts.BW_COMPUTE.value - 1]
        + estimations[PerfParts.RECOMPUTE.value - 1]
    )
    estimations_in_real_components[RealParts.DP_WAIT].append(
        estimations[PerfParts.DP_COMM.value - 1]
    )
    estimations_in_real_components[RealParts.MP_WAIT].append(
        estimations[PerfParts.MP_COMM.value - 1]
    )
    estimations_in_real_components[RealParts.CP_WAIT].append(
        estimations[PerfParts.CP_COMM.value - 1]
    )
    estimations_in_real_components[RealParts.EP_WAIT].append(
        estimations[PerfParts.EP_COMM.value - 1]
    )
    estimations_in_real_components[RealParts.PP_WAIT].append(
        estimations[PerfParts.BUBBLE.value - 1]
        + estimations[PerfParts.PP_COMM.value - 1]
    )
    return estimations_in_real_components


def real_in_parts(parts, real, time):
    """Transform the real time components into
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623

def real_in_parts(parts, real, time):
    """Transform the real time components into
    the RealParts components for comparison with estimation"""
    parts[RealParts.TOTAL].append(time)
    for part in RealParts:
        if part not in {RealParts.TOTAL, RealParts.IDLE}:
            if str(part) in real.keys():
                parts[part].append(real[str(part)])
            else:
                logger.warning(
                    "part = %s not in real keys = %s", part, real.keys()
                )
                parts[part].append(0)

    op = "op_wait"
    if op in real.keys():
        parts[RealParts.DP_WAIT][-1] += real["op_wait"]

    sp = "sp_wait"
    if sp in real.keys():
        parts[RealParts.MP_WAIT][-1] += real["sp_wait"]

    return parts


def correlation_with_classified_comms(configs_estimated):
    """Computes correlation and distance
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662

def correlation_with_classified_comms(configs_estimated):
    """Computes correlation and distance
    between components time & estimation"""
    score_classified = {}
    time_classified = {}
    distances = {}

    for wait in RealParts:
        if wait not in {RealParts.IDLE}:
            score_classified[wait] = []
            time_classified[wait] = []
            distances[wait] = []

    topk = 0
    still_top_k = True

    for i, (_, _, time, score, values, real_values) in enumerate(
        configs_estimated
    ):
        if (
            still_top_k
            and score == (min(configs_estimated[i:], key=lambda t: t[3]))[3]
        ):
            topk += 1
        else:
            still_top_k = False

        score_classified = estimation_in_real_parts(
            score_classified, values, score
        )

        time_classified = real_in_parts(time_classified, real_values, time)

        square_distances_sum = 0
        for wait in RealParts:
            if wait not in {RealParts.TOTAL, RealParts.IDLE}:
                distance = (
                    time_classified[wait][-1]
                    / time_classified[RealParts.TOTAL][-1]
                    - score_classified[wait][-1]
                    / score_classified[RealParts.TOTAL][-1]
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
                    / time_classified[RealParts.TOTAL][-1]
                    - score_classified[wait][-1]
                    / score_classified[RealParts.TOTAL][-1]
                )
                square_distances_sum += distance * distance
                distances[wait].append(abs(distance))
        distances[RealParts.TOTAL] = sqrt(square_distances_sum)

    correls = {}
    for wait in RealParts:
        pearson_wait(correls, time_classified, score_classified, wait)
    return correls, distances, topk, len(configs_estimated)


def color_diff(diff):
    """Color difference"""
    if diff > 0:
        return f"\033[92m improved by {diff:.3f}%\033[00m"
    return f"\033[91m worsened by {-diff:.3f}%\033[00m"


def color_correl(correlation):
    """Color correlation"""
    res = f"{correlation*100:.3f}%"
    if correlation > 0.9:
        res = f" \033[92m{res}\033[00m "
    elif correlation < 0:
        res = f"\033[91m{res}\033[00m "
    elif correlation < 0.5:
        res = f" \033[91m{res}\033[00m "
    else:
        res = f" \033[00m{res}\033[00m "
    return res


def print_diff(case, prev, new, **kwargs):
    """Print difference of correlation"""
    topk = kwargs.get("topk", None)
    total = kwargs.get("total", None)
    tabsize = kwargs.get("tabsize", 40)
    diff = (new - prev) * 100
    msg = ""
    if -0.1 < diff < 0.1:
        msg = f"{case} \tcorrelation :{color_correl(new)}  \033[00m\033[00m"
    else:
        msg = (
            f"{case} \tcorrelation ({color_correl(new)}) is{color_diff(diff)}"
        )
    if topk is not None and total is not None:
        msg += f"   topk = {topk}/{total}"
    logger.output(msg.expandtabs(tabsize))


def get_distance_i(part, data_i):
    """get the average distance of a given part"""
    _, distance, _, _ = data_i
    if part is RealParts.TOTAL:
        return distance[part]
    return sum(distance[part]) / len(distance[part])


def get_correl_i(part, data_i):
    """get the correlation of a given part"""
    f_correl, _, _, _ = data_i
    return f_correl[part]


def print_part_x_file(data, fun):
    """prints a metric computed by fun for
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784

def print_part_x_file(data, fun):
    """prints a metric computed by fun for
    each couple (part, file)"""
    msg = ""
    for part in RealParts:
        if part is not RealParts.IDLE:
            msg += "\n" + str(part) + "\t"
            col_sum = 0
            col_num = 0
            for data_i in data:
                try:
                    info = fun(part, data_i)
                    msg += f"\t{(info*100):.1f}%"
                    col_sum += info
                    col_num += 1
                except KeyError:
                    msg += "\t  -"
            if col_num > 0:
                msg += f"\t\t{(col_sum/col_num)*100:.1f}%"
    return msg


def print_correlations_classified(data):
    """Printer for estimation vs detailed profiling"""
    msg = "\n\t"
    for i, _ in enumerate(data):
        msg += "\tFile " + str(i + 1)
    msg += "\t\tavg"

    msg += "\nCorrelation (higher is better)"
    msg += print_part_x_file(data, get_correl_i)

    msg += "\ntop_k\t"
    for _, _, top_k, total in data:
        msg += "\t" + str(top_k) + "/" + str(total)

    msg += "\n\nEuclidean Distance (lower is better)"
    msg += print_part_x_file(data, get_distance_i)

    logger.output(msg)


def is_constant(array):
    """Whether the given array only has the same elements"""
    if len(array) == 0:
        return True
    value = array[0]
    return all(vi == value for vi in array)


def pearson_wait(correls, real, estim, wait):
    """Compute Pearson correlation if inputs are not empty"""
    if wait not in {RealParts.IDLE}:
        logger.debug(
            "correlation for %s between real = %s && estim = %s",
            str(wait),
            str(real[wait]),
            str(estim[wait]),
782
783
784
785
786
787
788
789
790
791
792
793
794
            str(wait),
            str(real[wait]),
            str(estim[wait]),
        )
        if not is_constant(real[wait]) and not is_constant(estim[wait]):
            pearson = pearsonr(
                real[wait], estim[wait]
            ).statistic  # type: ignore
            logger.info(
                "correlation[%s] of real %s vs estim %s = %f",
                wait,
                real[wait],
                estim[wait],
793
794
795
796
797
798
799
800
801
802
803
804
                real[wait],
                estim[wait],
                pearson,
            )
            correls[wait] = pearson
        else:
            logger.warning(
                "either estim[%s] = %s is constant", wait, str(estim[wait])
            )
            logger.warning(
                "or      real[%s] = %s is constant", wait, str(real[wait])
            )
hyper_parallel/auto_parallel/sapp_nd/nd/global_config.py
35
36
37
38
39
40
41
42
43
44
        if dimensions is not None:
            logger.debug("dimensions = %s", str(dimensions))
            self.dimensions = dimensions
        else:
            logger.debug("dimensions = %s", str(Dim.ALL_DIMS))
            self.dimensions = Dim.ALL_DIMS.copy()
        logger.debug("self.dimensions = %s", str(self.dimensions))
        logger.debug("layer_num_for_offset = %d", self.layer_num_for_offset())
        logger.debug("total layer num = %d", self.total_layer_num())
        self.balancing = BA.BalancingAdapter(
51
52
53
54
55
56
57
58
59
    def dim_val(self, dim, parallel_config):
        """Get the value of a parallel dimension"""
        if parallel_config.has_dim(dim):
            return parallel_config.val(dim)
        return dim.from_config(self.ccfg)

    def global_batch_size(self, parallel_config):
        """Compute global batch size from hyperparameters"""
        dp = self.dim_val(Dim.DP, parallel_config)
69
70
71
72
73
74
75
76
    def layer_num_for_offset(self):
        """Compute layer number including MTP when necessary for offset"""
        layer_num = self.ccfg.n_lay
        if self.ccfg.emb_out_in_offset:
            layer_num += 2
        if self.ccfg.is_mtp_in_offset:
            layer_num += self.ccfg.n_mtp
        return layer_num
89
90
91
92
93
94
95
96
97
        new_offset = self.balancing.treat_offset(new_pp, new_vpp)
        logger.debug("adapted offset: %s", str(new_offset))
        ok = self.balancing.offset_checker(new_pp, new_vpp, new_offset)
        if not ok:
            logger.error("Offset {%s} NOT VALID", str(new_offset))
        return new_offset, new_recompute_config

    def adapt_config(self, pp, vpp):
        """Adapt configuration to different parallel config"""
 98
 99
100
101
102
103
104
105
106
107
108
        return self.adapt_config_balancing(pp, vpp)

    def write(self, folder, parallel_config):
        """Dump config into a yaml file"""
        if folder:
            file_name = parallel_config.unique_name()
            self.ccfg.config.dump(file_name, folder)

    def moe_valid(self, parallel_config):
        """Check whether  the model is MoE"""
        expert_num = self.ccfg.n_exp
118
119
120
121
122
123
124
125
126
                dp,
                mp,
            )
            return ep <= min(expert_num, dp * mp)
        return True

    def make_parallel_config_args(self, **kwargs):
        """Create a parallel config from parallel values"""
        logger.debug("dimensions considered: %s", str(self.dimensions))
134
135
136
137
138
139
140
141
142
143
            Dim.MBN not in self.dimensions
            and Dim.PP in self.dimensions
            and (Dim.DP in self.dimensions or Dim.MBS in self.dimensions)
        ):
            dims.append((Dim.MBN, kwargs.get(Dim.MBN.lname())))
            self.dimensions.append(Dim.MBN)
        return Dim.Dimensions(dims, all_dims=self.dimensions)

    def make_parallel_config(self, dtpc_p, mbsn, evos_p):
        """Create a parallel config from parallel values"""
179
180
181
182
183
184
185
186
187
188
189
                    "search in predefined arch_hooks"
                )
                check_and_apply_custom_hook(self.ccfg)
            else:
                logger.info("Apply hooks")
                hook = list(self.ccfg.hooks_dict.values())[0]
                hook(self.wrap)

        return ok

    def space(self, dim, divide, reverse=False):
207
208
209
210
211
212
213
214
215
216
217
218
219
220
                str(dim),
                str(Hard.all_divisors(divide, reverse=reverse)),
            )
            return Hard.all_divisors(divide, reverse=reverse)
        logger.debug(
            "Space of original dim %s is [%s]",
            str(dim),
            str(dim.from_config(self.ccfg)),
        )
        return [dim.from_config(self.ccfg)]

    def range_space(self, dim, bound):
        """Generate the space for a given dimension"""
        if dim in self.dimensions:
218
219
220
221
222
223
224
225
226
    def range_space(self, dim, bound):
        """Generate the space for a given dimension"""
        if dim in self.dimensions:
            return range(1, bound + 1)
        return [dim.from_config(self.ccfg)]

    def bool_space(self, dim):
        """Generate the space for a given boolean dimension"""
        if dim in self.dimensions:
224
225
226
227
228
229
230
231
232
    def bool_space(self, dim):
        """Generate the space for a given boolean dimension"""
        if dim in self.dimensions:
            return [False, True]
        return [dim.from_config(self.ccfg)]

    def max_op(self, dp, tp, ep):
        """Compute bound for dimension OP"""
        if (
233
234
235
236
237
238
239
240
241
242
243
244
245
246
            isinstance(self.ccfg.optimizer, str)
            and "muon" not in self.ccfg.optimizer.lower()
        ):
            return dp
        if self.ccfg.n_exp and self.ccfg.n_exp > 1:
            exp_gcd = gcd(dp * tp // max(tp, ep), self.ccfg.n_exp)
        else:
            exp_gcd = dp

        if (
            self.ccfg.dc_kv
            and self.ccfg.dc_kv > 1
            and self.ccfg.dhr
            and self.ccfg.dhr > 1
244
245
246
247
248
249
250
251
            and self.ccfg.dc_kv > 1
            and self.ccfg.dhr
            and self.ccfg.dhr > 1
        ):
            att_gcd = gcd(self.ccfg.h, self.ccfg.dc_kv + self.ccfg.dhr)
        else:
            att_gcd = self.ccfg.h
        return gcd(exp_gcd, att_gcd)
hyper_parallel/auto_parallel/sapp_nd/nd/parallelize.py
50
51
52
53
54
55
56
57
58
        self.machine = machine
        if "mppb" in extra_config:
            manual_ppb = extra_config.pop("mppb")
        else:
            manual_ppb = False

        self.mem_eval = evaluator

        self.model_name = self.mem_eval._ccfg.model_name
64
65
66
67
68
69
70
71
72

        if "max_mem" in extra_config:
            max_mem = extra_config.pop("max_mem")
            if max_mem is not None:
                self.mem_eval._ccfg.device_capacity.set(max_mem)

        logger.debug("before global config init")

        if "sub_model" in extra_config:
71
72
73
74
75
76
77
78
79

        if "sub_model" in extra_config:
            sub_model = extra_config.pop("sub_model")
            if sub_model is not None:
                self.config = GlobalConfig(
                    self.mem_eval._ccfg.mm_ccfgs[sub_model],
                    dimensions,
                    mppb=manual_ppb,
                )
81
82
83
84
85
86
87
88
89
                self.config = GlobalConfig(
                    self.mem_eval._ccfg, dimensions, mppb=manual_ppb
                )
        else:
            self.config = GlobalConfig(
                self.mem_eval._ccfg, dimensions, mppb=manual_ppb
            )

        self.mem_eval.set_passes(**extra_config)
 92
 93
 94
 95
 96
 97
 98
 99
100
            self.config.ccfg.strategy_num_devices()
        )

        if global_batch_size:
            self.global_batch_size = global_batch_size
        else:
            self.global_batch_size = self.config.ccfg.gbs

        self.bound_space()
135
136
137
138
139
140
141
142
143
                str(kv_heads),
            )
        else:
            # num_head % (TP * UP) == 0. Add UP later
            Dim.TP.set_bound(
                Hard.highest_power_of_2_divisor(self.config.ccfg.a)
            )

    def filtered_out(self, _):
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        if not self.config.moe_valid(parallel_config):
            logger.warning("expert parallel is higher than expert number")
            return False
        if self.filtered_out(parallel_config):
            logger.warning("Config manually filtered out")
            return False
        gbs = self.config.global_batch_size(parallel_config)
        if not gbs == self.global_batch_size:
            logger.error(
                "wrong global batch size: ccfg is %d, instead of %d",
                gbs,
                self.global_batch_size,
            )
            return False
        return True

    def memory_estim(self, debugger=None):
        """Whether the config fits memory"""
179
180
181
182
183
184
185
186
187
            verbose=verbose
        )  # (logger.level>2))
        logger.debug("peak memory = %d", peak)
        if debugger and debugger.is_enabled():
            debugger.info[Debug.MemParts.TOTAL] = peak
        return peak

    def generate_search_space(self, folder, threads_num):
        """Return a search space computed with memory estimation"""
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        space = ({}, 0)
        configs = []
        results = {}
        if threads_num:
            with proc.Pool(processes=threads_num) as pool:
                logger.debug("before loops")
                results, size = self.device_loops(space, pool)
                logger.debug("%d results", len(results))
                for config, result in results.items():
                    logger.debug("result = %s", str(result))
                    logger.debug(
                        "before get: is ready ? %s", str(result.ready())
                    )
                    logger.debug(
                        "before get: is successful ? %s",
                        str(result.successful()),
                    )
                    # if result.successful():
                    peak_mem = result.get(1)
                    logger.debug(
                        "after get: is ready ? %s", str(result.ready())
                    )
                    logger.debug(
                        "after get: is successful ? %s",
                        str(result.successful()),
                    )
                    logger.debug("peak_mem = %s", str(peak_mem))
                    if self.mem_eval.mem_fit(peak_mem):
                        configs.append((config, peak_mem))
                pool.close()
                pool.join()
        else:
            results, size = self.device_loops(space, None)
            for config, peak_mem in results.items():
                if self.mem_eval.mem_fit(peak_mem):
                    configs.append((config, peak_mem))
                    if folder:
                        self.config.write(folder, config)
        logger.output("%d valid configurations generated", size)
        logger.output("%d configuration fitting memory to order", len(configs))

        return configs

    def device_loops(self, space, pool):
        """Exploration loop nest level 0:
        parallel dimensions dividing devices"""
244
245
246
247
248
249
250
251
252
253
254
                        pp,
                    )
                    dp = self.machine.number // tp // cp // pp
                    if dp < 1:
                        break
                    space = self.batch_loops(space, pool, (dp, tp, pp, cp))
        return space

    def batch_loops(self, space, pool, dtpc_p):
        """Exploration loop nest level 1:
        dimensions dividing batch (except already processed DP)"""
300
301
302
303
304
305
306
307
308
            parallel_config
        ):
            if pool is None:
                if self.enable_debug:
                    mem_debugger = Debug.Debug(
                        parallel_config,
                        info_type=Debug.MemParts,
                        enable=self.enable_debug,
                        output_file="debug_mem.csv",
307
308
309
310
311
312
313
314
315
316
                        enable=self.enable_debug,
                        output_file="debug_mem.csv",
                    )
                    # try:
                    peak = self.memory_estim(mem_debugger)
                    mem_debugger.write()
                else:
                    peak = self.memory_estim()
                # except:
                # logger.error()
317
318
319
320
321
322
323
324
325
326
                # return (configs, size)
            else:
                # logger.debug("before evaluator copy")
                # evaluator = copy.deepcopy(self.mem_eval)
                logger.debug("before apply_async")
                peak = pool.apply_async(
                    pool_estimate_memory,
                    args=(copy.deepcopy(self.config.ccfg),),
                    # args=(evaluator,),
                    # self.memory_estim,
324
325
326
327
328
329
330
331
                    args=(copy.deepcopy(self.config.ccfg),),
                    # args=(evaluator,),
                    # self.memory_estim,
                )
                logger.debug("after apply_async")
            configs[parallel_config] = peak

        return (configs, size)
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        return (configs, size)

    def order_search_space(self, space, threads_num, cache_file):
        """Sort the search space computed with performance estimation"""
        if not space:
            return ([], [])
        multiproc = False
        if threads_num and threads_num > 5 * len(space):
            multiproc = True
        scored_space = []
        debug_parts = []
        for config, mem in space:
            self.config.set_parallel_config(config)
            values = []
            if multiproc:
                with proc.Pool(processes=threads_num) as pool:
                    score = pool.apply_async(
                        pool_estimate_performance,
                        args=(
                            copy.deepcopy(self.config),
                            self.machine.device,
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
                            cache_file,
                        ),
                    )
            else:
                if self.enable_debug:
                    debugger = Debug.Debug(
                        config,
                        info_type=Debug.PerfParts,
                        enable=self.enable_debug,
                    )
                    score = estimate_performance(
                        self.config.ccfg,
                        debugger=debugger,
                        device_type=self.machine.device,
                        memory=mem,
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
                        device_type=self.machine.device,
                        memory=mem,
                        cache_file=cache_file,
                    )
                    debugger.write()
                    debug_parts = list(debugger.info.keys())
                    values = list(debugger.info.values())
                    del values[-2:]
                    del debug_parts[-2:]
                else:
                    score = estimate_performance(
                        self.config.ccfg,
                        device_type=self.machine.device,
                        memory=mem,
                    )
            scored_space.append((config, mem, score, values))

            logger.info("config %s has score %f", str(config), score)

        if multiproc:
            new_scored_space = []
            pool.close()
            pool.join()
            for config, mem, score, values in scored_space:
                new_scored_space.append((config, mem, score.get(), values))
        else:
            new_scored_space = scored_space
        return (sorted(new_scored_space, key=lambda x: x[2]), debug_parts)

    def order_space_test_comm_classified(self, space, order_by=2):
        """Order the given space with performance estimation"""
        scored_space = []
        debug_parts = []
        for config, real_time, real_comm_wait in space:
            debugger = Debug.Debug(
                config, info_type=Debug.PerfParts, enable=self.enable_debug
            )
            self.config.set_parallel_config(config)
            peak_mem = self.memory_estim()
            score = estimate_performance(
                self.config.ccfg,
                debugger=debugger,
                device_type=self.machine.device,
                stage_focused=0,
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
                debugger=debugger,
                device_type=self.machine.device,
                stage_focused=0,
            )  # , memory = mem)
            debugger.write()
            debug_parts = list(debugger.info.keys())
            values = list(debugger.info.values())
            del values[-2:]
            scored_space.append(
                (config, peak_mem, real_time, score, values, real_comm_wait)
            )

            logger.info("config %s has score %f", str(config), score)
        del debug_parts[-2:]
        return (sorted(scored_space, key=lambda x: x[order_by]), debug_parts)

    def order_space_test(self, space, order_by=2):
        """Order the given space with performance estimation"""
        scored_space = []
        debug_parts = []
        for config, real_time in space:
            debugger = Debug.Debug(
                config, info_type=Debug.PerfParts, enable=self.enable_debug
            )
            logger.info("Test config %s", str(config))
            self.config.set_parallel_config(config)
            logger.debug(self.mem_eval.get_strategy())
            peak_mem = self.memory_estim()
            score = estimate_performance(
                self.config.ccfg,
                debugger=debugger,
                device_type=self.machine.device,
            )  # , memory = mem)
            debugger.write()
            debug_parts = list(debugger.info.keys())
            values = list(debugger.info.values())
            del values[-2:]
            scored_space.append((config, peak_mem, real_time, score, values))

            logger.info("config %s has score %f", str(config), score)
        del debug_parts[-2:]
        return (sorted(scored_space, key=lambda x: x[order_by]), debug_parts)

    def plot_title(self):
        """Generate plot title"""
        return (
            f"{self.model_name} on {self.machine.number}"
            + f" {self.machine.device} with {self.global_batch_size} GBS"
        )
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
    ):
        """Test some functions"""
        start = time.time()
        space = self.generate_search_space(yaml_folder, threads_num)
        generation = time.time()
        scored_space, dbg = self.order_search_space(
            space, threads_num, cache_file=cache_file
        )
        ordering = time.time()
        logger.output(
            space_to_string(scored_space, max_num=top_num, debug_parts=dbg)
        )
        logger.output(
            "Space generation took %.2fs and ordering took %.2fs",
            generation - start,
            ordering - generation,
        )
        is_not = " NOT" if not self.config.balancing.from_config else ""
        logger.output(
            "Offset & Recompute were%s computed from config info", is_not
        )
        logger.output(
            "Device number is %d, global batch size is %d, dimensions are %s",
            self.machine.number,
            self.global_batch_size,
            str(self.config.dimensions),
481
482
483
484
485
486
487
488
489
490
491
492
493
            self.machine.number,
            self.global_batch_size,
            str(self.config.dimensions),
        )
        if self.enable_debug:
            file_path = os.path.dirname(os.path.realpath(__file__))
            output_path = file_path + "/output/"
            if scored_space:
                Debug.plot_nd(
                    scored_space,
                    output_path,
                    dbg,
                    title=self.plot_title(),
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
                    dbg,
                    title=self.plot_title(),
                    max_num=top_num,
                )
        return scored_space

    def to_ppb(self, scored_space, k, cfg_name):
        """Create an input file for pipeline balancing"""
        parallel_config = scored_space[k][0]
        self.config.set_parallel_config(parallel_config)
        self.mem_eval.update_config(self.config)
        m = cfg_name + "_nd_to_ppb_" + str(k)
        s = self.config.dim_val(Dim.PP, parallel_config)
        mb = self.config.dim_val(Dim.MBN, parallel_config)
        i = self.config.dim_val(Dim.VPP, parallel_config)
        mem = str(self.config.ccfg.device_capacity.to_mb)
        filename = (
            os.path.dirname(os.path.realpath(__file__))
            + "/../pipeline_balance/layers/"
            + m
            + ".json"
510
511
512
513
514
515
516
517
518
519
            + "/../pipeline_balance/layers/"
            + m
            + ".json"
        )
        with open(filename, "w+", encoding="utf-8") as fp:
            json.dump(
                self.mem_eval.estimate_layer_memory(
                    device_type=self.machine.device
                ),
                fp,
518
519
520
521
522
523
524
525
526
                ),
                fp,
                indent=4,
            )
        logger.output(
            "To run pipeline balancing on configuration %s:"
            "\npython run_pipeline_balance.py "
            "-m %d -s %d -mb %d -i %d -mem %d",
            parallel_config,
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
            mb,
            i,
            mem,
        )
        logger.output("Warning: currently select_recompute_memory \
                should be removed & layer time need to be added")

    def test_from_csv(self, csv_f, output_path=None):
        """Run estimation tests against a real run profiling in csv format"""
        configs, row_num = Debug.get_real_data(csv_f)
        configs_estimated, debug_parts = self.order_space_test(
            configs, order_by=2
        )
        if output_path is not None:
            Debug.plot_vs_real(
                configs_estimated,
                csv_f,
                output_path,
                debug_parts,
546
547
548
549
550
551
552
553
554
555
                output_path,
                debug_parts,
                title=self.plot_title(),
            )
        correl, topk = Debug.correlation_topk(configs_estimated, csv_f)
        return correl, topk, row_num

    def test_from_csv_comm_classified(
        self, csv_f, output_path=None, plot_idle=False
    ):
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
    def test_from_csv_comm_classified(
        self, csv_f, output_path=None, plot_idle=False
    ):
        """Run test to compare estimation with detailed profiling"""
        configs = Debug.get_comm_classified_data(csv_f, plot_idle=plot_idle)
        configs_estimated, debug_parts = self.order_space_test_comm_classified(
            configs, order_by=2
        )

        if output_path is not None:
            Debug.plot_vs_real_comm_classified(
                configs_estimated,
                csv_f,
                output_path,
                debug_parts,
568
569
570
571
572
573
574
575
576
                title=self.plot_title(),
                plot_idle=plot_idle,
            )

        return Debug.correlation_with_classified_comms(configs_estimated)


class ParallelizeMultiModal(ParallelizeLayer):
    """Parallelize a MultiModel"""
583
584
585
586
587
588
589
590
591
        dimensions=None,
        **extra_config,
    ):

        super().__init__(
            evaluator,
            machine,
            global_batch_size=global_batch_size,
            dimensions=dimensions,
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
            mem_eval = EvaluatorV2(
                config, framework=framework, hook_cls=model_name, machine=machine
            )
        else:
            mem_eval = EvaluatorV2(config, framework=framework, machine=machine)

        if "global_batch_size" in extra_config:
            global_batch_size = extra_config.pop("global_batch_size")
        else:
            global_batch_size = None

        if "dimensions" in extra_config:
            dimensions = extra_config.pop("dimensions")
        else:
            dimensions = None

        if mem_eval.ccfg.multimodal:
            logger.debug("MultiModal is triggered")
            self.instance = ParallelizeMultiModal(
                mem_eval,
                machine,
                global_batch_size=global_batch_size,
                dimensions=dimensions,
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691


def space_to_string(space, max_num=None, debug_parts=None):
    """Space printer"""
    i = 0
    s = ""
    if max_num is not None:
        s += "Top " + str(max_num) + " configurations:\n"
    else:
        s += "\n"
    if len(space) == 0:
        return s
    s += "\t"
    for d in space[0][0].all_dims:
        s += str(d) + " " * (6 - len(str(d)))
    s += "Memory    Performance score  "
    if debug_parts is not None:
        for dbg_part in debug_parts:
            s += "\t" + dbg_part.short_name()
    s += "\n"
    for config in space:
        if max_num is not None and max_num == i:
            break
        s += "\t"
        for v in config[0].values():
            s += v + " " * (6 - len(v))
        s += str(config[1]) + " MB  "  # + str(config[2])
        s += f"{(config[2]):16.12e}"
        for v in config[3]:
            s += f"\t{(100*v/config[2]):.2f}%"
        s += "\n"
        i += 1
    return s


def pool_estimate_memory(config):
    """Calls memory estimation for multiprocessing"""
    logger.debug("estimate_peak")
    # print("estimate_peak")
    e = EvaluatorV2(config)
    return e.estimate_peak()


# def pool_estimate_memory(evaluator):
#     """Calls memory estimation for multiprocessing"""
694
695
696
697
698


def pool_estimate_performance(config, device):
    """Calls performance estimation for multiprocessing"""
    return estimate_performance(config, device_type=device)
hyper_parallel/auto_parallel/sapp_nd/nd/run_nd.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# limitations under the License.
# ============================================================================
"""run parallelization"""

import argparse
import os

from hyper_parallel.auto_parallel.sapp_nd.memory_estimation.size import Memory
from hyper_parallel.auto_parallel.sapp_nd.nd.logger import logger, set_verbose_level
import hyper_parallel.auto_parallel.sapp_nd.nd.parallelize as Par
import hyper_parallel.auto_parallel.sapp_nd.nd.dimensions as Dim
import hyper_parallel.auto_parallel.sapp_nd.nd.common.hardware as Hard

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="python run_nd.py",
        description=("Provides a degree to *N* parallelism dimensions"),
        epilog="",
    )
29
30
31
32
33
34
35
36
37
        description=("Provides a degree to *N* parallelism dimensions"),
        epilog="",
    )

    parser.add_argument(
        "-y",
        "--yaml_config",
        type=str,
        required=True,
36
37
38
39
40
41
42
43
44
        type=str,
        required=True,
        help="Path to yaml configuration file",
    )
    parser.add_argument(
        "-f",
        "--framework",
        default="mindformers",
        type=str,
45
46
47
48
49
50
51
52
53
        required=False,
        help="Framework to evaluate in "
        "[mindformers, mindspeed, hyperparallel, torchtitan]",
    )
    parser.add_argument(
        "-d",
        "--devices",
        type=int,
        default=None,
52
53
54
55
56
57
58
59
60
        type=int,
        default=None,
        help="Number of devices. Takes yaml value if unspecified",
    )
    parser.add_argument(
        "-b",
        "--global_batch_size",
        type=int,
        default=None,
59
60
61
62
63
64
65
66
67
        type=int,
        default=None,
        help="Global batch size. Takes yaml value if unspecified",
    )
    parser.add_argument(
        "-m",
        "--model",
        type=str,
        default=None,
80
81
82
83
84
85
86
87
88
    #     type=str,
    #     default=None,
    #     help="Computes correlation coefficient from csv results file",
    # )
    parser.add_argument(
        "-l",
        "--dimensions",
        nargs="*",
        type=str,
 95
 96
 97
 98
 99
100
101
102
103
    #     type=int,
    #     default=None,
    #     help="Number of threads for the space generation",
    # )
    parser.add_argument(
        "-v",
        "--verbosity",
        type=int,
        default=2,
111
112
113
114
115
116
117
118
119
    #     type=int,
    #     default=None,
    #     help="choose configuration number k for ppb",
    # )
    parser.add_argument(
        "-A",
        "--device_type",
        default="A2",
        help="choose device type between A2 or A3",
117
118
119
120
121
122
123
124
125
        "--device_type",
        default="A2",
        help="choose device type between A2 or A3",
    )
    parser.add_argument(
        "-swap_os",
        "--swap_opt_state",
        action=argparse.BooleanOptionalAction,
        default=False,
131
132
133
134
135
136
137
138
139
    #     action=argparse.BooleanOptionalAction,
    #     default=False,
    #     help="Activate less memory schedule",
    # )
    parser.add_argument(
        "-mppb",
        "-–manual_pipeline_balance",
        action=argparse.BooleanOptionalAction,
        default=False,
138
139
140
141
142
143
144
145
146
        action=argparse.BooleanOptionalAction,
        default=False,
        help="Takes offset and recompute from yaml",
    )
    parser.add_argument(
        "-t",
        "--top_config_number",
        type=int,
        default=None,
145
146
147
148
149
150
151
152
153
        type=int,
        default=None,
        help="Number of top configs to print & plot",
    )
    parser.add_argument(
        "-mem",
        "--mem_for_ppb",
        type=str,
        default="0GB",
153
154
155
156
157
158
159
160
161
        default="0GB",
        help="Memory to reserve for pipeline balancing. "
        "Will be decreased from the memory budget allowed by ND (default 0GB)",
    )
    parser.add_argument(
        "-c",
        "--cache_file",
        type=str,
        default=None,
162
163
164
165
166
167
168
169
170
        help="Cache file with ratios to recalibrate ND scores. "
        "Will be defaulted to 'None'.",
    )

    parser.add_argument(
        "-M",
        "--max_mem",
        type=str,
        default=None,
170
171
172
173
174
175
176
177
178
        default=None,
        help="Memory to reserve for pipeline balancing. "
        "Will be decreased from the memory budget allowed by ND (default 0GB)",
    )
    parser.add_argument(
        "--train-yaml",
        type=str,
        default=None,
        help="Path to training configuration yaml file (for hyperparallel2)",
176
177
178
179
180
181
182
183
184
        type=str,
        default=None,
        help="Path to training configuration yaml file (for hyperparallel2)",
    )
    parser.add_argument(
        "--accelerate-yaml",
        type=str,
        default=None,
        help="Path to accelerate configuration yaml file (for hyperparallel2)",
183
184
185
186
187
188
189
190
191
192
193
        default=None,
        help="Path to accelerate configuration yaml file (for hyperparallel2)",
    )

    args = parser.parse_args()

    max_mem = (
        Memory.from_string(args.max_mem.strip())
        if args.max_mem is not None
        else None
    )
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        if args.max_mem is not None
        else None
    )

    if args.cache_file is not None:
        if not os.path.exists(args.cache_file):
            logger.error(
                f"cache file not found:"
                f" {args.cache_file}"
                "\nProceeding without cache file..."
            )
            args.cache_file = None

    set_verbose_level(args.verbosity)
    dims = Dim.get_dims(args.dimensions)
    YAML_FOLDER = None  # args.generate_yaml_in
    machine = Hard.Machine(args.devices, args.device_type)

    if args.framework == "hyperparallel2":
        if args.yaml_config is None or args.train_yaml is None or args.accelerate_yaml is None:
            parser.error("-y (model yaml), --train-yaml, and --accelerate-yaml are required for hyperparallel2")
        input_config = {
            "model": args.yaml_config,
            "train": args.train_yaml,
            "accelerate": args.accelerate_yaml,
            "machine": args.devices
214
215
216
217
218
219
220
221
222
223
224
            "train": args.train_yaml,
            "accelerate": args.accelerate_yaml,
            "machine": args.devices
        }
    elif args.framework == "torchtitan":
        module, config = args.yaml_config.split(":")
        input_config = {
            "module": module,
            "config": config,
            "machine": machine,
        }
222
223
224
225
226
227
228
229
230
231
232
            "config": config,
            "machine": machine,
        }
    else:
        input_config = args.yaml_config

    nd_runner = Par.Parallelize(
        args.framework,
        input_config,
        machine,
        global_batch_size=args.global_batch_size,
239
240
241
242
243
244
245
246
247
248
249
250
        mem_for_ppb=Memory.from_string(args.mem_for_ppb.strip()),
        # vpp_less_mem=args.less_memory,
    )

    if YAML_FOLDER and not os.path.exists(YAML_FOLDER):
        os.makedirs(YAML_FOLDER)

    space = nd_runner.run_generation_to_ordering(
        YAML_FOLDER,
        threads_num=None,  # args.threads_num
        top_num=args.top_config_number,
        cache_file=args.cache_file,
hyper_parallel/auto_parallel/sapp_nd/perf_estimation/comm_time.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
COUNT_OPTIMIZER = False

def fill_dp_table(cfg, tables):
    """DP"""
    table_dp = {}
    table_dp["n_attMM"] = cfg.h * cfg.h / cfg.t
    table_dp["n_ffMM"] = cfg.h * cfg.hff / cfg.t
    table_dp["n_normOp"] = 2 * cfg.h / cfg.sp

    if COUNT_OPTIMIZER:
        table_dp["n_attParamCast"] = (
            11 * cfg.h * cfg.h / (cfg.d if cfg.has_op else 1)
        )
        table_dp["n_ffParamCast"] = (
            11 * cfg.h * cfg.hff / (cfg.d if cfg.has_op else 1)
        )
    for op in table_dp:
        table_dp[op] *= cfg.bytes_norm if op == "n_normOp" else cfg.bytes_p

    table_exp_dp = deepcopy(table_dp)
    table_exp_dp["n_ffMM"] = (
        2
        * (cfg.n_exp + cfg.n_shared_exp)
        * cfg.h
        * cfg.hff_exp
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        * cfg.hff_exp
        / cfg.t
        * cfg.bytes_p
    )
    tables[Dim.DP] = table_dp
    tables["exp_dp"] = table_exp_dp


def fill_tp_table(cfg, tables):
    """TP"""
    table_tp = {}
    high_tp_bias = 11 / 16 if cfg.t >= 8 else 1  # Fix this
    table_tp["n_gather"] = cfg.b * cfg.s * cfg.h * high_tp_bias

    for op in table_tp:
        table_tp[op] *= cfg.bytes_compute

    table_exp_tp = deepcopy(table_tp)
    table_exp_tp["n_gather"] = (
        cfg.b * cfg.s * cfg.h * 1.5 * (cfg.ep / cfg.d) * cfg.bytes_compute
    )
    tables["tp"] = table_tp
    tables["exp_tp"] = table_exp_tp


def fill_ep_table(cfg, tables, device_type):
    """EP"""
    intra_devices = device_type.intra_node_num()
    table_ep = {}
    inter_node_bias_ep = 1
    table_ep["n_ffMM"] = (
        4
        * cfg.n_chosen_exp
        * cfg.b
        * cfg.s
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
            )
        )
    )

    for op in table_ep:
        table_ep[op] *= cfg.bytes_compute
    tables[Dim.EP] = table_ep


def dp_ratio(cfg, device_type):
    """formula"""
    return (
        0
        if cfg.comm_d_non_exp == 0
        else 1
        - True  # overlap_dp, Completely overlap standard DP comm
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148


def comm_embed_ouput(cfg):
    """ "formula"""
    comm_embed = cfg.bytes_compute * cfg.h * cfg.v / cfg.shard_embed
    comm_output = cfg.h * cfg.v / cfg.t
    return comm_embed, comm_output


def estimate_op_bulk_comm(*args, **kwargs):
    """FW + BW"""
    param = {
        "cfg": args[0],
        "ccfg": args[1],
        "stages": args[2],
        "device_type": args[3],
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        ),
        "debugger": kwargs.get("debugger", args[5] if len(args) > 5 else None),
    }

    param["tables"] = {}
    fill_dp_table(param["cfg"], param["tables"])

    param['dp_ratio'] = dp_ratio(param['cfg'], param['device_type'])

    param["comm_embed"], param["comm_output"] = comm_embed_ouput(param["cfg"])

    if param["cfg"].dc_kv != 0:  # Deepseek
        param["comm_output"] += param["cfg"].h * (
            2 * param["cfg"].h + param["cfg"].v
        )
        param["comm_output"] *= param["cfg"].n_mtp

    param["comm_output"] *= param["cfg"].bytes_p

    fill_tp_table(param["cfg"], param["tables"])
    fill_ep_table(param["cfg"], param["tables"], param["device_type"])

    lccfgs = get_layer_custom_configs(param["cfg"])
    logger.info(lccfgs)
    param["layer_count"] = 0
    param["idx_lccfg"] = 0
    comms = {Dim.DP: [], Dim.TP: [], Dim.EP: []}
    # ignores comm recomp, to improve
    for stage in param["stages"]:
        comm = {Dim.DP: 0.0, Dim.TP: 0.0, Dim.EP: 0.0}
        for chunk in stage:
            for layer in chunk:
                param["layer_count"], param["idx_lccfg"] = (
                    estimate_op_bulk_comm_layer(
                        param,
                        lccfgs,
                        layer=layer,
188
189
190
191
192
193
194
195
196
197
198
                        layer_count=param["layer_count"],
                        idx_lccfg=param["idx_lccfg"],
                    )
                )
        if param["ccfg"].ttype == PerformanceType.TIME:
            for dim, ov in zip([Dim.DP, Dim.TP, Dim.DP], [0.0, 0.0, 0.0]):
                comm[dim] = estimate_comm_score(
                    param["cfg"],
                    comm[dim],
                    dim,
                    overlap=ov,
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
                    overlap=ov,
                    device=param["device_type"],
                )

        comm[Dim.DP] *= param["dp_ratio"]
        comm[Dim.TP] *= param["cfg"].comm_t
        comm[Dim.EP] *= param["cfg"].comm_ep

        if param["device_type"].name == "A3":
            logger.info("A3 ratio")
            comm[Dim.TP] /= 3

        comms[Dim.DP].append(comm[Dim.DP])
        comms[Dim.TP].append(comm[Dim.TP])
        comms[Dim.EP].append(comm[Dim.EP])

    if param["debugger"] and param["debugger"].is_enabled():
        logger.info("DP_COMM = %s", comms[Dim.DP])
        logger.info("MP_COMM = %s", comms[Dim.TP])
        logger.info("EP_COMM = %s", comms[Dim.EP])
        param["debugger"].info[PerfParts.DP_COMM] = comms[Dim.DP]
        param["debugger"].info[PerfParts.MP_COMM] = comms[Dim.TP]
        param["debugger"].info[PerfParts.EP_COMM] = comms[Dim.EP]

    res = []
    for i, c in enumerate(comms[Dim.TP]):
        res.append(comms[Dim.DP][i] + c + comms[Dim.EP][i])

    return res


def estimate_op_bulk_comm_layer(cfg, lccfgs, **kwargs):
    """for estimate_op_bulk_comm"""
    if kwargs["layer"] == LayerType.EMBEDDING_LAYER:
        kwargs["comm"][Dim.DP] += kwargs["param"]["comm_embed"]
        return kwargs["layer_count"]

    if kwargs["layer"] == LayerType.OUTPUT_LAYER:
        kwargs["comm"][Dim.DP] += kwargs["param"]["comm_output"]
        if cfg.dc_kv != 0:  # Deepseek
            lccfg = lccfgs[kwargs["idx_lccfg"]][0]
            kwargs["comm"][Dim.TP] += cfg.n_mtp * get_table_quantity(
                lccfg,
                kwargs["param"]["tables"]["exp_tp"],
                LayerType.NOT_REC_LAYER,
                kwargs["param"]["with_recomp"],
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
                kwargs["param"]["tables"]["exp_tp"],
                LayerType.NOT_REC_LAYER,
                kwargs["param"]["with_recomp"],
            )
        return kwargs["layer_count"]

    if (
        kwargs["idx_lccfg"] + 1 < len(lccfgs)
        and lccfgs[kwargs["idx_lccfg"]][1] == kwargs["layer_count"]
    ):
        kwargs["layer_count"] = 0
        kwargs["idx_lccfg"] += 1

    lccfg = lccfgs[kwargs["idx_lccfg"]][0]
    is_moe_layer = lccfg.n_exp > 1

    if is_moe_layer:
        kwargs["comm"][Dim.DP] += get_table_quantity(
            lccfg,
            kwargs["param"]["tables"]["exp_dp"],
            kwargs["layer"],
            kwargs["param"]["with_recomp"],
260
261
262
263
264
265
266
267
268
            kwargs["param"]["tables"]["exp_dp"],
            kwargs["layer"],
            kwargs["param"]["with_recomp"],
        )
        kwargs["comm"][Dim.TP] += get_table_quantity(
            lccfg,
            kwargs["param"]["tables"]["exp_tp"],
            kwargs["layer"],
            kwargs["param"]["with_recomp"],
266
267
268
269
270
271
272
273
274
            kwargs["param"]["tables"]["exp_tp"],
            kwargs["layer"],
            kwargs["param"]["with_recomp"],
        )
        kwargs["comm"][Dim.EP] += get_table_quantity(
            lccfg,
            kwargs["param"]["tables"][Dim.EP],
            kwargs["layer"],
            kwargs["param"]["with_recomp"],
273
274
275
276
277
278
279
280
281
            kwargs["layer"],
            kwargs["param"]["with_recomp"],
        )
    else:
        kwargs["comm"][Dim.DP] += get_table_quantity(
            lccfg,
            kwargs["param"]["tables"][Dim.DP],
            kwargs["layer"],
            kwargs["param"]["with_recomp"],
279
280
281
282
283
284
285
286
287
            kwargs["param"]["tables"][Dim.DP],
            kwargs["layer"],
            kwargs["param"]["with_recomp"],
        )
        kwargs["comm"][Dim.TP] += get_table_quantity(
            lccfg,
            kwargs["param"]["tables"]["tp"],
            kwargs["layer"],
            kwargs["param"]["with_recomp"],
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
            kwargs["layer"],
            kwargs["param"]["with_recomp"],
        )

    kwargs["layer_count"] += 1
    return kwargs["layer_count"], kwargs["idx_lccfg"]


def prepare_context():
    """context object"""
    ctx = Context()
    ctx.attn_num_p = EvalAttn.num_params_attn
    ctx.ffn_num_p = EvalFFn.num_params_ffn
    ctx.norm_num_p = EvalNorm.num_params_norm

    ctx.node_eval[LayerType.EMBEDDING_LAYER] = NodeEval(
        EvalHead.num_params_embed, None, None
    )
    ctx.node_eval[LayerType.OUTPUT_LAYER] = NodeEval(
        EvalTail.num_params_output, None, None
    )
    ctx.node_eval[LayerType.NOT_REC_LAYER] = NodeEval(
        EvalBody.num_params_layer, None, None
    )
    ctx.enable_accu_log = False
    return ctx


def estimate_from_mem_comm(*args, **kwargs):
    """For memory estimation"""
313
314
315
316
317
318
319
320
321

def estimate_from_mem_comm(*args, **kwargs):
    """For memory estimation"""

    param = {
        "cfg": args[0],
        "ccfg": args[1],
        "stages": args[2],
        "device_type": args[3],
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        "ccfg": args[1],
        "stages": args[2],
        "device_type": args[3],
    }
    param["debugger"] = kwargs.get(
        "debugger", args[5] if len(args) > 5 else None
    )
    param["ctx"] = prepare_context()

    # For layer type
    param["flatten"] = sum(
        [[f[1]] * f[0] for f in param["cfg"].layer_custom_config], []
    )
    comms = {Dim.DP: [], Dim.TP: [], Dim.EP: [], Dim.CP: []}
    for stage in param["stages"]:
        comm = {Dim.DP: 0.0, Dim.TP: 0.0, Dim.EP: 0.0, Dim.CP: 0.0}
        for chunk in stage:
            for layer in chunk:
                param["ctx"].current_node = layer
                if (
                    layer
                    not in [LayerType.EMBEDDING_LAYER, LayerType.OUTPUT_LAYER]
                    and param["flatten"]
                ):
                    custom_fun = param["flatten"].pop(0)
                    if custom_fun:
                        custom_fun(param["cfg"])
                    logger.info("is layer moe ? %s", param["cfg"].n_exp > 1)
                    param["ctx"].current_node = LayerType.NOT_REC_LAYER
                    logger.info("param ctx %s", param["ctx"])
                    comm[Dim.DP] += EvalLayerComm.dp_comm_layer(param["cfg"], param["ctx"])

                comm[Dim.TP] += EvalLayerComm.tp_comm_layer(
                    param["cfg"], param["ctx"], 1
                )  # / 4 #* (param["cfg"].t - 1)
                comm[Dim.EP] += EvalLayerComm.ep_comm_layer(
                    param["cfg"], param["ctx"], 1
                )  # * param["cfg"].ep
                comm[Dim.CP] += EvalLayerComm.cp_comm_layer(
                    param["cfg"], param["ctx"]
                )
                # min(device_type.level_bound_number[0], param["cfg"].ep)
                # comm_cp += EvalLayerComm.cp_comm_layer
362
363
364
365
366
367
368
369
370
371
372
                # (param["cfg"], param["ctx"])



        if param["ccfg"].ttype == PerformanceType.TIME:
            for dim, ov in zip([Dim.DP, Dim.TP, Dim.DP], [0.9, 0, 0.0]):
                comm[dim] = estimate_comm_score(
                    param["cfg"],
                    comm[dim],
                    dim,
                    overlap=ov,
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442

        # comm[Dim.EP] += comm[Dim.EP] * param["cfg"].ep / 100
        # comm[Dim.TP] += comm[Dim.TP] * param["cfg"].t / 100

        dev_per_node = param["device_type"].level_bound_number[0]
        comm[Dim.TP] *= max(1, param["cfg"].t // dev_per_node)
        comm[Dim.EP] *= max(1, param["cfg"].ep // dev_per_node)
        comm[Dim.CP] *= max(1, param["cfg"].cp // dev_per_node)

        comm[Dim.TP] /= 2 # TO REMOVE: FOR RUIWEN TEST ONLY
        comm[Dim.DP] = 0 #/= 10 # TO REMOVE: FOR RUIWEN TEST ONLY

        if param["device_type"].name == "A3":
            logger.info("A3 ratio")
            comm[Dim.DP] /= 2
            comm[Dim.TP] /= 2
            comm[Dim.EP] /= 2
            comm[Dim.CP] /= 2

        comms[Dim.DP].append(comm[Dim.DP])
        comms[Dim.TP].append(comm[Dim.TP])
        comms[Dim.EP].append(comm[Dim.EP])
        comms[Dim.CP].append(comm[Dim.CP])

    if param["debugger"] and param["debugger"].is_enabled():
        logger.info("DP_COMM = %s", comms[Dim.DP])
        logger.info("MP_COMM = %s", comms[Dim.TP])
        logger.info("EP_COMM = %s", comms[Dim.EP])
        logger.info("CP_COMM = %s", comms[Dim.CP])
        param["debugger"].info[PerfParts.DP_COMM] = comms[Dim.DP]
        param["debugger"].info[PerfParts.MP_COMM] = comms[Dim.TP]
        param["debugger"].info[PerfParts.EP_COMM] = comms[Dim.EP]
        param["debugger"].info[PerfParts.CP_COMM] = comms[Dim.CP]

    res = []
    for i, c in enumerate(comms[Dim.TP]):
        res += [c + comms[Dim.DP][i] + comms[Dim.EP][i] + comms[Dim.CP][i]]

    return res


def estimate_comm(*args, **kwargs):
    """wrapper"""
    cfg, ccfg, stages, device_type = args[0], args[1], args[2], args[3]
    with_recomp = kwargs.get(
        "with_recomp", args[4] if len(args) > 4 else False
    )
    debugger = kwargs.get("debugger", args[5] if len(args) > 5 else None)
    # return estimate_op_bulk_comm(cfg, ccfg, stages,
    # device_type=device_type, with_recomp=with_recomp,
    # debugger=debugger)
    return estimate_from_mem_comm(
        cfg,
        ccfg,
        stages,
        device_type,
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490


def level_efficiency(level):
    """to improve for Ascend A2"""
    if level == NetworkLevel.NODE:
        return 0.7
    if level == NetworkLevel.CLUSTER:
        return 0.9
    raise ValueError


def level_bandwidth(level):
    """to improve for Ascend A2"""
    if level == NetworkLevel.NODE:
        return 300
    if level == NetworkLevel.CLUSTER:
        return 25
    raise ValueError


def level_latency(level):
    """to improve for Ascend A2"""
    if level == NetworkLevel.NODE:
        return 0.00001
    if level == NetworkLevel.CLUSTER:
        return 0.00002
    raise ValueError


def comm_throughput(level):
    """formula"""
    eff = level_efficiency(level)
    bw = level_bandwidth(level)
    return bw * eff


def estimate_comm_size_time(_, comm_size, level):
    """formula"""
    th = comm_throughput(level)
    lat = level_latency(level)
    return lat + comm_size / th


def estimate_comm_score(
    cfg, comm_volume, dim, overlap=0.0, device=Hard.device_map["A2"]
489
490
491
492
493
494
495
496
497
498
499
500
501
502
def estimate_comm_score(
    cfg, comm_volume, dim, overlap=0.0, device=Hard.device_map["A2"]
):
    """score assignment"""
    assignment = device.level_assign(dp=cfg.d, tp=cfg.t, pp=cfg.p)
    score = 0
    for level in range(device.levels):
        # intra_comm = comm_volume * (1-overlap)
        # * (assignment[dim][0]-1) / device.intra_node_bw
        score += (
            comm_volume
            * (1 - overlap)
            * (
                (assignment[dim][level] - 1)
503
504
505
506
507
                * device.devices_below_level(level)
                / device.level_bandwidth[level]
            )
        )
    return score
hyper_parallel/auto_parallel/sapp_nd/perf_estimation/estimate.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
BACKWARD_RATIO = 2

def op_table(cfg):
    """op compute load formulas"""
    table = {}
    # cfg.s /= cfg.cp
    table["n_attMM"] = (
        3 * (1 + cfg.n_kv / cfg.a) * cfg.b * cfg.s * cfg.h * cfg.h
    )
    table["n_ffMM"] = 6 * cfg.b * cfg.s * cfg.h * cfg.hff
    table["n_attBMM"] = 6 * cfg.b * cfg.s * cfg.s * cfg.h
    table["n_ffBMM"] = 6 * cfg.b * cfg.s * cfg.s * cfg.hff
    table["n_softmax"] = 13 * cfg.a * cfg.b * cfg.s * cfg.s
    table["n_headCast"] = 3 * cfg.a * cfg.b * cfg.s * cfg.s
    table["n_gather"] = cfg.b * cfg.s * cfg.h * (cfg.t - 1)
    table["n_ffAct"] = 21 * cfg.b * cfg.hff

    table["n_normOp"] = 30 * cfg.b * cfg.s * cfg.h * cfg.t / cfg.sp
    table["n_dropout"] = (
        3 * cfg.b * cfg.s * max(cfg.a * cfg.s, 3 * cfg.h * cfg.t / cfg.sp)
    )
    if cfg.dc_kv != 0:  # Deepseek
        table["n_attMM"] = (
            3
            / 2
            * (
                2 * cfg.dc_kv * cfg.n_kv * cfg.dh
72
73
74
75
76
77
78
79
80
81
82
83
            )
            * cfg.b
            * cfg.s
        )
    for op in table:
        table[op] *= cfg.bytes_p / cfg.t / cfg.cp
    # cfg.s *= cfg.cp
    return table


# Evaluation functions
def estimate_op_bulk_comp(cfg, ccfg, stages, with_recomp=False, debugger=None):
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

# Evaluation functions
def estimate_op_bulk_comp(cfg, ccfg, stages, with_recomp=False, debugger=None):
    """FW + BW"""
    _ = debugger
    table = op_table(cfg)

    table_exp = deepcopy(table)  # Verify this with MF MoEV2
    table_exp["n_ffMM"] *= (
        cfg.hff_exp / cfg.hff * max(1, cfg.n_chosen_exp) * cfg.cap_fact
    )
    table_exp["n_ffBMM"] *= (
        cfg.hff_exp / cfg.hff * max(1, cfg.n_chosen_exp) * cfg.cap_fact
    )

    lccfgs = get_layer_custom_configs(cfg)
    layer_count = 0
    idx_lccfg = 0

    flops = []
    for stage in stages:
        flops += [0]
        for chunk in stage:
            for layer in chunk:
                if layer == LayerType.EMBEDDING_LAYER:
                    continue

                if layer == LayerType.OUTPUT_LAYER:
                    flops[-1] += (1 if cfg.dc_kv == 0 else cfg.n_mtp) * (
                        1
                        / 16  # bias_imbalance
                        * 6
                        * cfg.b
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
                        * cfg.s
                        * cfg.bytes_p
                        / cfg.t
                    )
                    continue

                layer_count += 1
                if (
                    idx_lccfg + 1 < len(lccfgs)
                    and lccfgs[idx_lccfg][1] <= layer_count
                ):
                    layer_count = 0
                    idx_lccfg += 1

                flop = get_table_quantity(
                    lccfgs[idx_lccfg][0],
                    table_exp if (lccfgs[idx_lccfg][0].n_exp > 1) else table,
                    layer,
                    with_recomp,
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
                    layer,
                    with_recomp,
                )

                if ccfg.ttype == PerformanceType.TIME:
                    flop = estimate_comp_flop_time(lccfgs[idx_lccfg][0], flop)

                flops[-1] += flop

    return flops


def estimate_comp(cfg, ccfg, stages, with_recomp=False, debugger=None):
    """wrapper"""
    return estimate_op_bulk_comp(
        cfg, ccfg, stages, with_recomp, debugger=debugger
    )

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176


def efficiency(x):
    """obtained via extrapolation"""
    eff = min(
        1.0, max(0.1, 0.00004694 * x**3 + 0.0014 * x**2 - 0.0336 * x + 0.1)
    )
    return eff

def throughput(precision_bytes, flop):
    """assumes matrix"""
    eff = efficiency(flop / (10.0**12))
    return precision_bytes**2 * (10.0**12) * eff


def estimate_comp_flop_time(cfg, flop, is_softmax=False):
    """flop from throughput"""
    th = throughput(
        cfg.bytes_softmax if is_softmax else cfg.bytes_compute, flop
    )
    return flop / th


##############################################################
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216


def get_dynamic_ratio(cfg):
    """comm/comp"""
    if cfg.n_exp == 1:
        return 3 / 2 * (cfg.hff + cfg.s) * (8192 / (cfg.h + cfg.s))
    return 3 / 2 * (cfg.hff_exp + cfg.s) * (8192 / (cfg.h + cfg.s))

def estimate_stage(*args, **kwargs):
    """stage level estimation"""
    cfg = args[0]
    ccfg = args[1]
    compute_perfs = args[2]
    comm_perfs = args[3]
    recompute_perfs = args[4]
    recomm_perfs = args[5]
    debugger = kwargs.get("debugger", args[6] if len(args) > 6 else None)
    comp_w = 1
    comm_w = 1
    if ccfg.rtype == RatioType.COMM_ONLY:
        comp_w = 0
    elif ccfg.rtype == RatioType.COMPUTE_ONLY:
        comm_w = 0
    elif ccfg.rtype == RatioType.STATIC:
        comm_w = 10**4
        ccfg.static_ratio = comm_w
    elif ccfg.rtype == RatioType.DYNAMIC:
        comm_w = get_dynamic_ratio(cfg)
        ccfg.dynamic_ratio = comm_w
    perf = [
        comp_w * compute_perfs[i] + comm_w * comm_perfs[i]
        for i in range(len(compute_perfs))
    ]
    logger.info("ratio = %s", comm_w)
    # ignores comm recomp, to improve
    re_perf = [
        (
            max(0, comp_w * (recompute_perfs[i] - compute_perfs[i]))
            + max(0, comm_w * (recomm_perfs[i] - comm_perfs[i]))
        )
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        / (1 + BACKWARD_RATIO)
        for i in range(len(compute_perfs))
    ]

    if debugger and debugger.is_enabled():
        for p in [PerfParts.DP_COMM, PerfParts.MP_COMM, PerfParts.EP_COMM, PerfParts.CP_COMM]:
            debugger.info[p] = [
                comm_w * c for c in debugger.info[p]
            ]
        debugger.info[PerfParts.FW_COMPUTE] = [
            comp_w * comp / (1 + BACKWARD_RATIO) for comp in compute_perfs
        ]
        debugger.info[PerfParts.BW_COMPUTE] = [
            fw * BACKWARD_RATIO for fw in debugger.info[PerfParts.FW_COMPUTE]
        ]
        debugger.info[PerfParts.RECOMPUTE] = re_perf

    return [perf[i] + re_perf[i] for i in range(len(perf))]
    #penalty_fn(stage)
    #return stage

def estimate_pipeline(cfg, stage_perfs, stage_focused=None, debugger=None):
236
237
238
239
240
241
242
243
244
245
246
247
248
    #return stage

def estimate_pipeline(cfg, stage_perfs, stage_focused=None, debugger=None):
    """pipeline level estimation"""
    logger.info("stage_perfs = %s", stage_perfs)
    straggler_time = max(stage_perfs)
    sum_time = sum(stage_perfs)
    last_straggler_idx = cfg.p - 1 - np.argmax(stage_perfs[::-1])
    logger.info(
        "straggler estim is %s and its stage is %s",
        straggler_time,
        last_straggler_idx,
    )
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        straggler_time,
        last_straggler_idx,
    )

    non_steady_perf = 0
    steady_perf = 0
    if cfg.p == 1:
        assert len(stage_perfs) == 1
        steady_perf = sum_time * cfg.m
    elif cfg.vp == 1:
        non_steady_perf = sum_time
        if GENERALIZE_PIPELINE_CALCULATION:
            last_idx = last_straggler_idx + 1
            steady_perf = (
                cfg.m - cfg.p + last_straggler_idx
            ) * straggler_time + sum(stage_perfs[last_idx:])
        else:
            steady_perf = (cfg.m - 1) * straggler_time
    else:
        less_extra = cfg.p * (cfg.vp - 1)
        # big_extra = less_extra + cfg.p

        # we assume that times of all micro-blocks in one vp chunk are the same
        straggler_time /= cfg.vp
        sum_time /= cfg.vp

        # more_memory has more micro blocks in warm-up but it does
        # not matter since they will overlap with steady phase
        non_steady_perf = sum_time + less_extra * straggler_time

        # more_memory is a boost to performance and a nerf to memory
        # if cfg.vp_less_memory or True:
        #     steady_perf = (cfg.m * cfg.vp - less_extra - 1) * straggler_time
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        # if cfg.vp_less_memory or True:
        #     steady_perf = (cfg.m * cfg.vp - less_extra - 1) * straggler_time
        # else:
        #     steady_perf = (cfg.m * cfg.vp - big_extra - 1) * straggler_time
        steady_perf = (cfg.m * cfg.vp - less_extra - 1) * straggler_time
        straggler_time *= cfg.vp
        sum_time *= cfg.vp

    pipeline_perf = non_steady_perf + steady_perf
    logger.info(
        "pipeline_perf = non_steady_perf(%.2E) + steady_perf(%.2E)",
        non_steady_perf,
        steady_perf,
    )
288
289
290
291
292
293
294
295
296
297
298
299
300
        non_steady_perf,
        steady_perf,
    )

    if stage_focused is not None:
        last_straggler_idx = stage_focused
    if debugger and debugger.is_enabled():
        time_sum = 0
        for k in [
            PerfParts.DP_COMM,
            PerfParts.MP_COMM,
            PerfParts.EP_COMM,
            PerfParts.CP_COMM,
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
            PerfParts.FW_COMPUTE,
            PerfParts.BW_COMPUTE,
            PerfParts.RECOMPUTE,
        ]:
            debugger.info[k] = (
                debugger.info[k][last_straggler_idx] * cfg.m
            )  # / cfg.vp
            logger.info(
                "time_sum += debugger[%s] = %.2E",
                k,
                debugger.info[k],
            )
            time_sum += debugger.info[k]

        if abs(time_sum - straggler_time * cfg.m) < 1e-9:
            logger.warning("Inconsistency found in straggler time calculation")
            time_sum = straggler_time * cfg.m
        logger.info(
            "straggler time = %.2E. %s x stragglers = %.2E",
            straggler_time,
            cfg.m,
            straggler_time * cfg.m,
321
322
323
324
325
326
327
328
329
330
            cfg.m,
            straggler_time * cfg.m,
        )

        bubble = pipeline_perf - time_sum
        logger.info(
            "bubble(%.2E) = pipeline_perf(%.2E) - time_sum(%.2E)",
            bubble,
            pipeline_perf,
            time_sum,
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
            bubble,
            pipeline_perf,
            time_sum,
        )
        debugger.info[PerfParts.BUBBLE] = bubble
    return pipeline_perf


def estimate_p2p_comm(cfg, straggler, ratio=MANUAL_P2P_RATIO, debugger=None):
    """pipeline comm"""
    nb_send_recv = 0
    if cfg.vp == 1:
        nb_send_recv = (
            0
            if cfg.p == 1
            else (
                4 * cfg.m
346
347
348
349
350
351
352
353
354
                else 4 * cfg.p * cfg.m + 4 * cfg.p * cfg.p - 14 * cfg.p
            )
        )
    else:
        nb_send_recv = (
            0
            if cfg.p == 1
            else (
                8 * cfg.m * cfg.vp - 4 * cfg.m
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
                    - 13 * cfg.p
                )
            )
        )
    pp_comm = ratio * nb_send_recv / cfg.p * straggler / cfg.sp
    if debugger and debugger.is_enabled():
        debugger.info[PerfParts.PP_COMM] = pp_comm

    return pp_comm


def estimate_perf(cfg, _, stage_perfs, stage_focused=None, debugger=None):
    """wrapper"""
    return estimate_pipeline(cfg, stage_perfs, stage_focused=stage_focused, debugger=debugger)


def estimate_p2p(cfg, ccfg, stage_perfs, debugger=None):
    """wrapper"""
    if ccfg.ptype != P2PCommType.MANUAL:
        p2p = 0
    else:
        p2p = estimate_p2p_comm(cfg, max(stage_perfs), debugger=debugger)
    if debugger and debugger.is_enabled():
        debugger.info[PerfParts.PP_COMM] = p2p
    return p2p


def estimate_layer_perf(*args, **kwargs):
    """for PPB"""
    cfg = args[0]
    stages = kwargs.get("stages", args[2] if len(args) > 2 else None)
    extra_custom_func = kwargs.get(
        "extra_custom_func", args[3] if len(args) > 3 else None
    )
    ccfg = kwargs.get("ccfg", args[4] if len(args) > 4 else CustomConfig())
    debugger = kwargs.get("debugger", args[5] if len(args) > 5 else None)
    # cfg = CostModelConfig(mf_config)
    # Process custom model config
    if extra_custom_func:
        extra_custom_func(cfg)
    else:
        logger.info("auto applying custom model config")
        check_and_apply_custom_hook(cfg)

    new_layer_config = []
    stages = [[LayerType.EMBEDDING_LAYER]]
    for _, layer in cfg.layer_custom_config:
        new_layer_config.append((1, layer))
        stages.append([LayerType.NOT_REC_LAYER])
    stages.append([LayerType.OUTPUT_LAYER])
    cfg.layer_custom_config = new_layer_config

    logger.output("cfg.layer_custom_config = %s", cfg.layer_custom_config)
    logger.output("stages = %s", stages)

    cfg.n = cfg.d * cfg.t * cfg.p

    cfg.n_headCast = 1
    cfg.n_ffAct = 1

    logger.info(str(cfg))
    logger.info(stages)
    logger.info(ccfg)

    perfs = {}
    perfs["compute_perfs"] = estimate_comp(
        cfg, ccfg, stages, with_recomp=False, debugger=debugger
    )
    logger.info("PerfEst: compute_perfs %s", perfs["compute_perfs"])
    perfs["recompute_perfs"] = (
        [0] * cfg.p
        if ccfg.retype not in {RecType.COMPUTE_ONLY, RecType.WITH}
        else estimate_comp(
            cfg, ccfg, stages, with_recomp=True, debugger=debugger
434
435
436
437
438
439
440
441
442
443
444
445
446
            cfg, ccfg, stages, with_recomp=True, debugger=debugger
        )
    )

    perfs["comm_perfs"] = estimate_comm(
        cfg, ccfg, stages, args[1], with_recomp=False, debugger=debugger
    )
    logger.info("PerfEst: comm_perfs %s", perfs["comm_perfs"])
    perfs["recomm_perfs"] = (
        [0] * cfg.p
        if ccfg.retype not in {RecType.COMM_ONLY, RecType.WITH}
        else estimate_comm(
            cfg, ccfg, stages, args[1], with_recomp=True, debugger=debugger
445
446
447
448
449
450
451
452
453
        else estimate_comm(
            cfg, ccfg, stages, args[1], with_recomp=True, debugger=debugger
        )
    )
    stage_perfs = estimate_stage(
        cfg,
        ccfg,
        perfs["compute_perfs"],
        perfs["comm_perfs"],
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        perfs["recompute_perfs"],
        perfs["recomm_perfs"],
        debugger=debugger,
    )
    logger.output("PerfEst: stage_perfs %s", stage_perfs)

    for s, perf in enumerate(stage_perfs):
        stage_perfs[s] = int(perf / 10**12)

    return stage_perfs



def apply_regression_coefficients(coeffs, debugger, old_perf):
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
def apply_regression_coefficients(coeffs, debugger, old_perf):
    """
    applies the coefficients present in regression's cache_file
    """
    compute_ratio = coeffs.get("COMPUTE")
    for part, raw in list(debugger.info.items()):
        if part in (PerfParts.TOTAL, PerfParts.MEMORY): continue
        if part in (PerfParts.FW_COMPUTE,
                   PerfParts.BW_COMPUTE,
                   PerfParts.RECOMPUTE):
            ratio = compute_ratio
        else:
            ratio = coeffs.get(part.name)
        new_val = 0.0 if raw == 0.0 else raw * ratio
        debugger.info[part] = new_val

    max_idx = max(p.value for p in PerfParts) -1
    estimations = [0.0] * max_idx
    for part in PerfParts:
        if part in (PerfParts.TOTAL, PerfParts.MEMORY): continue
        estimations[part.value - 1] = debugger.info.get(part) or 0.0

    real_buckets = {rp: [] for rp in RealParts}
    real_buckets = estimation_in_real_parts(real_buckets, estimations, old_perf)

    perf = (
            real_buckets[RealParts.COMP][-1]
            + real_buckets[RealParts.DP_WAIT][-1]
            + real_buckets[RealParts.MP_WAIT][-1]
            + real_buckets[RealParts.EP_WAIT][-1]
496
497
498
499
500
501
502
503
504
505
            + real_buckets[RealParts.EP_WAIT][-1]
            + real_buckets[RealParts.CP_WAIT][-1]
            + real_buckets[RealParts.PP_WAIT][-1]
    )
    debugger.info[PerfParts.TOTAL] = perf
    return perf


# performance estimation
def estimate_performance(*args, **kwargs):
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538

# performance estimation
def estimate_performance(*args, **kwargs):
    """main estimation"""
    stages = kwargs.get("stages", args[1] if len(args) > 1 else None)
    extra_custom_func = kwargs.get(
        "extra_custom_func", args[2] if len(args) > 2 else None
    )
    ccfg = kwargs.get("ccfg", args[3] if len(args) > 3 else CustomConfig())
    debugger = kwargs.get("debugger", args[4] if len(args) > 4 else None)
    device_type = kwargs.get(
        "device_type", args[5] if len(args) > 5 else Hard.device_map["A2"]
    )
    memory = kwargs.get("memory", args[6] if len(args) > 6 else None)

    # cfg = CostModelConfig(args[0])
    if isinstance(args[0], CostModelConfig):
        cfg = args[0]
    else:
        cfg = CostModelConfig(args[0])

    # Process custom model config
    if extra_custom_func:
        extra_custom_func(cfg)
    else:
        logger.info("auto applying custom model config")
        check_and_apply_custom_hook(cfg)

    # Process partition generation
    if not stages:
        logger.info("stage partitions are generated")
        stages = cfg.generate_partitions_vpp()

    # print(f"DP = {cfg.d}; MP = {cfg.t}; EP = {cfg.ep}; PP = {cfg.p}")

    # print(list(map(list,
538
539
540
541
542
543
544
545
546
547
548
549
550
    # print(list(map(list,
    # zip(*list(map(lambda x: list(map(len, x)),stages))))))
    # print(stages)

    cfg.n = cfg.d * cfg.t * cfg.p
    cfg.n_headCast = 1
    cfg.n_ffAct = 1

    logger.debug(
        "perf_model: DP = %d, TP = %d, EP = %d, PP = %d, MB = %d",
        cfg.d,
        cfg.t,
        cfg.ep,
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
        cfg.p,
        cfg.m,
    )

    logger.info(str(cfg))
    logger.info(stages)
    logger.info(ccfg)

    compute_perfs = estimate_comp(
        cfg, ccfg, stages, with_recomp=False, debugger=debugger
    )
    recompute_perfs = (
        [0] * cfg.p
        if ccfg.retype not in {RecType.COMPUTE_ONLY, RecType.WITH}
        else estimate_comp(
            cfg, ccfg, stages, with_recomp=True, debugger=debugger
565
566
567
568
569
570
571
572
573
574
575
576
577
        else estimate_comp(
            cfg, ccfg, stages, with_recomp=True, debugger=debugger
        )
    )
    comm_perfs = estimate_comm(
        cfg, ccfg, stages, device_type, with_recomp=False, debugger=debugger
    )
    logger.info("PerfEst: comm_perfs %s", comm_perfs)
    recomm_perfs = (
        [0] * cfg.p
        if ccfg.retype not in {RecType.COMM_ONLY, RecType.WITH}
        else estimate_comm(
            cfg, ccfg, stages, device_type, with_recomp=True, debugger=debugger
577
578
579
580
581
582
583
584
585
            cfg, ccfg, stages, device_type, with_recomp=True, debugger=debugger
        )
    )

    stage_perfs = estimate_stage(
        cfg,
        ccfg,
        compute_perfs,
        comm_perfs,
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
        recompute_perfs,
        recomm_perfs,
        debugger=debugger,
    )
    logger.info("PerfEst: stage_perfs %s", stage_perfs)

    stage_focused = kwargs.get("stage_focused", None)
    perf = estimate_perf(
        cfg, ccfg, stage_perfs, stage_focused=stage_focused, debugger=debugger
    )
    perf += estimate_p2p(cfg, ccfg, stage_perfs, debugger=debugger)
    logger.info("PerfEst: perf %s", perf)

    cache_file = kwargs.get("cache_file")
    coeffs = None
    cache = False
    if cache_file is not None:
        with open(cache_file, 'r', encoding='utf-8') as f:
            coeffs = json.load(f)
        cache = True

    if debugger and debugger.is_enabled():
        if cache:
            perf = apply_regression_coefficients(coeffs, debugger, perf)
        debugger.info[PerfParts.TOTAL] = perf
        if memory is not None:
            debugger.info[PerfParts.MEMORY] = memory
    return perf  # / cfg.gbs

# TO-DO
# Fix More Memory
# Add Context Parallelism
hyper_parallel/auto_parallel/sapp_nd/perf_estimation/getters.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def get_layer_custom_configs(cfg):
    """Stores each configuration along with how many layers are affected by it
    in ascending order of execution in a forward pass"""

    if cfg.layer_custom_config is None or any(
        func is None for (_, func) in cfg.layer_custom_config
    ):
        return [(cfg, cfg.n_lay)]

    lccfgs = []
    for nb_layers, func in cfg.layer_custom_config:
        lccfg = deepcopy(cfg)
        func(lccfg)
        lccfgs.append((lccfg, nb_layers))

    return lccfgs


def get_recomp_factor(lccfg, layer, op_name):
    """recomputation factor"""
    if layer == LayerType.FULL_REC_LAYER:
        return 1
    if layer == LayerType.NOT_REC_LAYER:
        return 0
    if layer == LayerType.SEL_REC_LAYER:
        return getattr(lccfg.rec_op, op_name, 0)
    logger.warning("Unrecognized recompute type %s", layer)
    return 0


def get_table_quantity(lccfg, table, layer, with_recomp):
    """op compute load from given table"""
    qt_layer = 0
    for op, quantity in table.items():
        op_name = op[2:]

        qt_layer += (
            (1 + with_recomp * get_recomp_factor(lccfg, layer, op_name))
            * getattr(lccfg, op)
            * quantity
        )
    return qt_layer
hyper_parallel/auto_parallel/sapp_nd/perf_estimation/utils_classes.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        ttype=PerformanceType.FLOP,
        ptype=P2PCommType.NONE, #MANUAL,
        retype=RecType.COMPUTE_ONLY,
    ):
        self.rtype = rtype
        self.ttype = ttype
        self.ptype = ptype
        self.retype = retype

    def __repr__(self):
        return (
            f"CustomConfig(rtype={self.rtype}, "
            f"ttype={self.ttype}, "
            f"ptype={self.ptype}, "
            f"retype={self.retype})"
79
80
81
82
83
            f"retype={self.retype})"
        )

    def __str__(self):
        return self.__repr__()