/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/


/* !
 * \file softmaxgradfront.asc
 * \brief
 */

#include "acl/acl.h"
#include "data_utils.h"
#include "kernel_operator.h"
#include "tiling/tiling_api.h"

namespace optiling {
BEGIN_TILING_DATA_DEF(SoftmaxgradfrontCustomTilingData)
TILING_DATA_FIELD_DEF(uint32_t, columnLength);
TILING_DATA_FIELD_DEF(uint32_t, rowLength);
TILING_DATA_FIELD_DEF(uint32_t, sharedTmpBufferSize);
TILING_DATA_FIELD_DEF(uint32_t, usedNumBlocks);
TILING_DATA_FIELD_DEF(uint32_t, coreRowNum);
TILING_DATA_FIELD_DEF(uint32_t, tailCoreRowNum);
TILING_DATA_FIELD_DEF(uint32_t, singleLoopCoreRowNum);
TILING_DATA_FIELD_DEF(uint32_t, singleCoreLoopCount);
TILING_DATA_FIELD_DEF(uint32_t, singleCoreLoopTail);
TILING_DATA_FIELD_DEF(uint32_t, tailCoreSingleLoopCoreRowNum);
TILING_DATA_FIELD_DEF(uint32_t, tailCoreSingleCoreLoopCount);
TILING_DATA_FIELD_DEF(uint32_t, tailCoreSingleCoreLoopTail);
TILING_DATA_FIELD_DEF_STRUCT(SoftMaxTiling, softmaxTilingData);
END_TILING_DATA_DEF;

REGISTER_TILING_DATA_CLASS(SoftmaxgradfrontCustom, SoftmaxgradfrontCustomTilingData)
} // namespace optiling
namespace SoftmaxgradfrontCustomTiling {
constexpr uint32_t SHARED_TMP_BUFFER_SIZE = 61440; // reserved tmpbuffer 60K for softmax compute
struct SingleCoreLoopParam {
    uint32_t singleLoopCoreRowNum{0}; // row num processed in single loop
    uint32_t singleCoreLoopCount{0};  // loop count in single loop
    uint32_t singleCoreLoopTail{0};   // row num of last loop in single core
};

const std::vector<std::pair<uint32_t, uint32_t>> SLICE_TABLE = {
    // {reduce axis length, slice factor}
    {8192, 1}, {4096, 2}, {2048, 4}, {1024, 8}, {512, 16}, {256, 32}, {0, 64}};

SingleCoreLoopParam GetSingleCoreLoopParam(const uint32_t colNum, const uint32_t coreRowNum)
{
    //  Determine the params of single core based on the reduce axis length
    for (auto param : SLICE_TABLE) {
        if (colNum >= param.first) {
            SingleCoreLoopParam singleCoreLoopParam;
            singleCoreLoopParam.singleLoopCoreRowNum = param.second;
            singleCoreLoopParam.singleCoreLoopCount = coreRowNum / param.second;
            singleCoreLoopParam.singleCoreLoopTail = coreRowNum % param.second;
            return singleCoreLoopParam;
        }
    }
    return {};
}

void ComputeTiling(const uint32_t rowNum, const uint32_t colNum, const uint32_t coreNum,
                   optiling::SoftmaxgradfrontCustomTilingData& tiling)
{
    uint32_t localworkspaceSize = SHARED_TMP_BUFFER_SIZE;
    auto alignedRowNum = (rowNum + coreNum - 1) / coreNum * coreNum;
    auto coreRowNum = alignedRowNum / coreNum; // each core equal distribution
    auto tailCoreRowNum = rowNum % coreRowNum; // last core process the tail rownum
    auto usedNumBlocks = rowNum / coreRowNum;   // the core num used actually

    SingleCoreLoopParam mainCoreLoopParam = GetSingleCoreLoopParam(colNum, coreRowNum);
    SingleCoreLoopParam tailCoreLoopParam;
    if (usedNumBlocks == coreNum && tailCoreRowNum == 0) {
        tailCoreLoopParam = GetSingleCoreLoopParam(colNum, coreRowNum);
    } else {
        tailCoreLoopParam = GetSingleCoreLoopParam(colNum, tailCoreRowNum);
    }

    tiling.set_columnLength(colNum);
    tiling.set_rowLength(rowNum);
    tiling.set_sharedTmpBufferSize(localworkspaceSize);
    tiling.set_usedNumBlocks(usedNumBlocks);
    tiling.set_coreRowNum(coreRowNum);
    tiling.set_tailCoreRowNum(tailCoreRowNum);

    tiling.set_singleLoopCoreRowNum(mainCoreLoopParam.singleLoopCoreRowNum);
    tiling.set_singleCoreLoopCount(mainCoreLoopParam.singleCoreLoopCount);
    tiling.set_singleCoreLoopTail(mainCoreLoopParam.singleCoreLoopTail);
    tiling.set_tailCoreSingleLoopCoreRowNum(tailCoreLoopParam.singleLoopCoreRowNum);
    tiling.set_tailCoreSingleCoreLoopCount(tailCoreLoopParam.singleCoreLoopCount);
    tiling.set_tailCoreSingleCoreLoopTail(tailCoreLoopParam.singleCoreLoopTail);

    ge::Shape softmaxComputeShape({mainCoreLoopParam.singleLoopCoreRowNum, colNum});
    uint32_t apiNeedMinTmpSize = AscendC::GetSoftMaxGradMinTmpSize(softmaxComputeShape, sizeof(float), true, true);
    if (apiNeedMinTmpSize > SHARED_TMP_BUFFER_SIZE) {
        localworkspaceSize = apiNeedMinTmpSize;
    } else {
        localworkspaceSize = SHARED_TMP_BUFFER_SIZE;
    }
    // get tiling
    AscendC::SoftMaxGradTilingFunc(softmaxComputeShape, sizeof(float), localworkspaceSize, tiling.softmaxTilingData,
                                   true);
}
} // namespace SoftmaxgradfrontCustomTiling

