/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file codegen_npu.cpp
 * \brief
 */
#include "codegen_npu.h"

#include <cstring>
#include <error.h>
#include <fstream>

#include "codegen/utils/parallel_execute.h"
#include "codegen_op_npu.h"
#include "interface/utils/file_utils.h"
#include "interface/tensor/logical_tensor.h"
#include "interface/function/function.h"
#include "interface/compiler_monitor/monitor_stage_scope.h"
#include "interface/compiler_monitor/monitor_manager.h"
#include "interface/configs/config_manager.h"
#include "interface/utils/op_info_manager.h"
#include "interface/operation/distributed/distributed_common.h"
#include "tilefwk/tilefwk.h"
#include "securec.h"

namespace npu::tile_fwk {

void FloatSpecValMgr::UpdateByOp(const Operation& op)
{
    std::vector<Element> eles;
    if (op.HasAttr(OpAttributeKey::scalar)) {
        eles.emplace_back(op.GetElementAttribute(OpAttributeKey::scalar));
    }
    if (op.HasAttr(OpAttributeKey::vectorScalar)) {
        auto vecScalars = op.GetVectorElementAttribute(OpAttributeKey::vectorScalar);
        eles.insert(eles.end(), vecScalars.begin(), vecScalars.end());
    }

    if (eles.empty()) {
        return;
    }

    for (const auto& e : eles) {
        if (e.GetDataType() == DataType::DT_FP16 || e.GetDataType() == DataType::DT_FP32 ||
            e.GetDataType() == DataType::DT_BF16) {
            double value = e.Cast<float>();
            if (std::isinf(value) || std::isnan(value)) {
                floatSpecVals_.insert({e.GetDataType(), value});
            }
        }
    }
}

void FloatSpecValMgr::PrintFloatSpecVal(std::ostringstream& oss)
{
    // print statement like: union {float f; uint32_t u;} float_inf = {.u = 0x7F800000};
    for (const auto& fs : floatSpecVals_) {
        std::string dtypeCCE = DataType2CCEStr(fs.dtype);
        oss << "union "
            << "{" << dtypeCCE << " f; "
            << "uint32_t u;} " << fs.GetFsVarName() << " = {.u = " << fs.GetFsValueStr() << "};\n";
    }
}

void PrintOperand(const std::string& operIO, std::shared_ptr<LogicalTensor> operand)
{
    CODEGEN_LOGI(
        "insert %s magic: %d, tensor: %s, memory map is: ", operIO.c_str(), operand->GetMagic(),
        operand->Dump().c_str());
    CODEGEN_LOGI(
        "range is [%zu, %zu, %d]\n", operand->memoryrange.start, operand->memoryrange.end, operand->memoryrange.memId);
}

bool HasAllocAttr(const std::shared_ptr<LogicalTensor>& tensor)
{
    bool needAlloc = false;
    tensor->GetAttr(OpAttributeKey::needAlloc, needAlloc);
    return needAlloc;
}

void CodeGenNPU::GenInclude(const Function& topFunc, std::ostringstream& oss) const
{
    // expression fusion generated by machine, controlled by COMPILE_STAGE
    if (config::GetHostOption<int64_t>(COMPILE_STAGE) != CS_EXECUTE_GRAPH &&
        topFunc.IsFunctionType({FunctionType::DYNAMIC, FunctionType::DYNAMIC_LOOP, FunctionType::DYNAMIC_LOOP_PATH})) {
        uint64_t tilingKey = OpInfoManager::GetInstance().GetOpTilingKey();
        std::string expFileName = "../kernel_aicpu/expression_" + std::to_string(tilingKey) + ".h";
        // expression.h depend on __TILE_FWK_AICORE__
        oss << "#define __TILE_FWK_AICORE__ 1\n#include \"" << expFileName << "\"\n";
    }

    oss << "#include \"TileOpImpl.h\"\n";
    oss << "#include \"tilefwk/aicpu_common.h\"\n\n";
}

void CodeGenNPU::GenCommentBeforeFuncHeader(Function& subFunc, std::ostringstream& oss) const
{
    oss << "// funcHash: " << subFunc.GetFunctionHash() << "\n\n";
    CODEGEN_LOGI("function hash is: %s", subFunc.GetFunctionHash().c_str());
}

std::string CodeGenNPU::GenKernelName(Function& topFunc, uint64_t programId)
{
    std::ostringstream kernelName;
    kernelName << topFunc.GetMagicName() << "_" << programId;
    uint64_t tilingKey = OpInfoManager::GetInstance().GetNewSubTilingKey();
    kernelName << "_" << std::to_string(tilingKey);
    return kernelName.str();
}

std::string CodeGenNPU::GenFuncHeader(uint64_t programId, Function& topFunc, CompileInfo& compileInfo) const
{
    std::ostringstream funcHeader;
    funcHeader << "extern \"C\" [aicore] void ";
    // kernel name
    auto kernelName = GenKernelName(topFunc, programId);
    compileInfo.SetKernelName(kernelName);
    funcHeader << kernelName;
    // kernel func param
    funcHeader << "(" << GM_PARAM_TYPE_FOR_DYN
               << "* param, int64_t GMStackBase, __gm__ int64_t *hcclContext, __gm__ TaskStat* taskStat)";
    auto funcDec = funcHeader.str() + ";";
    compileInfo.SetFuncDeclare(funcDec);
    funcHeader << " {\n";
    return funcHeader.str();
}

void CodeGenNPU::GenDDRChecker(std::ostringstream& oss) const
{
    if (config::GetDebugOption<int64_t>(CFG_RUNTIME_DBEUG_MODE) != CFG_DEBUG_GM_OUT_OF_BOUNDS &&
        config::GetDebugOption<int64_t>(CFG_RUNTIME_DBEUG_MODE) != CFG_DEBUG_ALL) {
        return;
    }

    std::string checkFunc = R"!!!(
inline __aicore__ void  CheckInvalidAccessOfDDR(uint64_t ddr_size, uint64_t access_offset, uint64_t access_extent, uint32_t read_or_write) {
    if (access_offset < 0 || access_offset + access_extent > ddr_size) {
        if (read_or_write == 1) {
            trap();
        } else {
            trap();
        }
    }
}
)!!!";

