/* -------------------------------------------------------------------------
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is part of the MindStudio project.
 *
 * MindStudio is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *
 *    http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 * -------------------------------------------------------------------------
*/
#include "csrc/common/inject/mstx_inject.h"

#include <cstring>
#include <mutex>

#include "csrc/activity/activity_manager.h"
#include "csrc/activity/ascend/parser/mstx_parser.h"
#include "csrc/common/utils.h"
#include "csrc/common/plog_manager.h"
#include "csrc/common/inject/mstx_inject.h"
#include "csrc/common/function_loader.h"

using namespace Mspti::Activity;
using namespace Mspti::Common;

namespace Mspti {

static std::string g_defaultDomainName = "default";
std::once_flag MstxDomainMgr::onceFlag;
mstxDomainHandle_t MstxDomainMgr::defaultDomainHandle_t;
std::unordered_set<std::string> MstxDomainMgr::disableDomains_;
std::unordered_map<mstxDomainHandle_t, std::shared_ptr<MstxDomainAttr>> MstxDomainMgr::domainHandleMap_;

MstxDomainMgr* MstxDomainMgr::GetInstance()
{
    static MstxDomainMgr instance;
    std::call_once(onceFlag, [] {
        std::shared_ptr<MstxDomainAttr> domainHandlePtr;
        Mspti::Common::MsptiMakeSharedPtr(domainHandlePtr);
        if (domainHandlePtr == nullptr) {
            MSPTI_LOGE("Failed to malloc memory for domain attribute, Init DomainMgr Fail");
            return;
        }
        Mspti::Common::MsptiMakeSharedPtr(domainHandlePtr->handle);
        if (domainHandlePtr->handle == nullptr) {
            MSPTI_LOGE("Failed to malloc memory for domain handle, Init DomainMgr Fail");
            return;
        }
        Mspti::Common::MsptiMakeSharedPtr(domainHandlePtr->name, g_defaultDomainName);
        if (domainHandlePtr->name == nullptr) {
            MSPTI_LOGE("Failed to malloc memory for domain %s, Init DomainMgr Fail", g_defaultDomainName.c_str());
            return;
        }
        defaultDomainHandle_t = domainHandlePtr->handle.get();
        domainHandleMap_.insert(std::make_pair(domainHandlePtr->handle.get(), domainHandlePtr));
    });
    return &instance;
}

mstxDomainHandle_t MstxDomainMgr::CreateDomainHandle(const char *name)
{
    if (strcmp(g_defaultDomainName.c_str(), name) == 0) {
        MSPTI_LOGE("domain name can not be 'default'!");
        return nullptr;
    }
    std::lock_guard<std::mutex> lk(domainMutex_);
    if (domainHandleMap_.size() > MARK_MAX_CACHE_NUM) {
        MSPTI_LOGE("Cache domain name failed, current size: %zu, limit size: %u",
            domainHandleMap_.size(), MARK_MAX_CACHE_NUM);
        return nullptr;
    }
    for (const auto &iter : domainHandleMap_) {
        if (strncmp(iter.second->name->c_str(), name, iter.second->name->size()) == 0) {
            iter.second->isDestroyed = false;
            return iter.first;
        }
    }
    std::shared_ptr<MstxDomainAttr> domainHandlePtr;
    Mspti::Common::MsptiMakeSharedPtr(domainHandlePtr);
    if (domainHandlePtr == nullptr) {
        MSPTI_LOGE("Failed to malloc memory for domain attribute.");
        return nullptr;
    }
    Mspti::Common::MsptiMakeSharedPtr(domainHandlePtr->handle);
    if (domainHandlePtr->handle == nullptr) {
        MSPTI_LOGE("Failed to malloc memory for domain handle.");
        return nullptr;
    }
    Mspti::Common::MsptiMakeSharedPtr(domainHandlePtr->name, name);
    if (domainHandlePtr->name == nullptr) {
        MSPTI_LOGE("Failed to malloc memory for domain %s.", name);
        return nullptr;
    }
    domainHandleMap_.insert(std::make_pair(domainHandlePtr->handle.get(), domainHandlePtr));
    // domain默认使能,除非用户提前调用msptiActivityDisableMarkerDomain去使能该domain
    if (disableDomains_.count(*domainHandlePtr->name) != 0) {
        domainHandlePtr->isEnabled = false;
    }
    return domainHandlePtr->handle.get();
}

void MstxDomainMgr::DestroyDomainHandle(mstxDomainHandle_t domain)
{
    std::lock_guard<std::mutex> lk(domainMutex_);
    auto iter = domainHandleMap_.find(domain);
    if (iter == domainHandleMap_.end() || iter->second->isDestroyed) {
        MSPTI_LOGW("Input domain is invalid");
        return;
    }
    iter->second->isDestroyed = true;
}

std::shared_ptr<std::string> MstxDomainMgr::GetDomainNameByHandle(mstxDomainHandle_t domain)
{
    std::lock_guard<std::mutex> lk(domainMutex_);
    auto iter = domainHandleMap_.find(domain);
    if (iter == domainHandleMap_.end() || iter->second->isDestroyed) {
        MSPTI_LOGW("Input domain is invalid.");
        return nullptr;
    }
    return iter->second->name;
}

msptiResult MstxDomainMgr::SetMstxDomainEnableStatus(const char* name, bool flag)
{
    std::string nameStr(name);
    std::lock_guard<std::mutex> lk(domainMutex_);
    if (flag) {
        disableDomains_.erase(nameStr);
    } else {
        disableDomains_.insert(nameStr);
    }
    for (const auto &iter : domainHandleMap_) {
        if (nameStr == *(iter.second->name)) {
            iter.second->isEnabled = flag;
            break;
        }
    }
    return MSPTI_SUCCESS;
}

bool MstxDomainMgr::isDomainEnable(mstxDomainHandle_t domainHandle)
{
    std::lock_guard<std::mutex> lk(domainMutex_);
    auto iter = domainHandleMap_.find(domainHandle);
    return iter == domainHandleMap_.end() ? false : iter->second->isEnabled;
}

bool MstxDomainMgr::isDomainEnable(const char* name)
{
    if (!name) {
        return false;
    }
    std::string nameStr(name);
    std::lock_guard<std::mutex> lk(domainMutex_);
    return disableDomains_.count(nameStr) == 0 ? true : false;
}
}

