/**
Copyright (c) 2025 Huawei Technologies Co., Ltd.
This program is free software, you can redistribute it and/or modify it under the terms and conditions of
CANN Open Software License Agreement Version 2.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License.
*/

#include <pto/pto-inst.hpp>
#include <pto/common/constants.hpp>
#include <limits>
#include <algorithm>

using namespace std;
using namespace pto;

#define LOGSIZE 128
#define PRINTLOG 4
#define DEBUGLOG
#ifdef DEBUGLOG
#define LOG(x) (*(gLog++) = x)
#else
#define LOG(x)
#endif

// case shape is static, but testing would do dynamic or static test
template <int shape0, int shape1, int shape2, int shape3, int shape4>
AICORE __inline__ auto getOptDynShape(int gShape0, int gShape1, int gShape2, int gShape3, int gShape4)
{
    if constexpr (shape0 == 1) {
        using DynShapeDim5 = Shape<1, -1, -1, -1, -1>;
        DynShapeDim5 dynShape(gShape1, gShape2, gShape3, gShape4);
        return dynShape;
    } else if constexpr (shape0 == 1 && shape1 == 1) {
        using DynShapeDim5 = Shape<1, 1, -1, -1, -1>;
        DynShapeDim5 dynShape(gShape2, gShape3, gShape4);
        return dynShape;
    } else if constexpr (shape0 == 1 && shape1 == 1 && shape2 == 1) {
        using DynShapeDim5 = Shape<1, 1, 1, -1, -1>;
        DynShapeDim5 dynShape(gShape3, gShape4);
        return dynShape;
    } else if constexpr (shape0 == 1 && shape1 == 1 && shape2 == 1 && shape3 == 1) {
        using DynShapeDim5 = Shape<1, 1, 1, 1, -1>;
        DynShapeDim5 dynShape(gShape4);
        return dynShape;
    } else {
        using DynShapeDim5 = Shape<-1, -1, -1, -1, -1>;
        DynShapeDim5 dynShape(gShape0, gShape1, gShape2, gShape3, gShape4);
        return dynShape;
    }
}

// case shape is static, but testing would do dynamic or static test
template <typename T, int shape0, int shape1, int shape2, int shape3, int shape4, int tRows, int tCols, BLayout major,
          int dyn>
AICORE __inline__ auto getGlobalTensor(__gm__ T *addr, int gShape0, int gShape1, int gShape2, int gShape3, int gShape4)
{
    if constexpr (dyn) {
        int stride0 = gShape1 * gShape2 * shape3 * shape4;
        int stride1 = gShape2 * shape3 * shape4;
        int stride2 = shape3 * shape4;

        using DynStrideDim5 = pto::Stride<-1, -1, -1, -1, -1>;
        auto dynShape =
            getOptDynShape<shape0, shape1, shape2, shape3, shape4>(gShape0, gShape1, gShape2, gShape3, gShape4);
        using GlobalData = GlobalTensor<T, decltype(dynShape), DynStrideDim5>;

        if constexpr (major == BLayout::RowMajor) {
            GlobalData srcGlobal(addr, dynShape, DynStrideDim5(stride0, stride1, stride2, shape4, 1));
            return srcGlobal;
        } else {
            GlobalData srcGlobal(addr, dynShape, DynStrideDim5(stride0, stride1, stride2, 1, shape4));
            return srcGlobal;
        }
    } else // static
    {
        constexpr int stride0 = shape1 * shape2 * shape3 * shape4;
        constexpr int stride1 = shape2 * shape3 * shape4;
        constexpr int stride2 = shape3 * shape4;
        using StaticShapeDim5 = Shape<shape0, shape1, shape2, tRows, tCols>;

        if constexpr (major == BLayout::RowMajor) {
            using StaticStrideDim5 = pto::Stride<stride0, stride1, stride2, shape4, 1>;
            using GlobalData = GlobalTensor<T, StaticShapeDim5, StaticStrideDim5>;
            GlobalData srcGlobal(addr);
            return srcGlobal;
        } else {
            using StaticStrideDim5 = pto::Stride<stride0, stride1, stride2, 1, shape4>;
            using GlobalData = GlobalTensor<T, StaticShapeDim5, StaticStrideDim5>;
            GlobalData srcGlobal(addr);
            return srcGlobal;
        }
    }
}

