Coverage for hyper_parallel / core / pipeline_parallel / scheduler.py: 51%

476 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""pipeline schedule""" 

16from abc import ABC 

17from enum import Enum, auto 

18from collections import defaultdict 

19import itertools 

20import bisect 

21import logging 

22import re 

23import hyper_parallel 

24from hyper_parallel.platform import get_platform 

25platform = get_platform() 

26logger = logging.getLogger(__name__) 

27 

28 

29class MetaStepType(Enum): 

30 """Specify the enumeration type for MetaStep.""" 

31 FWD = auto() 

32 BWD = auto() 

33 FWD_RECV = auto() 

34 FWD_SEND = auto() 

35 BWD_RECV = auto() 

36 BWD_SEND = auto() 

37 

38 

39class MetaStep: 

40 """ 

41 Meta step of PipelineSchedule. 

42 An execution list composed of MetaStep can be constructed 

43 and fed into the PipelineSchedule for execution. 

44 

45 Args: 

46 micro_index (int): The index of micro-batch. 

47 type (MetaStepType): Specify the type of current step. 

48 stage_index (int): Specify the stage index of current step. 

49 """ 

50 def __init__(self, micro_index, meta_type, stage_index): 

51 self._type = meta_type 

52 self._micro_index = micro_index 

53 self._stage_index = stage_index 

54 

55 @property 

56 def micro_index(self): 

57 return self._micro_index 

58 

59 @property 

60 def stage_index(self): 

61 return self._stage_index 

62 

63 @property 

64 def type(self): 

65 return self._type 

66 

67 def __eq__(self, value): 

68 if not isinstance(value, MetaStep): 

69 return NotImplemented 

70 return self.type == value.type and \ 

71 self.micro_index == value.micro_index and \ 

72 self.stage_index == value.stage_index 

73 

74 def __ne__(self, value): 

75 if not isinstance(value, MetaStep): 

76 return NotImplemented 

77 return self.type != value.type or \ 

78 self.micro_index != value.micro_index or \ 

79 self.stage_index != value.stage_index 

80 

81 def __hash__(self): 

82 return hash((self.type, self.micro_index, self.stage_index)) 

83 

84 def __str__(self): 

85 return f"MetaStep(type={self.type}, micro_index={self.micro_index}, stage_index={self.stage_index})" 

86 

87 def __repr__(self): 

88 return f"MetaStep(type={self.type}, micro_index={self.micro_index}, stage_index={self.stage_index})" 

89 

90 @staticmethod 

91 def from_str(step_str): 

92 pass 

93 

94 

95class PipelineScheduleRuntime(ABC): 

96 """ 

97 Base class for pipeline schedule. 

98 Implements the `split_microbatches` and `run_microbatches` method. 

99 Derived classes should implement `run_microbatches` method and `run` method. 

100 

101 Args: 

102 stages (list[PipelineStage], PipelineStage): PipelineStage used to run_microbatches. 

103 micro_batch_num (int): The number of micro-batch. 

104 args_batch_dim (list, optional): Specify the batch dim of the args. 

105 Default ``None``. 

106 kwargs_batch_dim (dict, optional): Specify the batch dim of the kwargs. 

107 Default ``None``. 

