/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/


/* !
 * \file load_data.asc
 * \brief
 */

#include "acl/acl.h"
#include "kernel_operator.h"
#include "data_utils.h"

template <typename DST_T, typename FM_T, typename WEIGHT_T, typename DSTCO1_T>
class LoadDataKernel {
public:
    __aicore__ inline LoadDataKernel()
    {
        coutBlocks = (Cout + 16 - 1) / 16;
        ho = (H + padTop + padBottom - dilationH * (Kh - 1) - 1) / strideH + 1;
        wo = (W + padLeft + padRight - dilationW * (Kw - 1) - 1) / strideW + 1;
        howo = ho * wo;
        howoRound = ((howo + 16 - 1) / 16) * 16;
        featureMapA1Size = C1 * H * W * C0;      // shape: [C1, H, W, C0]
        weightA1Size = C1 * Kh * Kw * Cout * C0; // shape: [C1, Kh, Kw, Cout, C0]
        featureMapA2Size = howoRound * (C1 * Kh * Kw * C0);
        weightB2Size = (C1 * Kh * Kw * C0) * coutBlocks * 16;
        m = howo;
        k = C1 * Kh * Kw * C0;
        n = Cout;
        dstSize = coutBlocks * howo * 16; // shape: [coutBlocks, howo, 16]
        dstCO1Size = coutBlocks * howoRound * 16;
        fmRepeat = featureMapA2Size / (16 * C0);
        weRepeat = weightB2Size / (16 * C0);
    }
    __aicore__ inline void Init(__gm__ uint8_t* fmGm, __gm__ uint8_t* weGm, __gm__ uint8_t* dstGm)
    {
        fmGlobal.SetGlobalBuffer((__gm__ FM_T*)fmGm);
        weGlobal.SetGlobalBuffer((__gm__ WEIGHT_T*)weGm);
        dstGlobal.SetGlobalBuffer((__gm__ DST_T*)dstGm);
        pipe.InitBuffer(inQueueFmA1, 1, featureMapA1Size * sizeof(FM_T));
        pipe.InitBuffer(inQueueFmA2, 1, featureMapA2Size * sizeof(FM_T));
        pipe.InitBuffer(inQueueWeB1, 1, weightA1Size * sizeof(WEIGHT_T));
        pipe.InitBuffer(inQueueWeB2, 1, weightB2Size * sizeof(WEIGHT_T));
        pipe.InitBuffer(outQueueCO1, 1, dstCO1Size * sizeof(DSTCO1_T));
        pipe.InitBuffer(outQueueUB, 1, dstSize * sizeof(DST_T));
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Split();
        Compute();
        CopyUB();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<FM_T> featureMapA1 = inQueueFmA1.AllocTensor<FM_T>();
        AscendC::LocalTensor<WEIGHT_T> weightB1 = inQueueWeB1.AllocTensor<WEIGHT_T>();
        AscendC::DataCopy(featureMapA1, fmGlobal,
                          {1, static_cast<uint16_t>(featureMapA1Size * sizeof(FM_T) / 32), 0, 0});
        AscendC::DataCopy(weightB1, weGlobal, {1, static_cast<uint16_t>(weightA1Size * sizeof(WEIGHT_T) / 32), 0, 0});
        inQueueFmA1.EnQue(featureMapA1);
        inQueueWeB1.EnQue(weightB1);
    }
    __aicore__ inline void Split()
    {
        AscendC::LocalTensor<FM_T> featureMapA1 = inQueueFmA1.DeQue<FM_T>();
        AscendC::LocalTensor<WEIGHT_T> weightB1 = inQueueWeB1.DeQue<WEIGHT_T>();
        AscendC::LocalTensor<FM_T> featureMapA2 = inQueueFmA2.AllocTensor<FM_T>();
        AscendC::LocalTensor<WEIGHT_T> weightB2 = inQueueWeB2.AllocTensor<WEIGHT_T>();
        uint8_t padList[4] = {padLeft, padRight, padTop, padBottom};
        AscendC::LoadData3DParamsV1<FM_T> load3dParams(padList, H, W, 0, 0, 0, -1, -1, strideW, strideH, Kw, Kh,
                                                       dilationW, dilationH, 1, 0, fmRepeat, 0, (FM_T)(0));
        AscendC::LoadData(featureMapA2, featureMapA1, load3dParams);
        AscendC::LoadData2DParams load2dParams(0, weRepeat, 1, 0, 0, false, 0);
        AscendC::LoadData(weightB2, weightB1, load2dParams);
        inQueueFmA2.EnQue<FM_T>(featureMapA2);
        inQueueWeB2.EnQue<WEIGHT_T>(weightB2);
        inQueueFmA1.FreeTensor(featureMapA1);
        inQueueWeB1.FreeTensor(weightB1);
    }
    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<FM_T> featureMapA2 = inQueueFmA2.DeQue<FM_T>();
        AscendC::LocalTensor<WEIGHT_T> weightB2 = inQueueWeB2.DeQue<WEIGHT_T>();
        AscendC::LocalTensor<DSTCO1_T> dstCO1 = outQueueCO1.AllocTensor<DSTCO1_T>();
        AscendC::Mmad(dstCO1, featureMapA2, weightB2, {m, n, k, 0, false, true});
        outQueueCO1.EnQue<DSTCO1_T>(dstCO1);
        inQueueFmA2.FreeTensor(featureMapA2);
        inQueueWeB2.FreeTensor(weightB2);
    }
    __aicore__ inline void CopyUB()
    {
        AscendC::LocalTensor<DSTCO1_T> dstCO1 = outQueueCO1.DeQue<DSTCO1_T>();
        AscendC::LocalTensor<DST_T> dstUB = outQueueUB.AllocTensor<DST_T>();
        AscendC::DataCopyParams dataCopyParams;
        dataCopyParams.blockCount = 1;
        dataCopyParams.blockLen = m * n * sizeof(DSTCO1_T) / 1024;
        AscendC::DataCopyEnhancedParams enhancedParams;
        enhancedParams.blockMode = AscendC::BlockMode::BLOCK_MODE_MATRIX;
        AscendC::DataCopy(dstUB, dstCO1, dataCopyParams, enhancedParams);
        outQueueUB.EnQue<DST_T>(dstUB);
        outQueueCO1.FreeTensor(dstCO1);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<DST_T> dstUB = outQueueUB.DeQue<DST_T>();
        AscendC::DataCopy(dstGlobal, dstUB, m * n);
        outQueueUB.FreeTensor(dstUB);
    }