void GenerateTiling(const uint32_t rowNum, const uint32_t colNum, const uint32_t coreNum, const uint32_t tilingSize,
                    uint8_t* tilingBuffer)
{
    optiling::SoftmaxgradfrontCustomTilingData tiling;
    SoftmaxgradfrontCustomTiling::ComputeTiling(rowNum, colNum, coreNum, tiling);

    // Copy tiling to tilingBuffer
    tiling.SaveToBuffer(tilingBuffer, tilingSize);
}

namespace MyCustomKernel {
constexpr int32_t BUFFER_NUM = 1;
constexpr uint32_t FLOAT_NUM_OF_SINGEL_BLOCK = 8;
constexpr uint32_t BASIC_BLOCK_ROW_FACTOR = 8;
constexpr uint32_t BASIC_BLOCK_COLUMN_FACTOR = 64;
constexpr uint32_t BASIC_BLOCK_MAX_COLUMN_LENGTH = 2048;
struct VecTiling {
    uint32_t columnLength = 0;
    uint32_t rowLength = 0;
    uint32_t sharedTmpBufferSize = 0;
    uint32_t usedNumBlocks = 0;
    uint32_t coreRowNum = 0;
    uint32_t tailCoreRowNum = 0;
    uint32_t singleLoopCoreRowNum = 0;
    uint32_t singleCoreLoopCount = 0;
    uint32_t singleCoreLoopTail = 0;
    uint32_t tailCoreSingleLoopCoreRowNum = 0;
    uint32_t tailCoreSingleCoreLoopCount = 0;
    uint32_t tailCoreSingleCoreLoopTail = 0;
    SoftMaxTiling softmaxTilingData;
};

class KernelSoftmax {
public:
    __aicore__ inline KernelSoftmax() {}
    __aicore__ inline void InitTiling(const VecTiling& tilingData)
    {
        rowLength = tilingData.rowLength;
        sharedTmpBufferSize = tilingData.sharedTmpBufferSize;
        columnLength = tilingData.columnLength;
        usedNumBlocks = tilingData.usedNumBlocks;
        coreRowNum = tilingData.coreRowNum;
        softmaxTiling = tilingData.softmaxTilingData;
        singleLoopCoreRowNum = tilingData.singleLoopCoreRowNum;
        singleCoreLoopCount = tilingData.singleCoreLoopCount;
        leftRow = tilingData.singleCoreLoopTail;
        tailCoreSingleLoopCoreRowNum = tilingData.tailCoreSingleLoopCoreRowNum;
        tailCoreSingleCoreLoopCount = tilingData.tailCoreSingleCoreLoopCount;
        tailCoreSingleCoreLoopTail = tilingData.tailCoreSingleCoreLoopTail;
    }
    __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, const VecTiling& tiling, AscendC::TPipe* pipeIn)
    {
        ASSERT(AscendC::GetBlockNum() != 0 && "block dim can not be zero!");
        InitTiling(tiling);
        pipe = pipeIn;

        if (AscendC::GetBlockIdx() == this->usedNumBlocks) { // tail core
            this->singleLoopCoreRowNum = this->tailCoreSingleLoopCoreRowNum;
            this->singleCoreLoopCount = this->tailCoreSingleCoreLoopCount;
            this->leftRow = this->tailCoreSingleCoreLoopTail;
        }

        this->blockLength = this->coreRowNum * this->columnLength;
        uint32_t offset1 = this->blockLength * AscendC::GetBlockIdx();
        uint32_t offset2 = this->coreRowNum * FLOAT_NUM_OF_SINGEL_BLOCK * AscendC::GetBlockIdx();

        xGm.SetGlobalBuffer((__gm__ float*)x + offset1, this->blockLength);
        yGm.SetGlobalBuffer((__gm__ float*)y + offset1, this->blockLength);
        zGm.SetGlobalBuffer((__gm__ float*)z + offset2, this->coreRowNum * FLOAT_NUM_OF_SINGEL_BLOCK);

        this->tileLength = this->singleLoopCoreRowNum * this->columnLength;
        pipe->InitBuffer(queueX, BUFFER_NUM, this->tileLength * sizeof(float));
        pipe->InitBuffer(queueY, BUFFER_NUM, this->tileLength * sizeof(float));
        pipe->InitBuffer(queueZ, BUFFER_NUM, this->singleLoopCoreRowNum * FLOAT_NUM_OF_SINGEL_BLOCK * sizeof(float));

        pipe->InitBuffer(sharedTmpBuffer, sharedTmpBufferSize);
    }