108 """ 

109 def __init__(self, 

110 stages, 

111 micro_batch_num, 

112 args_batch_dim=None, 

113 kwargs_batch_dim=None, 

114 output_concat_dim=None, 

115 overlap_p2p=False): 

116 self.stages = self._check_stages(stages) 

117 self.micro_batch_num = micro_batch_num 

118 self._args_batch_dim = args_batch_dim 

119 self._kwargs_batch_dim = kwargs_batch_dim 

120 self._output_concat_dim = output_concat_dim 

121 self.split_micro_batch = platform.micro_batch(self.micro_batch_num, 

122 self._args_batch_dim, self._kwargs_batch_dim) 

123 self.n_local_stages = len(self.stages) 

124 self._stage_dict = self.convert_stages_dict() 

125 self.real_stage_num = self.stages[0].stage_num // self.n_local_stages 

126 self._stage_num = self.stages[0].stage_num 

127 self._overlap_p2p = overlap_p2p 

128 self.exec_order = {} 

129 self._init_stages() 

130 self.fwd_handle_cache = {} 

131 self.bwd_handle_cache = {} 

132 

133 def convert_stages_dict(self): 

134 """convert stages to dict.""" 

135 stage_dict = {} 

136 for stage in self.stages: 

137 stage_dict[stage.stage_index] = stage 

138 return stage_dict 

139 

140 def split_microbatches(self, args, kwargs): 

141 """split_microbatches.""" 

142 if args or kwargs: 

143 args_split, kwargs_split = self.split_micro_batch(args, kwargs) 

144 return args_split, kwargs_split 

145 return [[]] * self.micro_batch_num, [{}] * self.micro_batch_num 

146 

147 def _check_stages(self, stages): 

148 """check stages type.""" 

149 if isinstance(stages, hyper_parallel.PipelineStage): 

150 return [stages] 

151 if isinstance(stages, (list, tuple)): 

152 for stage in stages: 

153 if not isinstance(stage, hyper_parallel.PipelineStage): 

154 raise TypeError(f"Argument 'stages' must be type of PipelineStage, \ 

155 list or tuple of PipelineStage, but got list or tuple of {type(stage)}.") 

156 return stages 

157 raise TypeError(f"Argument 'stages' must be type of PipelineStage, \ 

158 list or tuple of PipelineStage, but got type of {type(stages)}.") 

159 

160 def _init_stages(self): 

161 """init stages.""" 

162 for stage in self.stages: 

163 stage.init(self.n_local_stages) 

164 

165 def run(self, *args, **kwargs): 

166 """schedule run.""" 

167 split_args, split_kwargs = self.split_microbatches(args, kwargs) 

168 losses = [] 

169 self.run_microbatches(split_args, split_kwargs, losses) 

170 return losses 

171 

172 def sync_shared_parameters_grad(self): 

173 """sync_shared_parameters_grad.""" 

174 for stage in self.stages: 

175 stage.sync_shared_parameters_grad() 

176 

177 def update_losses(self, stage, loss, losses): 

178 """update_losses.""" 

179 if stage.is_last_stage: 

180 losses.append(loss) 

181 

182 def _wait_p2p(self, handles): 

183 for handle in handles: 

184 if handle is not None: 

185 handle.wait() 

186 

187 def run_microbatches(self, arg_mbs, kwarg_mbs, losses): 

188 """run_microbatches.""" 

189 real_stage_index = self.stages[0].stage_index % self.real_stage_num 

190 send_handle = [] 

191 for cur_step in self.exec_order[real_stage_index]: 

192 if cur_step is None: 

193 continue 

194 stage = self._stage_dict[cur_step.stage_index] 

195 stage_index = cur_step.stage_index 

196 micro_index = cur_step.micro_index 

197 if cur_step.type == MetaStepType.FWD_RECV: 

198 comm_handle = stage.exec_fwd_recv_ops(micro_index) 

199 if not self._overlap_p2p: 

200 self._wait_p2p(comm_handle) 

201 else: 

202 key = (stage_index, micro_index) 

203 self.fwd_handle_cache[key] = comm_handle 

204 if cur_step.type == MetaStepType.FWD: 

205 key = (stage_index, micro_index) 

206 if self._overlap_p2p and key in self.fwd_handle_cache: 

207 comm_handle = self.fwd_handle_cache.pop(key) 

208 self._wait_p2p(comm_handle) 

209 out = stage.forward_one_chunk(micro_index, arg_mbs[micro_index], kwarg_mbs[micro_index]) 

210 self.update_losses(stage, out, losses) 

211 if cur_step.type == MetaStepType.FWD_SEND: 

212 comm_handle = stage.exec_fwd_send_ops(micro_index) 

213 if not self._overlap_p2p: 

214 self._wait_p2p(comm_handle) 

215 else: 

216 send_handle.append(comm_handle) 

217 if cur_step.type == MetaStepType.BWD_RECV: 

218 comm_handle = stage.exec_bwd_recv_ops(micro_index) 

219 if not self._overlap_p2p: 

220 self._wait_p2p(comm_handle) 

221 else: 

222 key = (stage_index, micro_index) 

223 self.bwd_handle_cache[key] = comm_handle 

224 if cur_step.type == MetaStepType.BWD: 

225 key = (stage_index, micro_index) 

226 if self._overlap_p2p and key in self.bwd_handle_cache: 

227 comm_handle = self.bwd_handle_cache.pop(key) 

228 self._wait_p2p(comm_handle) 

229 if micro_index == self.micro_batch_num - 1: 

230 stage.backward_one_chunk(micro_index, True) 

231 else: 

232 stage.backward_one_chunk(micro_index) 

233 if cur_step.type == MetaStepType.BWD_SEND: 

234 comm_handle = stage.exec_bwd_send_ops(micro_index) 

235 if not self._overlap_p2p: 

236 self._wait_p2p(comm_handle) 

237 else: 

238 send_handle.append(comm_handle) 

239 self.sync_shared_parameters_grad() 

240 while send_handle: 

241 self._wait_p2p(send_handle.pop()) 

242 

243 

244def add_send_recv(scheduler, stage_num, real_stage_num, style='loop'): 

245 """ 