template <typename T, int shape0, int shape1, int shape2, int shape3, int shape4, int tRows, int tCols, BLayout major,
          int dyn>
AICORE __inline__ auto getGlobalTensorNZ(__gm__ T *addr, int gShape0, int gShape1, int gShape2, int gShape3,
                                         int gShape4)
{
    constexpr int blockSize = 32 / sizeof(T);
    constexpr int innerCols = blockSize;
    constexpr int innerRows = 16;

    if constexpr (dyn) {
        int stride0 = gShape1 * gShape2 * innerRows * innerCols;
        int stride1 = gShape2 * innerRows * innerCols;
        int stride2 = innerRows * innerCols;

        using DynShapeDim5 = Shape<-1, -1, -1, -1, -1>;
        DynShapeDim5 dynShape(gShape0, gShape1, gShape2, innerRows, innerCols);
        using DynStrideDim5 = pto::Stride<-1, -1, -1, -1, -1>;
        using GlobalData = GlobalTensor<T, DynShapeDim5, DynStrideDim5, Layout::NZ>;

        GlobalData srcGlobal(addr, dynShape, DynStrideDim5(stride0, stride1, stride2, innerCols, 1));
        return srcGlobal;
    } else // static
    {
        constexpr int stride0 = shape1 * shape2 * innerRows * innerCols;
        constexpr int stride1 = shape2 * innerRows * innerCols;
        constexpr int stride2 = innerRows * innerCols;

        using StaticShapeDim5 = Shape<shape0, shape1, shape2, innerRows, innerCols>;
        using StaticStrideDim5 = pto::Stride<stride0, stride1, stride2, innerCols, 1>;
        using GlobalData = GlobalTensor<T, StaticShapeDim5, StaticStrideDim5, Layout::NZ>;

        GlobalData srcGlobal(addr);
        return srcGlobal;
    }
}

#define type_32_aligned(T) (32 / sizeof(T))
#define align_to_32B(x, T) ((((x) + type_32_aligned(T) - 1) / type_32_aligned(T)) * (type_32_aligned(T)))

template <typename T, int shape0, int shape1, int shape2, int shape3, int shape4, int kTRows_, int kTCols_, int dyn_,
          PadValue PadVal_ = PadValue::Null>
AICORE void runTLOADND(__gm__ T *out, __gm__ T *src, int gShape0, int gShape1, int gShape2, int gRows, int gCols,
                       __gm__ uint64_t *gLog)
{
    using TileData =
        Tile<TileType::Vec, T, kTRows_, kTCols_, BLayout::RowMajor, -1, -1, SLayout::NoneBox, 512, PadVal_>;
    TileData vecTile(kTRows_, gCols);

    constexpr int kGTRows = kTRows_ / shape0 / shape1 / shape2; // Dst Tile Rows, merged all shape0*shape1*shape2 row
    constexpr int shape4_aligned = align_to_32B(shape4, T);
    auto srcGlobal =
        getGlobalTensor<T, shape0, shape1, shape2, kGTRows, shape4, kGTRows, shape4, BLayout::RowMajor, dyn_>(
            src, gShape0, gShape1, gShape2, kGTRows, shape4);

    TASSIGN(vecTile, 0);

    TLOAD(vecTile, srcGlobal);
    for (size_t i = 0; i < TileData::Rows * TileData::Cols; i++) {
        out[i] = vecTile.data()[i];
    }
}

template <typename T, int shape0, int shape1, int shape2, int shape3, int shape4, int kTRows_, int kTCols_, int dyn_,
          PadValue PadVal_ = PadValue::Null>
