// SPDX-License-Identifier: Mulan PSL v2
/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This software is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *         http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

use xgpu_macros::api_hook;

#[api_hook(NcclApi, backend = crate::hook_impl::ipc::nccl::NcclApiImpl)]
mod api {
    use std::ffi::{c_int, c_void};

    use cudax::nccl::*;

    unsafe extern "C" {
        pub fn ncclGetVersion(version: *mut c_int) -> ncclResult_t;

        pub fn ncclCommDestroy(comm: ncclComm_t) -> ncclResult_t;

        pub fn ncclCommAbort(comm: ncclComm_t) -> ncclResult_t;

        pub fn ncclCommSplit(
            comm: ncclComm_t,
            color: c_int,
            key: c_int,
            newcomm: *mut ncclComm_t,
            config: *mut ncclConfig_t,
        ) -> ncclResult_t;

        pub fn ncclCommGetAsyncError(
            comm: ncclComm_t,
            async_error: *mut ncclResult_t,
        ) -> ncclResult_t;

        pub fn ncclBcast(
            buff: *mut c_void,
            count: usize,
            datatype: ncclDataType_t,
            root: c_int,
            comm: ncclComm_t,
            stream: cudaStream_t,
        ) -> ncclResult_t;

        pub fn ncclAllReduce(
            sendbuff: *const c_void,
            recvbuff: *mut c_void,
            count: usize,
            datatype: ncclDataType_t,
            op: ncclRedOp_t,
            comm: ncclComm_t,
            stream: cudaStream_t,
        ) -> ncclResult_t;

        pub fn ncclAllGather(
            sendbuff: *const c_void,
            recvbuff: *mut c_void,
            sendcount: usize,
            datatype: ncclDataType_t,
            comm: ncclComm_t,
            stream: cudaStream_t,
        ) -> ncclResult_t;

        pub fn ncclSend(
            sendbuff: *const c_void,
            count: usize,
            datatype: ncclDataType_t,
            peer: c_int,
            comm: ncclComm_t,
            stream: cudaStream_t,
        ) -> ncclResult_t;

        pub fn ncclRecv(
            recvbuff: *mut c_void,
            count: usize,
            datatype: ncclDataType_t,
            peer: c_int,
            comm: ncclComm_t,
            stream: cudaStream_t,
        ) -> ncclResult_t;

        pub fn ncclGroupStart() -> ncclResult_t;

        pub fn ncclGroupEnd() -> ncclResult_t;

        pub fn ncclGetUniqueId(unique_id: *mut ncclUniqueId) -> ncclResult_t;

        pub fn ncclCommInitRankConfig(
            comm: *mut ncclComm_t,
            nranks: c_int,
            comm_id: ncclUniqueId,
            rank: c_int,
            config: *mut ncclConfig_t,
        ) -> ncclResult_t;

        pub fn ncclReduce(
            sendbuff: *const c_void,
            recvbuff: *mut c_void,
            count: usize,
            datatype: ncclDataType_t,
            op: ncclRedOp_t,
            root: c_int,
            comm: ncclComm_t,
            stream: cudaStream_t,
        ) -> ncclResult_t;

        pub fn ncclCommShrink(
            _comm: ncclComm_t,
            _exclude_ranks_list: *mut c_int,
            _exclude_ranks_count: c_int,
            _newcomm: *mut ncclComm_t,
            _config: *mut ncclConfig_t,
            _shrink_flags: c_int,
        ) -> ncclResult_t;

        pub fn ncclCommInitNewRank(_comm: *mut ncclComm_t, _nranks: c_int) -> ncclResult_t;

        pub fn ncclCommAddNewRank(_comm: ncclComm_t) -> ncclResult_t;

        pub fn ncclCommSetupNewRank(_comm: ncclComm_t) -> ncclResult_t;

        pub fn ncclCommFinalize(comm: ncclComm_t) -> ncclResult_t;
    }
}

pub use api::NcclApi;