MlaPrologV3
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
功能说明
-
功能更新:(相对与aclnnMlaPrologV2weightNz的差异)
- 新增query与key的尺度矫正因子,分别对应qcQrScale(αq\alpha_q)与kcScale(αkv\alpha_{kv})。
- 新增可选输入与参数,将cache_mode由必选改为可选。具体包括:
- actualSeqLenOptional:用于BS合轴且CacheMode="PA_BLK_BSND"/"PA_BLK_NZ"时,指定当前batch中实际的序列长度。
- kNopeClipAlphaOptional:表示对kv_cache做clip操作时的缩放因子。
- queryNormFlag:表示是否输出query_norm,以及量化场景下的dequant_scale_q_norm。
- weightQuantMode:表示weight_dq、weight_uq_qr、weight_uk、weight_dkv_kr的量化模式。
- kvCacheQuantMode:表示kv_cache的量化模式。
- queryQuantMode:表示query的量化模式。
- ckvkrRepoMode:表示kv_cache和kr_cache的存储模式。
- quantScaleRepoMode:表示量化scale的存储模式。
- tileSize:表示per-tile量化时每个tile的大小。
- queryNormOptional:公式中tokenX做rmsNorm后的输出tensor(对应cQc^Q)。
- dequantScaleQNormOptional:query_norm的输出tensor的量化参数。
- 调整cacheIndex参数的名称与位置,对应当前的cacheIndexOptional。
-
算子功能:推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程分为四路,首先对输入xx乘以WDQW^{DQ}进行下采样和RmsNorm后分为两路,第一路乘以WUQW^{UQ}和WUKW^{UK}经过两次上采样后得到qNq^N;第二路乘以WQRW^{QR}后经过旋转位置编码(ROPE)得到qRq^R;第三路是输入xx乘以WDKVW^{DKV}进行下采样和RmsNorm后传入Cache中得到kCk^C;第四路是输入xx乘以WKRW^{KR}后经过旋转位置编码后传入另一个Cache中得到kRk^R。
-
计算公式:
RmsNorm公式
RmsNorm(x)=γ⋅xiRMS(x)\mathrm{RmsNorm}(x) = \gamma \cdot \frac{x_i}{\mathrm{RMS}(x)}
RMS(x)=1N∑i=1Nxi2+ϵ\mathrm{RMS}(x) = \sqrt{\frac{1}{N} \sum_{i=1}^{N} x_i^2 + \epsilon}
Query计算公式,包括下采样,RmsNorm和两次上采样
cQ=αq⋅RmsNorm(x⋅WDQ)c^Q = \alpha_q\cdot\mathrm{RmsNorm}(x \cdot W^{DQ})
qC=cQ⋅WUQq^C = c^Q \cdot W^{UQ}
qN=qC⋅WUKq^N = q^C \cdot W^{UK}
其中 αq\alpha_q 是 Query 的尺度矫正参数。
对Query进行ROPE旋转位置编码
qR=ROPE(cQ⋅WQR)q^R = \mathrm{ROPE}(c^Q \cdot W^{QR})
Key计算公式,包括下采样和RmsNorm,将计算结果存入cache
cKV=αkv⋅RmsNorm(x⋅WDKV)c^{KV} = \alpha_{kv}\cdot\mathrm{RmsNorm}(x \cdot W^{DKV})
kC=Cache(cKV)k^C = \mathrm{Cache}(c^{KV})
其中 αkv\alpha_{kv} 是 Key 的尺度矫正参数。
对Key进行ROPE旋转位置编码,并将结果存入cache
kR=Cache(ROPE(x⋅WKR))k^R = \mathrm{Cache}(\mathrm{ROPE}(x \cdot W^{KR}))
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| token_x | 输入 | 公式中计算Query和Key的输入tensor。 | BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 | ND |
| weight_dq | 输入 | 公式中计算Query的下采样权重矩阵WDQW^{DQ}。 不转置的情况下各个维度的表示:(k,n)。 |
BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 | FRACTAL_NZ |
| weight_uq_qr | 输入 | 公式中计算Query的上采样权重矩阵WUQW^{UQ}和位置编码权重矩阵WQRW^{QR}。 不转置的情况下各个维度的表示:(k, n)。 |
BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 | FRACTAL_NZ |
| weight_uk | 输入 | 公式中计算Key的上采样权重WUKW^{UK}。 | BFLOAT16 | ND |
| weight_dkv_kr | 输入 | 公式中计算Key的下采样权重矩阵WDKVW^{DKV}和位置编码权重矩阵WKRW^{KR}。 不转置的情况下各个维度的表示:(k, n)。 |
BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 | FRACTAL_NZ |
| rmsnorm_gamma_cq | 输入 | 计算cQc^Q的RmsNorm公式中γ\gamma参数。 | BFLOAT16 | ND |
| rmsnorm_gamma_ckv | 输入 | 计算cKVc^{KV}的RmsNorm公式中γ\gamma参数。 | BFLOAT16 | ND |
| rope_sin | 输入 | 旋转位置编码的正弦参数矩阵。 | BFLOAT16 | ND |
| rope_cos | 输入 | 旋转位置编码的余弦参数矩阵。 | BFLOAT16 | ND |
| kv_cache | 输入/ 输出 | cache索引的aclTensor,计算结果原地更新(对应kCk^C)。 | BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 | ND |
| kr_cache | 输入/ 输出 | key位置编码的cache,计算结果原地更新(对应kRk^R)。 | BFLOAT16、INT8 | ND |
| cache_index | 输入 | 存储kvCache和krCache的索引。 | INT64 | ND |
| dequant_scale_x | 输入 | token_x的反量化参数。 | FLOAT8_E8M0、FLOAT | ND |
| dequant_scale_w_dq | 输入 | weight_dq的反量化参数。 | FLOAT8_E8M0、FLOAT | ND |
| dequant_scale_w_uq_qr | 输入 | MatmulQcQr矩阵乘后反量化的per-channel参数。 | FLOAT8_E8M0、FLOAT | ND |
| dequant_scale_w_dkv_kr | 输入 | weight_dkv_kr的反量化参数。 | FLOAT8_E8M0、FLOAT | ND |
| quant_scale_ckv | 输入 | KVCache输出量化参数。 | FLOAT | ND |
| quant_scale_ckr | 输入 | KRCache输出量化参数。 | FLOAT | ND |
| smooth_scales_cq | 输入 | RmsNormCq输出动态量化参数。 | FLOAT | ND |
| actual_seq_len | 输入 | 预留参数,当前版本暂未使用,必须传入空指针。 | INT32 | ND |
| k_nope_clip_alpha | 输入 | 对kv_cache做clip操作时的缩放因子。 | FLOAT | ND |
| rmsnorm_epsilon_cq | 输入 | 计算cQc^Q的RmsNorm公式中ϵ\epsilon参数。 | DOUBLE | - |
| rmsnorm_epsilon_ckv | 输入 | 计算cKVc^{KV}的RmsNorm公式中ϵ\epsilon参数。 | DOUBLE | - |
| cache_mode | 输入 | kvCache模式。 | CHAR* | - |
| query_norm_flag | 输入 | 表示是否输出query_norm,Host侧参数。 | BOOL | - |
| weight_quant_mode | 输入 | 表示weight_dq、weight_uq_qr、weight_uk、weight_dkv_kr的量化模式。 | INT64 | - |
| kv_cache_quant_mode | 输入 | 表示kv_cache的量化模式。 | INT64 | - |
| query_quant_mode | 输入 | 表示query的量化模式。 | INT64 | - |
| ckvkr_repo_mode | 输入 | 表示kv_cache和kr_cache的存储模式。 | INT64 | - |
| quant_scale_repo_mode | 输入 | 表示量化scale的存储模式。 | INT64 | - |
| tile_size | 输入 | 表示per-tile量化时每个tile的大小,需要传入128。 | INT64 | - |
| qc_qr_scale | 输入 | Query的尺度矫正参数,对应αq\alpha_q,默认传1.0。 | DOUBLE | - |
| kc_scale | 输入 | Key的尺度矫正参数,对应αkv\alpha_{kv},默认传1.0。 | DOUBLE | - |
| query | 输出 | 公式中Query的输出tensor(对应qNq^N)。 | BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 | ND |
| query_rope | 输出 | 公式中Query位置编码的输出tensor(对应qRq^R)。 | BFLOAT16 | ND |
| dequant_scale_q_nope | 输出 | 表示Query的输出tensor的量化参数。 | FLOAT | ND |
| query_norm | 输出 | 公式中tokenX做rmsNorm后的输出tensor(对应cQc^Q)。 | BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 | ND |
| dequant_scale_q_norm | 输出 | query_norm的输出tensor的量化参数。 | FLOAT、FLOAT8_E8M0 | ND |
- Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:
- token_x、weight_dq、weight_uq_qr、weight_dkv_kr、kv_cache、query、query_norm不支持FLOAT8_E4M3FN、HIFLOAT8数据类型。
- dequant_scale_x、dequant_scale_w_dq、dequant_scale_w_uq_qr、dequant_scale_w_dkv_kr、dequant_scale_q_norm不支持FLOAT8_E8M0数据类型。
约束说明
-
shape约束
- 若tokenX的维度采用BS合轴,即(T, He):
- ropeSin和ropeCos的shape为(T, Dr)。
- 当CacheMode为PA_BSND或PA_NZ时,cacheIndex的shape为(T)。
- 当CacheMode为PA_BLK_BSND或PA_BLK_NZ时,cacheIndex的shape为(Sum(Ceil(S_i/BlockSize))),S_i为每个Batch中的S的长度。
- 当CacheMode为PA_BLK_BSND或PA_BLK_NZ时,actualSeqLenOptional需要传入,维度为(B)。
- int8/fp8/hif8全量化场景下,dequantScaleXOptional的shape为(T, 1);mxfp8全量化场景下,dequantScaleXOptional的shape为(T, He/32)。
- queryOut的shape为(T, N, Hckv)。
- queryRopeOut的shape为(T, N, Dr)。
- int8/mxfp8/fp8/hif8全量化场景下,dequantScaleQNopeOutOptional的shape为(T, N, 1),其他场景下为nullptr。
- 若tokenX的维度不采用BS合轴,即(B, S, He):
- ropeSin和ropeCos的shape为(B, S, Dr)。
- 当CacheMode为PA_BSND或PA_NZ时,cacheIndex的shape为(B, S)。
- 当CacheMode为PA_BLK_BSND或PA_BLK_NZ时,cacheIndex的shape为(B,Ceil(S/BlockSize))。
- int8/fp8/hif8全量化场景下,dequantScaleXOptional的shape为(B*S, 1);mxfp8全量化场景下,dequantScaleXOptional的shape为(B*S, He/32)。
- queryOut的shape为(B, S, N, Hckv)。
- queryRopeOut的shape为(B, S, N, Dr)。
- int8/mxfp8/fp8/hif8全量化场景下,dequantScaleQNopeOutOptional的shape为(B*S, N, 1),其他场景下为nullptr。
- B、S、T、Skv值允许一个或多个取0,即Shape与B、S、T、Skv值相关的入参允许传入空Tensor,其余入参不支持传入空Tensor。
- 如果B、S、T取值为0,则queryOut、queryRopeOut输出空Tensor,kvCacheRef、krCacheRef不做更新。
- 如果Skv取值为0,则queryOut、queryRopeOut、dequantScaleQNopeOutOptional正常计算,kvCacheRef、krCacheRef不做更新,即输出空Tensor。
- 当CacheMode为BSND时:
- tokenX应不采用BS合轴,即维度为(B, S, He)。
- kvCache的维度为(B,S,Nkv,Dr)。
- 当CacheMode为TND时:
- tokenX应采用BS合轴,即维度为(T, He)。
- kvCache的维度为(T,Nkv,Dr)。
- 当ckvkrRepoMode=1时:
- krCache的维度应包含0,支持shape为(0)。
- 若tokenX的维度采用BS合轴,即(T, He):
-
特殊约束
- actualSeqLenOptional传入时,actualSeqLenOptional最后一个数需与T保持一致。
- per-tile量化模式下,ckvkrRepoMode和quantScaleRepoMode必须同时为1;其他量化模式以及非量化场景下,ckvkrRepoMode和quantScaleRepoMode必须同时为0。
- per-tile量化模式下,CacheMode只支持PA_BSND, BSND和TND。
- 当ckvkrRepoMode值为1时,krCache必须为空Tensor(即shape的乘积为0)。
- kvcache per-tensor量化模式下,kvCacheQuantMode和queryQuantMode必须同时为1。
-
aclnnMlaPrologV3WeightNz接口支持场景:
- Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品:当前不支持 fp8/hif8/mxfp8 全量化场景
- Ascend 950PR/Ascend 950DT:当前支持所有量化场景
场景 含义 非量化 weight_quant_mode=0,kv_cache_quant_mode=0,query_quant_mode=0
入参:所有入参皆为非量化数据
出参:所有出参皆为非量化数据部分量化 kvCache非量化 weight_quant_mode=1,kv_cache_quant_mode=0,query_quant_mode=0
入参:weightUqQr传入per-token量化数据,其余入参皆为非量化数据。dequantScaleWUqQr字段必须传入,smoothScalesCq字段可选传入
出参:所有出参返回非量化数据kvCache per-channel量化 weight_quant_mode=1,kv_cache_quant_mode=2,query_quant_mode=0
入参:weightUqQr传入per-token量化数据,kvCacheRef、krCacheRef传入per-channel量化数据,其余入参皆为非量化数据。dequantScaleWUqQr、quantScaleCkv、quant_scale_ckr字段必须传入,smoothScalesCq字段可选传入
出参:kvCacheRef、krCacheRef返回per-channel量化数据,其余出参返回非量化数据kvCache per-tile量化 weight_quant_mode=1, kv_cache_quant_mode=3, query_quant_mode=0
入参:weightUqQr传入per-token量化数据,其余入参皆为非量化数据。dequantScaleWUqQr字段必须传入,smoothScalesCq字段可选传入
出参:kvCacheRef返回per-tile量化数据,其余出参返回非量化数据int8/fp8/hif8全量化 kvCache非量化 weight_quant_mode=2/4/5,kv_cache_quant_mode=0,query_quant_mode=0
入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr字段必须传入,smoothScalesCq字段可选传入
出参:所有出参返回非量化数据kvCache per-tensor量化 weight_quant_mode=2/4/5,kv_cache_quant_mode=1,query_quant_mode=1
入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,kvCacheRef传入per-tensor量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr、quantScaleCkv字段必须传入,smoothScalesCq字段可选传入
出参:queryOut返回per-token-head量化数据,kvCacheRef出参返回per-tensor量化数据,其余出参返回非量化数据kvCache per-tile量化 weight_quant_mode=2/4/5,kv_cache_quant_mode=3,query_quant_mode=0
入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr字段必须传入,smoothScalesCq字段可选传入
出参:kvCacheRef出参返回per-tile量化数据,其余出参返回非量化数据mxfp8全量化 kvCache非量化 weight_quant_mode=3,kv_cache_quant_mode=0,query_quant_mode=0
入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr字段必须传入
出参:所有出参返回非量化数据kvCache per-tensor量化 weight_quant_mode=3,kv_cache_quant_mode=1,query_quant_mode=1
入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,kvCacheRef传入per-tensor量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr、quantScaleCkv字段必须传入
出参:queryOut返回per-token-head量化数据,kvCacheRef出参返回per-tensor量化数据,其余出参返回非量化数据kvCache per-tile量化 weight_quant_mode=3,kv_cache_quant_mode=3,query_quant_mode=0
入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr字段必须传入
出参:kvCacheRef出参返回per-tile量化数据,其余出参返回非量化数据
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | MlaPrologV3接口测试用例代码 | 通过 aclnnMlaPrologV3WeightNz 接口方式调用算子 |