AICORE void runTLOADDN(__gm__ T *out, __gm__ T *src, int gShape0, int gShape1, int gShape2, int gRows, int gCols,
                       __gm__ uint64_t *gLog)
{
    using TileData =
        Tile<TileType::Vec, T, kTRows_, kTCols_, BLayout::ColMajor, -1, -1, SLayout::NoneBox, 512, PadVal_>;
    TileData vecTile(gRows, gCols);

    constexpr int kGTCols = kTCols_ / shape0 / shape1 / shape2; // Dst Tile Rows, merged all shape0*shape1*shape2 row
    auto srcGlobal =
        getGlobalTensor<T, shape0, shape1, shape2, shape3, kGTCols, shape3, kGTCols, BLayout::RowMajor, dyn_>(
            src, gShape0, gShape1, gShape2, shape3, kGTCols);

    TASSIGN(vecTile, 0);

    TLOAD(vecTile, srcGlobal);
    for (size_t i = 0; i < TileData::Rows * TileData::Cols; i++) {
        out[i] = vecTile.data()[i];
    }
}

template <typename T, int shape0, int shape1, int shape2, int shape3, int shape4, int kTRows_, int kTCols_, int dyn_,
          PadValue PadVal_ = PadValue::Null>
AICORE void runTLOADNZ(__gm__ T *out, __gm__ T *src, int gShape0, int gShape1, int gShape2, int gRows, int gCols,
                       __gm__ uint64_t *gLog)
{
    constexpr int Rows = shape2 * shape3;
    constexpr int Cols = shape0 * shape1 * shape4;
    constexpr int blockSize = 32 / sizeof(T);
    constexpr int alignedCols = (Cols + blockSize - 1) / blockSize * blockSize;

    using TileData =
        Tile<TileType::Vec, T, Rows, alignedCols, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 512, PadVal_>;
    TileData vecTile(Rows, Cols);
    auto srcGlobal =
        getGlobalTensorNZ<T, shape0, shape1, shape2, shape3, shape4, Rows, alignedCols, BLayout::RowMajor, dyn_>(
            src, gShape0, gShape1, gShape2, shape3, shape4);

    TASSIGN(vecTile, 0);

    TLOAD(vecTile, srcGlobal);
    for (size_t i = 0; i < TileData::Rows * TileData::Cols; i++) {
        out[i] = vecTile.data()[i];
    }
}

extern "C" __global__ AICORE void launchTLOAD_1(__gm__ uint8_t *out, __gm__ uint8_t *src, int gShape0, int gShape1,
                                                int gShape2, int gRows, int gCols, __gm__ uint64_t *gLog)
{
    runTLOADND<float, 1, 1, 1, 128, 128, 128, 128, 1, PadValue::Null>((__gm__ float *)out, (__gm__ float *)src, gShape0,
                                                                      gShape1, gShape2, gRows, gCols, gLog);
}

extern "C" __global__ AICORE void launchTLOAD_2(__gm__ uint8_t *out, __gm__ uint8_t *src, int gShape0, int gShape1,
                                                int gShape2, int gRows, int gCols, __gm__ uint64_t *gLog)
{
    runTLOADND<float, 2, 2, 2, 256, 64, 256, 64, 1, PadValue::Null>((__gm__ float *)out, (__gm__ float *)src, gShape0,
                                                                    gShape1, gShape2, gRows, gCols, gLog);
}

extern "C" __global__ AICORE void launchTLOAD_3(__gm__ uint8_t *out, __gm__ uint8_t *src, int gShape0, int gShape1,
                                                int gShape2, int gRows, int gCols, __gm__ uint64_t *gLog)
{
    runTLOADND<float, 1, 1, 1, 128, 127, 128, 128, 1, PadValue::Max>((__gm__ float *)out, (__gm__ float *)src, gShape0,
                                                                     gShape1, gShape2, gRows, gCols, gLog);
}

extern "C" __global__ AICORE void launchTLOAD_4(__gm__ uint8_t *out, __gm__ uint8_t *src, int gShape0, int gShape1,
                                                int gShape2, int gRows, int gCols, __gm__ uint64_t *gLog)
{
    runTLOADND<int16_t, 1, 1, 1, 128, 127, 128, 128, 1, PadValue::Max>((__gm__ int16_t *)out, (__gm__ int16_t *)src,
                                                                       gShape0, gShape1, gShape2, gRows, gCols, gLog);
}

extern "C" __global__ AICORE void launchTLOAD_5(__gm__ uint8_t *out, __gm__ uint8_t *src, int gShape0, int gShape1,
                                                int gShape2, int gRows, int gCols, __gm__ uint64_t *gLog)
{
    runTLOADND<uint8_t, 1, 1, 1, 128, 127, 128, 128, 1, PadValue::Min>((__gm__ uint8_t *)out, (__gm__ uint8_t *)src,
                                                                       gShape0, gShape1, gShape2, gRows, gCols, gLog);
}

