/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/


/* !
 * \file tiling_strategy_custom.asc
 * \brief
 */

#include <iostream>
#include <vector>
#include <algorithm>
#include <iterator>
#include "acl/acl.h"
#include "kernel_operator.h"
#include "data_utils.h"

#ifndef ADD_CUSTOM_TILING_H
#define ADD_CUSTOM_TILING_H
#include <cstdint>

struct AddCustomTilingData {
    uint32_t totalLength;
    uint32_t dataType;

    uint32_t blockLength;
    uint32_t tileNum;
    uint32_t tileLength;
    uint32_t lastTileLength;

    uint32_t formerNum;
    uint32_t formerLength;
    uint32_t formerTileNum;
    uint32_t formerTileLength;
    uint32_t formerLastTileLength;

    uint32_t tailNum;
    uint32_t tailLength;
    uint32_t tailTileNum;
    uint32_t tailTileLength;
    uint32_t tailLastTileLength;

    uint32_t isEvenCore;
};
#endif // ADD_CUSTOM_TILING_H

constexpr uint32_t DATA_TYPE_SIZE[] = {2, 2, 4, 1, 2, 4};
constexpr uint32_t BLOCK_SIZE = 32;
constexpr uint32_t BUFFER_NUM = 2;
constexpr uint32_t UB_BLOCK_NUM = 100; // UB最大可以使用的block数量
constexpr uint32_t MAX_AVAILABLE_UB_BLOCK_NUM = UB_BLOCK_NUM / BUFFER_NUM * BUFFER_NUM;

namespace {
// tiling参数计算函数
void TilingParamsCalc(uint32_t length, uint32_t alignNum, uint32_t& tileNum, uint32_t& tileLength,
                      uint32_t& lastTileLength)
{
    tileNum = length / (alignNum * MAX_AVAILABLE_UB_BLOCK_NUM);

    // 单核需要计算的长度 < 单核UB最大一次可计算长度 -> 仅有尾块
    if (tileNum == 0U) {
        tileLength = 0U;
        lastTileLength = static_cast<uint32_t>(((length + alignNum - 1) / alignNum) * alignNum);
    } else if (static_cast<uint32_t>(length / alignNum) % MAX_AVAILABLE_UB_BLOCK_NUM == 0U) {
        // 单核需要计算的长度 = 单核UB最大一次可计算长度的整数倍 -> 仅有整块
        tileLength = MAX_AVAILABLE_UB_BLOCK_NUM * alignNum;
        lastTileLength = 0U;
    } else {
        // 有整块 + 尾块
        tileLength = MAX_AVAILABLE_UB_BLOCK_NUM * alignNum;
        lastTileLength = static_cast<uint32_t>(length - tileNum * tileLength);
    }
}
} // namespace