    __aicore__ inline void Process()
    {
        if (AscendC::GetBlockIdx() > usedNumBlocks) {
            return;
        }

        for (int32_t i = 0; i < this->singleCoreLoopCount; i++) {
            CopyIn(i, this->singleLoopCoreRowNum);
            Compute(i, this->singleLoopCoreRowNum);
            CopyOut(i, this->singleLoopCoreRowNum);
        }

        if (this->leftRow > 0) {
            CopyIn(this->singleCoreLoopCount, this->leftRow);
            Compute(this->singleCoreLoopCount, this->leftRow);
            CopyOut(this->singleCoreLoopCount, this->leftRow);
        }
    }

private:
    __aicore__ inline void CopyIn(int32_t progress, uint32_t rowNum)
    {
        AscendC::LocalTensor<float> xLocal = queueX.AllocTensor<float>();
        AscendC::LocalTensor<float> yLocal = queueY.AllocTensor<float>();
        AscendC::DataCopy(xLocal, xGm[progress * this->tileLength], rowNum * this->columnLength);
        AscendC::DataCopy(yLocal, yGm[progress * this->tileLength], rowNum * this->columnLength);
        queueX.EnQue(xLocal);
        queueY.EnQue(yLocal);
    }

    __aicore__ inline void Compute(int32_t progressm, uint32_t rowNum)
    {
        AscendC::LocalTensor<float> xLocal = queueX.DeQue<float>();
        AscendC::LocalTensor<float> yLocal = queueY.DeQue<float>();
        AscendC::LocalTensor<float> zLocal = queueZ.AllocTensor<float>();
        AscendC::LocalTensor<uint8_t> tmpBuffer = sharedTmpBuffer.Get<uint8_t>();

        AscendC::SoftMaxShapeInfo srcShape = {rowNum, this->columnLength, rowNum, this->columnLength};
        if (this->singleLoopCoreRowNum % BASIC_BLOCK_ROW_FACTOR == 0
            && this->columnLength % BASIC_BLOCK_COLUMN_FACTOR == 0
            && this->columnLength < BASIC_BLOCK_MAX_COLUMN_LENGTH) {
            AscendC::SoftmaxGradFront<float, true>(zLocal, xLocal, yLocal, tmpBuffer, softmaxTiling, srcShape);
        } else {
            AscendC::SoftmaxGradFront<float>(zLocal, xLocal, yLocal, tmpBuffer, softmaxTiling, srcShape);
        }

        queueZ.EnQue(zLocal);
        queueX.FreeTensor(xLocal);
        queueY.FreeTensor(yLocal);
    }

