/* -------------------------------------------------------------------------
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is part of the MindStudio project.
 *
 * MindStudio is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *
 *    http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 * -------------------------------------------------------------------------
*/
#include "csrc/common/inject/driver_inject.h"

#include <functional>
#include <pthread.h>
#include "csrc/common/utils.h"
#include "csrc/common/function_loader.h"

namespace {
enum DriverFunctionIndex {
    FUNC_PROF_DRV_GET_CHANNELS,
    FUNC_DRV_GET_DEV_IDS,
    FUNC_DRV_GET_DEV_NUM,
    FUNC_PROF_DRV_START,
    FUNC_PROF_STOP,
    FUNC_PROF_CHANNEL_READ,
    FUNC_PROF_CHANNEL_POLL,
    FUNC_HAL_GET_DEVICE_INFO,
    FUNC_HAL_GET_API_VERSION,
    FUNC_HAL_PROF_DATA_FLUSH,
    FUNC_DRIVER_COUNT
};

pthread_once_t g_once = PTHREAD_ONCE_INIT;
void *g_driverFuncArray[FUNC_DRIVER_COUNT];

void LoadDriverFunction()
{
    g_driverFuncArray[FUNC_PROF_DRV_GET_CHANNELS] =
        Mspti::Common::RegisterFunction("libascend_hal", "prof_drv_get_channels");
    g_driverFuncArray[FUNC_DRV_GET_DEV_IDS] = Mspti::Common::RegisterFunction("libascend_hal", "drvGetDevIDs");
    g_driverFuncArray[FUNC_DRV_GET_DEV_NUM] = Mspti::Common::RegisterFunction("libascend_hal", "drvGetDevNum");
    g_driverFuncArray[FUNC_PROF_DRV_START] = Mspti::Common::RegisterFunction("libascend_hal", "prof_drv_start");
    g_driverFuncArray[FUNC_PROF_STOP] = Mspti::Common::RegisterFunction("libascend_hal", "prof_stop");
    g_driverFuncArray[FUNC_PROF_CHANNEL_READ] = Mspti::Common::RegisterFunction("libascend_hal", "prof_channel_read");
    g_driverFuncArray[FUNC_PROF_CHANNEL_POLL] = Mspti::Common::RegisterFunction("libascend_hal", "prof_channel_poll");
    g_driverFuncArray[FUNC_HAL_GET_DEVICE_INFO] = Mspti::Common::RegisterFunction("libascend_hal", "halGetDeviceInfo");
    g_driverFuncArray[FUNC_HAL_GET_API_VERSION] = Mspti::Common::RegisterFunction("libascend_hal", "halGetAPIVersion");
    g_driverFuncArray[FUNC_HAL_PROF_DATA_FLUSH] = Mspti::Common::RegisterFunction("libascend_hal", "halProfDataFlush");
}
}

int ProfDrvGetChannels(unsigned int deviceId, ChannelListT *channelList)
{
    pthread_once(&g_once, LoadDriverFunction);
    void *voidFunc = g_driverFuncArray[FUNC_PROF_DRV_GET_CHANNELS];
    using ProfDrvGetChannelsFunc = std::function<decltype(ProfDrvGetChannels)>;
    ProfDrvGetChannelsFunc func = Mspti::Common::ReinterpretConvert<decltype(&ProfDrvGetChannels)>(voidFunc);
    if (func == nullptr) {
        Mspti::Common::GetFunction("libascend_hal", "prof_drv_get_channels", func);
    }
    THROW_FUNC_NOTFOUND(func, "prof_drv_get_channels", "libascend_hal.so");
    return func(deviceId, channelList);
}

DrvError DrvGetDevIDs(uint32_t *devices, uint32_t len)
{
    pthread_once(&g_once, LoadDriverFunction);
    void *voidFunc = g_driverFuncArray[FUNC_DRV_GET_DEV_IDS];
    using DrvGetDevIDsFunc = std::function<decltype(DrvGetDevIDs)>;
    DrvGetDevIDsFunc func = Mspti::Common::ReinterpretConvert<decltype(&DrvGetDevIDs)>(voidFunc);
    if (func == nullptr) {
        Mspti::Common::GetFunction("libascend_hal", "drvGetDevIDs", func);
    }
    THROW_FUNC_NOTFOUND(func, "drvGetDevIDs", "libascend_hal.so");
    return func(devices, len);
}

DrvError DrvGetDevNum(uint32_t *count)
{
    pthread_once(&g_once, LoadDriverFunction);
    void *voidFunc = g_driverFuncArray[FUNC_DRV_GET_DEV_NUM];
    using DrvGetDevNumFunc = std::function<decltype(DrvGetDevNum)>;
    DrvGetDevNumFunc func = Mspti::Common::ReinterpretConvert<decltype(&DrvGetDevNum)>(voidFunc);
    if (func == nullptr) {
        Mspti::Common::GetFunction("libascend_hal", "drvGetDevNum", func);
    }
    THROW_FUNC_NOTFOUND(func, "drvGetDevNum", "libascend_hal.so");
    return func(count);
}

