// SPDX-License-Identifier: Mulan PSL v2
/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This software is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *         http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

use xgpu_macros::api_hook;

#[api_hook(CublasLtApi, backend = crate::hook_impl::ipc::cublaslt::CublasLtApiImpl)]
mod api {
    use std::ffi::{c_int, c_void};

    use cudax::cublaslt::*;

    unsafe extern "C" {
        pub fn cublasLtCreate(light_handle: *mut cublasLtHandle_t) -> cublasStatus_t;

        pub fn cublasLtDestroy(light_handle: cublasLtHandle_t) -> cublasStatus_t;

        pub fn cublasLtMatmul(
            light_handle: cublasLtHandle_t,
            compute_desc: cublasLtMatmulDesc_t,
            alpha: *const c_void,
            a: *const c_void,
            a_desc: cublasLtMatrixLayout_t,
            b: *const c_void,
            b_desc: cublasLtMatrixLayout_t,
            beta: *const c_void,
            c: *const c_void,
            c_desc: cublasLtMatrixLayout_t,
            d: *mut c_void,
            d_desc: cublasLtMatrixLayout_t,
            algo: *const cublasLtMatmulAlgo_t,
            workspace: *mut c_void,
            workspace_size_in_bytes: usize,
            stream: cudaStream_t,
        ) -> cublasStatus_t;

        pub fn cublasLtMatrixLayoutCreate(
            mat_layout: *mut cublasLtMatrixLayout_t,
            type_x: cudaDataType,
            rows: u64,
            cols: u64,
            ld: i64,
        ) -> cublasStatus_t;

        pub fn cublasLtMatrixLayoutDestroy(mat_layout: cublasLtMatrixLayout_t) -> cublasStatus_t;

        pub fn cublasLtMatmulDescCreate(
            matmul_desc: *mut cublasLtMatmulDesc_t,
            compute_type: cublasComputeType_t,
            scale_type: cudaDataType_t,
        ) -> cublasStatus_t;

        pub fn cublasLtMatmulDescDestroy(matmul_desc: cublasLtMatmulDesc_t) -> cublasStatus_t;

        pub fn cublasLtMatmulDescSetAttribute(
            matmul_desc: cublasLtMatmulDesc_t,
            attr: cublasLtMatmulDescAttributes_t,
            buf: *const c_void,
            size_in_bytes: usize,
        ) -> cublasStatus_t;

        pub fn cublasLtMatmulPreferenceCreate(
            pref: *mut cublasLtMatmulPreference_t,
        ) -> cublasStatus_t;

        pub fn cublasLtMatmulPreferenceDestroy(pref: cublasLtMatmulPreference_t) -> cublasStatus_t;

        pub fn cublasLtMatmulPreferenceSetAttribute(
            pref: cublasLtMatmulPreference_t,
            attr: cublasLtMatmulPreferenceAttributes_t,
            buf: *const c_void,
            size_in_bytes: usize,
        ) -> cublasStatus_t;

        pub fn cublasLtMatmulAlgoGetHeuristic(
            light_handle: cublasLtHandle_t,
            operation_desc: cublasLtMatmulDesc_t,
            a_desc: cublasLtMatrixLayout_t,
            b_desc: cublasLtMatrixLayout_t,
            c_desc: cublasLtMatrixLayout_t,
            d_desc: cublasLtMatrixLayout_t,
            preference: cublasLtMatmulPreference_t,
            requested_algo_count: c_int,
            heuristic_results_array: *mut cublasLtMatmulHeuristicResult_t,
            return_algo_count: *mut c_int,
        ) -> cublasStatus_t;
    }
}

pub use api::CublasLtApi;