// SPDX-License-Identifier: Mulan PSL v2
/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This software is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *         http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

use std::{
    env,
    sync::atomic::{AtomicUsize, Ordering},
    thread::{self, sleep},
    time::{Duration, Instant},
};

use anyhow::Context;
use tracing::debug;

use crate::{
    agent,
    fault_guard::{bootstrapx, journal, recovery},
};

const DEFAULT_HOOK_POINT: usize = 50;

static MINIBATCH_COUNT: AtomicUsize = AtomicUsize::new(0);

fn on_minibatch_begin() -> anyhow::Result<()> {
    agent::set_recovery_flag(true);

    while agent::get_active_request_count() > 0 {
        thread::sleep(Duration::from_millis(100));
    }

    let minibatch_count = MINIBATCH_COUNT.load(Ordering::Relaxed);
    MINIBATCH_COUNT.fetch_add(1, Ordering::Relaxed);
    debug!("[Hook] Minibatch {} begins", minibatch_count);

    let old_queue = journal::reset_minibatch_requests()?;
    thread::spawn(move || {
        while let Some(req_bytes) = old_queue.pop() {
            journal::dump_request(req_bytes).expect("Failed to log minibatch request");
        }
    });

    let hook_point = match env::var("MINIBATCH_HOOK") {
        Ok(val) => val.parse::<usize>().unwrap_or(DEFAULT_HOOK_POINT),
        Err(_) => DEFAULT_HOOK_POINT,
    };

    if !agent::get_recovery_triggered() && minibatch_count == hook_point {
        let mut curr_rank_id = bootstrapx::get_global_rank();
        let excluded_rank = 2;

        if curr_rank_id == 0 {
            println!("Start fault recovery simulation...");
        }
        sleep(Duration::from_secs(2));
        let start = Instant::now();

        journal::finalize();

        recovery::barrier(curr_rank_id, "nccl_ready");
        recovery::recovery_nccl(excluded_rank).context("Failed to recover nccl communicators")?;
        curr_rank_id = bootstrapx::get_global_rank();

        recovery::barrier(curr_rank_id, "replay_ready");
        if curr_rank_id == bootstrapx::get_rank_count() as i32 - 1 {
            journal::request_player_init().context("Failed to initialize request player")?;
            recovery::recovery_failed().context("Failed to recover failed rank")?;
        }

        recovery::barrier(curr_rank_id, "done");

        let duration = start.elapsed();
        if curr_rank_id == 0 {
            println!("Fault recovery simulation completed in {:?}", duration);
        }

        agent::set_recovery_triggered(true);
    }

    agent::set_recovery_flag(false);

    Ok(())
}

#[unsafe(no_mangle)]
pub extern "C" fn minibatch_begin_hook() {
    on_minibatch_begin().expect("Failed to execute on_minibatch_begin");
}