/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/* !
 * \file mmad_s8_f16_f32_with_A_B_transpose_option.asc
 * \brief
 */

#include "acl/acl.h"
#include "kernel_operator.h"
#include "data_utils.h"

constexpr uint32_t tilingKey = TILING_KEY;
constexpr uint32_t M = M_SIZE;
constexpr uint32_t N = N_SIZE;
constexpr uint32_t K = K_SIZE;

// A矩阵转置,则L1-->L0A时方块a需要转置;B矩阵转置,则L1-->L0B时方块b不需要转置。
template <class T, class U, bool isAtranspose, bool isBtranspose>
class KernelMmad {
public:
    __aicore__ inline KernelMmad()
    {
        // 左矩阵分形的shape
        fractalShape[0] = 16;
        fractalShape[1] = 32 / sizeof(T);
        // 右矩阵的shape:[32 / sizeof(T), 16]
        // 左、右矩阵分形的size,单位是元素数目
        fractalSize = 16 * fractalShape[1];
        // 转置只能以方块形式,因此不同位宽下,方块中包含的分形个数不同
        if constexpr (sizeof(T) == 2) {
            fractalNum = 1;
        } else {
            fractalNum = 2;
        }
        // 对齐后的shape
        // 计算不同场景下,A矩阵、B矩阵各个方向对齐的shape参数
        if constexpr (AscendC::IsSameType<T, int8_t>::value && AscendC::IsSameType<U, int32_t>::value) {
            if constexpr (!isAtranspose) {
                // L1上,A矩阵的对齐
                // GM上A矩阵的shape为[m,k]
                mAlignL1 = CeilAlign(m, fractalShape[0]); // 高度
                kAlignL1 = CeilAlign(k, fractalShape[1]); // 宽度
                aSizeAlignL1 = mAlignL1 * kAlignL1;

                // L0上,A矩阵的对齐
                // 由于L0上a矩阵也是Z排布
                mAlignL0 = CeilAlign(m, fractalShape[0]); // 高度
                kAlignL0 = CeilAlign(k, fractalShape[1]); // 宽度
                aSizeAlignL0 = mAlignL0 * kAlignL0;
            } else {
                // L1上,A矩阵的对齐
                // GM上A矩阵的shape为[k,m]
                kAlignL1 = CeilAlign(k, fractalShape[0] * fractalNum); // 高度
                mAlignL1 = CeilAlign(m, fractalShape[1]); // 宽度
                aSizeAlignL1 = kAlignL1 * mAlignL1;

                // L0上,A矩阵的对齐
                // 由于L0上a矩阵也是Z排布
                mAlignL0 = CeilAlign(m, fractalShape[0] * fractalNum); // 高度
                kAlignL0 = CeilAlign(k, fractalShape[1]); // 宽度
                aSizeAlignL0 = mAlignL0 * kAlignL0;
            }
            if constexpr (!isBtranspose) {
                // GM上B矩阵的shape为[k,n]
                kAlignL1 = CeilAlign(k, fractalShape[0] * fractalNum); // 高度
                nAlignL1 = CeilAlign(n, fractalShape[1]); // 宽度
                bSizeAlignL1 = kAlignL1 * nAlignL1;

                kAlignL0 = CeilAlign(k, fractalShape[1]); // 高度
                nAlignL0 = CeilAlign(n, fractalShape[0] * fractalNum); // 宽度
                bSizeAlignL0 = kAlignL0 * nAlignL0;
            } else {
                // L1上,B矩阵的对齐
                // L1上, b矩阵是Z排布
                nAlignL1 = CeilAlign(n, fractalShape[0]); // 高度
                kAlignL1 = CeilAlign(k, fractalShape[1]); // 宽度
                bSizeAlignL1 = nAlignL1 * kAlignL1;

                // L0上,B矩阵的对齐
                // L0上, b矩阵是N排布
                kAlignL0 = CeilAlign(k, fractalShape[1]); // 高度
                nAlignL0 = CeilAlign(n, fractalShape[0]); // 宽度
                bSizeAlignL0 = kAlignL0 * nAlignL0;
            }
        } else if constexpr (AscendC::IsSameType<T, half>::value && AscendC::IsSameType<U, float>::value) {
            if constexpr (!isAtranspose) {
                // L1上,A矩阵的对齐
                mAlignL1 = CeilAlign(m, fractalShape[0]); // 高度
                kAlignL1 = CeilAlign(k, fractalShape[1]); // 宽度
                aSizeAlignL1 = mAlignL1 * kAlignL1;

                // L0上,A矩阵的对齐
                // 由于L0上a矩阵也是Z排布
                mAlignL0 = CeilAlign(m, fractalShape[0]); // 高度
                kAlignL0 = CeilAlign(k, fractalShape[1]); // 宽度
                aSizeAlignL0 = mAlignL0 * kAlignL0;
            } else {
                // L1上,A矩阵的对齐
                // GM上A矩阵的shape为[k,m]
                kAlignL1 = CeilAlign(k, fractalShape[0]); // 高度
                mAlignL1 = CeilAlign(m, fractalShape[1]); // 宽度
                aSizeAlignL1 = kAlignL1 * mAlignL1;

                // L0上,A矩阵的对齐
                // 由于L0上a矩阵也是Z排布
                mAlignL0 = CeilAlign(m, fractalShape[0]); // 高度
                kAlignL0 = CeilAlign(k, fractalShape[1]); // 宽度
                aSizeAlignL0 = mAlignL0 * kAlignL0;
            }
            if constexpr (!isBtranspose) {
                // GM上B矩阵的shape为[k,n]
                kAlignL1 = CeilAlign(k, fractalShape[0]); // 高度
                nAlignL1 = CeilAlign(n, fractalShape[1]); // 宽度
                bSizeAlignL1 = kAlignL1 * nAlignL1;

                kAlignL0 = CeilAlign(k, fractalShape[1]); // 高度
                nAlignL0 = CeilAlign(n, fractalShape[0]); // 宽度
                bSizeAlignL0 = kAlignL0 * nAlignL0;
            } else {
                // L1上,B矩阵的对齐
                // L1上, b矩阵是Z排布
                // GM上A矩阵的shape为[n,k]
                nAlignL1 = CeilAlign(n, fractalShape[0]); // 高度
                kAlignL1 = CeilAlign(k, fractalShape[1]); // 宽度
                bSizeAlignL1 = nAlignL1 * kAlignL1;

                // L0上,B矩阵的对齐
                // L0上, b矩阵是N排布
                kAlignL0 = CeilAlign(k, fractalShape[1]); // 高度
                nAlignL0 = CeilAlign(n, fractalShape[0]); // 宽度
                bSizeAlignL0 = kAlignL0 * nAlignL0;
            }
        } else if constexpr (AscendC::IsSameType<T, float>::value && AscendC::IsSameType<U, float>::value) {
            if constexpr (!isAtranspose) {
                // L1上,A矩阵的对齐
                mAlignL1 = CeilAlign(m, fractalShape[0]); // 高度
                kAlignL1 = CeilAlign(k, fractalShape[1]); // 宽度
                aSizeAlignL1 = mAlignL1 * kAlignL1;

                // L0上,A矩阵的对齐
                // 由于L0上a矩阵也是Z排布
                mAlignL0 = CeilAlign(m, fractalShape[0]); // 高度
                kAlignL0 = CeilAlign(k, fractalShape[1]); // 宽度
                aSizeAlignL0 = mAlignL0 * kAlignL0;
            } else {
                // L1上,A矩阵的对齐
                // GM上A矩阵的shape为[k,m]
                kAlignL1 = CeilAlign(k, fractalShape[0]); // 高度
                mAlignL1 = CeilAlign(m, fractalShape[1]); // 宽度
                aSizeAlignL1 = kAlignL1 * mAlignL1;

                // L0上,A矩阵的对齐
                // 由于L0上a矩阵也是Z排布
                mAlignL0 = CeilAlign(m, fractalShape[1]); // 高度
                kAlignL0 = CeilAlign(k, fractalShape[1] * fractalNum); // 宽度
                aSizeAlignL0 = mAlignL0 * kAlignL0;
            }
            if constexpr (!isBtranspose) {
                // GM上B矩阵的shape为[k,n]
                kAlignL1 = CeilAlign(k, fractalShape[0]); // 高度
                nAlignL1 = CeilAlign(n, fractalShape[1]); // 宽度
                bSizeAlignL1 = kAlignL1 * nAlignL1;

                kAlignL0 = CeilAlign(k, fractalShape[0]); // 高度
                nAlignL0 = CeilAlign(n, fractalShape[1]); // 宽度
                bSizeAlignL0 = kAlignL0 * nAlignL0;
            } else {
                // L1上,B矩阵的对齐
                // L1上, b矩阵是Z排布
                nAlignL1 = CeilAlign(n, fractalShape[0]); // 高度
                kAlignL1 = CeilAlign(k, fractalShape[1]); // 宽度
                bSizeAlignL1 = nAlignL1 * kAlignL1;

                // L0上,B矩阵的对齐
                // L0上, b矩阵是N排布
                kAlignL0 = CeilAlign(k, fractalShape[1]); // 高度
                nAlignL0 = CeilAlign(n, fractalShape[0]); // 宽度
                bSizeAlignL0 = kAlignL0 * nAlignL0;
            }
        }
        // C矩阵无视数据类型和a、b是否转置,m,n都向16对齐
        cSizeAlignL0 = CeilAlign(m, fractalShape[0]) * CeilAlign(n, fractalShape[0]);
    }
    __aicore__ inline void Init(GM_ADDR a, GM_ADDR b, GM_ADDR c,  AscendC::TPipe* pipeIn)
    {
        // set cube only
        KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIC_ONLY);

