// SPDX-License-Identifier: Mulan PSL v2
/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This software is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *         http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

use std::cmp::Ordering;
use std::ffi::{c_int, c_void};
use std::fs;
use std::thread::sleep;
use std::time::Duration;

use cudax::nccl;
use smallvec::smallvec;
use tracing::debug;
use xgpu_common::ipc::message::{ArgumentFlag, Request};
use xgpu_common::{api_name::ApiFuncName, ipc::message::Argument};

use crate::agent::Agent;
use crate::fault_guard::bootstrapx::get_rank_count;
use crate::fault_guard::{bootstrapx, journal, virt};

const RECOVERY_FAILED_SKIP_APIS: &[ApiFuncName] = &[
    ApiFuncName::Cudasetdevice,
    ApiFuncName::Ncclcomminitrankconfig,
    ApiFuncName::Ncclbcast,
    ApiFuncName::Ncclallreduce,
    ApiFuncName::Ncclallgather,
    ApiFuncName::Cudaeventquery,
    ApiFuncName::Cudeviceprimaryctxgetstate,
    ApiFuncName::Loaddynlibrary,
];

pub const NCCLAPI_COMMIDX_MAP: &[(u64, usize)] = &[
    (ApiFuncName::Ncclcommdestroy as u64, 0),
    (ApiFuncName::Ncclcommabort as u64, 0),
    (ApiFuncName::Ncclcommsplit as u64, 0),
    (ApiFuncName::Ncclcommgetasyncerror as u64, 0),
    (ApiFuncName::Ncclbcast as u64, 4),
    (ApiFuncName::Ncclallreduce as u64, 5),
    (ApiFuncName::Ncclallgather as u64, 4),
    (ApiFuncName::Ncclsend as u64, 4),
    (ApiFuncName::Ncclrecv as u64, 4),
    (ApiFuncName::Ncclreduce as u64, 6),
    (ApiFuncName::Ncclcommfinalize as u64, 0),
];

pub fn barrier(rank: i32, comment: &str) {
    let barrier_file = format!("/tmp/rank_{}_{}", rank, comment);
    fs::write(&barrier_file, b"ready").ok();
    loop {
        let mut all_ready = true;
        for r in 0..get_rank_count() as i32 {
            if r == rank {
                continue;
            }
            let other_barrier_file = format!("/tmp/rank_{}_{}", r, comment);
            if fs::metadata(&other_barrier_file).is_err() {
                all_ready = false;
                break;
            }
        }
        if all_ready {
            break;
        }
        sleep(Duration::from_millis(100));
    }
}

