/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
#include <acl/acl.h>
#include <iostream>

#include "catlass/catlass.hpp"
#include "catlass/arch/arch.hpp"
#include "catlass/layout/layout.hpp"

#include "catlass/gemm/block/block_mmad.hpp"
#include "catlass/gemm/dispatch_policy.hpp"
#include "catlass/gemm/gemm_type.hpp"

#include "catlass/arch/cross_core_sync.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/epilogue/block/block_epilogue.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"

#include "tiling/kernel_common.hpp"
#include "tiling/fai_tiling.h"
#include "kernel_operator.h"

#include "catlass_kernel_prebuilt.h"
#include "../common/workspace_alloc.h"



namespace CatlassKernel {
using namespace Catlass;
// using namespace FAInferTiling;

// Macro function for unwinding acl errors.
#define ACL_CHECK(status)                                                                    \
    do {                                                                                     \
        aclError error = status;                                                             \
        if (error != ACL_ERROR_NONE) {                                                       \
            std::cerr << __FILE__ << ":" << __LINE__ << " aclError:" << error << std::endl;  \
        }                                                                                    \
    } while (0)

// Macro function for unwinding rt errors.
#define RT_CHECK(status)                                                                     \
    do {                                                                                     \
        int32_t error = status;                                                                 \
        if (error != 0) {                                                                    \
            std::cerr << __FILE__ << ":" << __LINE__ << " rtError:" << error << std::endl;   \
        }                                                                                    \
    } while (0)

 template <
    class BlockMmadQK,
    class BlockMmadPV,
    class BlockMmadQKTail,
    class BlockMmadPVTail,
    class EpilogueOnlineSoftmax,
    class EpilogueRescaleO,
    bool PAGED_CACHE_FLAG>
class FAInferKernel {
public:
    using ArchTag = typename BlockMmadQK::ArchTag;
    using L1TileShape = typename BlockMmadQK::L1TileShape;
    using ElementQ = typename BlockMmadQK::ElementA;
    using LayoutQ = typename BlockMmadQK::LayoutA;
    using ElementK = typename BlockMmadQK::ElementB;
    using LayoutK = typename BlockMmadQK::LayoutB;
    using ElementS = typename BlockMmadQK::ElementC;
    using LayoutS = typename BlockMmadQK::LayoutC;

    using ElementP = typename BlockMmadPV::ElementA;
    using LayoutP = typename BlockMmadPV::LayoutA;
    using ElementV = typename BlockMmadPV::ElementB;
    using LayoutV = typename BlockMmadPV::LayoutB;

    using ElementMask = typename EpilogueOnlineSoftmax::ElementMask;
    using LayoutMask = typename EpilogueOnlineSoftmax::LayoutMask;

    using ElementO = typename EpilogueRescaleO::ElementOutput;
    using LayoutO = typename EpilogueRescaleO::LayoutOutput;

    using ElementOTmp = typename EpilogueRescaleO::ElementInput;
    using LayoutOTmp = typename EpilogueRescaleO::LayoutInput;

    // Methods
    CATLASS_DEVICE
    FAInferKernel() {}

    template <int32_t CORE_TYPE = g_coreType>
    CATLASS_DEVICE void operator()(FAIKernelParams const &params);