    oss << checkFunc << "\n";
}

void CodeGenNPU::GenFuncBodyBefore(
    const std::pair<uint64_t, Function*>& subFuncPair, Function& topFunc, CompileInfo& compileInfo,
    std::ostringstream& oss) const
{
    GenInclude(topFunc, oss);
    GenCommentBeforeFuncHeader(*subFuncPair.second, oss);
    GenDDRChecker(oss);
    oss << GenFuncHeader(subFuncPair.first, topFunc, compileInfo);
}

void CodeGenNPU::GenFuncEnd(std::ostringstream& oss) const { oss << "}\n"; }

std::string CodeGenNPU::GenAllocForLocalBuffer(
    const Operation& op, const std::shared_ptr<SymbolManager>& symbolMgr) const
{
    std::string allocSourceCode;
    auto genExtraAllocForTensor = [this, &symbolMgr](const std::shared_ptr<LogicalTensor>& operand) -> std::string {
        if (HasAllocAttr(operand)) {
            CODEGEN_LOGI("operand has an alloc attr, need to gen extra alloc, operand is: %s", operand->Dump().c_str());
            std::optional<std::string> allocCodeMaybe = GenExtraAlloc(symbolMgr, operand);
            if (allocCodeMaybe.has_value()) {
                return allocCodeMaybe.value();
            }
        }
        return "";
    };
    for (const std::shared_ptr<LogicalTensor>& operand : op.GetIOperands()) {
        symbolMgr->AddToTensorMap(operand->GetMagic(), operand);
        PrintOperand("IOperand", operand);
        allocSourceCode.append(genExtraAllocForTensor(operand));
    }
    for (const std::shared_ptr<LogicalTensor>& operand : op.GetOOperands()) {
        symbolMgr->AddToTensorMap(operand->GetMagic(), operand);
        PrintOperand("OOperand", operand);
        allocSourceCode.append(genExtraAllocForTensor(operand));
    }

    return allocSourceCode;
}

std::string BuildDynParamInfo(const DynParamInfo& info)
{
    std::vector<std::string> params{
        "param", std::to_string(info.tensorIndex), std::to_string(info.tensorBaseAddrCoaIndex),
        std::to_string(info.dimSize), std::to_string(info.dimIndex)};
    auto res = WrapParamByParentheses(params);
    return res;
}