    __aicore__ inline void CopyOut(int32_t progress, uint32_t rowNum)
    {
        AscendC::LocalTensor<float> zLocal = queueZ.DeQue<float>();
        AscendC::DataCopy(zGm[progress * this->singleLoopCoreRowNum * FLOAT_NUM_OF_SINGEL_BLOCK], zLocal,
                          rowNum * FLOAT_NUM_OF_SINGEL_BLOCK);
        queueZ.FreeTensor(zLocal);
    }

private:
    AscendC::TPipe* pipe;
    AscendC::TBuf<AscendC::TPosition::VECCALC> sharedTmpBuffer;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> queueX;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> queueY;
    AscendC::TQue<AscendC::TPosition::VECOUT, BUFFER_NUM> queueZ;
    AscendC::GlobalTensor<float> xGm;
    AscendC::GlobalTensor<float> yGm;
    AscendC::GlobalTensor<float> zGm;

    uint32_t blockLength = 0;
    uint32_t usedNumBlocks = 0;
    uint32_t rowLength = 0;
    uint32_t columnLength = 0;
    uint32_t coreRowNum = 0;
    uint32_t tileLength = 0;
    uint32_t msTileLength = 0;
    uint32_t loopCount = 0;
    uint32_t sharedTmpBufferSize = 0;
    uint32_t singleLoopCoreRowNum = 0;
    uint32_t singleCoreLoopCount = 0;
    uint32_t leftRow = 0;
    uint32_t tailCoreSingleLoopCoreRowNum = 0;
    uint32_t tailCoreSingleCoreLoopCount = 0;
    uint32_t tailCoreSingleCoreLoopTail = 0;
    SoftMaxTiling softmaxTiling;
};
} // namespace MyCustomKernel

__aicore__ inline void CopyTiling(MyCustomKernel::VecTiling* tiling, GM_ADDR tilingGM)
{
    uint32_t* ptr = reinterpret_cast<uint32_t*>(tiling);
    auto tiling32 = reinterpret_cast<__gm__ uint32_t*>(tilingGM);

    for (int i = 0; i < sizeof(MyCustomKernel::VecTiling) / sizeof(uint32_t); i++, ptr++) { *ptr = *(tiling32 + i); }
    return;
}

extern "C" __global__ __aicore__ void softmaxgradfront_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z, GM_ADDR workspace,
                                                              GM_ADDR tiling)
{
    KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
    AscendC::TPipe pipe;
    MyCustomKernel::VecTiling tilingData;
    CopyTiling(&tilingData, tiling);
    MyCustomKernel::KernelSoftmax op;
    op.Init(x, y, z, tilingData, &pipe);
    op.Process();
}

constexpr uint32_t ROW_NUM = 960;
constexpr uint32_t COLUMN_NUM = 960;
constexpr uint32_t USED_CORE_NUM = 40;
constexpr uint32_t WORKSPACE_SIZE = 1024;
constexpr uint32_t TILINGDATA_SIZE = 28; // Element count of struct SoftmaxgradfrontCustomTilingData
constexpr uint32_t FLOAT_NUM_PER_BLOCK = 8;