    template <>
    CATLASS_DEVICE void operator()<AscendC::AIC>(FAIKernelParams const &params)
    {
        AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID0);
        AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID1);
        AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID2);
        AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID3);
        AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID4);
        AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID5);
        AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID6);
        AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID7);
        AscendC::SetFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0);
        AscendC::SetFlag<AscendC::HardEvent::FIX_M>(EVENT_ID1);
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID0);
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID1);
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID2);
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID3);
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID4);
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID5);
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID6);
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID7);
        static constexpr uint32_t L1_QK_SIZE = BlockMmadQK::L1TileShape::M * BlockMmadQK::L1TileShape::K * sizeof(ElementQ) + 
                                             BlockMmadQK::L1TileShape::N * BlockMmadQK::L1TileShape::K * sizeof(ElementK) * 2;
        BlockMmadQK blockMmadQK(resource);
        BlockMmadPV blockMmadPV(resource, L1_QK_SIZE);

        BlockMmadQKTail blockMmadQKTail(resource);
        BlockMmadPVTail blockMmadPVTail(resource, L1_QK_SIZE);
        __gm__ FATilingData *fATilingData = reinterpret_cast<__gm__ FATilingData *>(params.tiling);
        uint64_t mm1OutSize = fATilingData->mm1OutSize;
        uint64_t smOnlineOutSize = fATilingData->smOnlineOutSize;
        uint32_t batch = fATilingData->batch;
        uint32_t qHeads = fATilingData->numHeads;
        uint32_t kvHeads = fATilingData->kvHeads;
        uint32_t embed = fATilingData->embeddingSize;
        uint32_t pagedBlockSize = fATilingData->blockSize;
        uint32_t maxNumBlocksPerBatch = fATilingData->maxNumBlocksPerBatch;
        uint32_t curTotalTaskNum = fATilingData->firstBatchTaskNum;
        uint32_t totalTaskNum = fATilingData->totalTaskNum;
        uint32_t blockSize = fATilingData->blockSize;
        uint32_t maskType = fATilingData->maskType;
        float scaleValue = fATilingData->scaleValue;

        AscendC::GlobalTensor<ElementQ> gQ;
        gQ.SetGlobalBuffer((__gm__ ElementQ *)params.q);
        AscendC::GlobalTensor<ElementK> gK;
        gK.SetGlobalBuffer((__gm__ ElementK *)params.k);
        AscendC::GlobalTensor<ElementK> gV;
        gV.SetGlobalBuffer((__gm__ ElementK *)params.v);
        AscendC::GlobalTensor<int32_t> gBlockTable;
        gBlockTable.SetGlobalBuffer((__gm__ int32_t *)(params.blockTables));
        AscendC::GlobalTensor<int64_t> gActualQseqlen;
        gActualQseqlen.SetGlobalBuffer((__gm__ int64_t *)params.actualQseqlen);
        AscendC::GlobalTensor<int64_t> gActualKvseqlen;
        gActualKvseqlen.SetGlobalBuffer((__gm__ int64_t *)params.actualKvseqlen);
        AscendC::GlobalTensor<ElementS> gS;
        gS.SetGlobalBuffer((__gm__ ElementS *)params.s);
        AscendC::GlobalTensor<ElementP> gP;
        gP.SetGlobalBuffer((__gm__ ElementP *)params.p);
        AscendC::GlobalTensor<ElementOTmp> gOTmp;
        gOTmp.SetGlobalBuffer((__gm__ ElementOTmp *)params.oTemp);

        uint64_t strideQO = qHeads * embed;
        uint64_t strideKV = kvHeads * embed;
        uint32_t embedRound = RoundUp<BLOCK_SIZE>(embed);
        uint32_t groupSize = qHeads / kvHeads;

        uint32_t coreIdx = AscendC::GetBlockIdx();
        uint32_t coreNum = AscendC::GetBlockNum();
        curTotalTaskNum = 0;
        uint32_t preTotalTaskNum = 0;
        uint32_t curBatch = 0;
        uint64_t qBOffset = 0;
        uint64_t kBOffset = 0;
        uint64_t vBOffset = 0;
        uint64_t blockBOffset = 0;
        int64_t qSeqlen = 0;
        int64_t kvSeqlen = 0;
        uint32_t curQNBlockTile;
        uint32_t qNBlockNumPerGroup;
        uint32_t curQNBlockNum;
        int64_t curQSBlockTile;
        uint32_t curQSBlockNum;

        preTotalTaskNum = curTotalTaskNum;
        qSeqlen = reinterpret_cast<int64_t>(gActualQseqlen.GetValue(curBatch));
        kvSeqlen = reinterpret_cast<int64_t>(gActualKvseqlen.GetValue(curBatch));
        curQSBlockTile = GetQSBlockTile(kvSeqlen);
        curQNBlockTile = GetQNBlockTile(qSeqlen, groupSize);
        qNBlockNumPerGroup = CeilDiv(groupSize, curQNBlockTile);
        curQNBlockNum = qNBlockNumPerGroup * kvHeads;
        curQSBlockNum = CeilDiv(qSeqlen, curQSBlockTile);
        curTotalTaskNum += curQNBlockNum * curQSBlockNum;
        for (uint32_t taskIdx = coreIdx; taskIdx < totalTaskNum; taskIdx += uint32_t(coreNum)) {
            while (taskIdx >= curTotalTaskNum) {
                ++curBatch;
                preTotalTaskNum = curTotalTaskNum;
                qBOffset += qSeqlen * strideQO;
                if constexpr (!PAGED_CACHE_FLAG) {
                    kBOffset += kvSeqlen * strideKV;
                    vBOffset += kvSeqlen * strideKV;
                } else {
                    blockBOffset += maxNumBlocksPerBatch;
                }
                qSeqlen = reinterpret_cast<int64_t>(gActualQseqlen.GetValue(curBatch));
                kvSeqlen = reinterpret_cast<int64_t>(gActualKvseqlen.GetValue(curBatch));
                curQSBlockTile = GetQSBlockTile(kvSeqlen);
                curQNBlockTile = GetQNBlockTile(qSeqlen, groupSize);
                qNBlockNumPerGroup = CeilDiv(groupSize, curQNBlockTile);
                curQNBlockNum = qNBlockNumPerGroup * kvHeads;
                curQSBlockNum = CeilDiv(qSeqlen, curQSBlockTile);
                curTotalTaskNum += curQNBlockNum * curQSBlockNum;
            }
            uint32_t taskIdxCurBatch = taskIdx - preTotalTaskNum;
            uint32_t qSBlockIdx = taskIdxCurBatch / curQNBlockNum;
            uint32_t qNBlockIdx = taskIdxCurBatch - qSBlockIdx * curQNBlockNum;
            uint32_t qNBlockIdxCurGroup = qNBlockIdx % qNBlockNumPerGroup;
            uint32_t kvHeadIdx = qNBlockIdx / qNBlockNumPerGroup;
            uint32_t qHeadIdx = kvHeadIdx * groupSize + qNBlockIdxCurGroup * curQNBlockTile;
            uint64_t gmQOffset = qBOffset + qSBlockIdx * curQSBlockTile * strideQO + qHeadIdx * embed;
            uint64_t gmKOffset = kBOffset + kvHeadIdx * embed;
            uint64_t gmVOffset = vBOffset + kvHeadIdx * embed;
            uint32_t qSBlockSize = (qSBlockIdx == (curQSBlockNum - 1)) ? (qSeqlen - qSBlockIdx * curQSBlockTile) : curQSBlockTile;
            uint32_t qNBlockSize = (qNBlockIdxCurGroup == (qNBlockNumPerGroup - 1)) ?
                (groupSize - qNBlockIdxCurGroup * curQNBlockTile) : curQNBlockTile;
            uint32_t rowNum = qSBlockSize * qNBlockSize;
            uint32_t rowNumRound = AlignUp(rowNum, BLOCK_SIZE);
            uint32_t noSkipKvS = kvSeqlen;
            uint32_t noMaskKvS = kvSeqlen;
            uint32_t noMaskTailS = 0;
            if (maskType != 0) {
                uint32_t diffS = kvSeqlen - qSeqlen;
                noSkipKvS = (qSBlockIdx + 1) * curQSBlockTile + diffS;
                noSkipKvS = Min((uint32_t)kvSeqlen, noSkipKvS);
                noMaskKvS = noSkipKvS - qSBlockSize;
                noMaskTailS = noMaskKvS % pagedBlockSize;
            }
            uint32_t maskedKvS = qSBlockSize;
            uint32_t kvSLoopNumNoMask = CeilDiv(noMaskKvS, pagedBlockSize);
            uint32_t kvSLoopNumTotal = CeilDiv(noSkipKvS, pagedBlockSize);
            uint32_t blockStackNum = 4;
            uint32_t stackSeqTile;
            uint32_t stackSeqTileRound = blockStackNum * 128;
            int32_t preLaunch = 2;
            int32_t totalStackSeqNum = (maskType != 0) ?
                (CeilDiv(noMaskKvS, blockStackNum * pagedBlockSize) + 1) :
                CeilDiv(noMaskKvS, blockStackNum * pagedBlockSize);
            int32_t stackSeqCount = 0;

            LayoutQ layoutQTemp(rowNum, embed);
            LayoutK layoutKTemp(strideKV, blockStackNum * pagedBlockSize);
            LayoutV layoutVTemp(blockStackNum * pagedBlockSize, strideKV);
            blockMmadQK.loadQGM(gQ[gmQOffset], layoutQTemp, rowNum, qNBlockSize, qHeads);
            for (uint32_t kvSIdx = 0; kvSIdx < kvSLoopNumNoMask; kvSIdx += blockStackNum) {
                if (kvSIdx < kvSLoopNumNoMask) {
                    if (kvSIdx + blockStackNum > kvSLoopNumNoMask - 1) {
                        stackSeqTile = noMaskKvS - kvSIdx * pagedBlockSize;
                    } else {
                        stackSeqTile = pagedBlockSize * blockStackNum;
                    }
                    uint32_t SWorkSpacePingPongFlag = stackSeqCount % (preLaunch + 1);
                    uint64_t gmSOffset = coreIdx * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1) + SWorkSpacePingPongFlag * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB;
                    GemmCoord actualBlockShapeQK{rowNum, stackSeqTile, embed};
                    if constexpr (!PAGED_CACHE_FLAG) {
                        blockMmadQK(gQ[gmQOffset], gK[gmKOffset], gS[gmSOffset],
                            gBlockTable, layoutQTemp, layoutKTemp, actualBlockShapeQK,
                            kvSIdx, kvSLoopNumNoMask, pagedBlockSize, noMaskKvS, strideKV);
                    } else {
                        blockMmadQK(gQ[gmQOffset], gK[gmKOffset], gS[gmSOffset],
                            gBlockTable[blockBOffset], layoutQTemp, layoutKTemp, actualBlockShapeQK,
                        kvSIdx, kvSLoopNumNoMask, pagedBlockSize, noMaskKvS, strideKV);
                    }
                    Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(qkReady);
                }
                if (kvSIdx >= preLaunch * blockStackNum) {
                    uint32_t nowkvSIdx = kvSIdx - preLaunch * blockStackNum;
                    if (nowkvSIdx + blockStackNum > kvSLoopNumNoMask - 1) {
                        stackSeqTile = noMaskKvS - nowkvSIdx * pagedBlockSize;
                    } else {
                        stackSeqTile = pagedBlockSize * blockStackNum;
                    }
                    uint32_t PVWorkSpacePingPongFlag = (stackSeqCount - preLaunch) % (preLaunch + 1);
                    uint64_t gmPOffset = coreIdx * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1) + PVWorkSpacePingPongFlag * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB;
                    uint64_t gmOTmpOffset = coreIdx * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1) + PVWorkSpacePingPongFlag * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB;
                    LayoutP layoutPTemp(rowNum, stackSeqTileRound);
                    GemmCoord actualBlockShapePV{rowNum, embed, stackSeqTile};
                    if constexpr (!PAGED_CACHE_FLAG) {
                        blockMmadPV(gP[gmPOffset], gV[gmVOffset], gOTmp[gmOTmpOffset],
                            gBlockTable, layoutPTemp, layoutVTemp, actualBlockShapePV,
                            nowkvSIdx, kvSLoopNumNoMask, pagedBlockSize, noMaskKvS, strideKV, softmaxReady);
                    } else {
                        blockMmadPV(gP[gmPOffset], gV[gmVOffset], gOTmp[gmOTmpOffset],
                            gBlockTable[blockBOffset], layoutPTemp, layoutVTemp, actualBlockShapePV,
                            nowkvSIdx, kvSLoopNumNoMask, pagedBlockSize, noMaskKvS, strideKV, softmaxReady);
                    }
                    Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(pvReady);
                }
                stackSeqCount++;
            }

            /* 
             * for the secondary loop
             * while masked, it deals the CV stage1(Qk^t/SMOnline) of the final base block(typical shape [128, 512]), and the CV stage2(PV/rescaleO) of the last (prelaunch+1) base blocks
             * while not masked, it deals only the CV stage2(PV/rescaleO) of the last (prelaunch) base blocks
             */

            // deal secondary loop conditions
            uint32_t maskedStartIdx = (maskType != 0) ?
                ((noMaskTailS != 0) ? (kvSLoopNumNoMask - 1) : kvSLoopNumNoMask) :
                AlignUp(kvSLoopNumNoMask, blockStackNum);
            uint32_t noMaskTailInteStackNum = (noMaskKvS / pagedBlockSize) % blockStackNum;
            noMaskTailInteStackNum = (noMaskTailInteStackNum != 0) ?
                noMaskTailInteStackNum : ((noMaskTailS != 0) ? 0 : blockStackNum);
            uint32_t preLaunchStackNum = (maskType != 0) ?
                ((preLaunch - 1) * blockStackNum + noMaskTailInteStackNum) : (preLaunch * blockStackNum);
            
            // masked kvSeqlen loop
            for (uint32_t kvSIdx = maskedStartIdx;
                kvSIdx < kvSLoopNumTotal + preLaunchStackNum; ) {
                if ((kvSIdx < kvSLoopNumTotal) && (stackSeqCount <= totalStackSeqNum - 1)) {
                    stackSeqTile = maskedKvS;
                    uint32_t SWorkSpacePingPongFlag = stackSeqCount % (preLaunch + 1);
                    uint64_t gmSOffset = coreIdx * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1) + SWorkSpacePingPongFlag * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB;
                    GemmCoord actualBlockShapeQK{rowNum, stackSeqTile, embed};
                    if constexpr (!PAGED_CACHE_FLAG) {
                        blockMmadQKTail(gQ[gmQOffset], gK[gmKOffset], gS[gmSOffset],
                            gBlockTable, layoutQTemp, layoutKTemp, actualBlockShapeQK,
                            kvSIdx, kvSLoopNumTotal, pagedBlockSize, noSkipKvS, strideKV, noMaskTailS, 1);
                    } else {
                        blockMmadQKTail(gQ[gmQOffset], gK[gmKOffset], gS[gmSOffset],
                            gBlockTable[blockBOffset], layoutQTemp, layoutKTemp, actualBlockShapeQK,
                            kvSIdx, kvSLoopNumTotal, pagedBlockSize, noSkipKvS, strideKV, noMaskTailS, 1);
                    }
                    Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(qkReady);
                }
                
                if (kvSIdx >= preLaunchStackNum) {
                    uint32_t delayedKvSIdx = kvSIdx - preLaunchStackNum;

                    if (delayedKvSIdx + blockStackNum > kvSLoopNumTotal - 1 && (maskType != 0)) {
                        stackSeqTile = maskedKvS;
                    } else if (delayedKvSIdx + blockStackNum > kvSLoopNumNoMask - 1) {
                        stackSeqTile = noMaskKvS - delayedKvSIdx * pagedBlockSize;
                    } else {
                        stackSeqTile = pagedBlockSize * blockStackNum;
                    }
                    uint32_t PVWorkSpacePingPongFlag = (stackSeqCount - preLaunch) % (preLaunch + 1);
                    uint64_t gmPOffset = coreIdx * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1) + PVWorkSpacePingPongFlag * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB;
                    uint64_t gmOTmpOffset = coreIdx * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1) + PVWorkSpacePingPongFlag * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB;
                    LayoutP layoutPTemp(rowNum, stackSeqTileRound);
                    GemmCoord actualBlockShapePV{rowNum, embed, stackSeqTile};

                    if ((stackSeqCount - preLaunch == totalStackSeqNum - 1) && (maskType != 0)) { // 加mask
                        if constexpr (!PAGED_CACHE_FLAG) {
                            blockMmadPVTail(gP[gmPOffset], gV[gmVOffset], gOTmp[gmOTmpOffset],
                                gBlockTable, layoutPTemp, layoutVTemp, actualBlockShapePV,
                                delayedKvSIdx, kvSLoopNumTotal, pagedBlockSize, noSkipKvS, strideKV, softmaxReady, noMaskTailS, 1);
                        } else {
                            blockMmadPVTail(gP[gmPOffset], gV[gmVOffset], gOTmp[gmOTmpOffset],
                                gBlockTable[blockBOffset], layoutPTemp, layoutVTemp, actualBlockShapePV,
                                delayedKvSIdx, kvSLoopNumTotal, pagedBlockSize, noSkipKvS, strideKV, softmaxReady, noMaskTailS, 1);
                        }
                    } else { // 不加mask
                        if constexpr (!PAGED_CACHE_FLAG) {
                        blockMmadPV(gP[gmPOffset], gV[gmVOffset], gOTmp[gmOTmpOffset],
                                gBlockTable, layoutPTemp, layoutVTemp, actualBlockShapePV,
                                delayedKvSIdx, kvSLoopNumNoMask, pagedBlockSize, noMaskKvS, strideKV, softmaxReady);
                        } else {
                            blockMmadPV(gP[gmPOffset], gV[gmVOffset], gOTmp[gmOTmpOffset],
                                gBlockTable[blockBOffset], layoutPTemp, layoutVTemp, actualBlockShapePV,
                                delayedKvSIdx, kvSLoopNumNoMask, pagedBlockSize, noMaskKvS, strideKV, softmaxReady);
                        }
                    }
                    Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(pvReady);
                }
                if ((maskType != 0) && (stackSeqCount - preLaunch == totalStackSeqNum - 2)) {
                    kvSIdx += noMaskTailInteStackNum;
                } else {
                    kvSIdx += blockStackNum;
                }
                stackSeqCount++;
            }
        }

        AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID0);
        AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID1);
        AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID2);
        AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID3);
        AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID4);
        AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID5);
        AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID6);
        AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(EVENT_ID7);

        AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0);
        AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(EVENT_ID1);

        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID0);
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID1);
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID2);
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID3);
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID4);
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID5);
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID6);
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(EVENT_ID7);
    }
 
    template <>
    CATLASS_DEVICE void operator()<AscendC::AIV>(FAIKernelParams const &params)
    {
        AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
        AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
        AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID2);
        AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID3);
        AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID4);
        AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID5);

        AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
        AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
        AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
        AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID2);
        AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);

        // Get tiling parameters
        __gm__ FATilingData *fATilingData = reinterpret_cast<__gm__ FATilingData *>(params.tiling);
        uint64_t mm1OutSize = fATilingData->mm1OutSize;
        uint64_t smOnlineOutSize = fATilingData->smOnlineOutSize;
        uint64_t mm2OutSize = fATilingData->mm2OutSize;
        uint32_t batch = fATilingData->batch;
        uint32_t qHeads = fATilingData->numHeads;
        uint32_t kvHeads = fATilingData->kvHeads;
        uint32_t embed = fATilingData->embeddingSize;
        uint32_t pagedBlockSize = fATilingData->blockSize;
        uint32_t maxNumBlocksPerBatch = fATilingData->maxNumBlocksPerBatch;
        uint32_t firstBatchTaskNum = fATilingData->firstBatchTaskNum;
        uint32_t totalTaskNum = fATilingData->totalTaskNum;
        uint32_t maskType = fATilingData->maskType;
        float scaleValue = fATilingData->scaleValue;
        // Get the memory offset address of the input on Global Memory
        AscendC::GlobalTensor<ElementMask> gMask;
        gMask.SetGlobalBuffer((__gm__ ElementMask *)params.mask);
        AscendC::GlobalTensor<int64_t> gActualQseqlen;
        gActualQseqlen.SetGlobalBuffer((__gm__ int64_t *)params.actualQseqlen);
        AscendC::GlobalTensor<int64_t> gActualKvseqlen;
        gActualKvseqlen.SetGlobalBuffer((__gm__ int64_t *)params.actualKvseqlen);
        AscendC::GlobalTensor<ElementO> gO;
        gO.SetGlobalBuffer((__gm__ ElementO *)params.o);
        AscendC::GlobalTensor<ElementS> gS;
        gS.SetGlobalBuffer((__gm__ ElementS *)params.s);
        AscendC::GlobalTensor<ElementP> gP;
        gP.SetGlobalBuffer((__gm__ ElementP *)params.p);
        AscendC::GlobalTensor<ElementOTmp> gOTmp;
        gOTmp.SetGlobalBuffer((__gm__ ElementOTmp *)params.oTemp);
        AscendC::GlobalTensor<ElementOTmp> gOUpdate;
        gOUpdate.SetGlobalBuffer((__gm__ ElementOTmp *)params.oUpdate);

        uint32_t groupSize = qHeads / kvHeads;
        uint32_t embedRound = RoundUp(embed, BLOCK_SIZE);

        EpilogueOnlineSoftmax epilogueOnlineSoftmax(resource, scaleValue);
        EpilogueRescaleO epilogueRescaleO(resource);

        
        // uint32_t curTotalTaskNum = 0;
        uint32_t preTotalTaskNum = 0;
        uint32_t curBatch = 0;
        int64_t oBatchOffset = 0;
        uint32_t qSeqlen = static_cast<uint32_t>(gActualQseqlen.GetValue(curBatch));
        uint32_t kvSeqlen = static_cast<uint32_t>(gActualKvseqlen.GetValue(curBatch));
        uint32_t curQNBlockTile = GetQNBlockTile(qSeqlen, groupSize);
        uint32_t qNBlockNumPerGroup = CeilDiv(groupSize, curQNBlockTile);
        uint32_t curQNBlockNum = qNBlockNumPerGroup * kvHeads;
        uint32_t curQSBlockTile = GetQSBlockTile(kvSeqlen);
        uint32_t curQSBlockNum = CeilDiv(qSeqlen, curQSBlockTile);
        uint32_t curTotalTaskNum = firstBatchTaskNum;

        uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum();
        uint32_t coreNum = AscendC::GetBlockNum();
        // Go through each task.
        for (uint32_t taskIdx = coreIdx; taskIdx < totalTaskNum; taskIdx += uint32_t(coreNum)) {
            // Get the offset of each core on the GM.
            while (taskIdx >= curTotalTaskNum) {
                curBatch++;
                oBatchOffset += static_cast<int64_t>(qSeqlen) * qHeads * embed;
                preTotalTaskNum = curTotalTaskNum;
                qSeqlen = static_cast<uint32_t>(gActualQseqlen.GetValue(curBatch));
                kvSeqlen = static_cast<uint32_t>(gActualKvseqlen.GetValue(curBatch));
                curQNBlockTile = GetQNBlockTile(qSeqlen, groupSize);
                qNBlockNumPerGroup = CeilDiv(groupSize, curQNBlockTile);
                curQNBlockNum = qNBlockNumPerGroup * kvHeads;
                curQSBlockTile = GetQSBlockTile(kvSeqlen);
                curQSBlockNum = CeilDiv(qSeqlen, curQSBlockTile);
                curTotalTaskNum += curQNBlockNum * curQSBlockNum;
            }
            uint32_t taskIdxCurBatch = taskIdx - preTotalTaskNum;
            uint32_t qSBlockIdx = taskIdxCurBatch / curQNBlockNum;
            uint32_t qNBlockIdx = taskIdxCurBatch % curQNBlockNum;
            uint32_t qNBlockIdxCurGroup = qNBlockIdx % qNBlockNumPerGroup;
            
            uint32_t oSOffset = qSBlockIdx * curQSBlockTile * qHeads * embed;
            uint32_t kvNIdx = qNBlockIdx / qNBlockNumPerGroup;
            uint32_t qStartNIdx = kvNIdx * groupSize + qNBlockIdxCurGroup * curQNBlockTile;
            uint32_t oNOffset = qStartNIdx * embed;
            int64_t gmOffsetO = oBatchOffset + oSOffset + oNOffset;

            uint32_t qSBlockSize = (qSBlockIdx == (curQSBlockNum - 1)) ?
                (qSeqlen - qSBlockIdx * curQSBlockTile) : curQSBlockTile;
            uint32_t qNBlockSize = (qNBlockIdxCurGroup == (qNBlockNumPerGroup - 1)) ?
                (groupSize - qNBlockIdxCurGroup * curQNBlockTile) : curQNBlockTile;
            uint32_t rowNum = qSBlockSize * qNBlockSize;
            uint32_t rowNumRound = RoundUp(rowNum, BLOCK_SIZE);

            uint32_t noSkipKvS = kvSeqlen;
            uint32_t noMaskKvS = kvSeqlen;
            uint32_t noMaskTailS = 0;
            if (maskType != 0) {
                uint32_t diffS = kvSeqlen - qSeqlen;
                noSkipKvS = (qSBlockIdx + 1) * curQSBlockTile + diffS;
                noSkipKvS = Min(kvSeqlen, noSkipKvS);
                noMaskKvS = noSkipKvS - qSBlockSize;
                noMaskTailS = noMaskKvS % pagedBlockSize;
            }
            uint32_t maskedKvS = qSBlockSize;
            uint32_t kvSLoopNumTotal = CeilDiv(noSkipKvS, pagedBlockSize);
            uint32_t kvSLoopNumNoMask = CeilDiv(noMaskKvS, pagedBlockSize);
            uint32_t blockStackNum = 4;
            uint32_t stackSeqTilePad = blockStackNum * pagedBlockSize;
            uint32_t stackSeqTile;
            int32_t preLaunch = 2;
            int32_t totalStackSeqNum = (maskType != 0) ?
                (CeilDiv(noMaskKvS, blockStackNum * pagedBlockSize) + 1) :
                CeilDiv(noMaskKvS, blockStackNum * pagedBlockSize);
            int32_t stackSeqCount = 0;

            // no mask kvSeqlen loop
            for (uint32_t kvSIdx = 0;
                kvSIdx < kvSLoopNumNoMask;
                kvSIdx += blockStackNum) {

                if (kvSIdx + blockStackNum > kvSLoopNumNoMask - 1) {
                    stackSeqTile = noMaskKvS - kvSIdx * pagedBlockSize;
                } else {
                    stackSeqTile = pagedBlockSize * blockStackNum;
                }
                uint32_t stackSeqTileRound = RoundUp(stackSeqTile, BLOCK_SIZE);
                LayoutS layOutS(rowNum, stackSeqTile, stackSeqTilePad);
                LayoutP layOutP(rowNum, stackSeqTile, stackSeqTilePad);
                GemmCoord actualBlockShapeQK{rowNum, stackSeqTile, embed};
                uint32_t curStackTileMod = stackSeqCount % (preLaunch + 1);
                uint32_t gmOffsetS = coreIdx * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1) + // cube core offset
                    curStackTileMod * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB; // single cube core db offset
                // vec core offset will be processed within epilogue block
                uint32_t gmOffsetP = gmOffsetS;
                // AscendC::printf("stackSeqCount:%d\n", stackSeqCount);
                Arch::CrossCoreWaitFlag(qkReady);
                // online softmax
                epilogueOnlineSoftmax(
                    gP[gmOffsetP],
                    gS[gmOffsetS],
                    layOutP,
                    layOutS,
                    actualBlockShapeQK,
                    (stackSeqCount == 0),
                    qSBlockSize,
                    qNBlockSize,
                    curStackTileMod);
                Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(softmaxReady);

                if (kvSIdx >= preLaunch * blockStackNum) {
                    uint32_t delayedKvSIdx = kvSIdx - preLaunch * blockStackNum;
                    if (delayedKvSIdx + blockStackNum > kvSLoopNumNoMask - 1) {
                        stackSeqTile = noMaskKvS - kvSIdx * pagedBlockSize;
                    } else {
                        stackSeqTile = pagedBlockSize * blockStackNum;
                    }
                    LayoutO layoutO(qSeqlen, embed * qHeads);
                    LayoutOTmp layoutOTmp(rowNum, embed, embedRound);
                    GemmCoord actualBlockShapePV{rowNum, embed, stackSeqTile};
                    uint32_t curStackTileMod = (stackSeqCount - preLaunch) % (preLaunch + 1);
                    uint32_t gmOffsetOTmp = coreIdx * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1) +
                        curStackTileMod * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB;
                    Arch::CrossCoreWaitFlag(pvReady);
                    // rescale O
                    epilogueRescaleO(
                        gO[gmOffsetO],
                        gOTmp[gmOffsetOTmp],
                        layoutO,
                        layoutOTmp,
                        actualBlockShapePV,
                        qSBlockSize,
                        qNBlockSize,
                        (stackSeqCount - preLaunch == 0),
                        0,
                        curStackTileMod);
                }
                stackSeqCount++;
            }
            /*
            * for the secondary loop
            * while masked, it deals the CV stage1(Qk^t/SMOnline) of the final base block(typical shape [128, 512]), and the CV stage2(PV/rescaleO) of the last (prelaunch+1) base blocks
            * while unmasked, it deals the CV stage1(Qk^t/SMOnline) of the last (prelaunch+1) base blocks
            */
        // deal secondary loop conditions
        uint32_t maskedStartIdx = (maskType != 0) ?
            ((noMaskTailS != 0) ? (kvSLoopNumNoMask - 1) : kvSLoopNumNoMask) :
            AlignUp(kvSLoopNumNoMask, blockStackNum);
            uint32_t noMaskTailInteStackNum = (noMaskKvS / pagedBlockSize) % blockStackNum;
            noMaskTailInteStackNum = (noMaskTailInteStackNum != 0) ?
            noMaskTailInteStackNum : ((noMaskTailS != 0) ? 0 : blockStackNum);
            uint32_t preLaunchStackNum = (maskType != 0) ?
                ((preLaunch - 1) * blockStackNum + noMaskTailInteStackNum) : (preLaunch * blockStackNum);
            // masked kvSeqlen loop
            for (uint32_t kvSIdx = maskedStartIdx;
                kvSIdx < kvSLoopNumTotal + preLaunchStackNum; ){
                if ((kvSIdx < kvSLoopNumTotal) && (stackSeqCount <= totalStackSeqNum - 1)){
                    stackSeqTile = maskedKvS;
                    uint32_t stackSeqTileRound = RoundUp(stackSeqTile, BLOCK_SIZE);
                    LayoutS layOutS(rowNum, stackSeqTile, stackSeqTilePad);
                    LayoutP layOutP(rowNum, stackSeqTile, stackSeqTilePad);
                    LayoutMask layOutMask(1024, 1024, 1024);
                    GemmCoord actualBlockShapeQK{rowNum, stackSeqTile, embed};
                    uint32_t curStackTileMod = stackSeqCount % (preLaunch + 1);
                    uint32_t gmOffsetS = coreIdx * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1) + // cube core offset
                        curStackTileMod * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB; // single cube core db offset
                    // vec core offset will be processed within epilogue block
                    uint32_t gmOffsetP = gmOffsetS;
                    // online softmax
                    epilogueOnlineSoftmax(
                        gP[gmOffsetP],
                        gS[gmOffsetS],
                        gMask,
                        layOutP,
                        layOutS,
                        layOutMask,
                        actualBlockShapeQK,
                        (stackSeqCount == 0),
                        qSBlockSize,
                        qNBlockSize,
                        curStackTileMod,
                        qkReady);
                    Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(softmaxReady);
                }
                if (kvSIdx >= preLaunchStackNum){
                    uint32_t delayedKvSIdx = kvSIdx - preLaunchStackNum;
                    if (delayedKvSIdx + blockStackNum > kvSLoopNumTotal - 1 && (maskType != 0)){
                        stackSeqTile = maskedKvS;
                    }
                    else if(delayedKvSIdx + blockStackNum > kvSLoopNumNoMask - 1){
                        stackSeqTile = noMaskKvS - delayedKvSIdx * pagedBlockSize;
                    }
                    else{
                        stackSeqTile = pagedBlockSize * blockStackNum;
                    }
                    LayoutO layoutO(qSBlockSize, embed * qHeads);
                    LayoutOTmp layoutOTmp(rowNum, embed, embedRound);
                    GemmCoord actualBlockShapePV{rowNum, embed, stackSeqTile};
                    uint32_t curStackTileMod = (stackSeqCount - preLaunch) % (preLaunch + 1);
                    uint32_t gmOffsetOTmp = coreIdx * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB * (preLaunch + 1) +
                        curStackTileMod * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB;
                    Arch::CrossCoreWaitFlag(pvReady);
                    // rescale O
                    epilogueRescaleO(
                        gO[gmOffsetO],
                        gOTmp[gmOffsetOTmp],
                        layoutO,
                        layoutOTmp,
                        actualBlockShapePV,
                        qSBlockSize,
                        qNBlockSize,
                        (stackSeqCount - preLaunch == 0),
                        (stackSeqCount - preLaunch == totalStackSeqNum - 1),
                        curStackTileMod);
                }
                if((maskType != 0) && (stackSeqCount - preLaunch == totalStackSeqNum - 2)){
                    kvSIdx += noMaskTailInteStackNum;
                }
                else{
                    kvSIdx += blockStackNum;
                }
                stackSeqCount++;
            }
        }

        AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID2);
        AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID3);
        AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID4);
        AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID5);
        
        AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
        AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
        AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID2);
        AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
        AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
        AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
        AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
    }

