// SPDX-License-Identifier: Mulan PSL v2
/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This software is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *         http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

use std::{
    ffi::c_void,
    mem, process,
    sync::{
        atomic::{AtomicBool, AtomicUsize, Ordering},
        LazyLock,
    },
    thread,
    time::Duration,
};

use anyhow::{bail, Context};
use parking_lot::Mutex;
use tracing::debug;

use xgpu_common::{
    api_name::ApiFuncName,
    ipc::{
        framer::Framer,
        message::{Argument, ArgumentFlag, Request},
        peer::Server,
        transport::Transport,
    },
};

use crate::fault_guard::{bootstrapx, journal, recovery::NCCLAPI_COMMIDX_MAP, virt};

static GLOBAL_AGENT: LazyLock<Agent> =
    LazyLock::new(|| Agent::new().expect("FATAL: Failed to initialize agent"));

static RECOVERY_FLAG: AtomicBool = AtomicBool::new(false);
static RECOVERY_TRIGGERED: AtomicBool = AtomicBool::new(false);
static ACTIVE_REQUEST_COUNT: AtomicUsize = AtomicUsize::new(0);

#[derive(Debug)]
pub struct Agent {
    server: Mutex<Server>,
}

impl Agent {
    #[inline]
    pub fn get_instance() -> &'static Self {
        &GLOBAL_AGENT
    }

    pub fn new() -> anyhow::Result<Self> {
        const BUFFER_SIZE: usize = 512 * 1024 * 1024;

        let addr = process::id().to_string();
        let transport = Transport::new(BUFFER_SIZE);
        let framer = Framer::new(BUFFER_SIZE);

        let server = Server::create(&transport, framer, &addr)
            .with_context(|| format!("Failed to create ipc server on '{}'", addr))?;
        debug!("{:#?}", server);

        Ok(Self {
            server: Mutex::new(server),
        })
    }

    pub fn invoke_api<T: Copy + 'static>(&self, mut req: Request) -> anyhow::Result<T> {
        while RECOVERY_FLAG.load(Ordering::Acquire) {
            thread::sleep(Duration::from_millis(1));
        }

        ACTIVE_REQUEST_COUNT.fetch_add(1, Ordering::Release);
        debug!(
            "[<-agent] invoke_api, request_id: {}, method_id: {}, arg_num: {}",
            req.request_id(),
            req.method_id(),
            req.args().len()
        );

        let req_bytes = journal::serialize_request(&req)
            .with_context(|| format!("Failed to serialize request {}", req.request_id()))?;
        journal::add_minibatch_request(req_bytes)
            .with_context(|| format!("Failed to record request {}", req.request_id()))?;

        let server = self.server.lock();
        let resp = server
            .invoke(&req)
            .with_context(|| format!("Failed to invoke request {}", req.request_id()))?;
        if req.request_id() != resp.request_id() {
            bail!("Request id mismatch");
        }
        if req.method_id() != resp.method_id() {
            bail!("Method id mismatch");
        }

        let ret_arg = *resp.ret_value();
        let ret_value = ret_arg.downcast::<T>().unwrap();
        req.update_from(&resp)
            .with_context(|| format!("Failed to update request {} parameters", req.request_id()))?;

        debug!("[->agent] get response ok, update OUT args ok");
        ACTIVE_REQUEST_COUNT.fetch_sub(1, Ordering::Release);

        Ok(ret_value)
    }

    pub fn replay_api(&self, req: &mut Request) -> anyhow::Result<()> {
        let req_id = req.request_id();
        let method_id = req.method_id();
        let args = req.args_mut();
        for (i, arg) in args.iter_mut().enumerate() {
            if arg
                .flag()
                .contains(ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT)
            {
                if method_id == ApiFuncName::Cudalaunchkernel as u64 && i >= 5 {
                    let param_data = unsafe {
                        arg.downcast_mut_slice::<u8>()
                            .context("Failed to downcast kernel params argument")
                    }?;
                    let param_size = param_data.len();
                    for j in (0..param_size).step_by(mem::size_of::<usize>()) {
                        if j + mem::size_of::<usize>() > param_size {
                            break;
                        }
                        let chunk = &param_data[j..j + mem::size_of::<usize>()];
                        let key = usize::from_ne_bytes(chunk.try_into().unwrap());
                        let value = if virt::is_vhandle_valid(key as *mut c_void) {
                            virt::handle_map(key as *mut c_void).unwrap_or_else(|_| {
                                panic!(
                                    "Failed to map virtual handle: req_id {}, arg_idx: {}",
                                    req_id, i
                                )
                            }) as usize
                        } else {
                            key
                        };
                        let value_bytes = value.to_ne_bytes();
                        param_data[j..j + mem::size_of::<usize>()].copy_from_slice(&value_bytes);
                    }
                    let new_arg = Argument::from_slice(
                        param_data,
                        ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT,
                    );
                    arg.update_from(&new_arg)
                        .context("Failed to update kernel params argument from virtual handles")?;
                } else if NCCLAPI_COMMIDX_MAP
                    .iter()
                    .any(|(mid, comm_idx)| *mid == method_id && *comm_idx == i)
                {
                    let new_comm = bootstrapx::get_comm();
                    let mut new_comm_usize = new_comm as usize;
                    let new_arg = Argument::from_mut(
                        &mut new_comm_usize,
                        ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT,
                    );
                    arg.update_from(&new_arg)
                        .context("Failed to update ncclcomm argument from new communicator")?;
                } else {
                    let vh = arg
                        .downcast::<usize>()
                        .context("Failed to downcast argument to usize")?;
                    let ph = match virt::handle_map(vh as *mut c_void) {
                        Ok(p) => p,
                        Err(e) => {
                            if method_id == ApiFuncName::Cudapointergetattributes as u64 {
                                continue;
                            }
                            eprintln!("Error mapping virtual handle: {}", e);
                            bail!(
                                "Failed to map virtual handle: req_id: {}, method_id: {}, arg_idx: {}",
                                req_id,
                                method_id,
                                i
                            )
                        }
                    };
                    let mut ph_usize = ph as usize;
                    let new_arg = Argument::from_mut(
                        &mut ph_usize,
                        ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT,
                    );
                    arg.update_from(&new_arg)
                        .context("Failed to update argument from virtual handle")?;
                }
            }
        }

        let server = self.server.lock();
        let resp = server.invoke(req).context("server_invoke failed")?;
        req.update_from(&resp).context("update should succeed")?;

        let args = req.args();
        for (i, arg) in args.iter().enumerate() {
            if arg
                .flag()
                .contains(ArgumentFlag::ARG_OUT | ArgumentFlag::ARG_VIRT)
            {
                let vh = virt::req_id_vhandle_map(req.request_id()).unwrap_or_else(|_| {
                    panic!(
                        "Failed to map request_id to virtual handle: req_id {}",
                        req_id
                    )
                });
                let ph = unsafe { arg.downcast_mut::<usize>().unwrap() } as *mut usize
                    as *mut *mut c_void;
                let size = if req.method_id() == ApiFuncName::Cudamalloc as u64
                    || req.method_id() == ApiFuncName::Cudahostalloc as u64
                {
                    *req.args()[1].downcast_ref::<usize>().unwrap()
                } else {
                    0
                };
                virt::handle_insert(vh, unsafe { *ph }, size).unwrap_or_else(|_| {
                    panic!(
                        "Failed to insert virtual handle: req_id: {}, method_id: {}, arg_idx: {}",
                        req_id, method_id, i
                    )
                });
            }
        }

        Ok(())
    }

    pub fn send_request(&self, req: &mut Request) -> anyhow::Result<()> {
        let server = self.server.lock();
        let resp = server.invoke(req).context("server_invoke failed")?;
        req.update_from(&resp).context("update should succeed")?;

        Ok(())
    }
}

#[inline]
pub fn set_recovery_flag(value: bool) {
    RECOVERY_FLAG.store(value, Ordering::Release);
}

#[inline]
pub fn set_recovery_triggered(value: bool) {
    RECOVERY_TRIGGERED.store(value, Ordering::Release);
}

#[inline]
pub fn get_recovery_triggered() -> bool {
    RECOVERY_TRIGGERED.load(Ordering::Acquire)
}

#[inline]
pub fn get_active_request_count() -> usize {
    ACTIVE_REQUEST_COUNT.load(Ordering::Acquire)
}