extern "C" __global__ AICORE void launchTLOAD_6(__gm__ uint8_t *out, __gm__ uint8_t *src, int gShape0, int gShape1,
                                                int gShape2, int gRows, int gCols, __gm__ uint64_t *gLog)
{
    runTLOADND<int16_t, 1, 1, 32, 64, 128, 64, 128, 1, PadValue::Null>((__gm__ int16_t *)out, (__gm__ int16_t *)src,
                                                                       gShape0, gShape1, gShape2, gRows, gCols, gLog);
}

extern "C" __global__ AICORE void launchTLOAD_7(__gm__ uint8_t *out, __gm__ uint8_t *src, int gShape0, int gShape1,
                                                int gShape2, int gRows, int gCols, __gm__ uint64_t *gLog)
{
    runTLOADND<int16_t, 1, 1, 32, 64, 128, 64, 128, 0, PadValue::Null>((__gm__ int16_t *)out, (__gm__ int16_t *)src,
                                                                       gShape0, gShape1, gShape2, gRows, gCols, gLog);
}

extern "C" __global__ AICORE void launchTLOAD_8(__gm__ uint8_t *out, __gm__ uint8_t *src, int gShape0, int gShape1,
                                                int gShape2, int gRows, int gCols, __gm__ uint64_t *gLog)
{
    runTLOADND<float, 2, 2, 2, 256, 60, 256, 64, 1, PadValue::Max>((__gm__ float *)out, (__gm__ float *)src, gShape0,
                                                                   gShape1, gShape2, gRows, gCols, gLog);
}

extern "C" __global__ AICORE void launchTLOAD_9(__gm__ uint8_t *out, __gm__ uint8_t *src, int gShape0, int gShape1,
                                                int gShape2, int gRows, int gCols, __gm__ uint64_t *gLog)
{
    runTLOADDN<float, 1, 1, 32, 64, 128, 64, 128, 1, PadValue::Null>((__gm__ float *)out, (__gm__ float *)src, gShape0,
                                                                     gShape1, gShape2, gRows, gCols, gLog);
}

extern "C" __global__ AICORE void launchTLOAD_10(__gm__ uint8_t *out, __gm__ uint8_t *src, int gShape0, int gShape1,
                                                 int gShape2, int gRows, int gCols, __gm__ uint64_t *gLog)
{
    runTLOADDN<float, 2, 2, 2, 255, 60, 256, 64, 1, PadValue::Null>((__gm__ float *)out, (__gm__ float *)src, gShape0,
                                                                    gShape1, gShape2, gRows, gCols, gLog);
}

extern "C" __global__ AICORE void launchTLOAD_11(__gm__ uint8_t *out, __gm__ uint8_t *src, int gShape0, int gShape1,
                                                 int gShape2, int gRows, int gCols, __gm__ uint64_t *gLog)
{
    runTLOADNZ<float, 1, 1, 1, 16, 8, 16, 8, 1, PadValue::Null>((__gm__ float *)out, (__gm__ float *)src, gShape0,
                                                                gShape1, gShape2, gRows, gCols, gLog);
}

extern "C" __global__ AICORE void launchTLOAD_12(__gm__ uint8_t *out, __gm__ uint8_t *src, int gShape0, int gShape1,
                                                 int gShape2, int gRows, int gCols, __gm__ uint64_t *gLog)
{
    runTLOADNZ<int16_t, 2, 2, 2, 16, 16, 16, 16, 1, PadValue::Null>((__gm__ int16_t *)out, (__gm__ int16_t *)src,
                                                                    gShape0, gShape1, gShape2, gRows, gCols, gLog);
}

extern "C" __global__ AICORE void launchTLOAD_13(__gm__ uint8_t *out, __gm__ uint8_t *src, int gShape0, int gShape1,
                                                 int gShape2, int gRows, int gCols, __gm__ uint64_t *gLog)
{
    runTLOADNZ<uint8_t, 1, 2, 1, 16, 32, 16, 32, 1, PadValue::Null>((__gm__ uint8_t *)out, (__gm__ uint8_t *)src,
                                                                    gShape0, gShape1, gShape2, gRows, gCols, gLog);
}

