/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file matmul.asc
* \brief
*/
#include <cstdint>
#include <iostream>
#include <vector>
#include <algorithm>
#include <iterator>
#include "acl/acl.h"
#include "kernel_operator.h"
#include "lib/matmul_intf.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/tiling_api.h"
#include "kernel_tiling/kernel_tiling.h"
#include "data_utils.h"
constexpr uint32_t M = 127;
constexpr uint32_t N = 127;
constexpr uint32_t K = 63;
/**
* @brief Copy tiling data to TCubeTiling ptr from tiling gm addr.
* @param tiling: TCubeTiling ptr which needs to copy tiling data.
* @param tilingGM: tiling gm addr.
* @retval None
*/
__aicore__ inline void CopyTiling(TCubeTiling *tiling, GM_ADDR tilingGM)
{
uint32_t *ptr = reinterpret_cast<uint32_t *>(tiling);
auto tiling32 = reinterpret_cast<__gm__ uint32_t *>(tilingGM);
for (uint32_t i = 0; i < sizeof(TCubeTiling) / sizeof(uint32_t); i++, ptr++) {
*ptr = *(tiling32 + i);
}
return;
}
template <typename AType, typename BType, typename CType, typename BiasType>
class MatmulKernel {
public:
__aicore__ inline MatmulKernel(){};
/**
* @brief Initialization before process.
* @param a: A matrix gm addr.
* @param b: B matrix gm addr.
* @param bias: Bias matrix gm addr.
* @param c: C matrix gm addr.
* @param tiling: Matmul tiling struct.
* @param isTransA: Whether A matrix is transposed.
* @param isTransB: Whether B matrix is transposed.
* @param pipe: Pipe.
* @retval None
*/
__aicore__ inline void Init(GM_ADDR a, GM_ADDR b, GM_ADDR bias, GM_ADDR c, const TCubeTiling& tiling,
bool isTransA, bool isTransB, AscendC::TPipe *pipe);
/**
* @brief Process matrix calculation.
* @retval None
*/
__aicore__ inline void Process();
AscendC::Matmul<
AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, AType, false, LayoutMode::NONE, true>,
AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, BType, false, LayoutMode::NONE, true>,
AscendC::MatmulType<AscendC::TPosition::VECIN, CubeFormat::ND_ALIGN, CType>,
AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, BiasType>,
CFG_NORM, AscendC::MatmulCallBackFunc<nullptr, nullptr, nullptr>, AscendC::Impl::Detail::SplitMMatmulPolicy
> matmulObj;
private:
/**
* @brief Calculate the gm offset based on the blockIdx.
* @param blockIdx: Current Core blockidx.
* @param offsetA: Gm offset of A matrix.
* @param offsetB: Gm offset of B matrix.
* @param offsetC: Gm offset of C matrix.
* @param offsetBias: Gm offset of Bias matrix.
* @retval None
*/
__aicore__ inline void CalcOffset(
int32_t blockIdx, int32_t& offsetA, int32_t& offsetB, int32_t& offsetC, int32_t& offsetBias);
AscendC::GlobalTensor<AType> aGlobal;
AscendC::GlobalTensor<BType> bGlobal;
AscendC::GlobalTensor<CType> cGlobal;
AscendC::GlobalTensor<BiasType> biasGlobal;
TCubeTiling tiling;
int32_t mCoreIndex;
int32_t nCoreIndex;
bool isTransA{false};
bool isTransB{false};
AscendC::TQue<AscendC::TPosition::VECIN, 1> vecin;
AscendC::TQue<AscendC::TPosition::VECOUT, 1> vecout;
int32_t halfBaseSize;
};
template <typename AType, typename BType, typename CType, typename BiasType>
__aicore__ inline void MatmulKernel<AType, BType, CType, BiasType>::Init(GM_ADDR a, GM_ADDR b, GM_ADDR bias,
GM_ADDR c, const TCubeTiling& tiling, bool isTransA, bool isTransB, AscendC::TPipe *pipe)
{
this->tiling = tiling;
aGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ AType*>(a), tiling.M * tiling.Ka);
bGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ BType*>(b), tiling.Kb * tiling.N);
cGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ CType*>(c), tiling.M * tiling.N);
biasGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ BiasType*>(bias), tiling.N);
int32_t offsetA = 0;
int32_t offsetB = 0;
int32_t offsetC = 0;
int32_t offsetBias = 0;
this->isTransA = isTransA;
this->isTransB = isTransB;
// V0 V1 SameAB所以/2
CalcOffset(AscendC::GetBlockIdx() / 2, offsetA, offsetB, offsetC, offsetBias);
aGlobal = aGlobal[offsetA];
bGlobal = bGlobal[offsetB];
cGlobal = cGlobal[offsetC];
biasGlobal = biasGlobal[offsetBias];
// 初始化按base/2大小初始化
halfBaseSize = tiling.baseM / 2 * tiling.baseN;
pipe->InitBuffer(vecin, 1, halfBaseSize * sizeof(CType));
pipe->InitBuffer(vecout, 1, halfBaseSize * sizeof(CType));
}
template <typename AType, typename BType, typename CType, typename BiasType>
__aicore__ inline void MatmulKernel<AType, BType, CType, BiasType>::Process()
{
matmulObj.SetTensorA(aGlobal, isTransA);
matmulObj.SetTensorB(bGlobal, isTransB);
if (tiling.isBias) {
matmulObj.SetBias(biasGlobal);
}
// process with tail core
int tailM = tiling.M - mCoreIndex * tiling.singleCoreM;
tailM = tailM < tiling.singleCoreM ? tailM : tiling.singleCoreM;
int tailN = tiling.N - nCoreIndex * tiling.singleCoreN;
tailN = tailN < tiling.singleCoreN ? tailN : tiling.singleCoreN;
uint32_t curSingleCoreM = (tailM < tiling.singleCoreM) ? tailM : tiling.singleCoreM;
uint32_t curSingleCoreN = (tailN < tiling.singleCoreN) ? tailN : tiling.singleCoreN;
if (tailM < tiling.singleCoreM || tailN < tiling.singleCoreN) {
matmulObj.SetTail(curSingleCoreM, curSingleCoreN);
}
uint16_t nIter = (curSingleCoreN + tiling.baseN - 1) / tiling.baseN;
uint32_t tailBaseN = (curSingleCoreN % tiling.baseN == 0) ? tiling.baseN : (curSingleCoreN % tiling.baseN);
uint16_t mIter = (curSingleCoreM + tiling.baseM - 1) / tiling.baseM;
uint32_t tailBaseM = (curSingleCoreM % tiling.baseM == 0) ? tiling.baseM : (curSingleCoreM % tiling.baseM);
// orderm
for (int j = 0; j < nIter; j++) {
for (int i = 0; i < mIter; i++) {
matmulObj.Iterate();
AscendC::LocalTensor<CType> vecinTensor = vecin.AllocTensor<CType>();
matmulObj.GetTensorC(vecinTensor, 0, true);
vecin.EnQue(vecinTensor);
AscendC::LocalTensor<CType> vecinLocal = vecin.DeQue<CType>();
AscendC::LocalTensor<CType> vecoutTensor = vecout.AllocTensor<CType>();
DataCopy(vecoutTensor, vecinLocal, halfBaseSize); // 为方便直接拷贝整个大小
vecout.EnQue(vecoutTensor);
AscendC::LocalTensor<CType> vecoutLocal = vecout.DeQue<CType>();
vecin.FreeTensor(vecinLocal);
// 搬出主块参数
uint16_t blockCountV0 = tiling.baseM / 2;
uint16_t blockCountV1 = blockCountV0;
uint32_t blockLen = tiling.baseN * sizeof(CType);
uint32_t srcStride = 0;
uint32_t dstStride = (tiling.N - tiling.baseN) * sizeof(CType);
if ((j == nIter -1) && (i != mIter -1)) {
// n方向最后一块
blockCountV0 = tiling.baseM / 2;
blockCountV1 = blockCountV0;
blockLen = tailBaseN * sizeof(CType);
srcStride = 0;
dstStride = (tiling.N - tailBaseN) * sizeof(CType);
} else if ((j != nIter -1) && (i == mIter -1)) {
// m方向最后一块
blockCountV0 = (tailBaseM + 2 - 1) / 2;
blockCountV1 = (tailBaseM % 2 == 0) ? blockCountV0 : (blockCountV0 - 1);
blockLen = tiling.baseN * sizeof(CType);
srcStride = 0;
dstStride = (tiling.N - tiling.baseN) * sizeof(CType);
} else if ((j == nIter -1) && (i == mIter -1)) {
// 双方向最后一块
blockCountV0 = (tailBaseM + 2 - 1) / 2;
blockCountV1 = (tailBaseM % 2 == 0) ? blockCountV0 : (blockCountV0 - 1);
blockLen = tailBaseN * sizeof(CType);
srcStride = 0;
dstStride = (tiling.N - tailBaseN) * sizeof(CType);
}
uint64_t cGlobalOffsetV0 = j * tiling.baseN + i * tiling.baseM * tiling.N;
uint64_t cGlobalOffsetV1 = cGlobalOffsetV0 + blockCountV0 * tiling.N;
if (AscendC::GetSubBlockIdx() == 0) {
// v0
DataCopyPad(cGlobal[cGlobalOffsetV0], vecoutLocal, {blockCountV0, blockLen, srcStride, dstStride, 0});
} else {
// v1
if (blockCountV1 > 0) {
DataCopyPad(cGlobal[cGlobalOffsetV1], vecoutLocal, {blockCountV1, blockLen, srcStride, dstStride, 0});
}
}
vecout.FreeTensor(vecoutLocal);
}
}
matmulObj.End();
}
template <typename AType, typename BType, typename CType, typename BiasType>
__aicore__ inline void MatmulKernel<AType, BType, CType, BiasType>::CalcOffset(
int32_t blockIdx, int32_t& offsetA, int32_t& offsetB, int32_t& offsetC, int32_t& offsetBias)
{
const TCubeTiling& tiling = this->tiling;
auto mSingleBlocks = (tiling.M + tiling.singleCoreM - 1) / tiling.singleCoreM; // split M into mSingleBlocks cores
mCoreIndex = blockIdx % mSingleBlocks;
nCoreIndex = blockIdx / mSingleBlocks;
offsetA = mCoreIndex * tiling.Ka * tiling.singleCoreM;
if (isTransA) {
offsetA = mCoreIndex * tiling.singleCoreM;
}
offsetB = nCoreIndex * tiling.singleCoreN;
if (isTransB) {
offsetB = nCoreIndex * tiling.Kb * tiling.singleCoreN;
}
offsetC = mCoreIndex * tiling.N * tiling.singleCoreM + nCoreIndex * tiling.singleCoreN;
offsetBias = nCoreIndex * tiling.singleCoreN;
}
/**
* @brief matmul kernel function entry
* @param a: A matrix gm addr.
* @param b: B matrix gm addr.
* @param bias: bias matrix gm addr.
* @param c: C matrix gm addr.
* @param workspace: Temporary gm space addr required by matmul calc.
* @param tilingGm: Tiling data addr.
* @retval None
*/
__global__ __aicore__ void matmul_custom(GM_ADDR a, GM_ADDR b, GM_ADDR bias, GM_ADDR c,
__kfc_workspace__ GM_ADDR workspace, GM_ADDR tilingGm)
{
// prepare tiling
TCubeTiling tiling;
CopyTiling(&tiling, tilingGm);
// define matmul kernel
MatmulKernel<half, half, float, float> matmulKernel;
AscendC::TPipe pipe;
REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), matmulKernel.matmulObj, &tiling);
// init matmul kernel, isTransA=false, isTransB=false
matmulKernel.Init(a, b, bias, c, tiling, false, false, &pipe);
// matmul kernel process
matmulKernel.Process();
}
void GenerateTiling(platform_ascendc::PlatformAscendC* ascendcPlatform, uint8_t *tilingBuf)
{
using TPosition = matmul_tiling::TPosition;
using CubeFormat = matmul_tiling::CubeFormat;
using DataType = matmul_tiling::DataType;
TPosition leftPosition = TPosition::GM;
CubeFormat leftFormat = CubeFormat::ND;
DataType leftDtype = DataType::DT_FLOAT16;
bool isTransA = false;
TPosition rightPosition = TPosition::GM;
CubeFormat rightFormat = CubeFormat::ND;
DataType rightDtype = DataType::DT_FLOAT16;
bool isTransB = false;
TPosition biasPosition = TPosition::GM;
CubeFormat biasFormat = CubeFormat::ND;
DataType biasDtype = DataType::DT_FLOAT;
TPosition resultPosition = TPosition::VECIN;
CubeFormat resultFormat = CubeFormat::ND_ALIGN;
DataType resultDtype = DataType::DT_FLOAT;
bool isBias = true;
optiling::TCubeTiling tilingData;
matmul_tiling::MultiCoreMatmulTiling tilingApi(*ascendcPlatform);
tilingApi.SetDim(ascendcPlatform->GetCoreNumAic());
tilingApi.SetAType(leftPosition, leftFormat, leftDtype, isTransA);
tilingApi.SetBType(rightPosition, rightFormat, rightDtype, isTransB);
tilingApi.SetBiasType(biasPosition, biasFormat, biasDtype);
tilingApi.SetCType(resultPosition, resultFormat, resultDtype);
tilingApi.SetOrgShape(M, N, K);
tilingApi.SetShape(M, N, K);
tilingApi.EnableBias(isBias);
tilingApi.SetBufferSpace(-1, -1, -1);
int64_t res = tilingApi.GetTiling(tilingData); // Get matmul tiling data.
if (res == -1) {
std::cout << "gen tiling failed" << std::endl;
}
uint32_t tcubeTilingSize = tilingData.GetDataSize();
tilingData.SaveToBuffer(tilingBuf, tcubeTilingSize);
return;
}
int32_t main(int32_t argc, char *argv[])
{
auto ascendcPlatform = platform_ascendc::PlatformAscendCManager::GetInstance();
size_t aFileSize = M * K * sizeof(uint16_t); // uint16_t represent half
size_t bFileSize = K * N * sizeof(uint16_t); // uint16_t represent half
size_t cFileSize = M * N * sizeof(float);
size_t biasFileSize = N * sizeof(float);
size_t userWorkspaceSize = 0;
size_t systemWorkspaceSize = static_cast<size_t>(ascendcPlatform->GetLibApiWorkSpaceSize());
size_t workspaceSize = userWorkspaceSize + systemWorkspaceSize;
// matmul TCubeTiling
size_t tilingFileSize = sizeof(TCubeTiling);
uint8_t *tilingBuf = (uint8_t *)malloc(tilingFileSize);
GenerateTiling(ascendcPlatform, tilingBuf);
uint32_t numBlocks = reinterpret_cast<TCubeTiling *>(tilingBuf)->usedCoreNum;
aclInit(nullptr);
int32_t deviceId = 0;
aclrtSetDevice(deviceId);
aclrtStream stream = nullptr;
aclrtCreateStream(&stream);
uint8_t *aHost;
uint8_t *aDevice;
aclrtMallocHost((void **)(&aHost), aFileSize);
aclrtMalloc((void **)&aDevice, aFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
ReadFile("./input/x1_gm.bin", aFileSize, aHost, aFileSize);
aclrtMemcpy(aDevice, aFileSize, aHost, aFileSize, ACL_MEMCPY_HOST_TO_DEVICE);
uint8_t *bHost;
uint8_t *bDevice;
aclrtMallocHost((void **)(&bHost), bFileSize);
aclrtMalloc((void **)&bDevice, bFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
ReadFile("./input/x2_gm.bin", bFileSize, bHost, bFileSize);
aclrtMemcpy(bDevice, bFileSize, bHost, bFileSize, ACL_MEMCPY_HOST_TO_DEVICE);
uint8_t *biasHost;
uint8_t *biasDevice;
aclrtMallocHost((void **)(&biasHost), biasFileSize);
aclrtMalloc((void **)&biasDevice, biasFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
ReadFile("./input/bias_gm.bin", biasFileSize, biasHost, biasFileSize);
aclrtMemcpy(biasDevice, biasFileSize, biasHost, biasFileSize, ACL_MEMCPY_HOST_TO_DEVICE);
uint8_t *cHost;
uint8_t *cDevice;
aclrtMallocHost((void **)(&cHost), cFileSize);
aclrtMalloc((void **)&cDevice, cFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
uint8_t *workspaceDevice;
aclrtMalloc((void **)&workspaceDevice, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
uint8_t *tilingHost;
uint8_t *tilingDevice;
aclrtMallocHost((void **)(&tilingHost), tilingFileSize);
aclrtMalloc((void **)&tilingDevice, tilingFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMemcpy(tilingHost, tilingFileSize, tilingBuf, tilingFileSize, ACL_MEMCPY_HOST_TO_HOST);
aclrtMemcpy(tilingDevice, tilingFileSize, tilingHost, tilingFileSize, ACL_MEMCPY_HOST_TO_DEVICE);
matmul_custom<<<numBlocks, nullptr, stream>>>(aDevice, bDevice, biasDevice, cDevice, workspaceDevice, tilingDevice);
aclrtSynchronizeStream(stream);
aclrtMemcpy(cHost, cFileSize, cDevice, cFileSize, ACL_MEMCPY_DEVICE_TO_HOST);
WriteFile("./output/output.bin", cHost, cFileSize);
aclrtFree(aDevice);
aclrtFreeHost(aHost);
aclrtFree(bDevice);
aclrtFreeHost(bHost);
aclrtFree(biasDevice);
aclrtFreeHost(biasHost);
aclrtFree(workspaceDevice);
aclrtFree(tilingDevice);
aclrtFreeHost(tilingHost);
aclrtFree(cDevice);
aclrtFreeHost(cHost);
aclrtDestroyStream(stream);
aclrtResetDevice(deviceId);
aclFinalize();
free(tilingBuf);
return 0;
}