static int64_t CompareResult(void* outputData, const int64_t outSize)
{
    void* goldenData;
    aclrtMallocHost((void**)(&goldenData), outSize);
    size_t goldenSize = outSize;
    bool ret = ReadFile("./output/golden.bin", goldenSize, goldenData, goldenSize);
    if (ret) {
        printf("ReadFile golden success!\n");
    } else {
        aclrtFreeHost(goldenData);
        return -1;
    }
    constexpr float EPS = 1e-5;
    int64_t wrongNum = 0;

    for (int i = 0; i < outSize / sizeof(float); i++) {
        float a = ((float*)outputData)[i];
        float b = ((float*)goldenData)[i];
        float ae = std::abs(a - b);
        float re = ae / std::abs(b);
        if (ae > EPS && re > EPS) {
            printf("CompareResult failed output is %lf, golden is %lf\n", a, b);
            wrongNum++;
        }
    }
    aclrtFreeHost(goldenData);
    return wrongNum;
}

int32_t main(int32_t argc, char* argv[])
{
    size_t xSize = ROW_NUM * ROW_NUM * sizeof(float);
    size_t workspaceSize = WORKSPACE_SIZE * sizeof(float);
    size_t tilingSize = TILINGDATA_SIZE * sizeof(uint32_t);
    size_t ySize = ROW_NUM * ROW_NUM * sizeof(float);
    size_t zSize = ROW_NUM * FLOAT_NUM_PER_BLOCK * sizeof(float);
    int64_t wrongNum = -1;
    // Initialize resources
    aclInit(nullptr);
    aclrtContext context;
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtCreateContext(&context, deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);

    uint8_t *xHost, *zHost, *yHost, *workspaceHost, *tilingHost;
    uint8_t *xDevice, *zDevice, *yDevice, *workspaceDevice, *tilingDevice;

    // Allocate host memory and device memory
    aclrtMallocHost((void**)(&xHost), xSize);
    aclrtMallocHost((void**)(&yHost), ySize);
    aclrtMallocHost((void**)(&zHost), zSize);
    aclrtMallocHost((void**)(&workspaceHost), workspaceSize);
    aclrtMallocHost((void**)(&tilingHost), tilingSize);

    aclrtMalloc((void**)&xDevice, xSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&yDevice, ySize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&zDevice, zSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&workspaceDevice, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&tilingDevice, tilingSize, ACL_MEM_MALLOC_HUGE_FIRST);

    ReadFile("./input/input_x.bin", xSize, xHost, xSize);
    ReadFile("./input/input_y.bin", ySize, yHost, ySize);
    ReadFile("./input/workspace.bin", workspaceSize, workspaceHost, workspaceSize);

    GenerateTiling(ROW_NUM, COLUMN_NUM, USED_CORE_NUM, tilingSize, tilingHost);

    // Copy host memory to device memory
    aclrtMemcpy(xDevice, xSize, xHost, xSize, ACL_MEMCPY_HOST_TO_DEVICE);
    aclrtMemcpy(yDevice, ySize, yHost, ySize, ACL_MEMCPY_HOST_TO_DEVICE);
    aclrtMemcpy(workspaceDevice, workspaceSize, workspaceHost, workspaceSize, ACL_MEMCPY_HOST_TO_DEVICE);
    aclrtMemcpy(tilingDevice, tilingSize, tilingHost, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);

    // Execute the kernel
    softmaxgradfront_custom<<<USED_CORE_NUM, nullptr, stream>>>(xDevice, yDevice, zDevice, workspaceDevice,
                                                                tilingDevice);

    // Wait for the stop event to complete
    aclrtSynchronizeStream(stream);

    // Copy result to host memory and write to output file
    aclrtMemcpy(zHost, zSize, zDevice, zSize, ACL_MEMCPY_DEVICE_TO_HOST);
    WriteFile("./output/output_z.bin", zHost, zSize);

    // Compare the result with the golden result
    wrongNum = CompareResult(zHost, zSize);

    // Clean up memory
    aclrtFree(xDevice);
    aclrtFree(zDevice);
    aclrtFree(yDevice);
    aclrtFree(workspaceDevice);
    aclrtFree(tilingDevice);
    aclrtFreeHost(xHost);
    aclrtFreeHost(zHost);
    aclrtFreeHost(yHost);
    aclrtFreeHost(workspaceHost);
    aclrtFreeHost(tilingHost);

    aclrtDestroyStream(stream);
    aclrtDestroyContext(context);
    aclrtResetDevice(deviceId);
    aclFinalize();
    if (wrongNum != 0) {
        printf("test failed!\n");
    } else {
        printf("test pass!\n");
    }
    return 0;
}