// GET_PARAM_OFFSET_BY_IDX(param, n, base, dim, idx)
// GET_PARAM_VALID_SHAPE_BY_IDX(param, n, base, dim, idx)
std::string CodeGenNPU::GenDynParamForExpr(const Function& func) const
{
    if (!func.IsUnderDynamicFunction()) {
        return {};
    }
    std::string dynParamList;
    for (const auto& dynParam : func.GetDynParamTable()) {
        if (dynParam.second.replacedSymbol.empty()) {
            std::string dynParamExpr = "uint64_t " + dynParam.first + " = ";
            DynParamInfo info = dynParam.second;
            if (info.dim.IsValid()) {
                dynParamExpr.append(SymbolicExpressionTable::BuildExpression(info.dim)).append("; //");
            }
            if (info.type == DynParamInfoType::VALID_SHAPE) {
                dynParamExpr.append(GET_PARAM_VALID_SHAPE_BY_IDX);
            } else if (info.type == DynParamInfoType::OFFSET) {
                dynParamExpr.append(GET_PARAM_OFFSET_BY_IDX);
            }
            std::string params = BuildDynParamInfo(info);
            dynParamExpr.append(params).append(STMT_END);
            dynParamList.append(dynParamExpr);
        }
    }
    for (const auto& dynParam : func.GetDynParamTable()) {
        if (!dynParam.second.replacedSymbol.empty()) {
            std::string dynParamExpr = "uint64_t " + dynParam.first + " = ";
            dynParamExpr.append(dynParam.second.replacedSymbol).append(STMT_END);
            dynParamList.append(dynParamExpr);
        }
    }
    return dynParamList;
}

void CodeGenNPU::GenCode(
    Function& topFunc, [[maybe_unused]] const std::map<uint64_t, std::list<InvokeParaOffset>>& invokeParaOffset)
{
    COMPILER_LOGI(
        "Start Generate AI_CORE code for topFunc: %s, hash: %s", topFunc.GetMagicName().c_str(),
        topFunc.GetFunctionHash().c_str());

    Prepare(topFunc);

    std::deque<std::function<void(void)>> tasks;
    for (auto& subFuncPair : topFunc.rootFunc_->programs_) {
        std::function task = [this, subFuncPair, &topFunc]() {
            CODEGEN_LOGI(" ----- subprogram id [%lu] -----", subFuncPair.first);
            auto subFunc = subFuncPair.second;
            if (HandleForAICpuSubFunc(*subFunc)) {
                return;
            }
            bool isCube = subFunc->IsCube();
            CompileInfo compileInfo(topFunc, ctx, subFuncPair, isCube);
            std::ostringstream leafKernelFunc;
            GenFuncBodyBefore(subFuncPair, topFunc, compileInfo, leafKernelFunc);
            GenFuncBody(*subFunc, topFunc, leafKernelFunc);
            GenFuncEnd(leafKernelFunc);
#ifdef BUILD_WITH_CANN
            if (std::getenv(ENV_ASCEND_HOME_PATH.c_str()) != nullptr) {
                GenCodeToBinaryTask(leafKernelFunc, compileInfo, "");
            }
#endif
            UpdateSubFunc(subFuncPair, compileInfo);
        };
        tasks.push_back(task);
    }
    unsigned threadNum = GetCGThreadNum();
    ParallelExecuteAndWait(threadNum, tasks);

#ifdef BUILD_WITH_CANN
    if (std::getenv(ENV_ASCEND_HOME_PATH.c_str()) != nullptr) {
        ExecuteParallelCompile(topFunc);
    }
#endif
}

void CodeGenNPU::UpdateSubFunc(std::pair<uint64_t, Function*> subFuncPair, const CompileInfo& compileInfo) const
{
    auto leafFunc = subFuncPair.second;
    std::shared_ptr<LeafFuncAttribute> attr = leafFunc->GetLeafFuncAttribute();
    if (attr == nullptr) {
        attr = std::make_shared<LeafFuncAttribute>();
    }

    if (ctx.isMainBlock) {
        attr->kernelNameMainBlock = compileInfo.GetKernelName();
        attr->binPathMainBlock = compileInfo.GetBinAbsPath();
        attr->kernelDeclareMainBlock = compileInfo.GetFuncDeclare();
    } else {
        attr->kernelName = compileInfo.GetKernelName();
        attr->binPath = compileInfo.GetBinAbsPath();
        attr->kernelDeclare = compileInfo.GetFuncDeclare();
    }
    CoreType coreType = compileInfo.IsCube() ? CoreType::AIC : CoreType::AIV;
    attr->coreType = coreType;
    leafFunc->SetLeafFuncAttribute(attr);
}

int CheckInjectStr(const char cmdStr[], size_t strLen)
{
    if (cmdStr == nullptr) {
        return -1;
    }
    char filtChar[] = {';', '|', '`', '>', '<'};
    for (size_t i = 0; i < strLen; ++i) {
        for (const auto& c : filtChar) {
            if (cmdStr[i] == c) {
                return -1;
            }
        }
    }
    return 0;
}

