MlaPrologV2

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品

功能说明

  • 算子功能:推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程分为四路,首先对输入xx乘以WDQW^{DQ}进行下采样和RmsNorm后分为两路,第一路乘以WUQW^{UQ}WUKW^{UK}经过两次上采样后得到qNq^N;第二路乘以WQRW^{QR}后经过旋转位置编码(ROPE)得到qRq^R;第三路是输入xx乘以WDKVW^{DKV}进行下采样和RmsNorm后传入Cache中得到kCk^C;第四路是输入xx乘以WKRW^{KR}后经过旋转位置编码后传入另一个Cache中得到kRk^R

  • 计算公式

    RmsNorm公式

    RmsNorm(x)=γ⋅xiRMS(x)\text{RmsNorm}(x) = \gamma \cdot \frac{x_i}{\text{RMS}(x)}

    RMS(x)=1N∑i=1Nxi2+ϵ\text{RMS}(x) = \sqrt{\frac{1}{N} \sum_{i=1}^{N} x_i^2 + \epsilon}

    Query计算公式,包括下采样,RmsNorm和两次上采样

    cQ=RmsNorm(x⋅WDQ)c^Q = RmsNorm(x \cdot W^{DQ})

    qC=cQ⋅WUQq^C = c^Q \cdot W^{UQ}

    qN=qC⋅WUKq^N = q^C \cdot W^{UK}

    对Query的进行ROPE旋转位置编码

    qR=ROPE(cQ⋅WQR)q^R = ROPE(c^Q \cdot W^{QR})

    Key计算公式,包括下采样和RmsNorm,将计算结果存入cache

    cKV=RmsNorm(x⋅WDKV)c^{KV} = RmsNorm(x \cdot W^{DKV})

    kC=Cache(cKV)k^C = Cache(c^{KV})

    对Key进行ROPE旋转位置编码,并将结果存入cache

    kR=Cache(ROPE(x⋅WKR))k^R = Cache(ROPE(x \cdot W^{KR}))

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
token_x 输入 公式中计算Query和Key的输入tensor INT8, BF16 ND
weight_dq 输入 公式中计算Query的下采样权重矩阵WDQW^{DQ} INT8, BF16 FRACTAL_NZ
weight_uq_qr 输入 公式中计算Query的上采样权重矩阵WUQW^{UQ}和位置编码权重矩阵WQRW^{QR} INT8, BF16 FRACTAL_NZ
weight_uk 输入 公式中计算Key的上采样权重WUKW^{UK} FLOAT16, BF16 ND
weight_dkv_kr 输入 公式中计算Key的下采样权重矩阵WDKVW^{DKV}和位置编码权重矩阵WKRW^{KR} INT8, BF16 FRACTAL_NZ
rmsnorm_gamma_cq 输入 计算cQc^Q的RmsNorm公式中γ\gamma参数 FLOAT16, BF16 ND
rmsnorm_gamma_ckv 输入 计算cKVc^{KV}的RmsNorm公式中γ\gamma参数 FLOAT16, BF16 ND
rope_sin 输入 旋转位置编码的正弦参数矩阵 FLOAT16, BF16 ND
rope_cos 输入 旋转位置编码的余弦参数矩阵 FLOAT16, BF16 ND
cache_index 输入 存储kvCache和krCache的索引 INT64 ND
kv_cache 输入/ 输出 cache索引的aclTensor,计算结果原地更新(对应kCk^C FLOAT16, BF16, INT8 ND
kr_cache 输入/ 输出 key位置编码的cache,计算结果原地更新(对应kRk^R FLOAT16, BF16, INT8 ND
dequant_scale_x 输入 预留参数,当前版本暂未使用,必须传入空指针 FLOAT ND
dequant_scale_w_dq 输入 预留参数,当前版本暂未使用,必须传入空指针 FLOAT ND
dequant_scale_w_uq_qr 输入 MatmulQcQr矩阵乘后反量化的per-channel参数 FLOAT ND
dequant_scale_w_dkv_kr 输入 预留参数,当前版本暂未使用,必须传入空指针 FLOAT ND
quant_scale_ckv 输入 KVCache输出量化参数 FLOAT ND
quant_scale_ckr 输入 KRCache输出量化参数 FLOAT ND
smooth_scales_cq 输入 RmsNormCq输出动态量化参数 FLOAT ND
rmsnorm_epsilon_cq 输入 计算cQc^Q的RmsNorm公式中ϵ\epsilon参数 DOUBLE -
rmsnorm_epsilon_ckv 输入 计算cKVc^{KV}的RmsNorm公式中ϵ\epsilon参数 DOUBLE -
cache_mode 输入 kvCache模式 CHAR* -
query 输出 公式中Query的输出tensor(对应qNq^N FLOAT16, BF16, INT8 ND
query_rope 输出 公式中Query位置编码的输出tensor(对应qRq^R FLOAT16, BF16, INT8 ND
dequant_scale_q_nope 输出 表示Query的输出tensor的量化参数 FLOAT ND

约束说明

  • shape约束

    • 若token_x的维度采用BS合轴,即(T, He)
      • rope_sin和rope_cos的shape为(T, Dr)
      • cache_index的shape为(T,)
      • dequant_scale_x的shape为(T, 1)
      • query的shape为(T, N, Hckv)
      • query_rope的shape为(T, N, Dr)
      • 全量化场景下,dequantScaleQNopeOutOptional的shape为(T, N, 1),其他场景下为(1)
    • 若token_x的维度不采用BS合轴,即(B, S, He)
      • rope_sin和rope_cos的shape为(B, S, Dr)
      • cache_index的shape为(B, S)
      • dequant_scale_x的shape为(B*S, 1)
      • query的shape为(B, S, N, Hckv)
      • query_rope的shape为(B, S, N, Dr)
      • 全量化场景下,dequantScaleQNopeOutOptional的shape为(B*S, N, 1),其他场景下为(1)
    • B、S、T、Skv值允许一个或多个取0,即Shape与B、S、T、Skv值相关的入参允许传入空Tensor,其余入参不支持传入空Tensor。
      • 如果B、S、T取值为0,则query、query_rope输出空Tensor,kv_cache、kr_cache不做更新。
      • 如果Skv取值为0,则query、query_rope、dequantScaleQNopeOutOptional正常计算,kv_cache、kr_cache不做更新,即输出空Tensor。
  • weight_dq,weight_uq_qr,weight_dkv_kr在不转置的情况下各个维度的表示:(k,n)。

  • aclnnMlaPrologV2WeightNz接口支持场景:

    场景 含义
    非量化 入参:所有入参皆为非量化数据
    出参:所有出参皆为非量化数据
    部分量化 kv_cache非量化 入参:weight_uq_qr传入pertoken量化数据,其余入参皆为非量化数据
    出参:所有出参返回非量化数据
    kv_cache量化 入参:weight_uq_qr传入pertoken量化数据,kv_cache、kr_cache传入perchannel量化数据,其余入参皆为非量化数据
    出参:kv_cache、kr_cache返回perchannel量化数据,其余出参返回非量化数据
    全量化 kv_cache非量化 入参:token_x传入pertoken量化数据,weight_dq、weight_uq_qr、weight_dkv_kr传入perchannel量化数据,其余入参皆为非量化数据
    出参:所有出参皆为非量化数据
    kv_cache量化 入参:token_x传入pertoken量化数据,weight_dq、weight_uq_qr、weight_dkv_kr传入perchannel量化数据,kv_cache传入pertensor量化数据,其余入参皆为非量化数据
    出参:query返回pertoken_head量化数据,kv_cache出参返回pertensor量化数据,其余出参范围非量化数据
  • 在不同量化场景下,参数的dtype和shape组合需要满足如下条件:

    参数名 非量化场景 部分量化场景 全量化场景
    kv_cache非量化 kv_cache量化 kv_cache非量化 kv_cache量化
    dtype shape dtype shape dtype shape dtype shape dtype shape
    token_x BFLOAT16 · (B,S,He)
    · (T, He)
    BFLOAT16 · (B,S,He)
    · (T, He)
    BFLOAT16 · (B,S,He)
    · (T, He)
    INT8 · (B,S,He)
    · (T, He)
    INT8 · (B,S,He)
    · (T, He)
    weight_dq BFLOAT16 (He, Hcq) BFLOAT16 (He, Hcq) BFLOAT16 (He, Hcq) INT8 (He, Hcq) INT8 (He, Hcq)
    weight_uq_qr BFLOAT16 (Hcq, N*(D+Dr)) INT8 (Hcq, N*(D+Dr)) INT8 (Hcq, N*(D+Dr)) INT8 (Hcq, N*(D+Dr)) INT8 (Hcq, N*(D+Dr))
    weight_uk BFLOAT16 (N, D, Hckv) BFLOAT16 (N, D, Hckv) BFLOAT16 (N, D, Hckv) BFLOAT16 (N, D, Hckv) BFLOAT16 (N, D, Hckv)
    weight_dkv_kr BFLOAT16 (He, Hckv+Dr) BFLOAT16 (He, Hckv+Dr) BFLOAT16 (He, Hckv+Dr) INT8 (He, Hckv+Dr) INT8 (He, Hckv+Dr)
    rmsnorm_gamma_cq BFLOAT16 (Hcq) BFLOAT16 (Hcq) BFLOAT16 (Hcq) BFLOAT16 (Hcq) BFLOAT16 (Hcq)
    rmsnorm_gamma_ckv BFLOAT16 (Hckv) BFLOAT16 (Hckv) BFLOAT16 (Hckv) BFLOAT16 (Hckv) BFLOAT16 (Hckv)
    rope_sin BFLOAT16 · (B,S,Dr)
    · (T, Dr )
    BFLOAT16 · (B,S,Dr)
    · (T, Dr )
    BFLOAT16 · (B,S,Dr)
    · (T, Dr )
    BFLOAT16 · (B,S,Dr)
    · (T, Dr )
    BFLOAT16 · (B,S,Dr)
    · (T, Dr )
    rope_cos BFLOAT16 · (B,S,Dr)
    · (T, Dr )
    BFLOAT16 · (B,S,Dr)
    · (T, Dr )
    BFLOAT16 · (B,S,Dr)
    · (T, Dr )
    BFLOAT16 · (B,S,Dr)
    · (T, Dr )
    BFLOAT16 · (B,S,Dr)
    · (T, Dr )
    cache_index INT64 · (B,S)
    · (T)
    INT64 · (B,S)
    · (T)
    INT64 · (B,S)
    · (T)
    INT64 · (B,S)
    · (T)
    INT64 · (B,S)
    · (T)
    kv_cache BFLOAT16 (BlockNum, BlockSize, Nkv, Hckv) BFLOAT16 (BlockNum, BlockSize, Nkv, Hckv) INT8 (BlockNum, BlockSize, Nkv, Hckv) BFLOAT16 (BlockNum, BlockSize, Nkv, Hckv) INT8 (BlockNum, BlockSize, Nkv, Hckv)
    kr_cache BFLOAT16 (BlockNum, BlockSize, Nkv, Dr) BFLOAT16 (BlockNum, BlockSize, Nkv, Dr) INT8 (BlockNum, BlockSize, Nkv, Dr) BFLOAT16 (BlockNum, BlockSize, Nkv, Dr) BFLOAT16 (BlockNum, BlockSize, Nkv, Dr)
    dequant_scale_x 无需赋值 / 无需赋值 / 无需赋值 / FLOAT · (B*S, 1)
    · (T, 1)
    FLOAT · (B*S, 1)
    · (T, 1)
    dequant_scale_w_dq 无需赋值 / 无需赋值 / 无需赋值 / FLOAT (1, Hcq) FLOAT (1, Hcq)
    dequant_scale_w_uq_qr 无需赋值 / FLOAT (1, N*(D+Dr)) FLOAT (1, N*(D+Dr)) FLOAT (1, N*(D+Dr)) FLOAT (1, N*(D+Dr))
    dequant_scale_w_dkv_kr 无需赋值 / 无需赋值 / 无需赋值 / FLOAT (1, Hckv+Dr) FLOAT (1, Hckv+Dr)
    quant_scale_ckv 无需赋值 / 无需赋值 / FLOAT (1, Hckv) 无需赋值 / FLOAT (1, Hckv)
    quant_scale_ckr 无需赋值 / 无需赋值 / FLOAT (1, Dr) 无需赋值 / 无需赋值 /
    smooth_scales_cq 无需赋值 / FLOAT (1, Hcq) FLOAT (1, Hcq) FLOAT (1, Hcq) FLOAT (1, Hcq)
    query BFLOAT16 · (B, S, N, Hckv)
    · (T, N, Hckv)
    BFLOAT16 · (B, S, N, Hckv)
    · (T, N, Hckv)
    BFLOAT16 · (B, S, N, Hckv)
    · (T, N, Hckv)
    BFLOAT16 · (B, S, N, Hckv)
    · (T, N, Hckv)
    INT8 · (B, S, N, Hckv)
    · (T, N, Hckv)
    query_rope BFLOAT16 · (B, S, N, Dr)
    · (T, N, Dr)
    BFLOAT16 · (B, S, N, Dr)
    · (T, N, Dr)
    BFLOAT16 · (B, S, N, Dr)
    · (T, N, Dr)
    BFLOAT16 · (B, S, N, Dr)
    · (T, N, Dr)
    BFLOAT16 · (B, S, N, Dr)
    · (T, N, Dr)
    dequant_scale_q_nope 无需赋值 / 无需赋值 / 无需赋值 / 无需赋值 / FLOAT · (B*S, N, 1)
    · (T, N, 1)

调用说明

调用方式 样例代码 说明
aclnn接口 MlaPrologV2非量化(BSH)接口测试用例代码 通过 aclnnMlaPrologV2WeightNz 接口方式调用算子
MlaPrologV2非量化(TND)接口测试用例代码
MlaPrologV2半量化KV非量化(BSH)接口测试用例代码
MlaPrologV2半量化KV非量化(TND)接口测试用例代码
MlaPrologV2半量化KV量化(BSH)接口测试用例代码
MlaPrologV2半量化KV量化(TND)接口测试用例代码
MlaPrologV2全量化KV非量化(BSH)接口测试用例代码
MlaPrologV2全量化KV非量化(TND)接口测试用例代码
MlaPrologV2全量化KV量化(BSH)接口测试用例代码
MlaPrologV2全量化KV量化(TND)接口测试用例代码