Copyright (c) 2026 Huawei Technologies Co., Ltd.
This program is free software, you can redistribute it and/or modify it under the terms and conditions of
CANN Open Software License Agreement Version 2.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License.
*/
#include <acl/acl.h>
#include <pto/pto-inst.hpp>
#include <pto/common/fifo.hpp>
#include "fa_performance_kernel.h"
#include <pto/npu/kernels/Pto_prefetch.hpp>
#include <pto/npu/a5/custom/TSyncCVID.hpp>
#include <pto/npu/a5/custom/TSync_Custom.hpp>
#define UF_ENABLE 0
#include "pto_macro_matmul.hpp"
#include "pto_macro_fa_softmax.hpp"
#include "pto_macro_fa_gu.hpp"
using namespace std;
using namespace pto;
#ifndef FIFO_MODE
#define FIFO_MODE 2
#endif
#if (FIFO_MODE < 0) || (FIFO_MODE > 2)
#error "FIFO_MODE must be 0 (ALL_GM_PATH), 1 (ALL_UB_PATH), or 2 (QK_PV_UB_ONLY)"
#endif
#if ((FIFO_MODE == 1) || (FIFO_MODE == 2)) && (UF_ENABLE != 0)
#error "UF_ENABLE must be 0 for mode 1 (ALL_UB_PATH), and mode 2 (QK_PV_UB_ONLY)"
#endif
#if FIFO_MODE == 0
#define USE_L0C_TO_DUAL_UB_PATH_QK 0
#define USE_L0C_TO_UB_PV_PATH 0
#define USE_UB_TO_L1_PATH 0
#elif FIFO_MODE == 1
#define USE_L0C_TO_DUAL_UB_PATH_QK 1
#define USE_L0C_TO_UB_PV_PATH 1
#define USE_UB_TO_L1_PATH 1
#else
#define USE_L0C_TO_DUAL_UB_PATH_QK 1
#define USE_L0C_TO_UB_PV_PATH 1
#define USE_UB_TO_L1_PATH 0
#endif
#ifndef FFTS_BUFFER_FLAG_ENUM
#define FFTS_BUFFER_FLAG_ENUM
enum FftsBufferFlag : uint32_t
{
BUF0_QK_READY = 0,
BUF1_SM_READY = 2,
UPDATE_READY = 4,
UB_BUF_READY = 6,
PV_UB_BUF_READY = 8,
CV_BLOCK_END = 10,
};
#endif
enum CoreEvtID : uint32_t
{
QK_EVENT_ID0,
QK_EVENT_ID1,
PV_EVENT_ID0,
PV_EVENT_ID1,
};
#ifndef PTO_INLINE
#define PTO_INLINE __attribute__((always_inline)) inline
#endif
#ifdef __DAV_CUBE__
constexpr bool DAV_CUBE = true;
#else
constexpr bool DAV_CUBE = false;
#endif
#ifdef __DAV_VEC__
constexpr bool DAV_VEC = true;
#else
constexpr bool DAV_VEC = false;
#endif
constexpr std::size_t MAX_TILE_L1_BYTES = 512U * 1024U;
constexpr std::size_t MAX_VEC_UB_BYTES = 256U * 1024U;
template <int FifoSize, int SyncPeriod>
AICORE inline bool should_wait_consumption(int sync_iter)
{
static_assert(FifoSize >= 1, "CV FIFO size must be >= 1");
constexpr int period = (SyncPeriod > 0) ? SyncPeriod : 1;
static_assert(period >= 1, "CV FIFO consume sync period must be >= 1");
if (sync_iter < static_cast<int>(FifoSize))
return false;
return (sync_iter % period) == 0;
}
template <int FifoSize, int SyncPeriod>
AICORE inline bool should_notify_consumption(int sync_iter)
{
static_assert(FifoSize >= 1, "CV FIFO size must be >= 1");
constexpr int period = (SyncPeriod > 0) ? SyncPeriod : 1;
static_assert(period >= 1, "CV FIFO consume sync period must be >= 1");
return ((sync_iter + 1) % period) == 0;
}
AICORE inline int pending_consumption_events(int tiles_processed, int fifo_size, int sync_period)
{
if (tiles_processed <= 0 || sync_period <= 0 || fifo_size <= 0)
return 0;
const int notify_count = tiles_processed / sync_period;
int wait_count = 0;
if (tiles_processed > fifo_size) {
const int last_iter = tiles_processed - 1;
wait_count = (last_iter / sync_period) - ((fifo_size - 1) / sync_period);
if (wait_count < 0)
wait_count = 0;
}
int pending = notify_count - wait_count;
if (pending < 0)
pending = 0;
const int max_pending = (fifo_size + sync_period - 1) / sync_period;
return (pending > max_pending) ? max_pending : pending;
}
template <typename TileType>
constexpr AICORE std::size_t tile_storage_bytes()
{
using ElementType = typename TileType::DType;
return static_cast<std::size_t>(TileType::Rows * TileType::Cols) * sizeof(ElementType);
}
template <typename TileType, std::size_t NumBuffers>
constexpr AICORE std::size_t tile_buffer_total_bytes()
{
return tile_storage_bytes<TileType>() * NumBuffers;
}
template <typename TileType, std::size_t NumBuffers>
AICORE inline uint32_t assign_tile_buffers(TileType (&tiles)[NumBuffers], uint32_t base_offset)
{
if constexpr (NumBuffers == 0) {
return base_offset;
}
constexpr std::size_t total_storage_bytes = tile_buffer_total_bytes<TileType, NumBuffers>();
static_assert(total_storage_bytes <= MAX_TILE_L1_BYTES, "Tile buffer L1 allocation exceeds 512KB");
for (std::size_t idx = 0; idx < NumBuffers; ++idx) {
const uint32_t tile_offset = base_offset + static_cast<uint32_t>(idx * tile_storage_bytes<TileType>());
TASSIGN(tiles[idx], tile_offset);
}
return base_offset + static_cast<uint32_t>(total_storage_bytes);
}
template <typename TileA, std::size_t NumA, typename TileB, std::size_t NumB>
AICORE inline uint32_t assign_tile_buffers_union(TileA (&tilesA)[NumA], TileB (&tilesB)[NumB], uint32_t base_offset)
{
static_assert(NumA == NumB, "Union assignment expects matching buffer counts");
if constexpr (NumA == 0) {
return base_offset;
}
constexpr std::size_t stride_bytes = (tile_storage_bytes<TileA>() > tile_storage_bytes<TileB>()) ?
tile_storage_bytes<TileA>() :
tile_storage_bytes<TileB>();
constexpr std::size_t total_storage_bytes = stride_bytes * NumA;
static_assert(total_storage_bytes <= MAX_VEC_UB_BYTES, "Union tile UB allocation exceeds 256KB");
for (std::size_t idx = 0; idx < NumA; ++idx) {
const uint32_t tile_offset = base_offset + static_cast<uint32_t>(idx * stride_bytes);
TASSIGN(tilesA[idx], tile_offset);
TASSIGN(tilesB[idx], tile_offset);
}
return base_offset + static_cast<uint32_t>(total_storage_bytes);
}
template <typename TileQType, std::size_t NumQ, typename TileKType, std::size_t NumK, typename TilePType,
std::size_t NumP, typename TileVType, std::size_t NumV>
AICORE inline void allocate_cube_tile_buffers(TileQType (&qTiles)[NumQ], TileKType (&kTiles)[NumK],
TilePType (&pTiles)[NumP], TileVType (&vTiles)[NumV])
{
constexpr std::size_t total_bytes =
tile_buffer_total_bytes<TileQType, NumQ>() + tile_buffer_total_bytes<TileKType, NumK>() +
tile_buffer_total_bytes<TilePType, NumP>() + tile_buffer_total_bytes<TileVType, NumV>();
static_assert(total_bytes <= MAX_TILE_L1_BYTES, "Total cube L1 allocation exceeds 512KB");
uint32_t l1_offset = 0;
l1_offset = assign_tile_buffers(qTiles, l1_offset);
l1_offset = assign_tile_buffers(kTiles, l1_offset);
l1_offset = assign_tile_buffers(pTiles, l1_offset);
l1_offset = assign_tile_buffers(vTiles, l1_offset);
(void)l1_offset;
}
template <typename TileDataF_T, typename ReduceTileF_T, typename TileDataH_T, typename TileOutT, std::size_t SrcBuffers,
std::size_t XexpBuffers, std::size_t pvVecBuffers, std::size_t ExpMaxBuffers>
AICORE inline void allocate_vec_tile_buffers(TileDataF_T (&srcTiles)[SrcBuffers], ReduceTileF_T &m1_local_max,
TileDataF_T &input_reduce_tmp, ReduceTileF_T &l1_local_sum,
ReduceTileF_T &m2_global_max, ReduceTileF_T &l2_global_sum,
ReduceTileF_T (&l1_exp_max)[ExpMaxBuffers],
TileDataH_T (&x_expT)[XexpBuffers], TileOutT (&pvTile)[pvVecBuffers],
TileOutT &runningOTile)
{
constexpr std::size_t float_tile_bytes = tile_storage_bytes<TileDataF_T>();
constexpr std::size_t reduce_tile_bytes = tile_storage_bytes<ReduceTileF_T>();
constexpr std::size_t xexp_bytes = tile_buffer_total_bytes<TileDataH_T, XexpBuffers>();
constexpr std::size_t out_tile_bytes = tile_storage_bytes<TileOutT>();
static_assert(SrcBuffers == pvVecBuffers, "src/pv buffer counts must match");
#if USE_L0C_TO_DUAL_UB_PATH_QK
constexpr std::size_t src_bytes = tile_buffer_total_bytes<TileDataF_T, SrcBuffers>();
constexpr std::size_t pv_bytes = tile_buffer_total_bytes<TileOutT, pvVecBuffers>();
constexpr std::size_t total_bytes = src_bytes + pv_bytes + xexp_bytes + (reduce_tile_bytes * (3U + ExpMaxBuffers)) +
(float_tile_bytes * 1U) + out_tile_bytes;
static_assert(total_bytes <= MAX_VEC_UB_BYTES, "Vec tile UB allocation exceeds 256KB");
uint32_t offset = 0;
offset = assign_tile_buffers(srcTiles, offset);
TASSIGN(runningOTile, offset);
offset += out_tile_bytes;
offset = assign_tile_buffers(pvTile, offset);
#else
constexpr std::size_t union_stride = (tile_storage_bytes<TileDataF_T>() > tile_storage_bytes<TileOutT>()) ?
tile_storage_bytes<TileDataF_T>() :
tile_storage_bytes<TileOutT>();
constexpr std::size_t union_bytes = union_stride * SrcBuffers;
constexpr std::size_t total_bytes = union_bytes + xexp_bytes + (reduce_tile_bytes * (3U + ExpMaxBuffers)) +
(float_tile_bytes * 1U) + out_tile_bytes;
static_assert(total_bytes <= MAX_VEC_UB_BYTES, "Vec tile UB allocation exceeds 256KB");
uint32_t offset = 0;
TASSIGN(runningOTile, offset);
offset += out_tile_bytes;
offset = assign_tile_buffers_union(srcTiles, pvTile, offset);
#endif
TASSIGN(m1_local_max, offset);
offset += static_cast<uint32_t>(reduce_tile_bytes);
TASSIGN(m2_global_max, offset);
offset += static_cast<uint32_t>(reduce_tile_bytes);
uint32_t tmp_float_offset = offset;
TASSIGN(input_reduce_tmp, tmp_float_offset);
offset += static_cast<uint32_t>(float_tile_bytes);
TASSIGN(l1_local_sum, offset);
offset += static_cast<uint32_t>(reduce_tile_bytes);
TASSIGN(l2_global_sum, offset);
offset += static_cast<uint32_t>(reduce_tile_bytes);
offset = assign_tile_buffers(l1_exp_max, offset);
uint32_t tail_offset = assign_tile_buffers(x_expT, offset);
(void)tail_offset;
}
template <typename AccTileT>
AICORE inline int assign_running_acc_tile(AccTileT &accTile, int initial_id = -1)
{
static int running_tile_buffer_idx = 0;
if (initial_id == 0 || initial_id == 1) {
running_tile_buffer_idx = initial_id;
}
const int id = running_tile_buffer_idx;
const uint32_t base_addr = (id == 0) ? 0x0u : 0x10000u;
TASSIGN(accTile, base_addr);
running_tile_buffer_idx ^= 1;
return id;
}
template <typename QKPipe, int S0, int HEAD_SIZE, int S1, int CUBE_S0, int CUBE_S1, int TILE_S1,
bool INTERMEDIATE_CHECK, bool CAUSAL_MASK, int SRC_VEC_TN_BUFFERS, typename TileMatQData,
typename TileMatKData, typename TileQKData, typename TSyncUBBuf>
AICORE inline void compute_qk(QKPipe &qkPipe, int tile_id, int sub_tile_id, __gm__ half *q, __gm__ half *k,
__gm__ float *qk_tile_fifo, TileMatQData &qMatTile, TileMatKData &kMatTile,
TileQKData &qkAccTile, uint64_t qkMatTileEventId, int accTileEvtID, TSyncUBBuf &ubBufSync,
int blk_idx)
{
if constexpr (DAV_CUBE) {
constexpr uint32_t Cube_S0 = CUBE_S0;
constexpr uint32_t Cube_S1 = CUBE_S1;
constexpr uint32_t Tile_S1 = TILE_S1;
constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1;
constexpr uint32_t Cube_HEAD = HEAD_SIZE;
constexpr int QKP_CV_FIFO = QKPipe::DataFiFo::fifoDepth;
constexpr int CV_FIFO_CONS_SYNC_PERIOD = QKPipe::DataFiFo::fifoPeriod;
static_assert(QKP_CV_FIFO >= 1, "QKP_CV_FIFO must be >= 1");
static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1");
const int s0_index = blk_idx * CUBE_S0;
const int s1_index = tile_id * static_cast<int>(Tile_S1) + sub_tile_id * static_cast<int>(Cube_S1);
const int sync_iter = tile_id;
const bool should_wait_consume = should_wait_consumption<QKP_CV_FIFO, CV_FIFO_CONS_SYNC_PERIOD>(sync_iter);
if constexpr (CAUSAL_MASK) {
if (s1_index > s0_index) {
if (sub_tile_id == 0 && should_wait_consume)
qkPipe.prod.allocate();
if (sub_tile_id == static_cast<int>(kTileFactor) - 1)
qkPipe.prod.record();
return;
}
}
using GlobalDataQ =
GlobalTensor<half, pto::Shape<1, 1, 1, Cube_S0, HEAD_SIZE>, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>;
using GlobalDataK = GlobalTensor<half, pto::Shape<1, 1, 1, HEAD_SIZE, Cube_S1>,
pto::Stride<1, 1, 1, 1, HEAD_SIZE>, Layout::DN>;
GlobalDataQ qGlobal(q);
GlobalDataK kGlobal(k + s1_index * HEAD_SIZE);
wait_flag(PIPE_MTE1, PIPE_MTE2, qkMatTileEventId);
if (tile_id == 0 && sub_tile_id == 0) {
TLOAD(qMatTile, qGlobal);
}
TLOAD(kMatTile, kGlobal);
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
#if !UF_ENABLE
wait_flag(PIPE_FIX, PIPE_M, accTileEvtID);
#endif
pto_macro_matmul<Cube_S0, Cube_HEAD, Cube_S1>(qMatTile, kMatTile, qkAccTile,
UF_ENABLE ? AccMode::InitFinalSum : AccMode::Init);
set_flag(PIPE_MTE1, PIPE_MTE2, qkMatTileEventId);
set_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
#if USE_L0C_TO_DUAL_UB_PATH_QK
if constexpr (INTERMEDIATE_CHECK) {
using GlobalDataQK =
GlobalTensor<float, pto::Shape<1, 1, 1, Cube_S0, Cube_S1>, pto::Stride<1, 1, 1, Cube_S1, 1>>;
const uint32_t buf_idx = static_cast<uint32_t>(tile_id % kFaCvFifoSize);
const size_t base_elems =
static_cast<size_t>(buf_idx) * static_cast<size_t>(kTileFactor) * static_cast<size_t>(Cube_S0) *
static_cast<size_t>(Cube_S1) +
static_cast<size_t>(sub_tile_id) * static_cast<size_t>(Cube_S0) * static_cast<size_t>(Cube_S1);
GlobalDataQK qkGlobalTile(qk_tile_fifo + base_elems);
TSTORE<UF_ENABLE ? STPhase::Final : STPhase::Unspecified>(qkGlobalTile, qkAccTile);
}
if (sub_tile_id == 0 && tile_id >= static_cast<int>(SRC_VEC_TN_BUFFERS)) {
ubBufSync.allocate();
}
#endif
bool isAllocate = (sub_tile_id == 0 && should_wait_consume);
bool isRecord = (sub_tile_id == (static_cast<int>(kTileFactor) - 1));
qkPipe.prod.setAllocateStatus(isAllocate);
qkPipe.prod.setRecordStatus(isRecord);
qkPipe.prod.setEntryOffset(sub_tile_id * Cube_S0 * Cube_S1 * sizeof(float));
TPUSH(qkAccTile, qkPipe);
#if !UF_ENABLE
set_flag(PIPE_FIX, PIPE_M, accTileEvtID);
#endif
}
}
template <typename PPipe, typename PVPipe, int S0, int HEAD_SIZE, int S1, int CUBE_S0, int CUBE_S1, int TILE_S1,
bool INTERMEDIATE_CHECK, bool CAUSAL_MASK, int OUT_O_TILE_NBUFFERS, typename TileMatPData,
typename TileMatVData, typename TilePVData, typename TileOutT>
AICORE inline void compute_pv(PPipe &pPipe, PVPipe &pvPipe, int tile_id, int sub_tile_id, __gm__ half *v,
__gm__ float *pv_tile_fifo, TileMatPData &pMatTile, TileMatVData &vMatTile,
TilePVData &pvAccTile, TileOutT &runningOTile, uint64_t svMatTileEventId,
int accTileEvtID, int blk_idx)
{
constexpr uint32_t Cube_S0 = CUBE_S0;
constexpr uint32_t Cube_S1 = CUBE_S1;
constexpr uint32_t Tile_S1 = TILE_S1;
constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1;
constexpr uint32_t Cube_HEAD = HEAD_SIZE;
constexpr uint32_t TileElems = Cube_S0 * Tile_S1;
constexpr int P_CV_FIFO = PPipe::DataFiFo::fifoDepth;
constexpr int P_CV_FIFO_CONS_SYNC_PERIOD = PPipe::DataFiFo::fifoPeriod;
constexpr int PV_CV_FIFO = PVPipe::DataFiFo::fifoDepth;
constexpr int PV_CV_FIFO_CONS_SYNC_PERIOD = PVPipe::DataFiFo::fifoPeriod;
static_assert(P_CV_FIFO >= 1, "P_CV_FIFO must be >= 1");
static_assert(PV_CV_FIFO >= 1, "PV_CV_FIFO must be >= 1");
static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1");
const int s0_index = blk_idx * Cube_S0;
const int s1_index = tile_id * static_cast<int>(Tile_S1) + sub_tile_id * static_cast<int>(Cube_S1);
const int sync_iter = tile_id;
const bool should_wait_pv_consume = should_wait_consumption<PV_CV_FIFO, PV_CV_FIFO_CONS_SYNC_PERIOD>(sync_iter);
const bool should_notify_p_consume = should_notify_consumption<P_CV_FIFO, P_CV_FIFO_CONS_SYNC_PERIOD>(sync_iter);
const bool is_last_subtile = (sub_tile_id + 1 == static_cast<int>(kTileFactor));
const bool next_will_be_skipped = (s1_index + static_cast<int>(Cube_S1)) > s0_index && CAUSAL_MASK;
if constexpr (DAV_CUBE) {
if constexpr (CAUSAL_MASK) {
if (s1_index > s0_index) {
if (sub_tile_id == 0)
pPipe.cons.wait();
if (sub_tile_id == static_cast<int>(kTileFactor) - 1 && should_notify_p_consume)
pPipe.cons.free();
return;
}
}
using GlobalVT =
GlobalTensor<half, pto::Shape<1, 1, 1, Cube_S1, HEAD_SIZE>, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>;
wait_flag(PIPE_MTE1, PIPE_MTE2, svMatTileEventId);
GlobalVT vLoad((__gm__ half *)(v + s1_index * HEAD_SIZE));
TLOAD(vMatTile, vLoad);
bool isWait = (sub_tile_id == 0);
bool isFree = (sub_tile_id == static_cast<int>(kTileFactor) - 1 && should_notify_p_consume);
pPipe.cons.setWaitStatus(isWait);
pPipe.cons.setFreeStatus(isFree);
pPipe.cons.setEntryOffset(sub_tile_id * Cube_S0 * Cube_S1 * sizeof(half));
TPOP(pMatTile, pPipe);
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
if (sub_tile_id == 0) {
wait_flag(PIPE_FIX, PIPE_M, accTileEvtID);
}
#if UF_ENABLE
const AccMode accMode =
(sub_tile_id == 0) ?
(is_last_subtile || next_will_be_skipped ? AccMode::InitFinalSum : AccMode::InitPartialSum) :
(is_last_subtile || next_will_be_skipped ? AccMode::AccFinalSum : AccMode::AccPartialSum);
pto_macro_matmul<Cube_S0, Cube_S1, Cube_HEAD>(pMatTile, vMatTile, pvAccTile, accMode);
#else
const AccMode accMode = (sub_tile_id == 0) ? AccMode::Init : AccMode::Acc;
pto_macro_matmul<Cube_S0, Cube_S1, Cube_HEAD>(pMatTile, vMatTile, pvAccTile, accMode);
#endif
set_flag(PIPE_MTE1, PIPE_MTE2, svMatTileEventId);
if (sub_tile_id == static_cast<int>(kTileFactor) - 1 || next_will_be_skipped) {
set_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
#if USE_L0C_TO_UB_PV_PATH
#if USE_UB_TO_L1_PATH
TFREE(pPipe);
#endif
if (tile_id == 0) {
if (should_wait_pv_consume)
pvPipe.prod.allocate();
TMOV<TileOutT, TilePVData, AccToVecMode::DualModeSplitM>(runningOTile, pvAccTile);
pvPipe.prod.record();
} else {
pvPipe.prod.setAllocateStatus(should_wait_pv_consume);
pvPipe.prod.setRecordStatus(true);
pvPipe.prod.setEntryOffset(0);
TPUSH(pvAccTile, pvPipe);
}
if constexpr (INTERMEDIATE_CHECK) {
using GlobalDataPV =
GlobalTensor<float, pto::Shape<1, 1, 1, Cube_S0, HEAD_SIZE>, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>;
const uint32_t buf_idx_pv = static_cast<uint32_t>(tile_id % kFaCvFifoSize);
const size_t base_elems_pv =
static_cast<size_t>(buf_idx_pv) * static_cast<size_t>(Cube_S0) * static_cast<size_t>(HEAD_SIZE);
GlobalDataPV pvGlobalTile((__gm__ float *)(pv_tile_fifo + base_elems_pv));
TSTORE<UF_ENABLE ? STPhase::Final : STPhase::Unspecified>(pvGlobalTile, pvAccTile);
}
#else
pvPipe.prod.setAllocateStatus(should_wait_pv_consume);
pvPipe.prod.setRecordStatus(true);
pvPipe.prod.setEntryOffset(0);
TPUSH(pvAccTile, pvPipe);
#endif
#if !UF_ENABLE
set_flag(PIPE_FIX, PIPE_M, accTileEvtID);
#endif
}
}
}
template <typename QKPipe, typename PPipe, int S0, int HEAD_SIZE, int S1, int CUBE_S0, int CUBE_S1, int TILE_S1,
bool INTERMEDIATE_CHECK, bool CAUSAL_MASK, typename TileDataF_T, typename TileDataH_T,
typename TileDataH_NZ_T, typename ReduceTileF_T, typename TileMatPData, typename TSyncUBBuf>
AICORE inline void compute_p(QKPipe &qkPipe, PPipe &pPipe, int tile_id, int row_slice, __gm__ half *p_tile_fifo,
__gm__ float *exp_max_ififo, TileDataF_T &qkVecTile, TileDataH_T &x_expT,
TileDataF_T &input_reduce_tmp, ReduceTileF_T &m1_local_max, ReduceTileF_T &l1_local_sum,
ReduceTileF_T &m2_global_max, ReduceTileF_T &l2_global_sum,
ReduceTileF_T &l1_exp_max_ififo, TileMatPData &pMatTile, TileDataH_NZ_T &nzConvBuffer,
uint64_t pTileEventId, TSyncUBBuf &ubBufSync, int blk_idx)
{
constexpr uint32_t Cube_S0 = CUBE_S0;
constexpr uint32_t Cube_S1 = CUBE_S1;
constexpr uint32_t Tile_S1 = TILE_S1;
constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1;
constexpr uint32_t Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor;
const bool initFlag = (tile_id == 0);
constexpr int QK_CV_FIFO = QKPipe::DataFiFo::fifoDepth;
constexpr int QK_CV_FIFO_CONS_SYNC_PERIOD = QKPipe::DataFiFo::fifoPeriod;
constexpr int P_CV_FIFO = PPipe::DataFiFo::fifoDepth;
constexpr int P_CV_FIFO_CONS_SYNC_PERIOD = PPipe::DataFiFo::fifoPeriod;
static_assert(QK_CV_FIFO >= 1, "QK_CV_FIFO must be >= 1");
static_assert(P_CV_FIFO >= 1, "P_CV_FIFO must be >= 1");
static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1");
static_assert(Cube_S0 % (VEC_CORES * kTileFactor) == 0, "Vec rows must divide evenly across tile slices");
if constexpr (DAV_VEC) {
const size_t subblock_base_rows =
static_cast<size_t>(Cube_S0 / VEC_CORES) * static_cast<size_t>(get_subblockid());
const size_t row_offset = subblock_base_rows + static_cast<size_t>(row_slice * Vec_S0);
const int s0_index = blk_idx * Cube_S0 + row_offset;
const int s1_index = tile_id * static_cast<int>(Tile_S1);
const int sync_iter = tile_id;
const bool should_wait_qk_consume = should_wait_consumption<QK_CV_FIFO, QK_CV_FIFO_CONS_SYNC_PERIOD>(sync_iter);
const bool should_notify_qk_consume =
should_notify_consumption<QK_CV_FIFO, QK_CV_FIFO_CONS_SYNC_PERIOD>(sync_iter);
wait_flag(PIPE_V, PIPE_MTE2, pTileEventId);
const uint32_t buf_idx = static_cast<uint32_t>(tile_id % QK_CV_FIFO);
const size_t base_elems = static_cast<size_t>(buf_idx) * static_cast<size_t>(kTileFactor) *
static_cast<size_t>(Cube_S0) * static_cast<size_t>(Cube_S1);
bool isWait = (row_slice == 0);
bool isFree = (row_slice == static_cast<int>(kTileFactor) - 1 && should_notify_qk_consume);
qkPipe.cons.setWaitStatus(isWait);
qkPipe.cons.setFreeStatus(isFree);
qkPipe.cons.setEntryOffset(row_offset * Cube_S1 * sizeof(float));
TPOP(qkVecTile, qkPipe);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
using ReduceSliceTile = Tile<TileType::Vec, float, Vec_S0, 1, BLayout::ColMajor, Vec_S0, 1>;
const size_t reduce_slice_rows = static_cast<size_t>(row_slice * Vec_S0);
const uint64_t reduce_row_byte_offset = reduce_slice_rows * sizeof(float);
ReduceSliceTile m1_local_max_slice;
ReduceSliceTile l1_local_sum_slice;
ReduceSliceTile m2_global_max_slice;
ReduceSliceTile l2_global_sum_slice;
ReduceSliceTile l1_exp_max_slice;
TASSIGN(m1_local_max_slice, (uint64_t)m1_local_max.data() + reduce_row_byte_offset);
TASSIGN(l1_local_sum_slice, (uint64_t)l1_local_sum.data() + reduce_row_byte_offset);
TASSIGN(m2_global_max_slice, (uint64_t)m2_global_max.data() + reduce_row_byte_offset);
TASSIGN(l2_global_sum_slice, (uint64_t)l2_global_sum.data() + reduce_row_byte_offset);
TASSIGN(l1_exp_max_slice, (uint64_t)l1_exp_max_ififo.data() + reduce_row_byte_offset);
wait_flag(PIPE_MTE3, PIPE_V, pTileEventId);
if (initFlag) {
pto_macro_fa_softmax<true, HEAD_SIZE, CAUSAL_MASK>(
x_expT, qkVecTile, m1_local_max_slice, l1_local_sum_slice, m2_global_max_slice, l2_global_sum_slice,
l1_exp_max_slice, input_reduce_tmp, qkVecTile, input_reduce_tmp, s0_index, s1_index);
} else {
pto_macro_fa_softmax<false, HEAD_SIZE, CAUSAL_MASK>(
x_expT, qkVecTile, m1_local_max_slice, l1_local_sum_slice, m2_global_max_slice, l2_global_sum_slice,
l1_exp_max_slice, input_reduce_tmp, qkVecTile, input_reduce_tmp, s0_index, s1_index);
}
#if USE_L0C_TO_DUAL_UB_PATH_QK
qkPipe.cons.setFreeStatus((row_slice == static_cast<int>(kTileFactor) - 1));
TFREE(qkPipe);
if (row_slice == static_cast<int>(kTileFactor) - 1) {
ubBufSync.free();
}
#endif
set_flag(PIPE_V, PIPE_MTE2, pTileEventId);
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
const bool should_wait_sv_consumed = should_wait_consumption<P_CV_FIFO, P_CV_FIFO_CONS_SYNC_PERIOD>(sync_iter);
#if USE_UB_TO_L1_PATH
if constexpr (INTERMEDIATE_CHECK) {
using GlobalPTileHalfSub =
GlobalTensor<half, pto::Shape<1, 1, 1, Vec_S0, Cube_S1>, pto::Stride<1, 1, 1, Cube_S1, 1>>;
using TileDataH_Sub = Tile<TileType::Vec, half, Vec_S0, Tile_S1, BLayout::RowMajor, Vec_S0, Cube_S1>;
const uint32_t debug_buf_idx = static_cast<uint32_t>(tile_id % kFaCvFifoSize);
const size_t debug_base_elems = static_cast<size_t>(debug_buf_idx) * static_cast<size_t>(kTileFactor) *
static_cast<size_t>(Cube_S0) * static_cast<size_t>(Cube_S1);
__gm__ half *p_ptr = p_tile_fifo + debug_base_elems + row_offset * static_cast<size_t>(Cube_S1);
GlobalPTileHalfSub pTileHalfSub((__gm__ half *)(p_ptr));
TileDataH_Sub xExpSub;
TASSIGN(xExpSub, (uint64_t)x_expT.data());
TSTORE(pTileHalfSub, xExpSub);
}
TMOV(nzConvBuffer, x_expT);
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
bool isAllocate = (row_slice == 0 && should_wait_sv_consumed);
bool isRecord = (row_slice == static_cast<int>(kTileFactor) - 1);
pPipe.prod.setAllocateStatus(isAllocate);
pPipe.prod.setRecordStatus(isRecord);
TPUSH(nzConvBuffer, pPipe);
#else
bool isAllocate = (row_slice == 0 && should_wait_sv_consumed);
const bool isRecord = (row_slice == (static_cast<int>(kTileFactor) - 1));
pPipe.prod.setAllocateStatus(isAllocate);
pPipe.prod.setRecordStatus(isRecord);
pPipe.prod.setEntryOffset(row_offset * Cube_S1 * sizeof(half));
TPUSH(x_expT, pPipe);
(void)nzConvBuffer;
(void)pMatTile;
#endif
if constexpr (INTERMEDIATE_CHECK) {
if (row_slice == static_cast<int>(kTileFactor) - 1) {
constexpr uint32_t SubblockRows = Cube_S0 / VEC_CORES;
using GlobalPMaxFloatSub =
GlobalTensor<float, pto::Shape<1, 1, 1, 1, SubblockRows>, pto::Stride<1, 1, 1, Cube_S0, 1>>;
using ExpMaxSub = Tile<TileType::Vec, float, 1, SubblockRows, BLayout::RowMajor, 1, SubblockRows>;
const uint32_t debug_buf_idx = static_cast<uint32_t>(tile_id % kFaCvFifoSize);
const size_t base_elems_pmax =
static_cast<size_t>(debug_buf_idx) * static_cast<size_t>(Cube_S0) + subblock_base_rows;
__gm__ float *p_ptr_fp32 = exp_max_ififo + base_elems_pmax;
GlobalPMaxFloatSub pMaxGlobal(p_ptr_fp32);
ExpMaxSub l1_exp_max_rowmajor;
TRESHAPE(l1_exp_max_rowmajor, l1_exp_max_ififo);
TSTORE(pMaxGlobal, l1_exp_max_rowmajor);
}
}
set_flag(PIPE_MTE3, PIPE_V, pTileEventId);
}
}
template <typename PVPipe, int S0, int HEAD_SIZE, int S1, int CUBE_S0, int TILE_S1, bool INTERMEDIATE_CHECK,
bool CAUSAL_MASK, int SRC_VEC_TN_BUFFERS, int OUT_O_TILE_NBUFFERS, typename TileOutT, typename ReduceTileF_T>
AICORE inline void compute_gu(PVPipe &pvPipe, int tile_id, int num_tiles, __gm__ float *o_out, TileOutT &runningOTile,
TileOutT &pvVecTile, ReduceTileF_T &l1_exp_max_ififo, ReduceTileF_T &l2_global_sum,
uint64_t guEventId)
{
constexpr uint32_t Cube_S0 = CUBE_S0;
constexpr uint32_t Vec_S0 = Cube_S0 / VEC_CORES;
using GlobalDataPV_VEC =
GlobalTensor<float, pto::Shape<1, 1, 1, Vec_S0, HEAD_SIZE>, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>;
constexpr int PV_CV_FIFO = PVPipe::DataFiFo::fifoDepth;
constexpr int CV_FIFO_CONS_SYNC_PERIOD = PVPipe::DataFiFo::fifoPeriod;
if constexpr (DAV_VEC) {
const size_t subblock_base_rows =
static_cast<size_t>(Cube_S0 / VEC_CORES) * static_cast<size_t>(get_subblockid());
const bool should_notify_consume = should_notify_consumption<PV_CV_FIFO, CV_FIFO_CONS_SYNC_PERIOD>(tile_id);
wait_flag(PIPE_V, PIPE_MTE2, guEventId);
#if USE_L0C_TO_UB_PV_PATH
pvPipe.cons.setWaitStatus(true);
pvPipe.cons.setFreeStatus(should_notify_consume);
pvPipe.cons.setEntryOffset(0);
TPOP(pvVecTile, pvPipe);
if (tile_id > 0) {
if (tile_id < num_tiles - 1) {
pto_macro_fa_gu<ReduceTileF_T, TileOutT>(runningOTile, pvVecTile, l1_exp_max_ififo);
} else {
pto_macro_fa_gu_last<ReduceTileF_T, TileOutT>(runningOTile, pvVecTile, l1_exp_max_ififo, l2_global_sum);
}
} else {
if constexpr (CAUSAL_MASK) {
if (tile_id == num_tiles - 1)
pto_macro_fa_gu_single_and_last_tile(runningOTile, l2_global_sum);
}
}
TFREE(pvPipe);
#else
pvPipe.cons.setWaitStatus(true);
pvPipe.cons.setFreeStatus(should_notify_consume);
pvPipe.cons.setEntryOffset(subblock_base_rows * HEAD_SIZE * sizeof(float));
if (tile_id == 0) {
TPOP(runningOTile, pvPipe);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
if constexpr (CAUSAL_MASK) {
if (tile_id == num_tiles - 1)
pto_macro_fa_gu_single_and_last_tile(runningOTile, l2_global_sum);
}
} else {
TPOP(pvVecTile, pvPipe);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
if (tile_id < num_tiles - 1) {
pto_macro_fa_gu<ReduceTileF_T, TileOutT>(runningOTile, pvVecTile, l1_exp_max_ififo);
} else {
pto_macro_fa_gu_last<ReduceTileF_T, TileOutT>(runningOTile, pvVecTile, l1_exp_max_ififo, l2_global_sum);
}
}
#endif
set_flag(PIPE_V, PIPE_MTE2, guEventId);
if (tile_id == num_tiles - 1) {
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
using GlobalOutT =
GlobalTensor<float, pto::Shape<1, 1, 1, Vec_S0, HEAD_SIZE>, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>;
GlobalOutT outGlobal((__gm__ float *)(o_out + subblock_base_rows * HEAD_SIZE));
TSTORE(outGlobal, runningOTile);
}
}
}
template <int S0, int HEAD_SIZE, int S1, int CUBE_S0, int CUBE_S1, int TILE_S1, int QK_PRELOAD, int CV_FIFO_SIZE,
bool INTERMEDIATE_CHECK, bool CAUSAL_MASK, int CV_FIFO_CONS_SYNC_PERIOD>
__global__ AICORE void runTFA(__gm__ uint64_t *ffts_addr, __gm__ half *q, __gm__ half *k, __gm__ half *v,
__gm__ half *p_tile_fifo, __gm__ float *exp_max_ififo, __gm__ float *global_sum_out,
__gm__ float *exp_max_out, __gm__ float *o_out, __gm__ float *o_parts_out,
__gm__ float *qk_tile_fifo, __gm__ float *pv_tile_fifo, __gm__ uint8_t *cv_comm_buf,
__gm__ uint8_t *profile_buf)
{
uint64_t tStart = get_sys_cnt();
set_ffts_base_addr((uint64_t)ffts_addr);
if constexpr (DAV_CUBE) {
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0);
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
}
constexpr uint32_t Cube_S0 = CUBE_S0;
constexpr uint32_t block_rows = S0 / CUBE_S0;
constexpr uint32_t Cube_S1 = CUBE_S1;
constexpr uint32_t Tile_S1 = TILE_S1;
static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1");
constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1;
constexpr uint32_t Cube_HEAD = HEAD_SIZE;
constexpr uint32_t Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor;
constexpr uint32_t VecGuRows = Cube_S0 / VEC_CORES;
static_assert(Cube_S0 % (VEC_CORES * kTileFactor) == 0, "Vec rows must divide evenly across tile slices");
constexpr uint32_t qkPreloadNum = QK_PRELOAD;
constexpr uint32_t srcVecTNBuffers = 2;
constexpr uint32_t xexpVecTNBuffers = 2;
constexpr uint32_t outOTileNBuffers = 2;
constexpr uint32_t qMatTNBuffers = 1;
constexpr uint32_t kMatTNBuffers = 2;
constexpr uint32_t pMatTNBuffers = 2;
constexpr uint32_t vMatTNBuffers = 2;
constexpr uint32_t qkp_tile_fifo_size = CV_FIFO_SIZE;
constexpr uint32_t pv_tile_fifo_size = CV_FIFO_SIZE;
static_assert(qkPreloadNum >= 1, "qkPreloadNum must be >= 1");
static_assert(CV_FIFO_CONS_SYNC_PERIOD >= 1, "CV_FIFO_CONS_SYNC_PERIOD must be >= 1");
static_assert((qkPreloadNum > 1) || (kTileFactor == 1), "qkPreloadNum must be > 1 unless kTileFactor == 1");
#if USE_UB_TO_L1_PATH
static_assert(qkPreloadNum <= pMatTNBuffers,
"USE_UB_TO_L1_PATH requires qkPreloadNum <= pMatTNBuffers (2) to avoid buffer races. "
"Use --qk-preload 2 when running with UB mode enabled.");
#endif
using TileMatQData =
Tile<TileType::Mat, half, Cube_S0, HEAD_SIZE, BLayout::ColMajor, Cube_S0, HEAD_SIZE, SLayout::RowMajor, 512>;
using TileMatKData =
Tile<TileType::Mat, half, HEAD_SIZE, Cube_S1, BLayout::RowMajor, HEAD_SIZE, Cube_S1, SLayout::ColMajor, 512>;
using TileQKData = TileAcc<float, Cube_S0, Cube_S1, Cube_S0, Cube_S1>;
TileMatQData qMatTile[qMatTNBuffers];
TileMatKData kMatTile[kMatTNBuffers];
TileQKData qkAccTile;
using TileMatPData =
Tile<TileType::Mat, half, Cube_S0, Cube_S1, BLayout::ColMajor, Cube_S0, Cube_S1, SLayout::RowMajor, 512>;
using TileMatVData =
Tile<TileType::Mat, half, Cube_S1, HEAD_SIZE, BLayout::ColMajor, Cube_S1, HEAD_SIZE, SLayout::RowMajor, 512>;
using TilePVData = TileAcc<float, Cube_S0, HEAD_SIZE, Cube_S0, HEAD_SIZE>;
TileMatPData pMatTile[pMatTNBuffers];
TileMatVData vMatTile[vMatTNBuffers];
TilePVData pvAccTile;
allocate_cube_tile_buffers(qMatTile, kMatTile, pMatTile, vMatTile);
assign_running_acc_tile(qkAccTile, 0);
assign_running_acc_tile(pvAccTile, 1);
using TileDataF_T = Tile<TileType::Vec, float, Vec_S0, Tile_S1, BLayout::RowMajor, Vec_S0, Tile_S1>;
using TileDataH_T = Tile<TileType::Vec, half, Vec_S0, Tile_S1, BLayout::RowMajor, Vec_S0, Tile_S1>;
constexpr uint32_t NzBufRows = Vec_S0 + 1;
using TileDataH_NZ_T = Tile<TileType::Vec, half, NzBufRows, Cube_S1, BLayout::ColMajor, Vec_S0, Cube_S1,
SLayout::RowMajor, 512, PadValue::Null, CompactMode::RowPlusOne>;
constexpr uint32_t SubblockRows = Cube_S0 / VEC_CORES;
using ReduceTileF_T = Tile<TileType::Vec, float, SubblockRows, 1, BLayout::ColMajor, SubblockRows, 1>;
TileDataF_T qkVecTile[srcVecTNBuffers];
ReduceTileF_T m1_local_max;
TileDataF_T input_reduce_tmp;
ReduceTileF_T l1_local_sum;
ReduceTileF_T m2_global_max;
ReduceTileF_T l2_global_sum;
ReduceTileF_T l1_exp_max_ififo[qkp_tile_fifo_size];
TileDataH_T x_expT[xexpVecTNBuffers];
TileDataH_NZ_T nzConvBuffer;
using TileOutGuT = Tile<TileType::Vec, float, VecGuRows, HEAD_SIZE, BLayout::RowMajor, VecGuRows, HEAD_SIZE>;
TileOutGuT pvVecTile[outOTileNBuffers];
TileOutGuT runningOTile;
allocate_vec_tile_buffers<TileDataF_T, ReduceTileF_T, TileDataH_T, TileOutGuT, srcVecTNBuffers, xexpVecTNBuffers,
outOTileNBuffers>(qkVecTile, m1_local_max, input_reduce_tmp, l1_local_sum, m2_global_max,
l2_global_sum, l1_exp_max_ififo, x_expT, pvVecTile, runningOTile);
constexpr uint32_t nzBufSize = NzBufRows * Cube_S1 * sizeof(half);
constexpr uint32_t nzBufOffset = MAX_VEC_UB_BYTES - nzBufSize;
if constexpr (DAV_VEC) {
TASSIGN(nzConvBuffer, nzBufOffset);
}
const int block_offset_rows = block_idx * static_cast<int>(Cube_S0);
constexpr bool use_cv_comm = (!INTERMEDIATE_CHECK) && (block_rows >= static_cast<uint32_t>(pto::kCvMaxCores));
int comm_slot = block_idx;
if constexpr (use_cv_comm) {
comm_slot = pto::TSYNC_CVID(block_idx, cv_comm_buf);
}
__gm__ uint64_t *profile_entry = nullptr;
if (profile_buf != nullptr) {
std::size_t profile_block_base = static_cast<std::size_t>(block_idx) * kFaProfileBytesPerBlock;
std::size_t profile_offset = profile_block_base;
if constexpr (DAV_VEC) {
profile_offset +=
(static_cast<std::size_t>(get_subblockid()) + 1U) * 1024U;
}
profile_entry = reinterpret_cast<__gm__ uint64_t *>(profile_buf + profile_offset);
profile_entry[0] = tStart;
}
const size_t p_fifo_block_stride =
static_cast<size_t>(qkp_tile_fifo_size) * static_cast<size_t>(Cube_S0) * static_cast<size_t>(Tile_S1);
const size_t p_max_fifo_block_stride = static_cast<size_t>(qkp_tile_fifo_size) * static_cast<size_t>(Cube_S0);
const size_t qk_fifo_block_stride = p_fifo_block_stride;
const size_t pv_fifo_block_stride =
static_cast<size_t>(pv_tile_fifo_size) * static_cast<size_t>(Cube_S0) * static_cast<size_t>(HEAD_SIZE);
__gm__ half *q_block = q + block_offset_rows * HEAD_SIZE;
__gm__ half *p_tile_fifo_block = p_tile_fifo + static_cast<size_t>(comm_slot) * p_fifo_block_stride;
__gm__ float *exp_max_ififo_block = exp_max_ififo + static_cast<size_t>(comm_slot) * p_max_fifo_block_stride;
__gm__ float *global_sum_block = global_sum_out + block_offset_rows;
__gm__ float *exp_max_block = exp_max_out + block_offset_rows;
__gm__ float *o_out_block = o_out + static_cast<size_t>(block_offset_rows) * static_cast<size_t>(HEAD_SIZE);
__gm__ float *o_parts_block = o_parts_out + static_cast<size_t>(block_offset_rows) * static_cast<size_t>(HEAD_SIZE);
__gm__ float *qk_tile_fifo_block = qk_tile_fifo + static_cast<size_t>(comm_slot) * qk_fifo_block_stride;
__gm__ float *pv_tile_fifo_block = pv_tile_fifo + static_cast<size_t>(comm_slot) * pv_fifo_block_stride;
constexpr TSync_Custom<SyncOpType::TMOV_C2UB, SyncOpType::TLOAD> ubBufSync = {UB_BUF_READY};
int num_tiles_s1 = S1 / Tile_S1;
if constexpr (CAUSAL_MASK)
num_tiles_s1 = (1 + ((block_idx * CUBE_S0) / Tile_S1));
if constexpr (DAV_CUBE) {
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0);
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2);
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3);
set_flag(PIPE_FIX, PIPE_M, EVENT_ID0);
set_flag(PIPE_FIX, PIPE_M, EVENT_ID1);
}
if constexpr (DAV_VEC) {
set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0);
set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1);
set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0);
set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1);
}
int p_gu_src_pingpong_id = 0;
int k_src_pingpong_id = 0;
int pv_src_pingpong_id = 0;
int qkAccTileEvtID = 0;
int pvAccTileEvtID = 0;
#if USE_L0C_TO_DUAL_UB_PATH_QK
constexpr uint8_t QKFiFoDepth = 2;
constexpr uint8_t QKFiFoSyncT = 2;
using QKPipe =
TMPipe<BUF0_QK_READY, FIFOType::VEC_FIFO, QKFiFoDepth, QKFiFoSyncT, TileQKData, TileDataF_T, false, 0>;
QKPipe qkPipe((uint32_t)(uint64_t)qkVecTile[0].data());
#else
constexpr uint8_t QKFiFoDepth = CV_FIFO_SIZE;
constexpr uint8_t QKFiFoSyncT = CV_FIFO_CONS_SYNC_PERIOD;
using QKPipe = TMPipe<BUF0_QK_READY, FIFOType::GM_FIFO, QKFiFoDepth, QKFiFoSyncT, TileQKData, TileDataF_T,
UF_ENABLE ? true : false, 0>;
QKPipe qkPipe(qk_tile_fifo_block, (uint32_t)(uint64_t)qkVecTile[0].data());
#endif
#if USE_UB_TO_L1_PATH
constexpr uint8_t PFiFoDepth = 2;
constexpr uint8_t PFiFoSyncT = 2;
using PPipe =
TMPipe<BUF1_SM_READY, FIFOType::MAT_FIFO, PFiFoDepth, PFiFoSyncT, TileDataH_NZ_T, TileMatPData, false, 0>;
PPipe pPipe((uint32_t)(uint64_t)pMatTile[0].data());
#else
constexpr uint8_t PFiFoDepth = CV_FIFO_SIZE;
constexpr uint8_t PFiFoSyncT = CV_FIFO_CONS_SYNC_PERIOD;
using PPipe = TMPipe<BUF1_SM_READY, FIFOType::GM_FIFO, PFiFoDepth, PFiFoSyncT, TileDataH_T, TileMatPData, false, 0>;
PPipe pPipe(p_tile_fifo_block, (uint32_t)(uint64_t)pMatTile[0].data());
#endif
#if USE_L0C_TO_UB_PV_PATH
constexpr uint8_t PVFiFoDepth = 2;
constexpr uint8_t PVFiFoSyncT = 2;
using PVPipe = TMPipe<UPDATE_READY, FIFOType::VEC_FIFO, PVFiFoDepth, PVFiFoSyncT, TilePVData, TileOutGuT, false, 0>;
PVPipe pvPipe((uint32_t)(uint64_t)pvVecTile[0].data());
#else
constexpr uint8_t PVFiFoDepth = CV_FIFO_SIZE;
constexpr uint8_t PVFiFoSyncT = CV_FIFO_CONS_SYNC_PERIOD;
using PVPipe = TMPipe<UPDATE_READY, FIFOType::GM_FIFO, PVFiFoDepth, PVFiFoSyncT, TilePVData, TileOutGuT,
UF_ENABLE ? true : false, 0>;
PVPipe pvPipe(pv_tile_fifo_block, (uint32_t)(uint64_t)pvVecTile[0].data());
#endif
for (int preload_tile = 0; preload_tile < static_cast<int>(qkPreloadNum) && preload_tile < num_tiles_s1;
++preload_tile) {
if constexpr (DAV_CUBE) {
for (int sub_tile = 0; sub_tile < static_cast<int>(kTileFactor); ++sub_tile) {
qkAccTileEvtID = assign_running_acc_tile(qkAccTile);
qkPipe.prod.setTileId(preload_tile, sub_tile);
compute_qk<QKPipe, S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, Tile_S1, INTERMEDIATE_CHECK, CAUSAL_MASK,
srcVecTNBuffers>(qkPipe, preload_tile, sub_tile, q_block, k, qk_tile_fifo_block, qMatTile[0],
kMatTile[k_src_pingpong_id % kMatTNBuffers], qkAccTile,
k_src_pingpong_id % kMatTNBuffers, qkAccTileEvtID, ubBufSync, block_idx);
k_src_pingpong_id++;
}
}
if constexpr (DAV_VEC) {
for (int row_slice = 0; row_slice < static_cast<int>(kTileFactor); ++row_slice) {
pPipe.prod.setTileId(preload_tile, row_slice);
qkPipe.cons.setTileId(preload_tile, row_slice);
#if USE_L0C_TO_DUAL_UB_PATH_QK
const int tile_buf_idx = preload_tile % srcVecTNBuffers;
#else
const int tile_buf_idx = p_gu_src_pingpong_id % srcVecTNBuffers;
#endif
compute_p<QKPipe, PPipe, S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, Tile_S1, INTERMEDIATE_CHECK, CAUSAL_MASK>(
qkPipe, pPipe, preload_tile, row_slice, p_tile_fifo_block, exp_max_ififo_block,
qkVecTile[tile_buf_idx], x_expT[p_gu_src_pingpong_id % xexpVecTNBuffers], input_reduce_tmp,
m1_local_max, l1_local_sum, m2_global_max, l2_global_sum,
l1_exp_max_ififo[preload_tile % qkp_tile_fifo_size], pMatTile[preload_tile % pMatTNBuffers],
nzConvBuffer, p_gu_src_pingpong_id % xexpVecTNBuffers, ubBufSync, block_idx);
p_gu_src_pingpong_id++;
}
}
}
for (int tile_id = 0; tile_id < num_tiles_s1; ++tile_id) {
int next_qk_tile = (tile_id + static_cast<int>(qkPreloadNum) >= num_tiles_s1) ?
-1 :
(tile_id + static_cast<int>(qkPreloadNum));
if (next_qk_tile != -1)
qkAccTileEvtID = assign_running_acc_tile(qkAccTile);
pvAccTileEvtID = assign_running_acc_tile(pvAccTile);
for (int sub_tile = 0; sub_tile < static_cast<int>(kTileFactor); ++sub_tile) {
if constexpr (DAV_CUBE) {
if (next_qk_tile != -1) {
qkPipe.prod.setTileId(next_qk_tile, sub_tile);
compute_qk<QKPipe, S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, Tile_S1, INTERMEDIATE_CHECK, CAUSAL_MASK,
srcVecTNBuffers>(qkPipe, next_qk_tile, sub_tile, q_block, k, qk_tile_fifo_block,
qMatTile[0], kMatTile[k_src_pingpong_id % kMatTNBuffers], qkAccTile,
k_src_pingpong_id % kMatTNBuffers, qkAccTileEvtID, ubBufSync,
block_idx);
k_src_pingpong_id++;
}
}
if constexpr (DAV_VEC) {
if (next_qk_tile != -1) {
pPipe.prod.setTileId(next_qk_tile, sub_tile);
qkPipe.cons.setTileId(next_qk_tile, sub_tile);
#if USE_L0C_TO_DUAL_UB_PATH_QK
const int tile_buf_idx = next_qk_tile % srcVecTNBuffers;
#else
const int tile_buf_idx = p_gu_src_pingpong_id % srcVecTNBuffers;
#endif
compute_p<QKPipe, PPipe, S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, Tile_S1, INTERMEDIATE_CHECK,
CAUSAL_MASK>(
qkPipe, pPipe, next_qk_tile, sub_tile, p_tile_fifo_block, exp_max_ififo_block,
qkVecTile[tile_buf_idx], x_expT[p_gu_src_pingpong_id % xexpVecTNBuffers], input_reduce_tmp,
m1_local_max, l1_local_sum, m2_global_max, l2_global_sum,
l1_exp_max_ififo[next_qk_tile % qkp_tile_fifo_size], pMatTile[next_qk_tile % pMatTNBuffers],
nzConvBuffer, p_gu_src_pingpong_id % xexpVecTNBuffers, ubBufSync, block_idx);
p_gu_src_pingpong_id++;
}
}
if constexpr (DAV_CUBE) {
pPipe.cons.setTileId(tile_id, sub_tile);
pvPipe.prod.setTileId(tile_id, sub_tile);
compute_pv<PPipe, PVPipe, S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, Tile_S1, INTERMEDIATE_CHECK, CAUSAL_MASK,
outOTileNBuffers>(pPipe, pvPipe, tile_id, sub_tile, v, pv_tile_fifo_block,
pMatTile[pv_src_pingpong_id % pMatTNBuffers],
vMatTile[pv_src_pingpong_id % vMatTNBuffers], pvAccTile, runningOTile,
pv_src_pingpong_id % vMatTNBuffers + PV_EVENT_ID0, pvAccTileEvtID,
block_idx);
pv_src_pingpong_id++;
}
}
if constexpr (DAV_VEC) {
pvPipe.cons.setTileId(tile_id, -1);
compute_gu<PVPipe, S0, HEAD_SIZE, S1, CUBE_S0, Tile_S1, INTERMEDIATE_CHECK, CAUSAL_MASK, srcVecTNBuffers,
outOTileNBuffers>(
pvPipe, tile_id, num_tiles_s1, o_out_block, runningOTile, pvVecTile[tile_id % outOTileNBuffers],
l1_exp_max_ififo[tile_id % qkp_tile_fifo_size], l2_global_sum, tile_id % outOTileNBuffers);
p_gu_src_pingpong_id++;
}
}
const int pending_qk_sm_consumed =
pending_consumption_events(num_tiles_s1, QKPipe::DataFiFo::fifoDepth, QKPipe::DataFiFo::fifoPeriod);
const int pending_sv_consumed =
pending_consumption_events(num_tiles_s1, PPipe::DataFiFo::fifoDepth, PPipe::DataFiFo::fifoPeriod);
const int pending_update_consumed =
pending_consumption_events(num_tiles_s1, PVPipe::DataFiFo::fifoDepth, PVPipe::DataFiFo::fifoPeriod);
if constexpr (DAV_CUBE) {
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2);
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3);
wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0);
wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1);
for (int i = 0; i < pending_qk_sm_consumed; ++i)
qkPipe.prod.allocate();
for (int i = 0; i < pending_update_consumed; ++i)
pvPipe.prod.allocate();
}
if constexpr (DAV_VEC) {
wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1);
wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1);
for (int i = 0; i < pending_sv_consumed; ++i)
pPipe.prod.allocate();
}
pipe_barrier(PIPE_ALL);
uint64_t tEnd = get_sys_cnt();
if (profile_entry != nullptr) {
profile_entry[1] = tEnd;
}
#ifdef _DEBUG
if constexpr (DAV_CUBE) {
cce::printf("Core %d Cube Block %d, Start @%d End @%d (%d us)\n", get_coreid(), block_idx, int(tStart),
int(tEnd), int(tEnd - tStart) * 20 / 1000);
} else {
cce::printf("Core %d Vec Block %d, SubBlock %d, Start @%d End @%d (%d us)\n", get_coreid(), block_idx,
int(get_subblockid()), int(tStart), int(tEnd), int(tEnd - tStart) * 20 / 1000);
}
#endif
}
__global__ AICORE __attribute__((aic)) void warmup_kernel()
{}
template <int S0, int HEAD_SIZE, int S1, int CUBE_S0, int CUBE_S1, int TILE_S1, int QK_PRELOAD, int CV_FIFO_SIZE,
bool INTERMEDIATE_CHECK, bool CAUSAL_MASK, int CV_FIFO_CONS_SYNC_PERIOD>
void LaunchTFA(uint16_t *ffts, aclFloat16 *q, aclFloat16 *k, aclFloat16 *v, aclFloat16 *p_tile_fifo,
float *exp_max_ififo, float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out,
float *qk_tile_fifo, float *pv_tile_fifo, uint8_t *profile_data, aclrtStream stream,
uint8_t *cv_comm_buf)
{
static_assert(S0 % CUBE_S0 == 0, "S0 must be divisible by CUBE_S0");
constexpr uint32_t block_rows = S0 / CUBE_S0;
warmup_kernel<<<32, nullptr, stream>>>();
const uint64_t tensor_elems = static_cast<uint64_t>(S0) * static_cast<uint64_t>(HEAD_SIZE);
const uint64_t tensor_bytes = tensor_elems * sizeof(half);
constexpr bool kPrefetchUseSdma = true;
constexpr int kPrefetchAivCores = 64;
if constexpr (kPrefetchUseSdma) {
PTO_PREFETCH((__gm__ void *)q, tensor_bytes, stream);
PTO_PREFETCH((__gm__ void *)k, tensor_bytes, stream);
PTO_PREFETCH((__gm__ void *)v, tensor_bytes, stream);
} else {
PTO_PREFETCH<false, kPrefetchAivCores>((__gm__ void *)q, tensor_bytes, stream);
PTO_PREFETCH<false, kPrefetchAivCores>((__gm__ void *)k, tensor_bytes, stream);
PTO_PREFETCH<false, kPrefetchAivCores>((__gm__ void *)v, tensor_bytes, stream);
}
runTFA<S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, CV_FIFO_SIZE, INTERMEDIATE_CHECK, CAUSAL_MASK,
CV_FIFO_CONS_SYNC_PERIOD><<<block_rows, nullptr, stream>>>(
(__gm__ uint64_t *)ffts, (half *)q, (half *)k, (half *)v, (half *)p_tile_fifo, exp_max_ififo, global_sum_out,
exp_max_out, o_out, o_parts_out, qk_tile_fifo, pv_tile_fifo, cv_comm_buf, profile_data);
}
template <int S0, int HEAD_SIZE, int S1, int CUBE_S0, int CUBE_S1, int TILE_S1, int QK_PRELOAD, int CV_FIFO_SIZE,
bool INTERMEDIATE_CHECK, bool CAUSAL_MASK, int CV_FIFO_CONS_SYNC_PERIOD>
void LaunchTFA(uint16_t *ffts, aclFloat16 *q, aclFloat16 *k, aclFloat16 *v, aclFloat16 *p_tile_fifo,
float *exp_max_ififo, float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out,
float *qk_tile_fifo, float *pv_tile_fifo, aclrtStream stream, uint8_t *cv_comm_buf)
{
LaunchTFA<S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, CV_FIFO_SIZE, INTERMEDIATE_CHECK, CAUSAL_MASK,
CV_FIFO_CONS_SYNC_PERIOD>(ffts, q, k, v, p_tile_fifo, exp_max_ififo, global_sum_out, exp_max_out, o_out,
o_parts_out, qk_tile_fifo, pv_tile_fifo, nullptr, stream, cv_comm_buf);
}
#include "generated_cases.h"
#define INSTANTIATE_TFA(S0, HEAD, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, CAUSAL_MASK) \
template void LaunchTFA<S0, HEAD, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, kFaCvFifoSize, false, CAUSAL_MASK, \
kFaCvFifoConsSyncPeriod>( \
uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \
float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \
uint8_t *profile_data, aclrtStream stream, uint8_t *cv_comm_buf); \
template void LaunchTFA<S0, HEAD, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, kFaCvFifoSize, false, CAUSAL_MASK, \
kFaCvFifoConsSyncPeriod>( \
uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \
float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \
aclrtStream stream, uint8_t *cv_comm_buf); \
template void LaunchTFA<S0, HEAD, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, kFaCvFifoSize, true, CAUSAL_MASK, \
kFaCvFifoConsSyncPeriod>( \
uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \
float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \
uint8_t *profile_data, aclrtStream stream, uint8_t *cv_comm_buf); \
template void LaunchTFA<S0, HEAD, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, kFaCvFifoSize, true, CAUSAL_MASK, \
kFaCvFifoConsSyncPeriod>( \
uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \
float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \
aclrtStream stream, uint8_t *cv_comm_buf);
TFA_FOR_EACH_CASE(INSTANTIATE_TFA)
#undef INSTANTIATE_TFA