Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / pipeline_parallel / scheduler.py: 13%
705 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""pipeline schedule"""
16from abc import ABC
17from enum import Enum, auto
18from collections import defaultdict
19import itertools
20import bisect
21import logging
22import re
23import hyper_parallel
24from hyper_parallel.platform import get_platform
25platform = get_platform()
26logger = logging.getLogger(__name__)
29class MetaStepType(Enum):
30 """Specify the enumeration type for MetaStep."""
31 FWD = auto()
32 BWD = auto()
33 FWD_RECV = auto()
34 FWD_SEND = auto()
35 BWD_RECV = auto()
36 BWD_SEND = auto()
37 OVERLAP_F_B = auto()
38 OVERLAP_B_F = auto()
41class MetaStep:
42 """
43 Meta step of PipelineSchedule.
44 An execution list composed of MetaStep can be constructed
45 and fed into the PipelineSchedule for execution.
47 Args:
48 micro_index (int | None): The index of micro-batch. ``None`` for
49 composite types (``OVERLAP_F_B`` / ``OVERLAP_B_F``) whose real
50 micro index lives in each ``sub_steps`` entry.
51 type (MetaStepType): Specify the type of current step.
52 stage_index (int | None): Stage index of current step. ``None``
53 for composite types; use ``sub_steps`` to get each direction's
54 stage.
55 sub_steps (tuple[MetaStep, MetaStep] | None): For composite types
56 only: ``(fwd, bwd)`` for ``OVERLAP_F_B``, ``(bwd, fwd)`` for
57 ``OVERLAP_B_F``.
58 """
59 def __init__(self, micro_index, meta_type, stage_index, sub_steps=None):
60 self._type = meta_type
61 self._micro_index = micro_index
62 self._stage_index = stage_index
63 self._sub_steps = sub_steps
65 @property
66 def micro_index(self):
67 return self._micro_index
69 @property
70 def stage_index(self):
71 return self._stage_index
73 @property
74 def type(self):
75 return self._type
77 @property
78 def sub_steps(self):
79 """Sub-steps for composite types: ``(fwd, bwd)`` for OVERLAP_F_B,
80 ``(bwd, fwd)`` for OVERLAP_B_F, or ``None``."""
81 return self._sub_steps
83 def __eq__(self, value):
84 if not isinstance(value, MetaStep):
85 return NotImplemented
86 return self.type == value.type and \
87 self.micro_index == value.micro_index and \
88 self.stage_index == value.stage_index
90 def __ne__(self, value):
91 if not isinstance(value, MetaStep):
92 return NotImplemented
93 return self.type != value.type or \
94 self.micro_index != value.micro_index or \
95 self.stage_index != value.stage_index
97 def __hash__(self):
98 return hash((self.type, self.micro_index, self.stage_index))
100 def __str__(self):
101 return f"MetaStep(type={self.type}, micro_index={self.micro_index}, stage_index={self.stage_index})"
103 def __repr__(self):
104 return f"MetaStep(type={self.type}, micro_index={self.micro_index}, stage_index={self.stage_index})"
106 @staticmethod
107 def from_str(step_str):
108 pass
111class PipelineContext:
112 """Context passed to custom execution functions registered via
113 :meth:`PipelineScheduleRuntime.register_custom_function`.
115 Provides access to the schedule's internal state so that custom
116 handlers (e.g. OVERLAP_F_B callbacks) can perform P2P communication,
117 invoke ``forward_one_chunk`` / ``backward_one_chunk``, record losses,
118 etc.
120 Attributes:
121 schedule: The :class:`PipelineScheduleRuntime` instance.
122 arg_mbs: Per-micro-batch positional args.
123 kwarg_mbs: Per-micro-batch keyword args.
124 losses: Mutable list for loss collection.
125 fwd_recv_ops: ``{(stage_index, micro_index): [handle, ...]}``
126 cached forward recv handles (when ``overlap_p2p=True``).
127 bwd_recv_ops: Same for backward recv handles.
128 send_handles: Mutable list of outstanding send handles.
129 """
131 def __init__(self, schedule, arg_mbs, kwarg_mbs, losses, send_handles):
132 self.schedule = schedule
133 self.arg_mbs = arg_mbs
134 self.kwarg_mbs = kwarg_mbs
135 self.losses = losses
136 self.fwd_recv_ops = schedule.fwd_handle_cache
137 self.bwd_recv_ops = schedule.bwd_handle_cache
138 self.send_handles = send_handles
141class PipelineScheduleRuntime(ABC):
142 """
143 Base class for pipeline schedule.
144 Implements the `split_microbatches` and `run_microbatches` method.
145 Derived classes should implement `run_microbatches` method and `run` method.
147 Supports registering **custom execution functions** for any
148 :class:`MetaStepType` via :meth:`register_custom_function`. When
149 ``run_microbatches`` encounters a step whose type has a registered
150 handler, it creates a :class:`PipelineContext` and delegates execution
151 to the handler instead of using the built-in logic.
153 Args:
154 stages (list[PipelineStage], PipelineStage): PipelineStage used to run_microbatches.
155 micro_batch_num (int): The number of micro-batch.
156 args_batch_dim (list, optional): Specify the batch dim of the args.
157 Default ``None``.
158 kwargs_batch_dim (dict, optional): Specify the batch dim of the kwargs.
159 Default ``None``.
160 """
161 def __init__(self,
162 stages,
163 micro_batch_num,
164 args_batch_dim=None,
165 kwargs_batch_dim=None,
166 output_concat_dim=None,
167 overlap_p2p=False):
168 self.stages = self._check_stages(stages)
169 self.micro_batch_num = micro_batch_num
170 self._args_batch_dim = args_batch_dim
171 self._kwargs_batch_dim = kwargs_batch_dim
172 self._output_concat_dim = output_concat_dim
173 self.split_micro_batch = platform.micro_batch(self.micro_batch_num,
174 self._args_batch_dim, self._kwargs_batch_dim)
175 self.n_local_stages = len(self.stages)
176 self._stage_dict = self.convert_stages_dict()
177 self.real_stage_num = self.stages[0].stage_num // self.n_local_stages
178 self._stage_num = self.stages[0].stage_num
179 self._overlap_p2p = overlap_p2p
180 self.exec_order = {}
181 self._init_stages()
182 self.fwd_handle_cache = {}
183 self.bwd_handle_cache = {}
184 self._custom_fn_map = {}
186 def register_custom_function(self, step_type: MetaStepType, fn) -> None:
187 """Register a custom execution function for the given step type.
189 When :meth:`run_microbatches` encounters a :class:`MetaStep` whose
190 ``type`` matches ``step_type``, it calls ``fn(step, ctx)`` instead
191 of the built-in logic.
193 Args:
194 step_type: The :class:`MetaStepType` to intercept.
195 fn: A callable with signature ``(step: MetaStep, ctx: PipelineContext) -> None``.
197 Example:
198 >>> def my_overlap_callback(step, ctx):
199 ... fwd_step, bwd_step = step.sub_steps
200 ... # custom parallel execution logic
201 >>> schedule.register_custom_function(MetaStepType.OVERLAP_F_B, my_overlap_callback)
202 """
203 self._custom_fn_map[step_type] = fn
205 def convert_stages_dict(self):
206 """convert stages to dict."""
207 stage_dict = {}
208 for stage in self.stages:
209 stage_dict[stage.stage_index] = stage
210 return stage_dict
212 def split_microbatches(self, args, kwargs):
213 """split_microbatches."""
214 if args or kwargs:
215 args_split, kwargs_split = self.split_micro_batch(args, kwargs)
216 return args_split, kwargs_split
217 return [[] for _ in range(self.micro_batch_num)], [{} for _ in range(self.micro_batch_num)]
219 def _check_stages(self, stages):
220 """check stages type."""
221 if isinstance(stages, hyper_parallel.PipelineStage):
222 return [stages]
223 if isinstance(stages, (list, tuple)):
224 for stage in stages:
225 if not isinstance(stage, hyper_parallel.PipelineStage):
226 raise TypeError(f"Argument 'stages' must be type of PipelineStage, \
227 list or tuple of PipelineStage, but got list or tuple of {type(stage)}.")
228 return stages
229 raise TypeError(f"Argument 'stages' must be type of PipelineStage, \
230 list or tuple of PipelineStage, but got type of {type(stages)}.")
232 def _init_stages(self):
233 """init stages."""
234 for stage in self.stages:
235 stage.init(self.n_local_stages)
237 def run(self, *args, **kwargs):
238 """schedule run."""
239 split_args, split_kwargs = self.split_microbatches(args, kwargs)
240 losses = []
241 self.run_microbatches(split_args, split_kwargs, losses)
242 return losses
244 def sync_shared_parameters_grad(self):
245 """sync_shared_parameters_grad."""
246 for stage in self.stages:
247 stage.sync_shared_parameters_grad()
249 def update_losses(self, stage, loss, losses):
250 """update_losses."""
251 if stage.is_last_stage:
252 losses.append(loss)
254 def _wait_p2p(self, handles):
255 for handle in handles:
256 if handle is not None:
257 handle.wait()
259 def _exec_step(self, cur_step, arg_mbs, kwarg_mbs, losses, send_handles):
260 """Execute a single built-in step (FWD/BWD/SEND/RECV)."""
261 stage = self._stage_dict[cur_step.stage_index]
262 stage_index = cur_step.stage_index
263 micro_index = cur_step.micro_index
265 if cur_step.type == MetaStepType.FWD_RECV:
266 comm_handle = stage.exec_fwd_recv_ops(micro_index)
267 if not self._overlap_p2p:
268 self._wait_p2p(comm_handle)
269 else:
270 self.fwd_handle_cache[(stage_index, micro_index)] = comm_handle
272 elif cur_step.type == MetaStepType.FWD:
273 key = (stage_index, micro_index)
274 if self._overlap_p2p and key in self.fwd_handle_cache:
275 self._wait_p2p(self.fwd_handle_cache.pop(key))
276 out = stage.forward_one_chunk(micro_index, arg_mbs[micro_index], kwarg_mbs[micro_index])
277 self.update_losses(stage, out, losses)
279 elif cur_step.type == MetaStepType.FWD_SEND:
280 comm_handle = stage.exec_fwd_send_ops(micro_index)
281 if not self._overlap_p2p:
282 self._wait_p2p(comm_handle)
283 else:
284 send_handles.append(comm_handle)
286 elif cur_step.type == MetaStepType.BWD_RECV:
287 comm_handle = stage.exec_bwd_recv_ops(micro_index)
288 if not self._overlap_p2p:
289 self._wait_p2p(comm_handle)
290 else:
291 self.bwd_handle_cache[(stage_index, micro_index)] = comm_handle
293 elif cur_step.type == MetaStepType.BWD:
294 key = (stage_index, micro_index)
295 if self._overlap_p2p and key in self.bwd_handle_cache:
296 self._wait_p2p(self.bwd_handle_cache.pop(key))
297 last_bwd = micro_index == self.micro_batch_num - 1
298 stage.backward_one_chunk(micro_index, last_bwd)
300 elif cur_step.type == MetaStepType.BWD_SEND:
301 comm_handle = stage.exec_bwd_send_ops(micro_index)
302 if not self._overlap_p2p:
303 self._wait_p2p(comm_handle)
304 else:
305 send_handles.append(comm_handle)
307 def run_microbatches(self, arg_mbs, kwarg_mbs, losses):
308 """Execute the schedule step by step.
310 Steps whose :attr:`MetaStep.type` has a registered custom function
311 are delegated to that function with a :class:`PipelineContext`.
312 Composite ``OVERLAP_F_B`` / ``OVERLAP_B_F`` steps without a
313 registered handler fall back to executing their ``sub_steps``
314 sequentially via :meth:`_exec_step` — correct but without
315 comm/compute overlap. All other steps are executed by
316 :meth:`_exec_step`.
317 """
318 real_stage_index = self.stages[0].stage_index % self.real_stage_num
319 send_handles = []
320 ctx = None # lazily created
322 for cur_step in self.exec_order[real_stage_index]:
323 if cur_step is None:
324 continue
326 # Check for registered custom function
327 custom_fn = self._custom_fn_map.get(cur_step.type)
328 if custom_fn is not None:
329 if ctx is None:
330 ctx = PipelineContext(self, arg_mbs, kwarg_mbs, losses, send_handles)
331 custom_fn(cur_step, ctx)
332 continue
334 # Default for composite OVERLAP steps: run sub_steps sequentially.
335 # P2P send/recv around these steps are already laid out in two
336 # virtual slots by ``add_send_recv``, so sequential execution is
337 # semantically equivalent to non-overlapped 1F1B.
338 if (cur_step.type in (MetaStepType.OVERLAP_F_B, MetaStepType.OVERLAP_B_F)
339 and cur_step.sub_steps):
340 for sub in cur_step.sub_steps:
341 self._exec_step(sub, arg_mbs, kwarg_mbs, losses, send_handles)
342 continue
344 self._exec_step(cur_step, arg_mbs, kwarg_mbs, losses, send_handles)
346 self.sync_shared_parameters_grad()
347 while send_handles:
348 self._wait_p2p(send_handles.pop())
351class _OverlapPhantom:
352 """Internal marker used by :func:`add_send_recv` to expand an
353 ``OVERLAP_F_B`` or ``OVERLAP_B_F`` step into two virtual time slots.
355 An overlap step composes two sub-steps (``B + F`` or ``F + B``) that
356 execute concurrently on the GPU but occupy **two** logical time slots
357 in the column-scan sender timeline — the sender can only finish
358 emitting the second sub-step's output after the first sub-step has
359 completed. Treating an overlap step as a single slot places the RECV
360 triggered by the second sub-step too early on the receiver.
362 Each overlap step is expanded into two phantoms:
363 * ``is_first_half=True`` — represents the first sub-step's emission
364 slot; the original overlap step is emitted into the output
365 schedule here (only once).
366 * ``is_first_half=False`` — represents the second sub-step's emission
367 slot; only its send/recv comms are inserted.
368 """
370 __slots__ = ('obf_step', 'sub_step', 'is_first_half')
372 def __init__(self, obf_step, sub_step, is_first_half: bool):
373 self.obf_step = obf_step
374 self.sub_step = sub_step
375 self.is_first_half = is_first_half
378def _expand_overlap_slots(scheduler, real_stage_num):
379 """Expand OVERLAP steps in a per-rank schedule into 2 virtual time slots.
381 Returns a new ``{rank: [MetaStep | _OverlapPhantom | None, ...]}`` dict
382 where each OVERLAP step is replaced by a pair of phantoms. Non-OVERLAP
383 entries pass through unchanged.
384 """
385 expanded = {}
386 for rank in range(real_stage_num):
387 order = scheduler[rank]
388 exp = []
389 for op in order:
390 if (op is not None
391 and op.type in (MetaStepType.OVERLAP_F_B, MetaStepType.OVERLAP_B_F)
392 and op.sub_steps):
393 exp.append(_OverlapPhantom(op, op.sub_steps[0], is_first_half=True))
394 exp.append(_OverlapPhantom(op, op.sub_steps[1], is_first_half=False))
395 else:
396 exp.append(op)
397 expanded[rank] = exp
398 return expanded
401def _process_rank_items(real_stage_num, current_items, insert_step_comms, new_schedule):
402 """Run ``insert_step_comms`` for each rank's current item, even ranks first.
404 Even-before-odd ordering avoids P2P deadlocks between adjacent ranks.
405 """
406 for rank in range(0, real_stage_num, 2):
407 item = current_items.get(rank)
408 if item is not None:
409 sub = item.sub_step if isinstance(item, _OverlapPhantom) else item
410 insert_step_comms(sub, rank, new_schedule)
411 for rank in range(1, real_stage_num, 2):
412 item = current_items.get(rank)
413 if item is not None:
414 sub = item.sub_step if isinstance(item, _OverlapPhantom) else item
415 insert_step_comms(sub, rank, new_schedule)
418def _column_scan_insert_comms(expanded, real_stage_num, insert_step_comms):
419 """Column-scan over an OVERLAP-expanded schedule to insert SEND/RECV.
421 Processes ``expanded`` one time slot at a time. Emits the original
422 overlap step into ``new_schedule`` only once (at the first-half
423 phantom). Delegates comm insertion to ``insert_step_comms`` for each
424 plain step or phantom's underlying sub-step.
426 Even ranks are processed before odd ranks at each time step to avoid
427 P2P deadlocks between adjacent ranks.
429 Args:
430 expanded: Result of :func:`_expand_overlap_slots`.
431 real_stage_num: Number of physical ranks.
432 insert_step_comms: Callable ``(step, rank, new_schedule) -> None``
433 that inserts SEND/RECV for a single FWD/BWD step.
435 Returns:
436 ``{rank: [MetaStep, ...]}`` final schedule.
437 """
438 max_length = max(len(order) for order in expanded.values())
439 new_schedule = {rank: [] for rank in range(real_stage_num)}
441 for time_step in range(max_length):
442 current_items = {}
443 for rank in range(real_stage_num):
444 if time_step < len(expanded[rank]):
445 item = expanded[rank][time_step]
446 current_items[rank] = item
447 if item is None:
448 # Preserve bubble slots to keep per-rank time-step
449 # indexing aligned with the column scan. The runtime
450 # loop skips ``None`` entries, so this is execution-
451 # semantics-neutral.
452 new_schedule[rank].append(None)
453 continue
454 if isinstance(item, _OverlapPhantom):
455 # Emit the overlap step only once, at the first-half slot.
456 if item.is_first_half:
457 new_schedule[rank].append(item.obf_step)
458 else:
459 new_schedule[rank].append(item)
460 else:
461 current_items[rank] = None
463 _process_rank_items(
464 real_stage_num, current_items, insert_step_comms, new_schedule,
465 )
467 return new_schedule
470def add_send_recv(scheduler, stage_num, real_stage_num, style='loop'):
471 """Insert P2P send/recv operations into a per-rank compute schedule.
473 For each FWD or BWD step that requires cross-rank communication, a
474 ``FWD_SEND`` / ``BWD_SEND`` is appended to the sender's schedule and a
475 ``FWD_RECV`` / ``BWD_RECV`` is appended to the receiver's schedule.
477 ``OVERLAP_F_B`` / ``OVERLAP_B_F`` composite steps are expanded into
478 **two** virtual time slots during the column scan so that the RECV
479 triggered by the **second** sub-step lands in the receiver's schedule
480 one slot later — matching the fact that the sender can only finish
481 emitting the second sub-step's output after the first completes.
483 Even ranks are processed before odd ranks at each time step to avoid
484 P2P deadlocks between adjacent ranks.
486 Args:
487 scheduler: ``{rank: [MetaStep | None, ...]}`` — compute schedule
488 with ``None`` for bubble slots.
489 stage_num: Total number of virtual pipeline stages.
490 real_stage_num: Number of physical ranks.
491 style: Topology mapping — ``'loop'`` or ``'v'``.
493 Returns:
494 ``{rank: [MetaStep, ...]}`` — schedule with communication ops inserted.
495 """
497 def stage_to_rank(stage_index: int) -> int:
498 """Map a virtual stage index to its physical rank."""
499 if style == 'loop':
500 return stage_index % real_stage_num
501 if style == 'v':
502 if stage_index < real_stage_num:
503 return stage_index
504 return stage_num - 1 - stage_index
505 raise ValueError(f"Argument 'style' must be 'loop' or 'v', but got {style!r}.")
507 def _fwd_peer(stage_index: int):
508 """Return the rank that receives this stage's forward output, or None."""
509 if stage_index >= stage_num - 1:
510 return None
511 peer = stage_to_rank(stage_index + 1)
512 return peer if peer != stage_to_rank(stage_index) else None
514 def _bwd_peer(stage_index: int):
515 """Return the rank that receives this stage's backward gradient, or None."""
516 if stage_index <= 0:
517 return None
518 peer = stage_to_rank(stage_index - 1)
519 return peer if peer != stage_to_rank(stage_index) else None
521 def _insert_comms_for_step(step, rank, new_schedule):
522 """Insert send/recv for a single FWD, BWD, or composite OVERLAP step."""
523 if step is None:
524 return
526 if step.type == MetaStepType.FWD:
527 peer = _fwd_peer(step.stage_index)
528 if peer is not None:
529 new_schedule[rank].append(
530 MetaStep(step.micro_index, MetaStepType.FWD_SEND, step.stage_index))
531 new_schedule[peer].append(
532 MetaStep(step.micro_index, MetaStepType.FWD_RECV, step.stage_index + 1))
534 elif step.type == MetaStepType.BWD:
535 peer = _bwd_peer(step.stage_index)
536 if peer is not None:
537 new_schedule[rank].append(
538 MetaStep(step.micro_index, MetaStepType.BWD_SEND, step.stage_index))
539 new_schedule[peer].append(
540 MetaStep(step.micro_index, MetaStepType.BWD_RECV, step.stage_index - 1))
542 elif step.type in (MetaStepType.OVERLAP_F_B, MetaStepType.OVERLAP_B_F) and step.sub_steps:
543 for sub in step.sub_steps:
544 _insert_comms_for_step(sub, rank, new_schedule)
546 # --- Main logic: expand OVERLAP steps into 2 virtual slots, then scan ---
547 expanded = _expand_overlap_slots(scheduler, real_stage_num)
548 return _column_scan_insert_comms(expanded, real_stage_num, _insert_comms_for_step)
551_ALIGN_PAD = object()
552"""Sentinel marking a forced 1F1B-boundary bubble produced during alignment."""
555def _step_dep_ready(step, rank, t, done, stage_num, stage_to_rank):
556 """Cross-rank data dependency check used by the alignment simulator.
558 A FWD step at stage ``s`` depends on FWD at stage ``s-1`` (on a
559 different rank); BWD at stage ``s`` depends on BWD at stage ``s+1``.
560 Steps at boundaries or whose producer lives on the same rank are
561 always ready.
562 """
563 si, mi = step.stage_index, step.micro_index
564 if step.type == MetaStepType.FWD:
565 if si == 0 or stage_to_rank(si - 1) == rank:
566 return True
567 key = (MetaStepType.FWD, si - 1, mi)
568 return key in done and done[key] < t
569 if step.type == MetaStepType.BWD:
570 if si == stage_num - 1 or stage_to_rank(si + 1) == rank:
571 return True
572 key = (MetaStepType.BWD, si + 1, mi)
573 return key in done and done[key] < t
574 return True
577def _simulate_aligned_schedule(padded, stage_num, real_stage_num, stage_to_rank):
578 """Simulate execution time-step by time-step, inserting bubbles where
579 a step is not yet ready (cross-rank dep) or where the cooldown
580 rhythm requires it.
582 Args:
583 padded: ``{rank: [step | _ALIGN_PAD | None, ...]}`` after
584 1F1B-boundary padding.
585 stage_num: Total number of virtual pipeline stages.
586 real_stage_num: Number of physical ranks.
587 stage_to_rank: Topology mapping from stage to rank.
589 Returns:
590 ``{rank: [step | None, ...]}`` ready for the column-scan SEND/RECV
591 insertion phase.
592 """
593 remaining_fwd = {
594 rank: sum(
595 1 for s in padded[rank]
596 if s is not _ALIGN_PAD and s is not None and s.type == MetaStepType.FWD
597 )
598 for rank in range(real_stage_num)
599 }
600 cursors = {r: 0 for r in range(real_stage_num)}
601 aligned = {r: [] for r in range(real_stage_num)}
602 done = {}
603 last_was_cooldown_bwd = {r: False for r in range(real_stage_num)}
604 max_t = sum(len(v) for v in padded.values()) + real_stage_num * 20
606 def _emit_bubble(rank):
607 aligned[rank].append(None)
608 last_was_cooldown_bwd[rank] = False
610 def _emit_step(rank, step, t, in_cooldown):
611 aligned[rank].append(step)
612 done[(step.type, step.stage_index, step.micro_index)] = t
613 cursors[rank] += 1
614 if step.type == MetaStepType.FWD:
615 remaining_fwd[rank] -= 1
616 last_was_cooldown_bwd[rank] = in_cooldown and step.type == MetaStepType.BWD
618 def _step_rank_at(t, rank):
619 if cursors[rank] >= len(padded[rank]):
620 return
621 item = padded[rank][cursors[rank]]
622 if item is _ALIGN_PAD:
623 _emit_bubble(rank)
624 cursors[rank] += 1
625 return
626 in_cooldown = remaining_fwd[rank] == 0
627 # Cooldown rhythm: alternate None / BWD in pure-BWD phase.
628 cooldown_skip = (
629 in_cooldown
630 and item.type == MetaStepType.BWD
631 and last_was_cooldown_bwd[rank]
632 )
633 if cooldown_skip:
634 _emit_bubble(rank)
635 return
636 if not _step_dep_ready(item, rank, t, done, stage_num, stage_to_rank):
637 _emit_bubble(rank)
638 return
639 _emit_step(rank, item, t, in_cooldown)
641 for t in range(max_t):
642 if all(cursors[r] >= len(padded[r]) for r in range(real_stage_num)):
643 break
644 for rank in range(real_stage_num):
645 _step_rank_at(t, rank)
646 return aligned
649def auto_align_and_add_send_recv(scheduler, stage_num, real_stage_num, style='loop'):
650 """Auto-insert bubble alignment and P2P send/recv into a pure-compute schedule.
652 Unlike :func:`add_send_recv` which requires the caller to pre-insert
653 ``None`` bubble slots for time-step alignment, this function accepts a
654 **pure compute order** (``FWD`` / ``BWD`` only, no ``None`` needed) and
655 automatically determines bubble placement via execution simulation.
657 Three constraints are enforced:
659 1. **Data dependency** — a ``FWD(stage_k)`` cannot execute until
660 ``FWD(stage_{k-1})`` on its source rank has completed (and
661 analogously for ``BWD``).
662 2. **1F1B transition alignment** — ``real_stage_num - 1 - rank`` padding
663 slots are inserted at the warmup → 1F1B boundary (detected as the
664 first ``FWD`` immediately followed by a ``BWD`` in the compute order)
665 so that all ranks enter the 1F1B steady state in lockstep.
666 3. **Cooldown rhythm** — once a rank exhausts its ``FWD`` ops and enters
667 pure-``BWD`` cooldown, consecutive ``BWD`` steps are separated by a
668 ``None`` slot, maintaining the column-phase-sync property (no rank
669 does ``BWD`` while another does ``FWD`` at the same time step).
671 After alignment, a column-scan pass inserts ``FWD_SEND`` / ``FWD_RECV``
672 and ``BWD_SEND`` / ``BWD_RECV`` with the same prefetch semantics as
673 :func:`add_send_recv`.
675 Args:
676 scheduler: ``{rank: [MetaStep, ...]}`` — pure compute schedule.
677 ``None`` entries are silently stripped before processing.
678 stage_num: Total number of virtual pipeline stages.
679 real_stage_num: Number of physical ranks.
680 style: Topology mapping — ``'loop'`` or ``'v'``.
682 Returns:
683 ``{rank: [MetaStep, ...]}`` — fully aligned schedule with bubbles
684 and communication ops inserted.
685 """
687 # ---- topology helpers (shared with column-scan phase) ----
689 def stage_to_rank(stage_index: int) -> int:
690 if style == 'loop':
691 return stage_index % real_stage_num
692 if style == 'v':
693 if stage_index < real_stage_num:
694 return stage_index
695 return stage_num - 1 - stage_index
696 raise ValueError(f"Argument 'style' must be 'loop' or 'v', but got {style!r}.")
698 def _fwd_peer(stage_index: int):
699 if stage_index >= stage_num - 1:
700 return None
701 peer = stage_to_rank(stage_index + 1)
702 return peer if peer != stage_to_rank(stage_index) else None
704 def _bwd_peer(stage_index: int):
705 if stage_index <= 0:
706 return None
707 peer = stage_to_rank(stage_index - 1)
708 return peer if peer != stage_to_rank(stage_index) else None
710 # ---- Phase 1: strip None, detect 1F1B boundary, insert transition padding ----
712 def _find_1f1b_boundary(order):
713 """Index of the first FWD followed by BWD; ``len(order)`` if absent."""
714 for i in range(len(order) - 1):
715 if (order[i].type == MetaStepType.FWD
716 and order[i + 1].type == MetaStepType.BWD):
717 return i
718 return len(order)
720 padded = {}
721 for rank in range(real_stage_num):
722 order = [s for s in scheduler[rank] if s is not None]
723 boundary = _find_1f1b_boundary(order)
724 pad_count = real_stage_num - 1 - rank
725 padded[rank] = order[:boundary] + [_ALIGN_PAD] * pad_count + order[boundary:]
727 # ---- Phase 2: simulate execution with data deps + cooldown rhythm ----
729 aligned = _simulate_aligned_schedule(padded, stage_num, real_stage_num, stage_to_rank)
731 # ---- Phase 3: column-scan SEND/RECV insertion (same as add_send_recv) ----
733 def _insert_comms_for_step(step, rank, new_schedule):
734 if step is None:
735 return
736 if step.type == MetaStepType.FWD:
737 peer = _fwd_peer(step.stage_index)
738 if peer is not None:
739 new_schedule[rank].append(
740 MetaStep(step.micro_index, MetaStepType.FWD_SEND, step.stage_index))
741 new_schedule[peer].append(
742 MetaStep(step.micro_index, MetaStepType.FWD_RECV, step.stage_index + 1))
743 elif step.type == MetaStepType.BWD:
744 peer = _bwd_peer(step.stage_index)
745 if peer is not None:
746 new_schedule[rank].append(
747 MetaStep(step.micro_index, MetaStepType.BWD_SEND, step.stage_index))
748 new_schedule[peer].append(
749 MetaStep(step.micro_index, MetaStepType.BWD_RECV, step.stage_index - 1))
750 elif step.type in (MetaStepType.OVERLAP_F_B, MetaStepType.OVERLAP_B_F) and step.sub_steps:
751 for sub in step.sub_steps:
752 _insert_comms_for_step(sub, rank, new_schedule)
754 # Expand OVERLAP steps into 2 virtual slots before the column scan so
755 # the RECV triggered by an overlap's second sub-step lands one slot
756 # later on the receiver — matching the fact that the sender can only
757 # finish emitting the second sub-step after the first completes.
758 expanded = _expand_overlap_slots(aligned, real_stage_num)
759 return _column_scan_insert_comms(expanded, real_stage_num, _insert_comms_for_step)
762class ScheduleGPipe(PipelineScheduleRuntime):
763 """
764 The Gpipe schedule.
765 It first executes all forward micro batches and then execute all backward micro batches.
766 """
767 def __init__(self,
768 stages,
769 micro_batch_num,
770 args_batch_dim=None,
771 kwargs_batch_dim=None,
772 output_concat_dim=None):
773 super().__init__(stages,
774 micro_batch_num,
775 args_batch_dim=args_batch_dim,
776 kwargs_batch_dim=kwargs_batch_dim,
777 output_concat_dim=output_concat_dim)
778 self.construct_exec_order()
780 def construct_exec_order(self):
781 """construct_exec_order of Gpipe."""
782 for stage_index in range(self.real_stage_num):
783 order_list = []
784 for mb_index in range(self.micro_batch_num):
785 if stage_index != 0:
786 order_list.append(MetaStep(mb_index, MetaStepType.FWD_RECV, stage_index))
787 order_list.append(MetaStep(mb_index, MetaStepType.FWD, stage_index))
788 if stage_index != self.real_stage_num - 1:
789 order_list.append(MetaStep(mb_index, MetaStepType.FWD_SEND, stage_index))
790 for mb_index in range(self.micro_batch_num):
791 if stage_index != self.real_stage_num - 1:
792 order_list.append(MetaStep(mb_index, MetaStepType.BWD_RECV, stage_index))
793 order_list.append(MetaStep(mb_index, MetaStepType.BWD, stage_index))
794 if stage_index != 0:
795 order_list.append(MetaStep(mb_index, MetaStepType.BWD_SEND, stage_index))
796 self.exec_order[stage_index] = order_list
799class Schedule1F1B(PipelineScheduleRuntime):
800 """
801 The 1F1B schedule.
802 It will perform one forward and one backward on the micro batches in steady state.
803 """
804 def __init__(self,
805 stages,
806 micro_batch_num,
807 args_batch_dim=None,
808 kwargs_batch_dim=None,
809 output_concat_dim=None):
810 super().__init__(stages,
811 micro_batch_num,
812 args_batch_dim=args_batch_dim,
813 kwargs_batch_dim=kwargs_batch_dim,
814 output_concat_dim=output_concat_dim)
815 self.construct_exec_order()
817 def construct_exec_order(self):
818 """construct_exec_order of 1F1B."""
819 for stage_index in range(self.real_stage_num):
820 order_list = []
821 fwd_index = 0
822 bwd_index = 0
823 # warmup phase
824 warmup_micro_batches = min(self.real_stage_num - stage_index, self.micro_batch_num)
825 for _ in range(warmup_micro_batches):
826 if stage_index != 0:
827 order_list.append(MetaStep(fwd_index, MetaStepType.FWD_RECV, stage_index))
828 if stage_index % 2 == 0:
829 order_list.append(MetaStep(fwd_index, MetaStepType.FWD, stage_index))
830 if fwd_index != warmup_micro_batches - 1:
831 order_list.append(MetaStep(fwd_index, MetaStepType.FWD_SEND, stage_index))
832 else:
833 if fwd_index > 0:
834 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index))
835 order_list.append(MetaStep(fwd_index, MetaStepType.FWD, stage_index))
836 fwd_index += 1
838 # if warmup phase cannot filled up, then we need to execute fwd send in advance
839 if self.real_stage_num - stage_index > self.micro_batch_num:
840 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index))
841 fwd_index += 1
842 # steady phase
843 steady_micro_batches = self.micro_batch_num - warmup_micro_batches
844 for _ in range(steady_micro_batches):
845 if stage_index != self.real_stage_num - 1:
846 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_RECV, stage_index))
847 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index))
848 order_list.append(MetaStep(bwd_index, MetaStepType.BWD, stage_index))
850 if stage_index != 0:
851 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_SEND, stage_index))
852 order_list.append(MetaStep(fwd_index, MetaStepType.FWD_RECV, stage_index))
853 order_list.append(MetaStep(fwd_index, MetaStepType.FWD, stage_index))
854 fwd_index += 1
855 bwd_index += 1
857 # cooldown phase
858 cooldown_micro_batches = warmup_micro_batches
859 for _ in range(cooldown_micro_batches):
860 if stage_index != self.real_stage_num - 1:
861 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_RECV, stage_index))
862 if bwd_index == self.micro_batch_num - warmup_micro_batches and fwd_index <= self.micro_batch_num:
863 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index))
864 order_list.append(MetaStep(bwd_index, MetaStepType.BWD, stage_index))
866 if stage_index != 0:
867 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_SEND, stage_index))
868 bwd_index += 1
869 self.exec_order[stage_index] = order_list
872class ScheduleInterleaved1F1B(PipelineScheduleRuntime):
873 """The Interleaved 1F1B schedule.
875 Supports multiple stages per rank. In steady state, performs one
876 forward followed by one backward on each micro-batch. Handles the
877 cases where ``micro_batch_num`` is less than, equal to, or greater
878 than the stage count, including non-evenly-divisible micro counts.
880 Two orthogonal overlap modes can be enabled via constructor flags:
882 * ``overlap_p2p=True``: defer P2P recv ``handle.wait()`` until the
883 consuming FWD/BWD step (or the OVERLAP_B_F callback when
884 ``overlap_b_f=True``), letting recv overlap with prior compute.
885 * ``overlap_b_f=True``: in the 1F1B steady state, pair consecutive
886 ``(B_i, F_{i+1})`` steps into ``OVERLAP_B_F`` composite steps so
887 a registered callback can drive comm/compute overlap (typically
888 via :class:`CommComputeOverlap` for MoE EP A2A). Users register
889 the callback through :meth:`register_custom_function`.
891 The two flags are independent and can be combined.
893 Example:
894 >>> # Plain interleaved 1F1B
895 >>> sched = ScheduleInterleaved1F1B(stages, 8)
896 >>> # With B/F overlap (dual-pipe-style comm/compute overlap)
897 >>> sched = ScheduleInterleaved1F1B(stages, 8, overlap_b_f=True)
898 >>> sched.register_custom_function(MetaStepType.OVERLAP_B_F, callback)
899 """
900 def __init__(self,
901 stages,
902 micro_batch_num,
903 args_batch_dim=None,
904 kwargs_batch_dim=None,
905 output_concat_dim=None,
906 overlap_p2p=False,
907 overlap_b_f=False):
908 super().__init__(stages,
909 micro_batch_num,
910 args_batch_dim=args_batch_dim,
911 kwargs_batch_dim=kwargs_batch_dim,
912 output_concat_dim=output_concat_dim,
913 overlap_p2p=overlap_p2p)
914 # _overlap_b_f selects between plain F/B emission and OVERLAP_B_F
915 # pairing in the 1F1B steady-state phase. Must be set before
916 # ``construct_stage_exec_order`` is called below.
917 self._overlap_b_f = overlap_b_f
918 self.n_rounds = max(1, self.micro_batch_num // self.real_stage_num)
919 if self.micro_batch_num < self.real_stage_num:
920 base = self.micro_batch_num - self.real_stage_num
921 remainder = 0
922 else:
923 n_extra_microbatch = self.micro_batch_num % self.real_stage_num
924 base = n_extra_microbatch // self.n_rounds
925 remainder = n_extra_microbatch % self.n_rounds
926 self.n_microbatch_per_round = \
927 [self.real_stage_num + base + 1 if i < remainder else
928 self.real_stage_num + base for i in range(self.n_rounds)]
929 self.n_microbatch_per_round_accu = \
930 [x * self.n_local_stages for x in itertools.accumulate(self.n_microbatch_per_round)]
931 self.n_microbatch_per_round_accu.insert(0, 0)
932 for stage_index in range(self.real_stage_num):
933 self.exec_order[stage_index] = self.construct_stage_exec_order(stage_index)
934 self.exec_order = add_send_recv(self.exec_order, self._stage_num, self.real_stage_num, style = 'loop')
936 def warmup_ops(self, stage_index):
937 """warmup phase."""
938 warmup_ops_last_stage = (self.n_local_stages - 1) * self.n_microbatch_per_round[0]
939 warmup_ops = warmup_ops_last_stage + 2 * (self.real_stage_num - 1 - stage_index)
940 return min(warmup_ops, self.micro_batch_num * self.n_local_stages)
942 def forward_stage_index(self, op_index, stage_index):
943 """obtain forward stage_index based on op_index."""
944 accu_index = bisect.bisect_right(self.n_microbatch_per_round_accu, op_index) - 1
945 local_index = (op_index - self.n_microbatch_per_round_accu[accu_index]) // \
946 self.n_microbatch_per_round[accu_index]
947 return (local_index * self.real_stage_num) + stage_index
949 def backward_stage_index(self, op_index, stage_index):
950 """obtain backward stage_index based on op_index."""
951 accu_index = bisect.bisect_right(self.n_microbatch_per_round_accu, op_index) - 1
952 local_index = (op_index - self.n_microbatch_per_round_accu[accu_index]) // \
953 self.n_microbatch_per_round[accu_index]
954 local_index = self.n_local_stages - 1 - local_index
955 return (local_index * self.real_stage_num) + stage_index
957 def _short_micro(self) -> bool:
958 """True when ``micro_batch_num < real_stage_num`` (extra-bubble regime)."""
959 return self.micro_batch_num < self.real_stage_num
961 def _trailing_bubble(self) -> int:
962 """Bubble count appended after a BWD with ``micro == micro_batch_num - 1``
963 in the short-micro regime.
964 """
965 return self.real_stage_num - self.micro_batch_num
967 def _emit_warmup_ops(self, stage_index, warmup_ops, fwd_stage_micro_index):
968 """Emit pure-FWD warmup ops with optional short-micro bubble padding."""
969 ops = []
970 short = self._short_micro()
971 last_micro = self.micro_batch_num - 1
972 last_stage = self.real_stage_num - 1
973 bubble = self._trailing_bubble()
974 for op_idx in range(warmup_ops):
975 fwd_stage_idx = self.forward_stage_index(op_idx, stage_index)
976 fwd_micro_idx = fwd_stage_micro_index[fwd_stage_idx]
977 ops.append(MetaStep(fwd_micro_idx, MetaStepType.FWD, fwd_stage_idx))
978 need_pad = (
979 short
980 and fwd_micro_idx == last_micro
981 and (op_idx != warmup_ops - 1 or stage_index == last_stage)
982 )
983 if need_pad:
984 ops.extend([None] * bubble)
985 fwd_stage_micro_index[fwd_stage_idx] += 1
986 return ops
988 def _emit_cooldown_ops(self, stage_index, warmup_ops, fwd_bwd_ops, total_ops,
989 bwd_stage_micro_index):
990 """Emit pure-BWD cooldown ops (each preceded by a bubble) with
991 optional short-micro trailing padding.
992 """
993 ops = []
994 short = self._short_micro()
995 last_micro = self.micro_batch_num - 1
996 bubble = self._trailing_bubble()
997 for op_idx in range(warmup_ops + fwd_bwd_ops, total_ops):
998 ops.append(None)
999 bwd_stage_idx = self.backward_stage_index(op_idx - warmup_ops, stage_index)
1000 bwd_micro_idx = bwd_stage_micro_index[bwd_stage_idx]
1001 ops.append(MetaStep(bwd_micro_idx, MetaStepType.BWD, bwd_stage_idx))
1002 if short and bwd_micro_idx == last_micro:
1003 ops.extend([None] * bubble)
1004 bwd_stage_micro_index[bwd_stage_idx] += 1
1005 return ops
1007 def _emit_1f1b_ops(self, stage_index, warmup_ops, fwd_bwd_ops,
1008 fwd_stage_micro_index, bwd_stage_micro_index):
1009 """Emit interleaved (FWD, BWD) pairs for the 1F1B steady-state phase."""
1010 ops = []
1011 short = self._short_micro()
1012 last_micro = self.micro_batch_num - 1
1013 last_stage = self.real_stage_num - 1
1014 bubble = self._trailing_bubble()
1015 for op_idx in range(warmup_ops, warmup_ops + fwd_bwd_ops):
1016 fwd_stage_idx = self.forward_stage_index(op_idx, stage_index)
1017 fwd_micro_idx = fwd_stage_micro_index[fwd_stage_idx]
1018 ops.append(MetaStep(fwd_micro_idx, MetaStepType.FWD, fwd_stage_idx))
1019 fwd_stage_micro_index[fwd_stage_idx] += 1
1020 bwd_stage_idx = self.backward_stage_index(op_idx - warmup_ops, stage_index)
1021 bwd_micro_idx = bwd_stage_micro_index[bwd_stage_idx]
1022 ops.append(MetaStep(bwd_micro_idx, MetaStepType.BWD, bwd_stage_idx))
1023 need_pad = (
1024 short
1025 and bwd_micro_idx == last_micro
1026 and stage_index == last_stage
1027 )
1028 if need_pad:
1029 ops.extend([None] * bubble)
1030 bwd_stage_micro_index[bwd_stage_idx] += 1
1031 return ops
1033 @staticmethod
1034 def _collect_fwd_bwd_steps(emit_fwd, emit_bwd, fwd_bwd_ops, warmup_ops):
1035 """Walk the 1F1B range collecting parallel ``fwd_steps`` / ``bwd_steps``.
1037 ``emit_fwd(op_idx)`` and ``emit_bwd(op_idx)`` build a single
1038 :class:`MetaStep` and advance their respective per-stage micro
1039 counters as a side effect.
1040 """
1041 fwd_steps = []
1042 bwd_steps = []
1043 for op_idx in range(warmup_ops, warmup_ops + fwd_bwd_ops):
1044 fwd_steps.append(emit_fwd(op_idx))
1045 bwd_steps.append(emit_bwd(op_idx))
1046 return fwd_steps, bwd_steps
1048 @staticmethod
1049 def _pair_into_overlap_b_f(fwd_steps, bwd_steps):
1050 """Build ``F₁, [B_i, F_{i+1}], B_n`` ordering with OVERLAP_B_F pairs.
1052 ``sub_steps`` carry the ``(bwd, fwd)`` tuple — callbacks access
1053 them via ``step.sub_steps`` to recover per-direction stage /
1054 micro info.
1055 """
1056 ops = []
1057 if fwd_steps:
1058 ops.append(fwd_steps[0]) # F₁ runs alone
1059 for i in range(len(bwd_steps) - 1):
1060 ops.append(MetaStep(
1061 None, MetaStepType.OVERLAP_B_F, None,
1062 sub_steps=(bwd_steps[i], fwd_steps[i + 1]),
1063 ))
1064 if bwd_steps:
1065 ops.append(bwd_steps[-1]) # B_n runs alone
1066 return ops
1068 def _emit_1f1b_overlap_ops(self, stage_index, warmup_ops, fwd_bwd_ops,
1069 fwd_stage_micro_index, bwd_stage_micro_index):
1070 """Emit ``F₁, [B_i, F_{i+1}], B_n`` for the 1F1B phase under
1071 ``overlap_b_f=True``. Each ``[B_i, F_{i+1}]`` becomes an
1072 ``OVERLAP_B_F`` composite step; a registered callback drives the
1073 actual concurrent execution. Short-micro extra-bubble padding
1074 on the last rank is appended after ``B_n``.
1075 """
1076 def emit_fwd(op_idx):
1077 fwd_si = self.forward_stage_index(op_idx, stage_index)
1078 fwd_mi = fwd_stage_micro_index[fwd_si]
1079 fwd_stage_micro_index[fwd_si] += 1
1080 return MetaStep(fwd_mi, MetaStepType.FWD, fwd_si)
1082 def emit_bwd(op_idx):
1083 bwd_si = self.backward_stage_index(op_idx - warmup_ops, stage_index)
1084 bwd_mi = bwd_stage_micro_index[bwd_si]
1085 bwd_stage_micro_index[bwd_si] += 1
1086 return MetaStep(bwd_mi, MetaStepType.BWD, bwd_si)
1088 fwd_steps, bwd_steps = self._collect_fwd_bwd_steps(
1089 emit_fwd, emit_bwd, fwd_bwd_ops, warmup_ops,
1090 )
1091 ops = self._pair_into_overlap_b_f(fwd_steps, bwd_steps)
1093 last_stage = self.real_stage_num - 1
1094 if self._short_micro() and stage_index == last_stage and bwd_steps:
1095 if bwd_steps[-1].micro_index == self.micro_batch_num - 1:
1096 ops.extend([None] * self._trailing_bubble())
1097 return ops
1099 def construct_stage_exec_order(self, stage_index):
1100 """Construct the execution order for ``stage_index``.
1102 Builds: warmup → bubbles → 1F1B steady state → cooldown. The
1103 1F1B segment switches between :meth:`_emit_1f1b_ops` (plain) and
1104 :meth:`_emit_1f1b_overlap_ops` (OVERLAP_B_F pairing) based on
1105 the ``overlap_b_f`` constructor flag.
1106 """
1107 warmup_ops = self.warmup_ops(stage_index)
1108 fwd_bwd_ops = self.n_local_stages * self.micro_batch_num - warmup_ops
1109 total_ops = 2 * warmup_ops + fwd_bwd_ops
1110 order_list = [None for _ in range(stage_index)]
1111 fwd_stage_micro_index = defaultdict(int)
1112 bwd_stage_micro_index = defaultdict(int)
1113 order_list.extend(self._emit_warmup_ops(stage_index, warmup_ops, fwd_stage_micro_index))
1114 bubbles_before_1f1b = max(
1115 0,
1116 2 * (self.real_stage_num - stage_index - 1) - self.micro_batch_num,
1117 )
1118 order_list.extend([None] * bubbles_before_1f1b)
1119 order_list.extend([None] * (self.real_stage_num - 1 - stage_index))
1120 if self._overlap_b_f:
1121 order_list.extend(self._emit_1f1b_overlap_ops(
1122 stage_index, warmup_ops, fwd_bwd_ops,
1123 fwd_stage_micro_index, bwd_stage_micro_index,
1124 ))
1125 else:
1126 order_list.extend(self._emit_1f1b_ops(
1127 stage_index, warmup_ops, fwd_bwd_ops,
1128 fwd_stage_micro_index, bwd_stage_micro_index,
1129 ))
1130 order_list.extend(self._emit_cooldown_ops(
1131 stage_index, warmup_ops, fwd_bwd_ops, total_ops, bwd_stage_micro_index,
1132 ))
1133 return order_list
1136def detect_cycle_in_graph(ranks_map):
1137 """
1138 Detects a cycle in the directed graph constructed from ranks_map.
1140 Args:
1141 ranks_map: A dictionary where keys are rank names and values are lists of nodes.
1143 Returns:
1144 tuple: (cycle_path, cycle_ranks) where cycle_path is a list of nodes forming the cycle and cycle_ranks
1145 is a list of rank transitions corresponding to the cycle path.
1146 """
1147 graph = defaultdict(list)
1148 rank_edges = {}
1150 for rank, nodes in ranks_map.items():
1151 for i in range(len(nodes) - 1):
1152 u, v = nodes[i], nodes[i + 1]
1153 graph[u].append(v)
1154 rank_edges[(u, v)] = rank
1156 visited = set()
1157 path = []
1158 node_indices = {}
1159 cycle_path = []
1160 cycle_ranks = []
1162 stack = []
1163 for node in list(graph.keys()):
1164 if node not in visited:
1165 stack.append((node, False))
1166 while stack:
1167 current_node, is_processed = stack.pop()
1169 if is_processed:
1170 path.pop()
1171 del node_indices[current_node]
1172 continue
1174 if current_node in node_indices:
1175 cycle_start = node_indices[current_node]
1176 cycle_path = path[cycle_start:] + [current_node]
1177 for i in range(cycle_start, len(path)):
1178 u = path[i]
1179 v = path[i + 1] if i + 1 < len(path) else current_node
1180 cycle_ranks.append(f"{rank_edges[(u, v)]} {u} -> {v}")
1181 return cycle_path, cycle_ranks
1183 if current_node in visited:
1184 continue
1186 visited.add(current_node)
1187 node_indices[current_node] = len(path)
1188 path.append(current_node)
1190 stack.append((current_node, True))
1191 for neighbor in reversed(graph[current_node]):
1192 stack.append((neighbor, False))
1194 return None, None
1197def output_cycle_results(cycle_path, cycle_ranks):
1198 """
1199 Helper function to output cycle detection results.
1201 Args:
1202 cycle_path (list): List of nodes forming a cycle, if any.
1203 cycle_ranks (list): List of ranks involved in the cycle.
1205 Returns:
1206 None: Outputs results to the console.
1207 """
1208 if cycle_path:
1209 logger.error("Cycle detected:")
1210 path_str = " -> ".join(str(node) for node in cycle_path)
1211 logger.error("%s -> %s", path_str, cycle_path[0]) # Close the cycle
1212 logger.error("Involving ranks:")
1213 for rank in cycle_ranks:
1214 logger.error(rank)
1215 else:
1216 logger.warning("Cycle Check succeeded. There is no cycle in the graph.")
1219def parse_and_validate(data: dict, all_rank: bool = True):
1220 """
1221 Parse and validate execution orders in a directed graph structure.
1223 This function checks the integrity and consistency of a given dataset, ensuring all required
1224 keys are present and correctly referenced. It also validates the structure of the input data
1225 and parses string values to extract meaningful components.
1227 Args:
1228 data (dict): A dictionary where keys are string identifiers and values are lists of strings.
1229 Each value represents a dependency or reference to other keys.
1230 all_rank (bool): If True, checks that all elements referenced in the data are present as keys
1231 in the dictionary. If False, only checks intersections.
1233 Returns:
1234 None: Log error messages to the console if validation fails, otherwise completes silently.
1236 Raises:
1237 ValueError: Raised indirectly if `parse_elements` encounters malformed input strings.
1238 TypeError: Raised indirectly if data contains unexpected types.
1239 """
1241 def parse_elements(value: str, max_groups: int = 2) -> set:
1242 """Extract unique elements inside the first one or two parentheses from a string."""
1244 groups = re.findall(r'\((\d+)\)', value)
1245 limited_groups = groups[:max_groups] # Limit to the first `max_groups` matches
1247 return {item.strip() for item in limited_groups}
1249 if not isinstance(data, dict):
1250 logger.error("Input must be a dictionary with string keys and lists of strings as values.")
1251 return
1253 key_to_values = {key: set(values) for key, values in data.items() if
1254 isinstance(values, list) and all(isinstance(v, str) for v in values)}
1256 for key, values in data.items():
1257 if not isinstance(values, list) or not all(isinstance(v, str) for v in values):
1258 logger.error("Values for key '%s' must be a list of strings.", key)
1259 continue
1261 for value in values:
1262 try:
1263 elements = parse_elements(value)
1264 except (ValueError, TypeError, AttributeError) as e:
1265 logger.error("Unable to parse elements from value '%s' in key '%s'. Error: %s", value, key, e)
1266 continue
1268 # Check for missing keys if all_rank is True
1269 if all_rank:
1270 missing_keys = elements - key_to_values.keys()
1271 if missing_keys:
1272 logger.error("The following keys are missing for value '%s': %s", value, missing_keys)
1273 continue
1275 # Check if the value is present in the referenced keys
1276 for element in elements & key_to_values.keys() if not all_rank else elements:
1277 if value not in key_to_values[element]:
1278 logger.error("Key '%s' is missing the value '%s'.", element, value)
1281def generate_operations(order_list: dict[int, list[MetaStep]],
1282 chunk_num: int,
1283 com_type: str = 'loop') -> dict[str, list[str]]:
1284 """
1285 Generate formatted operations dictionary from pipeline execution order.
1287 Args:
1288 order_list (dict): Dictionary where keys are rank IDs and values are MetaStep execution sequences
1289 chunk_num (int): Number of chunks (virtual pipeline stages)
1290 com_type (str): Stage-to-rank mapping type ('loop' for cyclic, 'v' for V-shaped)
1292 Returns:
1293 Dictionary where keys are rank IDs (as strings) and values are lists of formatted operation strings
1294 """
1296 def stage_to_rank(stage_index, style, stage_num, real_stage_num):
1297 """Map stage index to rank"""
1298 if style == 'loop':
1299 return stage_index % real_stage_num
1300 if style == 'v':
1301 if stage_index < real_stage_num:
1302 return stage_index
1303 return stage_num - 1 - stage_index
1304 raise ValueError("Invalid style")
1306 def find_send_target(stage_idx, op_type):
1307 """Find target stage for SEND operation"""
1308 if op_type == MetaStepType.FWD_SEND:
1309 return forward_comm.get(stage_idx)
1310 return backward_comm.get(stage_idx)
1312 def find_recv_source(stage_idx, op_type):
1313 """Find source stage for RECV operation"""
1314 if op_type == MetaStepType.FWD_RECV:
1315 # Reverse lookup in forward_comm
1316 for src, dst in forward_comm.items():
1317 if dst == stage_idx:
1318 return src
1319 else:
1320 # Reverse lookup in backward_comm
1321 for src, dst in backward_comm.items():
1322 if dst == stage_idx:
1323 return src
1324 return None
1326 real_stage = len(order_list)
1327 total_stages = real_stage * chunk_num
1329 # Build communication rules
1330 forward_comm = {}
1331 backward_comm = {}
1333 for i in range(total_stages):
1334 if i + 1 < total_stages:
1335 forward_comm[i] = i + 1
1336 if i - 1 >= 0:
1337 backward_comm[i] = i - 1
1339 formatted_operations = defaultdict(list)
1341 for rank, steps in order_list.items():
1342 operation_counter = defaultdict(int)
1344 for step in steps:
1345 if step.type in [MetaStepType.FWD_SEND, MetaStepType.BWD_SEND]:
1346 target_stage = find_send_target(step.stage_index, step.type)
1347 if target_stage is not None:
1348 target_rank = stage_to_rank(target_stage, com_type, total_stages, real_stage)
1349 comm_pair = (rank, target_rank, step.micro_index)
1350 operation_counter[comm_pair] += 1
1351 count = operation_counter[comm_pair]
1352 formatted_op = f"Send_Receive_({rank})->({target_rank})_micro{step.micro_index}_{count}th"
1353 formatted_operations[str(rank)].append(formatted_op)
1355 elif step.type in [MetaStepType.FWD_RECV, MetaStepType.BWD_RECV]:
1356 source_stage = find_recv_source(step.stage_index, step.type)
1357 if source_stage is not None:
1358 source_rank = stage_to_rank(source_stage, com_type, total_stages, real_stage)
1359 comm_pair = (source_rank, rank, step.micro_index)
1360 operation_counter[comm_pair] += 1
1361 count = operation_counter[comm_pair]
1362 formatted_op = f"Send_Receive_({source_rank})->({rank})_micro{step.micro_index}_{count}th"
1363 formatted_operations[str(rank)].append(formatted_op)
1365 # Convert defaultdict to dict
1366 return dict(formatted_operations)
1369def validate_pipeline_execution(order_list: dict[int, list[MetaStep]],
1370 chunk_num: int,
1371 com_type: str = 'loop') -> dict[str, any]:
1372 """
1373 Comprehensive validation function for pipeline parallel execution order.
1375 This function validates the execution order of pipeline parallelism by:
1376 1. Checking SEND/RECV communication pair matching
1377 2. Detecting duplicate operations
1378 3. Detecting cycles in communication graphs
1379 4. Verifying computation-SEND matching
1381 Args:
1382 order_list: Dictionary where keys are rank IDs and values are MetaStep execution sequences
1383 chunk_num: Number of chunks (virtual pipeline stages)
1384 com_type: Stage-to-rank mapping type ('loop' for cyclic, 'v' for V-shaped)
1386 Returns:
1387 Dictionary containing validation results with the following keys:
1388 - validation: Communication pair validation results
1389 - cycle_detection: Cycle detection results
1390 - computation_send_matching: Computation-SEND matching validation results
1391 - has_errors: Boolean indicating if any errors were found
1392 - error_messages: List of all error messages found
1393 - formatted_operations: Generated formatted operations
1394 """
1396 # Generate operations
1397 formatted_operations = generate_operations(order_list, chunk_num, com_type)
1399 parse_and_validate(formatted_operations, True)
1401 # Detect cycles
1402 cycle_path, cycle_ranks = detect_cycle_in_graph(formatted_operations)
1404 # Output results
1405 output_cycle_results(cycle_path, cycle_ranks)
1407 result = {
1408 'formatted_operations': formatted_operations,
1409 'cycle_path': cycle_path,
1410 'cycle_ranks': cycle_ranks,
1411 'has_cycle': bool(cycle_path)
1412 }
1413 return result