/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file batch_matmul.asc
* \brief
*/
#include "data_utils.h"
#include "kernel_tiling/kernel_tiling.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/tiling_api.h"
#include "acl/acl.h"
#include "kernel_operator.h"
#include "lib/matmul_intf.h"
constexpr uint32_t M = 32;
constexpr uint32_t N = 256;
constexpr uint32_t K = 64;
constexpr int32_t baseM = 32;
constexpr int32_t baseN = 32;
constexpr int32_t A_BNUM = 2;
constexpr int32_t A_SNUM = 32;
constexpr int32_t A_GNUM = 3;
constexpr int32_t A_DNUM = 64;
constexpr int32_t B_BNUM = 2;
constexpr int32_t B_SNUM = 256;
constexpr int32_t B_GNUM = 3;
constexpr int32_t B_DNUM = 64;
constexpr int32_t C_BNUM = 2;
constexpr int32_t C_SNUM = 32;
constexpr int32_t C_GNUM = 3;
constexpr int32_t C_DNUM = 256;
constexpr int32_t BATCH_NUM = 3;
constexpr int USED_CORE_NUM = 2;
constexpr int32_t FULL_L1_SIZE = 512 * 1024;
constexpr int32_t FULL_L0C_SIZE = 128 * 1024;
bool ComputeTiling(optiling::TCubeTiling& tiling, matmul_tiling::MultiCoreMatmulTiling* cubeTiling, bool isBias)
{
int32_t M = 32;
int32_t N = 256;
int32_t K = 64;
int32_t baseM = 32;
int32_t baseN = 32;
cubeTiling->SetDim(1);
cubeTiling->SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16);
cubeTiling->SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16, true);
cubeTiling->SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT);
cubeTiling->SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT);
cubeTiling->SetShape(M, N, K);
cubeTiling->SetOrgShape(M, N, K);
cubeTiling->SetFixSplit(baseM, baseN, -1);
cubeTiling->EnableBias(isBias);
cubeTiling->SetBufferSpace(-1, -1, -1);
cubeTiling->SetALayout(A_BNUM, A_SNUM, 1, A_GNUM, A_DNUM);
cubeTiling->SetBLayout(B_BNUM, B_SNUM, 1, B_GNUM, B_DNUM);
cubeTiling->SetCLayout(C_BNUM, C_SNUM, 1, C_GNUM, C_DNUM);
cubeTiling->SetBatchNum(BATCH_NUM);
cubeTiling->SetBufferSpace(-1, -1, -1);
if (cubeTiling->GetTiling(tiling) == -1) {
return false;
}
return true;
}
uint8_t *GetTilingBuf(optiling::TCubeTiling *tilingData)
{
if (!tilingData) {
return nullptr;
}
uint32_t tilingSize = tilingData->GetDataSize();
if (tilingSize == 0) {
return nullptr;
}
uint8_t *buf = (uint8_t *)malloc(tilingSize);
if (!buf) {
return nullptr;
}
tilingData->SaveToBuffer(buf, tilingSize);
return buf;
}
uint8_t *GenerateTiling()
{
optiling::TCubeTiling tilingData;
matmul_tiling::MultiCoreMatmulTiling tilingApi;
bool res = ComputeTiling(tilingData, &tilingApi, false);
if (!res) {
std::cout << "gen tiling failed" << std::endl;
}
return GetTilingBuf(&tilingData);
}
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE>
class BatchMatmulKernel {
public:
__aicore__ inline BatchMatmulKernel(){};
__aicore__ inline void Init(GM_ADDR a, GM_ADDR b, GM_ADDR bias, GM_ADDR c, GM_ADDR workspace, const TCubeTiling& tiling);
template <bool hasBias = false>
__aicore__ inline void Process(AscendC::TPipe* pipe, int32_t batchA, int32_t batchB);
AscendC::Matmul<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE> matmulObj;
private:
__aicore__ inline void CalcOffset(int32_t blockIdx, const TCubeTiling& tiling, int32_t& offsetA, int32_t& offsetB,
int32_t& offsetC, int32_t& offsetBias);
using aType = typename A_TYPE::T;
using bType = typename B_TYPE::T;
using cType = typename C_TYPE::T;
using biasType = typename BIAS_TYPE::T;
AscendC::GlobalTensor<aType> aGlobal;
AscendC::GlobalTensor<bType> bGlobal;
AscendC::GlobalTensor<cType> cGlobal;
AscendC::GlobalTensor<biasType> biasGlobal;
TCubeTiling tiling;
};
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE>
__aicore__ inline void BatchMatmulKernel<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE>::Init(GM_ADDR a, GM_ADDR b, GM_ADDR bias,
GM_ADDR c, GM_ADDR workspace, const TCubeTiling& tiling)
{
this->tiling = tiling;
int32_t sizeA = tiling.ALayoutInfoB * tiling.ALayoutInfoS * tiling.ALayoutInfoN * tiling.ALayoutInfoG * tiling.ALayoutInfoD * sizeof(aType);
int32_t sizeB = tiling.BLayoutInfoB * tiling.BLayoutInfoS * tiling.BLayoutInfoN * tiling.BLayoutInfoG * tiling.BLayoutInfoD * sizeof(bType);
int32_t sizeC = tiling.CLayoutInfoB * tiling.CLayoutInfoS1 * tiling.CLayoutInfoN * tiling.CLayoutInfoG * tiling.CLayoutInfoS2 * sizeof(cType);
int32_t sizeBias = tiling.CLayoutInfoB * tiling.CLayoutInfoN * tiling.CLayoutInfoG * tiling.CLayoutInfoS2 * sizeof(cType);
aGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ aType*>(a), sizeA);
bGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ bType*>(b), sizeB);
cGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ cType*>(c), sizeC);
biasGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ biasType*>(bias), sizeBias);
int32_t offsetA = 0;
int32_t offsetB = 0;
int32_t offsetC = 0;
int32_t offsetBias = 0;
CalcOffset(AscendC::GetBlockIdx(), tiling, offsetA, offsetB, offsetC, offsetBias);
aGlobal = aGlobal[offsetA];
bGlobal = bGlobal[offsetB];
cGlobal = cGlobal[offsetC];
biasGlobal = biasGlobal[offsetBias];
if (GetSysWorkSpacePtr() == nullptr) {
return;
}
}
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE>
template <bool hasBias>
__aicore__ inline void BatchMatmulKernel<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE>::Process(AscendC::TPipe* pipe, int32_t batchA, int32_t batchB)
{
int batchC = batchA > batchB ? batchA : batchB;
int gLay = tiling.ALayoutInfoG > tiling.BLayoutInfoG ? tiling.ALayoutInfoG : tiling.BLayoutInfoG;
int forExent = (tiling.ALayoutInfoB / USED_CORE_NUM) * tiling.ALayoutInfoN * gLay / tiling.BatchNum; // cut multi cores from batch axis
for (int i = 0; i < forExent; ++i) {
int batchOffsetA = i * tiling.ALayoutInfoD * batchA;
int batchOffsetB = i * tiling.BLayoutInfoD * batchB;
if (tiling.BatchNum == tiling.ALayoutInfoN * tiling.ALayoutInfoG) {
batchOffsetA = i * tiling.ALayoutInfoD * tiling.ALayoutInfoS * batchA;
}
if (tiling.BatchNum == tiling.BLayoutInfoN * tiling.BLayoutInfoG) {
batchOffsetB = i * tiling.BLayoutInfoD * tiling.BLayoutInfoS * batchB;
}
matmulObj.SetTensorA(aGlobal[batchOffsetA], false);
matmulObj.SetTensorB(bGlobal[batchOffsetB], true); // B transpose
int idxC = i * batchC;
if constexpr (hasBias) {
int batchOffsetBias = idxC * tiling.CLayoutInfoS2;
matmulObj.SetBias(biasGlobal[batchOffsetBias]);
}
int batchOffsetC = idxC * tiling.CLayoutInfoS2;
if (tiling.BatchNum == tiling.CLayoutInfoN * tiling.CLayoutInfoG) {
batchOffsetC = idxC * tiling.CLayoutInfoS2 * tiling.CLayoutInfoS1;
}
matmulObj.IterateBatch(cGlobal[batchOffsetC], batchA, batchB, false);
AscendC::PipeBarrier<PIPE_FIX>();
}
}
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE>
__aicore__ inline void BatchMatmulKernel<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE>::CalcOffset(int32_t blockIdx, const TCubeTiling& param,
int32_t& offsetA, int32_t& offsetB, int32_t& offsetC, int32_t& offsetBias)
{
int singleCoreBatch = 1; // Batch axis nums after cut on singleCore
if constexpr (A_TYPE::layout == LayoutMode::BSNGD) {
// cut multi cores from batch axis when layout is BSNGD
offsetA = blockIdx * singleCoreBatch * param.ALayoutInfoS * param.ALayoutInfoN * param.ALayoutInfoG * param.ALayoutInfoD;
}
if constexpr (B_TYPE::layout == LayoutMode::BSNGD) {
// cut multi cores from batch axis when layout is BSNGD
offsetB = blockIdx * singleCoreBatch * param.BLayoutInfoS * param.BLayoutInfoN * param.BLayoutInfoG * param.BLayoutInfoD;
}
if constexpr (C_TYPE::layout == LayoutMode::BSNGD) {
// cut multi cores from batch axis when layout is BSNGD
offsetC = blockIdx * singleCoreBatch * param.CLayoutInfoS1 * param.CLayoutInfoN * param.CLayoutInfoG * param.CLayoutInfoS2;
offsetBias = blockIdx * singleCoreBatch * param.CLayoutInfoN * param.CLayoutInfoG * param.CLayoutInfoS2;
}
}
__aicore__ inline void CopyTiling(TCubeTiling* tiling, GM_ADDR tilingGM)
{
uint32_t* ptr = reinterpret_cast<uint32_t*>(tiling);
auto tiling32 = reinterpret_cast<__gm__ uint32_t*>(tilingGM);
for (int i = 0; i < sizeof(TCubeTiling) / sizeof(uint32_t); i++, ptr++) {
*ptr = *(tiling32 + i);
}
return;
}
extern "C" __global__ __aicore__ void batch_matmul_custom(GM_ADDR a, GM_ADDR b, GM_ADDR c, __kfc_workspace__ GM_ADDR workspace,
GM_ADDR tilingGm)
{
// prepare tiling
TCubeTiling tiling;
CopyTiling(&tiling, tilingGm);
// define matmul kernel
typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half, false, LayoutMode::BSNGD> A_TYPE;
typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half, true, LayoutMode::BSNGD> B_TYPE;
typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float, false, LayoutMode::BSNGD> C_TYPE;
typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float> BIAS_TYPE;
BatchMatmulKernel<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE> batchMatmulKernel;
AscendC::TPipe pipe;
tiling.shareMode = 0; // 0, share mode
tiling.shareL1Size = FULL_L1_SIZE; // full L1
tiling.shareL0CSize = FULL_L0C_SIZE; // full L0C
tiling.shareUbSize = 0; // no UB
REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), batchMatmulKernel.matmulObj, &tiling);
// init matmul kernel
batchMatmulKernel.Init(a, b, nullptr, c, workspace, tiling);
// matmul kernel process
batchMatmulKernel.Process<false>(&pipe, 3, 3);
}
void batch_matmul_custom_do(uint32_t numBlocks, void* stream, GM_ADDR a, GM_ADDR b,
GM_ADDR c, GM_ADDR workspace, GM_ADDR tilingGm)
{
// invoke the kernel function through the <<<>>> symbol
batch_matmul_custom<<<numBlocks, nullptr, stream>>>(a, b, c, workspace, tilingGm);
}
int32_t main(int32_t argc, char* argv[])
{
size_t aFileSize = 192 * 64 * sizeof(uint16_t); // uint16_t represent half
size_t bFileSize = 64 * 1536 * sizeof(uint16_t); // uint16_t represent half
size_t cFileSize = 192 * 256 * sizeof(float);
uint32_t workspaceSize = 16 * 1024 * 1024;
size_t tilingFileSize = sizeof(TCubeTiling);
uint32_t numBlocks = 1;
int64_t wrongNum = -1;
uint8_t *tilingBuf = GenerateTiling();
aclInit(nullptr);
aclrtContext context;
int32_t deviceId = 0;
aclrtSetDevice(deviceId);
aclrtCreateContext(&context, deviceId);
aclrtStream stream = nullptr;
aclrtCreateStream(&stream);
uint8_t *aHost;
uint8_t *aDevice;
aclrtMallocHost((void **)(&aHost), aFileSize);
aclrtMalloc((void **)&aDevice, aFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
ReadFile("./input/x1_gm.bin", aFileSize, aHost, aFileSize);
aclrtMemcpy(aDevice, aFileSize, aHost, aFileSize,
ACL_MEMCPY_HOST_TO_DEVICE);
uint8_t *bHost;
uint8_t *bDevice;
aclrtMallocHost((void **)(&bHost), bFileSize);
aclrtMalloc((void **)&bDevice, bFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
ReadFile("./input/x2_gm.bin", bFileSize, bHost, bFileSize);
aclrtMemcpy(bDevice, bFileSize, bHost, bFileSize,
ACL_MEMCPY_HOST_TO_DEVICE);
uint8_t *workspaceHost;
uint8_t *workspaceDevice;
aclrtMallocHost((void **)(&workspaceHost), workspaceSize);
aclrtMalloc((void **)&workspaceDevice, workspaceSize,
ACL_MEM_MALLOC_HUGE_FIRST);
uint8_t *tilingHost;
uint8_t *tilingDevice;
aclrtMallocHost((void **)(&tilingHost), tilingFileSize);
aclrtMalloc((void **)&tilingDevice, tilingFileSize,
ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMemcpy(tilingHost, tilingFileSize, tilingBuf,
tilingFileSize, ACL_MEMCPY_HOST_TO_HOST);
aclrtMemcpy(tilingDevice, tilingFileSize, tilingHost,
tilingFileSize, ACL_MEMCPY_HOST_TO_DEVICE);
uint8_t *cHost;
uint8_t *cDevice;
aclrtMallocHost((void **)(&cHost), cFileSize);
aclrtMalloc((void **)&cDevice, cFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
batch_matmul_custom_do(numBlocks, stream, aDevice, bDevice, cDevice, workspaceDevice, tilingDevice);
aclrtSynchronizeStream(stream);
aclrtMemcpy(cHost, cFileSize, cDevice, cFileSize,
ACL_MEMCPY_DEVICE_TO_HOST);
WriteFile("./output/output.bin", cHost, cFileSize);
aclrtFree(cDevice);
aclrtFreeHost(cHost);
aclrtDestroyStream(stream);
aclrtDestroyContext(context);
aclrtResetDevice(deviceId);
aclFinalize();
}