torch_npu.npu_mla_prolog_v2

[!NOTICE]
该接口中kv_cache和kr_cache进行原地计算,未按in-place算子实现接口,推荐使用torch_npu.npu_mla_prolog_v3接口进行替换。

产品支持情况

产品 是否支持
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品

功能说明

  • API功能:推理场景下,Multi-Head Latent Attention(MLA)前处理的计算。主要计算过程分为五路;

    • 首先对输入x乘以WDQ进行下采样和RmsNorm后分成两路,第一路乘以WUQ和WUK经过两次上采样后得到qN
    • 第二路乘以WQR后经过旋转位置编码(ROPE)得到qR
    • 第三路是输入x乘以WDKV进行下采样和RmsNorm后传入Cache中得到kC
    • 第四路是输入x乘以WKR后经过旋转位置编码后传入另一个Cache中得到kR
    • 第五路是输出qN经过DynamicQuant后得到的量化参数。
  • 计算公式:

    • RmsNorm公式

      RmsNorm(x)=γ⋅x1N∑i=1Nxi2+ϵRmsNorm(x) = \gamma \cdot \frac{x} {\sqrt{\frac{1}{N} \sum_{i=1}^{N} x_i^2 + \epsilon}}

    • Query计算公式

      cQ=RmsNorm(x⋅WDQ)c^Q = RmsNorm(x \cdot W^{DQ})

      qC=cQ⋅WUQq^C = c^Q \cdot W^{UQ}

      qN=qC⋅WUKq^N = q^C \cdot W^{UK}

    • Query ROPE旋转位置编码

      qR=ROPE(cQ⋅WQR)q^R = ROPE(c^Q \cdot W^{QR})

    • Key计算公式

      kC=Cache(RmsNorm(x⋅WDKV))k^C = Cache(RmsNorm(x \cdot W^{DKV}))

    • Key ROPE旋转位置编码

      kR=Cache(ROPE(x⋅WKR))k^R = Cache(ROPE(x \cdot W^{KR}))

    • 反量化缩放因子(Dequant Scale Query Nope)计算公式

      dequantScaleQNope=RowMax(abs(qN))127\text{dequantScaleQNope} = \frac{\text{RowMax}(\text{abs}(q^N))}{127}

      qN=round(qNdequantScaleQNope)q^N = \text{round}(\frac{q^N}{\text{dequantScaleQNope}})

函数原型

torch_npu.npu_mla_prolog_v2(token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq, rmsnorm_gamma_ckv, rope_sin, rope_cos, cache_index, kv_cache, kr_cache, *, dequant_scale_x=None, dequant_scale_w_dq=None, dequant_scale_w_uq_qr=None, dequant_scale_w_dkv_kr=None, quant_scale_ckv=None, quant_scale_ckr=None, smooth_scales_cq=None, rmsnorm_epsilon_cq=1e-05, rmsnorm_epsilon_ckv=1e-05, cache_mode="PA_BSND") -> (Tensor, Tensor, Tensor, Tensor, Tensor)

