/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/


/* !
 * \file batch_matmul.asc
 * \brief
 */

#include "data_utils.h"
#include "kernel_tiling/kernel_tiling.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/tiling_api.h"
#include "acl/acl.h"
#include "kernel_operator.h"
#include "lib/matmul_intf.h"

constexpr uint32_t M = 32;
constexpr uint32_t N = 256;
constexpr uint32_t K = 64;
constexpr int32_t baseM = 32;
constexpr int32_t baseN = 32;

constexpr int32_t A_BNUM = 2;
constexpr int32_t A_SNUM = 32;
constexpr int32_t A_GNUM = 3;
constexpr int32_t A_DNUM = 64;
constexpr int32_t B_BNUM = 2;
constexpr int32_t B_SNUM = 256;
constexpr int32_t B_GNUM = 3;
constexpr int32_t B_DNUM = 64;
constexpr int32_t C_BNUM = 2;
constexpr int32_t C_SNUM = 32;
constexpr int32_t C_GNUM = 3;
constexpr int32_t C_DNUM = 256;
constexpr int32_t BATCH_NUM = 3;
constexpr int USED_CORE_NUM = 2;
constexpr int32_t FULL_L1_SIZE = 512 * 1024;
constexpr int32_t FULL_L0C_SIZE = 128 * 1024;


bool ComputeTiling(optiling::TCubeTiling& tiling, matmul_tiling::MultiCoreMatmulTiling* cubeTiling, bool isBias)
{
    int32_t M = 32;
    int32_t N = 256;
    int32_t K = 64;
    int32_t baseM = 32;
    int32_t baseN = 32;
    cubeTiling->SetDim(1);
    cubeTiling->SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16);
    cubeTiling->SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16, true);
    cubeTiling->SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT);
    cubeTiling->SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT);
    cubeTiling->SetShape(M, N, K);
    cubeTiling->SetOrgShape(M, N, K);
    cubeTiling->SetFixSplit(baseM, baseN, -1);
    cubeTiling->EnableBias(isBias);
    cubeTiling->SetBufferSpace(-1, -1, -1);

    cubeTiling->SetALayout(A_BNUM, A_SNUM, 1, A_GNUM, A_DNUM);
    cubeTiling->SetBLayout(B_BNUM, B_SNUM, 1, B_GNUM, B_DNUM);
    cubeTiling->SetCLayout(C_BNUM, C_SNUM, 1, C_GNUM, C_DNUM);
    cubeTiling->SetBatchNum(BATCH_NUM);
    cubeTiling->SetBufferSpace(-1, -1, -1);
    if (cubeTiling->GetTiling(tiling) == -1) {
        return false;
    }
    return true;
}

uint8_t *GetTilingBuf(optiling::TCubeTiling *tilingData)
{
    if (!tilingData) {
        return nullptr;
    }
    uint32_t tilingSize = tilingData->GetDataSize();
    if (tilingSize == 0) {
        return nullptr;
    }
    uint8_t *buf = (uint8_t *)malloc(tilingSize);
    if (!buf) {
        return nullptr;
    }
    tilingData->SaveToBuffer(buf, tilingSize);
    return buf;
}