std::string CodeGenNPU::PrepareCmd(const CompileInfo& compileInfo, const std::string& compileOptions) const
{
    std::ostringstream oss;
    oss << "bisheng -c -O3 -g -x cce -std=c++17 ";
    BuildArchOptions(oss, compileInfo);
    BuildIncludes(oss);
    BuildExtraOptions(oss, compileOptions);

    const std::string srcFile = compileInfo.GetCCEAbsPath();
    const std::string objFile = compileInfo.GetBinAbsPath();
    oss << "-o " << objFile << " " << srcFile;

    std::string compileCmd = oss.str();

    CODEGEN_LOGI_FULL("compile kernel...\n%s", compileCmd.c_str());
    return compileCmd;
}

void CodeGenNPU::GenCodeToBinaryTask(
    std::ostringstream& code, const CompileInfo& compileInfo, const std::string& compileOptions) const
{
    std::string compileCmd = PrepareCmd(compileInfo, compileOptions);
    code << "\n\n\n// kernel compilation command:\n// " << compileCmd << "\n";
    DumpCode(compileInfo.GetCCEAbsPath(), code);

    CompileTaskInfo task;
    task.outputPath = compileInfo.GetBinAbsPath();
    task.inputPath = compileInfo.GetCCEAbsPath();
    task.compileCmd = compileCmd;

    CollectCompileTask(task);
}

bool CodeGenNPU::IsNeedDumpCode(const std::string& inputFile) const
{
    if (ConfigManager::Instance().GetCodeGenConfig(KEY_FORCE_OVERWRITE, true)) {
        // force dump, default is true
        return true;
    }
    // not force dump
    if (FileExist(inputFile)) {
        return false;
    }
    return true;
}

void CodeGenNPU::DumpCode(const std::string& fileName, std::ostringstream& code) const
{
    if (!IsNeedDumpCode(fileName)) {
        return;
    }

    std::ofstream codeFile;
    try {
        codeFile.exceptions(std::ofstream::failbit | std::ofstream::badbit);
        codeFile.open(fileName);
        codeFile << code.str();
        codeFile.flush();
        codeFile.close();
    } catch (const std::ofstream::failure& e) {
        CODEGEN_LOGE(
            CmpCodeErr::FILE_IO_FAILED, "Code file operation failed: %s, error: %s, errno: %d", fileName.c_str(),
            e.what(), errno);
        codeFile.close();
        std::remove(fileName.c_str());
        return;
    }
}

std::optional<std::string> CodeGenNPU::GenExtraAlloc(
    const std::shared_ptr<SymbolManager>& symbolMgr, const std::shared_ptr<LogicalTensor>& tensor) const
{
    auto memType = tensor->GetMemoryTypeOriginal();
    if (OPERAND_TYPE_TO_MEMORY_TYPE.find(memType) == OPERAND_TYPE_TO_MEMORY_TYPE.end()) {
        CODEGEN_LOGE(
            OperErr::OPERAND_TYPE_UNSUPPORTED, " memory type(%u) of tensor from PASS is invalid, tensor is: %s",
            ToUnderlying(memType), tensor->Dump().c_str());
        return std::nullopt;
    }

    const TileRange& memRange = tensor->memoryrange;
    auto bufferType = OPERAND_TYPE_TO_MEMORY_TYPE.at(memType);

    return GenAlloc(symbolMgr, bufferType, tensor->Datatype(), memRange);
}

// NEXTNEXT: After TileTensor mode is applied to all TileOp, retain just one
std::pair<std::string, std::string> GenAllocVarName(const std::string& prefix, const TileRange& range)
{
    std::ostringstream ss;
    ss << prefix
       // range start/end are always positive
       << "_S" << range.start << "_E" << range.end;

    std::string varName = ss.str();
    // Normal mode: e.g. UB_S0_E1024
    // TileTensor mode: e.g. UB_S0_E1024_T
    return std::make_pair(varName, varName + "_T");
}

