/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
 *
 * MindIE is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

#define __FILENAME__ (strrchr("/" __FILE__, '/') + 1)
#include <ATen/Tensor.h>
#include <torch_npu/csrc/framework/utils/CalcuOpUtil.h>
#include <dlfcn.h>

#include <fstream>

#include "find_op_path.h"

std::string RealPath(const std::string &path)
{
    if (path.empty() || path.size() > PATH_MAX) {
        return "";
    }
    char realPathBuf[PATH_MAX] = {0};
    if (realpath(path.c_str(), realPathBuf) == nullptr) {
        return "";
    }
    return std::string(realPathBuf);
}

std::vector<std::string> SplitStr(std::string s, const std::string &del)
{
    int end = s.find(del);
    std::vector<std::string> path_list;
    while (end != -1) {
        path_list.push_back(s.substr(0, end));
        s.erase(s.begin(), s.begin() + end + 1);
        end = s.find(del);
    }
    path_list.push_back(s);
    return path_list;
}
std::vector<std::string> ProcessPathList(const std::string& pathStr)
{
    return SplitStr(pathStr, ":");
}

void AppendLibPathSuffix(std::vector<std::string>& pathList)
{
    for (auto& currentPathIt : pathList) {
        currentPathIt += "/op_api/lib/";
    }
}

std::vector<std::string> ProcessCustomLibPath(const char* ascendCustomOppPath)
{
    std::string ascendCustomOppPathStr(ascendCustomOppPath);
    auto customLibPathList = ProcessPathList(ascendCustomOppPathStr);
    if (customLibPathList.empty()) {
        return std::vector<std::string>();
    }
    AppendLibPathSuffix(customLibPathList);
    return customLibPathList;
}

std::vector<std::string> GetCustomLibPath()
{
    const char *ascendCustomOppPath = std::getenv("ASCEND_CUSTOM_OPP_PATH");
    if (ascendCustomOppPath == nullptr) {
        ASCEND_LOGW("ASCEND_CUSTOM_OPP_PATH is not exists");
        return std::vector<std::string>();
    }
    return ProcessCustomLibPath(ascendCustomOppPath);
}

std::string GetVendorsConfigFilePath(const std::string& vendorsPath)
{
    return RealPath(vendorsPath + "/config.ini");
}

bool IsFileExist(const std::string &path)
{
    if (path.empty() || path.size() > PATH_MAX) {
        return false;
    }
    return (access(path.c_str(), F_OK) == 0) ? true : false;
}

bool ValidateVendorsConfigFile(const std::string& configFile)
{
    if (configFile.empty() || !IsFileExist(configFile)) {
        ASCEND_LOGW("config.ini is not exists or the path length is more than %d", PATH_MAX);
        return false;
    }
    return true;
}

std::string ReadLoadPriorityLine(const std::string& configFile)
{
    std::ifstream ifs(configFile);
    std::string line;
    while (std::getline(ifs, line)) {
        if (line.find("load_priority=") == 0) {
            break;
        }
    }
    return line;
}

std::string ExtractLoadPriorityValue(const std::string& line)
{
    std::string head = "load_priority=";
    std::string result = line;
    if (result.find(head) == 0) {
        result.erase(0, head.length());
    }
    return result;
}

std::vector<std::string> ProcessVendorsList(const std::string& vendorsPath, const std::string& line)
{
    auto defaultVendorsList = SplitStr(line, ",");
    for (auto &it : defaultVendorsList) {
        it = RealPath(vendorsPath + "/" + it + "/op_api/lib/");
    }
    return defaultVendorsList;
}

std::vector<std::string> ParseVendorsConfig(const std::string& vendorsPath)
{
    std::string vendorsConfigFile = GetVendorsConfigFilePath(vendorsPath);
    if (!ValidateVendorsConfigFile(vendorsConfigFile)) {
        return {};
    }
    std::string line = ReadLoadPriorityLine(vendorsConfigFile);
    std::string priorityValue = ExtractLoadPriorityValue(line);
    return ProcessVendorsList(vendorsPath, priorityValue);
}

