#include "torch_npu/csrc/framework/interface/MsProfilerInterface.h"
#include "torch_npu/csrc/core/npu/NPUException.h"
#include "torch_npu/csrc/core/npu/register/FunctionLoader.h"
#include "third_party/acl/inc/acl/acl_prof.h"

namespace at_npu {
namespace native {

#undef TORCH_NPU_LOAD_FUNC
#define TORCH_NPU_LOAD_FUNC(funcName) \
  TORCH_NPU_REGISTER_FUNCTION(libmsprofiler, funcName)

#undef TORCH_NPU_GET_FUNC
#define TORCH_NPU_GET_FUNC(funcName)              \
  TORCH_NPU_GET_FUNCTION(libmsprofiler, funcName)


TORCH_NPU_REGISTER_LIBRARY(libmsprofiler, RTLD_LAZY | RTLD_GLOBAL)
TORCH_NPU_LOAD_FUNC(aclprofWarmup)
TORCH_NPU_LOAD_FUNC(aclprofSetConfig)
TORCH_NPU_LOAD_FUNC(aclprofGetSupportedFeatures)
TORCH_NPU_LOAD_FUNC(aclprofGetSupportedFeaturesV2)
TORCH_NPU_LOAD_FUNC(aclprofRegisterDeviceCallback)
TORCH_NPU_LOAD_FUNC(aclprofMarkEx)

aclError AclProfilingRegisterDeviceCallback()
{
    typedef aclError (*AclProfRegisterDeviceCallbackFunc)();
    static AclProfRegisterDeviceCallbackFunc func = nullptr;
    if (func == nullptr) {
        func = (AclProfRegisterDeviceCallbackFunc)TORCH_NPU_GET_FUNC(aclprofRegisterDeviceCallback);
        if (func == nullptr) {
            return ACL_ERROR_PROF_MODULES_UNSUPPORTED;
        }
    }
    return func();
}

aclError AclProfilingWarmup(const aclprofConfig *profilerConfig)
{
    typedef aclError (*AclProfWarmupFunc)(const aclprofConfig *);
    static AclProfWarmupFunc func = nullptr;
    if (func == nullptr) {
        func = (AclProfWarmupFunc)TORCH_NPU_GET_FUNC(aclprofWarmup);
        if (func == nullptr) {
            return ACL_ERROR_PROF_MODULES_UNSUPPORTED;
        }
    }
    TORCH_CHECK(func, "Failed to find function ", "aclprofWarmup", PROF_ERROR(ErrCode::NOT_FOUND));
    return func(profilerConfig);
}

aclError AclprofSetConfig(aclprofConfigType configType, const char* config, size_t configLength) {
    typedef aclError(*AclprofSetConfigFunc)(aclprofConfigType, const char *, size_t);
    static AclprofSetConfigFunc func = nullptr;
    if (func == nullptr) {
        func = (AclprofSetConfigFunc)TORCH_NPU_GET_FUNC(aclprofSetConfig);
        if (func == nullptr) {
            return ACL_ERROR_PROF_MODULES_UNSUPPORTED;
        }
    }
    TORCH_CHECK(func, "Failed to find function ", "aclprofSetConfig", PROF_ERROR(ErrCode::NOT_FOUND));
    return func(configType, config, configLength);
}

aclError AclprofGetSupportedFeatures(size_t* featuresSize, void** featuresData)
{
    typedef aclError(*AclprofGetSupportedFeaturesFunc)(size_t*, void**);
    static AclprofGetSupportedFeaturesFunc func = nullptr;
    if (func == nullptr) {
        func = (AclprofGetSupportedFeaturesFunc)TORCH_NPU_GET_FUNC(aclprofGetSupportedFeaturesV2);
        if (func == nullptr) {
            func = (AclprofGetSupportedFeaturesFunc)TORCH_NPU_GET_FUNC(aclprofGetSupportedFeatures);
        }
    }

    if (func != nullptr) {
        return func(featuresSize, featuresData);
    }
    return ACL_ERROR_PROF_MODULES_UNSUPPORTED;
}

aclError AclProfilingMarkEx(const char *msg, size_t msgLen, aclrtStream stream)
{
    typedef aclError (*aclprofMarkExFunc) (const char *, size_t, aclrtStream);
    static aclprofMarkExFunc func = nullptr;
    if (func == nullptr) {
        func = (aclprofMarkExFunc)TORCH_NPU_GET_FUNC(aclprofMarkEx);
        if (func == nullptr) {
            return ACL_ERROR_PROF_MODULES_UNSUPPORTED;
        }
    }
    TORCH_CHECK(func, "Failed to find function ", "aclprofMarkEx", PROF_ERROR(ErrCode::NOT_FOUND));
    return func(msg, msgLen, stream);
}

}
}