Copyright (c) 2026 Huawei Technologies Co., Ltd.
This program is free software, you can redistribute it and/or modify it under the terms and conditions of
CANN Open Software License Agreement Version 2.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License.
*/
#include <acl/acl.h>
#include <pto/pto-inst.hpp>
#include "fa_performance_kernel.h"
#include <pto/npu/kernels/Pto_prefetch.hpp>
#include <pto/npu/a5/custom/TSyncCVID.hpp>
#include <pto/npu/a5/custom/TSync_Custom.hpp>
#include <pto/npu/a5/TMov.hpp>
#define UF_ENABLE 0
#include "pto_macro_dn_matmul.hpp"
#include "pto_macro_fa_dn_softmax.hpp"
#include "pto_macro_fa_dn_gu.hpp"
using namespace std;
using namespace pto;
#ifndef FIFO_MODE
#define FIFO_MODE 2
#endif
#if (FIFO_MODE < 0) || (FIFO_MODE > 2)
#error "FIFO_MODE must be 0 (ALL_GM_PATH), 1 (ALL_UB_PATH), or 2 (QK_PV_UB_ONLY)"
#endif
#if ((FIFO_MODE == 1) || (FIFO_MODE == 2)) && (UF_ENABLE != 0)
#error "UF_ENABLE must be 0 for mode 1 (ALL_UB_PATH), and mode 2 (QK_PV_UB_ONLY)"
#endif
#if FIFO_MODE == 0
#define USE_L0C_TO_DUAL_UB_PATH_QK 0
#define USE_L0C_TO_UB_PV_PATH 0
#define USE_UB_TO_L1_PATH 0
#elif FIFO_MODE == 1
#define USE_L0C_TO_DUAL_UB_PATH_QK 1
#define USE_L0C_TO_UB_PV_PATH 1
#define USE_UB_TO_L1_PATH 1
#else
#define USE_L0C_TO_DUAL_UB_PATH_QK 1
#define USE_L0C_TO_UB_PV_PATH 1
#define USE_UB_TO_L1_PATH 0
#endif
#ifndef FFTS_BUFFER_FLAG_ENUM
#define FFTS_BUFFER_FLAG_ENUM
enum FftsBufferFlag : uint32_t
{
BUF0_QK_READY = 0,
BUF1_SM_READY = 2,
UPDATE_READY = 4,
UB_BUF_READY = 6,
PV_UB_BUF_READY = 8,
CV_BLOCK_END = 10,
};
#endif
enum CoreEvtID : uint32_t
{
QK_EVENT_ID0,
QK_EVENT_ID1,
PV_EVENT_ID0,
PV_EVENT_ID1,
};
#ifndef PTO_INLINE
#define PTO_INLINE __attribute__((always_inline)) inline
#endif
#ifdef __DAV_CUBE__
constexpr bool DAV_CUBE = true;
#else
constexpr bool DAV_CUBE = false;
#endif
#ifdef __DAV_VEC__
constexpr bool DAV_VEC = true;
#else
constexpr bool DAV_VEC = false;
#endif
constexpr std::size_t MAX_TILE_L1_BYTES = 512U * 1024U;
constexpr std::size_t MAX_VEC_UB_BYTES = 256U * 1024U;
template <int FifoSize, int SyncPeriod>
AICORE inline bool should_wait_consumption(int sync_iter)
{
static_assert(FifoSize >= 1, "CV FIFO size must be >= 1");
constexpr int period = (SyncPeriod > 0) ? SyncPeriod : 1;
static_assert(period >= 1, "CV FIFO consume sync period must be >= 1");
if (sync_iter < static_cast<int>(FifoSize))
return false;
return (sync_iter % period) == 0;
}
template <int FifoSize, int SyncPeriod>
AICORE inline bool should_notify_consumption(int sync_iter)
{
static_assert(FifoSize >= 1, "CV FIFO size must be >= 1");
constexpr int period = (SyncPeriod > 0) ? SyncPeriod : 1;
static_assert(period >= 1, "CV FIFO consume sync period must be >= 1");
return ((sync_iter + 1) % period) == 0;
}
AICORE inline int pending_consumption_events(int tiles_processed, int fifo_size, int sync_period)
{
if (tiles_processed <= 0 || sync_period <= 0 || fifo_size <= 0)
return 0;
const int notify_count = tiles_processed / sync_period;
int wait_count = 0;
if (tiles_processed > fifo_size) {
const int last_iter = tiles_processed - 1;
wait_count = (last_iter / sync_period) - ((fifo_size - 1) / sync_period);
if (wait_count < 0)
wait_count = 0;
}
int pending = notify_count - wait_count;
if (pending < 0)
pending = 0;
const int max_pending = (fifo_size + sync_period - 1) / sync_period;
return (pending > max_pending) ? max_pending : pending;
}
template <typename TileType>
constexpr AICORE std::size_t tile_storage_bytes()
{
using ElementType = typename TileType::DType;
return static_cast<std::size_t>(TileType::Rows * TileType::Cols) * sizeof(ElementType);
}
template <typename TileType, std::size_t NumBuffers>
constexpr AICORE std::size_t tile_buffer_total_bytes()
{
return tile_storage_bytes<TileType>() * NumBuffers;
}
template <typename TileType, std::size_t NumBuffers>
AICORE inline uint32_t assign_tile_buffers(TileType (&tiles)[NumBuffers], uint32_t base_offset)
{
if constexpr (NumBuffers == 0) {
return base_offset;
}
constexpr std::size_t total_storage_bytes = tile_buffer_total_bytes<TileType, NumBuffers>();
static_assert(total_storage_bytes <= MAX_TILE_L1_BYTES, "Tile buffer L1 allocation exceeds 512KB");
for (std::size_t idx = 0; idx < NumBuffers; ++idx) {
const uint32_t tile_offset = base_offset + static_cast<uint32_t>(idx * tile_storage_bytes<TileType>());
TASSIGN(tiles[idx], tile_offset);
}
return base_offset + static_cast<uint32_t>(total_storage_bytes);
}
template <typename TileA, std::size_t NumA, typename TileB, std::size_t NumB>
AICORE inline uint32_t assign_tile_buffers_union(TileA (&tilesA)[NumA], TileB (&tilesB)[NumB], uint32_t base_offset)
{
static_assert(NumA == NumB, "Union assignment expects matching buffer counts");
if constexpr (NumA == 0) {
return base_offset;
}
constexpr std::size_t stride_bytes = (tile_storage_bytes<TileA>() > tile_storage_bytes<TileB>()) ?
tile_storage_bytes<TileA>() :
tile_storage_bytes<TileB>();
constexpr std::size_t total_storage_bytes = stride_bytes * NumA;
static_assert(total_storage_bytes <= MAX_VEC_UB_BYTES, "Union tile UB allocation exceeds 256KB");
for (std::size_t idx = 0; idx < NumA; ++idx) {
const uint32_t tile_offset = base_offset + static_cast<uint32_t>(idx * stride_bytes);
TASSIGN(tilesA[idx], tile_offset);
TASSIGN(tilesB[idx], tile_offset);
}
return base_offset + static_cast<uint32_t>(total_storage_bytes);
}
template <typename TileQType, std::size_t NumQ, typename TileKType, std::size_t NumK, typename TilePType,
std::size_t NumP, typename TileVType, std::size_t NumV>
AICORE inline void allocate_cube_tile_buffers(TileQType (&qTiles)[NumQ], TileKType (&kTiles)[NumK],
TilePType (&pTiles)[NumP], TileVType (&vTiles)[NumV])
{
constexpr std::size_t total_bytes =
tile_buffer_total_bytes<TileQType, NumQ>() + tile_buffer_total_bytes<TileKType, NumK>() +
tile_buffer_total_bytes<TilePType, NumP>() + tile_buffer_total_bytes<TileVType, NumV>();
static_assert(total_bytes <= MAX_TILE_L1_BYTES, "Total cube L1 allocation exceeds 512KB");
uint32_t l1_offset = 0;
l1_offset = assign_tile_buffers(qTiles, l1_offset);
l1_offset = assign_tile_buffers(kTiles, l1_offset);
l1_offset = assign_tile_buffers(pTiles, l1_offset);
l1_offset = assign_tile_buffers(vTiles, l1_offset);
(void)l1_offset;
}
template <typename TileDataF_T, typename ReduceTileF_T, typename TileDataH_T, typename TileOutT, std::size_t SrcBuffers,
std::size_t XexpBuffers, std::size_t pvVecBuffers, std::size_t ExpMaxBuffers>
AICORE inline void allocate_vec_tile_buffers(TileDataF_T (&srcTiles)[SrcBuffers], ReduceTileF_T &m1_local_max,
TileDataF_T &input_reduce_tmp, ReduceTileF_T &l1_local_sum,
ReduceTileF_T &m2_global_max, ReduceTileF_T &l2_global_sum,
ReduceTileF_T (&l1_exp_max)[ExpMaxBuffers],
TileDataH_T (&x_expT)[XexpBuffers], TileOutT (&pvTile)[pvVecBuffers],
TileOutT &runningOTile)
{
constexpr std::size_t float_tile_bytes = tile_storage_bytes<TileDataF_T>();
constexpr std::size_t reduce_tile_bytes = tile_storage_bytes<ReduceTileF_T>();
constexpr std::size_t xexp_bytes = tile_buffer_total_bytes<TileDataH_T, XexpBuffers>();
constexpr std::size_t out_tile_bytes = tile_storage_bytes<TileOutT>();
static_assert(SrcBuffers == pvVecBuffers, "src/pv buffer counts must match");
#if USE_L0C_TO_DUAL_UB_PATH_QK
constexpr std::size_t src_bytes = tile_buffer_total_bytes<TileDataF_T, SrcBuffers>();
constexpr std::size_t pv_bytes = tile_buffer_total_bytes<TileOutT, pvVecBuffers>();
constexpr std::size_t total_bytes = src_bytes + pv_bytes + xexp_bytes + (reduce_tile_bytes * (3U + ExpMaxBuffers)) +
(float_tile_bytes * 1U) + out_tile_bytes;
static_assert(total_bytes <= MAX_VEC_UB_BYTES, "Vec tile UB allocation exceeds 256KB");
uint32_t offset = 0;
offset = assign_tile_buffers(srcTiles, offset);
TASSIGN(runningOTile, offset);
offset += out_tile_bytes;
offset = assign_tile_buffers(pvTile, offset);
#else
constexpr std::size_t union_stride = (tile_storage_bytes<TileDataF_T>() > tile_storage_bytes<TileOutT>()) ?
tile_storage_bytes<TileDataF_T>() :
tile_storage_bytes<TileOutT>();
constexpr std::size_t union_bytes = union_stride * SrcBuffers;
constexpr std::size_t total_bytes = union_bytes + xexp_bytes + (reduce_tile_bytes * (3U + ExpMaxBuffers)) +
(float_tile_bytes * 1U) + out_tile_bytes;
static_assert(total_bytes <= MAX_VEC_UB_BYTES, "Vec tile UB allocation exceeds 256KB");
uint32_t offset = 0;
TASSIGN(runningOTile, offset);
offset += out_tile_bytes;
offset = assign_tile_buffers_union(srcTiles, pvTile, offset);
#endif
TASSIGN(m1_local_max, offset);
offset += static_cast<uint32_t>(reduce_tile_bytes);
TASSIGN(m2_global_max, offset);
offset += static_cast<uint32_t>(reduce_tile_bytes);
uint32_t tmp_float_offset = offset;
TASSIGN(input_reduce_tmp, tmp_float_offset);
offset += static_cast<uint32_t>(float_tile_bytes);
TASSIGN(l1_local_sum, offset);
offset += static_cast<uint32_t>(reduce_tile_bytes);
TASSIGN(l2_global_sum, offset);
offset += static_cast<uint32_t>(reduce_tile_bytes);
offset = assign_tile_buffers(l1_exp_max, offset);
uint32_t tail_offset = assign_tile_buffers(x_expT, offset);
(void)tail_offset;
}
template <typename AccTileT>
AICORE inline int assign_running_acc_tile(AccTileT &accTile, int initial_id = -1)
{
static int running_tile_buffer_idx = 0;
if (initial_id == 0 || initial_id == 1) {
running_tile_buffer_idx = initial_id;
}
const int id = running_tile_buffer_idx;
const uint32_t base_addr = (id == 0) ? 0x0u : 0x10000u;
TASSIGN(accTile, base_addr);
running_tile_buffer_idx ^= 1;
return id;
}
template <int S0, int HEAD_SIZE, int S1, int CUBE_S0, int CUBE_S1, int TILE_S1, int QKP_CV_FIFO,
int CV_FIFO_CONS_SYNC_PERIOD, bool INTERMEDIATE_CHECK, bool CAUSAL_MASK, int SRC_VEC_TN_BUFFERS,
typename TileMatQData, typename TileMatKData, typename TileQKData, typename TileQKVecData,
typename TSyncQK2SM, typename TSyncUBBuf>
AICORE inline void compute_qk(int tile_id, int sub_tile_id, int ub_buf_idx, __gm__ half *q, __gm__ half *k,
__gm__ float *qk_tile_fifo, TileMatQData &qMatTile, TileMatKData &kMatTile,
TileQKData &qkAccTile, TileQKVecData &qkVecTile, uint64_t qkMatTileEventId,
int accTileEvtID, TSyncQK2SM &qk2smSync, TSyncUBBuf &ubBufSync, int blk_idx)
{
if constexpr (DAV_CUBE) {
constexpr uint32_t Cube_S0 = CUBE_S0;
constexpr uint32_t Cube_S1 = CUBE_S1;
constexpr uint32_t Tile_S1 = TILE_S1;
constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1;
constexpr uint32_t Cube_HEAD = HEAD_SIZE;
static_assert(QKP_CV_FIFO >= 1, "QKP_CV_FIFO must be >= 1");
static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1");
const int s0_index = blk_idx * CUBE_S0;
const int s1_index = tile_id * static_cast<int>(Tile_S1) + sub_tile_id * static_cast<int>(Cube_S1);
const int sync_iter = tile_id;
const bool should_wait_consume = should_wait_consumption<QKP_CV_FIFO, CV_FIFO_CONS_SYNC_PERIOD>(sync_iter);
if constexpr (CAUSAL_MASK) {
if (s1_index > s0_index) {
if (sub_tile_id == 0 && should_wait_consume)
qk2smSync.allocate();
if (sub_tile_id == static_cast<int>(kTileFactor) - 1)
qk2smSync.record();
return;
}
}
using GlobalDataQ =
GlobalTensor<half, pto::Shape<1, 1, 1, Cube_S0, HEAD_SIZE>, pto::Stride<1, 1, 1, 1, HEAD_SIZE>, Layout::DN>;
using GlobalDataK =
GlobalTensor<half, pto::Shape<1, 1, 1, Cube_S1, HEAD_SIZE>, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>;
GlobalDataQ qGlobal(q);
GlobalDataK kGlobal(k + s1_index * HEAD_SIZE);
wait_flag(PIPE_MTE1, PIPE_MTE2, qkMatTileEventId);
if (tile_id == 0 && sub_tile_id == 0) {
TLOAD(qMatTile, qGlobal);
}
TLOAD(kMatTile, kGlobal);
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_FIX, PIPE_M, accTileEvtID);
#if UF_ENABLE
pto_macro_matmul<Cube_S1, Cube_HEAD, Cube_S0>(kMatTile, qMatTile, qkAccTile, AccMode::InitFinalSum);
#else
pto_macro_matmul<Cube_S1, Cube_HEAD, Cube_S0>(kMatTile, qMatTile, qkAccTile, AccMode::Init);
#endif
set_flag(PIPE_MTE1, PIPE_MTE2, qkMatTileEventId);
set_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
if (sub_tile_id == 0 && should_wait_consume)
qk2smSync.allocate();
#if USE_L0C_TO_DUAL_UB_PATH_QK
if constexpr (INTERMEDIATE_CHECK) {
const uint32_t buf_idx = static_cast<uint32_t>(tile_id % QKP_CV_FIFO);
const size_t base_elems =
static_cast<size_t>(buf_idx) * static_cast<size_t>(kTileFactor) * static_cast<size_t>(Cube_S0) *
static_cast<size_t>(Cube_S1) +
static_cast<size_t>(sub_tile_id) * static_cast<size_t>(Cube_S0) * static_cast<size_t>(Cube_S1);
using GlobalDataQK =
GlobalTensor<float, pto::Shape<1, 1, 1, Cube_S1, Cube_S0>, pto::Stride<1, 1, 1, Cube_S0, 1>>;
GlobalDataQK qkGlobalTile(qk_tile_fifo + base_elems);
#if UF_ENABLE
TSTORE<STPhase::Final>(qkGlobalTile, qkAccTile);
#else
TSTORE(qkGlobalTile, qkAccTile);
#endif
}
constexpr uint32_t Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor;
const uint64_t col_byte_offset = static_cast<uint64_t>(sub_tile_id * Cube_S1 * sizeof(float));
using TileDataF_Sub = Tile<TileType::Vec, float, Tile_S1, Vec_S0, BLayout::RowMajor, Tile_S1, Vec_S0>;
TileDataF_Sub qkVecTileSubDN;
TASSIGN(qkVecTileSubDN, (uint64_t)qkVecTile.data() + col_byte_offset);
if (sub_tile_id == 0 && tile_id >= static_cast<int>(SRC_VEC_TN_BUFFERS)) {
ubBufSync.allocate();
}
TMOV<TileDataF_Sub, TileQKData, AccToVecMode::DualModeSplitN>(qkVecTileSubDN, qkAccTile);
set_flag(PIPE_FIX, PIPE_M, accTileEvtID);
if (sub_tile_id == static_cast<int>(kTileFactor) - 1) {
qk2smSync.record();
}
#else
using GlobalDataQK =
GlobalTensor<float, pto::Shape<1, 1, 1, Cube_S0, Cube_S1>, pto::Stride<1, 1, 1, Cube_S1, 1>>;
const uint32_t buf_idx = static_cast<uint32_t>(tile_id % QKP_CV_FIFO);
const size_t base_elems =
static_cast<size_t>(buf_idx) * static_cast<size_t>(kTileFactor) * static_cast<size_t>(Cube_S0) *
static_cast<size_t>(Cube_S1) +
static_cast<size_t>(sub_tile_id) * static_cast<size_t>(Cube_S0) * static_cast<size_t>(Cube_S1);
GlobalDataQK qkGlobalTile(qk_tile_fifo + base_elems);
#if UF_ENABLE
TSTORE<STPhase::Final>(qkGlobalTile, qkAccTile);
#else
TSTORE(qkGlobalTile, qkAccTile);
#endif
set_flag(PIPE_FIX, PIPE_M, accTileEvtID);
if (sub_tile_id == static_cast<int>(kTileFactor) - 1)
qk2smSync.record();
(void)ubBufSync;
(void)qkVecTile;
#endif
}
}
template <int S0, int HEAD_SIZE, int S1, int CUBE_S0, int CUBE_S1, int TILE_S1, int QKP_CV_FIFO, int PV_CV_FIFO,
int CV_FIFO_CONS_SYNC_PERIOD, bool INTERMEDIATE_CHECK, bool CAUSAL_MASK, int OUT_O_TILE_NBUFFERS,
typename TileMatPData, typename TileMatVData, typename TilePVData, typename TileOutT, typename TSyncSM2PV,
typename TSyncPV2GU, typename TSyncPVUBBuf>
AICORE inline void compute_pv(int tile_id, int sub_tile_id, int pv_ub_buf_idx, __gm__ half *p_tile_fifo, __gm__ half *v,
__gm__ float *pv_tile_fifo, TileMatPData &pMatTile, TileMatVData &vMatTile,
TilePVData &pvAccTile, TileOutT &runningOTile, TileOutT (&pvVecTile)[OUT_O_TILE_NBUFFERS],
uint64_t svMatTileEventId, int accTileEvtID, TSyncSM2PV &sm2pvSync, TSyncPV2GU &pv2guSync,
TSyncPVUBBuf &pvUbBufSync, int blk_idx)
{
constexpr uint32_t Cube_S0 = CUBE_S0;
constexpr uint32_t Cube_S1 = CUBE_S1;
constexpr uint32_t Tile_S1 = TILE_S1;
constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1;
constexpr uint32_t Cube_HEAD = HEAD_SIZE;
constexpr uint32_t TileElems = Cube_S0 * Tile_S1;
static_assert(QKP_CV_FIFO >= 1, "QKP_CV_FIFO must be >= 1");
static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1");
const int s0_index = blk_idx * Cube_S0;
const int s1_index = tile_id * static_cast<int>(Tile_S1) + sub_tile_id * static_cast<int>(Cube_S1);
const int sync_iter = tile_id;
const bool should_wait_consume = should_wait_consumption<QKP_CV_FIFO, CV_FIFO_CONS_SYNC_PERIOD>(sync_iter);
const bool should_notify_consume = should_notify_consumption<QKP_CV_FIFO, CV_FIFO_CONS_SYNC_PERIOD>(sync_iter);
const bool is_last_subtile = (sub_tile_id + 1 == static_cast<int>(kTileFactor));
const bool next_will_be_skipped = (s1_index + static_cast<int>(Cube_S1)) > s0_index && CAUSAL_MASK;
if constexpr (DAV_CUBE) {
if constexpr (CAUSAL_MASK) {
if (s1_index > s0_index) {
if (sub_tile_id == 0)
sm2pvSync.wait();
if (sub_tile_id == static_cast<int>(kTileFactor) - 1 && should_notify_consume)
sm2pvSync.free();
return;
}
}
using GlobalVT =
GlobalTensor<half, pto::Shape<1, 1, 1, Cube_S1, HEAD_SIZE>, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>;
wait_flag(PIPE_MTE1, PIPE_MTE2, svMatTileEventId);
GlobalVT vLoad((__gm__ half *)(v + s1_index * HEAD_SIZE));
TLOAD(vMatTile, vLoad);
if (sub_tile_id == 0)
sm2pvSync.wait();
#if USE_UB_TO_L1_PATH
if (sub_tile_id == static_cast<int>(kTileFactor) - 1 && should_notify_consume)
sm2pvSync.free();
#else
#ifndef P_FIFO_USE_NZ
using GlobalXexpTileT =
GlobalTensor<half, pto::Shape<1, 1, 1, Cube_S1, Cube_S0>, pto::Stride<1, 1, 1, 1, Cube_S0>, Layout::DN>;
#else
using GlobalXexpTileT = GlobalTensor<half, pto::Shape<1, Cube_S1 / 16, Cube_S0 / 16, 16, 16>,
pto::Stride<Cube_S0 * Cube_S1, Cube_S0 * 16, 16 * 16, 16, 1>, Layout::NZ>;
#endif
const uint32_t buf_idx = static_cast<uint32_t>(tile_id % QKP_CV_FIFO);
const size_t base_elems =
static_cast<size_t>(buf_idx) * static_cast<size_t>(Cube_S0) * static_cast<size_t>(Tile_S1) +
static_cast<size_t>(sub_tile_id) * static_cast<size_t>(Cube_S0) * static_cast<size_t>(Cube_S1);
GlobalXexpTileT xexpLoad(p_tile_fifo + base_elems);
TLOAD(pMatTile, xexpLoad);
if (sub_tile_id == static_cast<int>(kTileFactor) - 1 && should_notify_consume)
sm2pvSync.free();
#endif
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
if (sub_tile_id == 0) {
wait_flag(PIPE_FIX, PIPE_M, accTileEvtID);
}
#if UF_ENABLE
const AccMode accMode =
(sub_tile_id == 0) ?
(is_last_subtile || next_will_be_skipped ? AccMode::InitFinalSum : AccMode::InitPartialSum) :
(is_last_subtile || next_will_be_skipped ? AccMode::AccFinalSum : AccMode::AccPartialSum);
pto_macro_matmul<Cube_S0, Cube_S1, Cube_HEAD>(pMatTile, vMatTile, pvAccTile, accMode);
#else
const AccMode accMode = (sub_tile_id == 0) ? AccMode::Init : AccMode::Acc;
pto_macro_matmul<Cube_S0, Cube_S1, Cube_HEAD>(pMatTile, vMatTile, pvAccTile, accMode);
#endif
set_flag(PIPE_MTE1, PIPE_MTE2, svMatTileEventId);
if (sub_tile_id == static_cast<int>(kTileFactor) - 1 || next_will_be_skipped) {
set_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
if (should_wait_consume)
pv2guSync.allocate();
#if USE_L0C_TO_UB_PV_PATH
if (tile_id >= static_cast<int>(OUT_O_TILE_NBUFFERS)) {
pvUbBufSync.allocate();
}
if (tile_id == 0) {
TMOV<TileOutT, TilePVData, AccToVecMode::DualModeSplitM>(runningOTile, pvAccTile);
} else {
TMOV<TileOutT, TilePVData, AccToVecMode::DualModeSplitM>(pvVecTile[pv_ub_buf_idx], pvAccTile);
}
pvUbBufSync.record();
if constexpr (INTERMEDIATE_CHECK) {
using GlobalDataPV =
GlobalTensor<float, pto::Shape<1, 1, 1, Cube_S0, HEAD_SIZE>, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>;
const uint32_t buf_idx_pv = static_cast<uint32_t>(tile_id % PV_CV_FIFO);
const size_t base_elems_pv =
static_cast<size_t>(buf_idx_pv) * static_cast<size_t>(Cube_S0) * static_cast<size_t>(HEAD_SIZE);
GlobalDataPV pvGlobalTile((__gm__ float *)(pv_tile_fifo + base_elems_pv));
#if UF_ENABLE
TSTORE<STPhase::Final>(pvGlobalTile, pvAccTile);
#else
TSTORE(pvGlobalTile, pvAccTile);
#endif
}
set_flag(PIPE_FIX, PIPE_M, accTileEvtID);
pv2guSync.record();
#else
using GlobalDataPV =
GlobalTensor<float, pto::Shape<1, 1, 1, Cube_S0, HEAD_SIZE>, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>;
const uint32_t buf_idx_pv = static_cast<uint32_t>(tile_id % PV_CV_FIFO);
const size_t base_elems_pv =
static_cast<size_t>(buf_idx_pv) * static_cast<size_t>(Cube_S0) * static_cast<size_t>(HEAD_SIZE);
GlobalDataPV pvGlobalTile((__gm__ float *)(pv_tile_fifo + base_elems_pv));
#if UF_ENABLE
TSTORE<STPhase::Final>(pvGlobalTile, pvAccTile);
#else
TSTORE(pvGlobalTile, pvAccTile);
#endif
set_flag(PIPE_FIX, PIPE_M, accTileEvtID);
pv2guSync.record();
(void)pvUbBufSync;
(void)runningOTile;
(void)pvVecTile;
#endif
}
}
}
template <int S0, int HEAD_SIZE, int S1, int CUBE_S0, int CUBE_S1, int TILE_S1, int QKP_CV_FIFO,
int CV_FIFO_CONS_SYNC_PERIOD, bool INTERMEDIATE_CHECK, bool CAUSAL_MASK, typename TileDataF_T,
typename TileDataH_T, typename TileDataH_NZ_T, typename ReduceTileF_T, typename TileMatPData,
typename TSyncQK2SM, typename TSyncSM2PV, typename TSyncUBBuf>
AICORE inline void compute_p(int tile_id, int row_slice, __gm__ float *qk_tile_fifo, __gm__ half *p_tile_fifo,
__gm__ float *exp_max_ififo, __gm__ float *global_sum_out, __gm__ float *exp_max_out,
TileDataF_T &qkVecTile, TileDataH_T &x_expT, TileDataF_T &input_reduce_tmp,
ReduceTileF_T &m1_local_max, ReduceTileF_T &l1_local_sum, ReduceTileF_T &m2_global_max,
ReduceTileF_T &l2_global_sum, ReduceTileF_T &l1_exp_max_ififo, TileMatPData &pMatTile,
TileDataH_NZ_T &nzConvBuffer, uint64_t pTileEventId, TSyncQK2SM &qk2smSync,
TSyncSM2PV sm2pvSync, TSyncUBBuf &ubBufSync, int blk_idx)
{
constexpr uint32_t Cube_S0 = CUBE_S0;
constexpr uint32_t Cube_S1 = CUBE_S1;
constexpr uint32_t Tile_S1 = TILE_S1;
constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1;
constexpr uint32_t Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor;
const bool initFlag = (tile_id == 0);
static_assert(QKP_CV_FIFO >= 1, "QKP_CV_FIFO must be >= 1");
static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1");
static_assert(Cube_S0 % (VEC_CORES * kTileFactor) == 0, "Vec rows must divide evenly across tile slices");
if constexpr (DAV_VEC) {
const size_t subblock_base_rows =
static_cast<size_t>(Cube_S0 / VEC_CORES) * static_cast<size_t>(get_subblockid());
const size_t row_offset = subblock_base_rows + static_cast<size_t>(row_slice * Vec_S0);
const int s0_index = blk_idx * Cube_S0 + row_offset;
const int s1_index = tile_id * static_cast<int>(Tile_S1);
const int sync_iter = tile_id;
const bool should_wait_consume = should_wait_consumption<QKP_CV_FIFO, CV_FIFO_CONS_SYNC_PERIOD>(sync_iter);
const bool should_notify_consume = should_notify_consumption<QKP_CV_FIFO, CV_FIFO_CONS_SYNC_PERIOD>(sync_iter);
wait_flag(PIPE_V, PIPE_MTE2, pTileEventId);
if (row_slice == 0)
qk2smSync.wait();
const uint32_t buf_idx = static_cast<uint32_t>(tile_id % QKP_CV_FIFO);
const size_t base_elems = static_cast<size_t>(buf_idx) * static_cast<size_t>(kTileFactor) *
static_cast<size_t>(Cube_S0) * static_cast<size_t>(Cube_S1);
#if USE_L0C_TO_DUAL_UB_PATH_QK
(void)base_elems;
#else
__gm__ float *qk_ptr = qk_tile_fifo + base_elems + row_offset;
using GlobalDataQK_Sub =
GlobalTensor<float, pto::Shape<1, 1, 1, Cube_S1, Vec_S0>, pto::Stride<1, 1, 1, Cube_S0, 1>>;
using TileDataF_Sub = Tile<TileType::Vec, float, Tile_S1, Vec_S0, BLayout::RowMajor, Cube_S1, Vec_S0>;
for (int sub_col = 0; sub_col < static_cast<int>(kTileFactor); ++sub_col) {
__gm__ float *qk_ptr_sub =
qk_ptr + static_cast<size_t>(sub_col) * static_cast<size_t>(Cube_S0) * static_cast<size_t>(Cube_S1);
GlobalDataQK_Sub qkGlobalSub(qk_ptr_sub);
TileDataF_Sub qkVecSub;
const uint64_t col_byte_offset = static_cast<uint64_t>(sub_col * Cube_S1 * sizeof(float));
TASSIGN(qkVecSub, (uint64_t)qkVecTile.data() + col_byte_offset);
TLOAD(qkVecSub, qkGlobalSub);
}
#endif
if (row_slice == static_cast<int>(kTileFactor) - 1 && should_notify_consume)
qk2smSync.free();
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
using ReduceSliceTile = Tile<TileType::Vec, float, 1, Vec_S0, BLayout::RowMajor, 1, Vec_S0>;
const size_t reduce_slice_rows = static_cast<size_t>(row_slice * Vec_S0);
const uint64_t reduce_row_byte_offset = reduce_slice_rows * sizeof(float);
ReduceSliceTile m1_local_max_slice;
ReduceSliceTile l1_local_sum_slice;
ReduceSliceTile m2_global_max_slice;
ReduceSliceTile l2_global_sum_slice;
ReduceSliceTile l1_exp_max_slice;
TASSIGN(m1_local_max_slice, (uint64_t)m1_local_max.data() + reduce_row_byte_offset);
TASSIGN(l1_local_sum_slice, (uint64_t)l1_local_sum.data() + reduce_row_byte_offset);
TASSIGN(m2_global_max_slice, (uint64_t)m2_global_max.data() + reduce_row_byte_offset);
TASSIGN(l2_global_sum_slice, (uint64_t)l2_global_sum.data() + reduce_row_byte_offset);
TASSIGN(l1_exp_max_slice, (uint64_t)l1_exp_max_ififo.data() + reduce_row_byte_offset);
wait_flag(PIPE_MTE3, PIPE_V, pTileEventId);
if (initFlag) {
pto_macro_fa_softmax_dn<true, HEAD_SIZE, CAUSAL_MASK>(
x_expT, qkVecTile, m1_local_max_slice, l1_local_sum_slice, m2_global_max_slice, l2_global_sum_slice,
l1_exp_max_slice, input_reduce_tmp, qkVecTile, input_reduce_tmp, s0_index, s1_index);
} else {
pto_macro_fa_softmax_dn<false, HEAD_SIZE, CAUSAL_MASK>(
x_expT, qkVecTile, m1_local_max_slice, l1_local_sum_slice, m2_global_max_slice, l2_global_sum_slice,
l1_exp_max_slice, input_reduce_tmp, qkVecTile, input_reduce_tmp, s0_index, s1_index);
}
#if USE_L0C_TO_DUAL_UB_PATH_QK
if (row_slice == static_cast<int>(kTileFactor) - 1) {
ubBufSync.free();
}
#else
(void)ubBufSync;
#endif
set_flag(PIPE_V, PIPE_MTE2, pTileEventId);
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
const bool should_wait_sv_consumed = should_wait_consumption<QKP_CV_FIFO, CV_FIFO_CONS_SYNC_PERIOD>(sync_iter);
if (row_slice == 0 && should_wait_sv_consumed)
sm2pvSync.allocate();
#if USE_UB_TO_L1_PATH
using GlobalPTileHalfSub =
GlobalTensor<half, pto::Shape<1, 1, 1, Cube_S1, Vec_S0>, pto::Stride<1, 1, 1, Cube_S0, 1>>;
using TileDataH_Sub = Tile<TileType::Vec, half, Tile_S1, Vec_S0, BLayout::RowMajor, Cube_S1, Vec_S0>;
__gm__ half *p_ptr = p_tile_fifo + base_elems + row_offset;
for (int sub_col = 0; sub_col < static_cast<int>(kTileFactor); ++sub_col) {
using TileDataH_Sub_ND = Tile<TileType::Vec, half, Cube_S1, Vec_S0, BLayout::RowMajor, Cube_S1, Vec_S0>;
TileDataH_Sub_ND xExpSubND;
const uint64_t col_byte_offset = static_cast<uint64_t>(sub_col * Cube_S1 * Vec_S0 * sizeof(half));
TASSIGN(xExpSubND, (uint64_t)x_expT.data() + col_byte_offset);
if constexpr (INTERMEDIATE_CHECK) {
__gm__ half *p_ptr_sub =
p_ptr + static_cast<size_t>(sub_col) * static_cast<size_t>(Cube_S1) * static_cast<size_t>(Cube_S0);
GlobalPTileHalfSub pTileHalfSub((__gm__ half *)(p_ptr_sub));
TileDataH_Sub xExpSub;
TASSIGN(xExpSub, (uint64_t)x_expT.data() + col_byte_offset);
TSTORE(pTileHalfSub, xExpSub);
}
TMOV(nzConvBuffer, xExpSubND);
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
uint16_t col_offset = static_cast<uint16_t>(Vec_S0 * static_cast<size_t>(get_subblockid()));
TINSERT(pMatTile, nzConvBuffer, static_cast<uint16_t>(0), col_offset);
}
(void)global_sum_out;
(void)exp_max_out;
#else
using GlobalPTileHalfSub =
GlobalTensor<half, pto::Shape<1, 1, 1, Cube_S1, Vec_S0>, pto::Stride<1, 1, 1, Cube_S0, 1>>;
using TileDataH_Sub = Tile<TileType::Vec, half, Tile_S1, Vec_S0, BLayout::RowMajor, Cube_S1, Vec_S0>;
__gm__ half *p_ptr = p_tile_fifo + base_elems + row_offset;
for (int sub_col = 0; sub_col < static_cast<int>(kTileFactor); ++sub_col) {
__gm__ half *p_ptr_sub =
p_ptr + static_cast<size_t>(sub_col) * static_cast<size_t>(Cube_S1) * static_cast<size_t>(Cube_S0);
GlobalPTileHalfSub pTileHalfSub((__gm__ half *)(p_ptr_sub));
TileDataH_Sub xExpSub;
const uint64_t col_byte_offset = static_cast<uint64_t>(sub_col * Cube_S1 * Vec_S0 * sizeof(half));
TASSIGN(xExpSub, (uint64_t)x_expT.data() + col_byte_offset);
TSTORE(pTileHalfSub, xExpSub);
}
(void)nzConvBuffer;
(void)pMatTile;
#endif
if constexpr (INTERMEDIATE_CHECK) {
if (row_slice == static_cast<int>(kTileFactor) - 1) {
constexpr uint32_t SubblockRows = Cube_S0 / VEC_CORES;
using GlobalPMaxFloatSub =
GlobalTensor<float, pto::Shape<1, 1, 1, 1, SubblockRows>, pto::Stride<1, 1, 1, Cube_S0, 1>>;
using ExpMaxSub = Tile<TileType::Vec, float, 1, SubblockRows, BLayout::RowMajor, 1, SubblockRows>;
const size_t base_elems_pmax =
static_cast<size_t>(buf_idx) * static_cast<size_t>(Cube_S0) + subblock_base_rows;
__gm__ float *p_ptr_fp32 = exp_max_ififo + base_elems_pmax;
GlobalPMaxFloatSub pMaxGlobal(p_ptr_fp32);
ExpMaxSub l1_exp_max_rowmajor;
TRESHAPE(l1_exp_max_rowmajor, l1_exp_max_ififo);
TSTORE(pMaxGlobal, l1_exp_max_rowmajor);
}
}
if (row_slice == static_cast<int>(kTileFactor) - 1)
sm2pvSync.record();
set_flag(PIPE_MTE3, PIPE_V, pTileEventId);
}
}
template <int S0, int HEAD_SIZE, int S1, int CUBE_S0, int TILE_S1, int PV_CV_FIFO, int CV_FIFO_CONS_SYNC_PERIOD,
bool INTERMEDIATE_CHECK, bool CAUSAL_MASK, int SRC_VEC_TN_BUFFERS, int OUT_O_TILE_NBUFFERS, typename TileOutT,
typename ReduceTileF_T, typename TSyncPV2GU, typename TSyncUBBuf, typename TSyncPVUBBuf>
AICORE inline void compute_gu(int tile_id, int num_tiles, __gm__ float *pv_tile_fifo, __gm__ float *o_out,
__gm__ float *o_parts_out, TileOutT &runningOTile, TileOutT &pvVecTile,
ReduceTileF_T &l1_exp_max_ififo, ReduceTileF_T &l2_global_sum, uint64_t guEventId,
TSyncPV2GU &pv2guSync, TSyncUBBuf &ubBufSync, TSyncPVUBBuf &pvUbBufSync)
{
constexpr uint32_t Cube_S0 = CUBE_S0;
constexpr uint32_t Vec_S0 = Cube_S0 / VEC_CORES;
using GlobalDataPV_VEC =
GlobalTensor<float, pto::Shape<1, 1, 1, Vec_S0, HEAD_SIZE>, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>;
if constexpr (DAV_VEC) {
const uint32_t buf_idx = static_cast<uint32_t>(tile_id % PV_CV_FIFO);
const size_t base_elems =
static_cast<size_t>(buf_idx) * static_cast<size_t>(Cube_S0) * static_cast<size_t>(HEAD_SIZE);
const size_t subblock_base_rows =
static_cast<size_t>(Cube_S0 / VEC_CORES) * static_cast<size_t>(get_subblockid());
__gm__ float *pv_out_ptr = pv_tile_fifo + base_elems + subblock_base_rows * HEAD_SIZE;
GlobalDataPV_VEC pvGlobalVec(pv_out_ptr);
pv2guSync.wait();
const bool should_notify_consume = should_notify_consumption<PV_CV_FIFO, CV_FIFO_CONS_SYNC_PERIOD>(tile_id);
wait_flag(PIPE_V, PIPE_MTE2, guEventId);
(void)ubBufSync;
#if USE_L0C_TO_UB_PV_PATH
pvUbBufSync.wait();
if (tile_id > 0) {
if (tile_id < num_tiles - 1) {
pto_macro_fa_gu<ReduceTileF_T, TileOutT>(runningOTile, pvVecTile, l1_exp_max_ififo);
} else {
pto_macro_fa_gu_last<ReduceTileF_T, TileOutT>(runningOTile, pvVecTile, l1_exp_max_ififo, l2_global_sum);
}
} else {
if constexpr (CAUSAL_MASK) {
if (tile_id == num_tiles - 1)
pto_macro_fa_gu_single_and_last_tile(runningOTile, l2_global_sum);
}
}
pvUbBufSync.free();
#else
(void)pvUbBufSync;
if (tile_id == 0) {
TLOAD(runningOTile, pvGlobalVec);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
if constexpr (CAUSAL_MASK) {
if (tile_id == num_tiles - 1)
pto_macro_fa_gu_single_and_last_tile(runningOTile, l2_global_sum);
}
} else {
TLOAD(pvVecTile, pvGlobalVec);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
if (tile_id < num_tiles - 1) {
pto_macro_fa_gu<ReduceTileF_T, TileOutT>(runningOTile, pvVecTile, l1_exp_max_ififo);
} else {
pto_macro_fa_gu_last<ReduceTileF_T, TileOutT>(runningOTile, pvVecTile, l1_exp_max_ififo, l2_global_sum);
}
}
#endif
set_flag(PIPE_V, PIPE_MTE2, guEventId);
if (should_notify_consume)
pv2guSync.free();
if (tile_id == num_tiles - 1) {
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
using GlobalOutT =
GlobalTensor<float, pto::Shape<1, 1, 1, Vec_S0, HEAD_SIZE>, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>;
GlobalOutT outGlobal((__gm__ float *)(o_out + subblock_base_rows * HEAD_SIZE));
TSTORE(outGlobal, runningOTile);
}
}
}
template <int S0, int HEAD_SIZE, int S1, int CUBE_S0, int CUBE_S1, int TILE_S1, int QK_PRELOAD, int CV_FIFO_SIZE,
bool INTERMEDIATE_CHECK, bool CAUSAL_MASK, int CV_FIFO_CONS_SYNC_PERIOD>
__global__ AICORE void runTFA(__gm__ uint64_t *ffts_addr, __gm__ half *q, __gm__ half *k, __gm__ half *v,
__gm__ half *p_tile_fifo, __gm__ float *exp_max_ififo, __gm__ float *global_sum_out,
__gm__ float *exp_max_out, __gm__ float *o_out, __gm__ float *o_parts_out,
__gm__ float *qk_tile_fifo, __gm__ float *pv_tile_fifo, __gm__ uint8_t *cv_comm_buf,
__gm__ uint8_t *profile_buf)
{
uint64_t tStart = get_sys_cnt();
set_ffts_base_addr((uint64_t)ffts_addr);
if constexpr (DAV_CUBE) {
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0);
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
}
constexpr uint32_t Cube_S0 = CUBE_S0;
constexpr uint32_t block_rows = S0 / CUBE_S0;
constexpr uint32_t Cube_S1 = CUBE_S1;
constexpr uint32_t Tile_S1 = TILE_S1;
static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1");
constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1;
constexpr uint32_t Cube_HEAD = HEAD_SIZE;
constexpr uint32_t Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor;
constexpr uint32_t VecGuRows = Cube_S0 / VEC_CORES;
static_assert(Cube_S0 % (VEC_CORES * kTileFactor) == 0, "Vec rows must divide evenly across tile slices");
constexpr uint32_t qkPreloadNum = QK_PRELOAD;
constexpr uint32_t srcVecTNBuffers = 2;
constexpr uint32_t xexpVecTNBuffers = 2;
constexpr uint32_t outOTileNBuffers = 2;
constexpr uint32_t qMatTNBuffers = 1;
constexpr uint32_t kMatTNBuffers = 2;
constexpr uint32_t pMatTNBuffers = 2;
constexpr uint32_t vMatTNBuffers = 2;
constexpr uint32_t qkp_tile_fifo_size = CV_FIFO_SIZE;
constexpr uint32_t pv_tile_fifo_size = CV_FIFO_SIZE;
static_assert(qkPreloadNum >= 1, "qkPreloadNum must be >= 1");
static_assert(CV_FIFO_CONS_SYNC_PERIOD >= 1, "CV_FIFO_CONS_SYNC_PERIOD must be >= 1");
static_assert((qkPreloadNum > 1) || (kTileFactor == 1), "qkPreloadNum must be > 1 unless kTileFactor == 1");
#if USE_UB_TO_L1_PATH
static_assert(qkPreloadNum <= pMatTNBuffers,
"USE_UB_TO_L1_PATH requires qkPreloadNum <= pMatTNBuffers (2) to avoid buffer races. "
"Use --qk-preload 2 when running with UB mode enabled.");
#endif
using TileMatQData =
Tile<TileType::Mat, half, HEAD_SIZE, Cube_S0, BLayout::RowMajor, HEAD_SIZE, Cube_S0, SLayout::ColMajor, 512>;
using TileMatKData =
Tile<TileType::Mat, half, Cube_S1, HEAD_SIZE, BLayout::ColMajor, Cube_S1, HEAD_SIZE, SLayout::RowMajor, 512>;
using TileQKData = TileAcc<float, Cube_S1, Cube_S0, Cube_S1, Cube_S0>;
TileMatQData qMatTile[qMatTNBuffers];
TileMatKData kMatTile[kMatTNBuffers];
TileQKData qkAccTile;
using TileMatPData =
Tile<TileType::Mat, half, Cube_S0, Cube_S1, BLayout::RowMajor, Cube_S0, Cube_S1, SLayout::ColMajor, 512>;
using TileMatVData =
Tile<TileType::Mat, half, Cube_S1, HEAD_SIZE, BLayout::ColMajor, Cube_S1, HEAD_SIZE, SLayout::RowMajor, 512>;
using TilePVData = TileAcc<float, Cube_S0, HEAD_SIZE, Cube_S0, HEAD_SIZE>;
TileMatPData pMatTile[pMatTNBuffers];
TileMatVData vMatTile[vMatTNBuffers];
TilePVData pvAccTile;
allocate_cube_tile_buffers(qMatTile, kMatTile, pMatTile, vMatTile);
assign_running_acc_tile(qkAccTile, 0);
assign_running_acc_tile(pvAccTile, 1);
using TileDataF_T = Tile<TileType::Vec, float, Tile_S1, Vec_S0, BLayout::RowMajor, Tile_S1, Vec_S0>;
using TileDataH_T = Tile<TileType::Vec, half, Tile_S1, Vec_S0, BLayout::RowMajor, Tile_S1, Vec_S0>;
constexpr uint32_t SubblockRows = Cube_S0 / VEC_CORES;
using ReduceTileF_T = Tile<TileType::Vec, float, 1, SubblockRows, BLayout::RowMajor, 1, SubblockRows>;
constexpr uint32_t NzBufRows = Cube_S1 + 1;
using TileDataH_NZ_T = Tile<TileType::Vec, half, NzBufRows, Vec_S0, BLayout::ColMajor, Cube_S1, Vec_S0,
SLayout::RowMajor, 512, PadValue::Null, CompactMode::RowPlusOne>;
TileDataF_T qkVecTile[srcVecTNBuffers];
ReduceTileF_T m1_local_max;
TileDataF_T input_reduce_tmp;
ReduceTileF_T l1_local_sum;
ReduceTileF_T m2_global_max;
ReduceTileF_T l2_global_sum;
ReduceTileF_T l1_exp_max_ififo[qkp_tile_fifo_size];
TileDataH_T x_expT[xexpVecTNBuffers];
TileDataH_NZ_T nzConvBuffer;
using TileOutGuT = Tile<TileType::Vec, float, VecGuRows, HEAD_SIZE, BLayout::RowMajor, VecGuRows, HEAD_SIZE>;
TileOutGuT pvVecTile[outOTileNBuffers];
TileOutGuT runningOTile;
allocate_vec_tile_buffers<TileDataF_T, ReduceTileF_T, TileDataH_T, TileOutGuT, srcVecTNBuffers, xexpVecTNBuffers,
outOTileNBuffers>(qkVecTile, m1_local_max, input_reduce_tmp, l1_local_sum, m2_global_max,
l2_global_sum, l1_exp_max_ififo, x_expT, pvVecTile, runningOTile);
constexpr uint32_t nzBufSize = NzBufRows * Vec_S0 * sizeof(half);
constexpr uint32_t nzBufOffset = MAX_VEC_UB_BYTES - nzBufSize;
if constexpr (DAV_VEC) {
TASSIGN(nzConvBuffer, nzBufOffset);
}
const int block_offset_rows = block_idx * static_cast<int>(Cube_S0);
constexpr bool use_cv_comm = (!INTERMEDIATE_CHECK) && (block_rows >= static_cast<uint32_t>(pto::kCvMaxCores));
int comm_slot = block_idx;
if constexpr (use_cv_comm) {
comm_slot = pto::TSYNC_CVID(block_idx, cv_comm_buf);
}
__gm__ uint64_t *profile_entry = nullptr;
if (profile_buf != nullptr) {
std::size_t profile_block_base = static_cast<std::size_t>(block_idx) * kFaProfileBytesPerBlock;
std::size_t profile_offset = profile_block_base;
if constexpr (DAV_VEC) {
profile_offset +=
(static_cast<std::size_t>(get_subblockid()) + 1U) * 1024U;
}
profile_entry = reinterpret_cast<__gm__ uint64_t *>(profile_buf + profile_offset);
profile_entry[0] = tStart;
}
const size_t p_fifo_block_stride =
static_cast<size_t>(qkp_tile_fifo_size) * static_cast<size_t>(Cube_S0) * static_cast<size_t>(Tile_S1);
const size_t p_max_fifo_block_stride = static_cast<size_t>(qkp_tile_fifo_size) * static_cast<size_t>(Cube_S0);
const size_t qk_fifo_block_stride = p_fifo_block_stride;
const size_t pv_fifo_block_stride =
static_cast<size_t>(pv_tile_fifo_size) * static_cast<size_t>(Cube_S0) * static_cast<size_t>(HEAD_SIZE);
__gm__ half *q_block = q + block_offset_rows * HEAD_SIZE;
__gm__ half *p_tile_fifo_block = p_tile_fifo + static_cast<size_t>(comm_slot) * p_fifo_block_stride;
__gm__ float *exp_max_ififo_block = exp_max_ififo + static_cast<size_t>(comm_slot) * p_max_fifo_block_stride;
__gm__ float *global_sum_block = global_sum_out + block_offset_rows;
__gm__ float *exp_max_block = exp_max_out + block_offset_rows;
__gm__ float *o_out_block = o_out + static_cast<size_t>(block_offset_rows) * static_cast<size_t>(HEAD_SIZE);
__gm__ float *o_parts_block = o_parts_out + static_cast<size_t>(block_offset_rows) * static_cast<size_t>(HEAD_SIZE);
__gm__ float *qk_tile_fifo_block = qk_tile_fifo + static_cast<size_t>(comm_slot) * qk_fifo_block_stride;
__gm__ float *pv_tile_fifo_block = pv_tile_fifo + static_cast<size_t>(comm_slot) * pv_fifo_block_stride;
constexpr TSync_Custom<SyncOpType::TSTORE_C2GM, SyncOpType::TLOAD> qk2smSync = {BUF0_QK_READY};
#if USE_UB_TO_L1_PATH
constexpr TSync_Custom<SyncOpType::TINSERT_V2L1, SyncOpType::TLOAD> sm2pvSync = {BUF1_SM_READY};
#else
constexpr TSync_Custom<SyncOpType::TSTORE_V2GM, SyncOpType::TLOAD> sm2pvSync = {BUF1_SM_READY};
#endif
#if USE_L0C_TO_UB_PV_PATH
constexpr TSync_Custom<SyncOpType::TMOV_C2UB, SyncOpType::TLOAD> pv2guSync = {UPDATE_READY};
#else
constexpr TSync_Custom<SyncOpType::TSTORE_C2GM, SyncOpType::TLOAD> pv2guSync = {UPDATE_READY};
#endif
constexpr TSync_Custom<SyncOpType::TMOV_C2UB, SyncOpType::TLOAD> ubBufSync = {UB_BUF_READY};
constexpr TSync_Custom<SyncOpType::TMOV_C2UB, SyncOpType::TLOAD> pvUbBufSync = {PV_UB_BUF_READY};
int num_tiles_s1 = S1 / Tile_S1;
if constexpr (CAUSAL_MASK)
num_tiles_s1 = (1 + ((block_idx * CUBE_S0) / Tile_S1));
if constexpr (DAV_CUBE) {
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0);
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2);
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3);
set_flag(PIPE_FIX, PIPE_M, EVENT_ID0);
set_flag(PIPE_FIX, PIPE_M, EVENT_ID1);
}
if constexpr (DAV_VEC) {
set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0);
set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1);
set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0);
set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1);
}
int p_gu_src_pingpong_id = 0;
int k_src_pingpong_id = 0;
int pv_src_pingpong_id = 0;
int qkAccTileEvtID = 0;
int pvAccTileEvtID = 0;
for (int preload_tile = 0; preload_tile < static_cast<int>(qkPreloadNum) && preload_tile < num_tiles_s1;
++preload_tile) {
if constexpr (DAV_CUBE) {
for (int sub_tile = 0; sub_tile < static_cast<int>(kTileFactor); ++sub_tile) {
qkAccTileEvtID = assign_running_acc_tile(qkAccTile);
#if USE_L0C_TO_DUAL_UB_PATH_QK
const int tile_buf_idx = preload_tile % srcVecTNBuffers;
compute_qk<S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, Tile_S1, qkp_tile_fifo_size, CV_FIFO_CONS_SYNC_PERIOD,
INTERMEDIATE_CHECK, CAUSAL_MASK, srcVecTNBuffers>(
preload_tile, sub_tile, tile_buf_idx, q_block, k, qk_tile_fifo_block, qMatTile[0],
kMatTile[k_src_pingpong_id % kMatTNBuffers], qkAccTile, qkVecTile[tile_buf_idx],
k_src_pingpong_id % kMatTNBuffers, qkAccTileEvtID, qk2smSync, ubBufSync, block_idx);
#else
compute_qk<S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, Tile_S1, qkp_tile_fifo_size, CV_FIFO_CONS_SYNC_PERIOD,
INTERMEDIATE_CHECK, CAUSAL_MASK, srcVecTNBuffers>(
preload_tile, sub_tile, 0, q_block, k, qk_tile_fifo_block, qMatTile[0],
kMatTile[k_src_pingpong_id % kMatTNBuffers], qkAccTile, qkVecTile[0],
k_src_pingpong_id % kMatTNBuffers, qkAccTileEvtID, qk2smSync, ubBufSync, block_idx);
#endif
k_src_pingpong_id++;
}
}
if constexpr (DAV_VEC) {
for (int row_slice = 0; row_slice < static_cast<int>(kTileFactor); ++row_slice) {
#if USE_L0C_TO_DUAL_UB_PATH_QK
const int tile_buf_idx = preload_tile % srcVecTNBuffers;
compute_p<S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, Tile_S1, qkp_tile_fifo_size, CV_FIFO_CONS_SYNC_PERIOD,
INTERMEDIATE_CHECK, CAUSAL_MASK>(
preload_tile, row_slice, qk_tile_fifo_block, p_tile_fifo_block, exp_max_ififo_block,
global_sum_block, exp_max_block, qkVecTile[tile_buf_idx],
x_expT[p_gu_src_pingpong_id % xexpVecTNBuffers], input_reduce_tmp, m1_local_max, l1_local_sum,
m2_global_max, l2_global_sum, l1_exp_max_ififo[preload_tile % qkp_tile_fifo_size],
pMatTile[preload_tile % pMatTNBuffers], nzConvBuffer, p_gu_src_pingpong_id % xexpVecTNBuffers,
qk2smSync, sm2pvSync, ubBufSync, block_idx);
#else
compute_p<S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, Tile_S1, qkp_tile_fifo_size, CV_FIFO_CONS_SYNC_PERIOD,
INTERMEDIATE_CHECK, CAUSAL_MASK>(
preload_tile, row_slice, qk_tile_fifo_block, p_tile_fifo_block, exp_max_ififo_block,
global_sum_block, exp_max_block, qkVecTile[p_gu_src_pingpong_id % srcVecTNBuffers],
x_expT[p_gu_src_pingpong_id % xexpVecTNBuffers], input_reduce_tmp, m1_local_max, l1_local_sum,
m2_global_max, l2_global_sum, l1_exp_max_ififo[preload_tile % qkp_tile_fifo_size],
pMatTile[preload_tile % pMatTNBuffers], nzConvBuffer, p_gu_src_pingpong_id % xexpVecTNBuffers,
qk2smSync, sm2pvSync, ubBufSync, block_idx);
#endif
p_gu_src_pingpong_id++;
}
}
}
for (int tile_id = 0; tile_id < num_tiles_s1; ++tile_id) {
int next_qk_tile = (tile_id + static_cast<int>(qkPreloadNum) >= num_tiles_s1) ?
-1 :
(tile_id + static_cast<int>(qkPreloadNum));
if (next_qk_tile != -1)
qkAccTileEvtID = assign_running_acc_tile(qkAccTile);
pvAccTileEvtID = assign_running_acc_tile(pvAccTile);
for (int sub_tile = 0; sub_tile < static_cast<int>(kTileFactor); ++sub_tile) {
if constexpr (DAV_CUBE) {
if (next_qk_tile != -1) {
#if USE_L0C_TO_DUAL_UB_PATH_QK
const int tile_buf_idx = next_qk_tile % srcVecTNBuffers;
compute_qk<S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, Tile_S1, qkp_tile_fifo_size,
CV_FIFO_CONS_SYNC_PERIOD, INTERMEDIATE_CHECK, CAUSAL_MASK, srcVecTNBuffers>(
next_qk_tile, sub_tile, tile_buf_idx, q_block, k, qk_tile_fifo_block, qMatTile[0],
kMatTile[k_src_pingpong_id % kMatTNBuffers], qkAccTile, qkVecTile[tile_buf_idx],
k_src_pingpong_id % kMatTNBuffers, qkAccTileEvtID, qk2smSync, ubBufSync, block_idx);
#else
compute_qk<S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, Tile_S1, qkp_tile_fifo_size,
CV_FIFO_CONS_SYNC_PERIOD, INTERMEDIATE_CHECK, CAUSAL_MASK, srcVecTNBuffers>(
next_qk_tile, sub_tile, 0, q_block, k, qk_tile_fifo_block, qMatTile[0],
kMatTile[k_src_pingpong_id % kMatTNBuffers], qkAccTile, qkVecTile[0],
k_src_pingpong_id % kMatTNBuffers, qkAccTileEvtID, qk2smSync, ubBufSync, block_idx);
#endif
k_src_pingpong_id++;
}
}
if constexpr (DAV_VEC) {
if (next_qk_tile != -1) {
#if USE_L0C_TO_DUAL_UB_PATH_QK
const int tile_buf_idx = next_qk_tile % srcVecTNBuffers;
compute_p<S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, Tile_S1, qkp_tile_fifo_size,
CV_FIFO_CONS_SYNC_PERIOD, INTERMEDIATE_CHECK, CAUSAL_MASK>(
next_qk_tile, sub_tile, qk_tile_fifo_block, p_tile_fifo_block, exp_max_ififo_block,
global_sum_block, exp_max_block, qkVecTile[tile_buf_idx],
x_expT[p_gu_src_pingpong_id % xexpVecTNBuffers], input_reduce_tmp, m1_local_max, l1_local_sum,
m2_global_max, l2_global_sum, l1_exp_max_ififo[next_qk_tile % qkp_tile_fifo_size],
pMatTile[next_qk_tile % pMatTNBuffers], nzConvBuffer, p_gu_src_pingpong_id % xexpVecTNBuffers,
qk2smSync, sm2pvSync, ubBufSync, block_idx);
#else
compute_p<S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, Tile_S1, qkp_tile_fifo_size,
CV_FIFO_CONS_SYNC_PERIOD, INTERMEDIATE_CHECK, CAUSAL_MASK>(
next_qk_tile, sub_tile, qk_tile_fifo_block, p_tile_fifo_block, exp_max_ififo_block,
global_sum_block, exp_max_block, qkVecTile[p_gu_src_pingpong_id % srcVecTNBuffers],
x_expT[p_gu_src_pingpong_id % xexpVecTNBuffers], input_reduce_tmp, m1_local_max, l1_local_sum,
m2_global_max, l2_global_sum, l1_exp_max_ififo[next_qk_tile % qkp_tile_fifo_size],
pMatTile[next_qk_tile % pMatTNBuffers], nzConvBuffer, p_gu_src_pingpong_id % xexpVecTNBuffers,
qk2smSync, sm2pvSync, ubBufSync, block_idx);
#endif
p_gu_src_pingpong_id++;
}
}
if constexpr (DAV_CUBE) {
compute_pv<S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, Tile_S1, qkp_tile_fifo_size, pv_tile_fifo_size,
CV_FIFO_CONS_SYNC_PERIOD, INTERMEDIATE_CHECK, CAUSAL_MASK, outOTileNBuffers>(
tile_id, sub_tile, tile_id % outOTileNBuffers, p_tile_fifo_block, v, pv_tile_fifo_block,
pMatTile[pv_src_pingpong_id % pMatTNBuffers], vMatTile[pv_src_pingpong_id % vMatTNBuffers],
pvAccTile, runningOTile, pvVecTile, pv_src_pingpong_id % vMatTNBuffers + PV_EVENT_ID0,
pvAccTileEvtID, sm2pvSync, pv2guSync, pvUbBufSync, block_idx);
pv_src_pingpong_id++;
}
}
if constexpr (DAV_VEC) {
compute_gu<S0, HEAD_SIZE, S1, CUBE_S0, Tile_S1, pv_tile_fifo_size, CV_FIFO_CONS_SYNC_PERIOD,
INTERMEDIATE_CHECK, CAUSAL_MASK, srcVecTNBuffers, outOTileNBuffers>(
tile_id, num_tiles_s1, pv_tile_fifo_block, o_out_block, o_parts_block, runningOTile,
pvVecTile[tile_id % outOTileNBuffers], l1_exp_max_ififo[tile_id % qkp_tile_fifo_size], l2_global_sum,
tile_id % outOTileNBuffers, pv2guSync, ubBufSync, pvUbBufSync);
p_gu_src_pingpong_id++;
}
}
const int pending_qk_sm_consumed =
pending_consumption_events(num_tiles_s1, static_cast<int>(qkp_tile_fifo_size), CV_FIFO_CONS_SYNC_PERIOD);
const int pending_sv_consumed = pending_qk_sm_consumed;
const int pending_update_consumed =
pending_consumption_events(num_tiles_s1, static_cast<int>(qkp_tile_fifo_size), CV_FIFO_CONS_SYNC_PERIOD);
if constexpr (DAV_CUBE) {
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2);
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3);
wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0);
wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1);
for (int i = 0; i < pending_qk_sm_consumed; ++i)
qk2smSync.allocate();
for (int i = 0; i < pending_update_consumed; ++i)
pv2guSync.allocate();
#if USE_L0C_TO_DUAL_UB_PATH_QK
{
const int ub_drain_count =
(num_tiles_s1 < static_cast<int>(srcVecTNBuffers)) ? num_tiles_s1 : static_cast<int>(srcVecTNBuffers);
for (int i = 0; i < ub_drain_count; ++i)
ubBufSync.allocate();
}
#endif
#if USE_L0C_TO_UB_PV_PATH
{
const int pv_ub_drain_count =
(num_tiles_s1 < static_cast<int>(outOTileNBuffers)) ? num_tiles_s1 : static_cast<int>(outOTileNBuffers);
for (int i = 0; i < pv_ub_drain_count; ++i)
pvUbBufSync.allocate();
}
#endif
}
if constexpr (DAV_VEC) {
wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1);
wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1);
for (int i = 0; i < pending_sv_consumed; ++i)
sm2pvSync.allocate();
}
pipe_barrier(PIPE_ALL);
uint64_t tEnd = get_sys_cnt();
if (profile_entry != nullptr) {
profile_entry[1] = tEnd;
}
#ifdef _DEBUG
if constexpr (DAV_CUBE) {
cce::printf("Core %d Cube Block %d, Start @%d End @%d (%d us)\n", get_coreid(), block_idx, int(tStart),
int(tEnd), int(tEnd - tStart) * 20 / 1000);
} else {
cce::printf("Core %d Vec Block %d, SubBlock %d, Start @%d End @%d (%d us)\n", get_coreid(), block_idx,
int(get_subblockid()), int(tStart), int(tEnd), int(tEnd - tStart) * 20 / 1000);
}
#endif
}
__global__ AICORE __attribute__((aic)) void warmup_kernel()
{}
template <int S0, int HEAD_SIZE, int S1, int CUBE_S0, int CUBE_S1, int TILE_S1, int QK_PRELOAD, int CV_FIFO_SIZE,
bool INTERMEDIATE_CHECK, bool CAUSAL_MASK, int CV_FIFO_CONS_SYNC_PERIOD>
void LaunchTFA(uint16_t *ffts, aclFloat16 *q, aclFloat16 *k, aclFloat16 *v, aclFloat16 *p_tile_fifo,
float *exp_max_ififo, float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out,
float *qk_tile_fifo, float *pv_tile_fifo, uint8_t *profile_data, aclrtStream stream,
uint8_t *cv_comm_buf)
{
static_assert(S0 % CUBE_S0 == 0, "S0 must be divisible by CUBE_S0");
constexpr uint32_t block_rows = S0 / CUBE_S0;
warmup_kernel<<<32, nullptr, stream>>>();
const uint64_t tensor_elems = static_cast<uint64_t>(S0) * static_cast<uint64_t>(HEAD_SIZE);
const uint64_t tensor_bytes = tensor_elems * sizeof(half);
constexpr bool kPrefetchUseSdma = true;
constexpr int kPrefetchAivCores = 64;
if constexpr (kPrefetchUseSdma) {
PTO_PREFETCH((__gm__ void *)q, tensor_bytes, stream);
PTO_PREFETCH((__gm__ void *)k, tensor_bytes, stream);
PTO_PREFETCH((__gm__ void *)v, tensor_bytes, stream);
} else {
PTO_PREFETCH<false, kPrefetchAivCores>((__gm__ void *)q, tensor_bytes, stream);
PTO_PREFETCH<false, kPrefetchAivCores>((__gm__ void *)k, tensor_bytes, stream);
PTO_PREFETCH<false, kPrefetchAivCores>((__gm__ void *)v, tensor_bytes, stream);
}
runTFA<S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, CV_FIFO_SIZE, INTERMEDIATE_CHECK, CAUSAL_MASK,
CV_FIFO_CONS_SYNC_PERIOD><<<block_rows, nullptr, stream>>>(
(__gm__ uint64_t *)ffts, (half *)q, (half *)k, (half *)v, (half *)p_tile_fifo, exp_max_ififo, global_sum_out,
exp_max_out, o_out, o_parts_out, qk_tile_fifo, pv_tile_fifo, cv_comm_buf, profile_data);
}
template <int S0, int HEAD_SIZE, int S1, int CUBE_S0, int CUBE_S1, int TILE_S1, int QK_PRELOAD, int CV_FIFO_SIZE,
bool INTERMEDIATE_CHECK, bool CAUSAL_MASK, int CV_FIFO_CONS_SYNC_PERIOD>
void LaunchTFA(uint16_t *ffts, aclFloat16 *q, aclFloat16 *k, aclFloat16 *v, aclFloat16 *p_tile_fifo,
float *exp_max_ififo, float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out,
float *qk_tile_fifo, float *pv_tile_fifo, aclrtStream stream, uint8_t *cv_comm_buf)
{
LaunchTFA<S0, HEAD_SIZE, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, CV_FIFO_SIZE, INTERMEDIATE_CHECK, CAUSAL_MASK,
CV_FIFO_CONS_SYNC_PERIOD>(ffts, q, k, v, p_tile_fifo, exp_max_ififo, global_sum_out, exp_max_out, o_out,
o_parts_out, qk_tile_fifo, pv_tile_fifo, nullptr, stream, cv_comm_buf);
}
#include "generated_cases.h"
#define INSTANTIATE_TFA(S0, HEAD, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, CAUSAL_MASK) \
template void LaunchTFA<S0, HEAD, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, kFaCvFifoSize, false, CAUSAL_MASK, \
kFaCvFifoConsSyncPeriod>( \
uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \
float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \
uint8_t *profile_data, aclrtStream stream, uint8_t *cv_comm_buf); \
template void LaunchTFA<S0, HEAD, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, kFaCvFifoSize, false, CAUSAL_MASK, \
kFaCvFifoConsSyncPeriod>( \
uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \
float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \
aclrtStream stream, uint8_t *cv_comm_buf); \
template void LaunchTFA<S0, HEAD, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, kFaCvFifoSize, true, CAUSAL_MASK, \
kFaCvFifoConsSyncPeriod>( \
uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \
float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \
uint8_t *profile_data, aclrtStream stream, uint8_t *cv_comm_buf); \
template void LaunchTFA<S0, HEAD, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, kFaCvFifoSize, true, CAUSAL_MASK, \
kFaCvFifoConsSyncPeriod>( \
uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \
float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \
aclrtStream stream, uint8_t *cv_comm_buf);
TFA_FOR_EACH_CASE(INSTANTIATE_TFA)
#undef INSTANTIATE_TFA