MlaPrologV2
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | × |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
功能说明
-
算子功能:推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程分为四路,首先对输入xx乘以WDQW^{DQ}进行下采样和RmsNorm后分为两路,第一路乘以WUQW^{UQ}和WUKW^{UK}经过两次上采样后得到qNq^N;第二路乘以WQRW^{QR}后经过旋转位置编码(ROPE)得到qRq^R;第三路是输入xx乘以WDKVW^{DKV}进行下采样和RmsNorm后传入Cache中得到kCk^C;第四路是输入xx乘以WKRW^{KR}后经过旋转位置编码后传入另一个Cache中得到kRk^R。
-
计算公式:
RmsNorm公式
RmsNorm(x)=γ⋅xiRMS(x)\text{RmsNorm}(x) = \gamma \cdot \frac{x_i}{\text{RMS}(x)}
RMS(x)=1N∑i=1Nxi2+ϵ\text{RMS}(x) = \sqrt{\frac{1}{N} \sum_{i=1}^{N} x_i^2 + \epsilon}
Query计算公式,包括下采样,RmsNorm和两次上采样
cQ=RmsNorm(x⋅WDQ)c^Q = RmsNorm(x \cdot W^{DQ})
qC=cQ⋅WUQq^C = c^Q \cdot W^{UQ}
qN=qC⋅WUKq^N = q^C \cdot W^{UK}
对Query的进行ROPE旋转位置编码
qR=ROPE(cQ⋅WQR)q^R = ROPE(c^Q \cdot W^{QR})
Key计算公式,包括下采样和RmsNorm,将计算结果存入cache
cKV=RmsNorm(x⋅WDKV)c^{KV} = RmsNorm(x \cdot W^{DKV})
kC=Cache(cKV)k^C = Cache(c^{KV})
对Key进行ROPE旋转位置编码,并将结果存入cache
kR=Cache(ROPE(x⋅WKR))k^R = Cache(ROPE(x \cdot W^{KR}))
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| token_x | 输入 | 公式中计算Query和Key的输入tensor | INT8, BF16 | ND |
| weight_dq | 输入 | 公式中计算Query的下采样权重矩阵WDQW^{DQ} | INT8, BF16 | FRACTAL_NZ |
| weight_uq_qr | 输入 | 公式中计算Query的上采样权重矩阵WUQW^{UQ}和位置编码权重矩阵WQRW^{QR}。 | INT8, BF16 | FRACTAL_NZ |
| weight_uk | 输入 | 公式中计算Key的上采样权重WUKW^{UK} | FLOAT16, BF16 | ND |
| weight_dkv_kr | 输入 | 公式中计算Key的下采样权重矩阵WDKVW^{DKV}和位置编码权重矩阵WKRW^{KR} | INT8, BF16 | FRACTAL_NZ |
| rmsnorm_gamma_cq | 输入 | 计算cQc^Q的RmsNorm公式中γ\gamma参数 | FLOAT16, BF16 | ND |
| rmsnorm_gamma_ckv | 输入 | 计算cKVc^{KV}的RmsNorm公式中γ\gamma参数 | FLOAT16, BF16 | ND |
| rope_sin | 输入 | 旋转位置编码的正弦参数矩阵 | FLOAT16, BF16 | ND |
| rope_cos | 输入 | 旋转位置编码的余弦参数矩阵 | FLOAT16, BF16 | ND |
| cache_index | 输入 | 存储kvCache和krCache的索引 | INT64 | ND |
| kv_cache | 输入/ 输出 | cache索引的aclTensor,计算结果原地更新(对应kCk^C) | FLOAT16, BF16, INT8 | ND |
| kr_cache | 输入/ 输出 | key位置编码的cache,计算结果原地更新(对应kRk^R) | FLOAT16, BF16, INT8 | ND |
| dequant_scale_x | 输入 | 预留参数,当前版本暂未使用,必须传入空指针 | FLOAT | ND |
| dequant_scale_w_dq | 输入 | 预留参数,当前版本暂未使用,必须传入空指针 | FLOAT | ND |
| dequant_scale_w_uq_qr | 输入 | MatmulQcQr矩阵乘后反量化的per-channel参数 | FLOAT | ND |
| dequant_scale_w_dkv_kr | 输入 | 预留参数,当前版本暂未使用,必须传入空指针 | FLOAT | ND |
| quant_scale_ckv | 输入 | KVCache输出量化参数 | FLOAT | ND |
| quant_scale_ckr | 输入 | KRCache输出量化参数 | FLOAT | ND |
| smooth_scales_cq | 输入 | RmsNormCq输出动态量化参数 | FLOAT | ND |
| rmsnorm_epsilon_cq | 输入 | 计算cQc^Q的RmsNorm公式中ϵ\epsilon参数 | DOUBLE | - |
| rmsnorm_epsilon_ckv | 输入 | 计算cKVc^{KV}的RmsNorm公式中ϵ\epsilon参数 | DOUBLE | - |
| cache_mode | 输入 | kvCache模式 | CHAR* | - |
| query | 输出 | 公式中Query的输出tensor(对应qNq^N) | FLOAT16, BF16, INT8 | ND |
| query_rope | 输出 | 公式中Query位置编码的输出tensor(对应qRq^R) | FLOAT16, BF16, INT8 | ND |
| dequant_scale_q_nope | 输出 | 表示Query的输出tensor的量化参数 | FLOAT | ND |
约束说明
-
shape约束
- 若token_x的维度采用BS合轴,即(T, He)
- rope_sin和rope_cos的shape为(T, Dr)
- cache_index的shape为(T,)
- dequant_scale_x的shape为(T, 1)
- query的shape为(T, N, Hckv)
- query_rope的shape为(T, N, Dr)
- 全量化场景下,dequantScaleQNopeOutOptional的shape为(T, N, 1),其他场景下为(1)
- 若token_x的维度不采用BS合轴,即(B, S, He)
- rope_sin和rope_cos的shape为(B, S, Dr)
- cache_index的shape为(B, S)
- dequant_scale_x的shape为(B*S, 1)
- query的shape为(B, S, N, Hckv)
- query_rope的shape为(B, S, N, Dr)
- 全量化场景下,dequantScaleQNopeOutOptional的shape为(B*S, N, 1),其他场景下为(1)
- B、S、T、Skv值允许一个或多个取0,即Shape与B、S、T、Skv值相关的入参允许传入空Tensor,其余入参不支持传入空Tensor。
- 如果B、S、T取值为0,则query、query_rope输出空Tensor,kv_cache、kr_cache不做更新。
- 如果Skv取值为0,则query、query_rope、dequantScaleQNopeOutOptional正常计算,kv_cache、kr_cache不做更新,即输出空Tensor。
- 若token_x的维度采用BS合轴,即(T, He)
-
weight_dq,weight_uq_qr,weight_dkv_kr在不转置的情况下各个维度的表示:(k,n)。
-
aclnnMlaPrologV2WeightNz接口支持场景:
场景 含义 非量化 入参:所有入参皆为非量化数据
出参:所有出参皆为非量化数据部分量化 kv_cache非量化 入参:weight_uq_qr传入pertoken量化数据,其余入参皆为非量化数据
出参:所有出参返回非量化数据kv_cache量化 入参:weight_uq_qr传入pertoken量化数据,kv_cache、kr_cache传入perchannel量化数据,其余入参皆为非量化数据
出参:kv_cache、kr_cache返回perchannel量化数据,其余出参返回非量化数据全量化 kv_cache非量化 入参:token_x传入pertoken量化数据,weight_dq、weight_uq_qr、weight_dkv_kr传入perchannel量化数据,其余入参皆为非量化数据
出参:所有出参皆为非量化数据kv_cache量化 入参:token_x传入pertoken量化数据,weight_dq、weight_uq_qr、weight_dkv_kr传入perchannel量化数据,kv_cache传入pertensor量化数据,其余入参皆为非量化数据
出参:query返回pertoken_head量化数据,kv_cache出参返回pertensor量化数据,其余出参范围非量化数据 -
在不同量化场景下,参数的dtype和shape组合需要满足如下条件:
参数名 非量化场景 部分量化场景 全量化场景 kv_cache非量化 kv_cache量化 kv_cache非量化 kv_cache量化 dtype shape dtype shape dtype shape dtype shape dtype shape token_x BFLOAT16 · (B,S,He)
· (T, He)BFLOAT16 · (B,S,He)
· (T, He)BFLOAT16 · (B,S,He)
· (T, He)INT8 · (B,S,He)
· (T, He)INT8 · (B,S,He)
· (T, He)weight_dq BFLOAT16 (He, Hcq) BFLOAT16 (He, Hcq) BFLOAT16 (He, Hcq) INT8 (He, Hcq) INT8 (He, Hcq) weight_uq_qr BFLOAT16 (Hcq, N*(D+Dr)) INT8 (Hcq, N*(D+Dr)) INT8 (Hcq, N*(D+Dr)) INT8 (Hcq, N*(D+Dr)) INT8 (Hcq, N*(D+Dr)) weight_uk BFLOAT16 (N, D, Hckv) BFLOAT16 (N, D, Hckv) BFLOAT16 (N, D, Hckv) BFLOAT16 (N, D, Hckv) BFLOAT16 (N, D, Hckv) weight_dkv_kr BFLOAT16 (He, Hckv+Dr) BFLOAT16 (He, Hckv+Dr) BFLOAT16 (He, Hckv+Dr) INT8 (He, Hckv+Dr) INT8 (He, Hckv+Dr) rmsnorm_gamma_cq BFLOAT16 (Hcq) BFLOAT16 (Hcq) BFLOAT16 (Hcq) BFLOAT16 (Hcq) BFLOAT16 (Hcq) rmsnorm_gamma_ckv BFLOAT16 (Hckv) BFLOAT16 (Hckv) BFLOAT16 (Hckv) BFLOAT16 (Hckv) BFLOAT16 (Hckv) rope_sin BFLOAT16 · (B,S,Dr)
· (T, Dr )BFLOAT16 · (B,S,Dr)
· (T, Dr )BFLOAT16 · (B,S,Dr)
· (T, Dr )BFLOAT16 · (B,S,Dr)
· (T, Dr )BFLOAT16 · (B,S,Dr)
· (T, Dr )rope_cos BFLOAT16 · (B,S,Dr)
· (T, Dr )BFLOAT16 · (B,S,Dr)
· (T, Dr )BFLOAT16 · (B,S,Dr)
· (T, Dr )BFLOAT16 · (B,S,Dr)
· (T, Dr )BFLOAT16 · (B,S,Dr)
· (T, Dr )cache_index INT64 · (B,S)
· (T)INT64 · (B,S)
· (T)INT64 · (B,S)
· (T)INT64 · (B,S)
· (T)INT64 · (B,S)
· (T)kv_cache BFLOAT16 (BlockNum, BlockSize, Nkv, Hckv) BFLOAT16 (BlockNum, BlockSize, Nkv, Hckv) INT8 (BlockNum, BlockSize, Nkv, Hckv) BFLOAT16 (BlockNum, BlockSize, Nkv, Hckv) INT8 (BlockNum, BlockSize, Nkv, Hckv) kr_cache BFLOAT16 (BlockNum, BlockSize, Nkv, Dr) BFLOAT16 (BlockNum, BlockSize, Nkv, Dr) INT8 (BlockNum, BlockSize, Nkv, Dr) BFLOAT16 (BlockNum, BlockSize, Nkv, Dr) BFLOAT16 (BlockNum, BlockSize, Nkv, Dr) dequant_scale_x 无需赋值 / 无需赋值 / 无需赋值 / FLOAT · (B*S, 1)
· (T, 1)FLOAT · (B*S, 1)
· (T, 1)dequant_scale_w_dq 无需赋值 / 无需赋值 / 无需赋值 / FLOAT (1, Hcq) FLOAT (1, Hcq) dequant_scale_w_uq_qr 无需赋值 / FLOAT (1, N*(D+Dr)) FLOAT (1, N*(D+Dr)) FLOAT (1, N*(D+Dr)) FLOAT (1, N*(D+Dr)) dequant_scale_w_dkv_kr 无需赋值 / 无需赋值 / 无需赋值 / FLOAT (1, Hckv+Dr) FLOAT (1, Hckv+Dr) quant_scale_ckv 无需赋值 / 无需赋值 / FLOAT (1, Hckv) 无需赋值 / FLOAT (1, Hckv) quant_scale_ckr 无需赋值 / 无需赋值 / FLOAT (1, Dr) 无需赋值 / 无需赋值 / smooth_scales_cq 无需赋值 / FLOAT (1, Hcq) FLOAT (1, Hcq) FLOAT (1, Hcq) FLOAT (1, Hcq) query BFLOAT16 · (B, S, N, Hckv)
· (T, N, Hckv)BFLOAT16 · (B, S, N, Hckv)
· (T, N, Hckv)BFLOAT16 · (B, S, N, Hckv)
· (T, N, Hckv)BFLOAT16 · (B, S, N, Hckv)
· (T, N, Hckv)INT8 · (B, S, N, Hckv)
· (T, N, Hckv)query_rope BFLOAT16 · (B, S, N, Dr)
· (T, N, Dr)BFLOAT16 · (B, S, N, Dr)
· (T, N, Dr)BFLOAT16 · (B, S, N, Dr)
· (T, N, Dr)BFLOAT16 · (B, S, N, Dr)
· (T, N, Dr)BFLOAT16 · (B, S, N, Dr)
· (T, N, Dr)dequant_scale_q_nope 无需赋值 / 无需赋值 / 无需赋值 / 无需赋值 / FLOAT · (B*S, N, 1)
· (T, N, 1)