/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/


/* !
 * \file add.asc
 * \brief
 */

#include <cstdint>
#include <iostream>
#include <vector>
#include <algorithm>
#include <iterator>
#include "acl/acl.h"
#include "kernel_operator.h"
#include "data_utils.h"

#ifndef ADD_CUSTOM_TILING_H
#define ADD_CUSTOM_TILING_H
#include <cstdint>

struct AddCustomTilingData {
    uint32_t singleCoreLength;
};
#endif // ADD_CUSTOM_TILING_H

using AscendC::TPosition;
namespace {
constexpr uint32_t TILE_LENGTH = 4096;
}

class KernelAddV2 {
public:
    __aicore__ inline KernelAddV2() = default;
    __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, uint32_t singleCoreLength)
    {
        xGm.SetGlobalBuffer((__gm__ float*)x + AscendC::GetBlockIdx() * singleCoreLength, singleCoreLength);
        yGm.SetGlobalBuffer((__gm__ float*)y + AscendC::GetBlockIdx() * singleCoreLength, singleCoreLength);
        zGm.SetGlobalBuffer((__gm__ float*)z + AscendC::GetBlockIdx() * singleCoreLength, singleCoreLength);
        loopCount = singleCoreLength / TILE_LENGTH;
    }
    __aicore__ inline void Process()
    {
        // ping
        AscendC::LocalTensor<float> xLocalPing(AscendC::TPosition::VECCALC, xAddrPing, TILE_LENGTH);
        AscendC::LocalTensor<float> yLocalPing(AscendC::TPosition::VECCALC, yAddrPing, TILE_LENGTH);
        AscendC::LocalTensor<float> zLocalPing(AscendC::TPosition::VECCALC, zAddrPing, TILE_LENGTH);
        // pong
        AscendC::LocalTensor<float> xLocalPong(AscendC::TPosition::VECCALC, xAddrPong, TILE_LENGTH);
        AscendC::LocalTensor<float> yLocalPong(AscendC::TPosition::VECCALC, yAddrPong, TILE_LENGTH);
        AscendC::LocalTensor<float> zLocalPong(AscendC::TPosition::VECCALC, zAddrPong, TILE_LENGTH);

        // double buffer
        for (uint32_t i = 0; i < loopCount / 2; i++) {
            // ping part
            // dependency of PIPE_V & PIPE_MTE2 caused by xLocalPing/yLocalPing between 2 sequential loops
            if (i != 0) {
                AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
            }
            AscendC::DataCopy(xLocalPing, xGm[2 * i * TILE_LENGTH], TILE_LENGTH);
            AscendC::DataCopy(yLocalPing, yGm[2 * i * TILE_LENGTH], TILE_LENGTH);
            // dependency of PIPE_MTE2 & PIPE_V caused by xLocalPing/yLocalPing in one single loop
            AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
            AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
            if (i != 0) {
                // dependency of PIPE_MTE3 & PIPE_V caused by zLocalPing between 2 sequential loops
                AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
            }
            AscendC::Add(zLocalPing, xLocalPing, yLocalPing, TILE_LENGTH);
            if (i != (loopCount / 2 - 1)) {
                // dependency of PIPE_V & PIPE_MTE2 caused by xLocalPing/yLocalPing between 2 sequential loops
                AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
            }
            // dependency of PIPE_V & PIPE_MTE3 caused by zLocalPing in one single loop
            AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
            AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
            AscendC::DataCopy(zGm[2 * i * TILE_LENGTH], zLocalPing, TILE_LENGTH);
            if (i != (loopCount / 2 - 1)) {
                // dependency of PIPE_MTE3 & PIPE_V caused by zLocalPing between 2 sequential loops
                AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
            }

            // pong part
            // dependency of PIPE_V & PIPE_MTE2 caused by xLocalPong/yLocalPong between 2 sequential loops
            if (i != 0) {
                AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
            }
            AscendC::DataCopy(xLocalPong, xGm[(2 * i + 1) * TILE_LENGTH], TILE_LENGTH);
            AscendC::DataCopy(yLocalPong, yGm[(2 * i + 1) * TILE_LENGTH], TILE_LENGTH);
            // dependency of PIPE_MTE2 & PIPE_V caused by xLocalPong/yLocalPong in one single loop
            AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID1);
            AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID1);
            if (i != 0) {
                // dependency of PIPE_MTE3 & PIPE_V caused by zLocalPong between 2 sequential loops
                AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
            }
            AscendC::Add(zLocalPong, xLocalPong, yLocalPong, TILE_LENGTH);
            if (i != (loopCount / 2 - 1)) {
                // dependency of PIPE_V & PIPE_MTE2 caused by xLocalPong/yLocalPong between 2 sequential loops
                AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
            }
            // dependency of PIPE_V & PIPE_MTE3 caused by zLocalPong in one single loop
            AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID1);
            AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID1);
            AscendC::DataCopy(zGm[(2 * i + 1) * TILE_LENGTH], zLocalPong, TILE_LENGTH);
            if (i != (loopCount / 2 - 1)) {
                // dependency of PIPE_MTE3 & PIPE_V caused by zLocalPong between 2 sequential loops
                AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
            }
        }

        // tail block
        if (loopCount % 2 != 0) {
            // dependency of PIPE_V & PIPE_MTE2 caused by xLocalPing/yLocalPing with the previous for loop
            AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
            AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
            AscendC::DataCopy(xLocalPing, xGm[(loopCount - 1) * TILE_LENGTH], TILE_LENGTH);
            AscendC::DataCopy(yLocalPing, yGm[(loopCount - 1) * TILE_LENGTH], TILE_LENGTH);
            // dependency of PIPE_MTE2 & PIPE_V caused by xLocalPing/yLocalPing in one loop
            AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
            AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
            // dependency of PIPE_MTE3 & PIPE_V caused by zLocalPing with the previous for loop
            AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
            AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
            AscendC::Add(zLocalPing, xLocalPing, yLocalPing, TILE_LENGTH);
            // dependency of PIPE_V & PIPE_MTE3 caused by zLocalPing in one loop
            AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
            AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
            AscendC::DataCopy(zGm[(loopCount - 1) * TILE_LENGTH], zLocalPing, TILE_LENGTH);
        }
    }