void GenerateTilingData(uint8_t* tilingBuf, uint32_t numBlocks)
{
    uint32_t totalLength; // 总共要计算的元素个数
    uint32_t dataTypeSize;
    uint32_t blockLength;
    uint32_t totalLengthAligned;

    AddCustomTilingData* tiling = reinterpret_cast<AddCustomTilingData*>(tilingBuf);
    totalLength = tiling->totalLength;
    dataTypeSize = DATA_TYPE_SIZE[tiling->dataType];

    uint32_t alignNum = BLOCK_SIZE / dataTypeSize; // 一个block中的元素个数
    /** 计算使用的核数 **/
    /* 如果传入数据的长度非32B对齐, 计算对齐后的长度*/
    totalLengthAligned = (totalLength % alignNum == 0U) ?
                             static_cast<uint32_t>(totalLength) :
                             ((static_cast<uint32_t>(totalLength) + alignNum - 1) / alignNum) * alignNum;

    /* 核间可均分场景 */
    if ((totalLengthAligned / alignNum) % numBlocks == 0U) {
        uint32_t tileNum = 0;
        uint32_t tileLength = 0;
        uint32_t lastTileLength = 0;
        blockLength = totalLengthAligned / numBlocks;
        TilingParamsCalc(blockLength, alignNum, tileNum, tileLength, lastTileLength);

        tiling->blockLength = blockLength;
        tiling->tileNum = tileNum;
        tiling->tileLength = tileLength;
        tiling->lastTileLength = lastTileLength;
        tiling->isEvenCore = 1U;
    } else { // 核间不可均分
        uint32_t formerNum = (totalLengthAligned / alignNum) % numBlocks;
        uint32_t tailNum = numBlocks - formerNum;
        // 计算整块和尾块的数据量
        uint32_t formerLength =
            static_cast<uint32_t>(((totalLengthAligned + numBlocks - 1) / numBlocks + alignNum - 1) / alignNum)
            * alignNum;
        uint32_t tailLength = (totalLengthAligned / numBlocks / alignNum) * alignNum;

        uint32_t formerTileNum;
        uint32_t formerTileLength;
        uint32_t formerLastTileLength;

        uint32_t tailTileNum;
        uint32_t tailTileLength;
        uint32_t tailLastTileLength;

        TilingParamsCalc(formerLength, alignNum, formerTileNum, formerTileLength, formerLastTileLength);
        TilingParamsCalc(tailLength, alignNum, tailTileNum, tailTileLength, tailLastTileLength);

        tiling->formerNum = formerNum;
        tiling->formerLength = formerLength;
        tiling->formerTileNum = formerTileNum;
        tiling->formerTileLength = formerTileLength;
        tiling->formerLastTileLength = formerLastTileLength;

        tiling->tailNum = tailNum;
        tiling->tailLength = tailLength;
        tiling->tailTileNum = tailTileNum;
        tiling->tailTileLength = tailTileLength;
        tiling->tailLastTileLength = tailLastTileLength;
        tiling->isEvenCore = 0U;
    }
}

constexpr uint32_t ADD_BFLOAT16 = 0;
constexpr uint32_t ADD_FLOAT16 = 1;
constexpr uint32_t ADD_FLOAT32 = 2;
constexpr uint32_t ADD_INT8 = 3;
constexpr uint32_t ADD_INT16 = 4;
constexpr uint32_t ADD_INT32 = 5;
constexpr uint32_t LAST_TWO_TILE = 2;

template <typename T>
class KernelAdd;
template <>
class KernelAdd<bfloat16_t> {
public:
    __aicore__ inline KernelAdd() {}
    __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, AddCustomTilingData tiling, AscendC::TPipe* pipeIn)
    {
        pipe = pipeIn;
        if (tiling.isEvenCore) {
            this->blockLength = tiling.blockLength;
            this->tileNum = tiling.tileNum;
            this->tileLength = tiling.tileLength / BUFFER_NUM;
            this->lastTileLength = tiling.lastTileLength;

            uint64_t offset = this->blockLength * AscendC::GetBlockIdx();
            xGm.SetGlobalBuffer((__gm__ bfloat16_t*)x + offset, this->blockLength);
            yGm.SetGlobalBuffer((__gm__ bfloat16_t*)y + offset, this->blockLength);
            zGm.SetGlobalBuffer((__gm__ bfloat16_t*)z + offset, this->blockLength);
        } else {
            if (AscendC::GetBlockIdx() < tiling.formerNum) {
                this->tileNum = tiling.formerTileNum;
                this->tileLength = tiling.formerTileLength / BUFFER_NUM;
                this->lastTileLength = tiling.formerLastTileLength;

                uint64_t offset = tiling.formerLength * AscendC::GetBlockIdx();
                xGm.SetGlobalBuffer((__gm__ bfloat16_t*)x + offset, tiling.formerLength);
                yGm.SetGlobalBuffer((__gm__ bfloat16_t*)y + offset, tiling.formerLength);
                zGm.SetGlobalBuffer((__gm__ bfloat16_t*)z + offset, tiling.formerLength);
            } else {
                this->tileNum = tiling.tailTileNum;
                this->tileLength = tiling.tailTileLength / BUFFER_NUM;
                this->lastTileLength = tiling.tailLastTileLength;

                uint64_t offset = tiling.formerLength * tiling.formerNum
                                  + tiling.tailLength * (AscendC::GetBlockIdx() - tiling.formerNum);
                xGm.SetGlobalBuffer((__gm__ bfloat16_t*)x + offset, tiling.tailLength);
                yGm.SetGlobalBuffer((__gm__ bfloat16_t*)y + offset, tiling.tailLength);
                zGm.SetGlobalBuffer((__gm__ bfloat16_t*)z + offset, tiling.tailLength);
            }
        }
        this->initBufferLength = AscendC::Std::max(this->tileLength, this->lastTileLength);
        pipe->InitBuffer(inQueueX, BUFFER_NUM, this->initBufferLength * sizeof(bfloat16_t));
        pipe->InitBuffer(inQueueY, BUFFER_NUM, this->initBufferLength * sizeof(bfloat16_t));
        pipe->InitBuffer(outQueueZ, BUFFER_NUM, this->initBufferLength * sizeof(bfloat16_t));

        pipe->InitBuffer(tmpBuf0, this->initBufferLength * sizeof(float));
        pipe->InitBuffer(tmpBuf1, this->initBufferLength * sizeof(float));
    }
    __aicore__ inline void Process()
    {
        uint32_t loopCount = this->tileNum * BUFFER_NUM;
        for (uint32_t i = 0; i < loopCount; i++) {
            CopyIn(i, this->tileLength);
            Compute(i, this->tileLength);
            CopyOut(i, this->tileLength);
        }

        // 进行尾块计算,不做double buffer操作
        if (this->lastTileLength > 0) {
            CopyIn(loopCount, this->lastTileLength);
            Compute(loopCount, this->lastTileLength);
            CopyOut(loopCount, this->lastTileLength);
        }
    }

