DequantRopeQuantKvcache

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×
Kirin X90 处理器系列产品
Kirin 9030 处理器系列产品

功能说明

  • 算子功能:对输入张量(x)进行dequant(可选)后,按sizeSplits(为切分的长度)对尾轴进行切分,划分为q、k、vOut,对q、k进行旋转位置编码,生成qOut和kOut,之后对kOut和vOut进行量化并按照indices更新到kCacheRef和vCacheRef上。

  • 计算公式:

    dequantX=Dequant(x,weightScaleOptional,activationScaleOptional,biasOptional)dequantX = Dequant(x,weightScaleOptional,activationScaleOptional,biasOptional)

    q,k,vOut=SplitTensor(dequantX,dim=−1,‘sizeSplits‘)q,k,vOut = SplitTensor(dequantX,dim=-1,`sizeSplits`)

    qOut,kOut=ApplyRotaryPosEmb(q,k,cos,sin)qOut,kOut = ApplyRotaryPosEmb(q,k,cos,sin)

    quantK=Quant(kOut,scaleK,offsetKOptional)quantK = Quant(kOut,scaleK,offsetKOptional)

    quantV=Quant(vOut,scaleV,offsetVOptional)quantV = Quant(vOut,scaleV,offsetVOptional)

    如果cacheModeOptional为contiguous则:

    kCacheRef[i][indices[i]]=quantK[i]kCacheRef[i][indices[i]]=quantK[i]

    vCacheRef[i][indices[i]]=quantV[i]vCacheRef[i][indices[i]]=quantV[i]

    如果cacheModeOptional为page则:

    kCacheRefView=kCacheRef.view(−1,kCacheRef[−2],kCacheRef[−1])kCacheRefView=kCacheRef.view(-1,kCacheRef[-2],kCacheRef[-1])

    vCacheRefView=vCacheRef.view(−1,vCacheRef[−2],vCacheRef[−1])vCacheRefView=vCacheRef.view(-1,vCacheRef[-2],vCacheRef[-1])

    kCacheRefView[indices[i]]=quantK[i]kCacheRefView[indices[i]]=quantK[i]

    vCacheRefView[indices[i]]=quantV[i]vCacheRefView[indices[i]]=quantV[i]

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
x 输入 公式中的用于切分的输入`x`,Device侧的aclTensor。 FLOAT16、INT32、BFLOAT16 ND
cos 输入 公式中的用于位置编码的输入`cos`,Device侧的aclTensor。 FLOAT16、BFLOAT16 ND
sin 输入 公式中的用于位置编码的输入`sin`,Device侧的aclTensor。 BFLOAT16、FLOAT16、FLOAT32 ND
kCacheRef 输入 公式中用于缓存k的输入`kCacheRef`,Device侧的aclTensor。 INT8 ND
vCacheRef 输入 公式中用于缓存v的输入`vCacheRef`,Device侧的aclTensor。 INT8 ND
indices 输入 公式中表示Kvcache的token位置信息的输入`indices`。 INT32 ND
scaleK 输入 公式中的输入`scaleK`用于量化`k`的scale因子,Device侧的aclTensor。 FLOAT ND
scaleV 输入 公式中的输入`scaleV`用于量化`v`的scale因子,Device侧的aclTensor。 FLOAT ND
offsetKOptional 可选输入 公式中的输入`offsetKOptional`用于量化k的offset因子,Device侧的aclTensor。 FLOAT ND
offsetVOptional 可选输入 公式中的输入`offsetVoptional`用于量化的offset因子,Device侧的aclTensor。 FLOAT ND
weightScaleOptional 可选输入 公式中的输入`weightScaleoptional`用于反量化的权重scale因子,Device侧的aclTensor。 FLOAT ND
activationScaleOptional 可选输入 公式中的输入`activationScaleOptional`用于反量化的激活scale因子,Device侧的aclTensor。 FLOAT ND
biasOptional 可选输入 公式中的输入用于反量化的偏置`biasOptional`,Device侧的aclTensor。 FLOAT、FLOAT16(HALF)、INT32、BFLOAT16 ND
sizeSplits 输入 表示输入的qkv进行切分的长度。 INT64 -
quantModeOptional 可选输入 Host侧表达式字符串。表示支持的量化类型,目前仅支持static。 String -
layoutOptional 可选输入 Host侧表达式字符串。表示支持的数据格式,目前仅支持BSND。 String -
kvOutput 输入 Host侧表达式布尔值。表示是否输出kOut和vOut。 BOOL -
cacheModeOptional 输入 Host侧表达式字符串。表示kCacheRef的更新方式,目前仅支持page和contiguous,默认为contiguous。 String -
qOut 输出 表示经过处理的q,Device侧的aclTensor。 FLOAT16、BFLOAT16 ND
kOut 输出 表示输入的qkv进行切分的长度。 FLOAT16、BFLOAT16 ND
vOut 输出 表示经过处理的v,Device侧的aclTensor。 FLOAT16、BFLOAT16 ND
  • Kirin X90/Kirin 9030 处理器系列产品: 不支持BFLOAT16。

约束说明

  • cacheModeOptional为contiguous时:kCacheRef的第0维大于x的第0维,indices数据值大于等于0且小于等于vCacheRef的第1维([b,s,n,d]格式中的s)减x的第1维;cacheModeOptional为page时:indices 数据值大于等于0,小于kCacheRef的第0维*第1维且不重复。
  • x的尾轴小于等于4096,且按64对齐。
  • 输入x不为int32时,x、cos、sin与输出qOut、kOut、vOut的数据类型保持一致,此时activationScaleOptional,weightScaleOptional、biasOptional不生效;x为int32时,cos、sin与输出qOut、kOut、vOut的数据类型保持一致,此时weightScaleOptional必选,activationScaleOptional、biasOptional可选(biasOptional不需要与其他输入类型一致)。

调用说明

调用方式 调用样例 说明
aclnn调用 test_aclnn_dequant_rope_quant_kvcache 通过接口方式调用DequantRopeQuantKvcache算子。