246 Create schedule for each rank and automatically add communication operations 

247 

248 Args: 

249 scheduler: Compute schedule table with None 

250 stage_num: Total number of pipeline stages 

251 real_stage_num: Number of actual physical stages/ranks 

252 style: Communication style ('loop' or 'v') 

253 

254 Returns: 

255 Complete schedule table for each rank (including communication operations) 

256 """ 

257 

258 def _need_com(action, style, stage_num): 

259 """Determine if communication is needed""" 

260 if action.type == MetaStepType.FWD: 

261 if action.stage_index == stage_num - 1: 

262 return False # Last stage doesn't need forward communication 

263 next_stage_rank = stage_to_rank(action.stage_index + 1, style, stage_num, real_stage_num) 

264 current_rank = stage_to_rank(action.stage_index, style, stage_num, real_stage_num) 

265 return next_stage_rank != current_rank 

266 if action.type == MetaStepType.BWD: 

267 if action.stage_index == 0: 

268 return False # First stage doesn't need backward communication 

269 prev_stage_rank = stage_to_rank(action.stage_index - 1, style, stage_num, real_stage_num) 

270 current_rank = stage_to_rank(action.stage_index, style, stage_num, real_stage_num) 

271 return prev_stage_rank != current_rank 

272 return False 

273 

274 def stage_to_rank(stage_index, style, stage_num, real_stage_num): 

275 """Map stage index to rank""" 

276 if style == 'loop': 

277 return stage_index % real_stage_num 

278 if style == 'v': 

279 if stage_index < real_stage_num: 

280 return stage_index 

281 return stage_num - 1 - stage_index 

282 raise ValueError("Invalid style") 

283 

284 def process_rank_communication(rank, operation, new_schedule, style, stage_num, real_stage_num): 

285 """Process communication operations for single rank""" 

286 if operation is None: 

287 return 

288 

289 stage_index = operation.stage_index 

290 pre_rank = stage_to_rank(stage_index - 1, style, stage_num, real_stage_num) if stage_index > 0 else 0 

291 nxt_rank = stage_to_rank(stage_index + 1, style, stage_num, real_stage_num) if stage_index < stage_num else None 

292 

293 if (operation.type == MetaStepType.FWD and 

294 _need_com(operation, style, stage_num) and nxt_rank is not None): 

295 new_schedule[rank].append(MetaStep( 

296 micro_index=operation.micro_index, 

297 meta_type=MetaStepType.FWD_SEND, # 注意:使用 FWD_SEND 而不是 FWD_SEND 

298 stage_index=stage_index 

299 )) 

300 new_schedule[nxt_rank].append(MetaStep( 

301 micro_index=operation.micro_index, 

302 meta_type=MetaStepType.FWD_RECV, # 注意:使用 FWD_RECV 而不是 FWD_RECV 

303 stage_index=stage_index + 1 

304 )) 

305 elif (operation.type == MetaStepType.BWD and 

306 _need_com(operation, style, stage_num) and pre_rank is not None): 

307 new_schedule[rank].append(MetaStep( 

308 micro_index=operation.micro_index, 

309 meta_type=MetaStepType.BWD_SEND, # 注意:使用 BWD_SEND 而不是 BWD_SEND 

310 stage_index=stage_index 

311 )) 

312 new_schedule[pre_rank].append(MetaStep( 

313 micro_index=operation.micro_index, 

314 meta_type=MetaStepType.BWD_RECV, # 注意:使用 BWD_RECV 而不是 BWD_RECV 

315 stage_index=stage_index - 1 

316 )) 

317 

318 # Main logic 

319 max_length = max(len(schedule) for schedule in scheduler.values()) 

320 new_schedule = {rank: [] for rank in range(real_stage_num)} 

321 

322 for time_step in range(max_length): 

323 current_operations = {} 

324 for rank in range(real_stage_num): 

325 if time_step < len(scheduler[rank]): 

326 operation = scheduler[rank][time_step] 

327 current_operations[rank] = operation 

328 if operation is not None: 

329 new_schedule[rank].append(operation) 

330 else: 

331 current_operations[rank] = None 

332 

333 # Process even rank communication 

334 for rank in range(0, real_stage_num, 2): 

335 process_rank_communication(rank, current_operations[rank], new_schedule, 

336 style, stage_num, real_stage_num) 

337 

338 # Process odd rank communication 

339 for rank in range(1, real_stage_num, 2): 

340 process_rank_communication(rank, current_operations[rank], new_schedule, 

341 style, stage_num, real_stage_num) 

342 

343 return new_schedule 

344 

345 

346class ScheduleGPipe(PipelineScheduleRuntime): 

347 """ 

348 The Gpipe schedule. 