namespace MsptiMstxApi {
class MstxApiInject {
public:
    MstxApiInject() noexcept
    {
        Mspti::Common::RegisterFunction("libprofapi", "ProfRegisterMstxFunc");
        Mspti::Common::RegisterFunction("libprofapi", "ProfEnableMstxFunc");
        MstxRegisterMstxFunc();
    }
    ~MstxApiInject() = default;
};

MstxApiInject g_mstxApiInject;

static bool IsMsgValid(const char* msg)
{
    if (msg == nullptr) {
        MSPTI_LOGE("Input Params msg is null");
        return false;
    }
    if (strnlen(msg, MAX_MARK_MSG_LEN) == MAX_MARK_MSG_LEN) {
        MSPTI_LOGE("Input Params msg length exceeds the maximum value %d", MAX_MARK_MSG_LEN);
        return false;
    }
    return true;
}

void MstxMarkAFunc(const char* msg, AclrtStream stream)
{
    if (!ActivityManager::GetInstance()->IsActivityKindEnable(MSPTI_ACTIVITY_KIND_MARKER)) {
        return;
    }
    if (!IsMsgValid(msg)) {
        return;
    }
    if (!Mspti::MstxDomainMgr::GetInstance()->isDomainEnable(Mspti::g_defaultDomainName.c_str())) {
        return;
    }
    if (Mspti::Parser::MstxParser::GetInstance()->ReportMark(msg, stream, Mspti::g_defaultDomainName.c_str()) !=
        MSPTI_SUCCESS) {
        MSPTI_LOGE("Report Mark data failed.");
    }
}

uint64_t MstxRangeStartAFunc(const char* msg, AclrtStream stream)
{
    if (!ActivityManager::GetInstance()->IsActivityKindEnable(MSPTI_ACTIVITY_KIND_MARKER)) {
        return MSTX_INVALID_RANGE_ID;
    }
    if (!IsMsgValid(msg)) {
        return MSTX_INVALID_RANGE_ID;
    }
    if (!Mspti::MstxDomainMgr::GetInstance()->isDomainEnable((Mspti::g_defaultDomainName.c_str()))) {
        return MSTX_INVALID_RANGE_ID;
    }
    uint64_t markId = MSTX_INVALID_RANGE_ID;
    if (Mspti::Parser::MstxParser::GetInstance()->ReportRangeStartA(msg, stream,
        markId, Mspti::g_defaultDomainName.c_str()) != MSPTI_SUCCESS) {
        MSPTI_LOGE("Report RangeStart data failed.");
        return MSTX_INVALID_RANGE_ID;
    }
    return markId;
}

void MstxRangeEndFunc(uint64_t rangeId)
{
    if (!ActivityManager::GetInstance()->IsActivityKindEnable(MSPTI_ACTIVITY_KIND_MARKER)) {
        return;
    }
    if (!Mspti::MstxDomainMgr::GetInstance()->isDomainEnable((Mspti::g_defaultDomainName.c_str()))) {
        return;
    }
    if (Mspti::Parser::MstxParser::GetInstance()->ReportRangeEnd(rangeId) != MSPTI_SUCCESS) {
        MSPTI_LOGE("Report RangeEnd data failed.");
    }
}

mstxDomainHandle_t MstxDomainCreateAFunc(const char* name)
{
    if (!IsMsgValid(name)) {
        return nullptr;
    }
    return Mspti::MstxDomainMgr::GetInstance()->CreateDomainHandle(name);
}

void MstxDomainDestroyFunc(mstxDomainHandle_t domain)
{
    Mspti::MstxDomainMgr::GetInstance()->DestroyDomainHandle(domain);
}

void MstxDomainMarkAFunc(mstxDomainHandle_t domain, const char* msg, AclrtStream stream)
{
    if (!ActivityManager::GetInstance()->IsActivityKindEnable(MSPTI_ACTIVITY_KIND_MARKER)) {
        return;
    }
    if (!Mspti::MstxDomainMgr::GetInstance()->isDomainEnable(domain)) {
        return;
    }
    if ((!IsMsgValid(msg))) {
        return;
    }
    auto namePtr = Mspti::MstxDomainMgr::GetInstance()->GetDomainNameByHandle(domain);
    if (namePtr == nullptr) {
        return;
    }
    if (Mspti::Parser::MstxParser::GetInstance()->ReportMark(msg, stream, namePtr->c_str()) !=
        MSPTI_SUCCESS) {
        MSPTI_LOGE("Report Mark data failed.");
    }
}

uint64_t MstxDomainRangeStartAFunc(mstxDomainHandle_t domain, const char* msg, AclrtStream stream)
{
    if (!ActivityManager::GetInstance()->IsActivityKindEnable(MSPTI_ACTIVITY_KIND_MARKER)) {
        return MSTX_INVALID_RANGE_ID;
    }
    if (!Mspti::MstxDomainMgr::GetInstance()->isDomainEnable(domain)) {
        return MSTX_INVALID_RANGE_ID;
    }
    if ((!IsMsgValid(msg))) {
        return MSTX_INVALID_RANGE_ID;
    }
    auto namePtr = Mspti::MstxDomainMgr::GetInstance()->GetDomainNameByHandle(domain);
    if (namePtr == nullptr) {
        return MSTX_INVALID_RANGE_ID;
    }
    uint64_t markId = MSTX_INVALID_RANGE_ID;
    if (Mspti::Parser::MstxParser::GetInstance()->ReportRangeStartA(msg, stream, markId, namePtr->c_str()) !=
        MSPTI_SUCCESS) {
        MSPTI_LOGE("Report RangeStart data failed.");
        return MSTX_INVALID_RANGE_ID;
    }
    return markId;
}

void MstxDomainRangeEndFunc(mstxDomainHandle_t domain, uint64_t rangeId)
{
    if (!ActivityManager::GetInstance()->IsActivityKindEnable(MSPTI_ACTIVITY_KIND_MARKER)) {
        return;
    }
    if (!Mspti::MstxDomainMgr::GetInstance()->isDomainEnable(domain)) {
        MSPTI_LOGW("Domain is disable, range end will not take effect");
        return;
    }
    if (Mspti::MstxDomainMgr::GetInstance()->GetDomainNameByHandle(domain) == nullptr) {
        return;
    }
    if (Mspti::Parser::MstxParser::GetInstance()->ReportRangeEnd(rangeId) != MSPTI_SUCCESS) {
        MSPTI_LOGE("Report RangeEnd data failed.");
    }
}

void SetMstxModuleCoreApi(MstxFuncTable outTable, unsigned int size)
{
    if (size >= static_cast<unsigned int>(MSTX_FUNC_MARKA)) {
        *(outTable[MSTX_FUNC_MARKA]) = (MstxFuncPointer)MstxMarkAFunc;
    }
    if (size >= static_cast<unsigned int>(MSTX_FUNC_RANGE_STARTA)) {
        *(outTable[MSTX_FUNC_RANGE_STARTA]) = (MstxFuncPointer)MstxRangeStartAFunc;
    }
    if (size >= static_cast<unsigned int>(MSTX_FUNC_RANGE_END)) {
        *(outTable[MSTX_FUNC_RANGE_END]) = (MstxFuncPointer)MstxRangeEndFunc;
    }
}

void SetMstxModuleCoreDomainApi(MstxFuncTable outTable, unsigned int size)
{
    if (size >= static_cast<unsigned int>(MSTX_FUNC_DOMAIN_CREATEA)) {
        *(outTable[MSTX_FUNC_DOMAIN_CREATEA]) = (MstxFuncPointer)MstxDomainCreateAFunc;
    }
    if (size >= static_cast<unsigned int>(MSTX_FUNC_DOMAIN_DESTROY)) {
        *(outTable[MSTX_FUNC_DOMAIN_DESTROY]) = (MstxFuncPointer)MstxDomainDestroyFunc;
    }
    if (size >= static_cast<unsigned int>(MSTX_FUNC_DOMAIN_MARKA)) {
        *(outTable[MSTX_FUNC_DOMAIN_MARKA]) = (MstxFuncPointer)MstxDomainMarkAFunc;
    }
    if (size >= static_cast<unsigned int>(MSTX_FUNC_DOMAIN_RANGE_STARTA)) {
        *(outTable[MSTX_FUNC_DOMAIN_RANGE_STARTA]) = (MstxFuncPointer)MstxDomainRangeStartAFunc;
    }
    if (size >= static_cast<unsigned int>(MSTX_FUNC_DOMAIN_RANGE_END)) {
        *(outTable[MSTX_FUNC_DOMAIN_RANGE_END]) = (MstxFuncPointer)MstxDomainRangeEndFunc;
    }
}

int GetModuleTableFunc(MstxGetModuleFuncTableFunc getFuncTable)
{
    int retVal = MSTX_SUCCESS;
    unsigned int outSize = 0;
    MstxFuncTable outTable;
    static std::vector<unsigned int> CheckOutTableSizes = {
        0,
        MSTX_FUNC_END,
        MSTX_FUNC_DOMAIN_END,
        0
    };
    for (size_t i = MSTX_API_MODULE_CORE; i < MSTX_API_MODULE_SIZE; i++) {
        if (getFuncTable(static_cast<MstxFuncModule>(i), &outTable, &outSize) != MSPTI_SUCCESS) {
            MSPTI_LOGW("Failed to get func table for module %zu", i);
            continue;
        }
        switch (i) {
            case MSTX_API_MODULE_CORE:
                SetMstxModuleCoreApi(outTable, outSize);
                break;
            case MSTX_API_MODULE_CORE_DOMAIN:
                SetMstxModuleCoreDomainApi(outTable, outSize);
                break;
            default:
                MSPTI_LOGE("Invalid func module type");
                retVal = MSTX_FAIL;
                break;
        }
        if (retVal == MSTX_FAIL) {
            break;
        }
        MSPTI_LOGI("Succeed to get func table for module %zu", i);
    }
    return retVal;
}

msptiResult MstxRegisterMstxFunc()
{
    MSPTI_LOGI("mspti registerMstxFunc to libprofapi.so");
    ProfRegisterMstxFunc(InitInjectionMstx, PROF_MODULE_MSPTI);
    return MSPTI_SUCCESS;
}

void MsptiEnableMstxFunc()
{
    MSPTI_LOGI("mspti enable mstxFuncs");
    ProfEnableMstxFunc(PROF_MODULE_MSPTI);
}

void MsptiDisableMstxFunc()
{
    MSPTI_LOGI("mspti disable mstxFuncs");
    ProfEnableMstxFunc(PROF_MODULE_MSPROF);
}

int InitInjectionMstx(MstxGetModuleFuncTableFunc getFuncTable)
{
    if (getFuncTable == nullptr) {
        MSPTI_LOGE("Input null mstx getfunctable pointer");
        return MSTX_FAIL;
    }
    if (MsptiMstxApi::GetModuleTableFunc(getFuncTable) != MSPTI_SUCCESS) {
        MSPTI_LOGE("Failed to init mstx funcs.");
        return MSTX_FAIL;
    }
    return MSPTI_SUCCESS;
}
}