std::vector<std::string> GetDefaultCustomLibPath()
{
    const char *ascendOppPath = std::getenv("ASCEND_OPP_PATH");
    std::vector<std::string> defaultVendorsList;
    if (ascendOppPath == nullptr) {
        ASCEND_LOGW("ASCEND_OPP_PATH is not exists");
        return std::vector<std::string>();
    }
    std::string vendorsPath(ascendOppPath);
    vendorsPath = vendorsPath + "/vendors";
    return ParseVendorsConfig(vendorsPath);
}

const char *GetOpApiLibName(void)
{
    return "libopapi.so";
}

const char *GetCustOpApiLibName(void)
{
    return "libcust_opapi.so";
}

std::string GetCustomOpApiLibPath(const std::string& libPath)
{
    return RealPath(libPath + "/" + GetCustOpApiLibName());
}

void *GetOpApiLibHandler(const char *libName)
{
    auto handler = dlopen(libName, RTLD_LAZY);
    if (handler == nullptr) {
        ASCEND_LOGW("dlopen %s failed, error:%s.", libName, dlerror());
    }
    return handler;
}

template<typename T = void>
void *GetOpApiFuncAddrInLib(T *handler, const char *libName, const std::string& apiName)
{
    auto funcAddr = dlsym(handler, apiName.c_str());
    if (funcAddr == nullptr) {
        ASCEND_LOGW("dlsym %s from %s failed, error:%s.", apiName, libName, dlerror());
    }
    return funcAddr;
}

void* GetFuncFromDefaultLib(const std::string& apiName)
{
    static auto opApiHandler = GetOpApiLibHandler(GetOpApiLibName());
    if (opApiHandler == nullptr) {
        return nullptr;
    }
    return GetOpApiFuncAddrInLib(opApiHandler, GetOpApiLibName(), apiName);
}

void* LoadDefaultCustomOpApiHandler(const std::string& defaultCustOpApiLib)
{
    if (defaultCustOpApiLib.empty()) {
        return nullptr;
    }
    return GetOpApiLibHandler(defaultCustOpApiLib.c_str());
}

void* LoadCustomOpApiHandler(const std::string& custOpApiLib)
{
    if (custOpApiLib.empty()) {
        return nullptr;
    }
    return GetOpApiLibHandler(custOpApiLib.c_str());
}

void* FindFuncInCustomLibPath(const char* apiName, const std::string& libPath)
{
    auto custOpApiLib = GetCustomOpApiLibPath(libPath);
    auto custOpApiHandler = LoadCustomOpApiHandler(custOpApiLib);
    if (custOpApiHandler != nullptr) {
        auto funcAddr = GetOpApiFuncAddrInLib(custOpApiHandler, GetCustOpApiLibName(), apiName);
        if (funcAddr != nullptr) {
            ASCEND_LOGI("%s is found in %s.", apiName, custOpApiLib.c_str());
            return funcAddr;
        }
    }
    return nullptr;
}

std::string GetDefaultCustomOpApiLibPath(const std::string& libPath)
{
    return RealPath(libPath + "/" + GetCustOpApiLibName());
}

void* FindFuncInDefaultLibPath(const char* apiName, const std::string& libPath)
{
    auto defaultCustOpApiLib = GetDefaultCustomOpApiLibPath(libPath);
    auto custOpApiHandler = LoadDefaultCustomOpApiHandler(defaultCustOpApiLib);
    if (custOpApiHandler != nullptr) {
        auto funcAddr = GetOpApiFuncAddrInLib(custOpApiHandler, GetCustOpApiLibName(), apiName);
        if (funcAddr != nullptr) {
            ASCEND_LOGI("%s is found in %s.", apiName, defaultCustOpApiLib.c_str());
            return funcAddr;
        }
    }
    return nullptr;
}

void* AllocateWorkspace(uint64_t workspaceSize, at::Tensor& workspaceTensor)
{
    if (workspaceSize == 0) {
        return nullptr;
    }

    at::TensorOptions options = at::TensorOptions(torch_npu::utils::get_npu_device_type());
    workspaceTensor = at::empty({static_cast<int64_t>(workspaceSize)}, options.dtype(c10::kByte));
    return const_cast<void*>(workspaceTensor.storage().data());
}

const std::vector<std::string> g_customLibPath = GetCustomLibPath();
const std::vector<std::string> g_defaultCustomLibPath = GetDefaultCustomLibPath();