349 It first executes all forward micro batches and then execute all backward micro batches. 

350 """ 

351 def __init__(self, 

352 stages, 

353 micro_batch_num, 

354 args_batch_dim=None, 

355 kwargs_batch_dim=None, 

356 output_concat_dim=None): 

357 super().__init__(stages, 

358 micro_batch_num, 

359 args_batch_dim=args_batch_dim, 

360 kwargs_batch_dim=kwargs_batch_dim, 

361 output_concat_dim=output_concat_dim) 

362 self.construct_exec_order() 

363 

364 def construct_exec_order(self): 

365 """construct_exec_order of Gpipe.""" 

366 for stage_index in range(self.real_stage_num): 

367 order_list = [] 

368 for mb_index in range(self.micro_batch_num): 

369 if stage_index != 0: 

370 order_list.append(MetaStep(mb_index, MetaStepType.FWD_RECV, stage_index)) 

371 order_list.append(MetaStep(mb_index, MetaStepType.FWD, stage_index)) 

372 if stage_index != self.stage.stage_num - 1: 

373 order_list.append(MetaStep(mb_index, MetaStepType.FWD_SEND, stage_index)) 

374 for mb_index in range(self.micro_batch_num): 

375 if stage_index != self.real_stage_num - 1: 

376 order_list.append(MetaStep(mb_index, MetaStepType.BWD_RECV, stage_index)) 

377 order_list.append(MetaStep(mb_index, MetaStepType.BWD, stage_index)) 

378 if stage_index != 0: 

379 order_list.append(MetaStep(mb_index, MetaStepType.BWD_SEND, stage_index)) 

380 self.exec_order[stage_index] = order_list 

381 

382 

383class Schedule1F1B(PipelineScheduleRuntime): 

384 """ 

385 The 1F1B schedule. 

386 It will perform one forward and one backward on the micro batches in steady state. 

387 """ 

388 def __init__(self, 

389 stages, 

390 micro_batch_num, 

391 args_batch_dim=None, 

392 kwargs_batch_dim=None, 

393 output_concat_dim=None): 

394 super().__init__(stages, 

395 micro_batch_num, 

396 args_batch_dim=args_batch_dim, 

397 kwargs_batch_dim=kwargs_batch_dim, 

398 output_concat_dim=output_concat_dim) 

399 self.construct_exec_order() 

400 

401 def construct_exec_order(self): 

402 """construct_exec_order of 1F1B.""" 

403 for stage_index in range(self.real_stage_num): 

404 order_list = [] 

405 fwd_index = 0 

406 bwd_index = 0 

407 # warmup phase 

408 warmup_micro_batches = min(self.real_stage_num - stage_index, self.micro_batch_num) 

409 for _ in range(warmup_micro_batches): 

410 if stage_index != 0: 

411 order_list.append(MetaStep(fwd_index, MetaStepType.FWD_RECV, stage_index)) 

412 if stage_index % 2 == 0: 

413 order_list.append(MetaStep(fwd_index, MetaStepType.FWD, stage_index)) 

414 if fwd_index != warmup_micro_batches - 1: 

415 order_list.append(MetaStep(fwd_index, MetaStepType.FWD_SEND, stage_index)) 

416 else: 

417 if fwd_index > 0: 

418 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index)) 

419 order_list.append(MetaStep(fwd_index, MetaStepType.FWD, stage_index)) 

420 fwd_index += 1 

421 

422 # if warmup phase cannot filled up, then we need to execute fwd send in advance 

423 if self.real_stage_num - stage_index > self.micro_batch_num: 

424 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index)) 

425 fwd_index += 1 

426 # steady phase 

427 steady_micro_batches = self.micro_batch_num - warmup_micro_batches 

428 for _ in range(steady_micro_batches): 

429 if stage_index != self.real_stage_num - 1: 

430 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_RECV, stage_index)) 

431 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index)) 

432 order_list.append(MetaStep(bwd_index, MetaStepType.BWD, stage_index)) 

433 

434 if stage_index != 0: 

435 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_SEND, stage_index)) 

436 order_list.append(MetaStep(fwd_index, MetaStepType.FWD_RECV, stage_index)) 

437 order_list.append(MetaStep(fwd_index, MetaStepType.FWD, stage_index)) 

438 fwd_index += 1 

439 bwd_index += 1 

440 

441 # cooldown phase 

442 cooldown_micro_batches = warmup_micro_batches 

443 for _ in range(cooldown_micro_batches): 

444 if stage_index != self.real_stage_num - 1: 

445 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_RECV, stage_index)) 

446 if bwd_index == self.micro_batch_num - warmup_micro_batches and fwd_index <= self.micro_batch_num: 

447 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index)) 

448 order_list.append(MetaStep(bwd_index, MetaStepType.BWD, stage_index)) 

449 

450 if stage_index != 0: 

451 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_SEND, stage_index)) 

452 bwd_index += 1 

453 self.exec_order[stage_index] = order_list 

454 

455 

456class ScheduleInterleaved1F1B(PipelineScheduleRuntime): 

457 """ 