void ProfRegisterMstxFunc(MstxInitInjectionFunc mstxInitFunc, ProfModule module)
{
    using profRegisterMstxFunc = std::function<decltype(ProfRegisterMstxFunc)>;
    static profRegisterMstxFunc func = nullptr;
    if (func == nullptr) {
        Mspti::Common::GetFunction("libprofapi", __FUNCTION__, func);
    }
    if (func == nullptr) {
        MSPTI_LOGE("ProfRegisterMstxFunc can not be founded in libprofapi.so");
        return;
    }
    MSPTI_LOGI("mspti registerMstxFunc to libprofapi.so");
    return func(mstxInitFunc, module);
}

void ProfEnableMstxFunc(ProfModule module)
{
    using enableMstxFunc = std::function<decltype(ProfEnableMstxFunc)>;
    static enableMstxFunc func = nullptr;
    if (func == nullptr) {
        Mspti::Common::GetFunction("libprofapi", __FUNCTION__, func);
    }
    if (func == nullptr) {
        MSPTI_LOGE("ProfEnableMstxFunc can not be founded in libprofapi.so");
        return;
    }
    MSPTI_LOGI("mspti enable mstxFuncs");
    return func(module);
}

msptiResult msptiActivityEnableMarkerDomain(const char* name)
{
    if (!MsptiMstxApi::IsMsgValid(name)) {
        MSPTI_LOGE("domainHandle is nullptr, check your input params");
        return MSPTI_ERROR_INVALID_PARAMETER;
    }
    MSPTI_LOGI("Set domain %s collection on", name);
    return Mspti::MstxDomainMgr::GetInstance()->SetMstxDomainEnableStatus(name, true);
}

msptiResult msptiActivityDisableMarkerDomain(const char* name)
{
    if (!MsptiMstxApi::IsMsgValid(name)) {
        MSPTI_LOGE("domainHandle is nullptr, check your input params");
        return MSPTI_ERROR_INVALID_PARAMETER;
    }
    MSPTI_LOGI("Set domain %s collection off", name);
    return Mspti::MstxDomainMgr::GetInstance()->SetMstxDomainEnableStatus(name, false);
}