DequantRopeQuantKvcache
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | × |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
| Kirin X90 处理器系列产品 | √ |
| Kirin 9030 处理器系列产品 | √ |
功能说明
-
算子功能:对输入张量(x)进行dequant(可选)后,按
sizeSplits(为切分的长度)对尾轴进行切分,划分为q、k、vOut,对q、k进行旋转位置编码,生成qOut和kOut,之后对kOut和vOut进行量化并按照indices更新到kCacheRef和vCacheRef上。 -
计算公式:
dequantX=Dequant(x,weightScaleOptional,activationScaleOptional,biasOptional)dequantX = Dequant(x,weightScaleOptional,activationScaleOptional,biasOptional)
q,k,vOut=SplitTensor(dequantX,dim=−1,‘sizeSplits‘)q,k,vOut = SplitTensor(dequantX,dim=-1,`sizeSplits`)
qOut,kOut=ApplyRotaryPosEmb(q,k,cos,sin)qOut,kOut = ApplyRotaryPosEmb(q,k,cos,sin)
quantK=Quant(kOut,scaleK,offsetKOptional)quantK = Quant(kOut,scaleK,offsetKOptional)
quantV=Quant(vOut,scaleV,offsetVOptional)quantV = Quant(vOut,scaleV,offsetVOptional)
如果cacheModeOptional为contiguous则:
kCacheRef[i][indices[i]]=quantK[i]kCacheRef[i][indices[i]]=quantK[i]
vCacheRef[i][indices[i]]=quantV[i]vCacheRef[i][indices[i]]=quantV[i]
如果cacheModeOptional为page则:
kCacheRefView=kCacheRef.view(−1,kCacheRef[−2],kCacheRef[−1])kCacheRefView=kCacheRef.view(-1,kCacheRef[-2],kCacheRef[-1])
vCacheRefView=vCacheRef.view(−1,vCacheRef[−2],vCacheRef[−1])vCacheRefView=vCacheRef.view(-1,vCacheRef[-2],vCacheRef[-1])
kCacheRefView[indices[i]]=quantK[i]kCacheRefView[indices[i]]=quantK[i]
vCacheRefView[indices[i]]=quantV[i]vCacheRefView[indices[i]]=quantV[i]
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x | 输入 | 公式中的用于切分的输入`x`,Device侧的aclTensor。 | FLOAT16、INT32、BFLOAT16 | ND |
| cos | 输入 | 公式中的用于位置编码的输入`cos`,Device侧的aclTensor。 | FLOAT16、BFLOAT16 | ND |
| sin | 输入 | 公式中的用于位置编码的输入`sin`,Device侧的aclTensor。 | BFLOAT16、FLOAT16、FLOAT32 | ND |
| kCacheRef | 输入 | 公式中用于缓存k的输入`kCacheRef`,Device侧的aclTensor。 | INT8 | ND |
| vCacheRef | 输入 | 公式中用于缓存v的输入`vCacheRef`,Device侧的aclTensor。 | INT8 | ND |
| indices | 输入 | 公式中表示Kvcache的token位置信息的输入`indices`。 | INT32 | ND |
| scaleK | 输入 | 公式中的输入`scaleK`用于量化`k`的scale因子,Device侧的aclTensor。 | FLOAT | ND |
| scaleV | 输入 | 公式中的输入`scaleV`用于量化`v`的scale因子,Device侧的aclTensor。 | FLOAT | ND |
| offsetKOptional | 可选输入 | 公式中的输入`offsetKOptional`用于量化k的offset因子,Device侧的aclTensor。 | FLOAT | ND |
| offsetVOptional | 可选输入 | 公式中的输入`offsetVoptional`用于量化的offset因子,Device侧的aclTensor。 | FLOAT | ND |
| weightScaleOptional | 可选输入 | 公式中的输入`weightScaleoptional`用于反量化的权重scale因子,Device侧的aclTensor。 | FLOAT | ND |
| activationScaleOptional | 可选输入 | 公式中的输入`activationScaleOptional`用于反量化的激活scale因子,Device侧的aclTensor。 | FLOAT | ND |
| biasOptional | 可选输入 | 公式中的输入用于反量化的偏置`biasOptional`,Device侧的aclTensor。 | FLOAT、FLOAT16(HALF)、INT32、BFLOAT16 | ND |
| sizeSplits | 输入 | 表示输入的qkv进行切分的长度。 | INT64 | - |
| quantModeOptional | 可选输入 | Host侧表达式字符串。表示支持的量化类型,目前仅支持static。 | String | - |
| layoutOptional | 可选输入 | Host侧表达式字符串。表示支持的数据格式,目前仅支持BSND。 | String | - |
| kvOutput | 输入 | Host侧表达式布尔值。表示是否输出kOut和vOut。 | BOOL | - |
| cacheModeOptional | 输入 | Host侧表达式字符串。表示kCacheRef的更新方式,目前仅支持page和contiguous,默认为contiguous。 | String | - |
| qOut | 输出 | 表示经过处理的q,Device侧的aclTensor。 | FLOAT16、BFLOAT16 | ND |
| kOut | 输出 | 表示输入的qkv进行切分的长度。 | FLOAT16、BFLOAT16 | ND |
| vOut | 输出 | 表示经过处理的v,Device侧的aclTensor。 | FLOAT16、BFLOAT16 | ND |
- Kirin X90/Kirin 9030 处理器系列产品: 不支持BFLOAT16。
约束说明
- cacheModeOptional为contiguous时:kCacheRef的第0维大于x的第0维,indices数据值大于等于0且小于等于vCacheRef的第1维([b,s,n,d]格式中的s)减x的第1维;cacheModeOptional为page时:indices 数据值大于等于0,小于kCacheRef的第0维*第1维且不重复。
- x的尾轴小于等于4096,且按64对齐。
- 输入x不为int32时,x、cos、sin与输出qOut、kOut、vOut的数据类型保持一致,此时activationScaleOptional,weightScaleOptional、biasOptional不生效;x为int32时,cos、sin与输出qOut、kOut、vOut的数据类型保持一致,此时weightScaleOptional必选,activationScaleOptional、biasOptional可选(biasOptional不需要与其他输入类型一致)。
调用说明
| 调用方式 | 调用样例 | 说明 |
|---|---|---|
| aclnn调用 | test_aclnn_dequant_rope_quant_kvcache | 通过接口方式调用DequantRopeQuantKvcache算子。 |