private:
    Arch::Resource<ArchTag> resource;
    Arch::CrossCoreFlag qkReady{QK_READY_ID};
    Arch::CrossCoreFlag softmaxReady{SOFTMAX_READY_ID};
    Arch::CrossCoreFlag pvReady{PV_READY_ID};
};

template <class DType>
CATLASS_GLOBAL void __mix__(1,2) FAInfer(uint64_t hardwareSyncAddr,
                                GM_ADDR q,
                                GM_ADDR k,
                                GM_ADDR v,
                                GM_ADDR mask,
                                GM_ADDR blockTables,
                                GM_ADDR o,
                                GM_ADDR actualQseqlen,
                                GM_ADDR actualKvseqlen,
                                GM_ADDR s,
                                GM_ADDR p,
                                GM_ADDR oTemp,
                                GM_ADDR oUpdate,
                                GM_ADDR tiling)
{
    AscendC::SetSyncBaseAddr(hardwareSyncAddr);

    using ArchTag = Arch::AtlasA2;
    using ElementQ = DType;
    using LayoutQ = layout::RowMajor;
    using ElementK = DType;
    using LayoutK = layout::ColumnMajor;
    using ElementV = DType;
    using LayoutV = layout::RowMajor;
    using ElementS = float;
    using LayoutS = layout::RowMajor;
    using ElementP = DType;
    using LayoutP = layout::RowMajor;
    using ElementO = DType;
    using LayoutO = layout::RowMajor;
    using ElementMask = DType;
    using LayoutMask = layout::RowMajor;
    using ElementOTmp = float;
    using LayoutOTmp = layout::RowMajor;
    using ElementUpdate = float;
    using LayoutUpdate = layout::RowMajor;
    // L1TileShape::K must be embdding
    using L1TileShape = GemmShape<128, 128, 128>;
    using L0TileShape = L1TileShape;
    // GEMM Block模块,实现Flash Attention Infer的Q * K^T
    // using DispatchPolicyQK = Gemm::MmadAtlasA2FAIQK<true>;
    using DispatchPolicyQK = Gemm::MmadAtlasA2FAIQK<true, false>;
    using QType = Gemm::GemmType<ElementQ, LayoutQ>;
    using KType = Gemm::GemmType<ElementK, LayoutK>;
    using SType = Gemm::GemmType<ElementS, LayoutS>;
    using BlockMmadQK = Gemm::Block::BlockMmad<DispatchPolicyQK, L1TileShape, L0TileShape, QType, KType, SType>;

    using DispatchPolicyQKTail = Gemm::MmadAtlasA2FAITailQK<true, false>;
    using QType = Gemm::GemmType<ElementQ, LayoutQ>;
    using KType = Gemm::GemmType<ElementK, LayoutK>;
    using SType = Gemm::GemmType<ElementS, LayoutS>;
    using BlockMmadQKTail = Gemm::Block::BlockMmad<DispatchPolicyQKTail, L1TileShape, L0TileShape, QType, KType, SType>;

    // Epilogue Block模块,实现Flash Attention Infer中当前S基块的softmax
    using DispatchPolicyOnlineSoftmax = Epilogue::EpilogueAtlasA2OnlineSoftmax;
    using PType = Gemm::GemmType<ElementP, LayoutP>;
    using maskType = Gemm::GemmType<ElementMask, LayoutMask>;
    using EpilogueOnlineSoftmax =
        Epilogue::Block::BlockEpilogue<DispatchPolicyOnlineSoftmax, PType, SType, maskType>;

    // GEMM Block模块,实现Flash Attention Infer的P * V
    // using DispatchPolicyPV = Gemm::MmadAtlasA2FAIPV<true>;
    using DispatchPolicyPV = Gemm::MmadAtlasA2FAIPV<true, false>;
    using VType = Gemm::GemmType<ElementV, LayoutV>;
    using OTmpType = Gemm::GemmType<ElementOTmp, LayoutOTmp>;
    using BlockMmadPV = Gemm::Block::BlockMmad<DispatchPolicyPV, L1TileShape, L0TileShape, PType, VType, OTmpType>;

    using DispatchPolicyPVTail = Gemm::MmadAtlasA2FAITailPV<true, false>;
    using VType = Gemm::GemmType<ElementV, LayoutV>;
    using OTmpType = Gemm::GemmType<ElementOTmp, LayoutOTmp>;
    using BlockMmadPVTail = Gemm::Block::BlockMmad<DispatchPolicyPVTail, L1TileShape, L0TileShape, PType, VType, OTmpType>;

    // Epilogue Block模块,实现Flash Attention Infer中当前O基块的更新
    using DispatchPolicyRescaleO = Epilogue::EpilogueAtlasA2RescaleO;
    using OType = Gemm::GemmType<ElementO, LayoutO>;
    using OUpdateType = Gemm::GemmType<ElementUpdate, LayoutUpdate>;
    using EpilogueRescaleO =
        Epilogue::Block::BlockEpilogue<DispatchPolicyRescaleO, OType, OTmpType, OUpdateType>;


    // Kernel level
    // using FAInferKernel = FAInferKernel<BlockMmadQK, BlockMmadPV, EpilogueOnlineSoftmax, EpilogueRescaleO, true>;
    using FAInferKernel = FAInferKernel<BlockMmadQK, BlockMmadPV, BlockMmadQKTail, BlockMmadPVTail, EpilogueOnlineSoftmax, EpilogueRescaleO, true>;
    FAIKernelParams params{q, k, v, mask, blockTables,
                                          actualQseqlen, actualKvseqlen, o, s, p, oTemp, oUpdate, tiling};

    // call kernel
    FAInferKernel flashAttnInfer;
    flashAttnInfer(params);
}

