/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file mmad_s8_f16_f32_with_A_B_transpose_option.asc
* \brief
*/
#include "acl/acl.h"
#include "kernel_operator.h"
#include "data_utils.h"
constexpr uint32_t tilingKey = TILING_KEY;
constexpr uint32_t M = M_SIZE;
constexpr uint32_t N = N_SIZE;
constexpr uint32_t K = K_SIZE;
// A矩阵转置,则L1-->L0A时方块a需要转置;B矩阵转置,则L1-->L0B时方块b不需要转置。
template <class T, class U, bool isAtranspose, bool isBtranspose>
class KernelMmad {
public:
__aicore__ inline KernelMmad()
{
// 左矩阵分形的shape
fractalShape[0] = 16;
fractalShape[1] = 32 / sizeof(T);
// 右矩阵的shape:[32 / sizeof(T), 16]
// 左、右矩阵分形的size,单位是元素数目
fractalSize = 16 * fractalShape[1];
// 转置只能以方块形式,因此不同位宽下,方块中包含的分形个数不同
if constexpr (sizeof(T) == 2) {
fractalNum = 1;
} else {
fractalNum = 2;
}
// 对齐后的shape
// 计算不同场景下,A矩阵、B矩阵各个方向对齐的shape参数
if constexpr (AscendC::IsSameType<T, int8_t>::value && AscendC::IsSameType<U, int32_t>::value) {
if constexpr (!isAtranspose) {
// L1上,A矩阵的对齐
// GM上A矩阵的shape为[m,k]
mAlignL1 = CeilAlign(m, fractalShape[0]); // 高度
kAlignL1 = CeilAlign(k, fractalShape[1]); // 宽度
aSizeAlignL1 = mAlignL1 * kAlignL1;
// L0上,A矩阵的对齐
// 由于L0上a矩阵也是Z排布
mAlignL0 = CeilAlign(m, fractalShape[0]); // 高度
kAlignL0 = CeilAlign(k, fractalShape[1]); // 宽度
aSizeAlignL0 = mAlignL0 * kAlignL0;
} else {
// L1上,A矩阵的对齐
// GM上A矩阵的shape为[k,m]
kAlignL1 = CeilAlign(k, fractalShape[0] * fractalNum); // 高度
mAlignL1 = CeilAlign(m, fractalShape[1]); // 宽度
aSizeAlignL1 = kAlignL1 * mAlignL1;
// L0上,A矩阵的对齐
// 由于L0上a矩阵也是Z排布
mAlignL0 = CeilAlign(m, fractalShape[0] * fractalNum); // 高度
kAlignL0 = CeilAlign(k, fractalShape[1]); // 宽度
aSizeAlignL0 = mAlignL0 * kAlignL0;
}
if constexpr (!isBtranspose) {
// GM上B矩阵的shape为[k,n]
kAlignL1 = CeilAlign(k, fractalShape[0] * fractalNum); // 高度
nAlignL1 = CeilAlign(n, fractalShape[1]); // 宽度
bSizeAlignL1 = kAlignL1 * nAlignL1;
kAlignL0 = CeilAlign(k, fractalShape[1]); // 高度
nAlignL0 = CeilAlign(n, fractalShape[0] * fractalNum); // 宽度
bSizeAlignL0 = kAlignL0 * nAlignL0;
} else {
// L1上,B矩阵的对齐
// L1上, b矩阵是Z排布
nAlignL1 = CeilAlign(n, fractalShape[0]); // 高度
kAlignL1 = CeilAlign(k, fractalShape[1]); // 宽度
bSizeAlignL1 = nAlignL1 * kAlignL1;
// L0上,B矩阵的对齐
// L0上, b矩阵是N排布
kAlignL0 = CeilAlign(k, fractalShape[1]); // 高度
nAlignL0 = CeilAlign(n, fractalShape[0]); // 宽度
bSizeAlignL0 = kAlignL0 * nAlignL0;
}
} else if constexpr (AscendC::IsSameType<T, half>::value && AscendC::IsSameType<U, float>::value) {
if constexpr (!isAtranspose) {
// L1上,A矩阵的对齐
mAlignL1 = CeilAlign(m, fractalShape[0]); // 高度
kAlignL1 = CeilAlign(k, fractalShape[1]); // 宽度
aSizeAlignL1 = mAlignL1 * kAlignL1;
// L0上,A矩阵的对齐
// 由于L0上a矩阵也是Z排布
mAlignL0 = CeilAlign(m, fractalShape[0]); // 高度
kAlignL0 = CeilAlign(k, fractalShape[1]); // 宽度
aSizeAlignL0 = mAlignL0 * kAlignL0;
} else {
// L1上,A矩阵的对齐
// GM上A矩阵的shape为[k,m]
kAlignL1 = CeilAlign(k, fractalShape[0]); // 高度
mAlignL1 = CeilAlign(m, fractalShape[1]); // 宽度
aSizeAlignL1 = kAlignL1 * mAlignL1;
// L0上,A矩阵的对齐
// 由于L0上a矩阵也是Z排布
mAlignL0 = CeilAlign(m, fractalShape[0]); // 高度
kAlignL0 = CeilAlign(k, fractalShape[1]); // 宽度
aSizeAlignL0 = mAlignL0 * kAlignL0;
}
if constexpr (!isBtranspose) {
// GM上B矩阵的shape为[k,n]
kAlignL1 = CeilAlign(k, fractalShape[0]); // 高度
nAlignL1 = CeilAlign(n, fractalShape[1]); // 宽度
bSizeAlignL1 = kAlignL1 * nAlignL1;
kAlignL0 = CeilAlign(k, fractalShape[1]); // 高度
nAlignL0 = CeilAlign(n, fractalShape[0]); // 宽度
bSizeAlignL0 = kAlignL0 * nAlignL0;
} else {
// L1上,B矩阵的对齐
// L1上, b矩阵是Z排布
// GM上A矩阵的shape为[n,k]
nAlignL1 = CeilAlign(n, fractalShape[0]); // 高度
kAlignL1 = CeilAlign(k, fractalShape[1]); // 宽度
bSizeAlignL1 = nAlignL1 * kAlignL1;
// L0上,B矩阵的对齐
// L0上, b矩阵是N排布
kAlignL0 = CeilAlign(k, fractalShape[1]); // 高度
nAlignL0 = CeilAlign(n, fractalShape[0]); // 宽度
bSizeAlignL0 = kAlignL0 * nAlignL0;
}
} else if constexpr (AscendC::IsSameType<T, float>::value && AscendC::IsSameType<U, float>::value) {
if constexpr (!isAtranspose) {
// L1上,A矩阵的对齐
mAlignL1 = CeilAlign(m, fractalShape[0]); // 高度
kAlignL1 = CeilAlign(k, fractalShape[1]); // 宽度
aSizeAlignL1 = mAlignL1 * kAlignL1;
// L0上,A矩阵的对齐
// 由于L0上a矩阵也是Z排布
mAlignL0 = CeilAlign(m, fractalShape[0]); // 高度
kAlignL0 = CeilAlign(k, fractalShape[1]); // 宽度
aSizeAlignL0 = mAlignL0 * kAlignL0;
} else {
// L1上,A矩阵的对齐
// GM上A矩阵的shape为[k,m]
kAlignL1 = CeilAlign(k, fractalShape[0]); // 高度
mAlignL1 = CeilAlign(m, fractalShape[1]); // 宽度
aSizeAlignL1 = kAlignL1 * mAlignL1;
// L0上,A矩阵的对齐
// 由于L0上a矩阵也是Z排布
mAlignL0 = CeilAlign(m, fractalShape[1]); // 高度
kAlignL0 = CeilAlign(k, fractalShape[1] * fractalNum); // 宽度
aSizeAlignL0 = mAlignL0 * kAlignL0;
}
if constexpr (!isBtranspose) {
// GM上B矩阵的shape为[k,n]
kAlignL1 = CeilAlign(k, fractalShape[0]); // 高度
nAlignL1 = CeilAlign(n, fractalShape[1]); // 宽度
bSizeAlignL1 = kAlignL1 * nAlignL1;
kAlignL0 = CeilAlign(k, fractalShape[0]); // 高度
nAlignL0 = CeilAlign(n, fractalShape[1]); // 宽度
bSizeAlignL0 = kAlignL0 * nAlignL0;
} else {
// L1上,B矩阵的对齐
// L1上, b矩阵是Z排布
nAlignL1 = CeilAlign(n, fractalShape[0]); // 高度
kAlignL1 = CeilAlign(k, fractalShape[1]); // 宽度
bSizeAlignL1 = nAlignL1 * kAlignL1;
// L0上,B矩阵的对齐
// L0上, b矩阵是N排布
kAlignL0 = CeilAlign(k, fractalShape[1]); // 高度
nAlignL0 = CeilAlign(n, fractalShape[0]); // 宽度
bSizeAlignL0 = kAlignL0 * nAlignL0;
}
}
// C矩阵无视数据类型和a、b是否转置,m,n都向16对齐
cSizeAlignL0 = CeilAlign(m, fractalShape[0]) * CeilAlign(n, fractalShape[0]);
}
__aicore__ inline void Init(GM_ADDR a, GM_ADDR b, GM_ADDR c, AscendC::TPipe* pipeIn)
{
// set cube only
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIC_ONLY);
pipe = pipeIn;
aGM.SetGlobalBuffer((__gm__ T *)a);
bGM.SetGlobalBuffer((__gm__ T *)b);
cGM.SetGlobalBuffer((__gm__ U *)c);
pipe->InitBuffer(inQueueA1, 1, aSizeAlignL1 * sizeof(T));
pipe->InitBuffer(inQueueA2, 1, aSizeAlignL0 * sizeof(T));
pipe->InitBuffer(inQueueB1, 1, bSizeAlignL1 * sizeof(T));
pipe->InitBuffer(inQueueB2, 1, bSizeAlignL0 * sizeof(T));
pipe->InitBuffer(outQueueCO1, 1, cSizeAlignL0 * sizeof(U));
}
__aicore__ inline void Process()
{
CopyIn();
if constexpr (!isAtranspose) {
SplitALoad3Dv2();
} else {
SplitATransposeLoad3Dv2();
}
if constexpr (!isBtranspose) {
SplitBTransposeLoad3Dv2();
} else {
SplitB();
}
Compute();
CopyOut();
}
private:
__aicore__ inline uint16_t CeilDivision(uint16_t numerator, uint16_t denominator)
{
return (numerator + denominator - 1) / denominator;
}
__aicore__ inline uint16_t CeilAlign(uint16_t numerator, uint16_t denominator)
{
return (numerator + denominator - 1) / denominator * denominator;
}
__aicore__ inline void CopyIn()
{
AscendC::LocalTensor<T> a1Local = inQueueA1.AllocTensor<T>();
AscendC::LocalTensor<T> b1Local = inQueueB1.AllocTensor<T>();
// GM-->L1,搬运A矩阵
AscendC::Nd2NzParams nd2nzA1Params;
// 不同的数据类型,A、B矩阵在高度方向的对齐不同
if constexpr (!isAtranspose) {
// 传输ND矩阵的数目
nd2nzA1Params.ndNum = 1;
// ND矩阵的行数
nd2nzA1Params.nValue = m;
// ND矩阵的列数
nd2nzA1Params.dValue = k;
// 只传输了1个ND矩阵,该参数无效
nd2nzA1Params.srcNdMatrixStride = 0;
// 源操作数同一ND矩阵的相邻行起始地址间的偏移
nd2nzA1Params.srcDValue = k;
// 以下这个参数取A矩阵在L1上,高度方向的对齐后的长度
// 由于A不转置,因此对于三种数据类型该参数均相同
nd2nzA1Params.dstNzC0Stride = CeilAlign(m, fractalShape[0]);
nd2nzA1Params.dstNzNStride = 1;
nd2nzA1Params.dstNzMatrixStride = 0;
} else {
nd2nzA1Params.ndNum = 1;
nd2nzA1Params.nValue = k;
nd2nzA1Params.dValue = m;
nd2nzA1Params.srcNdMatrixStride = 0;
nd2nzA1Params.srcDValue = m;
// 以下这个参数取A矩阵在L1上,高度方向的对齐后的长度
// 由于A转置,因此三种数据类型下,该参数的配置不相同
if constexpr (AscendC::IsSameType<T, int8_t>::value && AscendC::IsSameType<U, int32_t>::value) {
nd2nzA1Params.dstNzC0Stride = CeilAlign(k, fractalShape[0] * fractalNum);
} else if constexpr (AscendC::IsSameType<T, half>::value && AscendC::IsSameType<U, float>::value) {
nd2nzA1Params.dstNzC0Stride = CeilAlign(k, fractalShape[0]);
} else if constexpr (AscendC::IsSameType<T, float>::value && AscendC::IsSameType<U, float>::value) {
nd2nzA1Params.dstNzC0Stride = CeilAlign(k, fractalShape[0]);
}
nd2nzA1Params.dstNzNStride = 1;
nd2nzA1Params.dstNzMatrixStride = 0;
}
AscendC::DataCopy(a1Local, aGM, nd2nzA1Params);
// GM-->L1,搬运B矩阵
AscendC::Nd2NzParams nd2nzB1Params;
if constexpr (!isBtranspose) {
nd2nzB1Params.ndNum = 1;
nd2nzB1Params.nValue = k;
nd2nzB1Params.dValue = n;
nd2nzB1Params.srcNdMatrixStride = 0;
nd2nzB1Params.srcDValue = n;
// 以下这个参数取B矩阵在L1上,高度方向的对齐后的长度
// 由于A转置,因此三种数据类型下,该参数的配置不相同
if constexpr (AscendC::IsSameType<T, int8_t>::value && AscendC::IsSameType<U, int32_t>::value) {
nd2nzB1Params.dstNzC0Stride = CeilAlign(k, fractalShape[0] * fractalNum);
} else if constexpr (AscendC::IsSameType<T, half>::value && AscendC::IsSameType<U, float>::value) {
nd2nzB1Params.dstNzC0Stride = CeilAlign(k, fractalShape[0]);
} else if constexpr (AscendC::IsSameType<T, float>::value && AscendC::IsSameType<U, float>::value) {
nd2nzB1Params.dstNzC0Stride = CeilAlign(k, fractalShape[0]);
}
nd2nzB1Params.dstNzNStride = 1;
nd2nzB1Params.dstNzMatrixStride = 0;
} else {
nd2nzB1Params.ndNum = 1;
nd2nzB1Params.nValue = n;
nd2nzB1Params.dValue = k;
nd2nzB1Params.srcNdMatrixStride = 0;
nd2nzB1Params.srcDValue = k;
// 以下这个参数取B矩阵在L1上,高度方向的对齐后的长度
// 由于B转置,因此三种数据类型下,该参数的配置相同
nd2nzB1Params.dstNzC0Stride = CeilAlign(n, fractalShape[0]);
nd2nzB1Params.dstNzNStride = 1;
nd2nzB1Params.dstNzMatrixStride = 0;
}
AscendC::DataCopy(b1Local, bGM, nd2nzB1Params);
inQueueA1.EnQue(a1Local);
inQueueB1.EnQue(b1Local);
}
// A矩阵转置,调用Load3Dv2
__aicore__ inline void SplitATransposeLoad3Dv2()
{
AscendC::LocalTensor<T> a1Local = inQueueA1.DeQue<T>();
AscendC::LocalTensor<T> a2Local = inQueueA2.AllocTensor<T>();
// 使用load3d接口,实现NZ2ZZ
AscendC::LoadData3DParamsV2<T> loadDataParams;
// 源操作数height
loadDataParams.l1H = CeilAlign(k, fractalShape[0]);
// 源操作数wight
loadDataParams.l1W = 1;
// 源操作数的通道数,
// img2col的结果矩阵高度为ho * wo,根据ho和wo的计算公式,代入卷积核宽度、卷积核滑动步长、卷积核膨胀系数等参数可知:ho * wo = loadDataParams.l1H * loadDataParams.l1w
// img2col的结果矩阵宽度为ci * kh * kw,代入kh=1,kw=1,可知结果矩阵的宽度为ci=loadDataParams.channelSize = m
loadDataParams.channelSize = CeilAlign(m, fractalShape[1]);
// 该指令在目的操作数width维度的传输长度,如果不覆盖最右侧的分形,对于half类型,应为16的倍数,对于int8_t/uint8_t应为32的倍数;覆盖的情况则无倍数要求。
loadDataParams.kExtension = CeilAlign(m, fractalShape[1]);
// 该指令在目的操作数height维度的传输长度,如果不覆盖最下侧的分形,对于half/int8_t/uint8_t,应为16的倍数;覆盖的情况则无倍数要求。
loadDataParams.mExtension = CeilAlign(k, fractalShape[1] * fractalNum);
// 卷积核在源操作数width维度滑动的步长
loadDataParams.strideW = 1;
// 卷积核在源操作数height维度滑动的步长
loadDataParams.strideH = 1;
// 卷积核width
loadDataParams.filterW = 1;
// 卷积核height
loadDataParams.filterH = 1;
// 卷积核width膨胀系数
loadDataParams.dilationFilterW = 1;
// 卷积核height膨胀系数
loadDataParams.dilationFilterH = 1;
loadDataParams.filterSizeW = false;
loadDataParams.filterSizeH = false;
loadDataParams.enTranspose = true;
loadDataParams.fMatrixCtrl = false;
AscendC::LoadData(a2Local, a1Local, loadDataParams);
inQueueA2.EnQue<T>(a2Local);
inQueueA1.FreeTensor(a1Local);
}
// A矩阵不转置,调用Load3Dv2
__aicore__ inline void SplitALoad3Dv2()
{
AscendC::LocalTensor<T> a1Local = inQueueA1.DeQue<T>();
AscendC::LocalTensor<T> a2Local = inQueueA2.AllocTensor<T>();
// 使用load3d接口,实现NZ2ZZ
AscendC::LoadData3DParamsV2<T> loadDataParams;
// 源操作数height
loadDataParams.l1H = CeilAlign(m, fractalShape[0]);
// 源操作数wight
loadDataParams.l1W = 1;
// 源操作数的通道数,
// img2col的结果矩阵高度为ho * wo,根据ho和wo的计算公式,代入卷积核宽度、卷积核滑动步长、卷积核膨胀系数等参数可知:ho * wo = loadDataParams.l1H * loadDataParams.l1w
// img2col的结果矩阵宽度为ci * kh * kw,代入kh=1,kw=1,可知结果矩阵的宽度为ci=loadDataParams.channelSize = m
loadDataParams.channelSize = CeilAlign(k, fractalShape[1]);
// 该指令在目的操作数width维度的传输长度,如果不覆盖最右侧的分形,对于half类型,应为16的倍数,对于int8_t/uint8_t应为32的倍数;覆盖的情况则无倍数要求。
loadDataParams.kExtension = CeilAlign(k, fractalShape[1]);
// 该指令在目的操作数height维度的传输长度,如果不覆盖最下侧的分形,对于half/int8_t/uint8_t,应为16的倍数;覆盖的情况则无倍数要求。
loadDataParams.mExtension = CeilAlign(m, fractalShape[0]);
// 卷积核在源操作数width维度滑动的步长
loadDataParams.strideW = 1;
// 卷积核在源操作数height维度滑动的步长
loadDataParams.strideH = 1;
// 卷积核width
loadDataParams.filterW = 1;
// 卷积核height
loadDataParams.filterH = 1;
// 卷积核width膨胀系数
loadDataParams.dilationFilterW = 1;
// 卷积核height膨胀系数
loadDataParams.dilationFilterH = 1;
loadDataParams.filterSizeW = false;
loadDataParams.filterSizeH = false;
loadDataParams.enTranspose = false;
loadDataParams.fMatrixCtrl = false;
AscendC::LoadData(a2Local, a1Local, loadDataParams);
inQueueA2.EnQue<T>(a2Local);
inQueueA1.FreeTensor(a1Local);
}
// B矩阵转置,调用Load3Dv2
__aicore__ inline void SplitBTransposeLoad3Dv2()
{
AscendC::LocalTensor<T> b1Local = inQueueB1.DeQue<T>();
AscendC::LocalTensor<T> b2Local = inQueueB2.AllocTensor<T>();
AscendC::LoadData3DParamsV2<T> loadDataParams;
loadDataParams.l1H = CeilAlign(k, fractalShape[0]);
loadDataParams.l1W = 1;
loadDataParams.channelSize = CeilAlign(n, fractalShape[1]);
loadDataParams.kExtension = CeilAlign(n, fractalShape[1]);
loadDataParams.mExtension = CeilAlign(k, fractalShape[0]);
loadDataParams.strideW = 1;
loadDataParams.strideH = 1;
loadDataParams.filterW = 1;
loadDataParams.filterH = 1;
loadDataParams.dilationFilterW = 1;
loadDataParams.dilationFilterH = 1;
loadDataParams.filterSizeW = false;
loadDataParams.filterSizeH = false;
// 对于Load3Dv2接口,当目的地址为L0B时,b矩阵会自动转置,loadDataParams.enTranspose仅在目的地址为L0A时生效。
loadDataParams.enTranspose = true;
loadDataParams.fMatrixCtrl = false;
AscendC::LoadData(b2Local, b1Local, loadDataParams);
inQueueB1.FreeTensor(b1Local);
inQueueB2.EnQue<T>(b2Local);
}
__aicore__ inline void SplitB()
{
AscendC::LocalTensor<T> b1Local = inQueueB1.DeQue<T>();
AscendC::LocalTensor<T> b2Local = inQueueB2.AllocTensor<T>();
uint32_t dstOffset = CeilDivision(n, fractalShape[0]) * fractalSize;
uint32_t srcOffset = CeilDivision(n, fractalShape[0]) * fractalSize;
// Nz -> Zz
AscendC::LoadData2DParams loadDataParams;
loadDataParams.repeatTimes = CeilDivision(n, fractalShape[0]);
loadDataParams.srcStride = 1;
loadDataParams.dstGap = 0;
loadDataParams.ifTranspose = false;
for (int i = 0; i < CeilDivision(k, fractalShape[1]); ++i) {
AscendC::LoadData(b2Local[i * dstOffset], b1Local[i * srcOffset], loadDataParams);
}
inQueueB1.FreeTensor(b1Local);
inQueueB2.EnQue<T>(b2Local);
}
__aicore__ inline void Compute()
{
AscendC::LocalTensor<T> a2Local = inQueueA2.DeQue<T>();
AscendC::LocalTensor<T> b2Local = inQueueB2.DeQue<T>();
AscendC::LocalTensor<U> c1Local = outQueueCO1.AllocTensor<U>();
AscendC::MmadParams mmadParams;
mmadParams.m = m;
mmadParams.n = n;
mmadParams.k = k;
if constexpr (AscendC::IsSameType<T, float>::value && AscendC::IsSameType<U, float>::value) {
if (isAtranspose) {
mmadParams.kDirectionAlign = true;
}
}
AscendC::Mmad(c1Local, a2Local, b2Local, mmadParams);
outQueueCO1.EnQue<U>(c1Local);
inQueueA2.FreeTensor(a2Local);
inQueueB2.FreeTensor(b2Local);
}
__aicore__ inline void CopyOut()
{
AscendC::LocalTensor<U> c1Local = outQueueCO1.DeQue<U>();
AscendC::FixpipeParamsV220 fixpipeParams;
fixpipeParams.nSize = n;
fixpipeParams.mSize = m;
// 源操作数来源于L0c,因此m只需要向16对齐,与数据类型无关
fixpipeParams.srcStride = CeilAlign(m, fractalShape[0]);
fixpipeParams.dstStride = n;
fixpipeParams.ndNum = 1;
fixpipeParams.srcNdStride = 0;
fixpipeParams.dstNdStride = 0;
AscendC::Fixpipe(cGM, c1Local, fixpipeParams);
outQueueCO1.FreeTensor(c1Local);
}
private:
AscendC::TPipe* pipe;
AscendC::TQue<AscendC::TPosition::A1, 1> inQueueA1;
AscendC::TQue<AscendC::TPosition::A2, 1> inQueueA2;
AscendC::TQue<AscendC::TPosition::B1, 1> inQueueB1;
AscendC::TQue<AscendC::TPosition::B2, 1> inQueueB2;
AscendC::TQue<AscendC::TPosition::CO1, 1> outQueueCO1;
AscendC::GlobalTensor<T> aGM;
AscendC::GlobalTensor<T> bGM;
AscendC::GlobalTensor<U> cGM;
uint16_t m = M, k = K, n = N;
uint16_t mAlignL1 = M, kAlignL1 = K, nAlignL1 = N;
uint16_t mAlignL0 = M, kAlignL0 = K, nAlignL0 = N;
uint16_t aSizeAlignL1, bSizeAlignL1;
uint16_t aSizeAlignL0, bSizeAlignL0, cSizeAlignL0;
uint16_t fractalShape[2] = {0, 0};
uint16_t fractalSize = 0;
uint16_t fractalNum = 0;
};
extern "C" __global__ __aicore__ void mmad_custom(GM_ADDR a, GM_ADDR b, GM_ADDR c)
{
AscendC::TPipe pipe;
// load3dv2接口只能用作A转置、不转置
// Load3Dv2接口,当目的地址为L0B时,b矩阵会自动转置(b是B矩阵中的分形)
// Load3Dv2接口,对于目的地址为L0A或L0B时,其支持的数据类型不同:
// 目的地址为L0A,支持数据类型为:uint8_t/int8_t/half/bfloat16_t/uint32_t/int32_t/float/int4b_t
// 目的地址为L0B,支持数据类型为:half/bfloat16_t/uint32_t/int32_t/float
// 综上,可以得出Load3Dv2接口适用的场景有以下五个:
if constexpr (tilingKey == 1) {
// 输入为half类型
// A矩阵不转置(a不转置),B矩阵不转置(b转置)
KernelMmad<half, float, false, false> op;
op.Init(a, b, c, &pipe);
op.Process();
} else if constexpr (tilingKey == 2) {
// 输入为half类型
// A矩阵转置(a转置),B矩阵不转置(b转置)
KernelMmad<half, float, true, false> op;
op.Init(a, b, c, &pipe);
op.Process();
} else if constexpr (tilingKey == 3) {
// 输入为float类型
// A矩阵不转置(a不转置),B矩阵不转置(b转置)
KernelMmad<float, float, false, false> op;
op.Init(a, b, c, &pipe);
op.Process();
} else if constexpr (tilingKey == 4) {
// 输入为float类型
// A矩阵转置(a转置),B矩阵不转置(b转置)
KernelMmad<float, float, true, false> op;
op.Init(a, b, c, &pipe);
op.Process();
} else if constexpr (tilingKey == 5) {
// 输入为int8_t类型
// A矩阵不转置(a不转置),B矩阵转置(b不转置)
KernelMmad<int8_t, int32_t, false, true> op;
op.Init(a, b, c, &pipe);
op.Process();
}
}
int32_t main(int32_t argc, char *argv[])
{
size_t aFileSize = 0;
size_t bFileSize = 0;
size_t cFileSize = 0;
if constexpr (tilingKey <= 2) {
aFileSize = M * K * sizeof(half);
bFileSize = K * N * sizeof(half);
cFileSize = M * N * sizeof(float);
} else if constexpr (tilingKey <= 4) {
aFileSize = M * K * sizeof(float);
bFileSize = K * N * sizeof(float);
cFileSize = M * N * sizeof(float);
} else {
aFileSize = M * K * sizeof(int8_t);
bFileSize = K * N * sizeof(int8_t);
cFileSize = M * N * sizeof(int32_t);
}
uint32_t numBlocks = 1;
aclInit(nullptr);
int32_t deviceId = 0;
aclrtSetDevice(deviceId);
aclrtStream stream = nullptr;
aclrtCreateStream(&stream);
uint8_t *aHost;
uint8_t *aDevice;
aclrtMallocHost((void **)(&aHost), aFileSize);
aclrtMalloc((void **)&aDevice, aFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
ReadFile("./input/x1_gm.bin", aFileSize, aHost, aFileSize);
aclrtMemcpy(aDevice, aFileSize, aHost, aFileSize, ACL_MEMCPY_HOST_TO_DEVICE);
uint8_t *bHost;
uint8_t *bDevice;
aclrtMallocHost((void **)(&bHost), bFileSize);
aclrtMalloc((void **)&bDevice, bFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
ReadFile("./input/x2_gm.bin", bFileSize, bHost, bFileSize);
aclrtMemcpy(bDevice, bFileSize, bHost, bFileSize, ACL_MEMCPY_HOST_TO_DEVICE);
uint8_t *cHost;
uint8_t *cDevice;
aclrtMallocHost((void **)(&cHost), cFileSize);
aclrtMalloc((void **)&cDevice, cFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
mmad_custom<<<numBlocks, nullptr, stream>>>(aDevice, bDevice, cDevice);
aclrtSynchronizeStream(stream);
aclrtMemcpy(cHost, cFileSize, cDevice, cFileSize, ACL_MEMCPY_DEVICE_TO_HOST);
WriteFile("./output/output.bin", cHost, cFileSize);
aclrtFree(aDevice);
aclrtFreeHost(aHost);
aclrtFree(bDevice);
aclrtFreeHost(bHost);
aclrtFree(cDevice);
aclrtFreeHost(cHost);
aclrtDestroyStream(stream);
aclrtResetDevice(deviceId);
aclFinalize();
return 0;
}