uint8_t *GenerateTiling()
{
    optiling::TCubeTiling tilingData;
    matmul_tiling::MultiCoreMatmulTiling tilingApi;
    bool res = ComputeTiling(tilingData, &tilingApi, false);
    if (!res) {
        std::cout << "gen tiling failed" << std::endl;
    }
    return GetTilingBuf(&tilingData);
}

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE>
class BatchMatmulKernel {
    public:
        __aicore__ inline BatchMatmulKernel(){};
        __aicore__ inline void Init(GM_ADDR a, GM_ADDR b, GM_ADDR bias, GM_ADDR c, GM_ADDR workspace, const TCubeTiling& tiling);
        template <bool hasBias = false>
        __aicore__ inline void Process(AscendC::TPipe* pipe, int32_t batchA, int32_t batchB);
        AscendC::Matmul<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE> matmulObj;
    private:
        __aicore__ inline void CalcOffset(int32_t blockIdx, const TCubeTiling& tiling, int32_t& offsetA, int32_t& offsetB,
                                          int32_t& offsetC, int32_t& offsetBias);
        using aType = typename A_TYPE::T;
        using bType = typename B_TYPE::T;
        using cType = typename C_TYPE::T;
        using biasType = typename BIAS_TYPE::T;
        AscendC::GlobalTensor<aType> aGlobal;
        AscendC::GlobalTensor<bType> bGlobal;
        AscendC::GlobalTensor<cType> cGlobal;
        AscendC::GlobalTensor<biasType> biasGlobal;
        TCubeTiling tiling;
};

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE>
__aicore__ inline void BatchMatmulKernel<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE>::Init(GM_ADDR a, GM_ADDR b, GM_ADDR bias,
                                         GM_ADDR c, GM_ADDR workspace, const TCubeTiling& tiling)
{
    this->tiling = tiling;
    int32_t sizeA = tiling.ALayoutInfoB * tiling.ALayoutInfoS * tiling.ALayoutInfoN * tiling.ALayoutInfoG * tiling.ALayoutInfoD * sizeof(aType);
    int32_t sizeB = tiling.BLayoutInfoB * tiling.BLayoutInfoS * tiling.BLayoutInfoN * tiling.BLayoutInfoG * tiling.BLayoutInfoD * sizeof(bType);
    int32_t sizeC = tiling.CLayoutInfoB * tiling.CLayoutInfoS1 * tiling.CLayoutInfoN * tiling.CLayoutInfoG * tiling.CLayoutInfoS2 * sizeof(cType);
    int32_t sizeBias = tiling.CLayoutInfoB * tiling.CLayoutInfoN * tiling.CLayoutInfoG * tiling.CLayoutInfoS2 * sizeof(cType);

    aGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ aType*>(a), sizeA);
    bGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ bType*>(b), sizeB);
    cGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ cType*>(c), sizeC);
    biasGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ biasType*>(bias), sizeBias);

    int32_t offsetA = 0;
    int32_t offsetB = 0;
    int32_t offsetC = 0;
    int32_t offsetBias = 0;
    CalcOffset(AscendC::GetBlockIdx(), tiling, offsetA, offsetB, offsetC, offsetBias);
    aGlobal = aGlobal[offsetA];
    bGlobal = bGlobal[offsetB];
    cGlobal = cGlobal[offsetC];
    biasGlobal = biasGlobal[offsetBias];
    if (GetSysWorkSpacePtr() == nullptr) {
        return;
    }
}

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE>
template <bool hasBias>
__aicore__ inline void BatchMatmulKernel<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE>::Process(AscendC::TPipe* pipe, int32_t batchA, int32_t batchB)
{
    int batchC = batchA > batchB ? batchA : batchB;
    int gLay = tiling.ALayoutInfoG > tiling.BLayoutInfoG ? tiling.ALayoutInfoG : tiling.BLayoutInfoG;
    int forExent = (tiling.ALayoutInfoB / USED_CORE_NUM) * tiling.ALayoutInfoN * gLay / tiling.BatchNum; // cut multi cores from batch axis 
    for (int i = 0; i < forExent; ++i) {
        int batchOffsetA = i * tiling.ALayoutInfoD * batchA;
        int batchOffsetB = i * tiling.BLayoutInfoD * batchB;
        if (tiling.BatchNum == tiling.ALayoutInfoN * tiling.ALayoutInfoG) {
            batchOffsetA = i * tiling.ALayoutInfoD * tiling.ALayoutInfoS * batchA;
        }
        if (tiling.BatchNum == tiling.BLayoutInfoN * tiling.BLayoutInfoG) {
            batchOffsetB = i * tiling.BLayoutInfoD * tiling.BLayoutInfoS * batchB;
        }
        matmulObj.SetTensorA(aGlobal[batchOffsetA], false);
        matmulObj.SetTensorB(bGlobal[batchOffsetB], true); // B transpose

        int idxC = i * batchC;
        if constexpr (hasBias) {
            int batchOffsetBias = idxC * tiling.CLayoutInfoS2;
            matmulObj.SetBias(biasGlobal[batchOffsetBias]);
        }

        int batchOffsetC = idxC * tiling.CLayoutInfoS2;
        if (tiling.BatchNum == tiling.CLayoutInfoN * tiling.CLayoutInfoG) {
            batchOffsetC = idxC * tiling.CLayoutInfoS2 * tiling.CLayoutInfoS1;
        }
        matmulObj.IterateBatch(cGlobal[batchOffsetC], batchA, batchB, false);
        AscendC::PipeBarrier<PIPE_FIX>();
    }
}

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE>
__aicore__ inline void BatchMatmulKernel<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE>::CalcOffset(int32_t blockIdx, const TCubeTiling& param,
                                        int32_t& offsetA, int32_t& offsetB, int32_t& offsetC, int32_t& offsetBias)
{
    int singleCoreBatch = 1; // Batch axis nums after cut on singleCore
    if constexpr (A_TYPE::layout == LayoutMode::BSNGD) {
        // cut multi cores from batch axis when layout is BSNGD
        offsetA = blockIdx * singleCoreBatch * param.ALayoutInfoS * param.ALayoutInfoN * param.ALayoutInfoG * param.ALayoutInfoD;
    }
 
    if constexpr (B_TYPE::layout == LayoutMode::BSNGD) {
        // cut multi cores from batch axis when layout is BSNGD
        offsetB = blockIdx * singleCoreBatch * param.BLayoutInfoS * param.BLayoutInfoN * param.BLayoutInfoG * param.BLayoutInfoD;
    }

    if constexpr (C_TYPE::layout == LayoutMode::BSNGD) {
        // cut multi cores from batch axis when layout is BSNGD
        offsetC = blockIdx * singleCoreBatch * param.CLayoutInfoS1 * param.CLayoutInfoN * param.CLayoutInfoG * param.CLayoutInfoS2;
        offsetBias = blockIdx * singleCoreBatch * param.CLayoutInfoN * param.CLayoutInfoG * param.CLayoutInfoS2;
    }
}