参数说明

  • token_xTensor):必选参数,对应公式中x。shape支持2维和3维,格式为(T, He)和(B, S, He),dtype支持bfloat16int8,数据格式支持ND。
  • weight_dqTensor):必选参数,表示计算Query的下采样权重矩阵,即公式中WDQ。shape支持2维,格式为(He, Hcq),dtype支持bfloat16int8,数据格式支持FRACTAL_NZ(可通过torch_npu.npu_format_cast将ND格式转为FRACTAL_NZ格式)。
  • weight_uq_qrTensor):必选参数,表示计算Query的上采样权重矩阵和Query的位置编码权重矩阵,即公式中WUQ和WQR。shape支持2维,格式为(Hcq, N*(D+Dr)),dtype支持bfloat16int8,数据格式支持FRACTAL_NZ。
  • weight_ukTensor):必选参数,表示计算Key的上采样权重,即公式中WUK。shape支持3维,格式为(N, D, Hckv),dtype支持bfloat16,数据格式支持ND。
  • weight_dkv_krTensor):必选参数,表示计算Key的下采样权重矩阵和Key的位置编码权重矩阵,即公式中WDKV和WKR。shape支持2维,格式为(He, Hckv+Dr),dtype支持bfloat16int8,数据格式支持FRACTAL_NZ。
  • rmsnorm_gamma_cqTensor):必选参数,表示计算cQ的RmsNorm公式中的_γ_参数。shape支持1维,格式为(Hcq,),dtype支持bfloat16,数据格式支持ND。
  • rmsnorm_gamma_ckvTensor):必选参数,表示计算cKV的RmsNorm公式中的_γ_参数。shape支持1维,格式为(Hckv,),dtype支持bfloat16,数据格式支持ND。
  • rope_sinTensor):必选参数,表示用于计算旋转位置编码的正弦参数矩阵。shape支持2维和3维,格式为(T, Dr)和(B, S, Dr),dtype支持bfloat16,数据格式支持ND。
  • rope_cosTensor):必选参数,表示用于计算旋转位置编码的余弦参数矩阵。shape支持2维和3维,格式为(T, Dr)和(B, S, Dr),dtype支持bfloat16,数据格式支持ND。
  • cache_indexTensor):必选参数,表示用于存储kv_cachekr_cache的索引。shape支持1维和2维,格式为(T)和(B, S),dtype支持int64,数据格式支持ND。cache_index的取值范围为[0, BlockNum*BlockSize),当前不会对cache_index传入值的合法性进行校验,需用户自行保证。
  • kv_cacheTensor):必选参数,表示用于cache索引的aclTensor。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Hckv),dtype支持bfloat16int8,数据格式支持ND。
  • kr_cacheTensor):必选参数,表示用于key位置编码的cache。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Dr),dtype支持bfloat16int8,数据格式支持ND。
  • *:必选参数,代表其之前的变量是位置相关的,必须按照顺序输入;之后的变量是可选参数,位置无关,需要使用键值对赋值,不赋值会使用默认值。
  • dequant_scale_xTensor):可选参数,输入token_x为int8类型时下采样后进行反量化操作时的参数,token_x量化方式为pertoken。其shape支持2维,格式为(T, 1)和(BS, 1),dtype支持float,数据格式支持ND。
  • dequant_scale_w_dqTensor):可选参数,输入token_x为int8类型时下采样后进行反量化操作时的参数,token_x量化方式为perchannel。其shape支持2维,格式为(1, Hcq),dtype支持float,数据格式支持ND。
  • dequant_scale_w_uq_qrTensor):可选参数,用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化方式为perchannel。shape支持2维,格式为(1, N*(D+Dr)),dtype支持float,数据格式支持ND。
  • dequant_scale_w_dkv_krTensor):可选参数,用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化方式为perchannel。其shape支持2维,格式为(1, Hckv+Dr),dtype支持float,数据格式支持ND。
  • quant_scale_ckvTensor):可选参数,用于对输出到kv_cache_out中的数据做量化操作时的参数。shape支持2维,格式为(1, Hckv),dtype支持float,数据格式支持ND。
  • quant_scale_ckrTensor):可选参数,用于对输出到kr_cache_out中的数据做量化操作时的参数。shape支持2维,格式为(1, Dr),dtype支持float,数据格式支持ND。
  • smooth_scales_cqTensor):可选参数,用于对RmsNormCq输出做动态量化操作时的参数。shape支持2维,格式为(1, Hcq),dtype支持float,数据格式支持ND。
  • rmsnorm_epsilon_cqfloat):可选参数,表示计算cQ的RmsNorm公式中的ε参数,默认值为1e-05。
  • rmsnorm_epsilon_ckvfloat):可选参数,表示计算cKV的RmsNorm公式中的ε参数,默认值为1e-05。
  • cache_modestr):可选参数,表示kv_cache的模式,支持"PA_BSND"、"PA_NZ",默认值为“PA_BSND”。