private:
    static constexpr uint32_t xAddrPing = 0;
    static constexpr uint32_t yAddrPing = TILE_LENGTH * sizeof(float);
    static constexpr uint32_t zAddrPing = TILE_LENGTH * sizeof(float) * 2;
    static constexpr uint32_t xAddrPong = TILE_LENGTH * sizeof(float) * 3;
    static constexpr uint32_t yAddrPong = TILE_LENGTH * sizeof(float) * 4;
    static constexpr uint32_t zAddrPong = TILE_LENGTH * sizeof(float) * 5;
    AscendC::GlobalTensor<float> xGm;
    AscendC::GlobalTensor<float> yGm;
    AscendC::GlobalTensor<float> zGm;
    uint32_t loopCount;
};

__global__ __aicore__ void add_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z, GM_ADDR tiling)
{
    KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
    AscendC::InitSocState();
    KernelAddV2 op;
    op.Init(x, y, z, ((__gm__ AddCustomTilingData*)tiling)->singleCoreLength);
    op.Process();
}

struct ArgInfo {
    std::string fileName;
    size_t length;
};

void KernelCall(uint32_t numBlocks, void* stream, std::vector<ArgInfo>& inputsInfo, std::vector<ArgInfo>& outputsInfo,
                uint8_t* tiling)
{
    std::vector<uint8_t*> inputHost(inputsInfo.size());
    std::vector<uint8_t*> inputDevice(inputsInfo.size());
    std::vector<uint8_t*> outputHost(outputsInfo.size());
    std::vector<uint8_t*> outputDevice(outputsInfo.size());
    uint8_t* tilingDevice;

    aclrtMalloc((void**)(&tilingDevice), sizeof(AddCustomTilingData), ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMemcpy(tilingDevice, sizeof(AddCustomTilingData), tiling, sizeof(AddCustomTilingData),
                ACL_MEMCPY_HOST_TO_DEVICE);

    for (uint32_t i = 0; i < inputsInfo.size(); i++) {
        aclrtMallocHost((void**)(&inputHost[i]), inputsInfo[i].length);
        aclrtMalloc((void**)(&inputDevice[i]), inputsInfo[i].length, ACL_MEM_MALLOC_HUGE_FIRST);
        ReadFile(inputsInfo[i].fileName, inputsInfo[i].length, inputHost[i], inputsInfo[i].length);
        aclrtMemcpy(inputDevice[i], inputsInfo[i].length, inputHost[i], inputsInfo[i].length,
                    ACL_MEMCPY_HOST_TO_DEVICE);
    }

    for (uint32_t i = 0; i < outputsInfo.size(); i++) {
        aclrtMallocHost((void**)(&outputHost[i]), outputsInfo[i].length);
        aclrtMalloc((void**)(&outputDevice[i]), outputsInfo[i].length, ACL_MEM_MALLOC_HUGE_FIRST);
    }

    add_custom<<<numBlocks, nullptr, stream>>>(inputDevice[0], inputDevice[1], outputDevice[0], tilingDevice);
    aclrtSynchronizeStream(stream);

    aclrtFree(tilingDevice);
    for (uint32_t i = 0; i < outputsInfo.size(); i++) {
        aclrtMemcpy(outputHost[i], outputsInfo[i].length, outputDevice[i], outputsInfo[i].length,
                    ACL_MEMCPY_DEVICE_TO_HOST);
        WriteFile(outputsInfo[i].fileName, outputHost[i], outputsInfo[i].length);
        aclrtFree(outputDevice[i]);
        aclrtFreeHost(outputHost[i]);
    }

    for (uint32_t i = 0; i < inputsInfo.size(); i++) {
        aclrtFree(inputDevice[i]);
        aclrtFreeHost(inputHost[i]);
    }
}

int32_t main(int32_t argc, char* argv[])
{
    uint32_t numBlocks = 8;
    // set data length, in this case we use 8 cores and length of each core is 4096 * 9
    uint32_t dataLen = 4096 * 9 * numBlocks;
    size_t inputByteSize = dataLen * sizeof(float);
    size_t outputByteSize = dataLen * sizeof(float);
    AddCustomTilingData tiling;
    tiling.singleCoreLength = dataLen / numBlocks;

    std::vector<ArgInfo> inputsInfo = {{"./input/input_x.bin", inputByteSize}, {"./input/input_y.bin", inputByteSize}};
    std::vector<ArgInfo> outputsV1Info = {{"./output/output.bin", outputByteSize}};

    aclInit(nullptr);
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);

    KernelCall(numBlocks, stream, inputsInfo, outputsV1Info, (uint8_t*)&tiling);

    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();

    return 0;
}