private:
    __aicore__ inline void CopyIn(uint32_t progress, uint32_t tileLength)
    {
        AscendC::LocalTensor<bfloat16_t> xLocal = inQueueX.AllocTensor<bfloat16_t>();
        AscendC::LocalTensor<bfloat16_t> yLocal = inQueueY.AllocTensor<bfloat16_t>();
        AscendC::DataCopy(xLocal, xGm[progress * this->tileLength], tileLength);
        AscendC::DataCopy(yLocal, yGm[progress * this->tileLength], tileLength);
        inQueueX.EnQue(xLocal);
        inQueueY.EnQue(yLocal);
    }
    __aicore__ inline void Compute(uint32_t progress, uint32_t tileLength)
    {
        AscendC::LocalTensor<bfloat16_t> xLocal = inQueueX.DeQue<bfloat16_t>();
        AscendC::LocalTensor<bfloat16_t> yLocal = inQueueY.DeQue<bfloat16_t>();
        AscendC::LocalTensor<bfloat16_t> zLocal = outQueueZ.AllocTensor<bfloat16_t>();

        AscendC::LocalTensor<float> tmpTensor0 = tmpBuf0.Get<float>();
        AscendC::LocalTensor<float> tmpTensor1 = tmpBuf1.Get<float>();

        AscendC::Cast(tmpTensor0, xLocal, AscendC::RoundMode::CAST_NONE, tileLength);
        AscendC::Cast(tmpTensor1, yLocal, AscendC::RoundMode::CAST_NONE, tileLength);

        AscendC::Add(tmpTensor0, tmpTensor0, tmpTensor1, tileLength);
        AscendC::Cast(zLocal, tmpTensor0, AscendC::RoundMode::CAST_RINT, tileLength);

        outQueueZ.EnQue<bfloat16_t>(zLocal);
        inQueueX.FreeTensor(xLocal);
        inQueueY.FreeTensor(yLocal);
    }
    __aicore__ inline void CopyOut(uint32_t progress, uint32_t tileLength)
    {
        AscendC::LocalTensor<bfloat16_t> zLocal = outQueueZ.DeQue<bfloat16_t>();
        AscendC::DataCopy(zGm[progress * this->tileLength], zLocal, tileLength);
        outQueueZ.FreeTensor(zLocal);
    }