458 The Interleaved 1F1B schedule. 

459 Support multiple stages per rank. It will perform one forward and one backward 

460 on the micro batches in steady state. 

461 We support cases where num_microbatch is less than or equal, or greater than the 

462 stage num, as well as cases where num_microbatch can't be evenly divided by the 

463 stage num. 

464 """ 

465 def __init__(self, 

466 stages, 

467 micro_batch_num, 

468 args_batch_dim=None, 

469 kwargs_batch_dim=None, 

470 output_concat_dim=None): 

471 super().__init__(stages, 

472 micro_batch_num, 

473 args_batch_dim=args_batch_dim, 

474 kwargs_batch_dim=kwargs_batch_dim, 

475 output_concat_dim=output_concat_dim) 

476 self.n_rounds = max(1, self.micro_batch_num // self.real_stage_num) 

477 if self.micro_batch_num < self.real_stage_num: 

478 base = self.micro_batch_num - self.real_stage_num 

479 remainder = 0 

480 else: 

481 n_extra_microbatch = self.micro_batch_num % self.real_stage_num 

482 base = n_extra_microbatch // self.n_rounds 

483 remainder = n_extra_microbatch % self.n_rounds 

484 self.n_microbatch_per_round = \ 

485 [self.real_stage_num + base + 1 if i < remainder else 

486 self.real_stage_num + base for i in range(self.n_rounds)] 

487 self.n_microbatch_per_round_accu = \ 

488 [x * self.n_local_stages for x in itertools.accumulate(self.n_microbatch_per_round)] 

489 self.n_microbatch_per_round_accu.insert(0, 0) 

490 for stage_index in range(self.real_stage_num): 

491 self.exec_order[stage_index] = self.construct_stage_exec_order(stage_index) 

492 self.exec_order = add_send_recv(self.exec_order, self._stage_num, self.real_stage_num, style = 'loop') 

493 

494 def warmup_ops(self, stage_index): 

495 """warmup phase.""" 

496 warmup_ops_last_stage = (self.n_local_stages - 1) * self.n_microbatch_per_round[0] 

497 warmup_ops = warmup_ops_last_stage + 2 * (self.real_stage_num - 1 - stage_index) 

498 return min(warmup_ops, self.micro_batch_num * self.n_local_stages) 

499 

500 def forward_stage_index(self, op_index, stage_index): 

501 """obtain forward stage_index based on op_index.""" 

502 accu_index = bisect.bisect_right(self.n_microbatch_per_round_accu, op_index) - 1 

503 local_index = (op_index - self.n_microbatch_per_round_accu[accu_index]) // \ 

504 self.n_microbatch_per_round[accu_index] 

505 return (local_index * self.real_stage_num) + stage_index 

506 

507 def backward_stage_index(self, op_index, stage_index): 

508 """obtain backward stage_index based on op_index.""" 

509 accu_index = bisect.bisect_right(self.n_microbatch_per_round_accu, op_index) - 1 

510 local_index = (op_index - self.n_microbatch_per_round_accu[accu_index]) // \ 

511 self.n_microbatch_per_round[accu_index] 

512 local_index = self.n_local_stages - 1 - local_index 

513 return (local_index * self.real_stage_num) + stage_index 

514 

515 def construct_stage_exec_order(self, stage_index): 

516 """construct the execution order of specified stage_index.""" 

517 warmup_ops = self.warmup_ops(stage_index) 

518 fwd_bwd_ops = self.n_local_stages * self.micro_batch_num - warmup_ops 

519 cooldown_ops = warmup_ops 

520 total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops 

521 # Pre-padding bubbles, stage starts with no-ops based on the warmup. 

522 order_list = [None for _ in range(stage_index)] 

523 fwd_stage_micro_index = defaultdict(int) 

524 bwd_stage_micro_index = defaultdict(int) 

525 # WarmUp Phase 

526 for op_idx in range(warmup_ops): 

527 fwd_stage_idx = self.forward_stage_index(op_idx, stage_index) 

528 fwd_micro_idx = fwd_stage_micro_index[fwd_stage_idx] 

529 order_list.append(MetaStep(fwd_micro_idx, MetaStepType.FWD, fwd_stage_idx)) 

530 # If micro is less than stage num, there will be additional bubbles during warmup phase. 

531 if self.micro_batch_num < self.real_stage_num and fwd_micro_idx == self.micro_batch_num - 1: 

532 if op_idx != warmup_ops - 1 or stage_index == self.real_stage_num - 1: 

533 order_list.extend([None] * (self.real_stage_num - self.micro_batch_num)) 

534 fwd_stage_micro_index[fwd_stage_idx] += 1 

535 # If micro is less than 2 * (self.real_stage_num - stage_index - 1), 

536 # there will be additional bubbles during warmup phase. 

537 if self.micro_batch_num < 2 * (self.real_stage_num - stage_index - 1): 

538 order_list.extend([None] * (2 * (self.real_stage_num - stage_index - 1) - self.micro_batch_num)) 

539 # Bubble from the end of warmup to the start of backward. 

540 order_list.extend([None] * (self.real_stage_num - 1 - stage_index)) 

541 

542 # 1f1b phase 

543 for op_idx in range(warmup_ops, warmup_ops+fwd_bwd_ops): 

544 fwd_stage_idx = self.forward_stage_index(op_idx, stage_index) 

545 fwd_micro_idx = fwd_stage_micro_index[fwd_stage_idx] 

546 order_list.append(MetaStep(fwd_micro_idx, MetaStepType.FWD, fwd_stage_idx)) 

547 fwd_stage_micro_index[fwd_stage_idx] += 1 

548 bwd_stage_idx = self.backward_stage_index(op_idx - warmup_ops, stage_index) 

549 bwd_micro_idx = bwd_stage_micro_index[bwd_stage_idx] 

550 order_list.append(MetaStep(bwd_micro_idx, MetaStepType.BWD, bwd_stage_idx)) 

551 # If micro is less than 2 * (self.real_stage_num - stage_index - 1), 

552 # there will be additional bubbles after 1f1b phase in last stage. 

553 if self.micro_batch_num < self.real_stage_num and bwd_micro_idx == self.micro_batch_num - 1: 

554 if stage_index == self.real_stage_num - 1: 

555 order_list.extend([None] * (self.real_stage_num - self.micro_batch_num)) 

556 bwd_stage_micro_index[bwd_stage_idx] += 1 

557 # cooldown phase 

558 for op_idx in range(warmup_ops+fwd_bwd_ops, total_ops): 

559 order_list.append(None) 

560 bwd_stage_idx = self.backward_stage_index(op_idx - warmup_ops, stage_index) 

561 bwd_micro_idx = bwd_stage_micro_index[bwd_stage_idx] 

562 order_list.append(MetaStep(bwd_micro_idx, MetaStepType.BWD, bwd_stage_idx)) 

563 # If micro is less than 2 * (self.real_stage_num - stage_index - 1), 

564 # there will be additional bubbles during cooldown phase. 

565 if self.micro_batch_num < self.real_stage_num and bwd_micro_idx == self.micro_batch_num - 1: 

566 order_list.extend([None] * (self.real_stage_num - self.micro_batch_num)) 

567 bwd_stage_micro_index[bwd_stage_idx] += 1 

568 return order_list 

569 

570 

571def detect_cycle_in_graph(ranks_map): 

572 """ 

