/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/


/* !
 * \file reduce_custom.asc
 * \brief
 */

#include <cstdint>
#include <iostream>
#include <vector>
#include <algorithm>
#include <iterator>
#include "acl/acl.h"
#include "kernel_operator.h"
#include "data_utils.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/tiling_api.h"

struct SumCustomTilingData {
    uint32_t inner;
    uint32_t outter;
    uint32_t n;
    uint32_t tmpBufSize;
    uint32_t out_inner;
};

namespace {
constexpr uint32_t PADDING_BYTE = 32U;
}

void GenerateTilingData(uint8_t* tilingBuf, const uint32_t M, const uint32_t N)
{
    uint32_t minValue = 0;
    uint32_t maxValue = 0;

    AscendC::GetSumMaxMinTmpSize(N, sizeof(uint32_t), false, maxValue, minValue);

    SumCustomTilingData* tiling = reinterpret_cast<SumCustomTilingData*>(tilingBuf);

    auto paddingFunc = [](const uint32_t n, const uint32_t typeSize) -> uint32_t {
        if (typeSize == 0) {
            return 0;
        }
        return (n * typeSize + PADDING_BYTE - 1U) / PADDING_BYTE * PADDING_BYTE / typeSize;
    };

    tiling->outter = M;
    tiling->inner = paddingFunc(N, sizeof(uint32_t));
    tiling->n = N;
    tiling->tmpBufSize = minValue;

    tiling->out_inner = paddingFunc(M, sizeof(uint32_t));
}

namespace MyCustomKernel {
template <typename T>
class KernelSum {
public:
    __aicore__ inline KernelSum() {}
    __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, SumCustomTilingData tilingData, AscendC::TPipe* pipeIn)
    {
        ASCENDC_ASSERT(AscendC::GetBlockNum() != 0, { KERNEL_LOG(KERNEL_ERROR, "block dim can not be zero!"); });
        inner = tilingData.inner;
        outter = tilingData.outter;
        n = tilingData.n;
        tmpBufSize = tilingData.tmpBufSize;
        out_inner = tilingData.out_inner;

        params.inner = inner;
        params.outter = outter;
        params.n = n;

        xGm.SetGlobalBuffer((__gm__ T*)x);
        yGm.SetGlobalBuffer((__gm__ T*)y);

        pipe = pipeIn;
        pipe->InitBuffer(inQueue, 1, inner * outter * sizeof(T));
        pipe->InitBuffer(outQueue, 1, out_inner * sizeof(T));
        pipe->InitBuffer(tmpBuf, tmpBufSize * sizeof(uint8_t));
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<T> xLocal = inQueue.AllocTensor<T>();
        AscendC::DataCopy(xLocal, xGm, inner * outter);
        inQueue.EnQue(xLocal);
    }
    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<T> xLocal = inQueue.DeQue<T>();
        AscendC::LocalTensor<T> yLocal = outQueue.AllocTensor<T>();
        AscendC::LocalTensor<uint8_t> sharedTmpBuffer = tmpBuf.AllocTensor<uint8_t>();

        T scalar(0);
        AscendC::Duplicate<T>(yLocal, scalar, out_inner);
        AscendC::Sum(yLocal, xLocal, sharedTmpBuffer, params);

        outQueue.EnQue<T>(yLocal);
        inQueue.FreeTensor(xLocal);
        tmpBuf.FreeTensor(sharedTmpBuffer);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<T> yLocal = outQueue.DeQue<T>();
        AscendC::DataCopy(yGm, yLocal, out_inner);
        outQueue.FreeTensor(yLocal);
    }

private:
    AscendC::TPipe* pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueue;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueue;
    AscendC::TBuf<AscendC::TPosition::VECCALC> tmpBuf;
    AscendC::GlobalTensor<T> xGm;
    AscendC::GlobalTensor<T> yGm;

    uint32_t inner = 0;
    uint32_t outter = 0;
    uint32_t n = 0;
    uint32_t tmpBufSize = 0;
    uint32_t out_inner = 0;
    AscendC::SumParams params;
};
} // namespace MyCustomKernel

