/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/* !
 * \file mmad_unitflag_disable.asc
 * \brief
 */

#include "acl/acl.h"
#include "kernel_operator.h"
#include "data_utils.h"

// half type, cube block: [16, 16]
constexpr uint32_t CUBE_BLOCK = 16;
constexpr uint32_t CUBE_BLOCK_SIZE = 16 * 16;

constexpr uint32_t M = 128;
constexpr uint32_t K = 512;
constexpr uint32_t N = 256;
constexpr uint32_t kRound = 8;

constexpr static uint16_t LIMIT_MNSIZE = 10;
constexpr static uint16_t ALIGN_NUM = 16;

class KernelMmad {
public:
    __aicore__ inline KernelMmad()
    {
        aSingleSize = m * k;
        bSingleSize = k * n;
        cSingleSize = m * n;
        aSize = kRound * aSingleSize;
        bSize = kRound * bSingleSize;
        cSize = kRound * cSingleSize;
    }
    __aicore__ inline void Init(GM_ADDR a, GM_ADDR b, GM_ADDR c, AscendC::TPipe* pipeIn)
    {
        // set cube only
        KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIC_ONLY);
        pipe = pipeIn;
        aGM.SetGlobalBuffer((__gm__ half *)a);
        bGM.SetGlobalBuffer((__gm__ half *)b);
        cGM.SetGlobalBuffer((__gm__ float *)c);
        pipe->InitBuffer(inQueueA1, 1, aSingleSize * sizeof(half));
        pipe->InitBuffer(inQueueA2, 1, aSingleSize * sizeof(half));
        pipe->InitBuffer(inQueueB1, 1, bSingleSize * sizeof(half));
        pipe->InitBuffer(inQueueB2, 1, bSingleSize * sizeof(half));
        pipe->InitBuffer(outQueueCO1, 1, cSingleSize * sizeof(float));
    }
    __aicore__ inline void Process()
    {
        uint32_t chunkSize = kRound;
        AscendC::LocalTensor<float> c = outQueueCO1.AllocTensor<float>();
        for (uint32_t kIndex = 0; kIndex < chunkSize; kIndex++) {
            CopyIn(kIndex);
            SplitA(kIndex);
            SplitB(kIndex);
            Compute(kIndex, c);
        }
        CopyOut();
    }

private:
    __aicore__ inline uint32_t CeilCubeBlock(uint32_t len) 
    {
        return (len + CUBE_BLOCK - 1) / CUBE_BLOCK;
    }

    __aicore__ inline void CopyIn(uint32_t kIndex)
    {
        AscendC::LocalTensor<half> a1Local = inQueueA1.AllocTensor<half>();
        AscendC::LocalTensor<half> b1Local = inQueueB1.AllocTensor<half>();
        AscendC::Nd2NzParams nd2nzA1Params;
        // 传输ND矩阵的数目
        nd2nzA1Params.ndNum = 1;
        nd2nzA1Params.nValue = m;
        nd2nzA1Params.dValue = k;
        // 源操作数相邻ND矩阵起始地址间的偏移,单位为元素。
        nd2nzA1Params.srcNdMatrixStride = 0;
        nd2nzA1Params.srcDValue = k;
        // 目的NZ矩阵中,来自源操作数同一行的多行数据相邻行起始地址间的偏移,取值范围:dstNzC0Stride∈[1, 16384],单位:C0_SIZE(32B)。
        nd2nzA1Params.dstNzC0Stride = CeilCubeBlock(m) * CUBE_BLOCK;
        // 目的NZ矩阵中,Z型矩阵相邻行起始地址之间的偏移。单位:C0_SIZE(32B)。
        nd2nzA1Params.dstNzNStride = 1;
        // 目的NZ矩阵中,相邻NZ矩阵起始地址间的偏移,单位为元素。
        nd2nzA1Params.dstNzMatrixStride = 0;
        AscendC::DataCopy(a1Local, aGM[kIndex * aSingleSize], nd2nzA1Params);

        AscendC::Nd2NzParams nd2nzB1Params;
        nd2nzB1Params.ndNum = 1;
        nd2nzB1Params.nValue = k;
        nd2nzB1Params.dValue = n;
        nd2nzB1Params.srcNdMatrixStride = 0;
        nd2nzB1Params.srcDValue = n;
        nd2nzB1Params.dstNzC0Stride = CeilCubeBlock(k) * CUBE_BLOCK;
        nd2nzB1Params.dstNzNStride = 1;
        nd2nzB1Params.dstNzMatrixStride = 0;
        AscendC::DataCopy(b1Local, bGM[kIndex * bSingleSize], nd2nzB1Params);

        inQueueA1.EnQue(a1Local);
        inQueueB1.EnQue(b1Local);
    }

    __aicore__ inline void SplitA(uint32_t kIndex)
    {
        AscendC::LocalTensor<half> a1Local = inQueueA1.DeQue<half>();
        AscendC::LocalTensor<half> a = inQueueA2.AllocTensor<half>();
        uint32_t dstOffset = CeilCubeBlock(k) * CUBE_BLOCK_SIZE;
        uint32_t srcOffset = CUBE_BLOCK_SIZE;
        // Nz -> Zz
        AscendC::LoadData2DParams loadDataParams;
        loadDataParams.repeatTimes = CeilCubeBlock(k);
        loadDataParams.srcStride = CeilCubeBlock(m);
        loadDataParams.dstGap = 0;
        loadDataParams.ifTranspose = false;
        for (int i = 0; i < CeilCubeBlock(m); ++i) {
            AscendC::LoadData(a[i * dstOffset], a1Local[i * srcOffset], loadDataParams);
        }
        inQueueA2.EnQue<half>(a);
        inQueueA1.FreeTensor(a1Local);
    }
    __aicore__ inline void SplitB(uint32_t kIndex)
    {
        AscendC::LocalTensor<half> b1Local = inQueueB1.DeQue<half>();
        AscendC::LocalTensor<half> b = inQueueB2.AllocTensor<half>();

        uint32_t dstOffset = CeilCubeBlock(n) * CUBE_BLOCK_SIZE;
        uint32_t srcOffset = CUBE_BLOCK_SIZE;
        // Nz -> Zn
        AscendC::LoadData2DParams loadDataParams;
        loadDataParams.repeatTimes = CeilCubeBlock(n);
        loadDataParams.srcStride = CeilCubeBlock(k);
        loadDataParams.dstGap = 0;
        loadDataParams.ifTranspose = true;
        for (int i = 0; i < CeilCubeBlock(k); ++i) {
            AscendC::LoadData(b[i * dstOffset], b1Local[i * srcOffset], loadDataParams);
        }
        inQueueB2.EnQue<half>(b);
        inQueueB1.FreeTensor(b1Local);
    }
    __aicore__ inline void Compute(uint32_t kIndex, AscendC::LocalTensor<float>& c)
    {
        AscendC::LocalTensor<half> a = inQueueA2.DeQue<half>();
        AscendC::LocalTensor<half> b = inQueueB2.DeQue<half>();
        AscendC::MmadParams mmadParams;
        mmadParams.m = m;
        mmadParams.n = n;
        mmadParams.k = k;
        if (kIndex == 0) {
            mmadParams.cmatrixInitVal = true;
        } else {
            mmadParams.cmatrixInitVal = false;
        }
        AscendC::Mmad(c, a, b, mmadParams);
        if ((m / ALIGN_NUM) * (n / ALIGN_NUM) < LIMIT_MNSIZE) {
            AscendC::PipeBarrier<PIPE_M>();
        }
        if (kIndex == kRound - 1) {
            outQueueCO1.EnQue<float>(c);
        }
        inQueueA2.FreeTensor(a);
        inQueueB2.FreeTensor(b);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<float> c = outQueueCO1.DeQue<float>();

        // LOC-->GM,随路NZ2ND
        AscendC::FixpipeParamsV220 fixpipeParams;
        fixpipeParams.nSize = n;
        fixpipeParams.mSize = m;
        fixpipeParams.srcStride = m;
        fixpipeParams.dstStride = n;

        fixpipeParams.ndNum = 1;
        fixpipeParams.srcNdStride = 0;
        fixpipeParams.dstNdStride = 0;
        AscendC::Fixpipe(cGM, c, fixpipeParams);
        outQueueCO1.FreeTensor(c);
    }