573 Detects a cycle in the directed graph constructed from ranks_map. 

574 

575 Args: 

576 ranks_map: A dictionary where keys are rank names and values are lists of nodes. 

577 

578 Returns: 

579 tuple: (cycle_path, cycle_ranks) where cycle_path is a list of nodes forming the cycle and cycle_ranks 

580 is a list of rank transitions corresponding to the cycle path. 

581 """ 

582 graph = defaultdict(list) 

583 rank_edges = {} 

584 

585 for rank, nodes in ranks_map.items(): 

586 for i in range(len(nodes) - 1): 

587 u, v = nodes[i], nodes[i + 1] 

588 graph[u].append(v) 

589 rank_edges[(u, v)] = rank 

590 

591 visited = set() 

592 path = [] 

593 node_indices = {} 

594 cycle_path = [] 

595 cycle_ranks = [] 

596 

597 stack = [] 

598 for node in list(graph.keys()): 

599 if node not in visited: 

600 stack.append((node, False)) 

601 while stack: 

602 current_node, is_processed = stack.pop() 

603 

604 if is_processed: 

605 path.pop() 

606 del node_indices[current_node] 

607 continue 

608 

609 if current_node in node_indices: 

610 cycle_start = node_indices[current_node] 

611 cycle_path = path[cycle_start:] + [current_node] 

612 for i in range(cycle_start, len(path)): 

613 u = path[i] 

614 v = path[i + 1] if i + 1 < len(path) else current_node 

615 cycle_ranks.append(f"{rank_edges[(u, v)]} {u} -> {v}") 

616 return cycle_path, cycle_ranks 

617 

618 if current_node in visited: 

619 continue 

620 

621 visited.add(current_node) 

622 node_indices[current_node] = len(path) 

623 path.append(current_node) 

624 

625 stack.append((current_node, True)) 

626 for neighbor in reversed(graph[current_node]): 

627 stack.append((neighbor, False)) 

628 

629 return None, None 

630 

631 

632def output_cycle_results(cycle_path, cycle_ranks): 

633 """ 

