/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/


/* !
 * \file abs_pad.asc
 * \brief
 */

#include "acl/acl.h"
#include "kernel_operator.h"
#include "data_utils.h"

#include "graph/tensor.h"
#include "tiling/tiling_api.h"
#include "tiling/platform/platform_ascendc.h"

constexpr int32_t BLOCK_BYTE_SIZE = 14; // equivalent to the definition of blockLen of DataCopyPad
constexpr int32_t BLOCK_GROUP_NUM = 16; // equivalent to the definition of blockCount of DataCopyPad
constexpr int32_t BLOCK_ELEMENT_NUM = BLOCK_BYTE_SIZE / sizeof(half);
constexpr int32_t BLOCKLEN_CEIL = 32 / sizeof(half); // since BLOCK_BYTE_SIZE<32
constexpr int32_t USE_CORE_NUM = 8;                  // num of core used
constexpr int32_t TILE_NUM = 8;                      // split data into 16 tiles for each core
constexpr int32_t BUFFER_NUM = 2;                    // tensor num for each queue
constexpr int32_t TOTAL_LENGTH = USE_CORE_NUM * TILE_NUM * BUFFER_NUM * BLOCK_GROUP_NUM * BLOCK_ELEMENT_NUM;
constexpr int32_t BLOCK_LENGTH = TOTAL_LENGTH / USE_CORE_NUM; // length computed of each core
constexpr int32_t TILE_LENGTH = BLOCK_LENGTH / TILE_NUM / BUFFER_NUM;

__aicore__ inline void CopyTiling(PadTiling *tiling, GM_ADDR tilingGM)
{
    uint32_t *ptr = reinterpret_cast<uint32_t *>(tiling);
    auto tiling32 = reinterpret_cast<__gm__ uint32_t *>(tilingGM);

    for (uint32_t i = 0; i < sizeof(PadTiling) / sizeof(uint32_t); i++, ptr++) {
        *ptr = *(tiling32 + i);
    }
    return;
}

class KernelAbsPad {
public:
    __aicore__ inline KernelAbsPad() {}
    __aicore__ inline void Init(GM_ADDR inputGM, GM_ADDR outputGM, PadTiling tiling)
    {
        this->tiling = tiling;
        srcGlobal.SetGlobalBuffer((__gm__ half *)(inputGM) + BLOCK_LENGTH * AscendC::GetBlockIdx(), BLOCK_LENGTH);
        dstGlobal.SetGlobalBuffer((__gm__ half *)(outputGM) + BLOCK_LENGTH * AscendC::GetBlockIdx(), BLOCK_LENGTH);
        pipe.InitBuffer(inQueue, BUFFER_NUM, BLOCK_GROUP_NUM * BLOCKLEN_CEIL * sizeof(half));
        pipe.InitBuffer(outQueue, BUFFER_NUM, BLOCK_GROUP_NUM * BLOCKLEN_CEIL * sizeof(half));
        AscendC::Fill(dstGlobal, BLOCK_LENGTH, half(0.0));
    }
    __aicore__ inline void Process()
    {
        const int32_t loopCount = TILE_NUM * BUFFER_NUM;
        for (int32_t i = 0; i < loopCount; i++) {
            CopyIn(i);
            Compute(i);
            CopyOut(i);
        }
    }

private:
    __aicore__ inline void CopyIn(int32_t progress)
    {
        AscendC::LocalTensor<half> inputLocal = inQueue.AllocTensor<half>();
        for (int32_t i = 0; i < BLOCK_GROUP_NUM; i++) {
            const uint32_t srcGmIdx = progress * TILE_LENGTH + BLOCK_ELEMENT_NUM * i;
            AscendC::DataCopy(inputLocal[BLOCKLEN_CEIL * i], srcGlobal[srcGmIdx], BLOCKLEN_CEIL);
        }
        inQueue.EnQue(inputLocal);
    }
    __aicore__ inline void Compute(int32_t progress)
    {
        AscendC::LocalTensor<half> inputLocal = inQueue.DeQue<half>();
        AscendC::LocalTensor<half> outputLocal = outQueue.AllocTensor<half>();
        AscendC::PadParams padParams = {0, BLOCKLEN_CEIL - BLOCK_ELEMENT_NUM, 0};
        AscendC::Pad(outputLocal, inputLocal, padParams, this->tiling);
        AscendC::Abs(outputLocal, outputLocal, BLOCK_GROUP_NUM * BLOCKLEN_CEIL); // main calculation
        outQueue.EnQue<half>(outputLocal);
        inQueue.FreeTensor(inputLocal);
    }
    __aicore__ inline void CopyOut(int32_t progress)
    {
        AscendC::LocalTensor<half> outputLocal = outQueue.DeQue<half>();
        AscendC::SetAtomicAdd<half>();
        for (int32_t i = 0; i < BLOCK_GROUP_NUM; i++) {
            const uint32_t srcGmIdx = progress * TILE_LENGTH + i * BLOCK_ELEMENT_NUM;
            AscendC::DataCopy<half>(dstGlobal[srcGmIdx], outputLocal[i * BLOCK_GROUP_NUM], BLOCKLEN_CEIL);
        }
        AscendC::SetAtomicNone();
        outQueue.FreeTensor(outputLocal);
    }

private:
    AscendC::GlobalTensor<half> srcGlobal;
    AscendC::GlobalTensor<half> dstGlobal;
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> inQueue;
    AscendC::TQue<AscendC::TPosition::VECOUT, BUFFER_NUM> outQueue;
    PadTiling tiling;
};


