/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/* !
 * \file batch_mmad.asc
 * \brief
 */

#include "acl/acl.h"
#include "kernel_operator.h"
#include "data_utils.h"


constexpr uint32_t M = M_SIZE;
constexpr uint32_t N = N_SIZE;
constexpr uint32_t K = K_SIZE;
constexpr uint32_t B = B_SIZE;

class KernelMmad {
public:
    __aicore__ inline KernelMmad()
    {
        // 左矩阵分形的shape
        cubeShape[0] = 16;
        cubeShape[1] = 32 / sizeof(float);
        // 右矩阵的shape:[32 / sizeof(float), 16]
        // 左、右矩阵分形的size,单位是元素数目
        cubeSize = 16 * cubeShape[1];
        fractNum = 2; 

        // 要注意A、B矩阵在L1、L0上的数据量,不要超过相应硬件的L1、L0A、L0B、L0C内存容量限制。
        aSizeAlignL0 = CeilAlign(m, cubeShape[0]) * CeilAlign(k, cubeShape[1]);
        aSizeAlignL1 = B * CeilAlign(m, cubeShape[0]) * CeilAlign(k, cubeShape[1]);
        bSizeAlignL0 = CeilAlign(k, cubeShape[0]) * CeilAlign(n, cubeShape[1]);
        bSizeAlignL1 = B * CeilAlign(k, cubeShape[0]) * CeilAlign(n, cubeShape[1]);
        // C矩阵无视数据类型和a、b是否转置,m,n都向16对齐
        cSizeAlignL0 = B * CeilAlign(m, cubeShape[0]) * CeilAlign(n, cubeShape[0]);
    }
    __aicore__ inline void Init(GM_ADDR a, GM_ADDR b, GM_ADDR c,  AscendC::TPipe* pipeIn)
    {
        // set cube only
        KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIC_ONLY);

        pipe = pipeIn;
        aGM.SetGlobalBuffer((__gm__ float *)a);
        bGM.SetGlobalBuffer((__gm__ float *)b);
        cGM.SetGlobalBuffer((__gm__ float *)c);
        pipe->InitBuffer(inQueueA1, 1, aSizeAlignL1 * sizeof(float));
        pipe->InitBuffer(inQueueA2, 1, aSizeAlignL0 * sizeof(float));
        pipe->InitBuffer(inQueueB1, 1, bSizeAlignL1 * sizeof(float));
        pipe->InitBuffer(inQueueB2, 1, bSizeAlignL0 * sizeof(float));
        pipe->InitBuffer(outQueueCO1, 1, cSizeAlignL0 * sizeof(float));
        pipe->InitBuffer(outQueueC1, 1, cSizeAlignL0 * sizeof(half));
    }
    __aicore__ inline void Process()
    {
        // GM-->L1随路ND2NZ,支持批量搬入
        // A矩阵和B矩阵在GM上非对齐,搬到L1后,m、k、n方向都向16对齐
        CopyIn();
        int32_t batchSize = B;
        AscendC::LocalTensor<float> a1Local = inQueueA1.DeQue<float>();
        AscendC::LocalTensor<float> b1Local = inQueueB1.DeQue<float>();
        AscendC::LocalTensor<float> c1Local = outQueueCO1.AllocTensor<float>();
        // 循环batchSize次,迭代计算batchSize对A、B矩阵的矩阵乘结果
        for (int32_t batchIndex = 0; batchIndex < batchSize; batchIndex++) {
            SplitA(a1Local[batchIndex * aSizeAlignL0]);
            SplitBTranspose(b1Local[batchIndex * bSizeAlignL0]);
            Compute(batchIndex, c1Local);
        }
        // LOC-->L1和LOC-->GM随路NZ2ND,支持批量处理
        CopyOut();
        inQueueA1.FreeTensor(a1Local);
        inQueueB1.FreeTensor(b1Local);
    }

