// SPDX-License-Identifier: Mulan PSL v2
/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This software is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *         http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

use std::ffi::{c_int, c_longlong, c_void};

use smallvec::smallvec;
use tracing::{debug, info};

use cudax::cublas::*;
use xgpu_common::{
    api_name::ApiFuncName,
    ipc::message::{Argument, ArgumentFlag, Request},
};

use crate::{agent::Agent, fault_guard::virt, hook::cublas::CublasApi};

fn get_compute_type_size(compute_type: u32) -> usize {
    #[allow(non_upper_case_globals)]
    match compute_type {
        CUBLAS_COMPUTE_16F | CUBLAS_COMPUTE_16F_PEDANTIC => 2,

        CUBLAS_COMPUTE_32F
        | CUBLAS_COMPUTE_32F_PEDANTIC
        | CUBLAS_COMPUTE_32F_FAST_16F
        | CUBLAS_COMPUTE_32F_FAST_16BF
        | CUBLAS_COMPUTE_32F_FAST_TF32 => 4,

        CUBLAS_COMPUTE_64F | CUBLAS_COMPUTE_64F_PEDANTIC => 8,

        CUBLAS_COMPUTE_32I | CUBLAS_COMPUTE_32I_PEDANTIC => 4,

        _ => 8,
    }
}

pub struct CublasApiImpl;

impl CublasApi for CublasApiImpl {
    fn cublasCreate_v2(&self, handle: *mut cublasHandle_t) -> cublasStatus_t {
        info!("[Hooked] api_name: cublasCreate_v2");
        debug!("before: handle ptr:{:p},  *handle:{:?}", handle, unsafe {
            *handle
        });

        let mut handle_usize = unsafe { *handle };

        let req = Request::with_args(
            ApiFuncName::CublascreateV2 as u64,
            smallvec![Argument::from_mut(
                &mut handle_usize,
                ArgumentFlag::ARG_OUT | ArgumentFlag::ARG_VIRT,
            )],
        );
        let request_id = req.request_id();
        let res = Agent::get_instance()
            .invoke_api::<cublasStatus_t>(req)
            .expect("call invoke_api failed");
        debug!("after: handle ptr:{:p},  *handle:{:?}", handle, unsafe {
            *handle
        });

        unsafe {
            *handle = handle_usize as cublasHandle_t;

            virt::handle_insert(*handle as *mut c_void, *handle as *mut c_void, 0)
                .expect("handle_insert failed");
            virt::req_id_vhandle_insert(request_id, *handle as *mut c_void)
                .expect("req_id_vhandle_insert failed");
        }
        res
    }

    fn cublasSetWorkspace_v2(
        &self,
        handle: cublasHandle_t,
        workspace: *mut c_void,
        workspace_size_in_bytes: usize,
    ) -> cublasStatus_t {
        info!("[Hooked] api_name: cublasSetWorkspace_v2");

        let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
            as cublasHandle_t as usize;
        let workspace_usize = virt::handle_map(workspace).expect("handle_map failed") as usize;

        let req = Request::with_args(
            ApiFuncName::CublassetworkspaceV2 as u64,
            smallvec![
                Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(
                    &workspace_usize,
                    ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT,
                ),
                Argument::from_ref(&workspace_size_in_bytes, ArgumentFlag::ARG_IN),
            ],
        );
        let res = Agent::get_instance()
            .invoke_api::<cublasStatus_t>(req)
            .expect("call invoke_api failed");
        debug!(
            "workspace:{:p}, size:{}",
            workspace, workspace_size_in_bytes
        );
        res
    }

    fn cublasSetStream_v2(
        &self,
        handle: cublasHandle_t,
        stream_id: cudaStream_t,
    ) -> cublasStatus_t {
        info!("[Hooked] api_name: cublasSetStream_v2");

        let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
            as cublasHandle_t as usize;
        let stream_id_usize = virt::handle_map(stream_id as *mut c_void).expect("handle_map failed")
            as cudaStream_t as usize;

        let req = Request::with_args(
            ApiFuncName::CublassetstreamV2 as u64,
            smallvec![
                Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(
                    &stream_id_usize,
                    ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT,
                ),
            ],
        );
        let res = Agent::get_instance()
            .invoke_api::<cublasStatus_t>(req)
            .expect("call invoke_api failed");
        debug!(" client-2 handle ptr:{:?}", handle);
        debug!(" client-2 stream_id ptr:{:?}", stream_id);
        res
    }

