* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file fia_tiling_info.h
* \brief
*/
#ifndef FIA_TILING_INFO_H
#define FIA_TILING_INFO_H
#include <vector>
#include "fia_tiling_base.h"
namespace optiling {
const std::string ACTUAL_SEQ_KV_LEN_NAME = "the key/value's actual sequence lengths";
const std::string ACTUAL_SEQ_Q_LEN_NAME = "the query's actual sequence lengths";
const std::string ATTEN_MASK_NAME = "atten_mask";
const std::string ATTEN_OUT_NAME = "attention_out";
const std::string BLOCK_SIZE_NAME = "block_size";
const std::string BLOCK_TABLE_NAME = "block_table";
const std::string DEQUANT_SCALE_QUERY_NAME = "the query's dequant scale";
const std::string INNER_PRECISE_NAME = "inner_precise";
const std::string KEY_NAME = "key";
const std::string KEY_ANTIQUANT_MODE_NAME = "the key's quant mode";
const std::string KEY_ANTIQUANT_OFFSET_NAME = "the key's quant offset";
const std::string KEY_ANTIQUANT_SCALE_NAME = "the key's quant scale";
const std::string KEY_ROPE_NAME = "key_rope";
const std::string KEY_ROPE_ANTIQUANT_SCALE_NAME = "the key_rope's dequant scale";
const std::string KV_HEADS_NUM_NAME = "the key/value's heads num";
const std::string NEXT_TOKENS_NAME = "next_tokens";
const std::string PRE_TOKENS_NAME = "pre_tokens";
const std::string PSE_SHIFT_NAME = "pse_shift";
const std::string QUANT_OFFSET2_NAME = "the output's dequant offset";
const std::string QUANT_SCALE2_NAME = "the output's dequant scale";
const std::string QUERY_NAME = "query";
const std::string QUERY_HEADS_NUM_NAME = "the query's heads num";
const std::string QUERY_QUANT_MODE_NAME = "the query's quant mode";
const std::string QUERY_ROPE_NAME = "query_rope";
const std::string SOFTMAX_SCALE_NAME = "the softmax's scale";
const std::string SPARSE_MODE_NAME = "sparse_mode";
const std::string VALUE_NAME = "value";
const std::string VALUE_ANTIQUANT_MODE_NAME = "the value's quant mode";
const std::string VALUE_ANTIQUANT_OFFSET_NAME = "the value's dequant offset";
const std::string VALUE_ANTIQUANT_SCALE_NAME = "the value's dequant scale";
const std::string ANTIQUANT_MODE_NAME = "antiquant_mode";
const std::string ANTIQUANT_SCALE_NAME = "antiquant_scale";
const std::string ANTIQUANT_OFFSET_NAME = "antiquant_offset";
const std::string DEQUANT_SCALE1_NAME = "dequant_scale1";
const std::string DEQUANT_SCALE2_NAME = "dequant_scale2";
const std::string KEY_SHARED_PREFIX_NAME = "key_shared_prefix";
const std::string KV_PADDING_SIZE_NAME = "kv_padding_size";
const std::string QUANT_SCALE1_NAME = "quant_scale1";
const std::string QUERY_PADDING_SIZE_NAME = "query_padding_size";
const std::string SOFTMAX_LSE_NAME = "softmax_lse";
const std::string VALUE_SHARED_PREFIX_NAME = "value_shared_prefix";
const std::string ACTUAL_SHARED_PREFIX_LEN_NAME = "actual_shared_prefix_len";
const std::string LEARNABLE_SINK_NAME = "learnable_sink";
constexpr int32_t SPARSE_MODE_NO_MASK = 0;
constexpr int32_t SPARSE_MODE_ALL_MASK = 1;
constexpr int32_t SPARSE_MODE_LEFT_UP = 2;
constexpr int32_t SPARSE_MODE_RIGHT_DOWN = 3;
constexpr int32_t SPARSE_MODE_BAND = 4;
constexpr int32_t SPARSE_MODE_TREE = 9;
enum class FiaLayout : uint32_t {
BSH = 0,
BSND = 1,
BNSD = 2,
NZ = 3,
TND = 4,
NBSD = 5,
NTD = 6,
S1S2 = 7,
BS2 = 8,
BnBsH = 11,
BnNBsD = 12,
BNS1S2 = 13,
INS1S2 = 14,
BNS11 = 15,
TN1 = 16,
BS1S2 = 18,
B1S1S2 = 19,
IS1S2 = 20,
I1S1S2 = 21,
S1S1 = 22,
};
enum class FiaAxis : uint32_t {
B = 0,
S = 1,
N = 2,
D = 3,
H = 4,
T = 5,
D1 = 6,
D0 = 7,
S1 = 8,
S2 = 9,
Bn = 10,
Bs = 11,
CONST = 12
};
enum class FiaQuantMode : uint32_t {
NO_QUANT = 0,
ANTI_QUANT,
FULL_QUANT
};
enum class KvStorageMode : uint32_t {
BATCH_CONTINUOUS = 0,
TENSOR_LIST = 1,
PAGE_ATTENTION = 2
};
enum class RopeMode : uint32_t {
NO_ROPE = 0,
ROPE_SPLIT = 1,
ROPE_COMBINE = 2
};
enum class MlaMode : uint32_t {
NO_MLA = 0,
ROPE_SPLIT_D512 = 1,
ROPE_SPLIT_D128 = 2,
ROPE_COMBINE_D128 = 3,
};
enum class FiaTilingInOutMode : uint32_t {
IO_INVALID = 0,
INT8_INT8 = 1,
FP16_INT8 = 2,
INT8_FP16 = 3,
FP16_FP16 = 4,
BF16_BF16 = 5,
FP32_FP32 = 6,
FP16_FP16_SPLITKV = 7,
BF16_INT8 = 8,
INT8_BF16 = 9,
};
enum class TilingKeyLayout : uint32_t {
BSH_BSND = 0,
BNSD = 1,
NZ = 2,
TND = 3,
NBSD = 4,
NTD = 5
};
enum class FiaTemplateId : uint32_t {
EMPTY_TENSOR = 0,
HIGH_PERFORMANCE_GQA = 3,
GENERAL_GQA = 4,
HIGH_PERFORMANCE_MLA = 5
};
enum class FiaFullQuantMode : uint32_t{
NO_FULL_QUANT = 0,
PER_TENSOR_FULL_QUANT = 1,
PER_BLOCK_FULL_QUANT = 2,
MXFP8_FULL_QUANT = 3,
PER_TOKEN_HEAD_FULL_QUANT = 4,
};
std::string LayoutToSerialString(FiaLayout layout);
std::string AxisToSerialString(FiaAxis axis);
std::string QuantModeToSerialString(FiaQuantMode fiaQuantMode);
std::string SituationToSerialString(RopeMode ropeMode);
struct FIARequiredParaInfo {
const gert::CompileTimeTensorDesc *desc;
const gert::StorageShape *shape;
};
struct FIAOptionalParaInfo {
const gert::CompileTimeTensorDesc *desc;
const gert::Tensor *tensor;
};
struct FIAParaInfo {
FIARequiredParaInfo query = {nullptr, nullptr};
FIARequiredParaInfo key = {nullptr, nullptr};
FIARequiredParaInfo value = {nullptr, nullptr};
FIAOptionalParaInfo pseShift = {nullptr, nullptr};
FIAOptionalParaInfo attenMask = {nullptr, nullptr};
FIAOptionalParaInfo actualSeqLengthsQ = {nullptr, nullptr};
FIAOptionalParaInfo actualSeqLengths = {nullptr, nullptr};
FIAOptionalParaInfo deqScale1 = {nullptr, nullptr};
FIAOptionalParaInfo quantScale1 = {nullptr, nullptr};
FIAOptionalParaInfo deqScale2 = {nullptr, nullptr};
FIAOptionalParaInfo quantScale2 = {nullptr, nullptr};
FIAOptionalParaInfo quantOffset2 = {nullptr, nullptr};
FIAOptionalParaInfo antiquantScale = {nullptr, nullptr};
FIAOptionalParaInfo antiquantOffset = {nullptr, nullptr};
FIAOptionalParaInfo blockTable = {nullptr, nullptr};
FIAOptionalParaInfo queryPaddingSize = {nullptr, nullptr};
FIAOptionalParaInfo kvPaddingSize = {nullptr, nullptr};
FIAOptionalParaInfo keyAntiquantScale = {nullptr, nullptr};
FIAOptionalParaInfo keyAntiquantOffset = {nullptr, nullptr};
FIAOptionalParaInfo valueAntiquantScale = {nullptr, nullptr};
FIAOptionalParaInfo valueAntiquantOffset = {nullptr, nullptr};
FIAOptionalParaInfo keySharedPrefix = {nullptr, nullptr};
FIAOptionalParaInfo valueSharedPrefix = {nullptr, nullptr};
FIAOptionalParaInfo actualSharedPrefixLen = {nullptr, nullptr};
FIAOptionalParaInfo queryRope = {nullptr, nullptr};
FIAOptionalParaInfo keyRope = {nullptr, nullptr};
FIAOptionalParaInfo keyRopeAntiquantScale = {nullptr, nullptr};
FIAOptionalParaInfo dequantScaleQuery = {nullptr, nullptr};
FIAOptionalParaInfo learnableSink = {nullptr, nullptr};
FIAOptionalParaInfo qStartIdx = {nullptr, nullptr};
FIAOptionalParaInfo kvStartIdx = {nullptr, nullptr};
FIARequiredParaInfo attenOut = {nullptr, nullptr};
FIARequiredParaInfo lseOut = {nullptr, nullptr};
const int32_t *numHeads = nullptr;
const int64_t *preToken = nullptr;
const int64_t *nextToken = nullptr;
const float *scaleValue = nullptr;
const int32_t *kvHeadNums = nullptr;
const char *layOut = nullptr;
const int32_t *blockSize = nullptr;
const int32_t *innerPrecise = nullptr;
const int64_t *antiquantMode = nullptr;
const bool *softmaxLseFlag = nullptr;
const int64_t *keyAntiquantMode = nullptr;
const int64_t *valueAntiquantMode = nullptr;
const int32_t *sparseMode = nullptr;
const int64_t *queryQuantMode = nullptr;
const int64_t *pseType = nullptr;
};
class FiaTilingInfo : public TilingInfo {
public:
const char *opName = nullptr;
fe::PlatFormInfos *platformInfo = nullptr;
FIAParaInfo opParamInfo;
platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B;
NpuArch npuArch = NpuArch::DAV_2201;
uint32_t bSize = 0;
uint32_t n1Size = 0;
uint32_t n2Size = 0;
uint32_t s1Size = 0;
int64_t s2Size = 0;
uint32_t qkHeadDim = 0;
uint32_t vHeadDim = 0;
uint32_t gSize = 0;
uint32_t ropeHeadDim = 0;
uint32_t qTSize = 0;
uint32_t kTSize = 0;
float scaleValue = 0;
int32_t innerPrecise = 0;
uint32_t l2CacheOffFlag = 0;
uint64_t l2CacheSize = 0;
std::vector<gert::StorageShape *> kCache = {};
std::vector<gert::StorageShape *> vCache = {};
std::vector<int32_t> qSize = {};
std::vector<int32_t> kvSize = {};
bool emptyTensorFlag = false;
uint64_t totalOutputSize = 0;
uint64_t totalLseSize = 0;
bool pageAttentionFlag = false;
int32_t blockSize = 0;
uint32_t blockTypeSize = 0;
uint32_t maxBlockNumPerBatch = 0;
uint32_t totalBlockNum = 0;
uint32_t kvLayoutType = 0;
bool antiQuantFlag = false;
uint32_t msdIterNum = 1;
uint32_t antiqSeqSize = 0;
uint32_t antiquantMode = 0;
uint32_t keyAntiquantMode = 0;
uint32_t valueAntiquantMode = 0;
bool antiquantSplit = false;
bool sysPrefixFlag = false;
uint32_t systemPrefixMaxLen = 0;
uint32_t systemPrefixLen = 0;
uint32_t actualLenQDims = 0;
int64_t maxActualseq = 0;
bool isAccumQSeq = false;
bool actualSeqLenFlag = false;
bool isSameSeqAllKVTensor = true;
bool isSameActualseq = true;
uint32_t actualLenDims = 0;
std::vector<int64_t> kvListSeqLens {};
bool isAccumKVSeq = false;
bool pseShiftFlag = false;
uint32_t pseShiftByBatch = 0U;
uint32_t pseShiftS1 = 0U;
uint32_t pseShiftS2 = 0U;
bool enableAlibiPse = false;
bool attenMaskFlag = false;
bool isExistRowInvalid = false;
uint32_t attenMaskBatchStride = 0;
uint32_t attenMaskStride = 0;
int32_t sparseMode = 0;
int64_t preToken = 0;
int64_t nextToken = 0;
bool isOutQuantPerChnOut = false;
bool isOutQuantTypeBf16 = false;
bool isOutQuantEnable = false;
bool batchContinuousFlag = true;
bool kvPaddingSizeFlag = false;
bool qPaddingSizeFlag = false;
bool softmaxLseFlag = false;
bool quantFlag = false;
bool isMaxWorkspace = false;
bool isLegacyIfa = false;
bool needInit = false;
bool slidingFlag = false;
bool learnableSinkFlag = false;
bool isQKVDDifferent = false;
FiaTilingInOutMode inOutMode = FiaTilingInOutMode::FP16_FP16;
ge::DataType inputQType = ge::DT_FLOAT16;
ge::DataType inputKvType = ge::DT_FLOAT16;
ge::DataType outputType = ge::DT_FLOAT16;
ge::DataType inputQRopeType = ge::DT_FLOAT16;
ge::DataType inputKRopeType = ge::DT_FLOAT16;
TilingKeyLayout inputKvLayout = TilingKeyLayout::BSH_BSND;
TilingKeyLayout inputLayout = TilingKeyLayout::BSH_BSND;
TilingKeyLayout outputLayout = TilingKeyLayout::BSH_BSND;
KvStorageMode kvStorageMode = KvStorageMode::BATCH_CONTINUOUS;
RopeMode ropeMode = RopeMode::NO_ROPE;
MlaMode mlaMode = MlaMode::NO_MLA;
FiaQuantMode quantMode = FiaQuantMode::NO_QUANT;
FiaFullQuantMode fullQuantMode = FiaFullQuantMode::NO_FULL_QUANT;
FiaLayout qLayout = FiaLayout::BSND;
FiaLayout outLayout = FiaLayout::BSND;
FiaLayout kvLayout = FiaLayout::BSND;
};
}
#endif