__global__ __aicore__ void sum_custom(GM_ADDR srcGm, GM_ADDR dstGm, SumCustomTilingData tiling)
{
    KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
    if ASCEND_IS_AIC {
        return;
    }
    AscendC::TPipe pipe;
    MyCustomKernel::KernelSum<float> op;
    op.Init(srcGm, dstGm, tiling, &pipe);
    op.Process();
}

namespace {
constexpr uint32_t USED_CORE_NUM = 1;
constexpr uint32_t TILINGDATA_SIZE = 5;
constexpr uint32_t M = 7;    // outter
constexpr uint32_t N = 2023; // inner_actual
} // namespace

static bool CompareResult(const void* outputData, uint32_t outSize)
{
    void* goldenData;

    aclrtMallocHost((void**)(&goldenData), outSize);

    size_t goldenSize = outSize;
    bool ret = ReadFile("./output/golden.bin", goldenSize, goldenData, goldenSize);
    if (ret) {
        printf("ReadFile golden.bin success!\n");
    } else {
        printf("test failed!\n");
        return false;
    }
    constexpr float EPS = 1e-4;
    int64_t wrongNum = 0;

    for (size_t i = 0; i < outSize / sizeof(float); i++) {
        float a = (reinterpret_cast<const float*>(outputData))[i];
        float b = (reinterpret_cast<const float*>(goldenData))[i];
        float ae = std::abs(a - b);
        float re = ae / abs(b);
        if (ae > EPS && re > EPS) {
            printf("CompareResult golden.bin failed output is %lf, golden is %lf\n", a, b);
            wrongNum++;
        }
    }

    aclrtFreeHost(goldenData);

    if (wrongNum != 0) {
        printf("wrongNum: %ld\n", wrongNum);
        return false;
    } else {
        printf("CompareResult golden.bin success!\n");
        return true;
    }
}

int32_t main(int32_t argc, char* argv[])
{
    uint8_t* tiling = nullptr;
    size_t tilingSize = TILINGDATA_SIZE * sizeof(uint32_t);

    aclInit(nullptr);
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);

    uint8_t *xHost, *yHost;
    uint8_t *xDevice, *yDevice;

    aclrtMallocHost((void**)(&tiling), tilingSize);

    GenerateTilingData(tiling, M, N);

    auto tilingData = reinterpret_cast<SumCustomTilingData*>(tiling);

    size_t inputSize = tilingData->outter * tilingData->inner * sizeof(uint32_t);
    size_t outputSize = tilingData->out_inner * sizeof(uint32_t);

    aclrtMallocHost((void**)(&xHost), inputSize);
    aclrtMallocHost((void**)(&yHost), outputSize);

    aclrtMalloc((void**)&xDevice, inputSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&yDevice, outputSize, ACL_MEM_MALLOC_HUGE_FIRST);

    ReadFile("./input/input_x.bin", inputSize, xHost, inputSize);

    // Copy host memory to device memory
    aclrtMemcpy(xDevice, inputSize, xHost, inputSize, ACL_MEMCPY_HOST_TO_DEVICE);

    // Execute the kernel
    sum_custom<<<USED_CORE_NUM, nullptr, stream>>>(xDevice, yDevice, *reinterpret_cast<SumCustomTilingData*>(tiling));

    // Wait for the stop event to complete
    aclrtSynchronizeStream(stream);

    // Copy result to host memory and write to output file
    aclrtMemcpy(yHost, outputSize, yDevice, outputSize, ACL_MEMCPY_DEVICE_TO_HOST);
    WriteFile("./output/output.bin", yHost, outputSize);

    // Compare the result with the golden result
    bool goldenResult = true;
    goldenResult = CompareResult(yHost, outputSize);

    // Clean up memory
    aclrtFree(xDevice);
    aclrtFree(yDevice);
    aclrtFreeHost(xHost);
    aclrtFreeHost(yHost);
    aclrtFreeHost(tiling);

    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();

    if (goldenResult) {
        printf("test pass!\n");
    } else {
        printf("test failed!\n");
    }
    return 0;
}