// Copyright (c) 2025 Huawei Technologies Co., Ltd
// All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef INC_EXTERNAL_ATB_INFEROPPARAM_H
#define INC_EXTERNAL_ATB_INFEROPPARAM_H
#include <cstdint>
#include <string>
#include <limits>
#include <hccl/hccl_types.h>
#include <acl/acl.h>
#include "./svector.h"
//!
//! \file infer_op_params.h
//!
//! \brief 定义加速库所有推理算子参数
//!
//!
//! \namespace atb
//!
//! \brief 加速库命名空间.
//!
namespace atb {
namespace infer {
//!
//! \enum InputLayout
//!
//! \brief 数据排布类型
//!
enum InputLayout : int {
TYPE_BSND = 0, //!< 默认值,表示数据排布为BSND
TYPE_BNSD //!< 表示数据排布为BNSD
};
//!
//! \enum QuantType
//!
//! \brief 量化支持的类型
//!
enum QuantType : int {
QUANT_UNDEFINED = 0, //!< 不量化
QUANT_INT4, //!< 当前不支持
QUANT_INT8, //!< int8量化
QUANT_INT16, //!< 当前不支持
QUANT_FLOAT8, //!< 当前不支持
QUANT_FLOAT16, //!< 当前不支持
};
//!
//! \enum DynamicQuantType
//!
//! \brief 动态量化支持的类型
//!
enum DynamicQuantType : int {
DYNAMIC_QUANT_UNDEFINED = 0, //!< 非动态量化
DYNAMIC_QUANT_SYMMETRIC, //!< 对称动态量化
DYNAMIC_QUANT_ASYMMETRIC, //!< 非对称动态量化,暂不支持
};
//!
//! \enum ActivationType
//!
//! \brief 激活支持的类型
//! ACTIVATION_FAST_GELU:快速运算的Gelu激活函数,对Tensor内每个element做Gelu激活函数近似计算,计算速度更快,同时保持较高的准确性。
//! ACTIVATION_SWIGLU_FORWARD: Swiglu正向激活函数。Atlas 推理系列产品中只支持32位对齐的数据。
//! ACTIVATION_FASTER_GELU_FORWARD: 简化后的FastGelu激活函数,计算速度更快。
//! ACTIVATION_SWIGLU_BACKWARD: Swiglu正向激活函数的反向,求梯度时使用。只支持Atlas 800I A2推理产品。
//!
enum ActivationType : int {
ACTIVATION_UNDEFINED = 0, //!< 未定义
ACTIVATION_RELU, //!< RELU激活类型
ACTIVATION_GELU, //!< GELU激活类型
ACTIVATION_FAST_GELU, //!< FAST_GELU激活类型
ACTIVATION_SWISH, //!< SWISH激活类型
ACTIVATION_LOG, //!< LOG激活类型
ACTIVATION_SWIGLU_FORWARD, //!< SWIGLU_FORWARD激活类型
ACTIVATION_SWIGLU_BACKWARD, //!< SWIGLU_BACKWARD激活类型
ACTIVATION_SIGMOID, //!< SIGMOID激活类型
ACTIVATION_FASTER_GELU_FORWARD, //!< FASTER_GELU_FORWARD激活类型
ACTIVATION_MAX, //!< 枚举最大值, 非激活类型
};
//!
//! \enum CommMode
//!
//! \brief 通信算子支持的通信模式.
//!
enum CommMode : int {
COMM_UNDEFINED = -1, //!< 未定义
COMM_MULTI_PROCESS, //!< 指定多进程通信
COMM_MULTI_THREAD, //!< 指定多线程通信
};
//!
//! \struct RmsNormParam
//!
//! \brief RMS归一化处理。
//!
//! \warning 所有输入输出Tensor的最后一维大小相等。
//! Atlas 推理系列产品中不支持bf16类型数据。
//!
struct RmsNormParam {
//!
//! \brief RmsNormType
//!
enum RmsNormType : int {
RMS_NORM_UNDEFINED = 0, //!< 默认值,未定义
RMS_NORM_NORM, //!< NORM参数。
RMS_NORM_PRENORM, //!< PRENORM参数。
RMS_NORM_POSTNORM, //!< POSTNORM参数
};
//!
//! \brief PrecisionMode
//!
enum PrecisionMode : int {
HIGH_PRECISION_MODE = 0, //!< 中间计算使用float类型
HIGH_PERFORMANCE_MODE, //!< 中间计算使用float16类型
};
//!
//! \brief ModelType
//!
enum ModelType : int {
LLAMA_MODEL = 0, //!< 默认值,使用Llama rmsnorm的公式
GEMMA_MODEL, //!< 使用Gemma rmsnorm的公式
};
//!
//! \brief NormParam
//!
struct NormParam {
//! \brief 量化类型。
//! 当前支持以下类型。
//! QUANT_UNDEINFED, QUANT_INT8
QuantType quantType = QUANT_UNDEFINED;
//! \brief Epsilon,归一化时加在分母上防止除零。
float epsilon = 1e-5;
//! \brief Epsilon,默认为1e-5,暂时不使用。
double layerNormEps = 1e-5;
//! \brief 默认为False,设置为true时会使用训练的rmsnormforward算子。仅在Atlas 800I A2推理产品上支持该设置。
//! 不支持和“precisionMode”,“modelType”同时设置。量化场景下不支持使用“rstd”。
bool rstd = false;
//! \brief 默认为HIGH_PRECISION_MODE。
//! 支持参数如下:
//! HIGH_PRECISION_MODE:默认值,中间计算使用float类型
//! HIGH_PERFORMANCE_MODE: 中间计算使用float16类型
//! 不支持和“rstd”,“modelType”同时设置。输入类型只支持float16。
//! 量化场景下不支持使用“precisionMode”,该场景下配置该参数将返回报错ERROR_INVALID_PARAM。
PrecisionMode precisionMode = HIGH_PRECISION_MODE;
//! \brief 默认为LLAMA_MODEL,设置为GEMMA_MODEL时使用gemma模型的rmsnorm计算公式。
//! 支持参数如下:
//! LLAMA_MODEL:默认值, Llama的rms norm计算公式。
//! GEMMA_MODEL:Gemma的rms norm计算公式。
//! 不支持和“rstd”,“precisionMode”同时启用。
//! 量化场景下不支持使用“modelType”,该场景下配置该参数将返回报错ERROR_INVALID_PARAM。
ModelType modelType = LLAMA_MODEL;
//! \brief 动态量化类型。默认为DYNAMIC_QUANT_UNDEFINED非动态量化。当前版本暂不支持非对称动态量化。
DynamicQuantType dynamicQuantType = DYNAMIC_QUANT_UNDEFINED;
//!
//! \brief 预留参数
//!
uint8_t rsv[32] = {0};
};
//!
//! \brief PreNormParam
//!
struct PreNormParam {
//! \brief 量化类型。
//! 当前支持以下类型。
//! QUANT_UNDEINFED
//! QUANT_INT8
QuantType quantType = QUANT_UNDEFINED;
//! \brief Epsilon,归一化时加在分母上防止除零。
float epsilon = 1e-5;
//! \brief 是否叠加偏置。默认为False,当需要输入beta时设置为True。量化场景下不支持使用“hasBias”,该场景下配置该参数将返回报错ERROR_INVALID_PARAM。
bool hasBias = false;
//!
//! \brief 预留参数
//!
uint8_t rsv[23] = {0};
};
//!
//! \brief PostNormParam
//!
struct PostNormParam {
//! \brief 量化类型。
//! 当前仅支持QUANT_UNDEINFED。
QuantType quantType = QUANT_UNDEFINED;
//! \brief Epsilon,归一化时加在分母上防止除零。
float epsilon = 1e-5;
//! \brief 是否叠加偏置。默认为False,当需要输入beta时设置为True。
bool hasBias = false;
//!
//! \brief 预留参数
//!
uint8_t rsv[23] = {0};
};
//! \brief 归一化类型,参数如下:
//! RMS_NORM_UNDEFINED:默认值,未定义。
//! RMS_NORM_NORM:NORM参数。
//! RMS_NORM_PRENORM:PRENORM参数。
//! RMS_NORM_POSTNORM:POSTNORM参数。
RmsNormType layerType = RMS_NORM_UNDEFINED;
//! \brief NORM参数。
NormParam normParam;
//! \brief PRENORM参数。
PreNormParam preNormParam;
//! \brief POSTNORM参数。
PostNormParam postNormParam;
//!
//! \brief 预留参数
//!
uint8_t rsv[8] = {0};
};
//!
//! \struct LinearParam
//!
//! \brief 将A、B两个矩阵进行矩阵乘运算,同时可以选择对矩阵乘的运算结果进行叠加偏置、InplaceAdd融合或反量化操作。
//!
//! \note 算子本质上是接收x和weight两个输入tensor作为A矩阵和B矩阵进行矩阵乘运算,可通过参数transposeA与transposeB控制做矩
//! 阵乘前是否需要对A矩阵和B矩阵进行行列转置,根据参数转置后的A矩阵和B矩阵需满足矩阵乘维度关系。例如,当transposeA为false,
//! transposeB为true时,x和weight的shape可以分别为[m, k]和[n, k]。
//!
//! \note 该算子支持浮点和量化场景,当参数outDataType值为ACL_DT_UNDEFINED时为浮点场景,否则为量化场景。
//!
struct LinearParam {
//!
//! \brief 是否转置A矩阵。
//!
//! \note 默认值为false,不转置。
//!
//! \warning 在量化场景下,非Atlas 800I A2推理产品仅支持配置为false。
//!
bool transposeA = false;
//!
//! \brief 是否转置B矩阵。
//!
//! \note 默认值为true,转置。
//!
//! \warning 在量化场景下,非Atlas 800I A2推理产品仅支持配置为true。
//!
bool transposeB = true;
//!
//! \brief 是否叠加偏置。
//!
//! \note 默认值为true,叠加偏置。
//!
//! \warning 在量化场景下,非Atlas 800I A2推理产品仅支持配置为true。
//!
//! \warning enAccum为true时,仅支持配置为false。
//!
bool hasBias = true;
//!
//! \brief 输出数据类型。
//!
//! \note 默认值为ACL_DT_UNDEFINED。
//!
//! \warning 浮点场景下:支持配置为ACL_DT_UNDEFINED。
//!
//! \warning 量化场景下:Atlas 800I A2推理产品支持配置为ACL_FLOAT16/ACL_BF16,否则,仅支持配置为ACL_FLOAT16。
//!
aclDataType outDataType = ACL_DT_UNDEFINED;
//!
//! \brief 是否使能累加。
//!
//! \note 默认值为false,不使能累加。
//!
//! \warning 仅在Atlas 800I A2推理产品支持配置为true。
//!
//! \warning hasBias为true时,仅支持配置为false。
//!
//! \warning 量化场景下,仅支持配置为false。
//!
bool enAccum = false;
//!
//! \brief 预留参数
//!
uint8_t rsv[23] = {0};
};
struct GroupTopkParam {
//!
//! \brief 每个token分组数量。注:“专家总数”为inTensor0Desc.shape.dims[1]的值。
//!
//! \note 必传,默认值为1,取值范围为[1, 专家总数]。
//!
//! \warning groupNum需要保证可以被inTensor0Desc.shape.dims[1]整除。
//!
int32_t groupNum = 1;
//!
//! \brief 选择top K组数量。
//!
//! \note 必传,默认值为0,取值范围为[1, groupNum]。
//!
//! \warning
//!
int32_t k = 0;
//!
//! \enum GroupMultiFlag
//!
//! \brief 指定GroupTopk每组中取值计算的方式。
//!
//! \warning
//!
enum GroupMultiFlag : uint16_t {
UNDEFINED = 0, //!< 默认方式,每组内取最大值。
SUM_MULTI_MAX //!< 每组内取n个最大值求和,需要设置参数n
};
//!
//! \brief 指定GroupTopk每组中取值计算的方式。
//!
//! \note 默认值为UNDEFINED。
//!
//! \warning 取值为SUM_MULTI_MAX时需要传入参数n。
//!
GroupMultiFlag groupMultiFlag = UNDEFINED;
//!
//! \brief 每组内取值的个数。
//!
//! \note 默认值为1,取值范围为[1,expert_num/groupNum]。
//!
//! \warning 只有当groupMultiFlag为SUM_MULTI_MAX时有效
//!
uint16_t n = 1;
//!
//! \brief 预留参数
//!
uint8_t rsv[12] = {0};
};
//!
//! \brief PagedAttention.
//!
//! 一个Q有多个token,一个token对应多个KV的token,以token0为例,block_table代表其对应的KV的block_id,-1代表截止,
//! 所以第二行和第四行为其目标block,context_lens则表示KV有多少个token,则代表仅有block_id为(3,4,5,9,10)是需要与Q进行计算的。
//!
struct PagedAttentionParam {
//! query 头大小
int32_t headNum = 0;
//! 算子tor值, 在Q*K^T后乘
float qkScale = 1.0;
//! kv头数量
int32_t kvHeadNum = 0;
//!
//! \enum MaskType
//!
//! \brief The type values of MaskType.
//!
enum MaskType : int {
UNDEFINED = 0, //!< 默认值,全0的mask
MASK_TYPE_NORM, //!< 倒三角mask
MASK_TYPE_ALIBI, //!< alibi mask
MASK_TYPE_SPEC //!< 并行解码mask
};
//! mask类型
MaskType maskType = UNDEFINED;
//! 是否开启动态batch
bool batchRunStatusEnable = false;
//!
//! \enum QuantType
//!
//! \brief quant类型
//!
enum QuantType : int {
TYPE_QUANT_UNDEFINED = 0, //!< 默认值,不与量化融合,此
TYPE_DEQUANT_FUSION, //!< 与反量化融合, 只支持Atlas 800I A2推理产品
TYPE_QUANT_QKV_OFFLINE, //!< 离线INT8量化, 只支持Atlas 800I A2推理产品
TYPE_QUANT_QKV_ONLINE //!< 在线INT8量化, 只支持Atlas 800I A2推理产品
};
//!
//! 量化类型:
//! 为TYPE_QUANT_UNDEFINED时q,keyCache,valueCache为bf16/float16。
//! 为TYPE_DEQUANT_FUSION时q为bf16/float16,keyCache,valueCache为int8。
//! 为TYPE_QUANT_QKV_OFFLINE或TYPE_QUANT_QKV_ONLINE时q,keyCache,valueCache为int8。
//! keyCache,valueCache的headsize等长,范围为(0, 256],且block_size * head_size ≤ 128 * 128。
//! outdatatype需要配置,只能是ACL_FLOAT16或ACL_BF16。inputLayout只支持TYPE_BSND。
QuantType quantType = TYPE_QUANT_UNDEFINED;
//! output数据类型(格式为aclDataType)
aclDataType outDataType = ACL_DT_UNDEFINED;
//! 开启量化功能后是否使用offset
bool hasQuantOffset = false;
//!
//! \enum CompressType
//!
//! \brief 压缩类型
//!
enum CompressType : int {
COMPRESS_TYPE_UNDEFINED = 0, //!< 默认值,不压缩
COMPRESS_TYPE_KVHEAD, //!< 压缩key_cache, value_cache的kvHead维度, 只支持Atlas 800I A2推理产品。
COMPRESS_TYPE_KVHEAD_ROPE, //!< rope场景压缩key_cache, value_cache的kvHead维度, 只支持Atlas 800I A2推理产品。
COMPRESS_TYPE_MAX //!< 压缩类型边界值,仅用于判断是否出界,所有情况不能取该值。
};
//!
//! 压缩方式
//! 为COMPRESS_TYPE_KVHEAD时,不支持quanttype为2和3。
//! 为COMPRESS_TYPE_KVHEAD_ROPE时, maskType需传0。不支持quanttype为2和3。
CompressType compressType = COMPRESS_TYPE_UNDEFINED;
//!
//! \enum CalcType
//!
//! \brief The type values of CalcType.
//!
enum CalcType : int {
CALC_TYPE_UNDEFINED = 0, //!< 默认值,不开启并行解码
CALC_TYPE_SPEC //!< 并行解码功能,此时只支持quantType = 0
};
//! 计算类型
CalcType calcType = CALC_TYPE_UNDEFINED;
//!
//! \enum ScaleType
//!
//! \brief The type values of ScaleType.
//!
enum ScaleType : int {
SCALE_TYPE_TOR = 0, //!< 默认值,不开启LogN缩放
SCALE_TYPE_LOGN, //!< 注意力使用LogN缩放
SCALE_TYPE_MAX //!< 边界值,仅用于判断是否出界
};
//! scale类型
//! 为SCALE_TYPE_LOGN时,不支持quanttype为2和3。
ScaleType scaleType = SCALE_TYPE_TOR;
//! 数据排布格式默认为BSND
InputLayout inputLayout = TYPE_BSND;
//! \brief 大于0时开启MLA合并kvcache功能,表示kv合并传入时v的head_size
//! \note 默认值为0
//! \warning 取值范围为[0,576]
uint32_t mlaVHeadSize = 0;
//!
//! \brief 预留参数
//!
uint8_t rsv[68] = {0};
};
//!
//! \brief 遍历每个key和value,将key和value(num_heads, head_size)按照slotmapping填入key_cache/value_cache指定位置
//!
struct ReshapeAndCacheParam {
//!
//! \enum CompressType
//!
//! \brief 压缩类型
//!
//! \note 默认值为COMPRESS_TYPE_UNDEFINED(0),不开启压缩功能。
//!
//! \warning 仅在Atlas 800I A2推理产品上支持设置为非COMPRESS_TYPE_UNDEFINED(0)的值
//!
enum CompressType : int {
COMPRESS_TYPE_UNDEFINED = 0, //!< 默认值,不压缩
COMPRESS_TYPE_KVHEAD, //!< alibi场景下压缩key_cache, value_cahe的kvHead维度
COMPRESS_TYPE_KVHEAD_ROPE //!< rope场景下压缩key_cache, value_cahe的kvHead维度
};
//!
//! \enum KvCacheCfg
//!
//! \brief KvCache配置
//!
//! \note 默认值为K_CACHE_V_CACHE(0),传入key_cache和value_cache
//!
//! \warning 仅在Atlas 800I A2推理产品上支持设置为K_CACHE_V_BYPASS(1)
//!
enum KvCacheCfg : int {
K_CACHE_V_CACHE = 0, //!< 默认值,传入key_cache和value_cache
K_CACHE_V_BYPASS, //!< 只传入key_cache
K_CACHE_V_CACHE_NZ //!< 传入key_cache和value_cache,且为NZ格式
};
//! 压缩方式
CompressType compressType = COMPRESS_TYPE_UNDEFINED;
//! kvcache配置
KvCacheCfg kvCacheCfg = K_CACHE_V_CACHE;
//!
//! \brief 预留参数
//!
uint8_t rsv[16] = {0};
};
//!
//! \brief 旋转位置编码。hiddenSizeQ必须是hiddenSizeK的整数倍且满足hiddenSizeQ = headDim * headNum。
//!
struct RopeParam {
//! \brief rope,旋转系数,对半旋转是2,支持配置2、4或headDim / 2。
int32_t rotaryCoeff = 4;
//! \brief 训练用参数,支持配置0或1
int32_t cosFormat = 0;
//!
//! \brief 预留参数
//!
uint8_t rsv[8] = {0};
};
//!
//! \brief 判断参数是否相同
//!
//! \param left
//! \param right
//! \return bool
//!
inline bool operator==(const RopeParam &left, const RopeParam &right)
{
return left.rotaryCoeff == right.rotaryCoeff && left.cosFormat == right.cosFormat;
}
//!
//! \brief KVCache+KVCache+Muls+FlashAttention.
//!
struct SelfAttentionParam {
//!
//! \enum CalcType
//!
//! \brief 计算类型
//!
enum CalcType : int {
UNDEFINED = 0, //!< decoder&encoder for flashAttention
ENCODER, //!< encoder for flashAttention
DECODER, //!< decoder for flashAttention
PA_ENCODER, //!< encoder for pagedAttention
PREFIX_ENCODER, //!< prefix encoder for flashAttention
};
//!
//! \enum KernelType
//!
//! \brief 算子内核精度类型
//!
enum KernelType : int {
KERNELTYPE_DEFAULT = 0, //!< i:float16, bmm:float16, o:float16
KERNELTYPE_HIGH_PRECISION //!< i:float16, bmm:float, o:float16
};
//!
//! \enum ClampType
//!
//! \brief clamp类型
//!
enum ClampType : int {
CLAMP_TYPE_UNDEFINED = 0, //!< 不做clamp
CLAMP_TYPE_MIN_MAX //!< 做clamp,同时指定最大最小值
};
//!
//! \enum MaskType
//!
//! \brief mask类型
//!
enum MaskType : int {
MASK_TYPE_UNDEFINED = 0, //!< 默认值,全0mask
MASK_TYPE_NORM, //!< 倒三角mask
MASK_TYPE_ALIBI, //!< alibi mask
MASK_TYPE_NORM_COMPRESS, //!< 倒三角压缩mask
MASK_TYPE_ALIBI_COMPRESS, //!< alibi压缩mask
MASK_TYPE_ALIBI_COMPRESS_SQRT, //!< alibi压缩开平方mask
MASK_TYPE_ALIBI_COMPRESS_LEFT_ALIGN, //!< alibi压缩mask左对齐,只支持Atlas 800I A2推理产品
MASK_TYPE_SLIDING_WINDOW_NORM, //!< sliding window attention mask
MASK_TYPE_SLIDING_WINDOW_COMPRESS //!< sliding window attention压缩mask
};
//!
//! \enum KvCacheCfg
//!
//! \brief KvCache配置,不支持calcType为PA_ENCODER
//!
enum KvCacheCfg : int {
K_CACHE_V_CACHE = 0, //!< 默认值,进行kvcache处理
K_BYPASS_V_BYPASS, //!< 直接传入kvcache
};
//!
//! \enum ScaleType
//!
//! \brief The type values of ScaleType.
//!
enum ScaleType : int {
SCALE_TYPE_TOR = 0, //!< 默认值,不开启LogN缩放
SCALE_TYPE_LOGN, //!< 注意力使用LogN缩放,quantType只能是0
SCALE_TYPE_MAX //!< 边界值,仅用于判断是否出界
};
//! \enum QuantType
//!
//! \brief quant类型
//!
enum QuantType : int {
TYPE_QUANT_UNDEFINED = 0, //!< 默认值,不与量化融合,此时q,k,v为bf16/float16
TYPE_QUANT_UNQUANT = 0, //!< 默认值,不与量化融合,此时q,k,v为bf16/float16
TYPE_DEQUANT_FUSION = 1, //!< 与反量化融合, 预留类型,当前不能够取此值。
TYPE_QUANT_QKV_OFFLINE = 2, //!< 离线INT8量化, 只支持Atlas 800I A2推理产品
TYPE_QUANT_QKV_ONLINE = 3 //!< 在线INT8量化, 只支持Atlas 800I A2推理产品
};
//!
//! \enum CacheType
//!
//! \brief cache内部排布类型, 为CACHE_TYPE_SWA开启SWA KVCache优化,只储存后windowSize个token的KVCache,
//! 控制KVCache的长度不超过windowSize, 以此减少显存占用
//!
enum CacheType : int8_t {
CACHE_TYPE_NORM = 0, //!< 正常cache
CACHE_TYPE_SWA = 1 //!< 固定长度cache
};
//!
//! 量化类型(只支持PA_ENCODER):
//! 当值为TYPE_QUANT_QKV_OFFLINE或TYPE_QUANT_QKV_ONLINE时q,k,v为int8。key,value的headsize等长,范围为(0, 256],
//! 且32对齐。outdatatype需要配置,只能是ACL_FLOAT16或ACL_BF16。inputLayout只支持TYPE_BSND,calcType只能为PA_ENCODER。
QuantType quantType = TYPE_QUANT_UNQUANT;
//! output数据类型:只支持PA_ENCODER,且QuantType不为TYPE_QUANT_UNQUANT(格式为aclDataType)
aclDataType outDataType = ACL_DT_UNDEFINED;
//! query头大小, 需大于0
int32_t headNum = 0;
//! kv头数量, 该值需要用户根据使用的模型实际情况传入
//! kvHeadNum = 0时,keyCache的k_head_num,valueCache的v_head_num与query的num_heads一致,均为num_heads的数值
//! kvHeadNum != 0时,keyCache的k_head_num, valueCache的v_head_num与kvHeadNum值相同
int32_t kvHeadNum = 0;
//! query缩放系数
float qScale = 1;
//! 算子tor值, 在Q*K^T后乘
float qkScale = 1;
//! 是否开启动态batch
bool batchRunStatusEnable = false;
//! 是否开启倒三角优化, 只有mask为倒三角的时候才能开启优化
uint32_t isTriuMask = 0;
//! 计算类型
CalcType calcType = UNDEFINED;
//! 内核精度类型
KernelType kernelType = KERNELTYPE_DEFAULT;
//! clamp类型
ClampType clampType = CLAMP_TYPE_UNDEFINED;
//! clamp功能最小值
float clampMin = 0;
//! clamp功能最大值
float clampMax = 0;
//! mask类型
MaskType maskType = MASK_TYPE_UNDEFINED;
//! kvcache配置
KvCacheCfg kvcacheCfg = K_CACHE_V_CACHE;
//! scale类型
ScaleType scaleType = SCALE_TYPE_TOR;
//! 数据排布格式默认为BSND
InputLayout inputLayout = TYPE_BSND;
//! \brief 大于0时开启MLA合并kvcache功能,表示kv合并传入时v的head_size
//! \note 默认值为0
//! \warning 取值范围为[0,576]
uint32_t mlaVHeadSize = 0;
//! \brief cache内部排布,开启SWA特性并设置为CACHE_TYPE_SWA可以开启SWA cache优化
//! \note 默认值为CACHE_TYPE_NORM
//! \warning 只有开启SWA特性后才可以是CACHE_TYPE_SWA
CacheType cacheType = CACHE_TYPE_NORM;
//! \brief windowSize大于0时开启SWA特性,开启SWA特性后表示sliding window 大小
//! \note 默认值为0
//! \warning windowSize大于0时需要将maskType设置为MASK_TYPE_SLIDING_WINDOW_NORM或MASK_TYPE_SLIDING_WINDOW_COMPRESS
uint32_t windowSize = 0;
//!
//! \brief 预留参数
//!
uint8_t rsv[64] = {0};
};
//!
//! \struct ElewiseParam
//!
//! \brief 常用的逐元素数值计算集合
//!
//! ELEWISE_ADD、ELEWISE_MUL、ELEWISE_REALDIV、ELEWISE_SUB计算类型将会对输入进行广播后再进行指定操作。
//! 输入x、y对应维度的对应值要求相同或至少其中一个为1
//!
struct ElewiseParam {
//!
//! \enum ElewiseType
//!
//! \brief 计算类型
//!
enum ElewiseType : int {
ELEWISE_UNDEFINED = 0, //!< 默认值,未定义
ELEWISE_CAST, //!< 数据类型转换
ELEWISE_MULS, //!< 向量逐元素乘值
ELEWISE_COS, //!< 逐元素计算余弦值
ELEWISE_SIN, //!< 逐元素计算正弦值
ELEWISE_NEG, //!< 逐元素取相反数
ELEWISE_QUANT, //!< 量化, 仅在Atlas 800I A2推理产品上支持
ELEWISE_LOGICAL_NOT, //!< 逐元素逻辑非
ELEWISE_ADD, //!< 逐元素相加
ELEWISE_MUL, //!< 向量与向量逐元素相乘
ELEWISE_REALDIV, //!< 向量与向量逐元素相除
ELEWISE_LOGICAL_AND, //!< 逐元素逻辑与
ELEWISE_LOGICAL_OR, //!< 逐元素逻辑或
ELEWISE_LESS, //!< 逐元素判断是否小于
ELEWISE_GREATER, //!< 逐元素判断是否大于
ELEWISE_SUB, //!< 逐元素相减
ELEWISE_EQUAL, //!< 逐元素判断是否相等
ELEWISE_QUANT_PER_CHANNEL, //!< 每个通道量化
ELEWISE_DEQUANT_PER_CHANNEL, //!< 每个通道反量化
ELEWISE_DYNAMIC_QUANT, //!< 逐行动态量化
ELEWISE_TANH, //!< 逐元素计算双曲正切值
ELEWISE_TYPE_MAX //!< 边界值,仅用于判断是否出界,所有情况不能取该值
};
//! 量化(非每通道)所需参数
struct QuantParam {
//! 量化的步长
float inputScale = 1.0f;
//! 动态量化的是否为非对称量化
bool asymmetric = false; //!< false : symmetric,true : asymmetric
//! 量化的偏移度
int inputOffset = 0;
//!
//! \brief 预留参数
//!
uint8_t rsv[20] = {0};
};
//! 向量乘值所需参数
struct MulsParam {
//! 向量乘的值
float varAttr = 0.0f;
//!
//! \brief 预留参数
//!
uint8_t rsv[12] = {0};
};
//! 计算方式
ElewiseType elewiseType = ELEWISE_UNDEFINED;
//! 量化参数
QuantParam quantParam;
//! 乘值参数
MulsParam mulsParam;
//! 指定数据类型转换输出的数据类型
aclDataType outTensorType = ACL_DT_UNDEFINED;
//!
//! \brief 预留参数
//!
uint8_t rsv[8] = {0};
};
} // namespace infer
} // namespace atb
#endif