private:
    AscendC::TPipe pipe;
    // feature map queue
    AscendC::TQue<AscendC::TPosition::A1, 1> inQueueFmA1;
    AscendC::TQue<AscendC::TPosition::A2, 1> inQueueFmA2;
    // weight queue
    AscendC::TQue<AscendC::TPosition::B1, 1> inQueueWeB1;
    AscendC::TQue<AscendC::TPosition::B2, 1> inQueueWeB2;
    // dst queue
    AscendC::TQue<AscendC::TPosition::CO1, 1> outQueueCO1;
    AscendC::TQue<AscendC::TPosition::CO2, 1> outQueueUB;
    AscendC::GlobalTensor<FM_T> fmGlobal;
    AscendC::GlobalTensor<WEIGHT_T> weGlobal;
    AscendC::GlobalTensor<DST_T> dstGlobal;

    uint16_t C1 = 2;
    uint16_t H = 4, W = 4;
    uint8_t Kh = 2, Kw = 2;
    uint16_t Cout = 16;
    uint16_t C0 = 16;
    uint8_t dilationH = 2, dilationW = 2;
    uint8_t padTop = 1, padBottom = 1, padLeft = 1, padRight = 1;
    uint8_t strideH = 1, strideW = 1;
    uint16_t coutBlocks, ho, wo, howo, howoRound;
    uint32_t featureMapA1Size, weightA1Size, featureMapA2Size, weightB2Size, dstSize, dstCO1Size;
    uint16_t m, k, n;
    uint8_t fmRepeat, weRepeat;
};

extern "C" __global__ __aicore__ void load_data_custom(__gm__ uint8_t* fmGm, __gm__ uint8_t* weGm,
                                                       __gm__ uint8_t* dstGm)
{
    LoadDataKernel<half, half, half, float> op;
    op.Init(fmGm, weGm, dstGm);
    op.Process();
}

int32_t main(int32_t argc, char* argv[])
{
    size_t aFileSize = (2 * 4 * 4 * 16) * sizeof(int16_t);      // C1 H W C0
    size_t bFileSize = (2 * 2 * 2 * 16 * 16) * sizeof(int16_t); // C1 Kh Kw Cout C0
    size_t cFileSize = (1 * 4 * 4 * 16) * sizeof(int16_t);      // Cout/16 ho*wo 16
    uint32_t numBlocks = 1;

    aclInit(nullptr);
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);

    uint8_t* aHost;
    uint8_t* aDevice;
    aclrtMallocHost((void**)(&aHost), aFileSize);
    aclrtMalloc((void**)&aDevice, aFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
    ReadFile("./input/x1_gm.bin", aFileSize, aHost, aFileSize);
    aclrtMemcpy(aDevice, aFileSize, aHost, aFileSize, ACL_MEMCPY_HOST_TO_DEVICE);

    uint8_t* bHost;
    uint8_t* bDevice;
    aclrtMallocHost((void**)(&bHost), bFileSize);
    aclrtMalloc((void**)&bDevice, bFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
    ReadFile("./input/x2_gm.bin", bFileSize, bHost, bFileSize);
    aclrtMemcpy(bDevice, bFileSize, bHost, bFileSize, ACL_MEMCPY_HOST_TO_DEVICE);

    uint8_t* cHost;
    uint8_t* cDevice;
    aclrtMallocHost((void**)(&cHost), cFileSize);
    aclrtMalloc((void**)&cDevice, cFileSize, ACL_MEM_MALLOC_HUGE_FIRST);

    load_data_custom<<<numBlocks, nullptr, stream>>>(aDevice, bDevice, cDevice);
    aclrtSynchronizeStream(stream);

    aclrtMemcpy(cHost, cFileSize, cDevice, cFileSize, ACL_MEMCPY_DEVICE_TO_HOST);
    WriteFile("./output/output.bin", cHost, cFileSize);

    aclrtFree(aDevice);
    aclrtFreeHost(aHost);
    aclrtFree(bDevice);
    aclrtFreeHost(bHost);
    aclrtFree(cDevice);
    aclrtFreeHost(cHost);

    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();
    return 0;
}