        pipe = pipeIn;
        aGM.SetGlobalBuffer((__gm__ T *)a);
        bGM.SetGlobalBuffer((__gm__ T *)b);
        cGM.SetGlobalBuffer((__gm__ U *)c);
        pipe->InitBuffer(inQueueA1, 1, aSizeAlignL1 * sizeof(T));
        pipe->InitBuffer(inQueueA2, 1, aSizeAlignL0 * sizeof(T));
        pipe->InitBuffer(inQueueB1, 1, bSizeAlignL1 * sizeof(T));
        pipe->InitBuffer(inQueueB2, 1, bSizeAlignL0 * sizeof(T));
        pipe->InitBuffer(outQueueCO1, 1, cSizeAlignL0 * sizeof(U));
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        if constexpr (!isAtranspose) {
            SplitALoad3Dv2();
        } else {
            SplitATransposeLoad3Dv2();
        }
        if constexpr (!isBtranspose) {
            SplitBTransposeLoad3Dv2();
        } else {
            SplitB();
        }
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline uint16_t CeilDivision(uint16_t numerator, uint16_t denominator) 
    {
        return (numerator + denominator - 1) / denominator;
    }

    __aicore__ inline uint16_t CeilAlign(uint16_t numerator, uint16_t denominator) 
    {
        return (numerator + denominator - 1) / denominator * denominator;
    }
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<T> a1Local = inQueueA1.AllocTensor<T>();
        AscendC::LocalTensor<T> b1Local = inQueueB1.AllocTensor<T>();
        // GM-->L1,搬运A矩阵
        AscendC::Nd2NzParams nd2nzA1Params;
        // 不同的数据类型,A、B矩阵在高度方向的对齐不同
        if constexpr (!isAtranspose) {
            // 传输ND矩阵的数目
            nd2nzA1Params.ndNum = 1;
            // ND矩阵的行数
            nd2nzA1Params.nValue = m;
            // ND矩阵的列数
            nd2nzA1Params.dValue = k;
            // 只传输了1个ND矩阵,该参数无效
            nd2nzA1Params.srcNdMatrixStride = 0;
            // 源操作数同一ND矩阵的相邻行起始地址间的偏移
            nd2nzA1Params.srcDValue = k;

            // 以下这个参数取A矩阵在L1上,高度方向的对齐后的长度
            // 由于A不转置,因此对于三种数据类型该参数均相同
            nd2nzA1Params.dstNzC0Stride = CeilAlign(m, fractalShape[0]);

            nd2nzA1Params.dstNzNStride = 1;
            nd2nzA1Params.dstNzMatrixStride = 0;
        } else {
            nd2nzA1Params.ndNum = 1;
            nd2nzA1Params.nValue = k;
            nd2nzA1Params.dValue = m;
            nd2nzA1Params.srcNdMatrixStride = 0;
            nd2nzA1Params.srcDValue = m;

            // 以下这个参数取A矩阵在L1上,高度方向的对齐后的长度
            // 由于A转置,因此三种数据类型下,该参数的配置不相同
            if constexpr (AscendC::IsSameType<T, int8_t>::value && AscendC::IsSameType<U, int32_t>::value) {
                nd2nzA1Params.dstNzC0Stride = CeilAlign(k, fractalShape[0] * fractalNum);
            } else if constexpr (AscendC::IsSameType<T, half>::value && AscendC::IsSameType<U, float>::value) {
                nd2nzA1Params.dstNzC0Stride = CeilAlign(k, fractalShape[0]);
            } else if constexpr (AscendC::IsSameType<T, float>::value && AscendC::IsSameType<U, float>::value) {
                nd2nzA1Params.dstNzC0Stride = CeilAlign(k, fractalShape[0]);
            }

            nd2nzA1Params.dstNzNStride = 1;
            nd2nzA1Params.dstNzMatrixStride = 0;
        }
        AscendC::DataCopy(a1Local, aGM, nd2nzA1Params);