template <int32_t testKey>
void launchTLOAD(uint8_t *out, uint8_t *src, uint64_t *gLog, void *stream)
{
    if constexpr (testKey == 1) {
        launchTLOAD_1(out, src, 1, 1, 1, 128, 128, gLog);
    } else if constexpr (testKey == 2) {
        launchTLOAD_2(out, src, 2, 2, 2, 256, 64, gLog);
    } else if constexpr (testKey == 3) {
        launchTLOAD_3(out, src, 1, 1, 1, 128, 127, gLog);
    } else if constexpr (testKey == 4) {
        launchTLOAD_4(out, src, 1, 1, 1, 128, 127, gLog);
    } else if constexpr (testKey == 5) {
        launchTLOAD_5(out, src, 1, 1, 1, 128, 127, gLog);
    } else if constexpr (testKey == 6) {
        launchTLOAD_6(out, src, 1, 1, 32, 64, 128, gLog);
    } else if constexpr (testKey == 7) {
        launchTLOAD_7(out, src, 1, 1, 32, 64, 128, gLog);
    } else if constexpr (testKey == 8) {
        launchTLOAD_8(out, src, 2, 2, 2, 256, 60, gLog);
    } else if constexpr (testKey == 9) {
        launchTLOAD_9(out, src, 1, 1, 32, 64, 128, gLog);
    } else if constexpr (testKey == 10) {
        launchTLOAD_10(out, src, 2, 2, 2, 255, 64, gLog);
    } else if constexpr (testKey == 11) {
        launchTLOAD_11(out, src, 1, 1, 1, 16, 8, gLog);
    } else if constexpr (testKey == 12) {
        launchTLOAD_12(out, src, 2, 2, 2, 16, 16, gLog);
    } else if constexpr (testKey == 13) {
        launchTLOAD_13(out, src, 1, 2, 1, 16, 32, gLog);
    }
}

template <typename T, int Shape0, int Shape1, int Shape2, int Shape3, int Shape4, int kTRows_, int kTCols_,
          PadValue PadVal_ = PadValue::Null>
int get_input_golden_case(uint8_t *input, uint8_t *golden)
{
    constexpr int shape4_aligned = align_to_32B(Shape4, T);
    static_assert((Shape3 % (Shape0 * Shape1 * Shape2)) == 0);
    constexpr int sh3 = Shape3 / (Shape0 * Shape1 * Shape2);
    int in_byteSize = Shape0 * Shape1 * Shape2 * sh3 * Shape4 * sizeof(T);
    int out_byteSize = Shape0 * Shape1 * Shape2 * sh3 * shape4_aligned * sizeof(T);

    T in_arr[Shape0][Shape1][Shape2][sh3][Shape4] = {};
    T gold_arr[Shape0][Shape1][Shape2][sh3][shape4_aligned] = {};

    for (int x0 = 0; x0 < Shape0; x0++)
        for (int x1 = 0; x1 < Shape1; x1++)
            for (int x2 = 0; x2 < Shape2; x2++)
                for (int i = 0; i < sh3; i++) {
                    for (int j = 0; j < shape4_aligned; j++) {
                        if (j < Shape4) {
                            in_arr[x0][x1][x2][i][j] = x0 * Shape1 * Shape2 * sh3 * Shape4 +
                                                       x1 * Shape2 * sh3 * Shape4 + x2 * sh3 * Shape4 + i * Shape4 + j;
                            gold_arr[x0][x1][x2][i][j] = in_arr[x0][x1][x2][i][j];
                        } else {
                            if (std::numeric_limits<T>::has_infinity) {
                                if (PadVal_ == PadValue::Max)
                                    gold_arr[x0][x1][x2][i][j] = std::numeric_limits<T>::infinity();
                                else if (PadVal_ == PadValue::Min)
                                    gold_arr[x0][x1][x2][i][j] = -std::numeric_limits<T>::infinity();
                                else
                                    gold_arr[x0][x1][x2][i][j] = 0;
                            } else {
                                if (PadVal_ == PadValue::Max)
                                    gold_arr[x0][x1][x2][i][j] = std::numeric_limits<T>::max();
                                else if (PadVal_ == PadValue::Min)
                                    gold_arr[x0][x1][x2][i][j] = std::numeric_limits<T>::min();
                                else
                                    gold_arr[x0][x1][x2][i][j] = 0;
                            }
                        }
                    } // j
                } // i

    std::copy((uint8_t *)in_arr, ((uint8_t *)(in_arr)) + in_byteSize, input);
    std::copy((uint8_t *)gold_arr, ((uint8_t *)(gold_arr)) + out_byteSize, golden);
    return sizeof(gold_arr);
}