返回值说明

  • queryTensor):表示Query的输出Tensor,即公式中qN。shape支持3维和4维,格式为(T, N, Hckv)和(B, S, N, Hckv),dtype支持bfloat16int8,数据格式支持ND。
  • query_ropeTensor):表示Query位置编码的输出Tensor,即公式中qR。shape支持3维和4维,格式为(T, N, Dr)和(B, S, N, Dr),dtype支持bfloat16,数据格式支持ND。
  • kv_cache_outTensor):表示Key输出到kv_cache中的Tensor(本质in-place更新),即公式中kC。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Hckv),dtype支持bfloat16int8,数据格式支持ND。
  • kr_cache_outTensor):表示Key的位置编码输出到kr_cache中的Tensor(本质in-place更新),即公式中kR。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Dr),dtype支持bfloat16int8,数据格式支持ND。
  • dequant_scale_q_nopeTensor):表示Query的输出Tensor的反量化参数。其shape支持1维和3维,全量化kv_cache量化场景下,其shape为(T, N, 1)和(B*S, N, 1);其他场景下,其shape为(1),dtype支持float,数据格式支持ND。

约束说明

  • 该接口支持推理场景下使用。

  • 该接口支持图模式。

  • 接口参数中shape格式字段含义:

    • B:Batch表示输入样本批量大小,取值范围为0~65536。

    • S:Seq-Length表示输入样本序列长度,取值范围为0~16。

    • He:Head-Size表示隐藏层的大小,取值为7168。

    • Hcq:q低秩矩阵维度,取值为1536。

    • N:Head-Num表示多头数,取值范围为1、2、4、8、16、32、64、128。

    • Hckv:kv低秩矩阵维度,取值为512。

    • D:qk不含位置编码维度,取值为128。

    • Dr:qk位置编码维度,取值为64。

    • Nkv:kv的head数,取值为1。

    • BlockNum:PagedAttention场景下的块数,取值为计算B*Skv/BlockSize的值后再向上取整,其中Skv表示kv的序列长度,该值允许取0。

    • BlockSize:PagedAttention场景下的块大小,取值范围为16、128。

    • T:BS合轴后的大小,取值范围:0~1048576。

  • shape约束:

    • token_x的维度采用BS合轴,即(T, He),则rope_sin和rope_cos的shape为(T, Dr),cache_index的shape为(T,),dequant_scale_x的shape为(T, 1),query的shape为(T, N, Hckv),query_rope的shape为(T, N, Dr)。全量化kv_cache量化场景下,dequant_scale_q_nope的shape为(T, N, 1),其他场景下dequant_scale_q_nope的shape为(1)。
    • token_x的维度不采用BS合轴,即(B, S, He),则rope_sin和rope_cos的shape为(B, S, Dr),cache_index的shape为(B, S),dequant_scale_x的shape为(B*S, 1),query的shape为(B, S, N, Hckv),query_rope的shape为(B, S, N, Dr)。全量化kv_cache量化场景下,dequant_scale_q_nope的shape为(B*S, N, 1),其他场景下dequant_scale_q_nope的shape为(1)。
    • B、S、T、Skv值允许一个或多个取0,即Shape与B、S、T、Skv值相关的入参允许传入空Tensor,其余入参不支持传入空Tensor。
      • 如果B、S、T取值为0,则query、query_rope、dequant_scale_q_nope输出空Tensor,kv_cache、kr_cache、kv_cache_out、kr_cache_out不更新。
      • 如果Skv取值为0,则query、query_rope、dequant_scale_q_nope正常计算,kv_cache、kr_cache、kv_cache_out、kr_cache_out不更新,即输出空Tensor。
  • 本算子支持以下场景:

    场景

    含义

    非量化

    算子所有入参全传入非量化数据,出参全返回非量化数据。

    部分量化

    kv_cache非量化

    入参:weight_uq_qr传入pertoken量化数据,其他入参传入非量化数据。

    出参:全返回非量化数据。

    kv_cache量化

    入参:weight_uq_qr传入pertoken量化数据,kv_cache、kr_cache传入perchannel量化数据,其他入参全传入非量化数据。

    出参:kv_cache_out、kr_cache_out返回perchannel量化数据,其他出参返回非量化数据。

    全量化

    kv_cache非量化

    入参:token_x传入pertoken量化数据,weight_dq、weight_uq_qr、weight_dkv_kr传入perchannel量化数据,其他入参传入非量化数据。

    出参:全返回非量化数据。

    kv_cache量化

    入参:token_x传入pertoken量化数据,weight_dq、weight_uq_qr、weight_dkv_kr传入perchannel量化数据,kv_cache传入pertensor量化数据,其他入参传入非量化数据。

    出参:query返回pertoken_head动态量化数据、kv_cache_out返回pertensor量化数据,其他出参返回非量化数据。

  • 在不同量化场景下,参数的dtype和shape组合需满足如下条件:

    参数名

    非量化场景

    部分量化场景

    全量化场景

    kv_cache非量化

    kv_cache量化

    kv_cache非量化

    kv_cache量化

    dtype

    shape

    dtype

    shape

    dtype

    shape

    dtype

    shape

    dtype

    shape

    token_x

    bfloat16

    • (B,S,He)
    • (T,He)

    bfloat16

    • (B,S,He)
    • (T,He)

    bfloat16

    • (B,S,He)
    • (T,He)

    int8

    • (B,S,He)
    • (T,He)

    int8

    • (B,S,He)
    • (T,He)

    weight_dq

    bfloat16

    (He,Hcq)

    bfloat16

    (He,Hcq)

    bfloat16

    He,Hcq

    int8

    (He,Hcq)

    int8

    (He,Hcq)

    weight_uq_qr

    bfloat16

    (Hcq,N*(D+Dr))

    int8

    (Hcq,N*(D+Dr))

    int8

    (Hcq,N*(D+Dr))

    int8

    (Hcq,N*(D+Dr))

    int8

    (Hcq,N*(D+Dr))

    weight_uk

    bfloat16

    (N,D,Hckv)

    bfloat16

    (N,D,Hckv)

    bfloat16

    (N,D,Hckv)

    bfloat16

    (N,D,Hckv)

    bfloat16

    (N,D,Hckv)

    weight_dkv_kr

    bfloat16

    (He,Hckv+Dr)

    bfloat16

    (He,Hckv+Dr)

    bfloat16

    (He,Hckv+Dr)

    int8

    (He,Hckv+Dr)

    int8

    (He,Hckv+Dr)

    rmsnorm_gamma_cq

    bfloat16

    (Hcq)

    bfloat16

    (Hcq)

    bfloat16

    (Hcq)

    bfloat16

    (Hcq)

    bfloat16

    (Hcq)

    rmsnorm_gamma_ckv

    bfloat16

    (Hckv)

    bfloat16

    (Hckv)

    bfloat16

    (Hckv)

    bfloat16

    (Hckv)

    bfloat16

    (Hckv)

    rope_sin

    bfloat16

    • (B,S,Dr)
    • (T,Dr)

    bfloat16

    • (B,S,Dr)
    • (T,Dr)

    bfloat16

    • (B,S,Dr)
    • (T,Dr)

    bfloat16

    • (B,S,Dr)
    • (T,Dr)

    bfloat16

    • (B,S,Dr)
    • (T,Dr)

    rope_cos

    bfloat16

    • (B,S,Dr)
    • (T,Dr)

    bfloat16

    • (B,S,Dr)
    • (T,Dr)

    bfloat16

    • (B,S,Dr)
    • (T,Dr)

    bfloat16

    • (B,S,Dr)
    • (T,Dr)

    bfloat16

    • (B,S,Dr)
    • (T,Dr)

    cache_index

    int64

    • (B,S)
    • (T)

    int64

    • (B,S)
    • (T)

    int64

    • (B,S)
    • (T)

    int64

    • (B,S)
    • (T)

    int64

    • (B,S)
    • (T)

    kv_cache

    bfloat16

    (BlockNum,BlockSize,Nkv,Hckv)

    bfloat16

    (BlockNum,BlockSize,Nkv,Hckv)

    int8

    (BlockNum,BlockSize,Nkv,Hckv)

    bfloat16

    (BlockNum,BlockSize,Nkv,Hckv)

    int8

    (BlockNum,BlockSize,Nkv,Hckv)

    kr_cache

    bfloat16

    (BlockNum,BlockSize,Nkv,Dr)

    bfloat16

    (BlockNum,BlockSize,Nkv,Dr)

    int8

    (BlockNum,BlockSize,Nkv,Dr)

    bfloat16

    (BlockNum,BlockSize,Nkv,Dr)

    bfloat16

    (BlockNum,BlockSize,Nkv,Dr)

    dequant_scale_x

    无需赋值

    /

    无需赋值

    /

    无需赋值

    /

    float

    • (B*S,1)
    • (T,1)

    float

    • (B*S,1)
    • (T,1)

    dequant_scale_w_dq

    无需赋值

    /

    无需赋值

    /

    无需赋值

    /

    float

    1,Hcq

    float

    1,Hcq

    dequant_scale_w_uq_qr

    无需赋值

    /

    float

    (1,N*(D+Dr))

    float

    (1,N*(D+Dr))

    float

    (1,N*(D+Dr))

    float

    (1,N*(D+Dr))

    dequant_scale_w_dkv_kr

    无需赋值

    /

    无需赋值

    /

    无需赋值

    /

    float

    (1,Hckv+Dr)

    float

    (1,Hckv+Dr)

    quant_scale_ckv

    无需赋值

    /

    无需赋值

    /

    float

    (1,Hckv)

    无需赋值

    /

    float

    (1,Hckv)

    quant_scale_ckr

    无需赋值

    /

    无需赋值

    /

    float

    (1,Dr)

    无需赋值

    /

    无需赋值

    /

    smooth_scales_cq

    无需赋值

    /

    float

    (1,Hcq)

    float

    (1,Hcq)

    float

    (1,Hcq)

    float

    (1,Hcq)

    query

    bfloat16

    • (B,S,N,Hckv)
    • (T,N,Hckv)

    bfloat16

    • (B,S,N,Hckv)
    • (T,N,Hckv)

    bfloat16

    • (B,S,N,Hckv)
    • (T,N,Hckv)

    bfloat16

    • (B,S,N,Hckv)
    • (T,N,Hckv)

    int8

    • (B,S,N,Hckv)
    • (T,N,Hckv)

    query_rope

    bfloat16

    • (B,S,N,Dr)
    • (T,N,Dr)

    bfloat16

    • (B,S,N,Dr)
    • (T,N,Dr)

    bfloat16

    • (B,S,N,Dr)
    • (T,N,Dr)

    bfloat16

    • (B,S,N,Dr)
    • (T,N,Dr)

    bfloat16

    • (B,S,N,Dr)
    • (T,N,Dr)

    kv_cache

    bfloat16

    (BlockNum,BlockSize,Nkv,Hckv)

    bfloat16

    (BlockNum,BlockSize,Nkv,Hckv)

    int8

    (BlockNum,BlockSize,Nkv,Hckv)

    bfloat16

    (BlockNum,BlockSize,Nkv,Hckv)

    int8

    (BlockNum,BlockSize,Nkv,Hckv)

    kr_cache

    bfloat16

    (BlockNum,BlockSize,Nkv,Dr)

    bfloat16

    (BlockNum,BlockSize,Nkv,Dr)

    int8

    (BlockNum,BlockSize,Nkv,Dr)

    bfloat16

    (BlockNum,BlockSize,Nkv,Dr)

    bfloat16

    (BlockNum,BlockSize,Nkv,Dr)

    dequant_scale_q_nope

    float

    (1,)

    float

    (1,)

    float

    (1,)

    float

    (1,)

    float

    • (B*S,N,1)
    • (T,N,1)