        // GM-->L1,搬运B矩阵
        AscendC::Nd2NzParams nd2nzB1Params;
        if constexpr (!isBtranspose) {
            nd2nzB1Params.ndNum = 1;
            nd2nzB1Params.nValue = k;
            nd2nzB1Params.dValue = n;
            nd2nzB1Params.srcNdMatrixStride = 0;
            nd2nzB1Params.srcDValue = n;

            // 以下这个参数取B矩阵在L1上,高度方向的对齐后的长度
            // 由于A转置,因此三种数据类型下,该参数的配置不相同
            if constexpr (AscendC::IsSameType<T, int8_t>::value && AscendC::IsSameType<U, int32_t>::value) {
                nd2nzB1Params.dstNzC0Stride = CeilAlign(k, fractalShape[0] * fractalNum);
            } else if constexpr (AscendC::IsSameType<T, half>::value && AscendC::IsSameType<U, float>::value) {
                nd2nzB1Params.dstNzC0Stride = CeilAlign(k, fractalShape[0]);
            } else if constexpr (AscendC::IsSameType<T, float>::value && AscendC::IsSameType<U, float>::value) {
                nd2nzB1Params.dstNzC0Stride = CeilAlign(k, fractalShape[0]);
            }

            nd2nzB1Params.dstNzNStride = 1;
            nd2nzB1Params.dstNzMatrixStride = 0;
        } else {
            nd2nzB1Params.ndNum = 1;
            nd2nzB1Params.nValue = n;
            nd2nzB1Params.dValue = k;
            nd2nzB1Params.srcNdMatrixStride = 0;
            nd2nzB1Params.srcDValue = k;

            // 以下这个参数取B矩阵在L1上,高度方向的对齐后的长度
            // 由于B转置,因此三种数据类型下,该参数的配置相同
            nd2nzB1Params.dstNzC0Stride = CeilAlign(n, fractalShape[0]);
            nd2nzB1Params.dstNzNStride = 1;
            nd2nzB1Params.dstNzMatrixStride = 0;
        }
        AscendC::DataCopy(b1Local, bGM, nd2nzB1Params);