private:
    AscendC::TPipe* pipe;
    AscendC::TQue<AscendC::TPosition::A1, 1> inQueueA1;
    AscendC::TQue<AscendC::TPosition::A2, 1> inQueueA2;
    AscendC::TQue<AscendC::TPosition::B1, 1> inQueueB1;
    AscendC::TQue<AscendC::TPosition::B2, 1> inQueueB2;
    AscendC::TQue<AscendC::TPosition::CO1, 1> outQueueCO1;

    AscendC::GlobalTensor<half> aGM;
    AscendC::GlobalTensor<half> bGM;
    AscendC::GlobalTensor<float> cGM;
    // 注意沿着K轴切分
    uint16_t m = M, k = K / kRound, n = N;
    uint16_t aSize, bSize, cSize;
    uint16_t aSingleSize, bSingleSize, cSingleSize;
};

extern "C" __global__ __aicore__ void chunk_mmad_custom(GM_ADDR a, GM_ADDR b, GM_ADDR c)
{
    AscendC::TPipe pipe;
    KernelMmad op;
    op.Init(a, b, c, &pipe);
    op.Process();
}

int32_t main(int32_t argc, char *argv[])
{
    size_t aFileSize = M * K * sizeof(int16_t); 
    size_t bFileSize = K * N * sizeof(int16_t); 
    size_t cFileSize = M * N * sizeof(float);
    uint32_t numBlocks = 1;

    aclInit(nullptr);
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);

    uint8_t *aHost;
    uint8_t *aDevice;
    aclrtMallocHost((void **)(&aHost), aFileSize);
    aclrtMalloc((void **)&aDevice, aFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
    ReadFile("./input/x1_gm.bin", aFileSize, aHost, aFileSize);
    aclrtMemcpy(aDevice, aFileSize, aHost, aFileSize, ACL_MEMCPY_HOST_TO_DEVICE);

    uint8_t *bHost;
    uint8_t *bDevice;
    aclrtMallocHost((void **)(&bHost), bFileSize);
    aclrtMalloc((void **)&bDevice, bFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
    ReadFile("./input/x2_gm.bin", bFileSize, bHost, bFileSize);
    aclrtMemcpy(bDevice, bFileSize, bHost, bFileSize, ACL_MEMCPY_HOST_TO_DEVICE);

    uint8_t *cHost;
    uint8_t *cDevice;
    aclrtMallocHost((void **)(&cHost), cFileSize);
    aclrtMalloc((void **)&cDevice, cFileSize, ACL_MEM_MALLOC_HUGE_FIRST);

    chunk_mmad_custom<<<numBlocks, nullptr, stream>>>(aDevice, bDevice, cDevice);
    aclrtSynchronizeStream(stream);

    aclrtMemcpy(cHost, cFileSize, cDevice, cFileSize, ACL_MEMCPY_DEVICE_TO_HOST);
    WriteFile("./output/output.bin", cHost, cFileSize);

    aclrtFree(aDevice);
    aclrtFreeHost(aHost);
    aclrtFree(bDevice);
    aclrtFreeHost(bHost);
    aclrtFree(cDevice);
    aclrtFreeHost(cHost);

    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();
    return 0;
}