MlaPrologV3

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品

功能说明

  • 功能更新:(相对与aclnnMlaPrologV2weightNz的差异)

    • 新增query与key的尺度矫正因子,分别对应qcQrScale(αq\alpha_q)与kcScale(αkv\alpha_{kv})。
    • 新增可选输入与参数,将cache_mode由必选改为可选。具体包括:
      • actualSeqLenOptional:用于BS合轴且CacheMode="PA_BLK_BSND"/"PA_BLK_NZ"时,指定当前batch中实际的序列长度。
      • kNopeClipAlphaOptional:表示对kv_cache做clip操作时的缩放因子。
      • queryNormFlag:表示是否输出query_norm,以及量化场景下的dequant_scale_q_norm。
      • weightQuantMode:表示weight_dq、weight_uq_qr、weight_uk、weight_dkv_kr的量化模式。
      • kvCacheQuantMode:表示kv_cache的量化模式。
      • queryQuantMode:表示query的量化模式。
      • ckvkrRepoMode:表示kv_cache和kr_cache的存储模式。
      • quantScaleRepoMode:表示量化scale的存储模式。
      • tileSize:表示per-tile量化时每个tile的大小。
      • queryNormOptional:公式中tokenX做rmsNorm后的输出tensor(对应cQc^Q)。
      • dequantScaleQNormOptional:query_norm的输出tensor的量化参数。
    • 调整cacheIndex参数的名称与位置,对应当前的cacheIndexOptional。
  • 算子功能:推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程分为四路,首先对输入xx乘以WDQW^{DQ}进行下采样和RmsNorm后分为两路,第一路乘以WUQW^{UQ}WUKW^{UK}经过两次上采样后得到qNq^N;第二路乘以WQRW^{QR}后经过旋转位置编码(ROPE)得到qRq^R;第三路是输入xx乘以WDKVW^{DKV}进行下采样和RmsNorm后传入Cache中得到kCk^C;第四路是输入xx乘以WKRW^{KR}后经过旋转位置编码后传入另一个Cache中得到kRk^R

  • 计算公式

    RmsNorm公式

    RmsNorm(x)=γ⋅xiRMS(x)\mathrm{RmsNorm}(x) = \gamma \cdot \frac{x_i}{\mathrm{RMS}(x)}

    RMS(x)=1N∑i=1Nxi2+ϵ\mathrm{RMS}(x) = \sqrt{\frac{1}{N} \sum_{i=1}^{N} x_i^2 + \epsilon}

    Query计算公式,包括下采样,RmsNorm和两次上采样

    cQ=αq⋅RmsNorm(x⋅WDQ)c^Q = \alpha_q\cdot\mathrm{RmsNorm}(x \cdot W^{DQ})

    qC=cQ⋅WUQq^C = c^Q \cdot W^{UQ}

    qN=qC⋅WUKq^N = q^C \cdot W^{UK}

    其中 αq\alpha_q 是 Query 的尺度矫正参数。

    对Query进行ROPE旋转位置编码

    qR=ROPE(cQ⋅WQR)q^R = \mathrm{ROPE}(c^Q \cdot W^{QR})

    Key计算公式,包括下采样和RmsNorm,将计算结果存入cache

    cKV=αkv⋅RmsNorm(x⋅WDKV)c^{KV} = \alpha_{kv}\cdot\mathrm{RmsNorm}(x \cdot W^{DKV})

    kC=Cache(cKV)k^C = \mathrm{Cache}(c^{KV})

    其中 αkv\alpha_{kv} 是 Key 的尺度矫正参数。

    对Key进行ROPE旋转位置编码,并将结果存入cache

    kR=Cache(ROPE(x⋅WKR))k^R = \mathrm{Cache}(\mathrm{ROPE}(x \cdot W^{KR}))

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
token_x 输入 公式中计算Query和Key的输入tensor。 BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 ND
weight_dq 输入 公式中计算Query的下采样权重矩阵WDQW^{DQ}
不转置的情况下各个维度的表示:(k,n)。
BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 FRACTAL_NZ
weight_uq_qr 输入 公式中计算Query的上采样权重矩阵WUQW^{UQ}和位置编码权重矩阵WQRW^{QR}
不转置的情况下各个维度的表示:(k, n)。
BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 FRACTAL_NZ
weight_uk 输入 公式中计算Key的上采样权重WUKW^{UK} BFLOAT16 ND
weight_dkv_kr 输入 公式中计算Key的下采样权重矩阵WDKVW^{DKV}和位置编码权重矩阵WKRW^{KR}
不转置的情况下各个维度的表示:(k, n)。
BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 FRACTAL_NZ
rmsnorm_gamma_cq 输入 计算cQc^Q的RmsNorm公式中γ\gamma参数。 BFLOAT16 ND
rmsnorm_gamma_ckv 输入 计算cKVc^{KV}的RmsNorm公式中γ\gamma参数。 BFLOAT16 ND
rope_sin 输入 旋转位置编码的正弦参数矩阵。 BFLOAT16 ND
rope_cos 输入 旋转位置编码的余弦参数矩阵。 BFLOAT16 ND
kv_cache 输入/ 输出 cache索引的aclTensor,计算结果原地更新(对应kCk^C)。 BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 ND
kr_cache 输入/ 输出 key位置编码的cache,计算结果原地更新(对应kRk^R)。 BFLOAT16、INT8 ND
cache_index 输入 存储kvCache和krCache的索引。 INT64 ND
dequant_scale_x 输入 token_x的反量化参数。 FLOAT8_E8M0、FLOAT ND
dequant_scale_w_dq 输入 weight_dq的反量化参数。 FLOAT8_E8M0、FLOAT ND
dequant_scale_w_uq_qr 输入 MatmulQcQr矩阵乘后反量化的per-channel参数。 FLOAT8_E8M0、FLOAT ND
dequant_scale_w_dkv_kr 输入 weight_dkv_kr的反量化参数。 FLOAT8_E8M0、FLOAT ND
quant_scale_ckv 输入 KVCache输出量化参数。 FLOAT ND
quant_scale_ckr 输入 KRCache输出量化参数。 FLOAT ND
smooth_scales_cq 输入 RmsNormCq输出动态量化参数。 FLOAT ND
actual_seq_len 输入 预留参数,当前版本暂未使用,必须传入空指针。 INT32 ND
k_nope_clip_alpha 输入 对kv_cache做clip操作时的缩放因子。 FLOAT ND
rmsnorm_epsilon_cq 输入 计算cQc^Q的RmsNorm公式中ϵ\epsilon参数。 DOUBLE -
rmsnorm_epsilon_ckv 输入 计算cKVc^{KV}的RmsNorm公式中ϵ\epsilon参数。 DOUBLE -
cache_mode 输入 kvCache模式。 CHAR* -
query_norm_flag 输入 表示是否输出query_norm,Host侧参数。 BOOL -
weight_quant_mode 输入 表示weight_dq、weight_uq_qr、weight_uk、weight_dkv_kr的量化模式。 INT64 -
kv_cache_quant_mode 输入 表示kv_cache的量化模式。 INT64 -
query_quant_mode 输入 表示query的量化模式。 INT64 -
ckvkr_repo_mode 输入 表示kv_cache和kr_cache的存储模式。 INT64 -
quant_scale_repo_mode 输入 表示量化scale的存储模式。 INT64 -
tile_size 输入 表示per-tile量化时每个tile的大小,需要传入128。 INT64 -
qc_qr_scale 输入 Query的尺度矫正参数,对应αq\alpha_q,默认传1.0。 DOUBLE -
kc_scale 输入 Key的尺度矫正参数,对应αkv\alpha_{kv},默认传1.0。 DOUBLE -
query 输出 公式中Query的输出tensor(对应qNq^N)。 BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 ND
query_rope 输出 公式中Query位置编码的输出tensor(对应qRq^R)。 BFLOAT16 ND
dequant_scale_q_nope 输出 表示Query的输出tensor的量化参数。 FLOAT ND
query_norm 输出 公式中tokenX做rmsNorm后的输出tensor(对应cQc^Q)。 BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 ND
dequant_scale_q_norm 输出 query_norm的输出tensor的量化参数。 FLOAT、FLOAT8_E8M0 ND
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:
    • token_x、weight_dq、weight_uq_qr、weight_dkv_kr、kv_cache、query、query_norm不支持FLOAT8_E4M3FN、HIFLOAT8数据类型。
    • dequant_scale_x、dequant_scale_w_dq、dequant_scale_w_uq_qr、dequant_scale_w_dkv_kr、dequant_scale_q_norm不支持FLOAT8_E8M0数据类型。

