/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/


/* !
 * \file layernorm.asc
 * \brief
 */

#include "acl/acl.h"
#include "data_utils.h"
#include "kernel_operator.h"
#include "tiling/tiling_api.h"

namespace optiling {
BEGIN_TILING_DATA_DEF(LayernormCustomTilingData)
TILING_DATA_FIELD_DEF_STRUCT(LayerNormTiling, layernormTilingData);
TILING_DATA_FIELD_DEF(float, epsilon);
END_TILING_DATA_DEF;

REGISTER_TILING_DATA_CLASS(LayernormCustom, LayernormCustomTilingData)
} // namespace optiling

void ComputeTiling(const uint32_t bLength, const uint32_t sLength, const uint32_t hLength,
                   optiling::LayernormCustomTilingData& tiling)
{
    ge::Shape geShape({bLength, sLength, hLength, hLength});

    uint32_t maxTmpSize = 0;
    uint32_t minTmpSize = 0;
    bool isReuseSource = false;
    AscendC::GetLayerNormMaxMinTmpSize(geShape, sizeof(float), isReuseSource, maxTmpSize, minTmpSize);
    uint32_t localWorkspaceSize = minTmpSize;
    // get layernorm Tiling
    AscendC::GetLayerNormNDTilingInfo(geShape, localWorkspaceSize, sizeof(float), isReuseSource,
                                      tiling.layernormTilingData);
    tiling.set_epsilon(0.0001);
}

uint8_t* GetTilingBuf(optiling::LayernormCustomTilingData* tilingData)
{
    uint32_t tilingSize = sizeof(optiling::LayernormCustomTilingData);
    uint8_t* buf = (uint8_t*)malloc(tilingSize);
    tilingData->SaveToBuffer(buf, tilingSize);
    return buf;
}

uint8_t* GenerateTiling(uint32_t bLength, uint32_t sLength, uint32_t hLength)
{
    optiling::LayernormCustomTilingData tiling;
    ComputeTiling(bLength, sLength, hLength, tiling);
    return GetTilingBuf(&tiling);
}

namespace MyCustomKernel {
struct VecTiling {
    LayerNormTiling layernormTilingData;
    float epsilon = 0;
};

template <bool isReuseSource = false>
class KernelLayernorm {
public:
    __aicore__ inline KernelLayernorm() {}
    __aicore__ inline void Init(GM_ADDR inputXGm, GM_ADDR gammGm, GM_ADDR betaGm, GM_ADDR outputGm,
                                GM_ADDR outputMeanGm, GM_ADDR outputVarianceGm, VecTiling tilingData,
                                AscendC::TPipe* pipeIn)
    {
        pipe = pipeIn;
        this->epsilon = tilingData.epsilon;
        tiling_ = tilingData.layernormTilingData;
        this->bLength = tiling_.bLength;
        this->sLength = tiling_.sLength;
        this->hLength = tiling_.hLength;

        bshLength = bLength * sLength * hLength;
        bsLength = bLength * sLength;

        inputXGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(inputXGm), bshLength);
        gammGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(gammGm), hLength);
        betaGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(betaGm), hLength);

        outputGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(outputGm), bshLength);
        outputMeanGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(outputMeanGm), bsLength);
        outputVarianceGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(outputVarianceGm), bsLength);

        pipe->InitBuffer(inQueueX, 1, sizeof(float) * bshLength);
        pipe->InitBuffer(inQueueGamma, 1, sizeof(float) * hLength);
        pipe->InitBuffer(inQueueBeta, 1, sizeof(float) * hLength);
        pipe->InitBuffer(outQueue, 1, sizeof(float) * bshLength);
        pipe->InitBuffer(outQueueMean, 1, sizeof(float) * bsLength);
        pipe->InitBuffer(outQueueVariance, 1, sizeof(float) * bsLength);
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<float> inputXLocal = inQueueX.AllocTensor<float>();
        AscendC::LocalTensor<float> gammaLocal = inQueueGamma.AllocTensor<float>();
        AscendC::LocalTensor<float> betaLocal = inQueueBeta.AllocTensor<float>();

        AscendC::DataCopy(inputXLocal, inputXGlobal, bshLength);
        AscendC::DataCopy(gammaLocal, gammGlobal, hLength);
        AscendC::DataCopy(betaLocal, betaGlobal, hLength);

        inQueueX.EnQue(inputXLocal);
        inQueueGamma.EnQue(gammaLocal);
        inQueueBeta.EnQue(betaLocal);
    }
    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<float> inputXLocal = inQueueX.DeQue<float>();
        AscendC::LocalTensor<float> gammaLocal = inQueueGamma.DeQue<float>();
        AscendC::LocalTensor<float> betaLocal = inQueueBeta.DeQue<float>();

        AscendC::LocalTensor<float> outputLocal = outQueue.AllocTensor<float>();
        AscendC::LocalTensor<float> meanLocal = outQueueMean.AllocTensor<float>();
        AscendC::LocalTensor<float> varianceLocal = outQueueVariance.AllocTensor<float>();

        AscendC::LayerNorm<float, isReuseSource>(outputLocal, meanLocal, varianceLocal, inputXLocal, gammaLocal,
                                                 betaLocal, (float)epsilon, tiling_);

        outQueue.EnQue<float>(outputLocal);
        outQueueMean.EnQue<float>(meanLocal);
        outQueueVariance.EnQue<float>(varianceLocal);

        inQueueX.FreeTensor(inputXLocal);
        inQueueGamma.FreeTensor(gammaLocal);
        inQueueBeta.FreeTensor(betaLocal);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<float> outputLocal = outQueue.DeQue<float>();
        AscendC::LocalTensor<float> meanLocal = outQueueMean.DeQue<float>();
        AscendC::LocalTensor<float> varianceLocal = outQueueVariance.DeQue<float>();

        AscendC::DataCopy(outputGlobal, outputLocal, bshLength);
        AscendC::DataCopy(outputMeanGlobal, meanLocal, bsLength);
        AscendC::DataCopy(outputVarianceGlobal, varianceLocal, bsLength);

        outQueue.FreeTensor(outputLocal);
        outQueueMean.FreeTensor(meanLocal);
        outQueueVariance.FreeTensor(varianceLocal);
    }