extern "C" __global__ __aicore__ void abs_pad_custom(GM_ADDR inputGM, GM_ADDR outputGM, GM_ADDR tilingData)
{
    KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
    KernelAbsPad op;
    PadTiling tiling;
    CopyTiling(&tiling, tilingData);
    op.Init(inputGM, outputGM, tiling);
    op.Process();
}

void GenerateTiling(const std::vector<int64_t> shapePad, const std::vector<int64_t> shapeUsed, uint8_t *tilingBuf)
{
    ge::Shape srcShape(shapePad);
    ge::Shape oriSrcShape(shapeUsed);
    uint32_t tmpMinSize, tmpMaxSize;
    AscendC::GetPadMaxMinTmpSize(srcShape, sizeof(int16_t), tmpMaxSize, tmpMinSize);
    optiling::PadTiling tilingData;
    AscendC::PadTilingFunc(srcShape, oriSrcShape, tmpMaxSize, sizeof(int16_t), tilingData);
    uint32_t tilingSize = tilingData.GetDataSize();
    tilingData.SaveToBuffer(tilingBuf, tilingSize);
    return;
}

int32_t main(int32_t argc, char *argv[])
{
    const std::vector<int64_t> shapeUsed({16, 7}); // shape of valid data
    const std::vector<int64_t> shapePad({16, 16}); // original shape
    uint32_t numBlocks = 8;

    // 14336 is the length of input data
    uint32_t oriLength = 14336;
    // we must allocate more space to prevent invalid address access
    uint32_t padLength = oriLength + shapePad[1] - shapeUsed[1];
    size_t inputByteSize = padLength * sizeof(int16_t);
    size_t outputByteSize = padLength * sizeof(int16_t);
    // however, original length must be used when output to file
    size_t outputFileSize = oriLength * sizeof(int16_t);
    size_t tilingSize = sizeof(PadTiling);
    uint8_t *tilingBuf = (uint8_t *)malloc(tilingSize);
    GenerateTiling(shapePad, shapeUsed, tilingBuf);

    aclInit(nullptr);
    aclrtContext context;
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtCreateContext(&context, deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);

    uint8_t *xHost, *zHost;
    uint8_t *xDevice, *zDevice, *tilingDevice;


    aclrtMallocHost((void **)(&xHost), inputByteSize);
    aclrtMallocHost((void **)(&zHost), outputByteSize);
    aclrtMalloc((void **)&xDevice, inputByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void **)&zDevice, outputByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void **)&tilingDevice, tilingSize, ACL_MEM_MALLOC_HUGE_FIRST);

    ReadFile("./input/input_x.bin", inputByteSize, xHost, inputByteSize);

    aclrtMemcpy(xDevice, inputByteSize, xHost, inputByteSize, ACL_MEMCPY_HOST_TO_DEVICE);
    aclrtMemcpy(tilingDevice, tilingSize, tilingBuf, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);


    abs_pad_custom<<<numBlocks, nullptr, stream>>>(xDevice, zDevice, tilingDevice);
    aclrtSynchronizeStream(stream);

    aclrtMemcpy(zHost, outputByteSize, zDevice, outputByteSize, ACL_MEMCPY_DEVICE_TO_HOST);
    WriteFile("./output/output.bin", zHost, outputFileSize);

    aclrtFree(xDevice);
    aclrtFree(zDevice);
    aclrtFree(tilingDevice);
    aclrtFreeHost(xHost);
    aclrtFreeHost(zHost);

    aclrtDestroyStream(stream);
    aclrtDestroyContext(context);
    aclrtResetDevice(deviceId);
    aclFinalize();
    free(tilingBuf);
    return 0;
}