    fn cublasGetStream_v2(
        &self,
        handle: cublasHandle_t,
        stream_id: *mut cudaStream_t,
    ) -> cublasStatus_t {
        info!("[Hooked] api_name: cublasGetStream_v2");
        debug!(" client handle ptr:{:?}", handle);
        debug!(
            " client1: stream_id ptr:{:p}, *ptr:{:?} ",
            stream_id,
            unsafe { *stream_id }
        );

        let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
            as cublasHandle_t as usize;
        let mut stream_id_usize = unsafe { *stream_id };

        let req = Request::with_args(
            ApiFuncName::CublasgetstreamV2 as u64,
            smallvec![
                Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_mut(
                    &mut stream_id_usize,
                    ArgumentFlag::ARG_OUT | ArgumentFlag::ARG_VIRT,
                ),
            ],
        );
        let request_id = req.request_id();
        let res = Agent::get_instance()
            .invoke_api::<cublasStatus_t>(req)
            .expect("call invoke_api failed");
        debug!(
            " client2: stream_id ptr:{:p}, *ptr:{:?} ",
            stream_id,
            unsafe { *stream_id }
        );

        unsafe {
            *stream_id = stream_id_usize as cudaStream_t;

            virt::handle_insert(*stream_id as *mut c_void, *stream_id as *mut c_void, 0)
                .expect("handle_insert failed");
            virt::req_id_vhandle_insert(request_id, *stream_id as *mut c_void)
                .expect("req_id_vhandle_insert failed");
        }
        res
    }

    fn cublasGetMathMode(&self, handle: cublasHandle_t, mode: *mut cublasMath_t) -> cublasStatus_t {
        info!("[Hooked] api_name: cublasGetMathMode");
        debug!(
            "before: handle: {}, mode: {:p}, *mode: {:?}",
            handle,
            mode,
            unsafe { *mode }
        );

        let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
            as cublasHandle_t as usize;

        let req = Request::with_args(
            ApiFuncName::Cublasgetmathmode as u64,
            smallvec![
                Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                unsafe { Argument::from_mut_ptr(mode, ArgumentFlag::ARG_OUT) },
            ],
        );
        let res = Agent::get_instance()
            .invoke_api::<cublasStatus_t>(req)
            .expect("call invoke_api failed");
        debug!(
            "after: handle: {}, mode: {:p}, *mode: {:?}",
            handle,
            mode,
            unsafe { *mode }
        );
        res
    }

