* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "catlass/arch/arch.hpp"
#include "catlass/arch/cross_core_sync.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/catlass.hpp"
#include "catlass/debug.hpp"
#include "catlass/epilogue/block/block_epilogue.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"
#include "catlass/gemm/block/block_mmad.hpp"
#include "catlass/gemm/dispatch_policy.hpp"
#include "catlass/gemm/gemm_type.hpp"
#include "catlass/layout/layout.hpp"
#include "kernel_common.hpp"
#include "kernel_operator.h"
using namespace Catlass;
template <
class BlockMmadQK,
class BlockMmadPV,
class BlockMmadQKTail,
class BlockMmadPVTail,
class EpilogueOnlineSoftmax,
class EpilogueRescaleO,
bool PAGED_CACHE_FLAG>
class FAInferKernel {
public:
using ArchTag = typename BlockMmadQK::ArchTag;
using L1TileShape = typename BlockMmadQK::L1TileShape;
using ElementQ = typename BlockMmadQK::ElementA;
using LayoutQ = typename BlockMmadQK::LayoutA;
using ElementK = typename BlockMmadQK::ElementB;
using LayoutK = typename BlockMmadQK::LayoutB;
using ElementS = typename BlockMmadQK::ElementC;
using LayoutS = typename BlockMmadQK::LayoutC;
using ElementP = typename BlockMmadPV::ElementA;
using LayoutP = typename BlockMmadPV::LayoutA;
using ElementV = typename BlockMmadPV::ElementB;
using LayoutV = typename BlockMmadPV::LayoutB;
using ElementMask = typename EpilogueOnlineSoftmax::ElementMask;
using LayoutMask = typename EpilogueOnlineSoftmax::LayoutMask;
using ElementO = typename EpilogueRescaleO::ElementOutput;
using LayoutO = typename EpilogueRescaleO::LayoutOutput;
using ElementOTmp = typename EpilogueRescaleO::ElementInput;
using LayoutOTmp = typename EpilogueRescaleO::LayoutInput;
CATLASS_DEVICE
FAInferKernel() {
}
template <int32_t CORE_TYPE = g_coreType>
CATLASS_DEVICE void operator()(FAIKernelParams const ¶ms);
template <>
CATLASS_DEVICE void operator()<AscendC::AIC>(FAIKernelParams const ¶ms) {
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID1);
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID2);
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID3);
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID4);
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID5);
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID6);
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID7);
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(EVENT_ID1);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID1);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID2);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID3);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID4);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID5);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID6);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID7);
static constexpr uint32_t L1_QK_SIZE = BlockMmadQK::L1TileShape::M * BlockMmadQK::L1TileShape::K
* sizeof(ElementQ)
+ BlockMmadQK::L1TileShape::N * BlockMmadQK::L1TileShape::K
* sizeof(ElementK) * 2;
BlockMmadQK blockMmadQK(resource);
BlockMmadPV blockMmadPV(resource, L1_QK_SIZE);
BlockMmadQKTail blockMmadQKTail(resource);
BlockMmadPVTail blockMmadPVTail(resource, L1_QK_SIZE);
__gm__ FATilingData *fATilingData = reinterpret_cast<__gm__ FATilingData *>(params.tiling);
uint64_t mm1OutSize = fATilingData->mm1OutSize;
uint64_t smOnlineOutSize = fATilingData->smOnlineOutSize;
uint32_t batch = fATilingData->batch;
uint32_t qHeads = fATilingData->numHeads;
uint32_t kvHeads = fATilingData->kvHeads;
uint32_t embed = fATilingData->embeddingSize;
uint32_t pagedBlockSize = fATilingData->blockSize;
uint32_t maxNumBlocksPerBatch = fATilingData->maxNumBlocksPerBatch;
uint32_t curTotalTaskNum = fATilingData->firstBatchTaskNum;
uint32_t totalTaskNum = fATilingData->totalTaskNum;
uint32_t blockSize = fATilingData->blockSize;
uint32_t maskType = fATilingData->maskType;
float scaleValue = fATilingData->scaleValue;
AscendC::GlobalTensor<ElementQ> gQ;
gQ.SetGlobalBuffer((__gm__ ElementQ *)params.q);
AscendC::GlobalTensor<ElementK> gK;
gK.SetGlobalBuffer((__gm__ ElementK *)params.k);
AscendC::GlobalTensor<ElementK> gV;
gV.SetGlobalBuffer((__gm__ ElementK *)params.v);
AscendC::GlobalTensor<int32_t> gBlockTable;
gBlockTable.SetGlobalBuffer((__gm__ int32_t *)(params.blockTables));
AscendC::GlobalTensor<int64_t> gActualQseqlen;
gActualQseqlen.SetGlobalBuffer((__gm__ int64_t *)params.actualQseqlen);
AscendC::GlobalTensor<int64_t> gActualKvseqlen;
gActualKvseqlen.SetGlobalBuffer((__gm__ int64_t *)params.actualKvseqlen);
AscendC::GlobalTensor<ElementS> gS;
gS.SetGlobalBuffer((__gm__ ElementS *)params.s);
AscendC::GlobalTensor<ElementP> gP;
gP.SetGlobalBuffer((__gm__ ElementP *)params.p);
AscendC::GlobalTensor<ElementOTmp> gOTmp;
gOTmp.SetGlobalBuffer((__gm__ ElementOTmp *)params.oTemp);
uint64_t strideQO = qHeads * embed;
uint64_t strideKV = kvHeads * embed;
uint32_t embedRound = RoundUp<BLOCK_SIZE>(embed);
uint32_t groupSize = qHeads / kvHeads;
uint32_t coreIdx = AscendC::GetBlockIdx();
uint32_t coreNum = AscendC::GetBlockNum();
curTotalTaskNum = 0;
uint32_t preTotalTaskNum = 0;
uint32_t curBatch = 0;
uint64_t qBOffset = 0;
uint64_t kBOffset = 0;
uint64_t vBOffset = 0;
uint64_t blockBOffset = 0;
int64_t qSeqlen = 0;
int64_t kvSeqlen = 0;
uint32_t curQNBlockTile;
uint32_t qNBlockNumPerGroup;
uint32_t curQNBlockNum;
int64_t curQSBlockTile;
uint32_t curQSBlockNum;
preTotalTaskNum = curTotalTaskNum;
qSeqlen = reinterpret_cast<int64_t>(gActualQseqlen.GetValue(curBatch));
kvSeqlen = reinterpret_cast<int64_t>(gActualKvseqlen.GetValue(curBatch));
curQSBlockTile = GetQSBlockTile(kvSeqlen);
curQNBlockTile = GetQNBlockTile(qSeqlen, groupSize);
qNBlockNumPerGroup = CeilDiv(groupSize, curQNBlockTile);
curQNBlockNum = qNBlockNumPerGroup * kvHeads;
curQSBlockNum = CeilDiv(qSeqlen, curQSBlockTile);
curTotalTaskNum += curQNBlockNum * curQSBlockNum;
for (uint32_t taskIdx = coreIdx; taskIdx < totalTaskNum; taskIdx += uint32_t(coreNum)) {
while (taskIdx >= curTotalTaskNum) {
++curBatch;
preTotalTaskNum = curTotalTaskNum;
qBOffset += qSeqlen * strideQO;
if constexpr (!PAGED_CACHE_FLAG) {
kBOffset += kvSeqlen * strideKV;
vBOffset += kvSeqlen * strideKV;
} else {
blockBOffset += maxNumBlocksPerBatch;
}
qSeqlen = reinterpret_cast<int64_t>(gActualQseqlen.GetValue(curBatch));
kvSeqlen = reinterpret_cast<int64_t>(gActualKvseqlen.GetValue(curBatch));
curQSBlockTile = GetQSBlockTile(kvSeqlen);
curQNBlockTile = GetQNBlockTile(qSeqlen, groupSize);
qNBlockNumPerGroup = CeilDiv(groupSize, curQNBlockTile);
curQNBlockNum = qNBlockNumPerGroup * kvHeads;
curQSBlockNum = CeilDiv(qSeqlen, curQSBlockTile);
curTotalTaskNum += curQNBlockNum * curQSBlockNum;
}
uint32_t taskIdxCurBatch = taskIdx - preTotalTaskNum;
uint32_t qSBlockIdx = taskIdxCurBatch / curQNBlockNum;
uint32_t qNBlockIdx = taskIdxCurBatch - qSBlockIdx * curQNBlockNum;
uint32_t qNBlockIdxCurGroup = qNBlockIdx % qNBlockNumPerGroup;
uint32_t kvHeadIdx = qNBlockIdx / qNBlockNumPerGroup;
uint32_t qHeadIdx = kvHeadIdx * groupSize + qNBlockIdxCurGroup * curQNBlockTile;
uint64_t gmQOffset = qBOffset + qSBlockIdx * curQSBlockTile * strideQO + qHeadIdx * embed;
uint64_t gmKOffset = kBOffset + kvHeadIdx * embed;
uint64_t gmVOffset = vBOffset + kvHeadIdx * embed;
uint32_t qSBlockSize = (qSBlockIdx == (curQSBlockNum - 1)) ? (qSeqlen - qSBlockIdx * curQSBlockTile)
: curQSBlockTile;
uint32_t qNBlockSize = (qNBlockIdxCurGroup == (qNBlockNumPerGroup - 1))
? (groupSize - qNBlockIdxCurGroup * curQNBlockTile)
: curQNBlockTile;
uint32_t rowNum = qSBlockSize * qNBlockSize;
uint32_t rowNumRound = AlignUp(rowNum, BLOCK_SIZE);
uint32_t noSkipKvS = kvSeqlen;
uint32_t noMaskKvS = kvSeqlen;
uint32_t noMaskTailS = 0;
if (maskType != 0) {
uint32_t diffS = kvSeqlen - qSeqlen;
noSkipKvS = (qSBlockIdx + 1) * curQSBlockTile + diffS;
noSkipKvS = Min((uint32_t)kvSeqlen, noSkipKvS);
noMaskKvS = noSkipKvS - qSBlockSize;
noMaskTailS = noMaskKvS % pagedBlockSize;
}
uint32_t maskedKvS = qSBlockSize;
uint32_t kvSLoopNumNoMask = CeilDiv(noMaskKvS, pagedBlockSize);
uint32_t kvSLoopNumTotal = CeilDiv(noSkipKvS, pagedBlockSize);
uint32_t blockStackNum = 4;
uint32_t stackSeqTile;
uint32_t stackSeqTileRound = blockStackNum * 128;
int32_t preLaunch = 2;
int32_t totalStackSeqNum = (maskType != 0) ? (CeilDiv(noMaskKvS, blockStackNum * pagedBlockSize) + 1)
: CeilDiv(noMaskKvS, blockStackNum * pagedBlockSize);
int32_t stackSeqCount = 0;
LayoutQ layoutQTemp(rowNum, embed);
LayoutK layoutKTemp(strideKV, blockStackNum * pagedBlockSize);
LayoutV layoutVTemp(blockStackNum * pagedBlockSize, strideKV);
blockMmadQK.loadQGM(gQ[gmQOffset], layoutQTemp, rowNum, qNBlockSize, qHeads);
for (uint32_t kvSIdx = 0; kvSIdx < kvSLoopNumNoMask; kvSIdx += blockStackNum) {
if (kvSIdx < kvSLoopNumNoMask) {
if (kvSIdx + blockStackNum > kvSLoopNumNoMask - 1) {
stackSeqTile = noMaskKvS - kvSIdx * pagedBlockSize;
} else {
stackSeqTile = pagedBlockSize * blockStackNum;
}
uint32_t SWorkSpacePingPongFlag = stackSeqCount % (preLaunch + 1);
uint64_t gmSOffset = coreIdx * WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1)
+ SWorkSpacePingPongFlag * WORKSPACE_BLOCK_SIZE_DB;
GemmCoord actualBlockShapeQK{rowNum, stackSeqTile, embed};
if constexpr (!PAGED_CACHE_FLAG) {
blockMmadQK(
gQ[gmQOffset], gK[gmKOffset], gS[gmSOffset], gBlockTable, layoutQTemp, layoutKTemp,
actualBlockShapeQK, kvSIdx, kvSLoopNumNoMask, pagedBlockSize, noMaskKvS, strideKV
);
} else {
blockMmadQK(
gQ[gmQOffset], gK[gmKOffset], gS[gmSOffset], gBlockTable[blockBOffset], layoutQTemp,
layoutKTemp, actualBlockShapeQK, kvSIdx, kvSLoopNumNoMask, pagedBlockSize, noMaskKvS,
strideKV
);
}
Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(qkReady);
}
if (kvSIdx >= preLaunch * blockStackNum) {
uint32_t nowkvSIdx = kvSIdx - preLaunch * blockStackNum;
if (nowkvSIdx + blockStackNum > kvSLoopNumNoMask - 1) {
stackSeqTile = noMaskKvS - nowkvSIdx * pagedBlockSize;
} else {
stackSeqTile = pagedBlockSize * blockStackNum;
}
uint32_t PVWorkSpacePingPongFlag = (stackSeqCount - preLaunch) % (preLaunch + 1);
uint64_t gmPOffset = coreIdx * WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1)
+ PVWorkSpacePingPongFlag * WORKSPACE_BLOCK_SIZE_DB;
uint64_t gmOTmpOffset = coreIdx * WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1)
+ PVWorkSpacePingPongFlag * WORKSPACE_BLOCK_SIZE_DB;
LayoutP layoutPTemp(rowNum, stackSeqTileRound);
GemmCoord actualBlockShapePV{rowNum, embed, stackSeqTile};
if constexpr (!PAGED_CACHE_FLAG) {
blockMmadPV(
gP[gmPOffset], gV[gmVOffset], gOTmp[gmOTmpOffset], gBlockTable, layoutPTemp, layoutVTemp,
actualBlockShapePV, nowkvSIdx, kvSLoopNumNoMask, pagedBlockSize, noMaskKvS, strideKV,
softmaxReady
);
} else {
blockMmadPV(
gP[gmPOffset], gV[gmVOffset], gOTmp[gmOTmpOffset], gBlockTable[blockBOffset], layoutPTemp,
layoutVTemp, actualBlockShapePV, nowkvSIdx, kvSLoopNumNoMask, pagedBlockSize, noMaskKvS,
strideKV, softmaxReady
);
}
Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(pvReady);
}
stackSeqCount++;
}
* for the secondary loop
* while masked, it deals the CV stage1(Qk^t/SMOnline) of the final base block(typical shape [128, 512]),
* and the CV stage2(PV/rescaleO) of the last (prelaunch+1) base blocks while not masked, it deals only the
* CV stage2(PV/rescaleO) of the last (prelaunch) base blocks
*/
uint32_t maskedStartIdx = (maskType != 0) ? ((noMaskTailS != 0) ? (kvSLoopNumNoMask - 1) : kvSLoopNumNoMask)
: AlignUp(kvSLoopNumNoMask, blockStackNum);
uint32_t noMaskTailInteStackNum = (noMaskKvS / pagedBlockSize) % blockStackNum;
noMaskTailInteStackNum = (noMaskTailInteStackNum != 0) ? noMaskTailInteStackNum
: ((noMaskTailS != 0) ? 0 : blockStackNum);
uint32_t preLaunchStackNum = (maskType != 0) ? ((preLaunch - 1) * blockStackNum + noMaskTailInteStackNum)
: (preLaunch * blockStackNum);
for (uint32_t kvSIdx = maskedStartIdx; kvSIdx < kvSLoopNumTotal + preLaunchStackNum;) {
if ((kvSIdx < kvSLoopNumTotal) && (stackSeqCount <= totalStackSeqNum - 1)) {
stackSeqTile = maskedKvS;
uint32_t SWorkSpacePingPongFlag = stackSeqCount % (preLaunch + 1);
uint64_t gmSOffset = coreIdx * WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1)
+ SWorkSpacePingPongFlag * WORKSPACE_BLOCK_SIZE_DB;
GemmCoord actualBlockShapeQK{rowNum, stackSeqTile, embed};
if constexpr (!PAGED_CACHE_FLAG) {
blockMmadQKTail(
gQ[gmQOffset], gK[gmKOffset], gS[gmSOffset], gBlockTable, layoutQTemp, layoutKTemp,
actualBlockShapeQK, kvSIdx, kvSLoopNumTotal, pagedBlockSize, noSkipKvS, strideKV,
noMaskTailS, 1
);
} else {
blockMmadQKTail(
gQ[gmQOffset], gK[gmKOffset], gS[gmSOffset], gBlockTable[blockBOffset], layoutQTemp,
layoutKTemp, actualBlockShapeQK, kvSIdx, kvSLoopNumTotal, pagedBlockSize, noSkipKvS,
strideKV, noMaskTailS, 1
);
}
Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(qkReady);
}
if (kvSIdx >= preLaunchStackNum) {
uint32_t delayedKvSIdx = kvSIdx - preLaunchStackNum;
if (delayedKvSIdx + blockStackNum > kvSLoopNumTotal - 1 && (maskType != 0)) {
stackSeqTile = maskedKvS;
} else if (delayedKvSIdx + blockStackNum > kvSLoopNumNoMask - 1) {
stackSeqTile = noMaskKvS - delayedKvSIdx * pagedBlockSize;
} else {
stackSeqTile = pagedBlockSize * blockStackNum;
}
uint32_t PVWorkSpacePingPongFlag = (stackSeqCount - preLaunch) % (preLaunch + 1);
uint64_t gmPOffset = coreIdx * WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1)
+ PVWorkSpacePingPongFlag * WORKSPACE_BLOCK_SIZE_DB;
uint64_t gmOTmpOffset = coreIdx * WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1)
+ PVWorkSpacePingPongFlag * WORKSPACE_BLOCK_SIZE_DB;
LayoutP layoutPTemp(rowNum, stackSeqTileRound);
GemmCoord actualBlockShapePV{rowNum, embed, stackSeqTile};
if ((stackSeqCount - preLaunch == totalStackSeqNum - 1) && (maskType != 0)) {
if constexpr (!PAGED_CACHE_FLAG) {
blockMmadPVTail(
gP[gmPOffset], gV[gmVOffset], gOTmp[gmOTmpOffset], gBlockTable, layoutPTemp,
layoutVTemp, actualBlockShapePV, delayedKvSIdx, kvSLoopNumTotal, pagedBlockSize,
noSkipKvS, strideKV, softmaxReady, noMaskTailS, 1
);
} else {
blockMmadPVTail(
gP[gmPOffset], gV[gmVOffset], gOTmp[gmOTmpOffset], gBlockTable[blockBOffset],
layoutPTemp, layoutVTemp, actualBlockShapePV, delayedKvSIdx, kvSLoopNumTotal,
pagedBlockSize, noSkipKvS, strideKV, softmaxReady, noMaskTailS, 1
);
}
} else {
if constexpr (!PAGED_CACHE_FLAG) {
blockMmadPV(
gP[gmPOffset], gV[gmVOffset], gOTmp[gmOTmpOffset], gBlockTable, layoutPTemp,
layoutVTemp, actualBlockShapePV, delayedKvSIdx, kvSLoopNumNoMask, pagedBlockSize,
noMaskKvS, strideKV, softmaxReady
);
} else {
blockMmadPV(
gP[gmPOffset], gV[gmVOffset], gOTmp[gmOTmpOffset], gBlockTable[blockBOffset],
layoutPTemp, layoutVTemp, actualBlockShapePV, delayedKvSIdx, kvSLoopNumNoMask,
pagedBlockSize, noMaskKvS, strideKV, softmaxReady
);
}
}
Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(pvReady);
}
if ((maskType != 0) && (stackSeqCount - preLaunch == totalStackSeqNum - 2)) {
kvSIdx += noMaskTailInteStackNum;
} else {
kvSIdx += blockStackNum;
}
stackSeqCount++;
}
}
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID1);
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID3);
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID4);
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID5);
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID6);
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID7);
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(EVENT_ID1);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID1);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID3);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID4);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID5);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID6);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID7);
}
template <>
CATLASS_DEVICE void operator()<AscendC::AIV>(FAIKernelParams const ¶ms) {
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID2);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID3);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID4);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID5);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
__gm__ FATilingData *fATilingData = reinterpret_cast<__gm__ FATilingData *>(params.tiling);
uint64_t mm1OutSize = fATilingData->mm1OutSize;
uint64_t smOnlineOutSize = fATilingData->smOnlineOutSize;
uint64_t mm2OutSize = fATilingData->mm2OutSize;
uint32_t batch = fATilingData->batch;
uint32_t qHeads = fATilingData->numHeads;
uint32_t kvHeads = fATilingData->kvHeads;
uint32_t embed = fATilingData->embeddingSize;
uint32_t pagedBlockSize = fATilingData->blockSize;
uint32_t maxNumBlocksPerBatch = fATilingData->maxNumBlocksPerBatch;
uint32_t firstBatchTaskNum = fATilingData->firstBatchTaskNum;
uint32_t totalTaskNum = fATilingData->totalTaskNum;
uint32_t maskType = fATilingData->maskType;
float scaleValue = fATilingData->scaleValue;
AscendC::GlobalTensor<ElementMask> gMask;
gMask.SetGlobalBuffer((__gm__ ElementMask *)params.mask);
AscendC::GlobalTensor<int64_t> gActualQseqlen;
gActualQseqlen.SetGlobalBuffer((__gm__ int64_t *)params.actualQseqlen);
AscendC::GlobalTensor<int64_t> gActualKvseqlen;
gActualKvseqlen.SetGlobalBuffer((__gm__ int64_t *)params.actualKvseqlen);
AscendC::GlobalTensor<ElementO> gO;
gO.SetGlobalBuffer((__gm__ ElementO *)params.o);
AscendC::GlobalTensor<ElementS> gS;
gS.SetGlobalBuffer((__gm__ ElementS *)params.s);
AscendC::GlobalTensor<ElementP> gP;
gP.SetGlobalBuffer((__gm__ ElementP *)params.p);
AscendC::GlobalTensor<ElementOTmp> gOTmp;
gOTmp.SetGlobalBuffer((__gm__ ElementOTmp *)params.oTemp);
AscendC::GlobalTensor<ElementOTmp> gOUpdate;
gOUpdate.SetGlobalBuffer((__gm__ ElementOTmp *)params.oUpdate);
uint32_t groupSize = qHeads / kvHeads;
uint32_t embedRound = RoundUp(embed, BLOCK_SIZE);
EpilogueOnlineSoftmax epilogueOnlineSoftmax(resource, scaleValue);
EpilogueRescaleO epilogueRescaleO(resource);
uint32_t preTotalTaskNum = 0;
uint32_t curBatch = 0;
uint32_t oBatchOffset = 0;
uint32_t qSeqlen = static_cast<uint32_t>(gActualQseqlen.GetValue(curBatch));
uint32_t kvSeqlen = static_cast<uint32_t>(gActualKvseqlen.GetValue(curBatch));
uint32_t curQNBlockTile = GetQNBlockTile(qSeqlen, groupSize);
uint32_t qNBlockNumPerGroup = CeilDiv(groupSize, curQNBlockTile);
uint32_t curQNBlockNum = qNBlockNumPerGroup * kvHeads;
uint32_t curQSBlockTile = GetQSBlockTile(kvSeqlen);
uint32_t curQSBlockNum = CeilDiv(qSeqlen, curQSBlockTile);
uint32_t curTotalTaskNum = firstBatchTaskNum;
uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum();
uint32_t coreNum = AscendC::GetBlockNum();
for (uint32_t taskIdx = coreIdx; taskIdx < totalTaskNum; taskIdx += uint32_t(coreNum)) {
while (taskIdx >= curTotalTaskNum) {
curBatch++;
oBatchOffset += qSeqlen * qHeads * embed;
preTotalTaskNum = curTotalTaskNum;
qSeqlen = static_cast<uint32_t>(gActualQseqlen.GetValue(curBatch));
kvSeqlen = static_cast<uint32_t>(gActualKvseqlen.GetValue(curBatch));
curQNBlockTile = GetQNBlockTile(qSeqlen, groupSize);
qNBlockNumPerGroup = CeilDiv(groupSize, curQNBlockTile);
curQNBlockNum = qNBlockNumPerGroup * kvHeads;
curQSBlockTile = GetQSBlockTile(kvSeqlen);
curQSBlockNum = CeilDiv(qSeqlen, curQSBlockTile);
curTotalTaskNum += curQNBlockNum * curQSBlockNum;
}
uint32_t taskIdxCurBatch = taskIdx - preTotalTaskNum;
uint32_t qSBlockIdx = taskIdxCurBatch / curQNBlockNum;
uint32_t qNBlockIdx = taskIdxCurBatch % curQNBlockNum;
uint32_t qNBlockIdxCurGroup = qNBlockIdx % qNBlockNumPerGroup;
uint32_t oSOffset = qSBlockIdx * curQSBlockTile * qHeads * embed;
uint32_t kvNIdx = qNBlockIdx / qNBlockNumPerGroup;
uint32_t qStartNIdx = kvNIdx * groupSize + qNBlockIdxCurGroup * curQNBlockTile;
uint32_t oNOffset = qStartNIdx * embed;
uint32_t gmOffsetO = oBatchOffset + oSOffset + oNOffset;
uint32_t qSBlockSize = (qSBlockIdx == (curQSBlockNum - 1)) ? (qSeqlen - qSBlockIdx * curQSBlockTile)
: curQSBlockTile;
uint32_t qNBlockSize = (qNBlockIdxCurGroup == (qNBlockNumPerGroup - 1))
? (groupSize - qNBlockIdxCurGroup * curQNBlockTile)
: curQNBlockTile;
uint32_t rowNum = qSBlockSize * qNBlockSize;
uint32_t rowNumRound = RoundUp(rowNum, BLOCK_SIZE);
uint32_t noSkipKvS = kvSeqlen;
uint32_t noMaskKvS = kvSeqlen;
uint32_t noMaskTailS = 0;
if (maskType != 0) {
uint32_t diffS = kvSeqlen - qSeqlen;
noSkipKvS = (qSBlockIdx + 1) * curQSBlockTile + diffS;
noSkipKvS = Min(kvSeqlen, noSkipKvS);
noMaskKvS = noSkipKvS - qSBlockSize;
noMaskTailS = noMaskKvS % pagedBlockSize;
}
uint32_t maskedKvS = qSBlockSize;
uint32_t kvSLoopNumTotal = CeilDiv(noSkipKvS, pagedBlockSize);
uint32_t kvSLoopNumNoMask = CeilDiv(noMaskKvS, pagedBlockSize);
uint32_t blockStackNum = 4;
uint32_t stackSeqTilePad = blockStackNum * pagedBlockSize;
uint32_t stackSeqTile;
int32_t preLaunch = 2;
int32_t totalStackSeqNum = (maskType != 0) ? (CeilDiv(noMaskKvS, blockStackNum * pagedBlockSize) + 1)
: CeilDiv(noMaskKvS, blockStackNum * pagedBlockSize);
int32_t stackSeqCount = 0;
for (uint32_t kvSIdx = 0; kvSIdx < kvSLoopNumNoMask; kvSIdx += blockStackNum) {
if (kvSIdx + blockStackNum > kvSLoopNumNoMask - 1) {
stackSeqTile = noMaskKvS - kvSIdx * pagedBlockSize;
} else {
stackSeqTile = pagedBlockSize * blockStackNum;
}
uint32_t stackSeqTileRound = RoundUp(stackSeqTile, BLOCK_SIZE);
LayoutS layOutS(rowNum, stackSeqTile, stackSeqTilePad);
LayoutP layOutP(rowNum, stackSeqTile, stackSeqTilePad);
GemmCoord actualBlockShapeQK{rowNum, stackSeqTile, embed};
uint32_t curStackTileMod = stackSeqCount % (preLaunch + 1);
uint32_t gmOffsetS = coreIdx * WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1) +
curStackTileMod * WORKSPACE_BLOCK_SIZE_DB;
uint32_t gmOffsetP = gmOffsetS;
Arch::CrossCoreWaitFlag(qkReady);
epilogueOnlineSoftmax(
gP[gmOffsetP], gS[gmOffsetS], layOutP, layOutS, actualBlockShapeQK, (stackSeqCount == 0),
qSBlockSize, qNBlockSize, curStackTileMod
);
Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(softmaxReady);
if (kvSIdx >= preLaunch * blockStackNum) {
uint32_t delayedKvSIdx = kvSIdx - preLaunch * blockStackNum;
if (delayedKvSIdx + blockStackNum > kvSLoopNumNoMask - 1) {
stackSeqTile = noMaskKvS - kvSIdx * pagedBlockSize;
} else {
stackSeqTile = pagedBlockSize * blockStackNum;
}
LayoutO layoutO(qSeqlen, embed * qHeads);
LayoutOTmp layoutOTmp(rowNum, embed, embedRound);
GemmCoord actualBlockShapePV{rowNum, embed, stackSeqTile};
uint32_t curStackTileMod = (stackSeqCount - preLaunch) % (preLaunch + 1);
uint32_t gmOffsetOTmp = coreIdx * WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1)
+ curStackTileMod * WORKSPACE_BLOCK_SIZE_DB;
Arch::CrossCoreWaitFlag(pvReady);
epilogueRescaleO(
gO[gmOffsetO], gOTmp[gmOffsetOTmp], layoutO, layoutOTmp, actualBlockShapePV, qSBlockSize,
qNBlockSize, (stackSeqCount - preLaunch == 0), 0, curStackTileMod
);
}
stackSeqCount++;
}
* for the secondary loop
* while masked, it deals the CV stage1(Qk^t/SMOnline) of the final base block(typical shape [128, 512]),
* and the CV stage2(PV/rescaleO) of the last (prelaunch+1) base blocks while unmasked, it deals the CV
* stage1(Qk^t/SMOnline) of the last (prelaunch+1) base blocks
*/
uint32_t maskedStartIdx = (maskType != 0) ? ((noMaskTailS != 0) ? (kvSLoopNumNoMask - 1) : kvSLoopNumNoMask)
: AlignUp(kvSLoopNumNoMask, blockStackNum);
uint32_t noMaskTailInteStackNum = (noMaskKvS / pagedBlockSize) % blockStackNum;
noMaskTailInteStackNum = (noMaskTailInteStackNum != 0) ? noMaskTailInteStackNum
: ((noMaskTailS != 0) ? 0 : blockStackNum);
uint32_t preLaunchStackNum = (maskType != 0) ? ((preLaunch - 1) * blockStackNum + noMaskTailInteStackNum)
: (preLaunch * blockStackNum);
for (uint32_t kvSIdx = maskedStartIdx; kvSIdx < kvSLoopNumTotal + preLaunchStackNum;) {
if ((kvSIdx < kvSLoopNumTotal) && (stackSeqCount <= totalStackSeqNum - 1)) {
stackSeqTile = maskedKvS;
uint32_t stackSeqTileRound = RoundUp(stackSeqTile, BLOCK_SIZE);
LayoutS layOutS(rowNum, stackSeqTile, stackSeqTilePad);
LayoutP layOutP(rowNum, stackSeqTile, stackSeqTilePad);
LayoutMask layOutMask(1024, 1024, 1024);
GemmCoord actualBlockShapeQK{rowNum, stackSeqTile, embed};
uint32_t curStackTileMod = stackSeqCount % (preLaunch + 1);
uint32_t gmOffsetS = coreIdx * WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1) +
curStackTileMod * WORKSPACE_BLOCK_SIZE_DB;
uint32_t gmOffsetP = gmOffsetS;
epilogueOnlineSoftmax(
gP[gmOffsetP], gS[gmOffsetS], gMask, layOutP, layOutS, layOutMask, actualBlockShapeQK,
(stackSeqCount == 0), qSBlockSize, qNBlockSize, curStackTileMod, qkReady
);
Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(softmaxReady);
}
if (kvSIdx >= preLaunchStackNum) {
uint32_t delayedKvSIdx = kvSIdx - preLaunchStackNum;
if (delayedKvSIdx + blockStackNum > kvSLoopNumTotal - 1 && (maskType != 0)) {
stackSeqTile = maskedKvS;
} else if (delayedKvSIdx + blockStackNum > kvSLoopNumNoMask - 1) {
stackSeqTile = noMaskKvS - delayedKvSIdx * pagedBlockSize;
} else {
stackSeqTile = pagedBlockSize * blockStackNum;
}
LayoutO layoutO(qSBlockSize, embed * qHeads);
LayoutOTmp layoutOTmp(rowNum, embed, embedRound);
GemmCoord actualBlockShapePV{rowNum, embed, stackSeqTile};
uint32_t curStackTileMod = (stackSeqCount - preLaunch) % (preLaunch + 1);
uint32_t gmOffsetOTmp = coreIdx * WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1)
+ curStackTileMod * WORKSPACE_BLOCK_SIZE_DB;
Arch::CrossCoreWaitFlag(pvReady);
epilogueRescaleO(
gO[gmOffsetO], gOTmp[gmOffsetOTmp], layoutO, layoutOTmp, actualBlockShapePV, qSBlockSize,
qNBlockSize, (stackSeqCount - preLaunch == 0),
(stackSeqCount - preLaunch == totalStackSeqNum - 1), curStackTileMod
);
}
if ((maskType != 0) && (stackSeqCount - preLaunch == totalStackSeqNum - 2)) {
kvSIdx += noMaskTailInteStackNum;
} else {
kvSIdx += blockStackNum;
}
stackSeqCount++;
}
}
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID3);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID4);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID5);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
}
private:
Arch::Resource<ArchTag> resource;
Arch::CrossCoreFlag qkReady{QK_READY_ID};
Arch::CrossCoreFlag softmaxReady{SOFTMAX_READY_ID};
Arch::CrossCoreFlag pvReady{PV_READY_ID};
};
CATLASS_GLOBAL void FAInferFp16(
uint64_t fftsAddr,
GM_ADDR q,
GM_ADDR k,
GM_ADDR v,
GM_ADDR mask,
GM_ADDR blockTables,
GM_ADDR o,
GM_ADDR actualQseqlen,
GM_ADDR actualKvseqlen,
GM_ADDR s,
GM_ADDR p,
GM_ADDR oTemp,
GM_ADDR oUpdate,
GM_ADDR tiling
) {
AscendC::SetSyncBaseAddr(fftsAddr);
using ArchTag = Arch::AtlasA2;
using ElementQ = half;
using LayoutQ = layout::RowMajor;
using ElementK = half;
using LayoutK = layout::ColumnMajor;
using ElementV = half;
using LayoutV = layout::RowMajor;
using ElementS = float;
using LayoutS = layout::RowMajor;
using ElementP = half;
using LayoutP = layout::RowMajor;
using ElementO = half;
using LayoutO = layout::RowMajor;
using ElementMask = half;
using LayoutMask = layout::RowMajor;
using ElementOTmp = float;
using LayoutOTmp = layout::RowMajor;
using ElementUpdate = float;
using LayoutUpdate = layout::RowMajor;
using L1TileShape = GemmShape<128, 128, 128>;
using L0TileShape = L1TileShape;
using DispatchPolicyQK = Gemm::MmadAtlasA2FAIQK<true, false>;
using QType = Gemm::GemmType<ElementQ, LayoutQ>;
using KType = Gemm::GemmType<ElementK, LayoutK>;
using SType = Gemm::GemmType<ElementS, LayoutS>;
using BlockMmadQK = Gemm::Block::BlockMmad<DispatchPolicyQK, L1TileShape, L0TileShape, QType, KType, SType>;
using DispatchPolicyQKTail = Gemm::MmadAtlasA2FAITailQK<true, false>;
using QType = Gemm::GemmType<ElementQ, LayoutQ>;
using KType = Gemm::GemmType<ElementK, LayoutK>;
using SType = Gemm::GemmType<ElementS, LayoutS>;
using BlockMmadQKTail = Gemm::Block::BlockMmad<DispatchPolicyQKTail, L1TileShape, L0TileShape, QType, KType, SType>;
using DispatchPolicyOnlineSoftmax = Epilogue::EpilogueAtlasA2OnlineSoftmax;
using PType = Gemm::GemmType<ElementP, LayoutP>;
using maskType = Gemm::GemmType<ElementMask, LayoutMask>;
using EpilogueOnlineSoftmax = Epilogue::Block::BlockEpilogue<DispatchPolicyOnlineSoftmax, PType, SType, maskType>;
using DispatchPolicyPV = Gemm::MmadAtlasA2FAIPV<true, false>;
using VType = Gemm::GemmType<ElementV, LayoutV>;
using OTmpType = Gemm::GemmType<ElementOTmp, LayoutOTmp>;
using BlockMmadPV = Gemm::Block::BlockMmad<DispatchPolicyPV, L1TileShape, L0TileShape, PType, VType, OTmpType>;
using DispatchPolicyPVTail = Gemm::MmadAtlasA2FAITailPV<true, false>;
using VType = Gemm::GemmType<ElementV, LayoutV>;
using OTmpType = Gemm::GemmType<ElementOTmp, LayoutOTmp>;
using BlockMmadPVTail =
Gemm::Block::BlockMmad<DispatchPolicyPVTail, L1TileShape, L0TileShape, PType, VType, OTmpType>;
using DispatchPolicyRescaleO = Epilogue::EpilogueAtlasA2RescaleO;
using OType = Gemm::GemmType<ElementO, LayoutO>;
using OUpdateType = Gemm::GemmType<ElementUpdate, LayoutUpdate>;
using EpilogueRescaleO = Epilogue::Block::BlockEpilogue<DispatchPolicyRescaleO, OType, OTmpType, OUpdateType>;
using FAInferKernel = FAInferKernel<
BlockMmadQK, BlockMmadPV, BlockMmadQKTail, BlockMmadPVTail, EpilogueOnlineSoftmax, EpilogueRescaleO, true>;
FAIKernelParams params{q, k, v, mask, blockTables, actualQseqlen, actualKvseqlen, o, s, p, oTemp, oUpdate, tiling};
FAInferKernel flashAttnInfer;
flashAttnInfer(params);
}
CATLASS_GLOBAL void FAInferBf16(
uint64_t fftsAddr,
GM_ADDR q,
GM_ADDR k,
GM_ADDR v,
GM_ADDR mask,
GM_ADDR blockTables,
GM_ADDR o,
GM_ADDR actualQseqlen,
GM_ADDR actualKvseqlen,
GM_ADDR s,
GM_ADDR p,
GM_ADDR oTemp,
GM_ADDR oUpdate,
GM_ADDR tiling
) {
AscendC::SetSyncBaseAddr(fftsAddr);
using ArchTag = Arch::AtlasA2;
using ElementQ = bfloat16_t;
using LayoutQ = layout::RowMajor;
using ElementK = bfloat16_t;
using LayoutK = layout::ColumnMajor;
using ElementV = bfloat16_t;
using LayoutV = layout::RowMajor;
using ElementS = float;
using LayoutS = layout::RowMajor;
using ElementP = bfloat16_t;
using LayoutP = layout::RowMajor;
using ElementO = bfloat16_t;
using LayoutO = layout::RowMajor;
using ElementMask = bfloat16_t;
using LayoutMask = layout::RowMajor;
using ElementOTmp = float;
using LayoutOTmp = layout::RowMajor;
using ElementUpdate = float;
using LayoutUpdate = layout::RowMajor;
using L1TileShape = GemmShape<128, 128, 128>;
using L0TileShape = L1TileShape;
using DispatchPolicyQK = Gemm::MmadAtlasA2FAIQK<true, false>;
using QType = Gemm::GemmType<ElementQ, LayoutQ>;
using KType = Gemm::GemmType<ElementK, LayoutK>;
using SType = Gemm::GemmType<ElementS, LayoutS>;
using BlockMmadQK = Gemm::Block::BlockMmad<DispatchPolicyQK, L1TileShape, L0TileShape, QType, KType, SType>;
using DispatchPolicyQKTail = Gemm::MmadAtlasA2FAITailQK<true, false>;
using QType = Gemm::GemmType<ElementQ, LayoutQ>;
using KType = Gemm::GemmType<ElementK, LayoutK>;
using SType = Gemm::GemmType<ElementS, LayoutS>;
using BlockMmadQKTail = Gemm::Block::BlockMmad<DispatchPolicyQKTail, L1TileShape, L0TileShape, QType, KType, SType>;
using DispatchPolicyOnlineSoftmax = Epilogue::EpilogueAtlasA2OnlineSoftmax;
using PType = Gemm::GemmType<ElementP, LayoutP>;
using maskType = Gemm::GemmType<ElementMask, LayoutMask>;
using EpilogueOnlineSoftmax = Epilogue::Block::BlockEpilogue<DispatchPolicyOnlineSoftmax, PType, SType, maskType>;
using DispatchPolicyPV = Gemm::MmadAtlasA2FAIPV<true, false>;
using VType = Gemm::GemmType<ElementV, LayoutV>;
using OTmpType = Gemm::GemmType<ElementOTmp, LayoutOTmp>;
using BlockMmadPV = Gemm::Block::BlockMmad<DispatchPolicyPV, L1TileShape, L0TileShape, PType, VType, OTmpType>;
using DispatchPolicyPVTail = Gemm::MmadAtlasA2FAITailPV<true, false>;
using VType = Gemm::GemmType<ElementV, LayoutV>;
using OTmpType = Gemm::GemmType<ElementOTmp, LayoutOTmp>;
using BlockMmadPVTail =
Gemm::Block::BlockMmad<DispatchPolicyPVTail, L1TileShape, L0TileShape, PType, VType, OTmpType>;
using DispatchPolicyRescaleO = Epilogue::EpilogueAtlasA2RescaleO;
using OType = Gemm::GemmType<ElementO, LayoutO>;
using OUpdateType = Gemm::GemmType<ElementUpdate, LayoutUpdate>;
using EpilogueRescaleO = Epilogue::Block::BlockEpilogue<DispatchPolicyRescaleO, OType, OTmpType, OUpdateType>;
using FAInferKernel = FAInferKernel<
BlockMmadQK, BlockMmadPV, BlockMmadQKTail, BlockMmadPVTail, EpilogueOnlineSoftmax, EpilogueRescaleO, true>;
FAIKernelParams params{q, k, v, mask, blockTables, actualQseqlen, actualKvseqlen, o, s, p, oTemp, oUpdate, tiling};
FAInferKernel flashAttnInfer;
flashAttnInfer(params);
}