/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/


/* !
 * \file mean.asc
 * \brief
 */

#include "acl/acl.h"
#include "data_utils.h"
#include "kernel_operator.h"
#include "tiling/tiling_api.h"

namespace MyCustomKernel {
template <uint16_t Rounds, typename T>
class KernelPhiloxStride {
public:
    __aicore__ inline KernelPhiloxStride() {}
    __aicore__ inline void Init(GM_ADDR dstGm, uint32_t paramStride, uint32_t paramRow, uint32_t paramColumn,
                                AscendC::TPipe* pipeIn)
    {
        ASCENDC_ASSERT(AscendC::GetBlockNum() != 0, { KERNEL_LOG(KERNEL_ERROR, "block dim can not be zero!"); });

        stride = paramStride;
        row = paramRow;
        column = paramColumn;
        count = row * column;
        const int alginSize = AscendC::GetDataBlockSizeInBytes() / sizeof(T);
        dstSize = (count + 256 + alginSize - 1) / alginSize * alginSize;
        dstGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(dstGm), dstSize);
        pipe = pipeIn;
        pipe->InitBuffer(outQueue, 1, dstSize * sizeof(T));
    }
    __aicore__ inline void Process(uint32_t seed0, uint32_t seed1, uint32_t seed2, uint32_t seed3, uint32_t seed4,
                                   uint32_t seed5)
    {
        Compute(seed0, seed1, seed2, seed3, seed4, seed5);
        CopyOut();
    }

private:
    __aicore__ inline void Compute(uint32_t seed0, uint32_t seed1, uint32_t seed2, uint32_t seed3, uint32_t seed4,
                                   uint32_t seed5)
    {
        AscendC::LocalTensor<T> dstLocal = outQueue.AllocTensor<T>();

        __ubuf__ uint8_t* dstLocalUB = (__ubuf__ uint8_t*)dstLocal.GetPhyAddr();
        uint32_t calCount = count + 256;
        for (uint32_t i = count * sizeof(uint32_t); i < calCount * sizeof(uint32_t); i++) { dstLocalUB[i] = 0; }
        AscendC::PipeBarrier<PIPE_V>();
        AscendC::PhiloxRandom<Rounds>(dstLocal, {seed0, seed1}, {seed2, seed3, seed4, seed5}, {stride, row, column});

        outQueue.EnQue<T>(dstLocal);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<T> dstLocal = outQueue.DeQue<T>();
        AscendC::DataCopy(dstGlobal, dstLocal, dstSize);
        outQueue.FreeTensor(dstLocal);
    }

private:
    AscendC::TPipe* pipe;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueue;
    AscendC::GlobalTensor<T> dstGlobal;

    uint32_t count;
    uint32_t stride;
    uint32_t row;
    uint32_t column;
    uint32_t dstSize;
};
} // namespace MyCustomKernel

__vector__ __global__ void philoxrandom_custom(GM_ADDR dstGm)
{
    AscendC::TPipe pipe;
    MyCustomKernel::KernelPhiloxStride<10, float> op;
    op.Init(dstGm, 32, 32, 32, &pipe);
    op.Process(0, 0, 0, 0, 0, 0);
}

constexpr uint32_t NUM_BLOCKS = 8;
constexpr uint32_t COUNT = 1280;

static bool CompareResult(const void* outputData, uint32_t outSize)
{
    void* goldenData;
    aclrtMallocHost((void**)(&goldenData), outSize);
    size_t goldenSize = outSize;
    bool ret = ReadFile("./output/golden.bin", goldenSize, goldenData, goldenSize);
    if (ret) {
        printf("ReadFile golden.bin success!\n");
    } else {
        printf("test failed!\n");
        return false;
    }
    constexpr float EPS = 1e-4;
    int64_t wrongNum = 0;

    for (size_t i = 0; i < outSize / sizeof(float); i++) {
        float a = (reinterpret_cast<const float*>(outputData))[i];
        float b = (reinterpret_cast<const float*>(goldenData))[i];
        float ae = std::abs(a - b);
        float re = ae / std::abs(b);
        if (ae > EPS && re > EPS) {
            printf("CompareResult golden.bin failed output is %lf, golden is %lf\n", a, b);
            wrongNum++;
        }
    }
    aclrtFreeHost(goldenData);
    if (wrongNum != 0) {
        printf("wrongNum: %ld\n", wrongNum);
        return false;
    } else {
        printf("CompareResult golden.bin success!\n");
        return true;
    }
}

int32_t main(int32_t argc, char* argv[])
{
    uint32_t numBlocks = NUM_BLOCKS;
    size_t outputSize = sizeof(float) * COUNT;

    aclInit(nullptr);
    aclrtContext context;
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtCreateContext(&context, deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);

    uint8_t* yHost;
    uint8_t* yDevice;

    aclrtMallocHost((void**)(&yHost), outputSize);
    aclrtMalloc((void**)&yDevice, outputSize, ACL_MEM_MALLOC_HUGE_FIRST);

    // Execute the kernel
    philoxrandom_custom<<<numBlocks, nullptr, stream>>>(yDevice);

    // Wait for the stop event to complete
    aclrtSynchronizeStream(stream);
    // Copy result to host memory and write to output file
    aclrtMemcpy(yHost, outputSize, yDevice, outputSize, ACL_MEMCPY_DEVICE_TO_HOST);
    WriteFile("./output/output.bin", yHost, outputSize);

    // Compare the result with the golden result
    bool goldenResult = true;
    goldenResult = CompareResult(yHost, outputSize);

    // Clean up memory
    aclrtFree(yDevice);
    aclrtFreeHost(yHost);

    aclrtDestroyStream(stream);
    aclrtDestroyContext(context);
    aclrtResetDevice(deviceId);
    aclFinalize();

    if (goldenResult) {
        printf("test pass!\n");
    } else {
        printf("test failed!\n");
    }
    return 0;
}