634 Helper function to output cycle detection results. 

635 

636 Args: 

637 cycle_path (list): List of nodes forming a cycle, if any. 

638 cycle_ranks (list): List of ranks involved in the cycle. 

639 

640 Returns: 

641 None: Outputs results to the console. 

642 """ 

643 if cycle_path: 

644 logger.error("Cycle detected:") 

645 path_str = " -> ".join(str(node) for node in cycle_path) 

646 logger.error("%s -> %s", path_str, cycle_path[0]) # Close the cycle 

647 logger.error("Involving ranks:") 

648 for rank in cycle_ranks: 

649 logger.error(rank) 

650 else: 

651 logger.warning("Cycle Check succeeded. There is no cycle in the graph.") 

652 

653 

654def parse_and_validate(data: dict, all_rank: bool = True): 

655 """ 

656 Parse and validate execution orders in a directed graph structure. 

657 

658 This function checks the integrity and consistency of a given dataset, ensuring all required 

659 keys are present and correctly referenced. It also validates the structure of the input data 

660 and parses string values to extract meaningful components. 

661 

662 Args: 

663 data (dict): A dictionary where keys are string identifiers and values are lists of strings. 

664 Each value represents a dependency or reference to other keys. 

665 all_rank (bool): If True, checks that all elements referenced in the data are present as keys 

666 in the dictionary. If False, only checks intersections. 

667 

668 Returns: 

669 None: Log error messages to the console if validation fails, otherwise completes silently. 

670 

671 Raises: 

672 ValueError: Raised indirectly if `parse_elements` encounters malformed input strings. 

673 TypeError: Raised indirectly if data contains unexpected types. 

674 """ 

675 

676 def parse_elements(value: str, max_groups: int = 2) -> set: 

677 """Extract unique elements inside the first one or two parentheses from a string.""" 

678 

679 groups = re.findall(r'\((\d+)\)', value) 

680 limited_groups = groups[:max_groups] # Limit to the first `max_groups` matches 

681 

682 return {item.strip() for item in limited_groups} 

683 

684 if not isinstance(data, dict): 

685 logger.error("Input must be a dictionary with string keys and lists of strings as values.") 

686 return 

687 

688 key_to_values = {key: set(values) for key, values in data.items() if 

689 isinstance(values, list) and all(isinstance(v, str) for v in values)} 

690 

691 for key, values in data.items(): 

692 if not isinstance(values, list) or not all(isinstance(v, str) for v in values): 

693 logger.error("Values for key '%s' must be a list of strings.", key) 

694 continue 

695 

696 for value in values: 

697 try: 

698 elements = parse_elements(value) 

699 except (ValueError, TypeError, AttributeError) as e: 

700 logger.error("Unable to parse elements from value '%s' in key '%s'. Error: %s", value, key, e) 

701 continue 

702 

703 # Check for missing keys if all_rank is True 

704 if all_rank: 

705 missing_keys = elements - key_to_values.keys() 

706 if missing_keys: 

707 logger.error("The following keys are missing for value '%s': %s", value, missing_keys) 

708 continue 

709 

710 # Check if the value is present in the referenced keys 

711 for element in elements & key_to_values.keys() if not all_rank else elements: 

712 if value not in key_to_values[element]: 

713 logger.error("Key '%s' is missing the value '%s'.", element, value) 

714 

715 

716def generate_operations(order_list: dict[int, list[MetaStep]], 

717 chunk_num: int, 

718 com_type: str = 'loop') -> dict[str, list[str]]: 

719 """ 

720 Generate formatted operations dictionary from pipeline execution order. 

721 

722 Args: 

723 order_list (dict): Dictionary where keys are rank IDs and values are MetaStep execution sequences 

724 chunk_num (int): Number of chunks (virtual pipeline stages) 

725 com_type (str): Stage-to-rank mapping type ('loop' for cyclic, 'v' for V-shaped) 

726 

727 Returns: 

728 Dictionary where keys are rank IDs (as strings) and values are lists of formatted operation strings 

