Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/auto_parallel/sapp_nd/memory_estimation/_backbone.py 30.0% 659-660,662-666
hyper_parallel/auto_parallel/sapp_nd/memory_estimation/evaluators/body.py 100%  
hyper_parallel/auto_parallel/sapp_nd/memory_estimation/evaluators/comm.py 93.8% 216,251
hyper_parallel/auto_parallel/sapp_nd/memory_estimation/evaluators/layer_block.py 100%  
hyper_parallel/auto_parallel/sapp_nd/nd/common/_cost_model_variables.py 100%  
hyper_parallel/auto_parallel/sapp_nd/nd/common/cost_model_preprocess.py 100%  
hyper_parallel/auto_parallel/sapp_nd/nd/common/cp_types.py 97.4% 46,49
hyper_parallel/auto_parallel/sapp_nd/nd/common/framework_parsers/_cost_model_parser.py 100%  
hyper_parallel/auto_parallel/sapp_nd/nd/common/hardware.py 13.6% 259,261-264,266-268,270,283,285-287,299-304
hyper_parallel/auto_parallel/sapp_nd/nd/dimensions.py 94.1% 232,262,376,378,462
hyper_parallel/auto_parallel/sapp_nd/nd/parallelize.py 22.7% 165-171,173,175-176,178,180,195-197,199-200
hyper_parallel/auto_parallel/sapp_nd/perf_estimation/comm_time.py 82.5% 63,132,496-497,505,560-564
hyper_parallel/auto_parallel/sapp_nd/memory_estimation/_backbone.py
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
            ins["AllGather Comm"] = self.mb(stage_accu[MemType.AG_COMM])
            ins["All2All Comm"] = self.mb(stage_accu[MemType.A2A_COMM])

            if self._ccfg.cp > 1:
                cp_memory = EvalBody.act_cp_layer(self._ccfg, self._ctx)
                cp_comm_buffer = EvalLayerComm.cp_comm_buffer(self._ccfg, self._ctx)

                ins["CP KV Cache"] = self.mb(cp_memory.kv_cache_memory)
                ins["CP Attn Scores"] = self.mb(cp_memory.attention_scores_memory)
                ins["CP Softmax"] = self.mb(cp_memory.softmax_outputs_memory)
                ins["CP Comm Buffer"] = self.mb(cp_comm_buffer)
                ins["CP Reduction"] = self.mb(cp_memory.total_reduction)

            ins["Node Log"] = sm["logs"][stage_id].node_compute_log
            # VERBOSE
            if verbose and spec_stage_id in (-1, stage_id):
hyper_parallel/auto_parallel/sapp_nd/memory_estimation/evaluators/comm.py
212
213
214
215
216
217
218
219
220
        rank holds only 1/t of the total KV dimension.  MLA is an
        exception: the compressed latent vector is not split by TP,
        so kv_lora_rank stays unchanged.
        """
        return compute_kv_dim(ccfg)

    # pylint: disable=unused-argument
    @staticmethod
    def cp_comm_buffer(ccfg: CostModelConfig, ctx: Context) -> float:
247
248
249
250
251
252
253
254
255
            intra_ranks = min(int(cp), int(ccfg.device_per_node))
            if cp <= ccfg.device_per_node:
                extra_chunks = intra_ranks - 1
            else:
                extra_chunks = 2 * intra_ranks - 1
            return extra_chunks * chunk

        kv_dim = compute_kv_dim(ccfg)
        chunk = (s / cp) * b * kv_dim * kv_bytes
hyper_parallel/auto_parallel/sapp_nd/nd/common/cp_types.py
42
43
44
45
46
47
48
49
50
51
52
53
def _resolve_cp_algo(ccfg) -> CPAlgo:
    """Resolve cp_algo from ccfg, defaulting to COLOSSALAI_CP."""
    raw = getattr(ccfg, 'cp_algo', None)
    if isinstance(raw, CPAlgo):
        return raw
    if isinstance(raw, str):
        return _CP_ALGO_STR_MAP.get(raw, CPAlgo.COLOSSALAI_CP)
    return CPAlgo.COLOSSALAI_CP


@dataclass
class CPMemoryBreakdown:
hyper_parallel/auto_parallel/sapp_nd/nd/common/hardware.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        - topology_type: "intra-node" or "cross-node"
        - effective_bandwidth: Bandwidth in GB/s
        - is_intra_node: True if CP stays within node
    """
    total_devices_needed = tp_degree * cp_degree

    if total_devices_needed <= device_per_node:
        topology_type = "intra-node"
        is_intra_node = True
        effective_bandwidth = 300.0
    else:
        topology_type = "cross-node"
        is_intra_node = False
        effective_bandwidth = 25.0

    return topology_type, effective_bandwidth, is_intra_node