__aicore__ inline void CopyTiling(TCubeTiling* tiling, GM_ADDR tilingGM)
{
    uint32_t* ptr = reinterpret_cast<uint32_t*>(tiling);
    auto tiling32 = reinterpret_cast<__gm__ uint32_t*>(tilingGM);

    for (int i = 0; i < sizeof(TCubeTiling) / sizeof(uint32_t); i++, ptr++) {
      *ptr = *(tiling32 + i);
    }
    return;
}

extern "C" __global__ __aicore__ void batch_matmul_custom(GM_ADDR a, GM_ADDR b, GM_ADDR c, __kfc_workspace__ GM_ADDR workspace,
                                                        GM_ADDR tilingGm)
{
    // prepare tiling
    TCubeTiling tiling;
    CopyTiling(&tiling, tilingGm);
    // define matmul kernel
    typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half, false, LayoutMode::BSNGD> A_TYPE;
    typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half, true, LayoutMode::BSNGD> B_TYPE;
    typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float, false, LayoutMode::BSNGD> C_TYPE;
    typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float> BIAS_TYPE;
    BatchMatmulKernel<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE> batchMatmulKernel;
    AscendC::TPipe pipe;
    tiling.shareMode = 0;             // 0, share mode
    tiling.shareL1Size = FULL_L1_SIZE;  // full L1
    tiling.shareL0CSize = FULL_L0C_SIZE; // full L0C
    tiling.shareUbSize = 0;           // no UB
    REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), batchMatmulKernel.matmulObj, &tiling);
    // init matmul kernel
    batchMatmulKernel.Init(a, b, nullptr, c, workspace, tiling);
    // matmul kernel process
    batchMatmulKernel.Process<false>(&pipe, 3, 3);
}