std::string CodeGenNPU::GenAlloc(
    const std::shared_ptr<SymbolManager>& sm, BufferType bufferType, DataType dataType, const TileRange& range) const
{
    if ((BUFFER_TYPE_TO_PREFIX.count(bufferType) == 0) || (OPERAND_TYPE_TO_ADDR_TYPE.count(bufferType) == 0)) {
        ASSERT(OperErr::OPERAND_TYPE_UNSUPPORTED, false) << "invalid bufferType: " << static_cast<size_t>(bufferType);
        return "";
    }

    const std::string prefix = BUFFER_TYPE_TO_PREFIX.at(bufferType);
    const std::string& addrSpaceQualifier = OPERAND_TYPE_TO_ADDR_TYPE.at(bufferType);
    auto [allocVarName, allocVarNameTileTensor] = GenAllocVarName(prefix, range);

    // must conform to CodeGenOpNPU::createAllocKey
    AllocKey key = AllocKey(bufferType, range.start, range.end);
    bool reuse = sm->BindAddrWithVariableName(key, allocVarName, allocVarNameTileTensor);
    if (reuse) {
        return "";
    }

    CODEGEN_LOGI("bind key to name: %s->%s", sm->FormatAllocKey(key).c_str(), allocVarName.c_str());

    std::string dataTypeStr = DataType2CCEStr(dataType);

    std::ostringstream oss;
    oss << dataTypeStr << " " << addrSpaceQualifier << " *" << allocVarName << " = (" << dataTypeStr << " "
        << addrSpaceQualifier << " *)get_imm(0x" << std::hex << static_cast<unsigned>(range.start) << "); // size: 0x"
        << std::hex << static_cast<unsigned>(range.Size()) << "\n";

    if (ConfigManager::Instance().GetCodeGenConfig(KEY_CODEGEN_SUPPORT_TILE_TENSOR, false)) {
        oss << dataTypeStr << " *" << allocVarNameTileTensor << " = (" << dataTypeStr << " *)get_imm(0x" << std::hex
            << static_cast<unsigned>(range.start) << "); // size: 0x" << std::hex << static_cast<unsigned>(range.Size())
            << "\n";
    }

    return oss.str();
}

void CodeGenNPU::CompileCode(const std::string& compileCmd) const
{
    if (config::GetHostOption<int64_t>(COMPILE_STAGE) == CS_CODEGEN_INSTRUCTION) {
        CODEGEN_LOGI("Compile stage terminates after codegen instruction.");
        return;
    }
    int ret = DoCompileCmd(compileCmd);
    ASSERT(CmpCodeErr::COMPILE_CODE_FAILED, ret == 0)
        << "DoCompileCmd failed. errCode = " << ret << "\n******** bisheng compiling cmd start ********\n"
        << compileCmd << "\n******** bisheng compiling cmd end ********\n";
}

std::string GetIncludePathByLib()
{
    std::string libPath = GetCurrentSharedLibPath();
    if (libPath.empty()) {
        return "";
    }

    std::string includePath = libPath + "/include";
    CODEGEN_LOGI("includePath by lib is %s", includePath.c_str());

    if (IsPathExist(includePath)) {
        return includePath;
    }

    return "";
}

std::string CodeGenNPU::GetIncludePathForCompileCCE() const
{
    if (!ctx.IsIncludePathEmpty()) {
        CODEGEN_LOGI("include path from ctx is %s", ctx.includePath.c_str());
        return ctx.includePath;
    }

    std::string includePathByLib = GetIncludePathByLib();
    CODEGEN_LOGI("includePathByLib is %s", includePathByLib.c_str());
    if (!includePathByLib.empty()) {
        return includePathByLib;
    }

    ASSERT(CmpCodeErr::INCLUDE_FILE_NOT_FOUND, false) << "include path for compiling cce is unavailable";
    return "";
}

std::string CodeGenNPU::GetPtoTileLibPathByEnv() const
{
    if (!ConfigManager::Instance().GetCodeGenConfig(KEY_CODEGEN_SUPPORT_TILE_TENSOR, false)) {
        return "";
    }

    // Priority 1: Obtain pto-isa from the patch specified by the environment variable "PTO_TILE_LIB_CODE_PATH".
    const char* homePath = std::getenv(ENV_PTO_TILE_LIB_CODE_PATH.c_str());
    if (homePath != nullptr) {
        std::string envPath = std::string(homePath) + "/include";
        ASSERT(CmpCodeErr::PTO_ISA_NOT_FOUND, IsPathExist(envPath + "/pto"))
            << "Pto-isa path " << envPath << "/pto not found! please check.";
        return envPath;
    }

    // Priority 2: Obtain pto-isa from the installed cann package.
    homePath = std::getenv(ENV_ASCEND_HOME_PATH.c_str());
    if (homePath != nullptr) {
        std::string cannPath = std::string(homePath) + "/include";
        ASSERT(CmpCodeErr::PTO_ISA_NOT_FOUND, IsPathExist(cannPath + "/pto"))
            << "Pto-isa path " << cannPath << "/pto not found! please check.";
        return cannPath;
    }

    ASSERT(CmpCodeErr::PTO_ISA_NOT_FOUND, false) << "Pto-isa path not found. please install pto-isa properly.";
    return "";
}