调用示例

  • 单算子模式调用

    import torch
    import torch_npu
    import math
    torch.npu.config.allow_internal_format = True
    # 生成随机数据, 并发送到npu
    B = 8
    He = 7168
    Hcq = 1536
    Hckv = 512
    N = 32
    D = 128
    Dr = 64
    Skv = 1024
    S = 2
    Nkv = 1
    BlockSize = 128
    BlockNum = 64
    token_x = torch.rand(B, S, He, dtype=torch.bfloat16).npu()
    w_dq = torch.rand(He, Hcq, dtype=torch.bfloat16).npu()
    w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
    w_uq_qr = torch.rand(Hcq, N * (D + Dr), dtype=torch.bfloat16).npu()
    w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
    w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
    w_dkv_kr = torch.rand(He, Hckv + Dr, dtype=torch.bfloat16).npu()
    w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
    rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
    rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
    rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
    rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
    cache_index = torch.rand(B, S).to(torch.int64).npu()
    kv_cache = torch.rand(BlockNum, BlockSize, Nkv, Hckv, dtype=torch.bfloat16).npu()
    kr_cache = torch.rand(BlockNum, BlockSize, Nkv, Dr, dtype=torch.bfloat16).npu()
    rmsnorm_epsilon_cq = 1.0e-5
    rmsnorm_epsilon_ckv = 1.0e-5
    cache_mode = "PA_BSND"
    
    # 调用MlaProlog算子
    query_mla, query_rope_mla, kv_cache_out_mla, kr_cache_out_mla, dequant_scale_q_nope_mla = torch_npu.npu_mla_prolog_v2(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq, rmsnorm_gamma_ckv, rope_sin, rope_cos, cache_index, kv_cache, kr_cache, rmsnorm_epsilon_cq=rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv=rmsnorm_epsilon_ckv, cache_mode=cache_mode)
    print(query_mla)
    # 执行上述代码的输出out类似如下
    tensor([[[[ 0.0219,  0.0201,  0.0049,  ...,  0.0118, -0.0011, -0.0140],
            [ 0.0294,  0.0256, -0.0081,  ...,  0.0267,  0.0067, -0.0117],
            [ 0.0285,  0.0296,  0.0011,  ...,  0.0150,  0.0056, -0.0062],
            ..
            [ 0.0177,  0.0194, -0.0060,  ...,  0.0226,  0.0029, -0.0039],
            [ 0.0180,  0.0186, -0.0067,  ...,  0.0204, -0.0045, -0.0164],
            [ 0.0176,  0.0288, -0.0091,  ...,  0.0304,  0.0033, -0.0173]]]],
            device='npu:0', dtype=torch.bfloat16)
    
  • 图模式调用

    # 入图方式
    import torch
    import torch_npu
    import math
    import torchair as tng
    
    from torchair.configs.compiler_config import CompilerConfig
    import torch._dynamo
    TORCHDYNAMO_VERBOSE=1
    TORCH_LOGS="+dynamo"
    
    # 支持入图的打印宏
    import logging
    from torchair.core.utils import logger
    logger.setLevel(logging.DEBUG)
    config = CompilerConfig()
    config.debug.graph_dump.type = "pbtxt"
    npu_backend = tng.get_npu_backend(compiler_config=config)
    from torch.library import Library, impl
    torch.npu.config.allow_internal_format = True
    
    # 数据生成
    B = 8
    He = 7168
    Hcq = 1536
    Hckv = 512
    N = 32
    D = 128
    Dr = 64
    Skv = 1024
    S = 2
    Nkv = 1
    BlockSize = 128
    BlockNum = 64
    token_x = torch.rand(B, S, He, dtype=torch.bfloat16).npu()
    w_dq = torch.rand(He, Hcq, dtype=torch.bfloat16).npu()
    w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
    w_uq_qr = torch.rand(Hcq, N * (D + Dr), dtype=torch.bfloat16).npu()
    w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
    w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
    w_dkv_kr = torch.rand(He, Hckv + Dr, dtype=torch.bfloat16).npu()
    w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
    rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
    rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
    rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
    rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
    cache_index = torch.rand(B, S).to(torch.int64).npu()
    kv_cache = torch.rand(BlockNum, BlockSize, Nkv, Hckv, dtype=torch.bfloat16).npu()
    kr_cache = torch.rand(BlockNum, BlockSize, Nkv, Dr, dtype=torch.bfloat16).npu()
    rmsnorm_epsilon_cq = 1.0e-5
    rmsnorm_epsilon_ckv = 1.0e-5
    cache_mode = "PA_BSND"
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self):
            return torch_npu.npu_mla_prolog_v2(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,rmsnorm_gamma_ckv, rope_sin, rope_cos, cache_index, kv_cache, kr_cache, rmsnorm_epsilon_cq=rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv=rmsnorm_epsilon_ckv, cache_mode=cache_mode)
    def MetaInfershape():
        with torch.no_grad():
            model = Model()
            model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
            graph_output = model()
        query_mla, query_rope_mla, kv_cache_out_mla, kr_cache_out_mla, dequant_scale_q_nope_mla = torch_npu.npu_mla_prolog_v2(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,rmsnorm_gamma_ckv, rope_sin, rope_cos, cache_index, kv_cache, kr_cache, rmsnorm_epsilon_cq=rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv=rmsnorm_epsilon_ckv, cache_mode=cache_mode)
        print("single op output:", query_mla)
        print("graph output:", graph_output)
        
    if __name__ == "__main__":
        MetaInfershape()
    
    # 执行上述代码的输出类似如下
    single op output with mask: tensor([[[[ 0.0219,  0.0201,  0.0049,  ...,  0.0118, -0.0011, -0.0140],
            [ 0.0294,  0.0256, -0.0081,  ...,  0.0267,  0.0067, -0.0117],
            [ 0.0285,  0.0296,  0.0011,  ...,  0.0150,  0.0056, -0.0062],
            ...,
            [ 0.0177,  0.0194, -0.0060,  ...,  0.0226,  0.0029, -0.0039],
            [ 0.0180,  0.0186, -0.0067,  ...,  0.0204, -0.0045, -0.0164],
            [ 0.0176,  0.0288, -0.0091,  ...,  0.0304,  0.0033, -0.0173]]]],
            device='npu:0', dtype=torch.bfloat16)
    
    graph output with mask: tensor([[[[ 0.0219,  0.0201,  0.0049,  ...,  0.0118, -0.0011, -0.0140],
            [ 0.0294,  0.0256, -0.0081,  ...,  0.0267,  0.0067, -0.0117],
            [ 0.0285,  0.0296,  0.0011,  ...,  0.0150,  0.0056, -0.0062],
            ...,
            [ 0.0177,  0.0194, -0.0060,  ...,  0.0226,  0.0029, -0.0039],
            [ 0.0180,  0.0186, -0.0067,  ...,  0.0204, -0.0045, -0.0164],
            [ 0.0176,  0.0288, -0.0091,  ...,  0.0304,  0.0033, -0.0173]]]],
            device='npu:0', dtype=torch.bfloat16)