private:
    __aicore__ inline uint16_t CeilDivision(uint16_t numerator, uint16_t denominator) 
    {
        return (numerator + denominator - 1) / denominator;
    }

    __aicore__ inline uint16_t CeilAlign(uint16_t numerator, uint16_t denominator) 
    {
        return (numerator + denominator - 1) / denominator * denominator;
    }
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<float> a1Local = inQueueA1.AllocTensor<float>();
        AscendC::LocalTensor<float> b1Local = inQueueB1.AllocTensor<float>();
        // GM-->L1,搬运A矩阵
        AscendC::Nd2NzParams nd2nzA1Params;
        /*
        表3 Nd2NzParams结构体参数定义
        展开
        参数名称

        含义

        ndNum

        传输ND矩阵的数目,取值范围:ndNum∈[0, 4095]。

        nValue

        ND矩阵的行数,取值范围:nValue∈[0, 16384]。

        dValue

        ND矩阵的列数,取值范围:dValue∈[0, 65535]。

        srcNdMatrixStride

        源操作数相邻ND矩阵起始地址间的偏移,取值范围:srcNdMatrixStride∈[0, 65535],单位为元素。

        srcDValue

        源操作数同一ND矩阵的相邻行起始地址间的偏移,取值范围:srcDValue∈[1, 65535],单位为元素。

        dstNzC0Stride

        ND转换到NZ格式后,源操作数中的一行会转换为目的操作数的多行。dstNzC0Stride表示,目的NZ矩阵中,来自源操作数同一行的多行数据相邻行起始地址间的偏移,取值范围:dstNzC0Stride∈[1, 16384],单位:C0_SIZE(32B)。

        dstNzNStride

        目的NZ矩阵中,Z型矩阵相邻行起始地址之间的偏移。取值范围:dstNzNStride∈[1, 16384],单位:C0_SIZE(32B)。

        dstNzMatrixStride

        目的NZ矩阵中,相邻NZ矩阵起始地址间的偏移,取值范围:dstNzMatrixStride∈[1, 65535],单位为元素。
        */
        nd2nzA1Params.ndNum = B;
        nd2nzA1Params.nValue = m;
        nd2nzA1Params.dValue = k;
        nd2nzA1Params.srcNdMatrixStride = m * k;
        nd2nzA1Params.srcDValue = k;
        nd2nzA1Params.dstNzC0Stride = CeilAlign(m, cubeShape[0]);
        nd2nzA1Params.dstNzNStride = 1;
        nd2nzA1Params.dstNzMatrixStride = aSizeAlignL0;
        AscendC::DataCopy(a1Local, aGM, nd2nzA1Params);

        // GM-->L1,搬运B矩阵
        AscendC::Nd2NzParams nd2nzB1Params;
        nd2nzB1Params.ndNum = B;
        nd2nzB1Params.nValue = k;
        nd2nzB1Params.dValue = n;
        nd2nzB1Params.srcNdMatrixStride = k * n;
        nd2nzB1Params.srcDValue = n;
        nd2nzB1Params.dstNzC0Stride = CeilAlign(k, cubeShape[0]);
        nd2nzB1Params.dstNzNStride = 1;
        nd2nzB1Params.dstNzMatrixStride = bSizeAlignL0;
        AscendC::DataCopy(b1Local, bGM, nd2nzB1Params);

        inQueueA1.EnQue(a1Local);
        inQueueB1.EnQue(b1Local);
    }

    __aicore__ inline void SplitA(const AscendC::LocalTensor<float>& a1Tensor)
    {
        AscendC::LocalTensor<float> a2Local = inQueueA2.AllocTensor<float>();
        uint32_t dstOffset = CeilDivision(k, cubeShape[1]) * cubeSize;
        uint32_t srcOffset = cubeSize; // c0_size=32
        // Nz -> Zz
        AscendC::LoadData2DParams loadDataParams;
        loadDataParams.repeatTimes = CeilDivision(k, cubeShape[1]);
        loadDataParams.srcStride = CeilDivision(m, cubeShape[0]);
        loadDataParams.dstGap = 0;
        loadDataParams.ifTranspose = false;
        for (int i = 0; i < CeilDivision(m, cubeShape[0]); ++i) {
            AscendC::LoadData(a2Local[i * dstOffset], a1Tensor[i * srcOffset], loadDataParams);
        }
        inQueueA2.EnQue<float>(a2Local);
    }

    __aicore__ inline void SplitBTranspose(const AscendC::LocalTensor<float>& b1Tensor)
    {
        AscendC::LocalTensor<float> b2Local = inQueueB2.AllocTensor<float>();
        AscendC::LoadData3DParamsV2<float> loadDataParams;
        loadDataParams.l1H = CeilAlign(k, cubeShape[0]);
        loadDataParams.l1W = 1;
        loadDataParams.channelSize = CeilAlign(n, cubeShape[1]);
        loadDataParams.kExtension = CeilAlign(n, cubeShape[1]);
        loadDataParams.mExtension = CeilAlign(k, cubeShape[0]);
        loadDataParams.strideW = 1;
        loadDataParams.strideH = 1;
        loadDataParams.filterW = 1;
        loadDataParams.filterH = 1;
        loadDataParams.dilationFilterW = 1;
        loadDataParams.dilationFilterH = 1;
        loadDataParams.filterSizeW = false;
        loadDataParams.filterSizeH = false;
        loadDataParams.enTranspose = true;
        loadDataParams.fMatrixCtrl = false;
        AscendC::LoadData(b2Local, b1Tensor, loadDataParams);
        
        inQueueB2.EnQue<float>(b2Local);
    }
    __aicore__ inline void Compute(int32_t batchIndex, AscendC::LocalTensor<float>& c1Local)
    {
        AscendC::LocalTensor<float> a2Local = inQueueA2.DeQue<float>();
        AscendC::LocalTensor<float> b2Local = inQueueB2.DeQue<float>();
        AscendC::MmadParams mmadParams;
        mmadParams.m = m;
        mmadParams.n = n;
        mmadParams.k = k;
        AscendC::Mmad(c1Local[batchIndex * CeilAlign(m, cubeShape[0]) * CeilAlign(n, cubeShape[0])],
            a2Local, b2Local, mmadParams);
        if (batchIndex == B - 1) {
            outQueueCO1.EnQue<float>(c1Local);
        } 
        inQueueA2.FreeTensor(a2Local);
        inQueueB2.FreeTensor(b2Local);
    }

    __aicore__ inline void CopyOut()
    {
        // LOC-->GM,随路NZ2ND
        AscendC::LocalTensor<float> c1Local = outQueueCO1.DeQue<float>();
        AscendC::FixpipeParamsV220 fixpipeParams;
        fixpipeParams.nSize = n;
        fixpipeParams.mSize = m;
        fixpipeParams.srcStride = CeilAlign(m, cubeShape[0]);
        fixpipeParams.dstStride = n;
        fixpipeParams.ndNum = B;
        // 不同NZ矩阵起始地址之间的间隔,取值范围:srcNdStride∈[1, 512],单位:1024B
        fixpipeParams.srcNdStride = (CeilAlign(m, cubeShape[0]) * CeilAlign(n, cubeShape[0])) 
                                    / (cubeShape[0] * cubeShape[0]);
        // 目的相邻ND矩阵起始地址之间的偏移,取值范围:dstNdstride∈[1, 65535],单位:element
        fixpipeParams.dstNdStride = m * n;
        AscendC::Fixpipe(cGM, c1Local, fixpipeParams);

        // LOC-->L1
        AscendC::LocalTensor<half> c1L1Local = outQueueC1.AllocTensor<half>();
        AscendC::Fixpipe(c1L1Local, c1Local, fixpipeParams);
        // 打印位于L1上的矩阵
        AscendC::printf("C in L1 buffer:\n");
        AscendC::DumpTensor(c1L1Local, 1, 10);

        outQueueCO1.FreeTensor(c1Local);
        outQueueC1.FreeTensor(c1L1Local);
    }