void CodeGenNPU::BuildArchOptions(std::ostringstream& oss, const CompileInfo& compileInfo) const
{
    const std::string corePredefine = compileInfo.IsCube() ? "-D__AIC__" : "-D__AIV__";

    std::vector<std::string> compileOpts{corePredefine};
    if (ConfigManager::Instance().GetCodeGenConfig(KEY_CODEGEN_SUPPORT_TILE_TENSOR, false)) {
        compileOpts.emplace_back("-DSUPPORT_TILE_TENSOR");
    }
    if (config::GetPlatformConfig(KEY_ENABLE_PROF_AICORE_TIME, false) ||
        config::GetDebugOption<int64_t>(CFG_RUNTIME_DBEUG_MODE) == CFG_DEBUG_ALL) {
        compileOpts.emplace_back("-DOPEN_MIX_PERF");
    }

    if (platform_ == NPUArch::DAV_2201) {
        compileOpts.emplace_back("-D__DAV_V220");
        compileOpts.emplace_back("-DMEMORY_BASE");
    } else {
        compileOpts.emplace_back("-D__DAV_V310");
        compileOpts.emplace_back("-DREGISTER_BASE");
    }

    compileOpts.emplace_back("--cce-aicore-only");
    std::string coreArch = GetCoreArch(compileInfo);
    compileOpts.emplace_back("--cce-aicore-arch=" + coreArch);

    if (platform_ == NPUArch::DAV_3510 && !compileInfo.IsCube()) {
        compileOpts.emplace_back("--cce-long-scbz=true");
    }

    std::string allCompileOpts = JoinString(compileOpts, " ");
    oss << allCompileOpts << " ";
}

void CodeGenNPU::BuildIncludes(std::ostringstream& oss) const
{
    // used for compiling cce
    std::string ptoTileLibPath = GetPtoTileLibPathByEnv();
    if (!ptoTileLibPath.empty()) {
        oss << "-I" << ptoTileLibPath << " ";
    }

    std::string includePath = GetIncludePathForCompileCCE();
    oss << "-I" << includePath << "/tilefwk "
        << "-I" << includePath << "/tileop "
        << "-I" << includePath << "/tileop/arch32 "
        << "-I" << includePath << " ";
}

void CodeGenNPU::AppendVFOptions(NPUArch platform, std::ostringstream& oss)
{
    if (platform != NPUArch::DAV_3510) {
        return;
    }

    if (!config::GetPassGlobalConfig(KEY_ENABLE_VF, true)) {
        oss << "--cce-simd-vf-fusion=false ";
        return;
    }

    oss << "--enable-pto-tile-fusion "
        << "-mllvm --tile-fusion-skip-shape-inference=true "
        << "-mllvm --tile-fusion-skip-reduceop-fusion=false "
        << "-mllvm --tile-fusion-skip-legality-check=false "
        << "-mllvm -cce-vf-fusion-max-candidate-set-threshold=32 ";
    if (config::GetPassGlobalConfig(KEY_ENABLE_VF_UNROLL, false)) {
        oss << "-mllvm -enable-unroll-after-fused=true ";
    }
}

void CodeGenNPU::BuildExtraOptions(std::ostringstream& oss, const std::string& compileOptions) const
{
    oss << "-mllvm -cce-aicore-stack-size=0x8000 "
        << "-mllvm -cce-aicore-function-stack-size=0x8000 "
        << "-mllvm -cce-aicore-record-overflow=false "
        << "-mllvm -cce-aicore-addr-transform "
        << "-mllvm -cce-aicore-dcci-insert-for-scalar=false ";
    AppendVFOptions(platform_, oss);
    oss << compileOptions << " ";
}

std::string CodeGenNPU::GetCoreArch(const CompileInfo& compileInfo) const
{
    bool isCude = compileInfo.IsCube();
    if (platform_ == NPUArch::DAV_2201) {
        return isCude ? "dav-c220-cube" : "dav-c220-vec";
    } else {
        return isCude ? "dav-c310-cube" : "dav-c310-vec";
    }
}