约束说明

  • shape约束

    • 若tokenX的维度采用BS合轴,即(T, He):
      • ropeSin和ropeCos的shape为(T, Dr)。
      • 当CacheMode为PA_BSND或PA_NZ时,cacheIndex的shape为(T)。
      • 当CacheMode为PA_BLK_BSND或PA_BLK_NZ时,cacheIndex的shape为(Sum(Ceil(S_i/BlockSize))),S_i为每个Batch中的S的长度。
      • 当CacheMode为PA_BLK_BSND或PA_BLK_NZ时,actualSeqLenOptional需要传入,维度为(B)。
      • int8/fp8/hif8全量化场景下,dequantScaleXOptional的shape为(T, 1);mxfp8全量化场景下,dequantScaleXOptional的shape为(T, He/32)。
      • queryOut的shape为(T, N, Hckv)。
      • queryRopeOut的shape为(T, N, Dr)。
      • int8/mxfp8/fp8/hif8全量化场景下,dequantScaleQNopeOutOptional的shape为(T, N, 1),其他场景下为nullptr。
    • 若tokenX的维度不采用BS合轴,即(B, S, He):
      • ropeSin和ropeCos的shape为(B, S, Dr)。
      • 当CacheMode为PA_BSND或PA_NZ时,cacheIndex的shape为(B, S)。
      • 当CacheMode为PA_BLK_BSND或PA_BLK_NZ时,cacheIndex的shape为(B,Ceil(S/BlockSize))。
      • int8/fp8/hif8全量化场景下,dequantScaleXOptional的shape为(B*S, 1);mxfp8全量化场景下,dequantScaleXOptional的shape为(B*S, He/32)。
      • queryOut的shape为(B, S, N, Hckv)。
      • queryRopeOut的shape为(B, S, N, Dr)。
      • int8/mxfp8/fp8/hif8全量化场景下,dequantScaleQNopeOutOptional的shape为(B*S, N, 1),其他场景下为nullptr。
    • B、S、T、Skv值允许一个或多个取0,即Shape与B、S、T、Skv值相关的入参允许传入空Tensor,其余入参不支持传入空Tensor。
      • 如果B、S、T取值为0,则queryOut、queryRopeOut输出空Tensor,kvCacheRef、krCacheRef不做更新。
      • 如果Skv取值为0,则queryOut、queryRopeOut、dequantScaleQNopeOutOptional正常计算,kvCacheRef、krCacheRef不做更新,即输出空Tensor。
    • 当CacheMode为BSND时:
      • tokenX应不采用BS合轴,即维度为(B, S, He)。
      • kvCache的维度为(B,S,Nkv,Dr)。
    • 当CacheMode为TND时:
      • tokenX应采用BS合轴,即维度为(T, He)。
      • kvCache的维度为(T,Nkv,Dr)。
    • 当ckvkrRepoMode=1时:
      • krCache的维度应包含0,支持shape为(0)。
  • 特殊约束

    • actualSeqLenOptional传入时,actualSeqLenOptional最后一个数需与T保持一致。
    • per-tile量化模式下,ckvkrRepoMode和quantScaleRepoMode必须同时为1;其他量化模式以及非量化场景下,ckvkrRepoMode和quantScaleRepoMode必须同时为0。
    • per-tile量化模式下,CacheMode只支持PA_BSND, BSND和TND。
    • 当ckvkrRepoMode值为1时,krCache必须为空Tensor(即shape的乘积为0)。
    • kvcache per-tensor量化模式下,kvCacheQuantMode和queryQuantMode必须同时为1。
  • aclnnMlaPrologV3WeightNz接口支持场景:

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品:当前不支持 fp8/hif8/mxfp8 全量化场景
    • Ascend 950PR/Ascend 950DT:当前支持所有量化场景
    场景 含义
    非量化 weight_quant_mode=0,kv_cache_quant_mode=0,query_quant_mode=0
    入参:所有入参皆为非量化数据
    出参:所有出参皆为非量化数据
    部分量化 kvCache非量化 weight_quant_mode=1,kv_cache_quant_mode=0,query_quant_mode=0
    入参:weightUqQr传入per-token量化数据,其余入参皆为非量化数据。dequantScaleWUqQr字段必须传入,smoothScalesCq字段可选传入
    出参:所有出参返回非量化数据
    kvCache per-channel量化 weight_quant_mode=1,kv_cache_quant_mode=2,query_quant_mode=0
    入参:weightUqQr传入per-token量化数据,kvCacheRef、krCacheRef传入per-channel量化数据,其余入参皆为非量化数据。dequantScaleWUqQr、quantScaleCkv、quant_scale_ckr字段必须传入,smoothScalesCq字段可选传入
    出参:kvCacheRef、krCacheRef返回per-channel量化数据,其余出参返回非量化数据
    kvCache per-tile量化 weight_quant_mode=1, kv_cache_quant_mode=3, query_quant_mode=0
    入参:weightUqQr传入per-token量化数据,其余入参皆为非量化数据。dequantScaleWUqQr字段必须传入,smoothScalesCq字段可选传入
    出参:kvCacheRef返回per-tile量化数据,其余出参返回非量化数据
    int8/fp8/hif8全量化 kvCache非量化 weight_quant_mode=2/4/5,kv_cache_quant_mode=0,query_quant_mode=0
    入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr字段必须传入,smoothScalesCq字段可选传入
    出参:所有出参返回非量化数据
    kvCache per-tensor量化 weight_quant_mode=2/4/5,kv_cache_quant_mode=1,query_quant_mode=1
    入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,kvCacheRef传入per-tensor量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr、quantScaleCkv字段必须传入,smoothScalesCq字段可选传入
    出参:queryOut返回per-token-head量化数据,kvCacheRef出参返回per-tensor量化数据,其余出参返回非量化数据
    kvCache per-tile量化 weight_quant_mode=2/4/5,kv_cache_quant_mode=3,query_quant_mode=0
    入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr字段必须传入,smoothScalesCq字段可选传入
    出参:kvCacheRef出参返回per-tile量化数据,其余出参返回非量化数据
    mxfp8全量化 kvCache非量化 weight_quant_mode=3,kv_cache_quant_mode=0,query_quant_mode=0
    入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr字段必须传入
    出参:所有出参返回非量化数据
    kvCache per-tensor量化 weight_quant_mode=3,kv_cache_quant_mode=1,query_quant_mode=1
    入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,kvCacheRef传入per-tensor量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr、quantScaleCkv字段必须传入
    出参:queryOut返回per-token-head量化数据,kvCacheRef出参返回per-tensor量化数据,其余出参返回非量化数据
    kvCache per-tile量化 weight_quant_mode=3,kv_cache_quant_mode=3,query_quant_mode=0
    入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr字段必须传入
    出参:kvCacheRef出参返回per-tile量化数据,其余出参返回非量化数据

调用说明

调用方式 样例代码 说明
aclnn接口 MlaPrologV3接口测试用例代码 通过 aclnnMlaPrologV3WeightNz 接口方式调用算子