private:
    AscendC::TPipe* pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> inQueueX;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> inQueueY;
    AscendC::TQue<AscendC::TPosition::VECOUT, BUFFER_NUM> outQueueZ;
    AscendC::TBuf<AscendC::TPosition::VECCALC> tmpBuf0;
    AscendC::TBuf<AscendC::TPosition::VECCALC> tmpBuf1;

    uint32_t initBufferLength;
    AscendC::GlobalTensor<bfloat16_t> xGm;
    AscendC::GlobalTensor<bfloat16_t> yGm;
    AscendC::GlobalTensor<bfloat16_t> zGm;

    uint32_t blockLength;
    uint32_t tileNum;
    uint32_t tileLength;
    uint32_t lastTileLength;
};

template <>
class KernelAdd<int8_t> {
public:
    __aicore__ inline KernelAdd() {}
    __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, AddCustomTilingData tiling, AscendC::TPipe* pipeIn)
    {
        pipe = pipeIn;
        if (tiling.isEvenCore) {
            this->blockLength = tiling.blockLength;
            this->tileNum = tiling.tileNum;
            this->tileLength = tiling.tileLength / BUFFER_NUM;
            this->lastTileLength = tiling.lastTileLength;

            uint64_t offset = this->blockLength * AscendC::GetBlockIdx();
            xGm.SetGlobalBuffer((__gm__ int8_t*)x + offset, this->blockLength);
            yGm.SetGlobalBuffer((__gm__ int8_t*)y + offset, this->blockLength);
            zGm.SetGlobalBuffer((__gm__ int8_t*)z + offset, this->blockLength);
        } else {
            if (AscendC::GetBlockIdx() < tiling.formerNum) {
                this->tileNum = tiling.formerTileNum;
                this->tileLength = tiling.formerTileLength / BUFFER_NUM;
                this->lastTileLength = tiling.formerLastTileLength;

                uint64_t offset = tiling.formerLength * AscendC::GetBlockIdx();
                xGm.SetGlobalBuffer((__gm__ int8_t*)x + offset, tiling.formerLength);
                yGm.SetGlobalBuffer((__gm__ int8_t*)y + offset, tiling.formerLength);
                zGm.SetGlobalBuffer((__gm__ int8_t*)z + offset, tiling.formerLength);
            } else {
                this->tileNum = tiling.tailTileNum;
                this->tileLength = tiling.tailTileLength / BUFFER_NUM;
                this->lastTileLength = tiling.tailLastTileLength;

                uint64_t offset = tiling.formerLength * tiling.formerNum
                                  + tiling.tailLength * (AscendC::GetBlockIdx() - tiling.formerNum);
                xGm.SetGlobalBuffer((__gm__ int8_t*)x + offset, tiling.tailLength);
                yGm.SetGlobalBuffer((__gm__ int8_t*)y + offset, tiling.tailLength);
                zGm.SetGlobalBuffer((__gm__ int8_t*)z + offset, tiling.tailLength);
            }
        }
        this->initBufferLength = AscendC::Std::max(this->tileLength, this->lastTileLength);
        pipe->InitBuffer(inQueueX, BUFFER_NUM, this->initBufferLength * sizeof(int8_t));
        pipe->InitBuffer(inQueueY, BUFFER_NUM, this->initBufferLength * sizeof(int8_t));
        pipe->InitBuffer(outQueueZ, BUFFER_NUM, this->initBufferLength * sizeof(int8_t));

        pipe->InitBuffer(tmpBuf0, this->initBufferLength * sizeof(half));
        pipe->InitBuffer(tmpBuf1, this->initBufferLength * sizeof(half));
    }
    __aicore__ inline void Process()
    {
        uint32_t loopCount = this->tileNum * BUFFER_NUM;
        for (uint32_t i = 0; i < loopCount; i++) {
            CopyIn(i, this->tileLength);
            Compute(i, this->tileLength);
            CopyOut(i, this->tileLength);
        }

        // 进行尾块计算,不做double buffer操作
        if (this->lastTileLength > 0U) {
            CopyIn(loopCount, this->lastTileLength);
            Compute(loopCount, this->lastTileLength);
            CopyOut(loopCount, this->lastTileLength);
        }
    }

