torch_npu.npu_mla_prolog_v2
[!NOTICE]
该接口中kv_cache和kr_cache进行原地计算,未按in-place算子实现接口,推荐使用torch_npu.npu_mla_prolog_v3接口进行替换。
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
功能说明
-
API功能:推理场景下,Multi-Head Latent Attention(MLA)前处理的计算。主要计算过程分为五路;
- 首先对输入x乘以WDQ进行下采样和RmsNorm后分成两路,第一路乘以WUQ和WUK经过两次上采样后得到qN;
- 第二路乘以WQR后经过旋转位置编码(ROPE)得到qR。
- 第三路是输入x乘以WDKV进行下采样和RmsNorm后传入Cache中得到kC。
- 第四路是输入x乘以WKR后经过旋转位置编码后传入另一个Cache中得到kR。
- 第五路是输出qN经过DynamicQuant后得到的量化参数。
-
计算公式:
-
RmsNorm公式
RmsNorm(x)=γ⋅x1N∑i=1Nxi2+ϵRmsNorm(x) = \gamma \cdot \frac{x} {\sqrt{\frac{1}{N} \sum_{i=1}^{N} x_i^2 + \epsilon}}
-
Query计算公式
cQ=RmsNorm(x⋅WDQ)c^Q = RmsNorm(x \cdot W^{DQ})
qC=cQ⋅WUQq^C = c^Q \cdot W^{UQ}
qN=qC⋅WUKq^N = q^C \cdot W^{UK}
-
Query ROPE旋转位置编码
qR=ROPE(cQ⋅WQR)q^R = ROPE(c^Q \cdot W^{QR})
-
Key计算公式
kC=Cache(RmsNorm(x⋅WDKV))k^C = Cache(RmsNorm(x \cdot W^{DKV}))
-
Key ROPE旋转位置编码
kR=Cache(ROPE(x⋅WKR))k^R = Cache(ROPE(x \cdot W^{KR}))
-
反量化缩放因子(Dequant Scale Query Nope)计算公式
dequantScaleQNope=RowMax(abs(qN))127\text{dequantScaleQNope} = \frac{\text{RowMax}(\text{abs}(q^N))}{127}
qN=round(qNdequantScaleQNope)q^N = \text{round}(\frac{q^N}{\text{dequantScaleQNope}})
-
函数原型
torch_npu.npu_mla_prolog_v2(token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq, rmsnorm_gamma_ckv, rope_sin, rope_cos, cache_index, kv_cache, kr_cache, *, dequant_scale_x=None, dequant_scale_w_dq=None, dequant_scale_w_uq_qr=None, dequant_scale_w_dkv_kr=None, quant_scale_ckv=None, quant_scale_ckr=None, smooth_scales_cq=None, rmsnorm_epsilon_cq=1e-05, rmsnorm_epsilon_ckv=1e-05, cache_mode="PA_BSND") -> (Tensor, Tensor, Tensor, Tensor, Tensor)
参数说明
- token_x(
Tensor):必选参数,对应公式中x。shape支持2维和3维,格式为(T, He)和(B, S, He),dtype支持bfloat16和int8,数据格式支持ND。 - weight_dq(
Tensor):必选参数,表示计算Query的下采样权重矩阵,即公式中WDQ。shape支持2维,格式为(He, Hcq),dtype支持bfloat16和int8,数据格式支持FRACTAL_NZ(可通过torch_npu.npu_format_cast将ND格式转为FRACTAL_NZ格式)。 - weight_uq_qr(
Tensor):必选参数,表示计算Query的上采样权重矩阵和Query的位置编码权重矩阵,即公式中WUQ和WQR。shape支持2维,格式为(Hcq, N*(D+Dr)),dtype支持bfloat16和int8,数据格式支持FRACTAL_NZ。 - weight_uk(
Tensor):必选参数,表示计算Key的上采样权重,即公式中WUK。shape支持3维,格式为(N, D, Hckv),dtype支持bfloat16,数据格式支持ND。 - weight_dkv_kr(
Tensor):必选参数,表示计算Key的下采样权重矩阵和Key的位置编码权重矩阵,即公式中WDKV和WKR。shape支持2维,格式为(He, Hckv+Dr),dtype支持bfloat16和int8,数据格式支持FRACTAL_NZ。 - rmsnorm_gamma_cq(
Tensor):必选参数,表示计算cQ的RmsNorm公式中的_γ_参数。shape支持1维,格式为(Hcq,),dtype支持bfloat16,数据格式支持ND。 - rmsnorm_gamma_ckv(
Tensor):必选参数,表示计算cKV的RmsNorm公式中的_γ_参数。shape支持1维,格式为(Hckv,),dtype支持bfloat16,数据格式支持ND。 - rope_sin(
Tensor):必选参数,表示用于计算旋转位置编码的正弦参数矩阵。shape支持2维和3维,格式为(T, Dr)和(B, S, Dr),dtype支持bfloat16,数据格式支持ND。 - rope_cos(
Tensor):必选参数,表示用于计算旋转位置编码的余弦参数矩阵。shape支持2维和3维,格式为(T, Dr)和(B, S, Dr),dtype支持bfloat16,数据格式支持ND。 - cache_index(
Tensor):必选参数,表示用于存储kv_cache和kr_cache的索引。shape支持1维和2维,格式为(T)和(B, S),dtype支持int64,数据格式支持ND。cache_index的取值范围为[0, BlockNum*BlockSize),当前不会对cache_index传入值的合法性进行校验,需用户自行保证。 - kv_cache(
Tensor):必选参数,表示用于cache索引的aclTensor。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Hckv),dtype支持bfloat16和int8,数据格式支持ND。 - kr_cache(
Tensor):必选参数,表示用于key位置编码的cache。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Dr),dtype支持bfloat16和int8,数据格式支持ND。 - *:必选参数,代表其之前的变量是位置相关的,必须按照顺序输入;之后的变量是可选参数,位置无关,需要使用键值对赋值,不赋值会使用默认值。
- dequant_scale_x(
Tensor):可选参数,输入token_x为int8类型时下采样后进行反量化操作时的参数,token_x量化方式为pertoken。其shape支持2维,格式为(T, 1)和(BS, 1),dtype支持float,数据格式支持ND。 - dequant_scale_w_dq(
Tensor):可选参数,输入token_x为int8类型时下采样后进行反量化操作时的参数,token_x量化方式为perchannel。其shape支持2维,格式为(1, Hcq),dtype支持float,数据格式支持ND。 - dequant_scale_w_uq_qr(
Tensor):可选参数,用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化方式为perchannel。shape支持2维,格式为(1, N*(D+Dr)),dtype支持float,数据格式支持ND。 - dequant_scale_w_dkv_kr(
Tensor):可选参数,用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化方式为perchannel。其shape支持2维,格式为(1, Hckv+Dr),dtype支持float,数据格式支持ND。 - quant_scale_ckv(
Tensor):可选参数,用于对输出到kv_cache_out中的数据做量化操作时的参数。shape支持2维,格式为(1, Hckv),dtype支持float,数据格式支持ND。 - quant_scale_ckr(
Tensor):可选参数,用于对输出到kr_cache_out中的数据做量化操作时的参数。shape支持2维,格式为(1, Dr),dtype支持float,数据格式支持ND。 - smooth_scales_cq(
Tensor):可选参数,用于对RmsNormCq输出做动态量化操作时的参数。shape支持2维,格式为(1, Hcq),dtype支持float,数据格式支持ND。 - rmsnorm_epsilon_cq(
float):可选参数,表示计算cQ的RmsNorm公式中的ε参数,默认值为1e-05。 - rmsnorm_epsilon_ckv(
float):可选参数,表示计算cKV的RmsNorm公式中的ε参数,默认值为1e-05。 - cache_mode(
str):可选参数,表示kv_cache的模式,支持"PA_BSND"、"PA_NZ",默认值为“PA_BSND”。
返回值说明
- query(
Tensor):表示Query的输出Tensor,即公式中qN。shape支持3维和4维,格式为(T, N, Hckv)和(B, S, N, Hckv),dtype支持bfloat16和int8,数据格式支持ND。 - query_rope(
Tensor):表示Query位置编码的输出Tensor,即公式中qR。shape支持3维和4维,格式为(T, N, Dr)和(B, S, N, Dr),dtype支持bfloat16,数据格式支持ND。 - kv_cache_out(
Tensor):表示Key输出到kv_cache中的Tensor(本质in-place更新),即公式中kC。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Hckv),dtype支持bfloat16和int8,数据格式支持ND。 - kr_cache_out(
Tensor):表示Key的位置编码输出到kr_cache中的Tensor(本质in-place更新),即公式中kR。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Dr),dtype支持bfloat16和int8,数据格式支持ND。 - dequant_scale_q_nope(
Tensor):表示Query的输出Tensor的反量化参数。其shape支持1维和3维,全量化kv_cache量化场景下,其shape为(T, N, 1)和(B*S, N, 1);其他场景下,其shape为(1),dtype支持float,数据格式支持ND。
约束说明
-
该接口支持推理场景下使用。
-
该接口支持图模式。
-
接口参数中shape格式字段含义:
-
B:Batch表示输入样本批量大小,取值范围为0~65536。
-
S:Seq-Length表示输入样本序列长度,取值范围为0~16。
-
He:Head-Size表示隐藏层的大小,取值为7168。
-
Hcq:q低秩矩阵维度,取值为1536。
-
N:Head-Num表示多头数,取值范围为1、2、4、8、16、32、64、128。
-
Hckv:kv低秩矩阵维度,取值为512。
-
D:qk不含位置编码维度,取值为128。
-
Dr:qk位置编码维度,取值为64。
-
Nkv:kv的head数,取值为1。
-
BlockNum:PagedAttention场景下的块数,取值为计算B*Skv/BlockSize的值后再向上取整,其中Skv表示kv的序列长度,该值允许取0。
-
BlockSize:PagedAttention场景下的块大小,取值范围为16、128。
-
T:BS合轴后的大小,取值范围:0~1048576。
-
-
shape约束:
- 若
token_x的维度采用BS合轴,即(T, He),则rope_sin和rope_cos的shape为(T, Dr),cache_index的shape为(T,),dequant_scale_x的shape为(T, 1),query的shape为(T, N, Hckv),query_rope的shape为(T, N, Dr)。全量化kv_cache量化场景下,dequant_scale_q_nope的shape为(T, N, 1),其他场景下dequant_scale_q_nope的shape为(1)。 - 若
token_x的维度不采用BS合轴,即(B, S, He),则rope_sin和rope_cos的shape为(B, S, Dr),cache_index的shape为(B, S),dequant_scale_x的shape为(B*S, 1),query的shape为(B, S, N, Hckv),query_rope的shape为(B, S, N, Dr)。全量化kv_cache量化场景下,dequant_scale_q_nope的shape为(B*S, N, 1),其他场景下dequant_scale_q_nope的shape为(1)。 - B、S、T、Skv值允许一个或多个取0,即Shape与B、S、T、Skv值相关的入参允许传入空Tensor,其余入参不支持传入空Tensor。
- 如果B、S、T取值为0,则query、query_rope、dequant_scale_q_nope输出空Tensor,kv_cache、kr_cache、kv_cache_out、kr_cache_out不更新。
- 如果Skv取值为0,则query、query_rope、dequant_scale_q_nope正常计算,kv_cache、kr_cache、kv_cache_out、kr_cache_out不更新,即输出空Tensor。
- 若
-
在不同量化场景下,参数的dtype和shape组合需满足如下条件:
调用示例
-
单算子模式调用
import torch import torch_npu import math torch.npu.config.allow_internal_format = True # 生成随机数据, 并发送到npu B = 8 He = 7168 Hcq = 1536 Hckv = 512 N = 32 D = 128 Dr = 64 Skv = 1024 S = 2 Nkv = 1 BlockSize = 128 BlockNum = 64 token_x = torch.rand(B, S, He, dtype=torch.bfloat16).npu() w_dq = torch.rand(He, Hcq, dtype=torch.bfloat16).npu() w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29) w_uq_qr = torch.rand(Hcq, N * (D + Dr), dtype=torch.bfloat16).npu() w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29) w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu() w_dkv_kr = torch.rand(He, Hckv + Dr, dtype=torch.bfloat16).npu() w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29) rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu() rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu() rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu() rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu() cache_index = torch.rand(B, S).to(torch.int64).npu() kv_cache = torch.rand(BlockNum, BlockSize, Nkv, Hckv, dtype=torch.bfloat16).npu() kr_cache = torch.rand(BlockNum, BlockSize, Nkv, Dr, dtype=torch.bfloat16).npu() rmsnorm_epsilon_cq = 1.0e-5 rmsnorm_epsilon_ckv = 1.0e-5 cache_mode = "PA_BSND" # 调用MlaProlog算子 query_mla, query_rope_mla, kv_cache_out_mla, kr_cache_out_mla, dequant_scale_q_nope_mla = torch_npu.npu_mla_prolog_v2(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq, rmsnorm_gamma_ckv, rope_sin, rope_cos, cache_index, kv_cache, kr_cache, rmsnorm_epsilon_cq=rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv=rmsnorm_epsilon_ckv, cache_mode=cache_mode) print(query_mla) # 执行上述代码的输出out类似如下 tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140], [ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117], [ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062], .. [ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039], [ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164], [ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]], device='npu:0', dtype=torch.bfloat16) -
图模式调用
# 入图方式 import torch import torch_npu import math import torchair as tng from torchair.configs.compiler_config import CompilerConfig import torch._dynamo TORCHDYNAMO_VERBOSE=1 TORCH_LOGS="+dynamo" # 支持入图的打印宏 import logging from torchair.core.utils import logger logger.setLevel(logging.DEBUG) config = CompilerConfig() config.debug.graph_dump.type = "pbtxt" npu_backend = tng.get_npu_backend(compiler_config=config) from torch.library import Library, impl torch.npu.config.allow_internal_format = True # 数据生成 B = 8 He = 7168 Hcq = 1536 Hckv = 512 N = 32 D = 128 Dr = 64 Skv = 1024 S = 2 Nkv = 1 BlockSize = 128 BlockNum = 64 token_x = torch.rand(B, S, He, dtype=torch.bfloat16).npu() w_dq = torch.rand(He, Hcq, dtype=torch.bfloat16).npu() w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29) w_uq_qr = torch.rand(Hcq, N * (D + Dr), dtype=torch.bfloat16).npu() w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29) w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu() w_dkv_kr = torch.rand(He, Hckv + Dr, dtype=torch.bfloat16).npu() w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29) rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu() rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu() rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu() rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu() cache_index = torch.rand(B, S).to(torch.int64).npu() kv_cache = torch.rand(BlockNum, BlockSize, Nkv, Hckv, dtype=torch.bfloat16).npu() kr_cache = torch.rand(BlockNum, BlockSize, Nkv, Dr, dtype=torch.bfloat16).npu() rmsnorm_epsilon_cq = 1.0e-5 rmsnorm_epsilon_ckv = 1.0e-5 cache_mode = "PA_BSND" class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self): return torch_npu.npu_mla_prolog_v2(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,rmsnorm_gamma_ckv, rope_sin, rope_cos, cache_index, kv_cache, kr_cache, rmsnorm_epsilon_cq=rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv=rmsnorm_epsilon_ckv, cache_mode=cache_mode) def MetaInfershape(): with torch.no_grad(): model = Model() model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True) graph_output = model() query_mla, query_rope_mla, kv_cache_out_mla, kr_cache_out_mla, dequant_scale_q_nope_mla = torch_npu.npu_mla_prolog_v2(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,rmsnorm_gamma_ckv, rope_sin, rope_cos, cache_index, kv_cache, kr_cache, rmsnorm_epsilon_cq=rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv=rmsnorm_epsilon_ckv, cache_mode=cache_mode) print("single op output:", query_mla) print("graph output:", graph_output) if __name__ == "__main__": MetaInfershape() # 执行上述代码的输出类似如下 single op output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140], [ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117], [ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062], ..., [ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039], [ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164], [ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]], device='npu:0', dtype=torch.bfloat16) graph output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140], [ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117], [ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062], ..., [ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039], [ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164], [ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]], device='npu:0', dtype=torch.bfloat16)