* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file util_regbase.h
* \brief
*/
#ifndef FLASH_ATTENTION_UTIL_REGBASE_H
#define FLASH_ATTENTION_UTIL_REGBASE_H
#include "util.h"
using AscendC::TQue;
using AscendC::QuePosition;
namespace regbaseutil {
constexpr uint16_t regBytes = 256;
constexpr int64_t MAX_PRE_NEXT_TOKENS = 0x7FFFFFFF;
constexpr float SINK_MIN_INF = -3.40E+38;
enum class VselrIndexEnum {GT_64_AND_LTE_128_INDEX = 0, GT_0_AND_LTE_64_INDEX = 1, DN_INDEX = 2, NZ_INDEX = 3};
enum class DTemplateType {
Aligned16 = 16,
Aligned32 = 32,
Aligned48 = 48,
Aligned64 = 64,
Aligned80 = 80,
Aligned96 = 96,
Aligned128 = 128,
Aligned160 = 160,
Aligned192 = 192,
Aligned256 = 256,
Aligned512 = 512,
Aligned576 = 576,
Aligned768 = 768,
NotAligned,
};
enum class S1TemplateType {
Aligned16 = 16,
Aligned64 = 64,
Aligned128 = 128,
Aligned256 = 256,
Aligned512 = 512,
NotAligned,
};
enum class S2TemplateType {
Aligned16 = 16,
Aligned32 = 32,
Aligned64 = 64,
Aligned128 = 128,
Aligned256 = 256,
Aligned512 = 512,
Aligned1024 = 1024,
NotAligned,
};
enum class SparseType : uint8_t {
DENSE = 0,
CASUAL = 1,
BAND = 2,
UNSUPPORTED = 3
};
template<bool isInfer = false>
struct RunParamStr;
#define COMMON_RUN_PARAM \
int64_t boIdx; \
int64_t s1oIdx; \
int64_t n2oIdx; \
int64_t goIdx; \
int32_t s2LoopEndIdx; \
int64_t s2LineStartIdx = 0; \
int64_t s2LineEndIdx; \
\
uint32_t s1RealSize; \
uint32_t s1RealSizeAlign32; \
uint32_t s1RealSizeAlign64; \
uint32_t halfS1RealSize; \
uint32_t firstHalfS1RealSize; \
int64_t tensorQOffset; \
int64_t attentionOutOffset; \
int64_t actualS1Size; \
int64_t actualS2Size; \
uint64_t b1SSOffset; \
uint64_t b1SSAttenMaskOffset; \
uint64_t b1SSOffsetAlign16; \
int64_t qRopeNBGOffset; \
int64_t kRopeNBGOffset;
template<>
struct RunParamStr<false> {
COMMON_RUN_PARAM;
};
template<>
struct RunParamStr<true> {
COMMON_RUN_PARAM;
int64_t s1LoopTimes;
int64_t gS1Idx;
int64_t s2InCurrentBatch;
int64_t preTokensPerBatch = MAX_PRE_NEXT_TOKENS;
int64_t nextTokensPerBatch = MAX_PRE_NEXT_TOKENS;
int64_t sOuterOffset;
int64_t cubeSOuterOffset;
int64_t keyCoreOffset;
int64_t valueCoreOffset;
uint64_t pseShiftCoreOffset;
int64_t keyOffset;
int64_t qBOffset;
int64_t qRopeBOffset;
int64_t queryLeftPaddingSize;
int64_t kvLeftPaddingSize;
int64_t softmaxLseOffset;
int64_t actualSeqLengthOfMlaPerBatch = 0;
int64_t nextTokensOfMlaPerBatch = 0;
int64_t preTokensOfMlaPerBatch = 0;
int64_t prefixCoreOffset = 0;
};
#define COMMON_RUN_INFO \
int64_t s2StartIdx; \
int64_t s2EndIdx; \
int64_t s2LoopCount; \
int64_t s2LoopLimit; \
int64_t s1oIdx = 0; \
int64_t boIdx = 0; \
int64_t n2oIdx = 0; \
int64_t goIdx = 0; \
int32_t s1RealSize; \
int32_t s1RealSizeAlign32; \
int32_t s1RealSizeAlign64; \
int32_t halfS1RealSize; \
int32_t firstHalfS1RealSize; \
int32_t s2RealSize; \
int64_t s2AlignedSize; \
int32_t vec2S1BaseSize; \
int32_t vec2S1RealSize; \
int64_t vecCoreOffset; \
int64_t queryOffset; \
int64_t keyOffset; \
int64_t valueOffset; \
int64_t qRopeOffset; \
int64_t kRopeOffset; \
\
int64_t taskId; \
int64_t multiCoreInnerIdx = 0; \
\
int64_t attentionOutOffset; \
uint64_t s1ScaleNumAcc; \
uint64_t s2ScaleNumAcc; \
int64_t s1SizeAcc; \
int64_t s2SizeAcc; \
int64_t actualS1Size; \
int64_t actualS2Size; \
int64_t preTokensPerBatch; \
int64_t nextTokensPerBatch; \
int64_t b1SSOffset; \
\
int64_t b1SSAttenMaskOffset; \
int64_t b1SSOffsetAlign; \
int64_t deScaleKvOffset; \
int64_t nextTokensOfMlaPerBatch = 0; \
int64_t preTokensOfMlaPerBatch = 0; \
uint8_t taskIdMod2; \
uint8_t taskIdMod3; \
uint8_t multiCoreIdxMod2 = 0; \
uint8_t multiCoreIdxMod3 = 0; \
int64_t sOuterOffset
template<bool isInfer = false>
struct RunInfo;
template <>
struct RunInfo<true> {
COMMON_RUN_INFO;
int64_t gS1Idx;
uint64_t pseShiftOffset;
int64_t queryLeftPaddingSize;
int64_t kvLeftPaddingSize;
int64_t actualSeqLengthOfMlaPerBatch = 0;
int64_t softmaxLseOffset;
int64_t flashDecodeS2Idx;
int64_t s2InCurrentBatch;
int64_t prefixOffset;
};
template<>
struct RunInfo<false> {
COMMON_RUN_INFO;
};
#define COMMON_CONST_INFO \
\
uint32_t s1BaseSize; \
uint32_t s2BaseSize; \
int64_t bSize; \
int64_t t1Size; \
int64_t t2Size; \
int64_t dSize; \
int64_t dSizeV; \
int64_t dBasicBlock; \
int64_t dSizeRope; \
int64_t gSize; \
int64_t n2Size; \
int64_t s1Size; \
int64_t s2Size; \
\
int64_t s1D; \
int64_t gS1D; \
int64_t n2GS1D; \
int64_t s2D; \
int64_t n2S2D; \
int64_t s1Dv; \
int64_t gS1Dv; \
int64_t n2GS1Dv; \
int64_t s2Dv; \
int64_t n2S2Dv; \
int64_t s1S2; \
int64_t gS1; \
int64_t gD; \
int64_t n2D; \
int64_t bN2D; \
int64_t gDv; \
int64_t n2Dv; \
int64_t bN2Dv; \
int64_t n2G; \
int64_t n2GD; \
int64_t bN2GD; \
int64_t n2GDv; \
int64_t bN2GDv; \
int64_t gS2; \
int64_t s1Dr; \
int64_t gS1Dr; \
int64_t n2GS1Dr; \
int64_t s2Dr; \
int64_t n2S2Dr; \
int64_t gDr; \
int64_t n2Dr; \
int64_t bN2Dr; \
int64_t n2GDr; \
int64_t bN2GDr; \
int32_t s2BaseN2D; \
int32_t s1BaseN2GD; \
int64_t s2BaseBN2D; \
int64_t s1BaseBN2GD; \
int32_t s1BaseD; \
int32_t s2BaseD; \
int64_t s2BaseN2Dv; \
int64_t s2BaseBN2Dv; \
int64_t s1BaseN2GDv; \
int64_t s1BaseBN2GDv; \
int32_t s1BaseDv; \
int32_t s2BaseDv; \
int64_t s1OuterSize; \
\
int64_t mm1Ka; \
int64_t mm1Kb; \
int64_t mm2Kb; \
\
int64_t attentionOutStride; \
uint32_t aivIdx; \
uint8_t layoutType; \
uint8_t subBlockIdx;\
bool softMaxCheckRes; \
float keepProb; \
float scaleValue; \
int64_t matmulMSize; \
bool learnableSinkFlag = false; \
float pScale;\
float sinkValue = SINK_MIN_INF;
#define ROPE_INFO \
\
int64_t s1DR; \
int64_t gS1DR; \
int64_t n2GS1DR; \
int64_t s2DR; \
int64_t n2S2DR; \
int64_t gDR; \
int64_t n2DR; \
int64_t bN2DR; \
int64_t n2GDR; \
int64_t bN2GDR; \
int64_t s2BaseN2DR; \
int64_t s2BaseBN2DR; \
int64_t s1BaseN2GDR; \
int64_t s1BaseBN2GDR; \
int64_t s1BaseDR; \
int64_t s2BaseDR; \
int64_t mm1RopeKa; \
int64_t mm1RopeKb
#define KVPREFIX_INFO \
\
bool isActualSharedPrefixLenNull = true; \
int64_t actualKVPrefixSize = 0; \
int64_t kvPrefixSize = 0; \
int64_t prefixLoopCount = 0
#define INFER_CONST_INFO \
\
bool isRowInvalid; \
bool isActualLenDimsNull; \
bool isActualLenDimsKVNull; \
bool isGqa; \
bool isPfaGS1Merge; \
\
uint32_t actualSeqLenSize; \
uint32_t actualSeqLenKVSize; \
uint32_t isKvContinuous; \
\
uint32_t blockTableDim2; \
uint32_t blockSize; \
uint32_t paLayoutType; \
uint32_t paBlockNumSum; \
uint32_t transposeLayout; \
\
\
uint32_t headNumRatio; \
bool rsvd1; \
bool isSoftmaxLseEnable; \
\
bool isQHasLeftPadding; \
bool isKVHasLeftPadding; \
int64_t queryRightPaddingSize; \
int64_t kvRightPaddingSize; \
\
int64_t sInnerLoopSize; \
int64_t actualCombineLoopSize; \
int64_t splitKVNum; \
\
bool isPostQuantPerChnl; \
bool isPostQuantBF16; \
bool isPostQuantOffsetExist; \
float postQuantScaleValue; \
float postQuantOffsetValue
#define CV_SHARED_PARAMS \
\
uint32_t bSize; \
int64_t t1Size; \
int64_t t2Size; \
uint32_t n2Size; \
uint32_t gSize; \
uint32_t s1Size; \
uint32_t s2Size; \
uint32_t dSize : 16; \
uint32_t dSizeV : 16; \
\
int64_t preTokens; \
int64_t nextTokens; \
uint32_t attenMaskS1Size; \
uint32_t attenMaskS2Size; \
int64_t s1SparseValidSize; \
int64_t s2SparseValidSize; \
\
volatile int64_t multiCoreInnerOffset; \
volatile int64_t multiCoreInnerLimit; \
uint32_t s1OuterSize; \
uint32_t bandIndex; \
uint32_t compressMode : 4; \
uint32_t implMode : 4; \
uint32_t layoutType : 4; \
uint32_t sparseType : 8; \
uint32_t dSizeRope : 11; \
uint32_t splitCoreMode : 1; \
uint32_t coreNum
#define FAG_CV_SHARED_PARAMS \
\
float qScaleDs
struct FagCVSharedParams {
FAG_CV_SHARED_PARAMS;
};
template<bool isInfer = false, bool hasRope = false>
struct ConstInfo;
template<>
struct ConstInfo<true, true> {
COMMON_CONST_INFO;
INFER_CONST_INFO;
ROPE_INFO;
KVPREFIX_INFO;
};
template <>
struct ConstInfo<true, false> {
COMMON_CONST_INFO;
INFER_CONST_INFO;
KVPREFIX_INFO;
};
template <>
struct ConstInfo<false, true> {
COMMON_CONST_INFO;
ROPE_INFO;
int64_t n2GS1o;
int64_t gS1o;
};
template <>
struct ConstInfo<false, false> {
COMMON_CONST_INFO;
int64_t n2GS1o;
int64_t gS1o;
};
template <bool isInfer = false, bool isPa = false>
struct CVSharedParams;
template<>
struct CVSharedParams<false, false> {
CV_SHARED_PARAMS;
int64_t firstFullLoadS1OuterIdx;
int64_t totalSize;
float scaleValue;
};
template<>
struct CVSharedParams<true, false> {
CV_SHARED_PARAMS;
uint32_t fromFused : 1;
uint32_t isGqa : 1;
uint32_t isPfaGS1Merge : 1;
uint32_t isKvContinuous : 1;
uint32_t isRowInvalid : 1;
uint32_t isActualSeqLengthsNull : 1;
uint32_t isActualSeqLengthsKVNull : 1;
uint32_t isQHasLeftPadding : 1;
uint32_t isKVHasLeftPadding : 1;
uint32_t needInit : 1;
uint32_t isPostQuantPerChnl : 1;
uint32_t isPostQuantBF16 : 1;
uint32_t headNumRatio : 20;
uint32_t transposeLayout;
uint32_t actualSeqLengthsSize;
uint32_t actualSeqLengthsKVSize;
uint32_t splitKVNum;
uint32_t bnStartIdx;
uint32_t bnEndIdx;
uint32_t queryRightPaddingSize;
uint32_t kvRightPaddingSize;
bool isActualSharedPrefixLenNull;
int64_t kvPrefixSize;
int64_t totalSize;
};
template<>
struct CVSharedParams<true, true> {
CV_SHARED_PARAMS;
uint32_t fromFused : 1;
uint32_t isGqa : 1;
uint32_t isPfaGS1Merge : 1;
uint32_t isKvContinuous : 1;
uint32_t isRowInvalid : 1;
uint32_t isActualSeqLengthsNull : 1;
uint32_t isActualSeqLengthsKVNull : 1;
uint32_t isQHasLeftPadding : 1;
uint32_t isKVHasLeftPadding : 1;
uint32_t needInit : 1;
uint32_t isPostQuantPerChnl : 1;
uint32_t isPostQuantBF16 : 1;
uint32_t headNumRatio : 20;
uint32_t transposeLayout;
uint32_t actualSeqLengthsSize;
uint32_t actualSeqLengthsKVSize;
uint32_t splitKVNum;
uint32_t bnStartIdx;
uint32_t bnEndIdx;
uint32_t queryRightPaddingSize;
uint32_t kvRightPaddingSize;
int32_t blockSize;
int32_t blockTableDim2;
int32_t paBlockNumSum;
uint32_t paLayoutType;
bool isActualSharedPrefixLenNull;
int64_t kvPrefixSize;
int64_t totalSize;
};
}
#endif