def get_cp_bandwidth(topology_type: str, device_type: str = "A2") -> float:
    """Get effective bandwidth for CP communication based on topology.
279
280
281
282
283
284
285
286
287
288
289
290
291

    Returns:
        Bandwidth in GB/s
    """
    device = device_map.get(device_type, Device_A2)

    if topology_type == "intra-node":
        return device.level_bandwidth[0] if device.level_bandwidth else 300.0
    return device.level_bandwidth[1] if len(device.level_bandwidth) > 1 else 25.0


def recommend_cp_max_by_attention(attention_type: str) -> int:
    """Recommend maximum CP degree based on attention type.
295
296
297
298
299
300
301
302
303
304

    Returns:
        Recommended maximum CP degree
    """
    attention_type_upper = attention_type.upper()
    if attention_type_upper == "MLA":
        return 16
    if attention_type_upper == "GQA":
        return 8
    return 4
hyper_parallel/auto_parallel/sapp_nd/nd/dimensions.py
228
229
230
231
232
233
234
235
236
    @staticmethod
    def _check_mbn_pp(dims_val, all_dims):
        """Return True if MBN/PP combination is valid."""
        if MBN not in dims_val or PP not in all_dims:
            return True
        valid = dims_val[MBN] >= dims_val[PP]
        valid = valid and not (dims_val[PP] == 1 and dims_val[MBN] > 1)
        if not valid:
            logger.warning("PP and MBN were deemed not suitable")
258
259
260
261
262
263
264
265
266
            if self.dims_val[SP] and self.dims_val[CP] > 1:
                logger.warning("SP & CP cannot coexist")
                return False
        if OP in self.all_dims and not self._check_power_of_two(OP, self.dims_val[OP]):
            return False
        return True

    def val(self, dim):
        """Get Dimension value"""
372
373
374
375
376
377
378
379
380
381
382
        )

    attn_upper = p.attention_type_str.upper()
    if attn_upper == "MLA":
        recommended_cp_max = 16
    elif attn_upper == "GQA":
        recommended_cp_max = 8
    else:
        recommended_cp_max = 4

    if recommended_cp_max is not None and p.cp_degree > recommended_cp_max:
458
459
460
461
462
463
464
465
466
        cp_algo, attention_heads, sp_enabled,
    )

    if p.cp_degree <= 1:
        return _cp_ok_result()

    if p.sp_enabled:
        return _cp_ok_result(
            is_valid=False,
hyper_parallel/auto_parallel/sapp_nd/nd/parallelize.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
            logger.warning("Config manually filtered out")
            return False

        if hasattr(parallel_config, 'dims_val') and Dim.CP in parallel_config.dims_val:
            cp_degree = parallel_config.dims_val[Dim.CP]
            if cp_degree > 1:
                seq_len = self.config.ccfg.s
                tp_degree = parallel_config.dims_val.get(Dim.TP, 1)
                pp_degree = parallel_config.dims_val.get(Dim.PP, 1)
                device_per_node = self.machine.device.intra_node_num()
                total_devices = self.machine.number

                attention_type = detect_attention_type(self.config.ccfg).name.lower()

                bw_intra = self.config.ccfg.bw_intra
                bw_inter = self.config.ccfg.bw_inter

                sp_enabled = bool(parallel_config.dims_val.get(Dim.SP, False))

                cp_result = validate_cp_constraints(
                    seq_len=seq_len,
                    cp_degree=cp_degree,
                    tp_degree=tp_degree,
                    pp_degree=pp_degree,
191
192
193
194
195
196
197
198
199
200
201
202
203
204
                    cp_algo=getattr(self.config.ccfg, 'cp_algo', 'colossalai_cp'),
                    attention_heads=self.config.ccfg.a,
                )

                if not cp_result.is_valid:
                    logger.warning("CP constraints violated: %s", cp_result.error_message)
                    return False

                if cp_result.warning_message:
                    logger.info("CP warning: %s", cp_result.warning_message)

        gbs = self.config.global_batch_size(parallel_config)
        if not gbs == self.global_batch_size:
            logger.error(
hyper_parallel/auto_parallel/sapp_nd/perf_estimation/comm_time.py
59
60
61
62
63
64
65
66
67
    intra_ranks = min(int(cp), int(device_per_node))
    if cp <= device_per_node:
        return "intra-node", bw_intra
    if intra_ranks == 1:
        return "cross-node", bw_inter
    intra_fraction = (intra_ranks - 1) / (cp - 1)
    cross_fraction = 1.0 - intra_fraction
    bw = intra_fraction * bw_intra + cross_fraction * bw_inter
    return "mixed", bw
128
129
130
131
132
133
134
135
136
    cp = ccfg.cp
    t = max(1, ccfg.t)

    if ccfg.a <= 0:
        raise ValueError(f"Number of attention heads must be positive, got {ccfg.a}")

    kv_dim = compute_kv_dim(ccfg)
    attention_type = detect_attention_type(ccfg)
    cp_algo = _resolve_cp_algo(ccfg)
492
493
494
495
496
497
498
499
500
                comm[Dim.EP] += EvalLayerComm.ep_comm_layer(
                    param["cfg"], param["ctx"], 1
                )  # * param["cfg"].ep
                if param["cfg"].cp > 1:
                    cp_comm_details = cp_comm_layer_detailed(param["cfg"], param["ctx"])
                    comm[Dim.CP] += cp_comm_details.comm_volume
                # min(device_type.level_bound_number[0], param["cfg"].ep)
                # comm_cp += EvalLayerComm.cp_comm_layer
                # (param["cfg"], param["ctx"])
501
502
503
504
505
506
507
508
509



        if param["ccfg"].ttype == PerformanceType.TIME:
            for dim, ov in zip([Dim.DP, Dim.TP, Dim.CP], [0.0, 0, 0.0]):
                comm[dim] = estimate_comm_score(
                    param["cfg"],
                    comm[dim],
                    dim,
556
557
558
559
560
561
562
563
564
565
566
567
568
        param["debugger"].info[PerfParts.EP_COMM] = comms[Dim.EP]
        param["debugger"].info[PerfParts.CP_COMM] = comms[Dim.CP]

        if param["cfg"].cp > 1:
            cp_comm_details = cp_comm_layer_detailed(param["cfg"], param["ctx"])
            param["debugger"].info["CP_KV_VOLUME"] = cp_comm_details.total_kv_volume
            param["debugger"].info["CP_EXPOSED_TIME"] = cp_comm_details.exposed_comm_time
            param["debugger"].info["CP_TOPOLOGY"] = cp_comm_details.topology
            param["debugger"].info["CP_BANDWIDTH"] = cp_comm_details.effective_bandwidth

    res = []
    for i, c in enumerate(comms[Dim.TP]):
        res += [c + comms[Dim.DP][i] + comms[Dim.EP][i] + comms[Dim.CP][i]]