Coverage for hyper_parallel / core / pipeline_parallel / scheduler.py: 51%
476 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""pipeline schedule"""
16from abc import ABC
17from enum import Enum, auto
18from collections import defaultdict
19import itertools
20import bisect
21import logging
22import re
23import hyper_parallel
24from hyper_parallel.platform import get_platform
25platform = get_platform()
26logger = logging.getLogger(__name__)
29class MetaStepType(Enum):
30 """Specify the enumeration type for MetaStep."""
31 FWD = auto()
32 BWD = auto()
33 FWD_RECV = auto()
34 FWD_SEND = auto()
35 BWD_RECV = auto()
36 BWD_SEND = auto()
39class MetaStep:
40 """
41 Meta step of PipelineSchedule.
42 An execution list composed of MetaStep can be constructed
43 and fed into the PipelineSchedule for execution.
45 Args:
46 micro_index (int): The index of micro-batch.
47 type (MetaStepType): Specify the type of current step.
48 stage_index (int): Specify the stage index of current step.
49 """
50 def __init__(self, micro_index, meta_type, stage_index):
51 self._type = meta_type
52 self._micro_index = micro_index
53 self._stage_index = stage_index
55 @property
56 def micro_index(self):
57 return self._micro_index
59 @property
60 def stage_index(self):
61 return self._stage_index
63 @property
64 def type(self):
65 return self._type
67 def __eq__(self, value):
68 if not isinstance(value, MetaStep):
69 return NotImplemented
70 return self.type == value.type and \
71 self.micro_index == value.micro_index and \
72 self.stage_index == value.stage_index
74 def __ne__(self, value):
75 if not isinstance(value, MetaStep):
76 return NotImplemented
77 return self.type != value.type or \
78 self.micro_index != value.micro_index or \
79 self.stage_index != value.stage_index
81 def __hash__(self):
82 return hash((self.type, self.micro_index, self.stage_index))
84 def __str__(self):
85 return f"MetaStep(type={self.type}, micro_index={self.micro_index}, stage_index={self.stage_index})"
87 def __repr__(self):
88 return f"MetaStep(type={self.type}, micro_index={self.micro_index}, stage_index={self.stage_index})"
90 @staticmethod
91 def from_str(step_str):
92 pass
95class PipelineScheduleRuntime(ABC):
96 """
97 Base class for pipeline schedule.
98 Implements the `split_microbatches` and `run_microbatches` method.
99 Derived classes should implement `run_microbatches` method and `run` method.
101 Args:
102 stages (list[PipelineStage], PipelineStage): PipelineStage used to run_microbatches.
103 micro_batch_num (int): The number of micro-batch.
104 args_batch_dim (list, optional): Specify the batch dim of the args.
105 Default ``None``.
106 kwargs_batch_dim (dict, optional): Specify the batch dim of the kwargs.
107 Default ``None``.
108 """
109 def __init__(self,
110 stages,
111 micro_batch_num,
112 args_batch_dim=None,
113 kwargs_batch_dim=None,
114 output_concat_dim=None,
115 overlap_p2p=False):
116 self.stages = self._check_stages(stages)
117 self.micro_batch_num = micro_batch_num
118 self._args_batch_dim = args_batch_dim
119 self._kwargs_batch_dim = kwargs_batch_dim
120 self._output_concat_dim = output_concat_dim
121 self.split_micro_batch = platform.micro_batch(self.micro_batch_num,
122 self._args_batch_dim, self._kwargs_batch_dim)
123 self.n_local_stages = len(self.stages)
124 self._stage_dict = self.convert_stages_dict()
125 self.real_stage_num = self.stages[0].stage_num // self.n_local_stages
126 self._stage_num = self.stages[0].stage_num
127 self._overlap_p2p = overlap_p2p
128 self.exec_order = {}
129 self._init_stages()
130 self.fwd_handle_cache = {}
131 self.bwd_handle_cache = {}
133 def convert_stages_dict(self):
134 """convert stages to dict."""
135 stage_dict = {}
136 for stage in self.stages:
137 stage_dict[stage.stage_index] = stage
138 return stage_dict
140 def split_microbatches(self, args, kwargs):
141 """split_microbatches."""
142 if args or kwargs:
143 args_split, kwargs_split = self.split_micro_batch(args, kwargs)
144 return args_split, kwargs_split
145 return [[]] * self.micro_batch_num, [{}] * self.micro_batch_num
147 def _check_stages(self, stages):
148 """check stages type."""
149 if isinstance(stages, hyper_parallel.PipelineStage):
150 return [stages]
151 if isinstance(stages, (list, tuple)):
152 for stage in stages:
153 if not isinstance(stage, hyper_parallel.PipelineStage):
154 raise TypeError(f"Argument 'stages' must be type of PipelineStage, \
155 list or tuple of PipelineStage, but got list or tuple of {type(stage)}.")
156 return stages
157 raise TypeError(f"Argument 'stages' must be type of PipelineStage, \
158 list or tuple of PipelineStage, but got type of {type(stages)}.")
160 def _init_stages(self):
161 """init stages."""
162 for stage in self.stages:
163 stage.init(self.n_local_stages)
165 def run(self, *args, **kwargs):
166 """schedule run."""
167 split_args, split_kwargs = self.split_microbatches(args, kwargs)
168 losses = []
169 self.run_microbatches(split_args, split_kwargs, losses)
170 return losses
172 def sync_shared_parameters_grad(self):
173 """sync_shared_parameters_grad."""
174 for stage in self.stages:
175 stage.sync_shared_parameters_grad()
177 def update_losses(self, stage, loss, losses):
178 """update_losses."""
179 if stage.is_last_stage:
180 losses.append(loss)
182 def _wait_p2p(self, handles):
183 for handle in handles:
184 if handle is not None:
185 handle.wait()
187 def run_microbatches(self, arg_mbs, kwarg_mbs, losses):
188 """run_microbatches."""
189 real_stage_index = self.stages[0].stage_index % self.real_stage_num
190 send_handle = []
191 for cur_step in self.exec_order[real_stage_index]:
192 if cur_step is None:
193 continue
194 stage = self._stage_dict[cur_step.stage_index]
195 stage_index = cur_step.stage_index
196 micro_index = cur_step.micro_index
197 if cur_step.type == MetaStepType.FWD_RECV:
198 comm_handle = stage.exec_fwd_recv_ops(micro_index)
199 if not self._overlap_p2p:
200 self._wait_p2p(comm_handle)
201 else:
202 key = (stage_index, micro_index)
203 self.fwd_handle_cache[key] = comm_handle
204 if cur_step.type == MetaStepType.FWD:
205 key = (stage_index, micro_index)
206 if self._overlap_p2p and key in self.fwd_handle_cache:
207 comm_handle = self.fwd_handle_cache.pop(key)
208 self._wait_p2p(comm_handle)
209 out = stage.forward_one_chunk(micro_index, arg_mbs[micro_index], kwarg_mbs[micro_index])
210 self.update_losses(stage, out, losses)
211 if cur_step.type == MetaStepType.FWD_SEND:
212 comm_handle = stage.exec_fwd_send_ops(micro_index)
213 if not self._overlap_p2p:
214 self._wait_p2p(comm_handle)
215 else:
216 send_handle.append(comm_handle)
217 if cur_step.type == MetaStepType.BWD_RECV:
218 comm_handle = stage.exec_bwd_recv_ops(micro_index)
219 if not self._overlap_p2p:
220 self._wait_p2p(comm_handle)
221 else:
222 key = (stage_index, micro_index)
223 self.bwd_handle_cache[key] = comm_handle
224 if cur_step.type == MetaStepType.BWD:
225 key = (stage_index, micro_index)
226 if self._overlap_p2p and key in self.bwd_handle_cache:
227 comm_handle = self.bwd_handle_cache.pop(key)
228 self._wait_p2p(comm_handle)
229 if micro_index == self.micro_batch_num - 1:
230 stage.backward_one_chunk(micro_index, True)
231 else:
232 stage.backward_one_chunk(micro_index)
233 if cur_step.type == MetaStepType.BWD_SEND:
234 comm_handle = stage.exec_bwd_send_ops(micro_index)
235 if not self._overlap_p2p:
236 self._wait_p2p(comm_handle)
237 else:
238 send_handle.append(comm_handle)
239 self.sync_shared_parameters_grad()
240 while send_handle:
241 self._wait_p2p(send_handle.pop())
244def add_send_recv(scheduler, stage_num, real_stage_num, style='loop'):
245 """
246 Create schedule for each rank and automatically add communication operations
248 Args:
249 scheduler: Compute schedule table with None
250 stage_num: Total number of pipeline stages
251 real_stage_num: Number of actual physical stages/ranks
252 style: Communication style ('loop' or 'v')
254 Returns:
255 Complete schedule table for each rank (including communication operations)
256 """
258 def _need_com(action, style, stage_num):
259 """Determine if communication is needed"""
260 if action.type == MetaStepType.FWD:
261 if action.stage_index == stage_num - 1:
262 return False # Last stage doesn't need forward communication
263 next_stage_rank = stage_to_rank(action.stage_index + 1, style, stage_num, real_stage_num)
264 current_rank = stage_to_rank(action.stage_index, style, stage_num, real_stage_num)
265 return next_stage_rank != current_rank
266 if action.type == MetaStepType.BWD:
267 if action.stage_index == 0:
268 return False # First stage doesn't need backward communication
269 prev_stage_rank = stage_to_rank(action.stage_index - 1, style, stage_num, real_stage_num)
270 current_rank = stage_to_rank(action.stage_index, style, stage_num, real_stage_num)
271 return prev_stage_rank != current_rank
272 return False
274 def stage_to_rank(stage_index, style, stage_num, real_stage_num):
275 """Map stage index to rank"""
276 if style == 'loop':
277 return stage_index % real_stage_num
278 if style == 'v':
279 if stage_index < real_stage_num:
280 return stage_index
281 return stage_num - 1 - stage_index
282 raise ValueError("Invalid style")
284 def process_rank_communication(rank, operation, new_schedule, style, stage_num, real_stage_num):
285 """Process communication operations for single rank"""
286 if operation is None:
287 return
289 stage_index = operation.stage_index
290 pre_rank = stage_to_rank(stage_index - 1, style, stage_num, real_stage_num) if stage_index > 0 else 0
291 nxt_rank = stage_to_rank(stage_index + 1, style, stage_num, real_stage_num) if stage_index < stage_num else None
293 if (operation.type == MetaStepType.FWD and
294 _need_com(operation, style, stage_num) and nxt_rank is not None):
295 new_schedule[rank].append(MetaStep(
296 micro_index=operation.micro_index,
297 meta_type=MetaStepType.FWD_SEND, # 注意:使用 FWD_SEND 而不是 FWD_SEND
298 stage_index=stage_index
299 ))
300 new_schedule[nxt_rank].append(MetaStep(
301 micro_index=operation.micro_index,
302 meta_type=MetaStepType.FWD_RECV, # 注意:使用 FWD_RECV 而不是 FWD_RECV
303 stage_index=stage_index + 1
304 ))
305 elif (operation.type == MetaStepType.BWD and
306 _need_com(operation, style, stage_num) and pre_rank is not None):
307 new_schedule[rank].append(MetaStep(
308 micro_index=operation.micro_index,
309 meta_type=MetaStepType.BWD_SEND, # 注意:使用 BWD_SEND 而不是 BWD_SEND
310 stage_index=stage_index
311 ))
312 new_schedule[pre_rank].append(MetaStep(
313 micro_index=operation.micro_index,
314 meta_type=MetaStepType.BWD_RECV, # 注意:使用 BWD_RECV 而不是 BWD_RECV
315 stage_index=stage_index - 1
316 ))
318 # Main logic
319 max_length = max(len(schedule) for schedule in scheduler.values())
320 new_schedule = {rank: [] for rank in range(real_stage_num)}
322 for time_step in range(max_length):
323 current_operations = {}
324 for rank in range(real_stage_num):
325 if time_step < len(scheduler[rank]):
326 operation = scheduler[rank][time_step]
327 current_operations[rank] = operation
328 if operation is not None:
329 new_schedule[rank].append(operation)
330 else:
331 current_operations[rank] = None
333 # Process even rank communication
334 for rank in range(0, real_stage_num, 2):
335 process_rank_communication(rank, current_operations[rank], new_schedule,
336 style, stage_num, real_stage_num)
338 # Process odd rank communication
339 for rank in range(1, real_stage_num, 2):
340 process_rank_communication(rank, current_operations[rank], new_schedule,
341 style, stage_num, real_stage_num)
343 return new_schedule
346class ScheduleGPipe(PipelineScheduleRuntime):
347 """
348 The Gpipe schedule.
349 It first executes all forward micro batches and then execute all backward micro batches.
350 """
351 def __init__(self,
352 stages,
353 micro_batch_num,
354 args_batch_dim=None,
355 kwargs_batch_dim=None,
356 output_concat_dim=None):
357 super().__init__(stages,
358 micro_batch_num,
359 args_batch_dim=args_batch_dim,
360 kwargs_batch_dim=kwargs_batch_dim,
361 output_concat_dim=output_concat_dim)
362 self.construct_exec_order()
364 def construct_exec_order(self):
365 """construct_exec_order of Gpipe."""
366 for stage_index in range(self.real_stage_num):
367 order_list = []
368 for mb_index in range(self.micro_batch_num):
369 if stage_index != 0:
370 order_list.append(MetaStep(mb_index, MetaStepType.FWD_RECV, stage_index))
371 order_list.append(MetaStep(mb_index, MetaStepType.FWD, stage_index))
372 if stage_index != self.stage.stage_num - 1:
373 order_list.append(MetaStep(mb_index, MetaStepType.FWD_SEND, stage_index))
374 for mb_index in range(self.micro_batch_num):
375 if stage_index != self.real_stage_num - 1:
376 order_list.append(MetaStep(mb_index, MetaStepType.BWD_RECV, stage_index))
377 order_list.append(MetaStep(mb_index, MetaStepType.BWD, stage_index))
378 if stage_index != 0:
379 order_list.append(MetaStep(mb_index, MetaStepType.BWD_SEND, stage_index))
380 self.exec_order[stage_index] = order_list
383class Schedule1F1B(PipelineScheduleRuntime):
384 """
385 The 1F1B schedule.
386 It will perform one forward and one backward on the micro batches in steady state.
387 """
388 def __init__(self,
389 stages,
390 micro_batch_num,
391 args_batch_dim=None,
392 kwargs_batch_dim=None,
393 output_concat_dim=None):
394 super().__init__(stages,
395 micro_batch_num,
396 args_batch_dim=args_batch_dim,
397 kwargs_batch_dim=kwargs_batch_dim,
398 output_concat_dim=output_concat_dim)
399 self.construct_exec_order()
401 def construct_exec_order(self):
402 """construct_exec_order of 1F1B."""
403 for stage_index in range(self.real_stage_num):
404 order_list = []
405 fwd_index = 0
406 bwd_index = 0
407 # warmup phase
408 warmup_micro_batches = min(self.real_stage_num - stage_index, self.micro_batch_num)
409 for _ in range(warmup_micro_batches):
410 if stage_index != 0:
411 order_list.append(MetaStep(fwd_index, MetaStepType.FWD_RECV, stage_index))
412 if stage_index % 2 == 0:
413 order_list.append(MetaStep(fwd_index, MetaStepType.FWD, stage_index))
414 if fwd_index != warmup_micro_batches - 1:
415 order_list.append(MetaStep(fwd_index, MetaStepType.FWD_SEND, stage_index))
416 else:
417 if fwd_index > 0:
418 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index))
419 order_list.append(MetaStep(fwd_index, MetaStepType.FWD, stage_index))
420 fwd_index += 1
422 # if warmup phase cannot filled up, then we need to execute fwd send in advance
423 if self.real_stage_num - stage_index > self.micro_batch_num:
424 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index))
425 fwd_index += 1
426 # steady phase
427 steady_micro_batches = self.micro_batch_num - warmup_micro_batches
428 for _ in range(steady_micro_batches):
429 if stage_index != self.real_stage_num - 1:
430 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_RECV, stage_index))
431 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index))
432 order_list.append(MetaStep(bwd_index, MetaStepType.BWD, stage_index))
434 if stage_index != 0:
435 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_SEND, stage_index))
436 order_list.append(MetaStep(fwd_index, MetaStepType.FWD_RECV, stage_index))
437 order_list.append(MetaStep(fwd_index, MetaStepType.FWD, stage_index))
438 fwd_index += 1
439 bwd_index += 1
441 # cooldown phase
442 cooldown_micro_batches = warmup_micro_batches
443 for _ in range(cooldown_micro_batches):
444 if stage_index != self.real_stage_num - 1:
445 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_RECV, stage_index))
446 if bwd_index == self.micro_batch_num - warmup_micro_batches and fwd_index <= self.micro_batch_num:
447 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index))
448 order_list.append(MetaStep(bwd_index, MetaStepType.BWD, stage_index))
450 if stage_index != 0:
451 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_SEND, stage_index))
452 bwd_index += 1
453 self.exec_order[stage_index] = order_list
456class ScheduleInterleaved1F1B(PipelineScheduleRuntime):
457 """
458 The Interleaved 1F1B schedule.
459 Support multiple stages per rank. It will perform one forward and one backward
460 on the micro batches in steady state.
461 We support cases where num_microbatch is less than or equal, or greater than the
462 stage num, as well as cases where num_microbatch can't be evenly divided by the
463 stage num.
464 """
465 def __init__(self,
466 stages,
467 micro_batch_num,
468 args_batch_dim=None,
469 kwargs_batch_dim=None,
470 output_concat_dim=None):
471 super().__init__(stages,
472 micro_batch_num,
473 args_batch_dim=args_batch_dim,
474 kwargs_batch_dim=kwargs_batch_dim,
475 output_concat_dim=output_concat_dim)
476 self.n_rounds = max(1, self.micro_batch_num // self.real_stage_num)
477 if self.micro_batch_num < self.real_stage_num:
478 base = self.micro_batch_num - self.real_stage_num
479 remainder = 0
480 else:
481 n_extra_microbatch = self.micro_batch_num % self.real_stage_num
482 base = n_extra_microbatch // self.n_rounds
483 remainder = n_extra_microbatch % self.n_rounds
484 self.n_microbatch_per_round = \
485 [self.real_stage_num + base + 1 if i < remainder else
486 self.real_stage_num + base for i in range(self.n_rounds)]
487 self.n_microbatch_per_round_accu = \
488 [x * self.n_local_stages for x in itertools.accumulate(self.n_microbatch_per_round)]
489 self.n_microbatch_per_round_accu.insert(0, 0)
490 for stage_index in range(self.real_stage_num):
491 self.exec_order[stage_index] = self.construct_stage_exec_order(stage_index)
492 self.exec_order = add_send_recv(self.exec_order, self._stage_num, self.real_stage_num, style = 'loop')
494 def warmup_ops(self, stage_index):
495 """warmup phase."""
496 warmup_ops_last_stage = (self.n_local_stages - 1) * self.n_microbatch_per_round[0]
497 warmup_ops = warmup_ops_last_stage + 2 * (self.real_stage_num - 1 - stage_index)
498 return min(warmup_ops, self.micro_batch_num * self.n_local_stages)
500 def forward_stage_index(self, op_index, stage_index):
501 """obtain forward stage_index based on op_index."""
502 accu_index = bisect.bisect_right(self.n_microbatch_per_round_accu, op_index) - 1
503 local_index = (op_index - self.n_microbatch_per_round_accu[accu_index]) // \
504 self.n_microbatch_per_round[accu_index]
505 return (local_index * self.real_stage_num) + stage_index
507 def backward_stage_index(self, op_index, stage_index):
508 """obtain backward stage_index based on op_index."""
509 accu_index = bisect.bisect_right(self.n_microbatch_per_round_accu, op_index) - 1
510 local_index = (op_index - self.n_microbatch_per_round_accu[accu_index]) // \
511 self.n_microbatch_per_round[accu_index]
512 local_index = self.n_local_stages - 1 - local_index
513 return (local_index * self.real_stage_num) + stage_index
515 def construct_stage_exec_order(self, stage_index):
516 """construct the execution order of specified stage_index."""
517 warmup_ops = self.warmup_ops(stage_index)
518 fwd_bwd_ops = self.n_local_stages * self.micro_batch_num - warmup_ops
519 cooldown_ops = warmup_ops
520 total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
521 # Pre-padding bubbles, stage starts with no-ops based on the warmup.
522 order_list = [None for _ in range(stage_index)]
523 fwd_stage_micro_index = defaultdict(int)
524 bwd_stage_micro_index = defaultdict(int)
525 # WarmUp Phase
526 for op_idx in range(warmup_ops):
527 fwd_stage_idx = self.forward_stage_index(op_idx, stage_index)
528 fwd_micro_idx = fwd_stage_micro_index[fwd_stage_idx]
529 order_list.append(MetaStep(fwd_micro_idx, MetaStepType.FWD, fwd_stage_idx))
530 # If micro is less than stage num, there will be additional bubbles during warmup phase.
531 if self.micro_batch_num < self.real_stage_num and fwd_micro_idx == self.micro_batch_num - 1:
532 if op_idx != warmup_ops - 1 or stage_index == self.real_stage_num - 1:
533 order_list.extend([None] * (self.real_stage_num - self.micro_batch_num))
534 fwd_stage_micro_index[fwd_stage_idx] += 1
535 # If micro is less than 2 * (self.real_stage_num - stage_index - 1),
536 # there will be additional bubbles during warmup phase.
537 if self.micro_batch_num < 2 * (self.real_stage_num - stage_index - 1):
538 order_list.extend([None] * (2 * (self.real_stage_num - stage_index - 1) - self.micro_batch_num))
539 # Bubble from the end of warmup to the start of backward.
540 order_list.extend([None] * (self.real_stage_num - 1 - stage_index))
542 # 1f1b phase
543 for op_idx in range(warmup_ops, warmup_ops+fwd_bwd_ops):
544 fwd_stage_idx = self.forward_stage_index(op_idx, stage_index)
545 fwd_micro_idx = fwd_stage_micro_index[fwd_stage_idx]
546 order_list.append(MetaStep(fwd_micro_idx, MetaStepType.FWD, fwd_stage_idx))
547 fwd_stage_micro_index[fwd_stage_idx] += 1
548 bwd_stage_idx = self.backward_stage_index(op_idx - warmup_ops, stage_index)
549 bwd_micro_idx = bwd_stage_micro_index[bwd_stage_idx]
550 order_list.append(MetaStep(bwd_micro_idx, MetaStepType.BWD, bwd_stage_idx))
551 # If micro is less than 2 * (self.real_stage_num - stage_index - 1),
552 # there will be additional bubbles after 1f1b phase in last stage.
553 if self.micro_batch_num < self.real_stage_num and bwd_micro_idx == self.micro_batch_num - 1:
554 if stage_index == self.real_stage_num - 1:
555 order_list.extend([None] * (self.real_stage_num - self.micro_batch_num))
556 bwd_stage_micro_index[bwd_stage_idx] += 1
557 # cooldown phase
558 for op_idx in range(warmup_ops+fwd_bwd_ops, total_ops):
559 order_list.append(None)
560 bwd_stage_idx = self.backward_stage_index(op_idx - warmup_ops, stage_index)
561 bwd_micro_idx = bwd_stage_micro_index[bwd_stage_idx]
562 order_list.append(MetaStep(bwd_micro_idx, MetaStepType.BWD, bwd_stage_idx))
563 # If micro is less than 2 * (self.real_stage_num - stage_index - 1),
564 # there will be additional bubbles during cooldown phase.
565 if self.micro_batch_num < self.real_stage_num and bwd_micro_idx == self.micro_batch_num - 1:
566 order_list.extend([None] * (self.real_stage_num - self.micro_batch_num))
567 bwd_stage_micro_index[bwd_stage_idx] += 1
568 return order_list
571def detect_cycle_in_graph(ranks_map):
572 """
573 Detects a cycle in the directed graph constructed from ranks_map.
575 Args:
576 ranks_map: A dictionary where keys are rank names and values are lists of nodes.
578 Returns:
579 tuple: (cycle_path, cycle_ranks) where cycle_path is a list of nodes forming the cycle and cycle_ranks
580 is a list of rank transitions corresponding to the cycle path.
581 """
582 graph = defaultdict(list)
583 rank_edges = {}
585 for rank, nodes in ranks_map.items():
586 for i in range(len(nodes) - 1):
587 u, v = nodes[i], nodes[i + 1]
588 graph[u].append(v)
589 rank_edges[(u, v)] = rank
591 visited = set()
592 path = []
593 node_indices = {}
594 cycle_path = []
595 cycle_ranks = []
597 stack = []
598 for node in list(graph.keys()):
599 if node not in visited:
600 stack.append((node, False))
601 while stack:
602 current_node, is_processed = stack.pop()
604 if is_processed:
605 path.pop()
606 del node_indices[current_node]
607 continue
609 if current_node in node_indices:
610 cycle_start = node_indices[current_node]
611 cycle_path = path[cycle_start:] + [current_node]
612 for i in range(cycle_start, len(path)):
613 u = path[i]
614 v = path[i + 1] if i + 1 < len(path) else current_node
615 cycle_ranks.append(f"{rank_edges[(u, v)]} {u} -> {v}")
616 return cycle_path, cycle_ranks
618 if current_node in visited:
619 continue
621 visited.add(current_node)
622 node_indices[current_node] = len(path)
623 path.append(current_node)
625 stack.append((current_node, True))
626 for neighbor in reversed(graph[current_node]):
627 stack.append((neighbor, False))
629 return None, None
632def output_cycle_results(cycle_path, cycle_ranks):
633 """
634 Helper function to output cycle detection results.
636 Args:
637 cycle_path (list): List of nodes forming a cycle, if any.
638 cycle_ranks (list): List of ranks involved in the cycle.
640 Returns:
641 None: Outputs results to the console.
642 """
643 if cycle_path:
644 logger.error("Cycle detected:")
645 path_str = " -> ".join(str(node) for node in cycle_path)
646 logger.error("%s -> %s", path_str, cycle_path[0]) # Close the cycle
647 logger.error("Involving ranks:")
648 for rank in cycle_ranks:
649 logger.error(rank)
650 else:
651 logger.warning("Cycle Check succeeded. There is no cycle in the graph.")
654def parse_and_validate(data: dict, all_rank: bool = True):
655 """
656 Parse and validate execution orders in a directed graph structure.
658 This function checks the integrity and consistency of a given dataset, ensuring all required
659 keys are present and correctly referenced. It also validates the structure of the input data
660 and parses string values to extract meaningful components.
662 Args:
663 data (dict): A dictionary where keys are string identifiers and values are lists of strings.
664 Each value represents a dependency or reference to other keys.
665 all_rank (bool): If True, checks that all elements referenced in the data are present as keys
666 in the dictionary. If False, only checks intersections.
668 Returns:
669 None: Log error messages to the console if validation fails, otherwise completes silently.
671 Raises:
672 ValueError: Raised indirectly if `parse_elements` encounters malformed input strings.
673 TypeError: Raised indirectly if data contains unexpected types.
674 """
676 def parse_elements(value: str, max_groups: int = 2) -> set:
677 """Extract unique elements inside the first one or two parentheses from a string."""
679 groups = re.findall(r'\((\d+)\)', value)
680 limited_groups = groups[:max_groups] # Limit to the first `max_groups` matches
682 return {item.strip() for item in limited_groups}
684 if not isinstance(data, dict):
685 logger.error("Input must be a dictionary with string keys and lists of strings as values.")
686 return
688 key_to_values = {key: set(values) for key, values in data.items() if
689 isinstance(values, list) and all(isinstance(v, str) for v in values)}
691 for key, values in data.items():
692 if not isinstance(values, list) or not all(isinstance(v, str) for v in values):
693 logger.error("Values for key '%s' must be a list of strings.", key)
694 continue
696 for value in values:
697 try:
698 elements = parse_elements(value)
699 except (ValueError, TypeError, AttributeError) as e:
700 logger.error("Unable to parse elements from value '%s' in key '%s'. Error: %s", value, key, e)
701 continue
703 # Check for missing keys if all_rank is True
704 if all_rank:
705 missing_keys = elements - key_to_values.keys()
706 if missing_keys:
707 logger.error("The following keys are missing for value '%s': %s", value, missing_keys)
708 continue
710 # Check if the value is present in the referenced keys
711 for element in elements & key_to_values.keys() if not all_rank else elements:
712 if value not in key_to_values[element]:
713 logger.error("Key '%s' is missing the value '%s'.", element, value)
716def generate_operations(order_list: dict[int, list[MetaStep]],
717 chunk_num: int,
718 com_type: str = 'loop') -> dict[str, list[str]]:
719 """
720 Generate formatted operations dictionary from pipeline execution order.
722 Args:
723 order_list (dict): Dictionary where keys are rank IDs and values are MetaStep execution sequences
724 chunk_num (int): Number of chunks (virtual pipeline stages)
725 com_type (str): Stage-to-rank mapping type ('loop' for cyclic, 'v' for V-shaped)
727 Returns:
728 Dictionary where keys are rank IDs (as strings) and values are lists of formatted operation strings
729 """
731 def stage_to_rank(stage_index, style, stage_num, real_stage_num):
732 """Map stage index to rank"""
733 if style == 'loop':
734 return stage_index % real_stage_num
735 if style == 'v':
736 if stage_index < real_stage_num:
737 return stage_index
738 return stage_num - 1 - stage_index
739 raise ValueError("Invalid style")
741 def find_send_target(stage_idx, op_type):
742 """Find target stage for SEND operation"""
743 if op_type == MetaStepType.FWD_SEND:
744 return forward_comm.get(stage_idx)
745 return backward_comm.get(stage_idx)
747 def find_recv_source(stage_idx, op_type):
748 """Find source stage for RECV operation"""
749 if op_type == MetaStepType.FWD_RECV:
750 # Reverse lookup in forward_comm
751 for src, dst in forward_comm.items():
752 if dst == stage_idx:
753 return src
754 else:
755 # Reverse lookup in backward_comm
756 for src, dst in backward_comm.items():
757 if dst == stage_idx:
758 return src
759 return None
761 real_stage = len(order_list)
762 total_stages = real_stage * chunk_num
764 # Build communication rules
765 forward_comm = {}
766 backward_comm = {}
768 for i in range(total_stages):
769 if i + 1 < total_stages:
770 forward_comm[i] = i + 1
771 if i - 1 >= 0:
772 backward_comm[i] = i - 1
774 formatted_operations = defaultdict(list)
776 for rank, steps in order_list.items():
777 operation_counter = defaultdict(int)
779 for step in steps:
780 if step.type in [MetaStepType.FWD_SEND, MetaStepType.BWD_SEND]:
781 target_stage = find_send_target(step.stage_index, step.type)
782 if target_stage is not None:
783 target_rank = stage_to_rank(target_stage, com_type, total_stages, real_stage)
784 comm_pair = (rank, target_rank, step.micro_index)
785 operation_counter[comm_pair] += 1
786 count = operation_counter[comm_pair]
787 formatted_op = f"Send_Receive_({rank})->({target_rank})_micro{step.micro_index}_{count}th"
788 formatted_operations[str(rank)].append(formatted_op)
790 elif step.type in [MetaStepType.FWD_RECV, MetaStepType.BWD_RECV]:
791 source_stage = find_recv_source(step.stage_index, step.type)
792 if source_stage is not None:
793 source_rank = stage_to_rank(source_stage, com_type, total_stages, real_stage)
794 comm_pair = (source_rank, rank, step.micro_index)
795 operation_counter[comm_pair] += 1
796 count = operation_counter[comm_pair]
797 formatted_op = f"Send_Receive_({source_rank})->({rank})_micro{step.micro_index}_{count}th"
798 formatted_operations[str(rank)].append(formatted_op)
800 # Convert defaultdict to dict
801 return dict(formatted_operations)
804def validate_pipeline_execution(order_list: dict[int, list[MetaStep]],
805 chunk_num: int,
806 com_type: str = 'loop') -> dict[str, any]:
807 """
808 Comprehensive validation function for pipeline parallel execution order.
810 This function validates the execution order of pipeline parallelism by:
811 1. Checking SEND/RECV communication pair matching
812 2. Detecting duplicate operations
813 3. Detecting cycles in communication graphs
814 4. Verifying computation-SEND matching
816 Args:
817 order_list: Dictionary where keys are rank IDs and values are MetaStep execution sequences
818 chunk_num: Number of chunks (virtual pipeline stages)
819 com_type: Stage-to-rank mapping type ('loop' for cyclic, 'v' for V-shaped)
821 Returns:
822 Dictionary containing validation results with the following keys:
823 - validation: Communication pair validation results
824 - cycle_detection: Cycle detection results
825 - computation_send_matching: Computation-SEND matching validation results
826 - has_errors: Boolean indicating if any errors were found
827 - error_messages: List of all error messages found
828 - formatted_operations: Generated formatted operations
829 """
831 # Generate operations
832 formatted_operations = generate_operations(order_list, chunk_num, com_type)
834 parse_and_validate(formatted_operations, True)
836 # Detect cycles
837 cycle_path, cycle_ranks = detect_cycle_in_graph(formatted_operations)
839 # Output results
840 output_cycle_results(cycle_path, cycle_ranks)
842 result = {
843 'formatted_operations': formatted_operations,
844 'cycle_path': cycle_path,
845 'cycle_ranks': cycle_ranks,
846 'has_cycle': bool(cycle_path)
847 }
848 return result