private:
    AscendC::TPipe* pipe;
    AscendC::TQue<AscendC::TPosition::A1, 1> inQueueA1;
    AscendC::TQue<AscendC::TPosition::A2, 1> inQueueA2;
    AscendC::TQue<AscendC::TPosition::B1, 1> inQueueB1;
    AscendC::TQue<AscendC::TPosition::B2, 1> inQueueB2;
    AscendC::TQue<AscendC::TPosition::CO1, 1> outQueueCO1;
    AscendC::TQue<AscendC::TPosition::C1, 1> outQueueC1;

    AscendC::GlobalTensor<float> aGM;
    AscendC::GlobalTensor<float> bGM;
    AscendC::GlobalTensor<float> cGM;
    uint16_t m = M, k = K, n = N;
    uint16_t aSizeAlignL1, bSizeAlignL1;
    uint16_t aSizeAlignL0, bSizeAlignL0, cSizeAlignL0;
    uint16_t cubeShape[2] = {0, 0};
    uint16_t cubeSize = 0;
    uint16_t fractNum = 0;
};

extern "C" __global__ __aicore__ void mmad_custom(GM_ADDR a, GM_ADDR b, GM_ADDR c)
{
    AscendC::TPipe pipe;   
    KernelMmad op;                                                        
    op.Init(a, b, c, &pipe);
    op.Process();
}

int32_t main(int32_t argc, char *argv[])
{
    size_t aFileSize = 0;
    size_t bFileSize = 0;
    size_t cFileSize = 0;
    
    aFileSize = B * M * K * sizeof(float); 
    bFileSize = B * K * N * sizeof(float); 
    cFileSize = B * M * N * sizeof(float);
    uint32_t numBlocks = 1;

    aclInit(nullptr);
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);

    uint8_t *aHost;
    uint8_t *aDevice;
    aclrtMallocHost((void **)(&aHost), aFileSize);
    aclrtMalloc((void **)&aDevice, aFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
    ReadFile("./input/x1_gm.bin", aFileSize, aHost, aFileSize);
    aclrtMemcpy(aDevice, aFileSize, aHost, aFileSize, ACL_MEMCPY_HOST_TO_DEVICE);

    uint8_t *bHost;
    uint8_t *bDevice;
    aclrtMallocHost((void **)(&bHost), bFileSize);
    aclrtMalloc((void **)&bDevice, bFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
    ReadFile("./input/x2_gm.bin", bFileSize, bHost, bFileSize);
    aclrtMemcpy(bDevice, bFileSize, bHost, bFileSize, ACL_MEMCPY_HOST_TO_DEVICE);

    uint8_t *cHost;
    uint8_t *cDevice;
    aclrtMallocHost((void **)(&cHost), cFileSize);
    aclrtMalloc((void **)&cDevice, cFileSize, ACL_MEM_MALLOC_HUGE_FIRST);

    mmad_custom<<<numBlocks, nullptr, stream>>>(aDevice, bDevice, cDevice);
    aclrtSynchronizeStream(stream);

    aclrtMemcpy(cHost, cFileSize, cDevice, cFileSize, ACL_MEMCPY_DEVICE_TO_HOST);
    WriteFile("./output/output.bin", cHost, cFileSize);

    aclrtFree(aDevice);
    aclrtFreeHost(aHost);
    aclrtFree(bDevice);
    aclrtFreeHost(bHost);
    aclrtFree(cDevice);
    aclrtFreeHost(cHost);

    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();
    return 0;
}