private:
    __aicore__ inline void CopyIn(uint32_t progress, uint32_t tileLength)
    {
        AscendC::LocalTensor<int8_t> xLocal = inQueueX.AllocTensor<int8_t>();
        AscendC::LocalTensor<int8_t> yLocal = inQueueY.AllocTensor<int8_t>();
        AscendC::DataCopy(xLocal, xGm[progress * this->tileLength], tileLength);
        AscendC::DataCopy(yLocal, yGm[progress * this->tileLength], tileLength);
        inQueueX.EnQue(xLocal);
        inQueueY.EnQue(yLocal);
    }
    __aicore__ inline void Compute(uint32_t progress, uint32_t tileLength)
    {
        AscendC::LocalTensor<int8_t> xLocal = inQueueX.DeQue<int8_t>();
        AscendC::LocalTensor<int8_t> yLocal = inQueueY.DeQue<int8_t>();
        AscendC::LocalTensor<int8_t> zLocal = outQueueZ.AllocTensor<int8_t>();

        AscendC::LocalTensor<half> tmpTensor0 = tmpBuf0.Get<half>();
        AscendC::LocalTensor<half> tmpTensor1 = tmpBuf1.Get<half>();

        AscendC::Cast(tmpTensor0, xLocal, AscendC::RoundMode::CAST_NONE, tileLength);
        AscendC::Cast(tmpTensor1, yLocal, AscendC::RoundMode::CAST_NONE, tileLength);

        AscendC::Add(tmpTensor0, tmpTensor0, tmpTensor1, tileLength);
        AscendC::Cast(zLocal, tmpTensor0, AscendC::RoundMode::CAST_NONE, tileLength);

        outQueueZ.EnQue<int8_t>(zLocal);
        inQueueX.FreeTensor(xLocal);
        inQueueY.FreeTensor(yLocal);
    }
    __aicore__ inline void CopyOut(uint32_t progress, uint32_t tileLength)
    {
        AscendC::LocalTensor<int8_t> zLocal = outQueueZ.DeQue<int8_t>();
        AscendC::DataCopy(zGm[progress * this->tileLength], zLocal, tileLength);
        outQueueZ.FreeTensor(zLocal);
    }

private:
    AscendC::TPipe* pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> inQueueX;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> inQueueY;
    AscendC::TQue<AscendC::TPosition::VECOUT, BUFFER_NUM> outQueueZ;
    AscendC::TBuf<AscendC::TPosition::VECCALC> tmpBuf0;
    AscendC::TBuf<AscendC::TPosition::VECCALC> tmpBuf1;

    AscendC::GlobalTensor<int8_t> xGm;
    AscendC::GlobalTensor<int8_t> yGm;
    AscendC::GlobalTensor<int8_t> zGm;

    uint32_t initBufferLength;
    uint32_t blockLength;
    uint32_t tileNum;
    uint32_t tileLength;
    uint32_t lastTileLength;
};