template<class DType>
void FAImpl(const uint32_t blockNum, aclrtStream stream, const FlashAttentionParams &params)
{
    // Parameters initialization.
    int32_t batch = params.batch;
    int32_t qSeqlen = params.qSeqlen;
    int32_t kvSeqlen = params.kvSeqlen;
    int32_t numHeads = params.numHeads;
    int32_t kvHeads = params.kvHeads;
    int32_t embeddingSize = params.embeddingSize;
    int32_t blockSize = params.blockSize;
    int32_t maskType = params.maskType;
    int32_t maxKvSeqlen = kvSeqlen;
    int32_t numBlocks = batch * ((maxKvSeqlen + blockSize - 1) / blockSize);

    int32_t numTokens = params.qNtokens;

    uint64_t seqArraySize = batch * sizeof(int64_t);
    uint32_t tilingSize = sizeof(FATilingData);

    uint8_t *qSeqDevice = params.inputAddr.at(0);
    uint8_t *kvSeqDevice = params.inputAddr.at(1);
    uint8_t *qDevice = params.inputAddr.at(2);
    uint8_t *kDevice = params.inputAddr.at(3);
    uint8_t *vDevice = params.inputAddr.at(4);

    uint8_t *maskDevice = nullptr;
    if (maskType == 1) {
        maskDevice = params.inputAddr.at(5);
    }

    uint8_t *blockTableDevice = params.inputAddr.at(6);
    uint8_t *oDevice = params.outputAddr.at(0);

    // Allocate matrices in device memory for workspace.
    // One base workspace block contains 65536 elements.
    uint64_t mm1OutSize = blockNum * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB * sizeof(float) * FAInferTiling::NUM3;
    uint64_t smOnlineOutSize = blockNum * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB * sizeof(DType) * FAInferTiling::NUM3;
    uint64_t mm2OutSize = blockNum * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB * sizeof(float) * FAInferTiling::NUM3;
    uint64_t UpdateSize = blockNum * FAInferTiling::WORKSPACE_BLOCK_SIZE_DB * sizeof(float) * FAInferTiling::NUM3;
    uint64_t workSpaceSize = mm1OutSize + smOnlineOutSize + mm2OutSize + UpdateSize;
    
    uint8_t *sDevice = g_catlassWorkspaceAlloc(mm1OutSize);
    uint8_t *pDevice = g_catlassWorkspaceAlloc(smOnlineOutSize);
    uint8_t *oTempDevice = g_catlassWorkspaceAlloc(mm2OutSize);
    uint8_t *oUpdateDevice = g_catlassWorkspaceAlloc(UpdateSize);

    uint8_t *tilingDevice = g_catlassWorkspaceAlloc(tilingSize);

    // get tiling
    void *tilingHost = nullptr;
    ACL_CHECK(aclrtMallocHost(&tilingHost, tilingSize));
    uint32_t blockDim = blockNum;

    FAInferTiling::FAInfo faInfo;
    faInfo.numTokens = numTokens;
    faInfo.numHeads = numHeads;
    faInfo.embeddingSize = embeddingSize;
    faInfo.numBlocks = numBlocks;
    faInfo.blockSize = blockSize;
    faInfo.kvHeads = kvHeads;
    faInfo.batch = batch;
    faInfo.maskType = static_cast<FAInferTiling::MaskType>(maskType);

    uint8_t *qSeqHost;
    ACL_CHECK(aclrtMallocHost(reinterpret_cast<void **>(&qSeqHost), seqArraySize));
    ACL_CHECK(aclrtMemcpy(qSeqHost, seqArraySize, qSeqDevice, seqArraySize, ACL_MEMCPY_DEVICE_TO_HOST));
    uint8_t *kvSeqHost;
    ACL_CHECK(aclrtMallocHost(reinterpret_cast<void **>(&kvSeqHost), seqArraySize));
    ACL_CHECK(aclrtMemcpy(kvSeqHost, seqArraySize, kvSeqDevice, seqArraySize, ACL_MEMCPY_DEVICE_TO_HOST));

    faInfo.qSeqlenList = reinterpret_cast<int64_t *>(qSeqHost);
    faInfo.kvSeqlenList = reinterpret_cast<int64_t *>(kvSeqHost);

    FATilingData faTilingData;

    FAInferTiling::GetFATilingParam(faInfo, blockDim, faTilingData);

    tilingHost = reinterpret_cast<void *>(&faTilingData);

    uint32_t tilingKey = 0;

    ACL_CHECK(aclrtMemcpy(tilingDevice, tilingSize, tilingHost, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE));

    // Prepare hardware sync address
    uint64_t hardwareSyncAddr{0};
    ACL_CHECK(aclrtGetHardwareSyncAddr(reinterpret_cast<void**>(&hardwareSyncAddr)));
    
    FAInfer<DType><<<blockDim, nullptr, stream>>>(hardwareSyncAddr, qDevice, kDevice, vDevice, maskDevice, blockTableDevice, oDevice, qSeqDevice, kvSeqDevice, sDevice, pDevice, oTempDevice, oUpdateDevice, tilingDevice);
    ACL_CHECK(aclrtSynchronizeStream(stream));
}

void FlashAttentionInfer(uint32_t blockNum, aclrtStream stream, const FlashAttentionParams& params)
{
    if (params.dataType == ACL_FLOAT16) {
        FAImpl<half>(blockNum, stream, params);
    } else if (params.dataType == ACL_BF16) {
        FAImpl<bfloat16_t>(blockNum, stream, params);
    }
}
} // namespace CatlassKernel