        inQueueA1.EnQue(a1Local);
        inQueueB1.EnQue(b1Local);
    }

    // A矩阵转置,调用Load3Dv2
    __aicore__ inline void SplitATransposeLoad3Dv2()
    {
        AscendC::LocalTensor<T> a1Local = inQueueA1.DeQue<T>();
        AscendC::LocalTensor<T> a2Local = inQueueA2.AllocTensor<T>();
        
        // 使用load3d接口,实现NZ2ZZ
        AscendC::LoadData3DParamsV2<T> loadDataParams;
        // 源操作数height
        loadDataParams.l1H = CeilAlign(k, fractalShape[0]);
        // 源操作数wight
        loadDataParams.l1W = 1;
        // 源操作数的通道数,
        // img2col的结果矩阵高度为ho * wo,根据ho和wo的计算公式,代入卷积核宽度、卷积核滑动步长、卷积核膨胀系数等参数可知:ho * wo = loadDataParams.l1H * loadDataParams.l1w
        // img2col的结果矩阵宽度为ci * kh * kw,代入kh=1,kw=1,可知结果矩阵的宽度为ci=loadDataParams.channelSize = m
        loadDataParams.channelSize = CeilAlign(m, fractalShape[1]);
        // 该指令在目的操作数width维度的传输长度,如果不覆盖最右侧的分形,对于half类型,应为16的倍数,对于int8_t/uint8_t应为32的倍数;覆盖的情况则无倍数要求。
        loadDataParams.kExtension = CeilAlign(m, fractalShape[1]);
        // 该指令在目的操作数height维度的传输长度,如果不覆盖最下侧的分形,对于half/int8_t/uint8_t,应为16的倍数;覆盖的情况则无倍数要求。
        loadDataParams.mExtension = CeilAlign(k, fractalShape[1] * fractalNum);
        // 卷积核在源操作数width维度滑动的步长
        loadDataParams.strideW = 1;
        // 卷积核在源操作数height维度滑动的步长
        loadDataParams.strideH = 1;
        // 卷积核width
        loadDataParams.filterW = 1;
        // 卷积核height
        loadDataParams.filterH = 1;
        // 卷积核width膨胀系数
        loadDataParams.dilationFilterW = 1;
        // 卷积核height膨胀系数
        loadDataParams.dilationFilterH = 1;
        loadDataParams.filterSizeW = false;
        loadDataParams.filterSizeH = false;
        loadDataParams.enTranspose = true;
        loadDataParams.fMatrixCtrl = false;
        AscendC::LoadData(a2Local, a1Local, loadDataParams);

        inQueueA2.EnQue<T>(a2Local);
        inQueueA1.FreeTensor(a1Local);
    }

    // A矩阵不转置,调用Load3Dv2
    __aicore__ inline void SplitALoad3Dv2()
    {
        AscendC::LocalTensor<T> a1Local = inQueueA1.DeQue<T>();
        AscendC::LocalTensor<T> a2Local = inQueueA2.AllocTensor<T>();
        
        // 使用load3d接口,实现NZ2ZZ
        AscendC::LoadData3DParamsV2<T> loadDataParams;
        // 源操作数height
        loadDataParams.l1H = CeilAlign(m, fractalShape[0]);
        // 源操作数wight
        loadDataParams.l1W = 1;
        // 源操作数的通道数,
        // img2col的结果矩阵高度为ho * wo,根据ho和wo的计算公式,代入卷积核宽度、卷积核滑动步长、卷积核膨胀系数等参数可知:ho * wo = loadDataParams.l1H * loadDataParams.l1w
        // img2col的结果矩阵宽度为ci * kh * kw,代入kh=1,kw=1,可知结果矩阵的宽度为ci=loadDataParams.channelSize = m
        loadDataParams.channelSize = CeilAlign(k, fractalShape[1]);
        // 该指令在目的操作数width维度的传输长度,如果不覆盖最右侧的分形,对于half类型,应为16的倍数,对于int8_t/uint8_t应为32的倍数;覆盖的情况则无倍数要求。
        loadDataParams.kExtension = CeilAlign(k, fractalShape[1]);
        // 该指令在目的操作数height维度的传输长度,如果不覆盖最下侧的分形,对于half/int8_t/uint8_t,应为16的倍数;覆盖的情况则无倍数要求。
        loadDataParams.mExtension = CeilAlign(m, fractalShape[0]);
        // 卷积核在源操作数width维度滑动的步长
        loadDataParams.strideW = 1;
        // 卷积核在源操作数height维度滑动的步长
        loadDataParams.strideH = 1;
        // 卷积核width
        loadDataParams.filterW = 1;
        // 卷积核height
        loadDataParams.filterH = 1;
        // 卷积核width膨胀系数
        loadDataParams.dilationFilterW = 1;
        // 卷积核height膨胀系数
        loadDataParams.dilationFilterH = 1;
        loadDataParams.filterSizeW = false;
        loadDataParams.filterSizeH = false;
        loadDataParams.enTranspose = false;
        loadDataParams.fMatrixCtrl = false;
        AscendC::LoadData(a2Local, a1Local, loadDataParams);

        inQueueA2.EnQue<T>(a2Local);
        inQueueA1.FreeTensor(a1Local);
    }

    // B矩阵转置,调用Load3Dv2
    __aicore__ inline void SplitBTransposeLoad3Dv2()
    {
        AscendC::LocalTensor<T> b1Local = inQueueB1.DeQue<T>();
        AscendC::LocalTensor<T> b2Local = inQueueB2.AllocTensor<T>();
        AscendC::LoadData3DParamsV2<T> loadDataParams;
        loadDataParams.l1H = CeilAlign(k, fractalShape[0]);
        loadDataParams.l1W = 1;
        loadDataParams.channelSize = CeilAlign(n, fractalShape[1]);
        loadDataParams.kExtension = CeilAlign(n, fractalShape[1]);
        loadDataParams.mExtension = CeilAlign(k, fractalShape[0]);
        loadDataParams.strideW = 1;
        loadDataParams.strideH = 1;
        loadDataParams.filterW = 1;
        loadDataParams.filterH = 1;
        loadDataParams.dilationFilterW = 1;
        loadDataParams.dilationFilterH = 1;
        loadDataParams.filterSizeW = false;
        loadDataParams.filterSizeH = false;
        // 对于Load3Dv2接口,当目的地址为L0B时,b矩阵会自动转置,loadDataParams.enTranspose仅在目的地址为L0A时生效。
        loadDataParams.enTranspose = true;
        loadDataParams.fMatrixCtrl = false;
        AscendC::LoadData(b2Local, b1Local, loadDataParams);
        inQueueB1.FreeTensor(b1Local);
        inQueueB2.EnQue<T>(b2Local);
    }

    __aicore__ inline void SplitB()
    {
        AscendC::LocalTensor<T> b1Local = inQueueB1.DeQue<T>();
        AscendC::LocalTensor<T> b2Local = inQueueB2.AllocTensor<T>();
        uint32_t dstOffset = CeilDivision(n, fractalShape[0]) * fractalSize;
        uint32_t srcOffset = CeilDivision(n, fractalShape[0]) * fractalSize;
        // Nz -> Zz
        AscendC::LoadData2DParams loadDataParams;     
        loadDataParams.repeatTimes = CeilDivision(n, fractalShape[0]);
        loadDataParams.srcStride = 1;
        loadDataParams.dstGap = 0;
        loadDataParams.ifTranspose = false;
        for (int i = 0; i < CeilDivision(k, fractalShape[1]); ++i) {
            AscendC::LoadData(b2Local[i * dstOffset], b1Local[i * srcOffset], loadDataParams);
        }
        inQueueB1.FreeTensor(b1Local);
        inQueueB2.EnQue<T>(b2Local);
    }

    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<T> a2Local = inQueueA2.DeQue<T>();
        AscendC::LocalTensor<T> b2Local = inQueueB2.DeQue<T>();
        AscendC::LocalTensor<U> c1Local = outQueueCO1.AllocTensor<U>();
        AscendC::MmadParams mmadParams;
        mmadParams.m = m;
        mmadParams.n = n;
        mmadParams.k = k;
        if constexpr (AscendC::IsSameType<T, float>::value && AscendC::IsSameType<U, float>::value) {
            if (isAtranspose) {
                mmadParams.kDirectionAlign = true;
            }
        }
        AscendC::Mmad(c1Local, a2Local, b2Local, mmadParams);
        outQueueCO1.EnQue<U>(c1Local);
        inQueueA2.FreeTensor(a2Local);
        inQueueB2.FreeTensor(b2Local);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<U> c1Local = outQueueCO1.DeQue<U>();
        AscendC::FixpipeParamsV220 fixpipeParams;
        fixpipeParams.nSize = n;
        fixpipeParams.mSize = m;

        // 源操作数来源于L0c,因此m只需要向16对齐,与数据类型无关
        fixpipeParams.srcStride = CeilAlign(m, fractalShape[0]);
        fixpipeParams.dstStride = n;

        fixpipeParams.ndNum = 1;
        fixpipeParams.srcNdStride = 0;
        fixpipeParams.dstNdStride = 0;
        AscendC::Fixpipe(cGM, c1Local, fixpipeParams);
        outQueueCO1.FreeTensor(c1Local);
    }