729 """ 

730 

731 def stage_to_rank(stage_index, style, stage_num, real_stage_num): 

732 """Map stage index to rank""" 

733 if style == 'loop': 

734 return stage_index % real_stage_num 

735 if style == 'v': 

736 if stage_index < real_stage_num: 

737 return stage_index 

738 return stage_num - 1 - stage_index 

739 raise ValueError("Invalid style") 

740 

741 def find_send_target(stage_idx, op_type): 

742 """Find target stage for SEND operation""" 

743 if op_type == MetaStepType.FWD_SEND: 

744 return forward_comm.get(stage_idx) 

745 return backward_comm.get(stage_idx) 

746 

747 def find_recv_source(stage_idx, op_type): 

748 """Find source stage for RECV operation""" 

749 if op_type == MetaStepType.FWD_RECV: 

750 # Reverse lookup in forward_comm 

751 for src, dst in forward_comm.items(): 

752 if dst == stage_idx: 

753 return src 

754 else: 

755 # Reverse lookup in backward_comm 

756 for src, dst in backward_comm.items(): 

757 if dst == stage_idx: 

758 return src 

759 return None 

760 

761 real_stage = len(order_list) 

762 total_stages = real_stage * chunk_num 

763 

764 # Build communication rules 

765 forward_comm = {} 

766 backward_comm = {} 

767 

768 for i in range(total_stages): 

769 if i + 1 < total_stages: 

770 forward_comm[i] = i + 1 

771 if i - 1 >= 0: 

772 backward_comm[i] = i - 1 

773 

774 formatted_operations = defaultdict(list) 

775 

776 for rank, steps in order_list.items(): 

777 operation_counter = defaultdict(int) 

778 

779 for step in steps: 

780 if step.type in [MetaStepType.FWD_SEND, MetaStepType.BWD_SEND]: 

781 target_stage = find_send_target(step.stage_index, step.type) 

782 if target_stage is not None: 

783 target_rank = stage_to_rank(target_stage, com_type, total_stages, real_stage) 

784 comm_pair = (rank, target_rank, step.micro_index) 

785 operation_counter[comm_pair] += 1 

786 count = operation_counter[comm_pair] 

787 formatted_op = f"Send_Receive_({rank})->({target_rank})_micro{step.micro_index}_{count}th" 

788 formatted_operations[str(rank)].append(formatted_op) 

789 

790 elif step.type in [MetaStepType.FWD_RECV, MetaStepType.BWD_RECV]: 

791 source_stage = find_recv_source(step.stage_index, step.type) 

792 if source_stage is not None: 

793 source_rank = stage_to_rank(source_stage, com_type, total_stages, real_stage) 

794 comm_pair = (source_rank, rank, step.micro_index) 

795 operation_counter[comm_pair] += 1 

796 count = operation_counter[comm_pair] 

797 formatted_op = f"Send_Receive_({source_rank})->({rank})_micro{step.micro_index}_{count}th" 

798 formatted_operations[str(rank)].append(formatted_op) 

799 

800 # Convert defaultdict to dict 

801 return dict(formatted_operations) 

802 

803 

804def validate_pipeline_execution(order_list: dict[int, list[MetaStep]], 

805 chunk_num: int, 

806 com_type: str = 'loop') -> dict[str, any]: 

807 """ 

808 Comprehensive validation function for pipeline parallel execution order. 

809 

810 This function validates the execution order of pipeline parallelism by: 

811 1. Checking SEND/RECV communication pair matching 

812 2. Detecting duplicate operations 

813 3. Detecting cycles in communication graphs 

814 4. Verifying computation-SEND matching 

815 

816 Args: 

817 order_list: Dictionary where keys are rank IDs and values are MetaStep execution sequences 

818 chunk_num: Number of chunks (virtual pipeline stages) 

819 com_type: Stage-to-rank mapping type ('loop' for cyclic, 'v' for V-shaped) 

820 

821 Returns: 

822 Dictionary containing validation results with the following keys: 

823 - validation: Communication pair validation results 

824 - cycle_detection: Cycle detection results 

825 - computation_send_matching: Computation-SEND matching validation results 

826 - has_errors: Boolean indicating if any errors were found 

827 - error_messages: List of all error messages found 

828 - formatted_operations: Generated formatted operations 

829 """ 

830 

831 # Generate operations 

832 formatted_operations = generate_operations(order_list, chunk_num, com_type) 

833 

834 parse_and_validate(formatted_operations, True) 

835 

836 # Detect cycles 

837 cycle_path, cycle_ranks = detect_cycle_in_graph(formatted_operations) 

838 

839 # Output results 

840 output_cycle_results(cycle_path, cycle_ranks) 

841 

842 result = { 

843 'formatted_operations': formatted_operations, 

844 'cycle_path': cycle_path, 

845 'cycle_ranks': cycle_ranks, 

846 'has_cycle': bool(cycle_path) 

847 } 

848 return result