pub fn recovery_nccl(excluded_rank: i32) -> anyhow::Result<()> {
    let mut rank = bootstrapx::get_global_rank();
    debug!("[Recovery] Rank {} starts nccl recovery...", rank);

    debug!(
        "[Recovery] Rank {} exclude failed rank {}",
        rank, excluded_rank
    );
    let mut old_comm = bootstrapx::get_comm();
    old_comm = virt::handle_map(old_comm as *mut c_void)? as nccl::ncclComm_t;
    let mut new_comm = nccl::ncclComm_t::default();
    if rank == excluded_rank {
        let mut device_reset_req =
            Request::with_args(ApiFuncName::Cudadevicereset as u64, smallvec![]);
        Agent::get_instance().send_request(&mut device_reset_req)?;
    } else {
        let excluded_ranks = &mut [excluded_rank];
        let exclude_ranks_count = excluded_ranks.len() as i32;
        let shrink_flags = nccl::NCCL_SHRINK_DEFAULT as i32;
        {
            let mut shrink_req = Request::with_args(
                ApiFuncName::Ncclcommshrink as u64,
                smallvec![
                    Argument::from_ref(&old_comm, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                    Argument::from_mut_slice(excluded_ranks, ArgumentFlag::ARG_IN),
                    Argument::from_ref(&exclude_ranks_count, ArgumentFlag::ARG_IN),
                    Argument::from_mut(
                        &mut new_comm,
                        ArgumentFlag::ARG_OUT | ArgumentFlag::ARG_VIRT,
                    ),
                    Argument::empty(),
                    Argument::from_ref(&shrink_flags, ArgumentFlag::ARG_IN),
                ],
            );
            Agent::get_instance().send_request(&mut shrink_req)?;
        }
    }
    barrier(rank, "shrink");

    let new_rank = get_rank_count() as i32 - 1;
    match rank.cmp(&excluded_rank) {
        Ordering::Greater => {
            debug!("[Recovery] Rank {} updating rank id to {}", rank, rank - 1);
            rank -= 1;
            bootstrapx::set_global_rank(rank);
        }
        Ordering::Equal => {
            debug!(
                "[Recovery] Rank {} will be replaced by a new id {}",
                rank, new_rank
            );
            rank = new_rank;
            bootstrapx::set_global_rank(rank);
        }
        Ordering::Less => {}
    }

    debug!("[Recovery] Rank {} adding new rank {}", rank, new_rank);
    if rank == new_rank {
        let rank_count = bootstrapx::get_rank_count() as c_int;
        let mut init_new_req = Request::with_args(
            ApiFuncName::Ncclcomminitnewrank as u64,
            smallvec![
                Argument::from_mut(
                    &mut new_comm,
                    ArgumentFlag::ARG_OUT | ArgumentFlag::ARG_VIRT,
                ),
                Argument::from_ref(&rank_count, ArgumentFlag::ARG_IN),
            ],
        );
        Agent::get_instance().send_request(&mut init_new_req)?;
    } else {
        let mut add_new_req = Request::with_args(
            ApiFuncName::Ncclcommaddnewrank as u64,
            smallvec![Argument::from_ref(
                &new_comm,
                ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT,
            )],
        );
        Agent::get_instance().send_request(&mut add_new_req)?;
    }
    barrier(rank, "add");

    let mut setup_new_req = Request::with_args(
        ApiFuncName::Ncclcommsetupnewrank as u64,
        smallvec![Argument::from_ref(
            &new_comm,
            ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT,
        )],
    );
    Agent::get_instance().send_request(&mut setup_new_req)?;
    barrier(rank, "setup");

    debug!(
        "[Recovery] Rank {}, old comm 0x{:x}, new comm 0x{:x}",
        rank, old_comm, new_comm
    );
    virt::handle_insert(old_comm as *mut c_void, new_comm as *mut c_void, 0)?;
    bootstrapx::set_comm(new_comm);

    debug!(
        "[Recovery] Rank {} nccl recovery completed successfully",
        rank
    );

    Ok(())
}

pub fn recovery_failed() -> anyhow::Result<()> {
    let rank = bootstrapx::get_global_rank();
    debug!("[Recovery] Rank {} starts replaying (failed rank)", rank);

    while let Some(mut req_bytes) = journal::load_next_request()? {
        let mut req: Request = journal::deserialize_request(&mut req_bytes)?;
        let req_id = req.request_id();
        let method_id = req.method_id();
        let api = ApiFuncName::try_from(method_id as i32)?;
        debug!(
            "[Recovery] Rank {} Stage 1, replaying api: request_id={}, api={:?}",
            rank, req_id, api
        );

        if !RECOVERY_FAILED_SKIP_APIS.contains(&api) {
            Agent::get_instance().replay_api(&mut req)?;
        }
    }

    while let Some(mut req_bytes) = journal::fetch_minibatch_request()? {
        let mut req: Request = journal::deserialize_request(&mut req_bytes)?;
        let req_id = req.request_id();
        let method_id = req.method_id();
        let api = ApiFuncName::try_from(method_id as i32)?;
        debug!(
            "[Recovery] Rank {} Stage 2, replaying api: request_id={}, api={:?}",
            rank, req_id, api
        );

        Agent::get_instance().replay_api(&mut req)?;
    }

    Ok(())
}