template <typename T, int Shape0, int Shape1, int Shape2, int Shape3, int Shape4, int kTRows_, int kTCols_,
          PadValue PadVal_ = PadValue::Null>
int get_input_golden_case_DN(uint8_t *input, uint8_t *golden)
{
    constexpr int shape3_aligned = align_to_32B(Shape3, T);
    static_assert((Shape4 % (Shape0 * Shape1 * Shape2)) == 0);
    constexpr int sh4 = Shape4 / (Shape0 * Shape1 * Shape2);
    int in_byteSize = Shape0 * Shape1 * Shape2 * Shape3 * sh4 * sizeof(T);
    int out_byteSize = Shape0 * Shape1 * Shape2 * shape3_aligned * sh4 * sizeof(T);

    T in_arr[Shape0][Shape1][Shape2][Shape3][sh4] = {};
    T gold_arr[Shape0][Shape1][Shape2][sh4][shape3_aligned] = {};

    for (int x0 = 0; x0 < Shape0; x0++)
        for (int x1 = 0; x1 < Shape1; x1++)
            for (int x2 = 0; x2 < Shape2; x2++)
                for (int i = 0; i < shape3_aligned; i++) {
                    for (int j = 0; j < sh4; j++) {
                        if (i < Shape3) {
                            in_arr[x0][x1][x2][i][j] = x0 * Shape1 * Shape2 * Shape3 * sh4 +
                                                       x1 * Shape2 * Shape3 * sh4 + x2 * Shape3 * sh4 + i * sh4 + j;
                            gold_arr[x0][x1][x2][j][i] = in_arr[x0][x1][x2][i][j];
                        } else {
                            if (std::numeric_limits<T>::has_infinity) {
                                if (PadVal_ == PadValue::Max)
                                    gold_arr[x0][x1][x2][j][i] = std::numeric_limits<T>::infinity();
                                else if (PadVal_ == PadValue::Min)
                                    gold_arr[x0][x1][x2][j][i] = -std::numeric_limits<T>::infinity();
                                else
                                    gold_arr[x0][x1][x2][j][i] = 0;
                            } else {
                                if (PadVal_ == PadValue::Max)
                                    gold_arr[x0][x1][x2][j][i] = std::numeric_limits<T>::max();
                                else if (PadVal_ == PadValue::Min)
                                    gold_arr[x0][x1][x2][j][i] = std::numeric_limits<T>::min();
                                else
                                    gold_arr[x0][x1][x2][j][i] = 0;
                            }
                        }
                    } // j
                } // i

    std::copy((uint8_t *)in_arr, ((uint8_t *)(in_arr)) + in_byteSize, input);
    std::copy((uint8_t *)gold_arr, ((uint8_t *)(gold_arr)) + out_byteSize, golden);

    return sizeof(gold_arr);
}