int ProfDrvStart(unsigned int deviceId, unsigned int channelId, struct ProfStartPara *startPara)
{
    pthread_once(&g_once, LoadDriverFunction);
    void *voidFunc = g_driverFuncArray[FUNC_PROF_DRV_START];
    using ProfDrvStartFunc = std::function<decltype(ProfDrvStart)>;
    ProfDrvStartFunc func = Mspti::Common::ReinterpretConvert<decltype(&ProfDrvStart)>(voidFunc);
    if (func == nullptr) {
        Mspti::Common::GetFunction("libascend_hal", "prof_drv_start", func);
    }
    THROW_FUNC_NOTFOUND(func, "prof_drv_start", "libascend_hal.so");
    return func(deviceId, channelId, startPara);
}

int ProfStop(unsigned int deviceId, unsigned int channelId)
{
    pthread_once(&g_once, LoadDriverFunction);
    void *voidFunc = g_driverFuncArray[FUNC_PROF_STOP];
    using ProfStopFunc = std::function<decltype(ProfStop)>;
    ProfStopFunc func = Mspti::Common::ReinterpretConvert<decltype(&ProfStop)>(voidFunc);
    if (func == nullptr) {
        Mspti::Common::GetFunction("libascend_hal", "prof_stop", func);
    }
    THROW_FUNC_NOTFOUND(func, "prof_stop", "libascend_hal.so");
    return func(deviceId, channelId);
}

int ProfChannelRead(unsigned int deviceId, unsigned int channelId, char *outBuf, unsigned int bufSize)
{
    pthread_once(&g_once, LoadDriverFunction);
    void *voidFunc = g_driverFuncArray[FUNC_PROF_CHANNEL_READ];
    using ProfChannelReadFunc = std::function<decltype(ProfChannelRead)>;
    ProfChannelReadFunc func = Mspti::Common::ReinterpretConvert<decltype(&ProfChannelRead)>(voidFunc);
    if (func == nullptr) {
        Mspti::Common::GetFunction("libascend_hal", "prof_channel_read", func);
    }
    THROW_FUNC_NOTFOUND(func, "prof_channel_read", "libascend_hal.so");
    return func(deviceId, channelId, outBuf, bufSize);
}

int ProfChannelPoll(struct ProfPollInfo *outBuf, int num, int timeout)
{
    pthread_once(&g_once, LoadDriverFunction);
    void *voidFunc = g_driverFuncArray[FUNC_PROF_CHANNEL_POLL];
    using ProfChannelPollFunc = std::function<decltype(ProfChannelPoll)>;
    ProfChannelPollFunc func = Mspti::Common::ReinterpretConvert<decltype(&ProfChannelPoll)>(voidFunc);
    if (func == nullptr) {
        Mspti::Common::GetFunction("libascend_hal", "prof_channel_poll", func);
    }
    THROW_FUNC_NOTFOUND(func, "prof_channel_poll", "libascend_hal.so");
    return func(outBuf, num, timeout);
}

DrvError HalGetDeviceInfo(uint32_t deviceId, int32_t moduleType, int32_t infoType, int64_t *value)
{
    pthread_once(&g_once, LoadDriverFunction);
    void *voidFunc = g_driverFuncArray[FUNC_HAL_GET_DEVICE_INFO];
    using HalGetDeviceInfoFunc = std::function<decltype(HalGetDeviceInfo)>;
    HalGetDeviceInfoFunc func = Mspti::Common::ReinterpretConvert<decltype(&HalGetDeviceInfo)>(voidFunc);
    if (func == nullptr) {
        Mspti::Common::GetFunction("libascend_hal", "halGetDeviceInfo", func);
    }
    THROW_FUNC_NOTFOUND(func, "halGetDeviceInfo", "libascend_hal.so");
    return func(deviceId, moduleType, infoType, value);
}

DrvError halGetAPIVersion(int32_t *apiVersion)
{
    pthread_once(&g_once, LoadDriverFunction);
    void *voidFunc = g_driverFuncArray[FUNC_HAL_GET_API_VERSION];
    using halGetAPIVersionFunc = std::function<decltype(halGetAPIVersion)>;
    halGetAPIVersionFunc func = Mspti::Common::ReinterpretConvert<decltype(&halGetAPIVersion)>(voidFunc);
    if (func == nullptr) {
        Mspti::Common::GetFunction("libascend_hal", __FUNCTION__, func);
    }
    THROW_FUNC_NOTFOUND(func, __FUNCTION__, "libascend_hal.so");
    return func(apiVersion);
}

int HalProfDataFlush(unsigned int device_id, unsigned int channel_id, unsigned int *data_len)
{
    pthread_once(&g_once, LoadDriverFunction);
    void *voidFunc = g_driverFuncArray[FUNC_HAL_PROF_DATA_FLUSH];
    using HalProfDataFlushFunc = std::function<decltype(HalProfDataFlush)>;
    HalProfDataFlushFunc func = Mspti::Common::ReinterpretConvert<decltype(&HalProfDataFlush)>(voidFunc);
    if (func == nullptr) {
        Mspti::Common::GetFunction("libascend_hal", "halProfDataFlush", func);
    }
    THROW_FUNC_NOTFOUND(func, "halProfDataFlush", "libascend_hal.so");
    return func(device_id, channel_id, data_len);
}