template <typename T>
class KernelAdd {
public:
    __aicore__ inline KernelAdd() {}
    __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, AddCustomTilingData tiling, AscendC::TPipe* pipeIn)
    {
        pipe = pipeIn;
        if (tiling.isEvenCore) {
            this->blockLength = tiling.blockLength;
            this->tileNum = tiling.tileNum;
            this->tileLength = tiling.tileLength / BUFFER_NUM;
            this->lastTileLength = tiling.lastTileLength;

            uint64_t offset = this->blockLength * AscendC::GetBlockIdx();
            xGm.SetGlobalBuffer((__gm__ T*)x + offset, this->blockLength);
            yGm.SetGlobalBuffer((__gm__ T*)y + offset, this->blockLength);
            zGm.SetGlobalBuffer((__gm__ T*)z + offset, this->blockLength);
        } else {
            if (AscendC::GetBlockIdx() < tiling.formerNum) {
                this->tileNum = tiling.formerTileNum;
                this->tileLength = tiling.formerTileLength / BUFFER_NUM;
                this->lastTileLength = tiling.formerLastTileLength;

                uint64_t offset = tiling.formerLength * AscendC::GetBlockIdx();
                xGm.SetGlobalBuffer((__gm__ T*)x + offset, tiling.formerLength);
                yGm.SetGlobalBuffer((__gm__ T*)y + offset, tiling.formerLength);
                zGm.SetGlobalBuffer((__gm__ T*)z + offset, tiling.formerLength);
            } else {
                this->tileNum = tiling.tailTileNum;
                this->tileLength = tiling.tailTileLength / BUFFER_NUM;
                this->lastTileLength = tiling.tailLastTileLength;

                uint64_t offset = tiling.formerLength * tiling.formerNum
                                  + tiling.tailLength * (AscendC::GetBlockIdx() - tiling.formerNum);
                xGm.SetGlobalBuffer((__gm__ T*)x + offset, tiling.tailLength);
                yGm.SetGlobalBuffer((__gm__ T*)y + offset, tiling.tailLength);
                zGm.SetGlobalBuffer((__gm__ T*)z + offset, tiling.tailLength);
            }
        }
        this->initBufferLength = AscendC::Std::max(this->tileLength, this->lastTileLength);
        pipe->InitBuffer(inQueueX, BUFFER_NUM, this->initBufferLength * sizeof(T));
        pipe->InitBuffer(inQueueY, BUFFER_NUM, this->initBufferLength * sizeof(T));
        pipe->InitBuffer(outQueueZ, BUFFER_NUM, this->initBufferLength * sizeof(T));
    }
    __aicore__ inline void Process()
    {
        uint32_t loopCount = this->tileNum * BUFFER_NUM;
        for (uint32_t i = 0; i < loopCount; i++) {
            CopyIn(i, this->tileLength);
            Compute(i, this->tileLength);
            CopyOut(i, this->tileLength);
        }

        // 进行尾块计算,不做double buffer操作
        if (this->lastTileLength > 0) {
            CopyIn(loopCount, this->lastTileLength);
            Compute(loopCount, this->lastTileLength);
            CopyOut(loopCount, this->lastTileLength);
        }
    }

private:
    __aicore__ inline void CopyIn(uint32_t progress, uint32_t tileLength)
    {
        AscendC::LocalTensor<T> xLocal = inQueueX.AllocTensor<T>();
        AscendC::LocalTensor<T> yLocal = inQueueY.AllocTensor<T>();
        AscendC::DataCopy(xLocal, xGm[progress * this->tileLength], tileLength);
        AscendC::DataCopy(yLocal, yGm[progress * this->tileLength], tileLength);
        inQueueX.EnQue(xLocal);
        inQueueY.EnQue(yLocal);
    }
    __aicore__ inline void Compute(uint32_t progress, uint32_t tileLength)
    {
        AscendC::LocalTensor<T> xLocal = inQueueX.DeQue<T>();
        AscendC::LocalTensor<T> yLocal = inQueueY.DeQue<T>();
        AscendC::LocalTensor<T> zLocal = outQueueZ.AllocTensor<T>();

        AscendC::Add(zLocal, xLocal, yLocal, tileLength);

        outQueueZ.EnQue<T>(zLocal);
        inQueueX.FreeTensor(xLocal);
        inQueueY.FreeTensor(yLocal);
    }
    __aicore__ inline void CopyOut(uint32_t progress, uint32_t tileLength)
    {
        AscendC::LocalTensor<T> zLocal = outQueueZ.DeQue<T>();
        AscendC::DataCopy(zGm[progress * this->tileLength], zLocal, tileLength);
        outQueueZ.FreeTensor(zLocal);
    }

private:
    AscendC::TPipe* pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> inQueueX;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> inQueueY;
    AscendC::TQue<AscendC::TPosition::VECOUT, BUFFER_NUM> outQueueZ;

    AscendC::GlobalTensor<T> xGm;
    AscendC::GlobalTensor<T> yGm;
    AscendC::GlobalTensor<T> zGm;

    uint32_t initBufferLength;
    uint32_t blockLength;
    uint32_t tileNum;
    uint32_t tileLength;
    uint32_t lastTileLength;
};

