/*
 * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "ecmascript/jit/jit.h"
#include "ecmascript/base/config.h"
#include "ecmascript/jit/jit_task.h"
#if ECMASCRIPT_ENABLE_ARK_STEED
#include "ecmascript/arksteed/arksteed_task.h"
#endif
#include "ecmascript/compiler/bytecodes.h"
#include "ecmascript/dfx/vmstat/jit_warmup_profiler.h"
#include "ecmascript/ohos/jit_tools.h"
#include "ecmascript/platform/os.h"
#include "ecmascript/checkpoint/thread_state_transition.h"

namespace panda::ecmascript {

Jit *Jit::GetInstance()
{
    static Jit instance_;
    return &instance_;
}

void Jit::CreateJitResources()
{
    if (jitResources_ == nullptr) {
        jitResources_ = std::make_unique<JitResources>();
        jitResources_->ResolveLib();
    }
}

bool Jit::IsLibResourcesResolved() const
{
    if (jitResources_ != nullptr) {
        return jitResources_->IsLibResolved();
    }
    return false;
}

void Jit::PreFork()
{
    CreateJitResources();
}

void Jit::SetJitEnablePostFork(EcmaVM *vm, const std::string &bundleName)
{
    JSRuntimeOptions &options = vm->GetJSOptions();
    bool jitEnable = ohos::JitTools::GetJitEscapeDisable() || !AotCrashInfo::IsJitEscape();
    jitEnable &= ohos::EnableAotJitListHelper::GetInstance()->IsEnableJit(bundleName);
    jitEnable &= !vm->GetJSOptions().GetAOTHasException();
    jitEnable &= ohos::JitTools::IsSupportJitCodeSigner();
    jitEnable &= HasJitFortACL();
    if (jitEnable) {
        bool isEnableFastJit = options.IsEnableJIT() && options.GetEnableAsmInterpreter();
        bool isEnableBaselineJit = options.IsEnableBaselineJIT() && options.GetEnableAsmInterpreter();
        options.SetEnableJitFrame(ohos::JitTools::GetJitFrameEnable());
        options.SetEnableAPPJIT(true);
        isApp_ = true;
        // for app threshold
        uint32_t defaultSize = 150;
        uint32_t threshold = ohos::JitTools::GetJitHotnessThreshold(defaultSize);
        options.SetJitHotnessThreshold(threshold);
        hotnessThreshold_ = threshold;
        bundleName_ = bundleName;
        SetEnableOrDisable(options, isEnableFastJit, isEnableBaselineJit);
        if (fastJitEnable_ || baselineJitEnable_) {
            ConfigJit(vm);
        }
    }
}

void Jit::SwitchProfileStubs(EcmaVM *vm)
{
#if ECMASCRIPT_ENABLE_INTERPRETER_JIT_STUBS
    JSThread *thread = vm->GetAssociatedJSThread();
    JSRuntimeOptions &options = vm->GetJSOptions();
    std::shared_ptr<PGOProfiler> pgoProfiler = vm->GetPGOProfiler();
    if (!options.IsEnableJITPGO() || pgoProfiler == nullptr) {
        thread->SwitchJitProfileStubs(false);
    } else {
        // if not enable aot pgo
        if (!pgo::PGOProfilerManager::GetInstance()->IsEnable()) {
            // disable dump
            options.SetEnableProfileDump(false);
            SetProfileNeedDump(false);
            // enable profiler
            options.SetEnablePGOProfiler(true);
            pgoProfiler->Reset(true);
            // switch pgo stub
            thread->SwitchJitProfileStubs(true);
        }
        pgoProfiler->InitJITProfiler();
    }
#else
    (void)vm;
#endif
}

void Jit::ConfigOptions(EcmaVM *vm) const
{
    JSRuntimeOptions &options = vm->GetJSOptions();

    options.SetEnableAPPJIT(isApp_);
    options.SetEnableProfileDump(isProfileNeedDump_);

    bool jitEnableLitecg = ohos::JitTools::IsJitEnableLitecg(options.IsCompilerEnableLiteCG());
    options.SetCompilerEnableLiteCG(jitEnableLitecg);

    uint16_t jitCallThreshold = ohos::JitTools::GetJitCallThreshold(options.GetJitCallThreshold());
    options.SetJitCallThreshold(jitCallThreshold);

    uint32_t jitHotnessThreshold = GetHotnessThreshold();
    options.SetJitHotnessThreshold(jitHotnessThreshold);

    bool jitDisableCodeSign = ohos::JitTools::GetCodeSignDisable(options.GetDisableCodeSign());
    options.SetDisableCodeSign(jitDisableCodeSign);

    bool jitEnableJitFort = ohos::JitTools::GetEnableJitFort(options.GetEnableJitFort());
    options.SetEnableJitFort(jitEnableJitFort);

    bool jitEnableVerifyPass = ohos::JitTools::GetEnableJitVerifyPass(options.IsEnableJitVerifyPass());
    options.SetEnableJitVerifyPass(jitEnableVerifyPass);

    bool jitEnableAsyncCopyToFort = ohos::JitTools::GetEnableAsyncCopyToFort(options.GetEnableAsyncCopyToFort());
    options.SetEnableAsyncCopyToFort(jitEnableAsyncCopyToFort);

    vm->SetEnableJitLogSkip(ohos::JitTools::GetSkipJitLogEnable());

    std::string jitMethodDichotomy = ohos::JitTools::GetJitMethodDichotomy(options.GetJitMethodDichotomy());
    options.SetJitMethodDichotomy(jitMethodDichotomy);

    LOG_JIT(INFO) << "enable jit bundle:" << bundleName_ <<
        ", litecg:" << jitEnableLitecg <<
        ", call threshold:" << static_cast<int>(jitCallThreshold) <<
        ", hotness threshold:" << jitHotnessThreshold <<
        ", disable codesigner:" << jitDisableCodeSign;
}

void Jit::ConfigJit(EcmaVM *vm)
{
    SwitchProfileStubs(vm);
    ConfigOptions(vm);
    ConfigJitFortOptions(vm);
    // initialize jit method dichotomy
    CompileDecision::GetMethodNameCollector().Init(vm);
    CompileDecision::GetMethodNameFilter().Init(vm);
}

void Jit::ConfigJitFortOptions(EcmaVM *vm)
{
    SetDisableCodeSign(vm->GetJSOptions().GetDisableCodeSign());
    SetEnableJitFort(vm->GetJSOptions().GetEnableJitFort());
    SetEnableAsyncCopyToFort(vm->GetJSOptions().GetEnableAsyncCopyToFort());
}

void Jit::SetEnableOrDisable(const JSRuntimeOptions &options, bool isEnableFastJit, bool isEnableBaselineJit)
{
    LockHolder holder(setEnableLock_);
    bool enableJit = isEnableFastJit || isEnableBaselineJit;
    if (enableJit) {
        CreateJitResources();
    }

    if (IsLibResourcesResolved()) {
        jitDfx_ = JitDfx::GetInstance();
        jitDfx_->Init(options, bundleName_);
        // When starting JIT through the application, the initialization of JIT fort is completed in app spawn.
        // When starting JIT through ark_js_vm, the initialization of JIT fort needs to be completed on the main thread.
        // Note: All threads using JIT fort must have access to JIT fort memory. It is best to ensure that all threads
        // that require directional JIT fort memory come from a thread that has already initialized JIT fort memory.
        if (!IsAppJit()) {
            JitFort::InitJitFort();
        }
        jitResources_->InitJitEnv(options);
        initialized_ = true;
    }

    if (initialized_) {
        fastJitEnable_ = isEnableFastJit;
        baselineJitEnable_ = isEnableBaselineJit;
        hotnessThreshold_ = options.GetJitHotnessThreshold();
    }
}

void Jit::Destroy()
{
    LockHolder holder(setEnableLock_);
    if (!initialized_) {
        return;
    }

    initialized_ = false;
    fastJitEnable_ = false;
    baselineJitEnable_ = false;
    ASSERT(jitResources_ != nullptr);
    jitResources_->Destroy();
    jitResources_ = nullptr;
}

bool Jit::IsEnableFastJit() const
{
    return fastJitEnable_;
}

bool Jit::IsEnableBaselineJit() const
{
    return baselineJitEnable_;
}

bool Jit::IsEnableJitFort() const
{
    return isEnableJitFort_;
}

void Jit::SetEnableJitFort(bool isEnableJitFort)
{
    isEnableJitFort_ = isEnableJitFort;
}

bool Jit::IsDisableCodeSign() const
{
    return isDisableCodeSign_;
}

void Jit::SetDisableCodeSign(bool isDisableCodeSign)
{
    isDisableCodeSign_ = isDisableCodeSign;
}

bool Jit::IsEnableAsyncCopyToFort() const
{
    return isEnableAsyncCopyToFort_;
}

void Jit::SetEnableAsyncCopyToFort(bool isEnableAsyncCopyToFort)
{
    isEnableAsyncCopyToFort_ = isEnableAsyncCopyToFort;
}

Jit::~Jit()
{
}

void Jit::CountInterpExecFuncs(JSThread *jsThread, JSHandle<JSFunction> &jsFunction)
{
    Method *method = Method::Cast(jsFunction->GetMethod(jsThread).GetTaggedObject());
    auto jSPandaFile = method->GetJSPandaFile(jsThread);
    ASSERT(jSPandaFile != nullptr);
    CString fileDesc = jSPandaFile->GetJSPandaFileDesc();
    CString methodInfo =
        fileDesc + ":" + method->GetRecordNameStr(jsThread) +"." + CString(method->GetMethodName(jsThread));
    auto &profMap = JitWarmupProfiler::GetInstance()->profMap_;
    if (profMap.find(methodInfo) == profMap.end()) {
        profMap.insert({methodInfo, false});
    }
}

void Jit::Compile(EcmaVM *vm, JSHandle<JSFunction> &jsFunction, CompilerTier tier,
                  int32_t osrOffset, JitCompileMode mode)
{
    auto jit = Jit::GetInstance();
    if ((!jit->IsEnableBaselineJit() && tier.IsBaseLine()) ||
        (!jit->IsEnableFastJit() && tier.IsFastJit())) {
        return;
    }

    if (!vm->IsEnableOsr() && osrOffset != MachineCode::INVALID_OSR_OFFSET) {
        return;
    }

    CompileDecision compileDecision(vm, jsFunction, tier, osrOffset, mode);
    if (!compileDecision.Decision()) {
        return;
    }

    jit->Compile(vm, compileDecision);
}

void Jit::Compile(EcmaVM *vm, const CompileDecision &decision)
{
    [[maybe_unused]] EcmaHandleScope handleScope(vm->GetJSThread());
    auto tier = decision.GetTier();
    auto jsFunction = decision.GetJsFunction();
    auto methodInfo = decision.GetMethodInfo();
    auto methodName = decision.GetMethodName();
    auto osrOffset = decision.GetOsrOffset();
    auto mode = decision.GetCompileMode();

    CString msg = "compile method:" + methodInfo + ", in work thread";
    TimeScope scope(vm, msg, tier, true, true);

    ECMA_BYTRACE_NAME(HITRACE_LEVEL_COMMERCIAL, HITRACE_TAG_ARK,
        ConvertToStdString("JIT::Compile:" + methodInfo).c_str(), "");
    if (tier.IsFastJit()) {
        jsFunction->SetJitCompilingFlag(true);
    } else {
        ASSERT(tier.IsBaseLine());
        jsFunction->SetBaselinejitCompilingFlag(true);
    }
    GetJitDfx()->SetTriggerCount(tier);

    {
        {
            ThreadNativeScope scope(vm->GetJSThread());
            JitTaskpool::GetCurrentTaskpool()->WaitForJitTaskPoolReady();
        }
        EcmaVM *compilerVm = JitTaskpool::GetCurrentTaskpool()->GetCompilerVm();
        std::shared_ptr<JitTask> jitTask = std::make_shared<JitTask>(vm->GetJSThread(),
            // avoid check fail when enable multi-thread check
            compilerVm->GetJSThreadNoCheck(), this, jsFunction, tier, methodName, osrOffset, mode);

        jitTask->PrepareCompile();
        JitTaskpool::GetCurrentTaskpool()->PostTask(
            std::make_unique<JitTask::AsyncTask>(jitTask, vm->GetJSThread()->GetThreadId()));
        if (mode.IsSync()) {
            // sync mode, also compile in taskpool as litecg unsupport parallel compile,
            // wait task compile finish then install code
            jitTask->WaitFinish();
            jitTask->InstallCode();
        }
        int spendTime = scope.TotalSpentTimeInMicroseconds();
        jitTask->SetMainThreadCompilerTime(spendTime);
        GetJitDfx()->RecordSpentTimeAndPrintStatsLogInJsThread(spendTime);
    }
}

void Jit::RequestInstallCode(std::shared_ptr<JitTask> jitTask)
{
    LockHolder holder(threadTaskInfoLock_);
    ThreadTaskInfo &info = threadTaskInfo_[jitTask->GetHostThread()];
    if (info.skipInstallTask_) {
        return;
    }
    info.installJitTasks_.push_back(jitTask);

    // set
    jitTask->GetHostThread()->SetInstallMachineCode(true);
    jitTask->GetHostThread()->SetCheckSafePointStatus();
}

uint32_t Jit::GetRunningTaskCnt(EcmaVM *vm)
{
    LockHolder holder(threadTaskInfoLock_);
    ThreadTaskInfo &info = threadTaskInfo_[vm->GetJSThread()];
    return info.jitTaskCnt_.load();
}

void Jit::InstallTasks(JSThread *jsThread)
{
    // Install tasks is only possible for JSThread in running state
    ASSERT(jsThread->IsJSThread() && jsThread->IsInRunningState());
    std::deque<std::shared_ptr<JitTask>> taskQueue;
    {
        LockHolder holder(threadTaskInfoLock_);
        ThreadTaskInfo &info = threadTaskInfo_[jsThread];
        taskQueue = info.installJitTasks_;
        info.installJitTasks_.clear();
    }
    ECMA_BYTRACE_NAME(HITRACE_LEVEL_COMMERCIAL, HITRACE_TAG_ARK,
        ConvertToStdString("Jit::InstallTasks count:" + ToCString(taskQueue.size())).c_str(), "");

    for (auto it = taskQueue.begin(); it != taskQueue.end(); it++) {
        std::shared_ptr<JitTask> task = *it;
        // check task state
        task->InstallCode();
    }
}

bool Jit::JitCompile(void *compiler, JitTask *jitTask)
{
    return jitResources_->Compile(compiler, jitTask);
}

bool Jit::JitFinalize(void *compiler, JitTask *jitTask)
{
    return jitResources_->Finalize(compiler, jitTask);
}

void *Jit::CreateJitCompilerTask(JitTask *jitTask)
{
    return jitResources_->CreateJitCompilerTask(jitTask);
}

void Jit::DeleteJitCompilerTask(void *compiler)
{
    jitResources_->DeleteJitCompilerTask(compiler);
}

void Jit::ClearTask(const std::function<bool(common::Task *task)> &checkClear)
{
    JitTaskpool::GetCurrentTaskpool()->ForEachTask([&checkClear](common::Task *task) {
        JitTask::AsyncTask *asyncTask = static_cast<JitTask::AsyncTask*>(task);
        if (checkClear(asyncTask)) {
            asyncTask->Terminated();
        }
    });
}

void Jit::ClearTaskWithVm(EcmaVM *vm)
{
    ClearTask([vm](common::Task *task) {
        JitTask::AsyncTask *asyncTask = static_cast<JitTask::AsyncTask*>(task);
        return vm == asyncTask->GetHostVM();
    });

    {
        LockHolder holder(threadTaskInfoLock_);
        ThreadTaskInfo &info = threadTaskInfo_[vm->GetJSThread()];
        info.skipInstallTask_ = true;
        auto &taskQueue = info.installJitTasks_;
        taskQueue.clear();

        if (info.jitTaskCnt_.load() != 0) {
            ThreadNativeScope threadNativeScope(vm->GetJSThread());
            info.jitTaskCntCv_.Wait(&threadTaskInfoLock_);
        }
    }
}

void Jit::IncJitTaskCnt(JSThread *thread)
{
    LockHolder holder(threadTaskInfoLock_);
    ThreadTaskInfo &info = threadTaskInfo_[thread];
    info.jitTaskCnt_.fetch_add(1);
}

void Jit::DecJitTaskCnt(JSThread *thread)
{
    LockHolder holder(threadTaskInfoLock_);
    ThreadTaskInfo &info = threadTaskInfo_[thread];
    uint32_t old = info.jitTaskCnt_.fetch_sub(1);
    if (old == 1) {
        info.jitTaskCntCv_.Signal();
    }
}

void Jit::CheckMechineCodeSpaceMemory(JSThread *thread, int remainSize)
{
    if (!thread->IsMachineCodeLowMemory()) {
        return;
    }
    if (remainSize > MIN_CODE_SPACE_SIZE) {
        thread->SetMachineCodeLowMemory(false);
    }
}

void Jit::ChangeTaskPoolState(bool inBackground)
{
    if (fastJitEnable_ || baselineJitEnable_) {
        if (inBackground) {
            JitTaskpool::GetCurrentTaskpool()->SetThreadPriority(common::PriorityMode::BACKGROUND);
        } else {
            JitTaskpool::GetCurrentTaskpool()->SetThreadPriority(common::PriorityMode::FOREGROUND);
        }
    }
}

Jit::TimeScope::TimeScope(EcmaVM *vm, CString message, CompilerTier tier, bool outPutLog, bool isDebugLevel)
    : vm_(vm), message_(message), tier_(tier), outPutLog_(outPutLog), isDebugLevel_(isDebugLevel)
{
    if (outPutLog_) {
        if (isDebugLevel_) {
            LOG_JIT(DEBUG) << tier_ << message_ << " begin.";
        } else {
            auto bundleName = vm_->GetBundleName();
            if (vm_->GetEnableJitLogSkip() && bundleName != "" && message_.find(bundleName) == std::string::npos) {
                return;
            }
            LOG_JIT(INFO) << tier_ << message_ << " begin.";
        }
    }
}

Jit::TimeScope::~TimeScope()
{
    if (!outPutLog_) {
        return;
    }
    if (isDebugLevel_) {
        LOG_JIT(DEBUG) << tier_ << message_ << ": " << TotalSpentTime() << "ms";
    } else {
        auto bundleName = vm_->GetBundleName();
        if (vm_->GetEnableJitLogSkip() && bundleName != "" && message_.find(bundleName) == std::string::npos) {
            return;
        }
        LOG_JIT(INFO) << tier_ << message_ << ", compile time: " << TotalSpentTime() << "ms";
    }
}

#if ECMASCRIPT_ENABLE_ARK_STEED
bool Jit::IsArkSteedBytecodeSupported(EcmaVM *vm, JSHandle<JSFunction> &jsFunction)
{
    JSThread *hostThread = vm->GetJSThread();
    Method *method = Method::Cast(jsFunction->GetMethod(hostThread).GetTaggedObject());
    uint32_t bytecodeSize = method->GetCodeSize(hostThread);
    const uint8_t *bytecodeArray = method->GetBytecodeArray();
    const uint8_t *pc = bytecodeArray;
    const uint8_t *end = bytecodeArray + bytecodeSize;
    while (pc < end) {
        kungfu::EcmaOpcode opcode = kungfu::Bytecodes::GetOpcode(pc);
        if (!kungfu::IsArkSteedSupportedOpcode(opcode)) {
            LOG_JIT(INFO) << "ArkSteed: skip compilation due to unsupported bytecode";
            return false;
        }
        pc += BytecodeInstruction::Size(opcode);
    }
    return true;
}

// ArkSteed compilation entry point
void Jit::CompileArkSteed(EcmaVM *vm, JSHandle<JSFunction> &jsFunction,
                          CompilerTier tier, int32_t osrOffset, JitCompileMode mode)
{
    auto jit = Jit::GetInstance();

    if (!vm->IsEnableOsr() && osrOffset != MachineCode::INVALID_OSR_OFFSET) {
        return;
    }

    CompileDecision compileDecision(vm, jsFunction, tier, osrOffset, mode);
    if (!compileDecision.Decision()) {
        return;
    }

    // Check if bytecode is supported by ArkSteed
    if (!IsArkSteedBytecodeSupported(vm, jsFunction)) {
        return;
    }

    [[maybe_unused]] EcmaHandleScope handleScope(vm->GetJSThread());
    CString methodInfo = "ArkSteed compile method";
    TimeScope scope(vm, "ArkSteed compile method:" + methodInfo, tier, true, true);

    ECMA_BYTRACE_NAME(HITRACE_LEVEL_COMMERCIAL, HITRACE_TAG_ARK,
        ConvertToStdString("ArkSteed::Compile:" + methodInfo).c_str(), "");

    if (tier.IsArkSteed()) {
        jsFunction->SetJitCompilingFlag(true);
    } else {
        // to do: fit baseline
        UNREACHABLE();
        ASSERT(tier.IsBaseLine());
        jsFunction->SetBaselinejitCompilingFlag(true);
    }
    jit->GetJitDfx()->SetTriggerCount(tier);

    {
        {
            ThreadNativeScope scope(vm->GetJSThread());
            JitTaskpool::GetCurrentTaskpool()->WaitForJitTaskPoolReady();
        }

        EcmaVM *compilerVm = JitTaskpool::GetCurrentTaskpool()->GetCompilerVm();
        JSThread *hostThread = vm->GetJSThread();
        JSThread *compilerThread = compilerVm->GetJSThreadNoCheck();

        Method *method = Method::Cast(jsFunction->GetMethod(hostThread).GetTaggedObject());
        CString methodName = method->GetMethodName(hostThread);

        auto arkSteedTask = std::make_shared<arksteed::ArkSteedTask>(
            hostThread, compilerThread, jit, jsFunction, tier, methodName, osrOffset, mode);

        arkSteedTask->PrepareCompile();

        JitTaskpool::GetCurrentTaskpool()->PostTask(
            std::make_unique<arksteed::ArkSteedTask::AsyncTask>(
                arkSteedTask, vm->GetJSThread()->GetThreadId()));

        if (mode.IsSync()) {
            arkSteedTask->WaitFinish();
            arkSteedTask->InstallCode();
        }

        int spendTime = scope.TotalSpentTimeInMicroseconds();
        arkSteedTask->SetMainThreadCompilerTime(spendTime);
        jit->GetJitDfx()->RecordSpentTimeAndPrintStatsLogInJsThread(spendTime);
    }
}
#endif
}  // namespace panda::ecmascript