int CodeGenNPU::DoCompileCmd(const std::string& compileCmd) const
{
    int ret = CheckInjectStr(compileCmd.c_str(), compileCmd.length());
    ASSERT(CmpCodeErr::CMD_CHECK_FAILED, ret == 0)
        << "CheckInjectStr failed. errCode = " << ret << ", compileCmd is " << compileCmd;

    int rootFuncIdx = MonitorManager::Instance().PrepareNextRootFunc();
    {
        MonitorStageScope compileCmdScope(STAGE_FUNC_TO_BIN, rootFuncIdx, rootFuncName_, rootFuncOpSize_);
        // rootFuncOpSize_ is leafFuncOpSize
        MonitorManager::Instance().SetFuncSumOpSize(rootFuncOpSize_);
        ret = std::system(compileCmd.c_str());
    }
    if (ret != 0) {
        CODEGEN_LOGE(
            CmpCodeErr::COMPILE_CODE_FAILED, "kernel compilation failed, ret = %d\ncompile cmd is:\n %s", ret,
            compileCmd.c_str());
    }
    return ret;
}

void CodeGenNPU::Prepare(const Function& topFunc)
{
    compileTasks_.clear();
    rootFuncName_ = topFunc.GetMagicName();
    int leafFuncOpSize = 0;
    for (auto& [psgId, leaf] : topFunc.rootFunc_->programs_) {
        (void)psgId;
        leafFuncOpSize += static_cast<int>(leaf->GetOperationSize());
    }
    rootFuncOpSize_ = leafFuncOpSize;
}

void EncodeWaitUntilInfo(const Operation& op, std::vector<int32_t>& code)
{
    constexpr int32_t paramSizePerOperand = 2; // waitUntil编码每个operand的2个属性:dim和coaIndex
    code.push_back(op.GetOOperands().size() * paramSizePerOperand);
    for (size_t i = 0; i < op.GetOOperands().size(); ++i) {
        code.push_back(op.GetOutputOperand(i)->shape.size());
        code.push_back(op.GetOOpAttrOffset(i));
    }

    code.push_back(op.GetIOperands().size() * paramSizePerOperand);
    for (size_t i = 0; i < op.GetIOperands().size(); ++i) {
        code.push_back(op.GetInputOperand(i)->shape.size());
        code.push_back(op.GetIOpAttrOffset(i));
    }
    // waitUntil OP有2个输入,下标0是dummy控制边,下标1是signal
    // 编码signal的rawShape
    code.push_back(op.GetInputOperand(1)->GetRawTensor()->rawshape.size() * paramSizePerOperand);
    for (auto dimShape : op.GetInputOperand(1)->GetRawTensor()->GetRawShape()) {
        code.push_back(dimShape);
    }
    // 编码signal的shape
    for (auto dimShape : op.GetInputOperand(1)->GetShape()) {
        code.push_back(dimShape);
    }
    // 编码waitUntil的attr属性,顺序固定
    std::map<std::string, Any> map = op.GetAllAttribute();
    auto it = map.find(OpAttributeKey::distOpAttr);
    if (it != map.end()) {
        Distributed::ShmemWaitUntilAttr distAttr = AnyCast<Distributed::ShmemWaitUntilAttr>(it->second);
        std::vector<int32_t> attrs;
        attrs.push_back(static_cast<int32_t>(distAttr.expectedSum));
        attrs.push_back(static_cast<int32_t>(distAttr.signalStride));
        attrs.push_back(static_cast<int32_t>(distAttr.resetSignal));
        attrs.push_back(static_cast<int32_t>(distAttr.tileShape.size()));
        for (size_t i = 0; i < distAttr.tileShape.size(); ++i) {
            attrs.push_back(static_cast<int32_t>(distAttr.tileShape[i]));
        }

        for (size_t i = 0; i < distAttr.viewshapes.size(); ++i) {
            attrs.push_back(static_cast<int32_t>(distAttr.viewshapes[i]));
        }

        for (size_t i = 0; i < distAttr.viewTileStrides.size(); ++i) {
            attrs.push_back(static_cast<int32_t>(distAttr.viewTileStrides[i]));
        }

        for (size_t i = 0; i < distAttr.viewIndexStrides.size(); ++i) {
            attrs.push_back(static_cast<int32_t>(distAttr.viewIndexStrides[i]));
        }

        attrs.push_back(static_cast<int32_t>(distAttr.viewTileNum));
        attrs.push_back(static_cast<int32_t>(distAttr.totalTileNum));

        code.push_back(static_cast<int32_t>(attrs.size()));
        code.insert(code.end(), attrs.begin(), attrs.end());
    }
}