private:
    AscendC::TPipe* pipe;
    AscendC::TQue<AscendC::TPosition::A1, 1> inQueueA1;
    AscendC::TQue<AscendC::TPosition::A2, 1> inQueueA2;
    AscendC::TQue<AscendC::TPosition::B1, 1> inQueueB1;
    AscendC::TQue<AscendC::TPosition::B2, 1> inQueueB2;
    AscendC::TQue<AscendC::TPosition::CO1, 1> outQueueCO1;

    AscendC::GlobalTensor<T> aGM;
    AscendC::GlobalTensor<T> bGM;
    AscendC::GlobalTensor<U> cGM;
    uint16_t m = M, k = K, n = N;

    uint16_t mAlignL1 = M, kAlignL1 = K, nAlignL1 = N;
    uint16_t mAlignL0 = M, kAlignL0 = K, nAlignL0 = N;

    uint16_t aSizeAlignL1, bSizeAlignL1;
    uint16_t aSizeAlignL0, bSizeAlignL0, cSizeAlignL0;
    uint16_t fractalShape[2] = {0, 0};
    uint16_t fractalSize = 0;
    uint16_t fractalNum = 0;
};

extern "C" __global__ __aicore__ void mmad_custom(GM_ADDR a, GM_ADDR b, GM_ADDR c)
{
    AscendC::TPipe pipe;        
    // load3dv2接口只能用作A转置、不转置
    // Load3Dv2接口,当目的地址为L0B时,b矩阵会自动转置(b是B矩阵中的分形)
    // Load3Dv2接口,对于目的地址为L0A或L0B时,其支持的数据类型不同:
    // 目的地址为L0A,支持数据类型为:uint8_t/int8_t/half/bfloat16_t/uint32_t/int32_t/float/int4b_t
    // 目的地址为L0B,支持数据类型为:half/bfloat16_t/uint32_t/int32_t/float
    // 综上,可以得出Load3Dv2接口适用的场景有以下五个:
    if constexpr (tilingKey == 1) {
        // 输入为half类型
        // A矩阵不转置(a不转置),B矩阵不转置(b转置)
        KernelMmad<half, float, false, false> op;                                                        
        op.Init(a, b, c, &pipe);
        op.Process();
    } else if constexpr (tilingKey == 2) {
        // 输入为half类型
        // A矩阵转置(a转置),B矩阵不转置(b转置)
        KernelMmad<half, float, true, false> op;  
        op.Init(a, b, c, &pipe);
        op.Process();
    } else if constexpr (tilingKey == 3) {
        // 输入为float类型
        // A矩阵不转置(a不转置),B矩阵不转置(b转置)
        KernelMmad<float, float, false, false> op;                                                        
        op.Init(a, b, c, &pipe);
        op.Process();
    } else if constexpr (tilingKey == 4) {
        // 输入为float类型
        // A矩阵转置(a转置),B矩阵不转置(b转置)
        KernelMmad<float, float, true, false> op;  
        op.Init(a, b, c, &pipe);
        op.Process();
    } else if constexpr (tilingKey == 5) {
        // 输入为int8_t类型
        // A矩阵不转置(a不转置),B矩阵转置(b不转置)
        KernelMmad<int8_t, int32_t, false, true> op;  
        op.Init(a, b, c, &pipe);
        op.Process();
    }
}

