/**

 * Copyright (c) 2026 Huawei Technologies Co., Ltd.

 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of

 * CANN Open Software License Agreement Version 2.0 (the "License").

 * Please refer to the License for details. You may not use this file except in compliance with the License.

 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,

 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.

 * See LICENSE in the root of the software repository for the full text of the License.

 */



/*!

 * \file codegen_mte_gather.cpp

 * \brief

 */

#include <iterator>

#include <string>



#include "codegen_op_npu.h"

#include "codegen/symbol_mgr/codegen_symbol.h"

#include "codegen/utils/codegen_utils.h"

#include "securec.h"



namespace npu::tile_fwk {

std::string CodeGenOpNPU::PrintGatherInL1TileTensor() const

{

    std::string srcVar = QueryTileTensorNameByIdx(ToUnderlying(MISOIdx::SRC0_IDX));

    std::string offsetsVar = QueryTileTensorNameByIdx(ToUnderlying(MISOIdx::SRC1_IDX));

    std::string blockTableVar = QueryTileTensorNameByIdx(ToUnderlying(MISOIdx::SRC2_IDX));

    std::string dstVar = QueryTileTensorNameByIdx(ToUnderlying(MISOIdx::DST_IDX));

    int64_t blockSize = AnyCast<int64_t>(opAttrs.at("op_attr_blocksize"));



    auto startOffset = opAttrs.at(OpAttributeKey::startOffset);

    ASSERT(OperErr::ATTRIBUTE_INVALID, startOffset.HasValue() && (startOffset.Type() == typeid(int64_t)))

        << "GenGatherInL1 startOffset must be int64_t!";

    auto srcColumnStartOffset = AnyCast<int64_t>(startOffset);

    std::string srcCoordCp = WrapParamByParentheses({std::to_string(srcColumnStartOffset)});

    std::string srcCoord = PrintCoord(SHAPE_DIM1, srcCoordCp);



    auto offsetsStartOffsets = GenParamIdxExprByIndex(ID2, SHAPE_DIM2, PREFIX_STR_OFFSET);

    std::string offsetCoordCp = WrapParamByParentheses(offsetsStartOffsets);

    std::string offsetCoord = PrintCoord(SHAPE_DIM2, offsetCoordCp);



    auto blockTableStartOffsets = GenParamIdxExprByIndex(ID3, SHAPE_DIM2, PREFIX_STR_OFFSET);

    std::string blockTableCoordCp = WrapParamByParentheses(blockTableStartOffsets);

    std::string blockTableCoord = PrintCoord(SHAPE_DIM2, blockTableCoordCp);



    std::ostringstream oss;

    oss << tileOpName;

    oss << WrapParamByAngleBrackets({std::to_string(blockSize)});

    oss << WrapParamByParentheses({dstVar, srcVar, blockTableVar, offsetsVar, srcCoord, offsetCoord, blockTableCoord});

    oss << STMT_END;

    return oss.str();

}



std::string CodeGenOpNPU::GenGatherInL1() const

{

    if (isSupportLayout) {

        return PrintGatherInL1TileTensor();

    }

    const DataType dstDtype = operandDtype[ID0];

    const DataType srcDtype = operandDtype[ID1];

    const DataType offsetsDtype = operandDtype[ID2];

    ASSERT(GenCodeErr::DATA_TYPE_MISMATCHED, dstDtype == srcDtype) << "dstDtype and srcDtype must be same!";



    std::string srcVar = GenGmParamVar(ID1);

    std::string offsetsVar = GenGmParamVar(ID2);

    std::string blockTableVar = GenGmParamVar(ID3);

    std::string dstVar = sm->QueryVarNameByTensorMagic(operandWithMagic[ID0]);



    auto dstRawShapes = rawShape[ID0];

    auto srcRawShapes = rawShape[ID1];

    auto offsetsRawShapes = rawShape[ID2];

    auto dstOriShapes = dynamicValidShape[ID0];

    ASSERT(GenCodeErr::TENSOR_DIM_UNSUPPORTED, srcRawShapes.size() == SHAPE_DIM2)

        << "GenGatherInL1 only support 2-dim!";

    ASSERT(GenCodeErr::TENSOR_DIM_UNSUPPORTED, dstRawShapes.size() == SHAPE_DIM2)

        << "GenGatherInL1 only support 2-dim!";

    ASSERT(GenCodeErr::TENSOR_DIM_UNSUPPORTED, offsetsRawShapes.size() == SHAPE_DIM2)

        << "GenGatherInL1 only support 2-dim!";

    ASSERT(GenCodeErr::TENSOR_DIM_UNSUPPORTED, dstOriShapes.size() == SHAPE_DIM2)

        << "GenGatherInL1 only support 2-dim!";



    auto offsetsStartOffsets = GenParamIdxExprByIndex(ID2, SHAPE_DIM2, PREFIX_STR_OFFSET);



    char buffer[BUFFER_SIZE_1024] = "CG_ERROR";

    std::string dstDtypeStr = DataType2CCEStr(dstDtype);

    std::string srcDtypeStr = DataType2CCEStr(srcDtype);

    std::string offsetsDtypeStr = DataType2CCEStr(offsetsDtype);

    std::string blockTableDtypeStr = DataType2CCEStr(operandDtype[ID3]);



    ASSERT(GenCodeErr::DATA_TYPE_MISMATCHED, dstDtypeStr == srcDtypeStr) << "dstDtypeStr and srcDtypeStr must be same!";

    ASSERT(GenCodeErr::DATA_TYPE_UNSUPPORTED, offsetsDtypeStr == "int64_t" || offsetsDtypeStr == "int32_t")

        << "offsetsDtypeStr must be int64_t or int32_t!";

    ASSERT(OperErr::ATTRIBUTE_INVALID, opAttrs.find("op_attr_blocksize") != opAttrs.end())

        << "GenGatherOp: There is nop blocksize attribute here";

    const int64_t blockSize = AnyCast<int64_t>(opAttrs.at("op_attr_blocksize"));



    auto startOffset = opAttrs.at(OpAttributeKey::startOffset);

    ASSERT(OperErr::ATTRIBUTE_INVALID, startOffset.HasValue() && (startOffset.Type() == typeid(int64_t)))

        << "GenGatherInL1 startOffset must be int64_t!";

    auto srcColumnStartOffset = AnyCast<int64_t>(startOffset);

    auto blockTableGMStride = GenParamIdxExprByIndex(ID3, SHAPE_DIM2, PREFIX_STR_RAW_SHAPE);

    auto blockTableStartOffsets = GenParamIdxExprByIndex(ID3, SHAPE_DIM2, PREFIX_STR_OFFSET);



    auto ret = sprintf_s(

        buffer, sizeof(buffer),

        "%s<%s, %s, %s, %lld, %lld, %lld, %lld>((__cbuf__ %s *)%s, %s, %s, (__gm__ %s *)%s, %lld, (__gm__ %s *)%s, "

        "(__gm__ %s *)%s, %s, %s, %s, %s, %s);\n",

        tileOpName.c_str(), dstDtypeStr.c_str(), offsetsDtypeStr.c_str(), blockTableDtypeStr.c_str(), dstRawShapes[ID0],

        offsetsRawShapes[ID1], srcColumnStartOffset, blockSize, dstDtypeStr.c_str(), dstVar.c_str(),

        dstOriShapes[ID0].Dump().c_str(), dstOriShapes[ID1].Dump().c_str(), srcDtypeStr.c_str(), srcVar.c_str(),

        srcRawShapes[1], offsetsDtypeStr.c_str(), offsetsVar.c_str(), blockTableDtypeStr.c_str(), blockTableVar.c_str(),

        offsetsStartOffsets[ID0].c_str(), offsetsStartOffsets[ID1].c_str(), blockTableGMStride[ID1].c_str(),

        blockTableStartOffsets[ID0].c_str(), blockTableStartOffsets[ID1].c_str());



    ASSERT(GenCodeErr::PRINT_FAILED, ret >= 0) << "GenGatherInL1 sprintf_s failed ";

    std::string ostring(buffer);

    return ostring;

}

/**

 * 辅助函数,对 axis 参数进行归一化

 * example:

 * parma [a,b]

 * axis 0

 * 归一化后

 * parma [1,1,a,b]

 * axis 2

 *

 */

inline int NormalizeAxis(int axis, int paramDim) { return axis + (SHAPE_DIM4 - paramDim); }

/**

 * 归一化的 gather 参数维度

 * example:

 * param [a,b]

 * indices [c]

 * axis 1

 * result [a,c]

 * 归一化后:

 * [1,1,a,b]

 * [1,c]

 * axis 3

 * result  [1,1,a,1,c]

 * 处理逻辑:

 * 1. 根据result的形状,还原出来的 param 和 indices 的维度

 * 2. param 归一化四维,indices 归一化到 两维

 * 3. 重新拼装处 result 形状

 */

template <typename T>

void NormalizeGatherShape(std::vector<T>& rawShape, const int paramDim, const int indicesDim, const int axis)

{

    bool isValidDType = (std::is_same_v<T, int64_t> || std::is_same_v<T, SymbolicScalar>);

    ASSERT(GenCodeErr::DATA_TYPE_UNSUPPORTED, isValidDType) << "T must be int64_t or SymbolicScalar";

    std::vector<T> paramShape{};

    std::vector<T> indicesShape{};

    indicesShape.assign(rawShape.begin() + axis, rawShape.begin() + axis + indicesDim);

    paramShape.assign(rawShape.begin(), rawShape.begin() + axis);

    paramShape.push_back(-1);

    paramShape.insert(paramShape.end(), rawShape.begin() + axis + indicesDim, rawShape.end());

    if constexpr (std::is_same_v<T, int64_t>) {

        paramShape = NormalizeShape(paramShape, SHAPE_DIM4);

        indicesShape = NormalizeShape(indicesShape, SHAPE_DIM2);

    } else if constexpr (std::is_same_v<T, SymbolicScalar>) {

        FillVecWithDummyInHead<SymbolicScalar>(paramShape, SHAPE_DIM4 - paramDim, 1);

        FillVecWithDummyInHead<SymbolicScalar>(indicesShape, SHAPE_DIM2 - indicesDim, 1);

    }

    rawShape = paramShape;

    int normalizedAxis = NormalizeAxis(axis, paramDim);

    rawShape.erase(rawShape.begin() + normalizedAxis);

    rawShape.insert(rawShape.begin() + normalizedAxis, indicesShape.begin(), indicesShape.end());

}

void HelpNormalize(std::vector<size_t>& index, int axis, int paramDim)

{

    size_t delNum = NUM4 - paramDim;

    index.erase(index.begin() + delNum);

    axis = NormalizeAxis(axis, paramDim);

    index.insert(index.begin() + axis, delNum);

}

std::string CodeGenOpNPU::PrintGatherDynamicUnaligned() const

{

    std::vector dstShape = rawShape[0];

    std::vector src0Shape = rawShape[1];



    std::string resultDtypeStr = DataType2CCEStr(operandDtype[ID0]);

    std::string paramDtypeStr = DataType2CCEStr(operandDtype[ID1]);

    std::string indicesDtypeStr = DataType2CCEStr(operandDtype[ID2]);

    ASSERT(GenCodeErr::DATA_TYPE_MISMATCHED, resultDtypeStr == paramDtypeStr)

        << "resultDtypeStr: " << resultDtypeStr << ", paramDtypeStr: " << paramDtypeStr;

    const int64_t axis = AnyCast<int64_t>(opAttrs.at("op_attr_axis"));

    auto outputRawShapes = rawShape[ID0];

    auto paramRawShapes = rawShape[ID1];

    auto indicesRawShapes = rawShape[ID2];

    auto outputValidShapes = dynamicValidShape[ID0];

    auto paramValidShapes = dynamicValidShape[ID1];

    auto indicesValidShapes = dynamicValidShape[ID2];

    const int paramDim = paramRawShapes.size();

    const int indicesDim = indicesRawShapes.size();

    constexpr int paramIndex = 1;

    constexpr int indicesIndex = 2;

    auto normalizedOutputRawShapes = outputRawShapes;

    NormalizeGatherShape<int64_t>(normalizedOutputRawShapes, paramDim, indicesDim, axis);

    std::ostringstream os;

    std::vector<std::string> paramList;

    paramList.emplace_back(paramDtypeStr);

    paramList.emplace_back(indicesDtypeStr);

    paramList.emplace_back(std::to_string(NormalizeAxis(axis, paramDim)));

    std::transform(

        normalizedOutputRawShapes.begin() + 1, normalizedOutputRawShapes.end(), back_inserter(paramList),

        [](int x) { return std::to_string(x); });



    std::string templateParam = JoinString(paramList, ", ");

    paramList.clear();



    std::string paramVar = GenGmParamVar(paramIndex);

    std::string indicesVar = GenGmParamVar(indicesIndex);

    std::string outputVar = sm->QueryVarNameByTensorMagic(operandWithMagic[ID0]);

    std::string outputParamStr = "(__ubuf__ " + resultDtypeStr + "*)" + outputVar;

    std::string paramParamStr = "(__gm__ " + paramDtypeStr + "*)" + paramVar;

    std::string indicesParamStr = "(__gm__ " + indicesDtypeStr + "*)" + indicesVar;

    paramList.emplace_back(outputParamStr);

    paramList.emplace_back(paramParamStr);

    paramList.emplace_back(indicesParamStr);

    NormalizeGatherShape<SymbolicScalar>(outputValidShapes, paramDim, indicesDim, axis);



    std::transform(outputValidShapes.begin(), outputValidShapes.end(), back_inserter(paramList), [](SymbolicScalar x) {

        return SymbolicExpressionTable::BuildExpression(x);

    });



    auto paramGMStride = GenParamIdxExprByIndex(paramIndex, paramDim, PREFIX_STR_RAW_SHAPE);

    auto paramStartOffsets = GenParamIdxExprByIndex(paramIndex, paramDim, PREFIX_STR_OFFSET);

    FillVecWithDummyInHead<std::string>(paramGMStride, SHAPE_DIM4 - paramDim, std::string("1"));

    FillVecWithDummyInHead<std::string>(paramStartOffsets, SHAPE_DIM4 - paramDim, std::string("0"));

    paramList.insert(paramList.end(), paramGMStride.begin() + 1, paramGMStride.end());

    paramList.insert(paramList.end(), paramStartOffsets.begin(), paramStartOffsets.end());



    auto indicesGMStride = GenParamIdxExprByIndex(indicesIndex, indicesDim, PREFIX_STR_RAW_SHAPE);

    auto indicesStartOffsets = GenParamIdxExprByIndex(indicesIndex, indicesDim, PREFIX_STR_OFFSET);

    FillVecWithDummyInHead<std::string>(indicesGMStride, SHAPE_DIM2 - indicesDim, std::string("1"));

    FillVecWithDummyInHead<std::string>(indicesStartOffsets, SHAPE_DIM2 - indicesDim, std::string("0"));

    paramList.insert(paramList.end(), indicesGMStride.begin() + 1, indicesGMStride.end());

    paramList.insert(paramList.end(), indicesStartOffsets.begin(), indicesStartOffsets.end());



    std::string tiloOpCallParam = JoinString(paramList, ", ");

    paramList.clear();

    os << tileOpName.c_str() << "<" << templateParam << ">"

       << "(" << tiloOpCallParam << ");\n";



    return os.str();

}

std::string CodeGenOpNPU::PrintGatherLayout() const

{

    auto outputRawShapes = rawShape[ID0];

    auto paramRawShapes = rawShape[ID1];

    auto indicesRawShapes = rawShape[ID2];

    auto outputValidShapes = dynamicValidShape[ID0];

    auto paramValidShapes = dynamicValidShape[ID1];

    auto indicesValidShapes = dynamicValidShape[ID2];

    size_t paramDim = paramValidShapes.size();

    size_t indicesDim = indicesValidShapes.size();

    const int64_t axis = AnyCast<int64_t>(opAttrs.at("op_attr_axis"));

    std::vector<size_t> helpIndex = {ID0, ID1, ID2, ID3, ID4};

    if (indicesDim == 1 && axis != 0) {

        HelpNormalize(helpIndex, axis, paramDim);

    }

    auto paramOffsetSymbol = GenGetParamMacroPacked(ID1, paramDim, PREFIX_STR_OFFSET);

    auto indicesOffsetSymbol = GenGetParamMacroPacked(ID2, indicesDim, PREFIX_STR_OFFSET);

    std::string coordCpparamOffset = WrapParamByParentheses(paramOffsetSymbol);

    std::string coordCpindicesOffset = WrapParamByParentheses(indicesOffsetSymbol);

    std::string coord4Param = PrintCoord(paramDim, coordCpparamOffset);

    std::string coord4Indices = PrintCoord(indicesDim, coordCpindicesOffset);



    auto tileOpParams = GetTileOpParamsByOrder();

    tileOpParams.insert(tileOpParams.end(), {coord4Param, coord4Indices});



    std::vector<std::string> paramList;

    paramList.emplace_back(std::to_string(NormalizeAxis(axis, paramDim)));

    std::transform(

        helpIndex.begin(), helpIndex.end(), back_inserter(paramList), [](size_t x) { return std::to_string(x); });



    std::ostringstream oss;

    oss << tileOpName << WrapParamByAngleBrackets(paramList) << WrapParamByParentheses(tileOpParams) << STMT_END;

    return oss.str();

}



std::string CodeGenOpNPU::GenGatherOp() const

{

    if (isSupportLayout) {

        return PrintGatherLayout();

    }

    if (isDynamicFunction) {

        return PrintGatherDynamicUnaligned();

    }

    ASSERT(GenCodeErr::PRINT_MODE_ERROR, false) << "Gather operator does not support static graph";

    return "";

}



std::string CodeGenOpNPU::PrintGatherInUBLayout() const

{

    constexpr int paramIndex = 1;

    constexpr int indicesIndex = 2;

    constexpr int blockTableIndex = 3;

    constexpr int paramDim = 2;

    constexpr int indicesDim = 2;

    constexpr int blockTableDim = 2;



    auto paramOffsetSymbol = GenGetParamMacroPacked(paramIndex, paramDim, PREFIX_STR_OFFSET);

    auto indicesOffsetSymbol = GenGetParamMacroPacked(indicesIndex, indicesDim, PREFIX_STR_OFFSET);

    auto blockTableOffsetSymbol = GenGetParamMacroPacked(blockTableIndex, blockTableDim, PREFIX_STR_OFFSET);



    std::string coordCpparamOffset = WrapParamByParentheses(paramOffsetSymbol);

    std::string coordCpindicesOffset = WrapParamByParentheses(indicesOffsetSymbol);

    std::string coordCpblockTableOffset = WrapParamByParentheses(blockTableOffsetSymbol);

    std::string coord4Param = PrintCoord(paramDim, coordCpparamOffset);

    std::string coord4Indices = PrintCoord(indicesDim, coordCpindicesOffset);

    std::string coord4BlockTable = PrintCoord(blockTableDim, coordCpblockTableOffset);

    std::vector<std::string> paramList;

    ASSERT(OperErr::ATTRIBUTE_INVALID, opAttrs.find(OpAttributeKey::blockSize) != opAttrs.end())

        << "GenGatherOp: There is nop blockSize attribute here";

    const int64_t blockSize = AnyCast<int64_t>(opAttrs.at(OpAttributeKey::blockSize));

    paramList.emplace_back(std::to_string(blockSize));

    std::string templateParam = JoinString(paramList, CONN_COMMA);



    std::vector<std::string> tileOpParamList = GetTileOpParamsByOrder();

    tileOpParamList.insert(tileOpParamList.end(), {coord4Param, coord4Indices, coord4BlockTable});

    std::ostringstream oss;

    oss << tileOpName << "<" << templateParam << ">" << WrapParamByParentheses(tileOpParamList) << STMT_END;

    return oss.str();

}

std::string CodeGenOpNPU::PrintGatherInUBDynamicUnaligned() const

{

    std::vector dstShape = rawShape[0];

    std::vector src0Shape = rawShape[1];



    std::string resultDtypeStr = DataType2CCEStr(operandDtype[ID0]);

    std::string paramDtypeStr = DataType2CCEStr(operandDtype[ID1]);

    std::string indicesDtypeStr = DataType2CCEStr(operandDtype[ID2]);

    std::string blockTableDtypeStr = DataType2CCEStr(operandDtype[ID3]);

    ASSERT(GenCodeErr::DATA_TYPE_MISMATCHED, resultDtypeStr == paramDtypeStr)

        << "resultDtypeStr and paramDtypeStr must be same!";

    ASSERT(OperErr::ATTRIBUTE_INVALID, opAttrs.find(OpAttributeKey::blockSize) != opAttrs.end())

        << "GenGatherOp: There is nop blockSize attribute here";

    const int64_t blockSize = AnyCast<int64_t>(opAttrs.at(OpAttributeKey::blockSize));

    auto outputRawShapes = rawShape[ID0];

    auto paramRawShapes = rawShape[ID1];

    auto indicesRawShapes = rawShape[ID2];

    auto outputValidShapes = dynamicValidShape[ID0];

    auto paramValidShapes = dynamicValidShape[ID1];

    auto indicesValidShapes = dynamicValidShape[ID2];

    constexpr int paramDim = 2;

    constexpr int indicesDim = 2;

    constexpr int blockTableDim = 2;

    std::ostringstream os;

    std::vector<std::string> paramList;

    paramList.emplace_back(paramDtypeStr);

    paramList.emplace_back(indicesDtypeStr);

    paramList.emplace_back(blockTableDtypeStr);

    paramList.emplace_back(std::to_string(outputRawShapes[0]));

    paramList.emplace_back(std::to_string(outputRawShapes[1]));

    paramList.emplace_back(std::to_string(blockSize));

    std::string templateParam = JoinString(paramList, ", ");

    paramList.clear();

    constexpr int paramIndex = 1;

    constexpr int indicesIndex = 2;

    constexpr int blockTableIndex = 3;

    std::string paramVar = GenGmParamVar(paramIndex);

    std::string indicesVar = GenGmParamVar(indicesIndex);

    std::string blockTableVar = GenGmParamVar(blockTableIndex);

    std::string outputVar = sm->QueryVarNameByTensorMagic(operandWithMagic[ID0]);

    std::string outputParamStr = "(__ubuf__ " + resultDtypeStr + "*)" + outputVar;

    std::string paramParamStr = "(__gm__ " + paramDtypeStr + "*)" + paramVar;

    std::string indicesParamStr = "(__gm__ " + indicesDtypeStr + "*)" + indicesVar;

    std::string blockTableParamStr = "(__gm__ " + blockTableDtypeStr + "*)" + blockTableVar;

    paramList.emplace_back(outputParamStr);

    paramList.emplace_back(paramParamStr);

    paramList.emplace_back(indicesParamStr);

    paramList.emplace_back(blockTableParamStr);

    paramList.emplace_back(SymbolicExpressionTable::BuildExpression(outputValidShapes[1]));



    auto paramGMStride = GenParamIdxExprByIndex(paramIndex, paramDim, PREFIX_STR_RAW_SHAPE);

    auto paramStartOffsets = GenParamIdxExprByIndex(paramIndex, paramDim, PREFIX_STR_OFFSET);

    paramList.emplace_back(paramGMStride[1]);

    paramList.emplace_back(paramStartOffsets[0]);

    paramList.emplace_back(paramStartOffsets[1]);



    paramList.emplace_back(SymbolicExpressionTable::BuildExpression(outputValidShapes[0]));

    auto indicesGMStride = GenParamIdxExprByIndex(indicesIndex, indicesDim, PREFIX_STR_RAW_SHAPE);

    auto indicesStartOffsets = GenParamIdxExprByIndex(indicesIndex, indicesDim, PREFIX_STR_OFFSET);

    paramList.emplace_back(indicesGMStride[1]);

    paramList.emplace_back(indicesStartOffsets[0]);

    paramList.emplace_back(indicesStartOffsets[1]);

    auto blockTableGMStride = GenParamIdxExprByIndex(blockTableIndex, blockTableDim, PREFIX_STR_RAW_SHAPE);

    auto blockTableStartOffsets = GenParamIdxExprByIndex(blockTableIndex, blockTableDim, PREFIX_STR_OFFSET);

    paramList.emplace_back(blockTableGMStride[1]);

    paramList.emplace_back(blockTableStartOffsets[0]);

    paramList.emplace_back(blockTableStartOffsets[1]);



    std::string tiloOpCallParam = JoinString(paramList, ", ");

    paramList.clear();

    os << tileOpName.c_str() << "<" << templateParam << ">"

       << "(" << tiloOpCallParam << ");\n";



    return os.str();

}



std::string CodeGenOpNPU::GenGatherInUB() const

{

    if (isSupportLayout) {

        return PrintGatherInUBLayout();

    }

    if (isDynamicFunction) {

        return PrintGatherInUBDynamicUnaligned();

    }

    ASSERT(GenCodeErr::PRINT_MODE_ERROR, false) << "Gather operator does not support static graph";

    return "";

}

} // namespace npu::tile_fwk