__global__ __vector__ void tiling_strategy_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z, AddCustomTilingData tiling)
{
    AscendC::TPipe pipe;
    if (tiling.dataType == ADD_BFLOAT16) {
        KernelAdd<bfloat16_t> op;
        op.Init(x, y, z, tiling, &pipe);
        op.Process();
    } else if (tiling.dataType == ADD_FLOAT16) {
        KernelAdd<half> op;
        op.Init(x, y, z, tiling, &pipe);
        op.Process();
    } else if (tiling.dataType == ADD_FLOAT32) {
        KernelAdd<float> op;
        op.Init(x, y, z, tiling, &pipe);
        op.Process();
    } else if (tiling.dataType == ADD_INT8) {
        KernelAdd<int8_t> op;
        op.Init(x, y, z, tiling, &pipe);
        op.Process();
    } else if (tiling.dataType == ADD_INT16) {
        KernelAdd<int16_t> op;
        op.Init(x, y, z, tiling, &pipe);
        op.Process();
    } else if (tiling.dataType == ADD_INT32) {
        KernelAdd<int32_t> op;
        op.Init(x, y, z, tiling, &pipe);
        op.Process();
    } else {
        return;
    }
}

int32_t main(int32_t argc, char* argv[])
{
    constexpr uint32_t NUM_BLOCKS = 8;
    constexpr uint32_t DATA_TYPE_SIZE[] = {2, 2, 4, 1, 2, 4};
    uint8_t* tiling = nullptr;
    size_t tilingSize = 17 * sizeof(uint32_t);

    aclInit(nullptr);
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);

    uint8_t *xHost, *yHost, *zHost;
    uint8_t *xDevice, *yDevice, *zDevice;

    aclrtMallocHost((void**)(&tiling), tilingSize);
    ReadFile("./input/input_tiling.bin", tilingSize, tiling, tilingSize);

    GenerateTilingData(tiling, NUM_BLOCKS);
    uint32_t dataTypeSize = DATA_TYPE_SIZE[reinterpret_cast<AddCustomTilingData*>(tiling)->dataType];
    size_t inputByteSize = reinterpret_cast<AddCustomTilingData*>(tiling)->totalLength * dataTypeSize;
    size_t outputByteSize = reinterpret_cast<AddCustomTilingData*>(tiling)->totalLength * dataTypeSize;

    aclrtMallocHost((void**)(&xHost), inputByteSize);
    aclrtMallocHost((void**)(&yHost), inputByteSize);
    aclrtMallocHost((void**)(&zHost), outputByteSize);
    aclrtMalloc((void**)&xDevice, inputByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&yDevice, inputByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&zDevice, outputByteSize, ACL_MEM_MALLOC_HUGE_FIRST);

    ReadFile("./input/input_x.bin", inputByteSize, xHost, inputByteSize);
    ReadFile("./input/input_y.bin", inputByteSize, yHost, inputByteSize);

    aclrtMemcpy(xDevice, inputByteSize, xHost, inputByteSize, ACL_MEMCPY_HOST_TO_DEVICE);
    aclrtMemcpy(yDevice, inputByteSize, yHost, inputByteSize, ACL_MEMCPY_HOST_TO_DEVICE);

    tiling_strategy_custom<<<NUM_BLOCKS, nullptr, stream>>>(xDevice, yDevice, zDevice,
                                               *reinterpret_cast<AddCustomTilingData*>(tiling));
    aclrtSynchronizeStream(stream);

    aclrtMemcpy(zHost, outputByteSize, zDevice, outputByteSize, ACL_MEMCPY_DEVICE_TO_HOST);
    WriteFile("./output/output.bin", zHost, outputByteSize);

    aclrtFree(xDevice);
    aclrtFree(yDevice);
    aclrtFree(zDevice);
    aclrtFreeHost(xHost);
    aclrtFreeHost(yHost);
    aclrtFreeHost(zHost);
    aclrtFreeHost(tiling);

    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();

    return 0;
}