* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file grouped_matmul_a4w4.h
* \brief
*/
#ifndef ASCENDC_GROUPED_MATMUL_A4W4_H
#define ASCENDC_GROUPED_MATMUL_A4W4_H
#include "grouped_matmul_utils.h"
#include "grouped_matmul.h"
#include "kernel_operator.h"
#ifdef GMM_A4W4
namespace GROUPED_MATMUL{
using namespace matmul;
using namespace AscendC;
using DTYPE_PERTOKEN_SCALE_A4W4 = float;
#ifdef GMM_A4W4_BF16
using DTYPE_Y_A4W4 = bfloat16_t;
#else
using DTYPE_Y_A4W4 = half;
#endif
using DTYPE_X_A4W4 = int4b_t;
using DTYPE_WEIGHT_A4W4 = int4b_t;
using DTYPE_SCALE_A4W4 = uint64_t;
constexpr uint64_t SYNC_AIV_TO_AIC = 3;
constexpr uint64_t SYNC_AIC_TO_AIV = 5;
constexpr uint32_t BUFFER_NUM = 2;
constexpr int32_t PARALL_NUM = 4;
template <typename T>
__aicore__ inline void DataCopyPad2DA4W4(const LocalTensor<T> dst, const GlobalTensor<T> src, uint32_t dim1, uint32_t dim0,
uint32_t srcDim0) {
DataCopyExtParams params;
params.blockCount = dim1;
params.blockLen = dim0 * sizeof(T);
params.srcStride = (srcDim0 - dim0) * sizeof(T);
params.dstStride = 0;
DataCopyPadExtParams<T> padParams{true, 0, 0, 0};
DataCopyPad(dst, src, params, padParams);
return;
}
template <class mmType>
class GMMA4W4Compute {
public:
using aT = MatmulType<TPosition::GM, CubeFormat::ND, DTYPE_X_A4W4>;
using bT = typename mmType::BT;
using biasT = MatmulType<TPosition::GM, CubeFormat::ND, int32_t>;
using cT = MatmulType<TPosition::GM, CubeFormat::ND, half>;
using DTYPE_OUT = DTYPE_Y_A4W4;
public:
__aicore__ inline GMMA4W4Compute(typename mmType::MT &matmul) : mm(matmul) {}
__aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR scale, GM_ADDR groupList,
GM_ADDR perTokenScale, GM_ADDR y, GM_ADDR workspace, const GMMBaseParams* __restrict gmmBaseParams,
const TCubeTiling* __restrict mmTilingData, TPipe* tPipe);
__aicore__ inline void Process();
protected:
__aicore__ inline void ProcessCommon(
MNConfig &mnConfig, uint32_t splitValue, uint32_t groupIdx, uint32_t &preCount);
__aicore__ inline void InitUbBuffer();
__aicore__ inline void MMCompute(uint32_t groupIdx, MNConfig& mnConfig);
__aicore__ inline void MMComputePerGroup(uint32_t groupIdx, MNConfig& mnConfig, uint32_t curSingleM, uint32_t curSingleN,
uint64_t xOffset, uint64_t weightOffset, uint32_t tailN);
__aicore__ inline void MMComputePerChannel(uint32_t groupIdx, MNConfig& mnConfig, uint32_t curSingleM, uint32_t curSingleN,
uint64_t xOffset, uint64_t weightOffset, uint32_t tailN);
__aicore__ inline void VectorCompute(uint32_t groupIdx, MNConfig& mnConfig);
__aicore__ inline void VectorTilingCalc(MNConfig& mnConfig, uint32_t& curCubeSingleN, uint32_t& curCubeSingleM, uint32_t& vecBaseM);
__aicore__ inline void ComputeDequantAndActivate(MNConfig& mnConfig, uint32_t curVecBaseM, uint32_t alignBaseN,
uint32_t curVecBaseN, uint32_t offsetM);
__aicore__ inline void DataCopyScale(uint32_t curBaseN, uint32_t alignBaseN, uint64_t scaleOffset);
__aicore__ inline void DataCopyPerTokenScaleAndBrcb(MNConfig& mnConfig, uint32_t curBaseM, uint32_t alignBaseN,
uint32_t offsetM);
protected:
typename mmType::MT& mm;
const uint32_t HALF_ALIGN = 16;
GlobalTensor<DTYPE_X_A4W4> xGm;
GlobalTensor<DTYPE_WEIGHT_A4W4> weightGm;
GlobalTensor<cT::T> mmOutGm;
GlobalTensor<DTYPE_SCALE_A4W4> scaleGm;
GlobalTensor<DTYPE_PERTOKEN_SCALE_A4W4> perTokenScaleGm;
GlobalTensor<int64_t> groupListGm;
GlobalTensor<DTYPE_OUT> yGm;
TQue<QuePosition::VECIN, 1> vecInQueue;
TQue<QuePosition::VECOUT, 1> vecOutQueue;
TQue<QuePosition::VECIN, 1> scaleInQueue;
TQue<QuePosition::VECIN, 1> perTokenScaleInQueue;
TBuf<TPosition::VECCALC> tmpBuff;
LocalTensor<float> mmOutFp32Buf;
LocalTensor<float> pertokenBrcbLocal;
LocalTensor<float> perTokenResBuf;
LocalTensor<uint8_t> calcTmpBuf;
uint32_t subBlockIdx;
uint32_t coreIdx;
uint32_t quantGroupSize_;
uint32_t cubeCount = 0;
uint32_t vecCount_ = 0;
uint32_t mmBaseBlockOffset_ = 0;
TPipe *pipe;
const GMMBaseParams *tiling;
const TCubeTiling* mmTilingData;
};
template <typename mmType>
__aicore__ inline void GMMA4W4Compute<mmType>::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR scale, GM_ADDR groupList,
GM_ADDR perTokenScale, GM_ADDR y, GM_ADDR workspace, const GMMBaseParams* __restrict gmmBaseParams,
const TCubeTiling* __restrict mmTilingData, TPipe* tPipe)
{
xGm.SetGlobalBuffer(GetTensorAddr<DTYPE_X_A4W4>(0, x));
weightGm.SetGlobalBuffer(GetTensorAddr<DTYPE_WEIGHT_A4W4>(0, weight));
mmOutGm.SetGlobalBuffer(reinterpret_cast<__gm__ cT::T *>(workspace));
scaleGm.SetGlobalBuffer(GetTensorAddr<DTYPE_SCALE_A4W4>(0, scale));
perTokenScaleGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_PERTOKEN_SCALE_A4W4 *>(perTokenScale));
groupListGm.SetGlobalBuffer(reinterpret_cast<__gm__ int64_t *>(groupList));
yGm.SetGlobalBuffer(GetTensorAddr<DTYPE_OUT>(0, y));
this->tiling = gmmBaseParams;
this->mmTilingData = mmTilingData;
mmBaseBlockOffset_ = mmTilingData->baseM * mmTilingData->baseN;
quantGroupSize_ = tiling->k / tiling->quantGroupNum;
subBlockIdx = GetSubBlockIdx();
coreIdx = GetBlockIdx();
if ASCEND_IS_AIV {
if (GetTaskRation() != 0) {
coreIdx /= GetTaskRation();
}
}
pipe = tPipe;
InitUbBuffer();
}
template <typename mmType>
__aicore__ inline void GMMA4W4Compute<mmType>::InitUbBuffer()
{
if ASCEND_IS_AIC {
return;
}
pipe->InitBuffer(perTokenScaleInQueue, BUFFER_NUM, mmTilingData->baseM * sizeof(float));
pipe->InitBuffer(vecInQueue, BUFFER_NUM, tiling->ubCalSize * sizeof(cT::T));
pipe->InitBuffer(vecOutQueue, BUFFER_NUM, tiling->ubCalSize * sizeof(DTYPE_OUT));
pipe->InitBuffer(tmpBuff, tiling->ubRestBytes);
uint32_t ubCalSizeFloat = tiling->ubCalSize * sizeof(float);
mmOutFp32Buf = tmpBuff.GetWithOffset<float>(tiling->ubCalSize, 0);
uint32_t offset = ubCalSizeFloat;
pertokenBrcbLocal = tmpBuff.GetWithOffset<float>(tiling->ubCalSize, offset);
offset += ubCalSizeFloat;
perTokenResBuf = tmpBuff.GetWithOffset<float>(tiling->ubCalSize, offset);
offset += ubCalSizeFloat;
calcTmpBuf = tmpBuff.GetWithOffset<uint8_t>(ubCalSizeFloat, offset);
offset += ubCalSizeFloat;
}
template <typename mmType>
__aicore__ inline void GMMA4W4Compute<mmType>::Process()
{
MNConfig mnConfig;
mnConfig.baseM = mmTilingData->baseM;
mnConfig.baseN = mmTilingData->baseN;
mnConfig.singleM = mnConfig.baseM;
if (tiling->quantGroupNum == 1 && tiling->singleN != 0) {
mnConfig.singleN = tiling->singleN;
} else {
mnConfig.singleN = mnConfig.baseN;
}
mnConfig.blockDimN = Ceil(tiling->n, mnConfig.singleN);
int32_t preOffset = 0;
for (uint32_t groupIdx = 0, preCount = 0; groupIdx < tiling->groupNum; ++groupIdx) {
int32_t splitValue = GetSplitValueFromGroupList(groupIdx, preOffset, tiling, groupListGm);
if (splitValue <= 0) {
continue;
}
ProcessCommon(mnConfig, splitValue, groupIdx, preCount);
}
}
template <typename mmType>
__aicore__ inline void GMMA4W4Compute<mmType>::ProcessCommon(
MNConfig &mnConfig, uint32_t splitValue, uint32_t groupIdx, uint32_t &preCount)
{
uint32_t coreNum = tiling->coreNum;
mnConfig.m = static_cast<uint32_t>(splitValue);
mnConfig.blockDimM = Ceil(mnConfig.m, mnConfig.singleM);
mm.SetOrgShape(mnConfig.m, tiling->n, tiling->k);
uint32_t curCount = preCount + mnConfig.blockDimN * mnConfig.blockDimM;
uint32_t curBlock = coreIdx >= preCount ? coreIdx : coreIdx + coreNum;
uint32_t thresholdM_dimN = thresholdBlockNum * mnConfig.blockDimN;
while (curBlock < curCount) {
MNBlockIdxCompute(mnConfig, curBlock, preCount, thresholdM_dimN);
MMCompute(groupIdx, mnConfig);
if ASCEND_IS_AIV {
VectorCompute(groupIdx, mnConfig);
}
curBlock += coreNum;
}
preCount = curCount % coreNum;
mnConfig.offsetM += mnConfig.m;
}
template <typename mmType>
__aicore__ inline void GMMA4W4Compute<mmType>::MMComputePerChannel(uint32_t groupIdx, MNConfig& mnConfig, uint32_t curSingleM, uint32_t curSingleN,
uint64_t xOffset, uint64_t weightOffset, uint32_t tailN){
mm.SetSingleShape(curSingleM, curSingleN, quantGroupSize_);
GlobalTensor<DTYPE_WEIGHT_A4W4> weightSlice;
mm.SetTensorA(xGm[xOffset]);
weightSlice = weightGm[weightOffset];
if (mnConfig.blockDimM == 1) {
weightSlice.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE);
}
mm.SetTensorB(weightSlice, mmType::BT::isTrans);
mm.SetQuantVector(scaleGm[groupIdx * tiling->n * tiling->quantGroupNum + tailN]);
#ifndef __CCE_KT_TEST__
while(mm.Iterate()) {
mnConfig.workSpaceOffset = mmBaseBlockOffset_ * \
(coreIdx + (cubeCount % PARALL_NUM) * tiling->coreNum);
if (cubeCount >= PARALL_NUM) {
CrossCoreWaitFlag(SYNC_AIV_TO_AIC);
}
mm.GetTensorC(mmOutGm[mnConfig.workSpaceOffset], 0, true);
CrossCoreSetFlag<2, PIPE_FIX>(SYNC_AIC_TO_AIV);
++cubeCount;
}
#endif
}
template <typename mmType>
__aicore__ inline void GMMA4W4Compute<mmType>::MMComputePerGroup(uint32_t groupIdx, MNConfig& mnConfig, uint32_t curSingleM, uint32_t curSingleN,
uint64_t xOffset, uint64_t weightOffset, uint32_t tailN){
mnConfig.workSpaceOffset = mmBaseBlockOffset_ * \
(coreIdx + (cubeCount % PARALL_NUM) * tiling->coreNum);
if (cubeCount >= PARALL_NUM) {
CrossCoreWaitFlag(SYNC_AIV_TO_AIC);
}
mm.SetSingleShape(curSingleM, curSingleN, quantGroupSize_);
GlobalTensor<DTYPE_WEIGHT_A4W4> weightSlice;
for (uint32_t loopK = 0; loopK < tiling->quantGroupNum; loopK++) {
mm.SetTensorA(xGm[xOffset + loopK * quantGroupSize_]);
if constexpr (mmType::BT::format == CubeFormat::NZ && mmType::BT::isTrans == true) {
weightSlice = weightGm[weightOffset + ((loopK * quantGroupSize_)) * tiling->n ];
} else if constexpr (mmType::BT::format == CubeFormat::NZ && mmType::BT::isTrans == false) {
weightSlice = weightGm[weightOffset + loopK * quantGroupSize_ * 64];
} else {
weightSlice = weightGm[weightOffset + loopK * quantGroupSize_ * tiling->n];
}
if (mnConfig.blockDimM == 1) {
weightSlice.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE);
}
mm.SetTensorB(weightSlice, mmType::BT::isTrans);
mm.SetQuantVector(scaleGm[groupIdx * tiling->n * tiling->quantGroupNum + loopK * tiling->n + tailN]);
#ifndef __CCE_KT_TEST__
mm.Iterate();
mm.GetTensorC(mmOutGm[mnConfig.workSpaceOffset], loopK == 0 ? 0 : 1, true);
#endif
}
CrossCoreSetFlag<2, PIPE_FIX>(SYNC_AIC_TO_AIV);
cubeCount++;
}
template <typename mmType>
__aicore__ inline void GMMA4W4Compute<mmType>::MMCompute(uint32_t groupIdx, MNConfig& mnConfig)
{
if ASCEND_IS_AIC {
uint32_t tailN = mnConfig.nIdx * mnConfig.singleN;
uint32_t curSingleN = mnConfig.singleN;
if (unlikely(mnConfig.nIdx == mnConfig.blockDimN - 1)) {
curSingleN = tiling->n - tailN;
}
uint32_t curSingleM = mnConfig.singleM;
if (unlikely(mnConfig.mIdx == mnConfig.blockDimM - 1)) {
curSingleM = mnConfig.m - mnConfig.mIdx * mnConfig.singleM;
}
uint64_t xOffset = (mnConfig.offsetM + mnConfig.mIdx * mnConfig.singleM) * tiling->k;
uint64_t weightOffset;
if constexpr (mmType::BT::format == CubeFormat::NZ && mmType::BT::isTrans == true) {
weightOffset = groupIdx * tiling->n * tiling->k + tailN * 64;
} else if constexpr (mmType::BT::format == CubeFormat::NZ && mmType::BT::isTrans == false) {
weightOffset = groupIdx * tiling->n * tiling->k + tailN * tiling->k;
} else {
weightOffset = groupIdx * tiling->n * tiling->k + tailN;
}
if(tiling->quantGroupNum == 1) {
MMComputePerChannel(groupIdx, mnConfig, curSingleM, curSingleN, xOffset, weightOffset, tailN);
} else {
MMComputePerGroup(groupIdx, mnConfig, curSingleM, curSingleN, xOffset, weightOffset, tailN);
}
}
}
template <typename mmType>
__aicore__ inline void GMMA4W4Compute<mmType>::VectorTilingCalc(
MNConfig& mnConfig, uint32_t& curCubeSingleN, uint32_t& curCubeSingleM, uint32_t& vecBaseM)
{
curCubeSingleN = mnConfig.nIdx == mnConfig.blockDimN - 1 ?
tiling->n - mnConfig.nIdx * mnConfig.singleN : mnConfig.singleN;
curCubeSingleM = mnConfig.mIdx == mnConfig.blockDimM - 1 ?
mnConfig.m - mnConfig.mIdx * mnConfig.singleM : mnConfig.singleM;
vecBaseM = tiling->ubCalSize / (Ceil(mnConfig.baseN, uint32_t(8)) * 8);
vecBaseM = vecBaseM < curCubeSingleM ? vecBaseM : curCubeSingleM;
}
template <typename mmType>
__aicore__ inline void GMMA4W4Compute<mmType>::VectorCompute(uint32_t groupIdx, MNConfig& mnConfig)
{
uint32_t curCubeSingleN;
uint32_t curCubeSingleM;
uint32_t vecBaseM;
VectorTilingCalc(mnConfig, curCubeSingleN, curCubeSingleM, vecBaseM);
uint32_t mGlobalOffset = mnConfig.offsetM + mnConfig.mIdx * mnConfig.singleM;
uint64_t outOffset = mGlobalOffset * tiling->n + mnConfig.nIdx * mnConfig.singleN;
uint32_t curVecBaseN = mnConfig.baseN;
uint32_t taskRation = GetTaskRation();
uint32_t nCount = 0;
for (uint32_t offsetN = 0; offsetN < curCubeSingleN; offsetN += mnConfig.baseN) {
mnConfig.workSpaceOffset = mmBaseBlockOffset_ * \
(coreIdx + (cubeCount % PARALL_NUM) * tiling->coreNum);
if (unlikely(offsetN + mnConfig.baseN >= curCubeSingleN)) curVecBaseN = curCubeSingleN - offsetN;
uint32_t alignBaseN = Ceil(curVecBaseN, uint32_t(16)) * 16;
uint32_t curVecBaseM = vecBaseM;
uint64_t mmOutOffset = mnConfig.workSpaceOffset;
uint32_t mCount = 0;
CrossCoreWaitFlag(SYNC_AIC_TO_AIV);
for (uint32_t offsetM = 0; offsetM < curCubeSingleM; offsetM += vecBaseM) {
vecCount_++;
if (taskRation != 0 && vecCount_ % taskRation != subBlockIdx) { continue; }
if (unlikely(offsetM + vecBaseM >= curCubeSingleM)) { curVecBaseM = curCubeSingleM - offsetM; }
LocalTensor<cT::T> mmOutLocal = vecInQueue.AllocTensor<cT::T>();
DataCopyPad2DA4W4(mmOutLocal, mmOutGm[mmOutOffset + offsetM * curVecBaseN], curVecBaseM, curVecBaseN, curVecBaseN);
vecInQueue.EnQue(mmOutLocal);
ComputeDequantAndActivate(mnConfig, curVecBaseM, alignBaseN, curVecBaseN, offsetM);
LocalTensor<DTYPE_OUT> yLocal = vecOutQueue.DeQue<DTYPE_OUT>();
DataCopyPad2D(yGm[outOffset + offsetM * tiling->n + offsetN], yLocal,
curVecBaseM, curVecBaseN, alignBaseN, tiling->n);
vecOutQueue.FreeTensor(yLocal);
}
CrossCoreSetFlag<2, PIPE_MTE2>(SYNC_AIV_TO_AIC);
++cubeCount;
}
}
template <typename mmType>
__aicore__ inline void GMMA4W4Compute<mmType>::ComputeDequantAndActivate(MNConfig& mnConfig,
uint32_t curVecBaseM, uint32_t alignBaseN, uint32_t curVecBaseN, uint32_t offsetM)
{
uint32_t computeSize = curVecBaseM * alignBaseN;
LocalTensor<cT::T> mmOutInUb = vecInQueue.DeQue<cT::T>();
uint32_t castSize = 0;
if constexpr (mmType::BT::format == CubeFormat::ND) {
castSize = (computeSize + HALF_ALIGN - 1) / HALF_ALIGN * HALF_ALIGN;
} else {
castSize = computeSize;
}
Cast(mmOutFp32Buf, mmOutInUb, RoundMode::CAST_NONE, castSize);
PipeBarrier<PIPE_V>();
vecInQueue.FreeTensor(mmOutInUb);
DataCopyPerTokenScaleAndBrcb(mnConfig, curVecBaseM, alignBaseN, offsetM);
Mul(perTokenResBuf, mmOutFp32Buf, pertokenBrcbLocal, computeSize);
PipeBarrier<PIPE_V>();
LocalTensor<DTYPE_OUT> yLocalInUb = vecOutQueue.AllocTensor<DTYPE_OUT>();
Cast(yLocalInUb, perTokenResBuf, RoundMode::CAST_RINT, computeSize);
PipeBarrier<PIPE_V>();
vecOutQueue.EnQue(yLocalInUb);
}
template <typename mmType>
__aicore__ inline void GMMA4W4Compute<mmType>::DataCopyPerTokenScaleAndBrcb(MNConfig& mnConfig,
uint32_t curBaseM, uint32_t alignBaseN, uint32_t offsetM)
{
uint64_t perTokenScaleOffset = mnConfig.offsetM + mnConfig.mIdx * mnConfig.singleM + offsetM;
DataCopyPadExtParams<float> padParams;
DataCopyExtParams perTokenScaleParams{1, static_cast<uint32_t>(curBaseM * sizeof(float)), 0, 0, 0};
LocalTensor<float> perTokenScaleLocal = perTokenScaleInQueue.AllocTensor<float>();
DataCopyPad(perTokenScaleLocal, perTokenScaleGm[perTokenScaleOffset], perTokenScaleParams, padParams);
perTokenScaleInQueue.EnQue(perTokenScaleLocal);
perTokenScaleLocal = perTokenScaleInQueue.DeQue<float>();
const uint32_t broadCastDst[2] = {curBaseM, alignBaseN};
const uint32_t broadCastSrc[2] = {curBaseM, 1};
BroadCast<float, 2, 1>(pertokenBrcbLocal, perTokenScaleLocal, broadCastDst, broadCastSrc, calcTmpBuf);
perTokenScaleInQueue.FreeTensor(perTokenScaleLocal);
}
template <class mmType>
class GMMA4W4SparseCompute : public GMMA4W4Compute<mmType> {
public:
__aicore__ inline GMMA4W4SparseCompute(typename mmType::MT &matmul) : GMMA4W4Compute<mmType>(matmul) {}
__aicore__ inline void Process();
};
template <typename mmType>
__aicore__ inline void GMMA4W4SparseCompute<mmType>::Process()
{
MNConfig mnConfig;
mnConfig.baseM = this->mmTilingData->baseM;
mnConfig.baseN = this->mmTilingData->baseN;
mnConfig.singleM = mnConfig.baseM;
if (this->tiling->quantGroupNum == 1 && this->tiling->singleN != 0) {
mnConfig.singleN = this->tiling->singleN;
} else {
mnConfig.singleN = mnConfig.baseN;
}
mnConfig.blockDimN = Ceil(this->tiling->n, mnConfig.singleN);
int32_t preOffset = 0;
uint32_t groupListSplitValueOffset = 1;
uint32_t groupListInnerShape = 2u;
uint32_t groupListShapeSize = this->tiling->groupNum * groupListInnerShape;
for (uint32_t loop = 0, preCount = 0; loop < groupListShapeSize; loop += groupListInnerShape) {
int32_t splitValue = static_cast<int32_t>(this->groupListGm.GetValue(loop + groupListSplitValueOffset));
if (splitValue <= 0) {
break;
}
uint32_t groupIdx = static_cast<uint32_t>(this->groupListGm.GetValue(loop));
this->ProcessCommon(mnConfig, splitValue, groupIdx, preCount);
}
}
}
#endif
#endif