private:
    AscendC::GlobalTensor<float> inputXGlobal;
    AscendC::GlobalTensor<float> gammGlobal;
    AscendC::GlobalTensor<float> betaGlobal;
    AscendC::GlobalTensor<float> outputGlobal;
    AscendC::GlobalTensor<float> outputMeanGlobal;
    AscendC::GlobalTensor<float> outputVarianceGlobal;

    AscendC::TPipe* pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueX;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueGamma;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueBeta;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueue;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueueMean;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueueVariance;

    uint32_t bLength;
    uint32_t sLength;
    uint32_t hLength;
    float epsilon;
    LayerNormTiling tiling_;

    uint32_t bshLength;
    uint32_t bsLength;
};
} // namespace MyCustomKernel

__aicore__ inline void CopyTiling(MyCustomKernel::VecTiling* tiling, GM_ADDR tilingGM)
{
    uint32_t* ptr = reinterpret_cast<uint32_t*>(tiling);
    auto tiling32 = reinterpret_cast<__gm__ uint32_t*>(tilingGM);

    for (int i = 0; i < sizeof(MyCustomKernel::VecTiling) / sizeof(uint32_t); i++, ptr++) { *ptr = *(tiling32 + i); }
    return;
}

extern "C" __global__ __aicore__ void layernorm_custom(GM_ADDR inputXGm, GM_ADDR gammaGm, GM_ADDR betaGm,
                                                       GM_ADDR outputGm, GM_ADDR outputMeanGm, GM_ADDR outputVarianceGm,
                                                       GM_ADDR workspace, GM_ADDR tiling)
{
    KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
    AscendC::TPipe pipe;
    MyCustomKernel::VecTiling tilingData;
    CopyTiling(&tilingData, tiling);
    MyCustomKernel::KernelLayernorm<false> op;
    op.Init(inputXGm, gammaGm, betaGm, outputGm, outputMeanGm, outputVarianceGm, tilingData, &pipe);
    op.Process();
}

constexpr uint32_t BLENGTH = 2;
constexpr uint32_t SLENGTH = 32;
constexpr uint32_t HLENGTH = 16;
constexpr uint32_t NUM_BLOCKS = 40;
constexpr uint32_t TILINGDATA_SIZE = 27;
constexpr uint32_t WORKSPACE_SIZE = 1024;

static bool CompareResult(const void* outputData, int64_t outSize, std::string goldenName)
{
    void* goldenData;
    aclrtMallocHost((void**)(&goldenData), outSize);
    size_t goldenSize = outSize;
    bool ret = ReadFile("./output/golden_output_" + goldenName + ".bin", goldenSize, goldenData, goldenSize);
    if (ret) {
        printf("ReadFile golden_output_%s.bin success!\n", goldenName.c_str());
    } else {
        aclrtFreeHost(goldenData);
        return false;
    }
    constexpr float EPS = 1e-5;
    int64_t wrongNum = 0;

    for (int i = 0; i < outSize / sizeof(float); i++) {
        float a = (reinterpret_cast<const float*>(outputData))[i];
        float b = (reinterpret_cast<const float*>(goldenData))[i];
        float ae = std::abs(a - b);
        float re = ae / std::abs(b);
        if (ae > EPS && re > EPS) {
            printf("CompareResult golden_output_%s.bin failed output is %lf, golden is %lf\n", goldenName.c_str(), a,
                   b);
            wrongNum++;
        }
    }
    aclrtFreeHost(goldenData);
    if (wrongNum != 0) {
        return false;
    } else {
        printf("CompareResult golden_output_%s.bin success!\n", goldenName.c_str());
        return true;
    }
}