int32_t main(int32_t argc, char *argv[])
{
    size_t aFileSize = 0;
    size_t bFileSize = 0;
    size_t cFileSize = 0;

    if constexpr (tilingKey <= 2) {
        aFileSize = M * K * sizeof(half); 
        bFileSize = K * N * sizeof(half); 
        cFileSize = M * N * sizeof(float);
    } else if constexpr (tilingKey <= 4) {
        aFileSize = M * K * sizeof(float); 
        bFileSize = K * N * sizeof(float); 
        cFileSize = M * N * sizeof(float);
    } else {
        aFileSize = M * K * sizeof(int8_t);
        bFileSize = K * N * sizeof(int8_t);
        cFileSize = M * N * sizeof(int32_t);
    }
    uint32_t numBlocks = 1;

    aclInit(nullptr);
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);

    uint8_t *aHost;
    uint8_t *aDevice;
    aclrtMallocHost((void **)(&aHost), aFileSize);
    aclrtMalloc((void **)&aDevice, aFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
    ReadFile("./input/x1_gm.bin", aFileSize, aHost, aFileSize);
    aclrtMemcpy(aDevice, aFileSize, aHost, aFileSize, ACL_MEMCPY_HOST_TO_DEVICE);

    uint8_t *bHost;
    uint8_t *bDevice;
    aclrtMallocHost((void **)(&bHost), bFileSize);
    aclrtMalloc((void **)&bDevice, bFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
    ReadFile("./input/x2_gm.bin", bFileSize, bHost, bFileSize);
    aclrtMemcpy(bDevice, bFileSize, bHost, bFileSize, ACL_MEMCPY_HOST_TO_DEVICE);

    uint8_t *cHost;
    uint8_t *cDevice;
    aclrtMallocHost((void **)(&cHost), cFileSize);
    aclrtMalloc((void **)&cDevice, cFileSize, ACL_MEM_MALLOC_HUGE_FIRST);

    mmad_custom<<<numBlocks, nullptr, stream>>>(aDevice, bDevice, cDevice);
    aclrtSynchronizeStream(stream);

    aclrtMemcpy(cHost, cFileSize, cDevice, cFileSize, ACL_MEMCPY_DEVICE_TO_HOST);
    WriteFile("./output/output.bin", cHost, cFileSize);

    aclrtFree(aDevice);
    aclrtFreeHost(aHost);
    aclrtFree(bDevice);
    aclrtFreeHost(bHost);
    aclrtFree(cDevice);
    aclrtFreeHost(cHost);

    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();
    return 0;
}