    fn cublasSetMathMode(&self, handle: cublasHandle_t, mode: cublasMath_t) -> cublasStatus_t {
        info!("[Hooked] api_name: cublasSetMathMode");
        debug!("before: handle: {}, mode:{}", handle, mode);

        let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
            as cublasHandle_t as usize;

        let req = Request::with_args(
            ApiFuncName::Cublassetmathmode as u64,
            smallvec![
                Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&mode, ArgumentFlag::ARG_IN),
            ],
        );
        let res = Agent::get_instance()
            .invoke_api::<cublasStatus_t>(req)
            .expect("call invoke_api failed");
        debug!("after: handle: {}, mode:{}", handle, mode);
        res
    }

    fn cublasSgemm_v2(
        &self,
        handle: cublasHandle_t,
        transa: cublasOperation_t,
        transb: cublasOperation_t,
        m: c_int,
        n: c_int,
        k: c_int,
        alpha: *const f32,
        a: *const f32,
        lda: c_int,
        b: *const f32,
        ldb: c_int,
        beta: *const f32,
        c: *mut f32,
        ldc: c_int,
    ) -> cublasStatus_t {
        info!("[Hooked] api_name: cublasSgemm_v2");

        let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
            as cublasHandle_t as usize;
        let a_usize =
            virt::handle_map(a as *mut c_void).expect("handle_map failed") as *const f32 as usize;
        let b_usize =
            virt::handle_map(b as *mut c_void).expect("handle_map failed") as *const f32 as usize;
        let c_usize =
            virt::handle_map(c as *mut c_void).expect("handle_map failed") as *mut f32 as usize;

        let req = Request::with_args(
            ApiFuncName::CublassgemmV2 as u64,
            smallvec![
                Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&transa, ArgumentFlag::ARG_IN),
                Argument::from_ref(&transb, ArgumentFlag::ARG_IN),
                Argument::from_ref(&m, ArgumentFlag::ARG_IN),
                Argument::from_ref(&n, ArgumentFlag::ARG_IN),
                Argument::from_ref(&k, ArgumentFlag::ARG_IN),
                unsafe { Argument::from_ptr(alpha, ArgumentFlag::ARG_IN) },
                Argument::from_ref(&a_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&lda, ArgumentFlag::ARG_IN),
                Argument::from_ref(&b_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&ldb, ArgumentFlag::ARG_IN),
                unsafe { Argument::from_ptr(beta, ArgumentFlag::ARG_IN) },
                Argument::from_ref(&c_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&ldc, ArgumentFlag::ARG_IN),
            ],
        );

        Agent::get_instance()
            .invoke_api::<cublasStatus_t>(req)
            .expect("call invoke_api failed")
    }

    fn cublasGemmEx(
        &self,
        handle: cublasHandle_t,
        transa: cublasOperation_t,
        transb: cublasOperation_t,
        m: c_int,
        n: c_int,
        k: c_int,
        alpha: *const c_void,
        a: *const c_void,
        atype: cudaDataType,
        lda: c_int,
        b: *const c_void,
        btype: cudaDataType,
        ldb: c_int,
        beta: *const c_void,
        c: *mut c_void,
        ctype: cudaDataType,
        ldc: c_int,
        compute_type: cublasComputeType_t,
        algo: cublasGemmAlgo_t,
    ) -> cublasStatus_t {
        info!("[Hooked] api_name: cublasGemmEx");

        let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
            as cublasHandle_t as usize;
        let a_usize = virt::handle_map(a as *mut c_void).expect("handle_map failed")
            as *const c_void as usize;
        let b_usize = virt::handle_map(b as *mut c_void).expect("handle_map failed")
            as *const c_void as usize;
        let c_usize = virt::handle_map(c).expect("handle_map failed") as usize;

        let compute_type_size = get_compute_type_size(compute_type);
        let alpha_slice =
            unsafe { std::slice::from_raw_parts(alpha.cast::<u8>(), compute_type_size) };
        let beta_slice =
            unsafe { std::slice::from_raw_parts(beta.cast::<u8>(), compute_type_size) };

        let req = Request::with_args(
            ApiFuncName::Cublasgemmex as u64,
            smallvec![
                Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&transa, ArgumentFlag::ARG_IN),
                Argument::from_ref(&transb, ArgumentFlag::ARG_IN),
                Argument::from_ref(&m, ArgumentFlag::ARG_IN),
                Argument::from_ref(&n, ArgumentFlag::ARG_IN),
                Argument::from_ref(&k, ArgumentFlag::ARG_IN),
                Argument::from_slice(alpha_slice, ArgumentFlag::ARG_IN),
                Argument::from_ref(&a_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&atype, ArgumentFlag::ARG_IN),
                Argument::from_ref(&lda, ArgumentFlag::ARG_IN),
                Argument::from_ref(&b_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&btype, ArgumentFlag::ARG_IN),
                Argument::from_ref(&ldb, ArgumentFlag::ARG_IN),
                Argument::from_slice(beta_slice, ArgumentFlag::ARG_IN),
                Argument::from_ref(&c_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&ctype, ArgumentFlag::ARG_IN),
                Argument::from_ref(&ldc, ArgumentFlag::ARG_IN),
                Argument::from_ref(&compute_type, ArgumentFlag::ARG_IN),
                Argument::from_ref(&algo, ArgumentFlag::ARG_IN),
            ],
        );
        Agent::get_instance()
            .invoke_api::<cublasStatus_t>(req)
            .expect("call invoke_api failed")
    }

    fn cublasSgemmStridedBatched(
        &self,
        handle: cublasHandle_t,
        transa: cublasOperation_t,
        transb: cublasOperation_t,
        m: c_int,
        n: c_int,
        k: c_int,
        alpha: *const f32,
        a: *const f32,
        lda: c_int,
        stride_a: c_longlong,
        b: *const f32,
        ldb: c_int,
        stride_b: c_longlong,
        beta: *const f32,
        c: *mut f32,
        ldc: c_int,
        stride_c: c_longlong,
        batch_count: c_int,
    ) -> cublasStatus_t {
        info!("[Hooked] api_name: cublasSgemmStridedBatched");

        let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
            as cublasHandle_t as usize;
        let a_usize =
            virt::handle_map(a as *mut c_void).expect("handle_map failed") as *const f32 as usize;
        let b_usize =
            virt::handle_map(b as *mut c_void).expect("handle_map failed") as *const f32 as usize;
        let c_usize =
            virt::handle_map(c as *mut c_void).expect("handle_map failed") as *mut f32 as usize;

        let req = Request::with_args(
            ApiFuncName::Cublassgemmstridedbatched as u64,
            smallvec![
                Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&transa, ArgumentFlag::ARG_IN),
                Argument::from_ref(&transb, ArgumentFlag::ARG_IN),
                Argument::from_ref(&m, ArgumentFlag::ARG_IN),
                Argument::from_ref(&n, ArgumentFlag::ARG_IN),
                Argument::from_ref(&k, ArgumentFlag::ARG_IN),
                unsafe { Argument::from_ptr(alpha, ArgumentFlag::ARG_IN) },
                Argument::from_ref(&a_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&lda, ArgumentFlag::ARG_IN),
                Argument::from_ref(&stride_a, ArgumentFlag::ARG_IN),
                Argument::from_ref(&b_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&ldb, ArgumentFlag::ARG_IN),
                Argument::from_ref(&stride_b, ArgumentFlag::ARG_IN),
                unsafe { Argument::from_ptr(beta, ArgumentFlag::ARG_IN) },
                Argument::from_ref(&c_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&ldc, ArgumentFlag::ARG_IN),
                Argument::from_ref(&stride_c, ArgumentFlag::ARG_IN),
                Argument::from_ref(&batch_count, ArgumentFlag::ARG_IN),
            ],
        );
        Agent::get_instance()
            .invoke_api::<cublasStatus_t>(req)
            .expect("call invoke_api failed")
    }

    fn cublasGemmStridedBatchedEx(
        &self,
        handle: cublasHandle_t,
        transa: cublasOperation_t,
        transb: cublasOperation_t,
        m: c_int,
        n: c_int,
        k: c_int,
        alpha: *const c_void,
        a: *const c_void,
        atype: cudaDataType,
        lda: c_int,
        stride_a: c_longlong,
        b: *const c_void,
        btype: cudaDataType,
        ldb: c_int,
        stride_b: c_longlong,
        beta: *const c_void,
        c: *mut c_void,
        ctype: cudaDataType,
        ldc: c_int,
        stride_c: c_longlong,
        batch_count: c_int,
        compute_type: cublasComputeType_t,
        algo: cublasGemmAlgo_t,
    ) -> cublasStatus_t {
        info!("[Hooked] api_name: cublasGemmStridedBatchedEx");

        let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
            as cublasHandle_t as usize;
        let a_usize = virt::handle_map(a as *mut c_void).expect("handle_map failed")
            as *const c_void as usize;
        let b_usize = virt::handle_map(b as *mut c_void).expect("handle_map failed")
            as *const c_void as usize;
        let c_usize = virt::handle_map(c).expect("handle_map failed") as usize;

        let compute_type_size = get_compute_type_size(compute_type);
        let alpha_slice =
            unsafe { std::slice::from_raw_parts(alpha.cast::<u8>(), compute_type_size) };
        let beta_slice =
            unsafe { std::slice::from_raw_parts(beta.cast::<u8>(), compute_type_size) };

        let req = Request::with_args(
            ApiFuncName::Cublasgemmstridedbatchedex as u64,
            smallvec![
                Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&transa, ArgumentFlag::ARG_IN),
                Argument::from_ref(&transb, ArgumentFlag::ARG_IN),
                Argument::from_ref(&m, ArgumentFlag::ARG_IN),
                Argument::from_ref(&n, ArgumentFlag::ARG_IN),
                Argument::from_ref(&k, ArgumentFlag::ARG_IN),
                Argument::from_slice(alpha_slice, ArgumentFlag::ARG_IN),
                Argument::from_ref(&a_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&atype, ArgumentFlag::ARG_IN),
                Argument::from_ref(&lda, ArgumentFlag::ARG_IN),
                Argument::from_ref(&stride_a, ArgumentFlag::ARG_IN),
                Argument::from_ref(&b_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&btype, ArgumentFlag::ARG_IN),
                Argument::from_ref(&ldb, ArgumentFlag::ARG_IN),
                Argument::from_ref(&stride_b, ArgumentFlag::ARG_IN),
                Argument::from_slice(beta_slice, ArgumentFlag::ARG_IN),
                Argument::from_ref(&c_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
                Argument::from_ref(&ctype, ArgumentFlag::ARG_IN),
                Argument::from_ref(&ldc, ArgumentFlag::ARG_IN),
                Argument::from_ref(&stride_c, ArgumentFlag::ARG_IN),
                Argument::from_ref(&batch_count, ArgumentFlag::ARG_IN),
                Argument::from_ref(&compute_type, ArgumentFlag::ARG_IN),
                Argument::from_ref(&algo, ArgumentFlag::ARG_IN),
            ],
        );
        Agent::get_instance()
            .invoke_api::<cublasStatus_t>(req)
            .expect("call invoke_api failed")
    }
}