int32_t main(int32_t argc, char* argv[])
{
    uint32_t numBlocks = NUM_BLOCKS;
    size_t workspaceSize = WORKSPACE_SIZE * sizeof(float);
    size_t xSize = BLENGTH * SLENGTH * HLENGTH * sizeof(float);
    size_t gammaSize = HLENGTH * sizeof(float);
    size_t betaSize = HLENGTH * sizeof(float);
    size_t outputSize = BLENGTH * SLENGTH * HLENGTH * sizeof(float);
    size_t meanSize = BLENGTH * SLENGTH * sizeof(float);
    size_t varianceSize = BLENGTH * SLENGTH * sizeof(float);
    size_t tilingFileSize = TILINGDATA_SIZE * sizeof(uint32_t);
    bool goldenResult = true;
    uint8_t* tilingBuf = GenerateTiling(BLENGTH, SLENGTH, HLENGTH);

    aclInit(nullptr);
    aclrtContext context;
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtCreateContext(&context, deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);

    uint8_t *inputXHost, *gammaHost, *betaHost, *resultHost, *meanHost, *varianceHost, *workspaceHost, *tilingHost;
    uint8_t *inputXDevice, *gammaDevice, *betaDevice, *resultDevice, *meanDevice, *varianceDevice, *workspaceDevice,
        *tilingDevice;

    aclrtMallocHost((void**)(&inputXHost), xSize);
    aclrtMallocHost((void**)(&gammaHost), gammaSize);
    aclrtMallocHost((void**)(&betaHost), betaSize);
    aclrtMallocHost((void**)(&resultHost), outputSize);
    aclrtMallocHost((void**)(&meanHost), meanSize);
    aclrtMallocHost((void**)(&varianceHost), varianceSize);
    aclrtMallocHost((void**)(&workspaceHost), workspaceSize);
    aclrtMallocHost((void**)(&tilingHost), tilingFileSize);
    aclrtMalloc((void**)&inputXDevice, xSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&gammaDevice, gammaSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&betaDevice, betaSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&resultDevice, outputSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&meanDevice, meanSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&varianceDevice, varianceSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&workspaceDevice, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&tilingDevice, tilingFileSize, ACL_MEM_MALLOC_HUGE_FIRST);

    ReadFile("./input/input_inputX.bin", xSize, inputXHost, xSize);
    ReadFile("./input/input_gamma.bin", gammaSize, gammaHost, gammaSize);
    ReadFile("./input/input_beta.bin", betaSize, betaHost, betaSize);

    aclrtMemcpy(workspaceDevice, workspaceSize, workspaceHost, workspaceSize, ACL_MEMCPY_HOST_TO_DEVICE);
    aclrtMemcpy(tilingDevice, tilingFileSize, tilingBuf, tilingFileSize, ACL_MEMCPY_HOST_TO_DEVICE);

    aclrtMemcpy(inputXDevice, xSize, inputXHost, xSize, ACL_MEMCPY_HOST_TO_DEVICE);
    aclrtMemcpy(gammaDevice, gammaSize, gammaHost, gammaSize, ACL_MEMCPY_HOST_TO_DEVICE);
    aclrtMemcpy(betaDevice, betaSize, betaHost, betaSize, ACL_MEMCPY_HOST_TO_DEVICE);

    layernorm_custom<<<numBlocks, nullptr, stream>>>(inputXDevice, gammaDevice, betaDevice, resultDevice, meanDevice,
                                                    varianceDevice, workspaceDevice, tilingDevice);
    aclrtSynchronizeStream(stream);
    aclrtMemcpy(resultHost, outputSize, resultDevice, outputSize, ACL_MEMCPY_DEVICE_TO_HOST);
    aclrtMemcpy(meanHost, meanSize, meanDevice, meanSize, ACL_MEMCPY_DEVICE_TO_HOST);

    aclrtMemcpy(varianceHost, varianceSize, varianceDevice, varianceSize, ACL_MEMCPY_DEVICE_TO_HOST);
    WriteFile("./output/output_result.bin", resultHost, outputSize);
    WriteFile("./output/output_mean.bin", meanHost, meanSize);
    WriteFile("./output/output_variance.bin", varianceHost, varianceSize);

    goldenResult &= CompareResult(resultHost, outputSize, "result");
    goldenResult &= CompareResult(meanHost, meanSize, "mean");
    goldenResult &= CompareResult(varianceHost, varianceSize, "variance");

    aclrtFree(inputXDevice);
    aclrtFree(gammaDevice);
    aclrtFree(betaDevice);
    aclrtFree(resultDevice);
    aclrtFree(meanDevice);
    aclrtFree(varianceDevice);
    aclrtFree(workspaceDevice);
    aclrtFree(tilingDevice);
    aclrtFreeHost(inputXHost);
    aclrtFreeHost(gammaHost);
    aclrtFreeHost(betaHost);
    aclrtFreeHost(resultHost);
    aclrtFreeHost(meanHost);
    aclrtFreeHost(varianceHost);
    aclrtFreeHost(workspaceHost);
    aclrtFreeHost(tilingHost);

    aclrtDestroyStream(stream);
    aclrtDestroyContext(context);
    aclrtResetDevice(deviceId);
    aclFinalize();
    free(tilingBuf);
    if (goldenResult) {
        printf("test pass!\n");
    } else {
        printf("test failed!\n");
    }
    return 0;
}