template <int32_t testKey>
int get_input_golden(uint8_t *input, uint8_t *golden)
{
    if constexpr (testKey == 1) {
        return get_input_golden_case<float, 1, 1, 1, 128, 128, 128, 128, PadValue::Null>(input, golden);
    } else if constexpr (testKey == 2) {
        return get_input_golden_case<float, 2, 2, 2, 256, 64, 256, 64, PadValue::Null>(input, golden);
    } else if constexpr (testKey == 3) {
        return get_input_golden_case<float, 1, 1, 1, 128, 127, 128, 128, PadValue::Max>(input, golden);
    } else if constexpr (testKey == 4) {
        return get_input_golden_case<int16_t, 1, 1, 1, 128, 127, 128, 128, PadValue::Max>(input, golden);
    } else if constexpr (testKey == 5) {
        return get_input_golden_case<uint8_t, 1, 1, 1, 128, 127, 128, 128, PadValue::Min>(input, golden);
    } else if constexpr (testKey == 6 || testKey == 7) {
        return get_input_golden_case<int16_t, 1, 1, 32, 64, 128, 64, 128, PadValue::Null>(input,
                                                                                          golden); // e.g. BNSD->BSH
    } else if constexpr (testKey == 8) {
        return get_input_golden_case<float, 2, 2, 2, 256, 60, 256, 64, PadValue::Max>(input, golden);
    } else if constexpr (testKey == 9) {
        return get_input_golden_case_DN<float, 1, 1, 32, 64, 128, 64, 128, PadValue::Null>(input, golden);
    } else if constexpr (testKey == 10) {
        return get_input_golden_case_DN<float, 2, 2, 2, 255, 64, 256, 64, PadValue::Null>(input, golden);
    } else if constexpr (testKey == 11) {
        return get_input_golden_case<float, 1, 1, 1, 16, 8, 16, 8, PadValue::Null>(input, golden);
    } else if constexpr (testKey == 12) {
        return get_input_golden_case<int16_t, 2, 2, 2, 16, 16, 16, 16, PadValue::Null>(input, golden);
    } else if constexpr (testKey == 13) {
        return get_input_golden_case<uint8_t, 1, 2, 1, 16, 32, 16, 32, PadValue::Null>(input, golden);
    }
    return 0;
}

template void launchTLOAD<1>(uint8_t *out, uint8_t *src, uint64_t *gLog, void *stream); // 实例化 Key=0 的版本
template void launchTLOAD<2>(uint8_t *out, uint8_t *src, uint64_t *gLog, void *stream); // 实例化 Key=0 的版本
template void launchTLOAD<3>(uint8_t *out, uint8_t *src, uint64_t *gLog, void *stream); // 实例化 Key=0 的版本
template void launchTLOAD<4>(uint8_t *out, uint8_t *src, uint64_t *gLog, void *stream); // 实例化 Key=0 的版本
template void launchTLOAD<5>(uint8_t *out, uint8_t *src, uint64_t *gLog, void *stream); // 实例化 Key=0 的版本
template void launchTLOAD<6>(uint8_t *out, uint8_t *src, uint64_t *gLog, void *stream); // 实例化 Key=0 的版本
template void launchTLOAD<7>(uint8_t *out, uint8_t *src, uint64_t *gLog, void *stream); // 实例化 Key=0 的版本
template void launchTLOAD<8>(uint8_t *out, uint8_t *src, uint64_t *gLog, void *stream); // 实例化 Key=0 的版本
template void launchTLOAD<9>(uint8_t *out, uint8_t *src, uint64_t *gLog, void *stream);
template void launchTLOAD<10>(uint8_t *out, uint8_t *src, uint64_t *gLog, void *stream);
template void launchTLOAD<11>(uint8_t *out, uint8_t *src, uint64_t *gLog, void *stream);
template void launchTLOAD<12>(uint8_t *out, uint8_t *src, uint64_t *gLog, void *stream);
template void launchTLOAD<13>(uint8_t *out, uint8_t *src, uint64_t *gLog, void *stream);

template int get_input_golden<1>(uint8_t *input, uint8_t *golden);
template int get_input_golden<2>(uint8_t *input, uint8_t *golden);
template int get_input_golden<3>(uint8_t *input, uint8_t *golden);
template int get_input_golden<4>(uint8_t *input, uint8_t *golden);
template int get_input_golden<5>(uint8_t *input, uint8_t *golden);
template int get_input_golden<6>(uint8_t *input, uint8_t *golden);
template int get_input_golden<7>(uint8_t *input, uint8_t *golden);
template int get_input_golden<8>(uint8_t *input, uint8_t *golden);
template int get_input_golden<9>(uint8_t *input, uint8_t *golden);
template int get_input_golden<10>(uint8_t *input, uint8_t *golden);
template int get_input_golden<11>(uint8_t *input, uint8_t *golden);
template int get_input_golden<12>(uint8_t *input, uint8_t *golden);
template int get_input_golden<13>(uint8_t *input, uint8_t *golden);