bool CodeGenNPU::HandleForAICpuSubFunc(Function& subFunc)
{
    if (!subFunc.IsAicpuSubFunction().first) {
        return false;
    }
    std::vector<int32_t> code;

    auto operationList = subFunc.Operations(false);
    for (const auto& op : operationList) {
        if (op.GetCoreType() != CoreType::AICPU) {
            continue;
        }
        code.push_back(static_cast<int32_t>(op.GetOpcode()));

        if (op.GetOpcode() == Opcode::OP_SHMEM_WAIT_UNTIL) {
            EncodeWaitUntilInfo(op, code);
        }
    }
    if (code.size() % 2 != 0) { // 确保 code.size() 是 2 的倍数,间接保证 code 占用的字节数是 8 的倍数
        code.push_back(0);
    }
    std::shared_ptr<LeafFuncAttribute> attr = std::make_shared<LeafFuncAttribute>();
    attr->coreType = CoreType::AICPU;
    attr->aicpuLeafCode = std::move(code);
    subFunc.SetLeafFuncAttribute(attr);
    return true;
}

void CodeGenNPU::CollectCompileTask(const CompileTaskInfo& task) const
{
    std::lock_guard<std::mutex> lock(compileTasksMutex_);
    compileTasks_.push_back(task);
}

std::string CodeGenNPU::GetOutputDir() const
{
    if (!compileTasks_.empty()) {
        const std::string& path = compileTasks_[0].outputPath;
        size_t pos = path.find_last_of('/');
        if (pos != std::string::npos) {
            return path.substr(0, pos);
        }
    }
    return ".";
}

void CodeGenNPU::GenerateMakefile(const std::string& makefilePath) const
{
    std::ofstream makefile(makefilePath);
    if (!makefile.is_open()) {
        ASSERT(CmpCodeErr::FILE_IO_FAILED, false) << "Failed to create Makefile: " << makefilePath;
        return;
    }

    makefile << "# Auto-generated Makefile for parallel compilation\n";

    makefile << "all: ";
    for (const auto& task : compileTasks_) {
        makefile << task.outputPath << " ";
    }
    makefile << "\n\n";

    for (const auto& task : compileTasks_) {
        makefile << task.outputPath << ": " << task.inputPath << "\n";
        makefile << "\t@" << task.compileCmd << "\n\n";
    }

    makefile << ".PHONY: all clean\n\n";
    makefile << "clean:\n";
    makefile << "\t@rm -f ";
    for (const auto& task : compileTasks_) {
        makefile << task.outputPath << " ";
    }
    makefile << "\n";

    makefile.close();
    CODEGEN_LOGI("Generated Makefile: %s with %zu tasks", makefilePath.c_str(), compileTasks_.size());
}

void CodeGenNPU::ExecuteParallelCompile(const Function& topFunc)
{
    if (config::GetHostOption<int64_t>(COMPILE_STAGE) == CS_CODEGEN_INSTRUCTION) {
        CODEGEN_LOGI("Compile stage terminates after codegen instruction.");
        return;
    }

    if (compileTasks_.empty()) {
        CODEGEN_LOGI("No compile tasks, skip parallel compilation");
        return;
    }

    std::ostringstream makeCmd;
    makeCmd << "make -j";

    unsigned parallelJobs = GetCGThreadNum();
    makeCmd << std::to_string(parallelJobs);

    std::string makefilePath = GetOutputDir();
    makefilePath.append("/Makefile_")
        .append(std::to_string(topFunc.GetFuncMagic()))
        .append("_")
        .append(topFunc.GetFunctionHash())
        .append(".compile");

    GenerateMakefile(makefilePath);

    makeCmd << " -f " << makefilePath;

    CODEGEN_LOGI(
        "Top Function magic: %d, hash: %s: Starting parallel compilation: %u jobs, %zu tasks", topFunc.GetFuncMagic(),
        topFunc.GetFunctionHash().c_str(), parallelJobs, compileTasks_.size());
    CODEGEN_LOGI("Execute: %s", makeCmd.str().c_str());

    auto startTime = std::chrono::high_resolution_clock::now();
    int ret = DoCompileCmd(makeCmd.str());
    auto endTime = std::chrono::high_resolution_clock::now();
    auto duration = std::chrono::duration<double, std::milli>(endTime - startTime);
    CODEGEN_LOGI(
        "Top Function magic: %d, hash: %s: Parallel compilation finished in %f ms", topFunc.GetFuncMagic(),
        topFunc.GetFunctionHash().c_str(), duration.count());

    ASSERT(CmpCodeErr::COMPILE_CODE_FAILED, ret == 0) << "Parallel compilation failed with return code: " << ret;

    compileTasks_.clear();
}

} // namespace npu::tile_fwk