* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This software is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
use std::{
env,
sync::atomic::{AtomicUsize, Ordering},
thread::{self, sleep},
time::{Duration, Instant},
};
use anyhow::Context;
use tracing::debug;
use crate::{
agent,
fault_guard::{bootstrapx, journal, recovery},
};
const DEFAULT_HOOK_POINT: usize = 50;
static MINIBATCH_COUNT: AtomicUsize = AtomicUsize::new(0);
fn on_minibatch_begin() -> anyhow::Result<()> {
agent::set_recovery_flag(true);
while agent::get_active_request_count() > 0 {
thread::sleep(Duration::from_millis(100));
}
let minibatch_count = MINIBATCH_COUNT.load(Ordering::Relaxed);
MINIBATCH_COUNT.fetch_add(1, Ordering::Relaxed);
debug!("[Hook] Minibatch {} begins", minibatch_count);
let old_queue = journal::reset_minibatch_requests()?;
thread::spawn(move || {
while let Some(req_bytes) = old_queue.pop() {
journal::dump_request(req_bytes).expect("Failed to log minibatch request");
}
});
let hook_point = match env::var("MINIBATCH_HOOK") {
Ok(val) => val.parse::<usize>().unwrap_or(DEFAULT_HOOK_POINT),
Err(_) => DEFAULT_HOOK_POINT,
};
if !agent::get_recovery_triggered() && minibatch_count == hook_point {
let mut curr_rank_id = bootstrapx::get_global_rank();
let excluded_rank = 2;
if curr_rank_id == 0 {
println!("Start fault recovery simulation...");
}
sleep(Duration::from_secs(2));
let start = Instant::now();
journal::finalize();
recovery::barrier(curr_rank_id, "nccl_ready");
recovery::recovery_nccl(excluded_rank).context("Failed to recover nccl communicators")?;
curr_rank_id = bootstrapx::get_global_rank();
recovery::barrier(curr_rank_id, "replay_ready");
if curr_rank_id == bootstrapx::get_rank_count() as i32 - 1 {
journal::request_player_init().context("Failed to initialize request player")?;
recovery::recovery_failed().context("Failed to recover failed rank")?;
}
recovery::barrier(curr_rank_id, "done");
let duration = start.elapsed();
if curr_rank_id == 0 {
println!("Fault recovery simulation completed in {:?}", duration);
}
agent::set_recovery_triggered(true);
}
agent::set_recovery_flag(false);
Ok(())
}
#[unsafe(no_mangle)]
pub extern "C" fn minibatch_begin_hook() {
on_minibatch_begin().expect("Failed to execute on_minibatch_begin");
}