void batch_matmul_custom_do(uint32_t numBlocks, void* stream, GM_ADDR a, GM_ADDR b,
                         GM_ADDR c, GM_ADDR workspace, GM_ADDR tilingGm)
{
    // invoke the kernel function through the <<<>>> symbol
    batch_matmul_custom<<<numBlocks, nullptr, stream>>>(a, b, c, workspace, tilingGm);
}

int32_t main(int32_t argc, char* argv[])
{
    size_t aFileSize = 192 * 64 * sizeof(uint16_t);   // uint16_t represent half
    size_t bFileSize = 64 * 1536 * sizeof(uint16_t);  // uint16_t represent half
    size_t cFileSize = 192 * 256 * sizeof(float);
    uint32_t workspaceSize = 16 * 1024 * 1024;
    size_t tilingFileSize = sizeof(TCubeTiling);
    uint32_t numBlocks = 1;
    int64_t wrongNum = -1;
    uint8_t *tilingBuf = GenerateTiling();

    aclInit(nullptr);
    aclrtContext context;
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtCreateContext(&context, deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);

    uint8_t *aHost;
    uint8_t *aDevice;
    aclrtMallocHost((void **)(&aHost), aFileSize);
    
        aclrtMalloc((void **)&aDevice, aFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
    ReadFile("./input/x1_gm.bin", aFileSize, aHost, aFileSize);
    aclrtMemcpy(aDevice, aFileSize, aHost, aFileSize,
                          ACL_MEMCPY_HOST_TO_DEVICE);

    uint8_t *bHost;
    uint8_t *bDevice;
    aclrtMallocHost((void **)(&bHost), bFileSize);
    
        aclrtMalloc((void **)&bDevice, bFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
    ReadFile("./input/x2_gm.bin", bFileSize, bHost, bFileSize);
    aclrtMemcpy(bDevice, bFileSize, bHost, bFileSize,
                          ACL_MEMCPY_HOST_TO_DEVICE);

    uint8_t *workspaceHost;
    uint8_t *workspaceDevice;
    aclrtMallocHost((void **)(&workspaceHost), workspaceSize);
    aclrtMalloc((void **)&workspaceDevice, workspaceSize,
                          ACL_MEM_MALLOC_HUGE_FIRST);

    uint8_t *tilingHost;
    uint8_t *tilingDevice;
    aclrtMallocHost((void **)(&tilingHost), tilingFileSize);
    aclrtMalloc((void **)&tilingDevice, tilingFileSize,
                          ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMemcpy(tilingHost, tilingFileSize, tilingBuf,
                          tilingFileSize, ACL_MEMCPY_HOST_TO_HOST);
    aclrtMemcpy(tilingDevice, tilingFileSize, tilingHost,
                          tilingFileSize, ACL_MEMCPY_HOST_TO_DEVICE);

    uint8_t *cHost;
    uint8_t *cDevice;
    aclrtMallocHost((void **)(&cHost), cFileSize);
    
        aclrtMalloc((void **)&cDevice, cFileSize, ACL_MEM_MALLOC_HUGE_FIRST);

    batch_matmul_custom_do(numBlocks, stream, aDevice, bDevice, cDevice, workspaceDevice, tilingDevice);
    aclrtSynchronizeStream(stream);

    aclrtMemcpy(cHost, cFileSize, cDevice, cFileSize,
                          ACL_MEMCPY_DEVICE_TO_HOST);
    WriteFile("./output/output.bin", cHost, cFileSize);

    aclrtFree(cDevice);
    aclrtFreeHost(cHost);

    aclrtDestroyStream(stream);
    aclrtDestroyContext(context);
    aclrtResetDevice(deviceId);
    aclFinalize();
}