aclnnMlaPrologV3WeightNz

📄 查看源码

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品

功能说明

  • 功能更新:(相对于aclnnMlaPrologV2weightNz的差异)

    • 新增Query与Key的尺度矫正因子,分别对应qcQrScale(αq\alpha_q)与kcScale(αkv\alpha_{kv})。
    • 新增可选输入参数(例如actualSeqLenOptional、kNopeClipAlphaOptional、queryNormFlag、weightQuantMode、kvCacheQuantMode、queryQuantMode、ckvkrRepoMode、quantScaleRepoMode、tileSize、queryNormOutOptional和dequantScaleQNormOptional等),将cache_mode由必选改为可选。
    • 调整cacheIndex参数的名称与位置,对应当前的cacheIndexOptional。
  • 接口功能:推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程分为五路:

    • 首先对输入xx乘以WDQW^{DQ}进行下采样和RmsNorm后分为两路,第一路乘以WUQW^{UQ}WUKW^{UK}经过两次上采样后,再乘以Query尺度矫正因子αq\alpha_q得到qNq^N;第二路乘以WQRW^{QR}后经过旋转位置编码(ROPE)得到qRq^R
    • 第三路是输入xx乘以WDKVW^{DKV}进行下采样和RmsNorm后,乘以Key尺度矫正因子αkv\alpha_{kv}传入Cache中得到kCk^C
    • 第四路是输入xx乘以WKRW^{KR}后经过旋转位置编码后传入另一个Cache中得到kRk^R
    • 第五路是输出qNq^N经过DynamicQuant后得到的量化参数。
    • 权重参数WeightDq、WeightUqQr和WeightDkvKr需要以NZ格式传入
  • 计算公式

    RmsNorm公式

    RmsNorm(x)=γ⋅xiRMS(x)\text{RmsNorm}(x) = \gamma \cdot \frac{x_i}{\text{RMS}(x)}

    RMS(x)=1N∑i=1Nxi2+ϵ\text{RMS}(x) = \sqrt{\frac{1}{N} \sum_{i=1}^{N} x_i^2 + \epsilon}

    Query的计算公式,包括下采样、RmsNorm和两次上采样

    cQ=αq⋅RmsNorm(x⋅WDQ)c^Q = \alpha_q\cdot\mathrm{RmsNorm}(x \cdot W^{DQ})

    qC=cQ⋅WUQq^C = c^Q \cdot W^{UQ}

    qN=qC⋅WUKq^N = q^C \cdot W^{UK}

    对Query进行ROPE旋转位置编码

    qR=ROPE(cQ⋅WQR)q^R = \mathrm{ROPE}(c^Q \cdot W^{QR})

    Key的计算公式,包括下采样和RmsNorm,将计算结果存入cache

    cKV=αkv⋅RmsNorm(x⋅WDKV)c^{KV} = \alpha_{kv}\cdot\mathrm{RmsNorm}(x \cdot W^{DKV})

    kC=Cache(cKV)k^C = \mathrm{Cache}(c^{KV})

    对Key进行ROPE旋转位置编码,并将结果存入cache

    kR=Cache(ROPE(x⋅WKR))k^R = \mathrm{Cache}(\mathrm{ROPE}(x \cdot W^{KR}))

    Dequant Scale Query Nope 计算公式

    dequantScaleQNope=RowMax(abs(qN))/127\mathrm{dequantScaleQNope} = {\mathrm{RowMax}(\mathrm{abs}(q^{N})) / 127}

    qN=round(qN/dequantScaleQNope)q^{N} = {\mathrm{round}(q^{N} / \mathrm{dequantScaleQNope})}

函数原型

每个算子分为两段式接口,必须先调用“aclnnMlaPrologV3WeightNzGetWorkspaceSize”接口获取入参并根据流程计算所需workspace大小,再调用“aclnnMlaPrologV3WeightNz”接口执行计算。

aclnnStatus aclnnMlaPrologV3WeightNzGetWorkspaceSize(
    const aclTensor *tokenX,
    const aclTensor *weightDq,
    const aclTensor *weightUqQr,
    const aclTensor *weightUk,
    const aclTensor *weightDkvKr,
    const aclTensor *rmsnormGammaCq,
    const aclTensor *rmsnormGammaCkv,
    const aclTensor *ropeSin,
    const aclTensor *ropeCos,
    aclTensor       *kvCacheRef,
    aclTensor       *krCacheRef,
    const aclTensor *cacheIndexOptional,
    const aclTensor *dequantScaleXOptional,
    const aclTensor *dequantScaleWDqOptional,
    const aclTensor *dequantScaleWUqQrOptional,
    const aclTensor *dequantScaleWDkvKrOptional,
    const aclTensor *quantScaleCkvOptional,
    const aclTensor *quantScaleCkrOptional,
    const aclTensor *smoothScalesCqOptional,
    const aclTensor *actualSeqLenOptional,
    const aclTensor *kNopeClipAlphaOptional,
    double           rmsnormEpsilonCq,
    double           rmsnormEpsilonCkv,
    char            *cacheModeOptional,
    int64_t          weightQuantMode,
    int64_t          kvCacheQuantMode,
    int64_t          queryQuantMode,
    int64_t          ckvkrRepoMode,
    int64_t          quantScaleRepoMode,
    int64_t          tileSize,
    double           qcQrScale,
    double           kcScale,
    const aclTensor *queryOut,
    const aclTensor *queryRopeOut,
    const aclTensor *dequantScaleQNopeOutOptional,
    const aclTensor *queryNormOutOptional,
    const aclTensor *dequantScaleQNormOutOptional,
    uint64_t        *workspaceSize,
    aclOpExecutor  **executor)
aclnnStatus aclnnMlaPrologV3WeightNz(
  void             *workspace,
  uint64_t          workspaceSize,
  aclOpExecutor    *executor,
  const aclrtStream stream)

aclnnMlaPrologV3WeightNzGetWorkspaceSize

  • 参数说明

    参数名 输入/输出 描述 使用说明 数据类型 数据格式 维度(shape) 非连续Tensor
    tokenX 输入 公式中用于计算Query和Key的输入tensor。 支持B=0,S=0,T=0的空Tensor BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 ND
    • BS合轴:(T, He)
    • BS非合轴:(B, S, He)
    ×
    weightDq 输入 公式中用于计算Query的下采样权重矩阵WDQ
    在不转置的情况下各个维度的表示:(k, n)
    不支持空Tensor BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 FRACTAL_NZ (He, Hcq) ×
    weightUqQr 输入 公式中用于计算Query的上采样权重矩阵WUQ和位置编码权重矩阵WQR
    在不转置的情况下各个维度的表示:(k, n)
    不支持空Tensor BFLOAT16、INT8、FLOAT8_E4M3FN、HIFLOAT8 FRACTAL_NZ (Hcq, N*(D+Dr)) ×
    weightUk 输入 公式中用于计算Key的上采样权重WUK 不支持空Tensor BFLOAT16 ND (N, D, Hckv) ×
    weightDkvKr 输入 公式中用于计算Key的下采样权重矩阵WDKV和位置编码权重矩阵WKR
    在不转置的情况下各个维度的表示:(k, n)
    不支持空Tensor BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 FRACTAL_NZ (He,Hckv+Dr) ×
    rmsnormGammaCq 输入 计算cQ的RmsNorm公式中的γ参数。 不支持空Tensor BFLOAT16 ND (Hcq) ×
    rmsnormGammaCkv 输入 计算cKV的RmsNorm公式中的γ参数。 不支持空Tensor BFLOAT16 ND (Hckv) ×
    ropeSin 输入 用于计算旋转位置编码的正弦参数矩阵。 支持B=0,S=0,T=0的空Tensor BFLOAT16 ND
    • BS合轴:(T,Dr)
    • BS非合轴:(B,S,Dr)
    ×
    ropeCos 输入 用于计算旋转位置编码的余弦参数矩阵。 支持B=0,S=0,T=0的空Tensor BFLOAT16 ND
    • BS合轴:(T,Dr)
    • BS非合轴:(B,S,Dr)
    ×
    kvCacheRef 输入/输出 用于cache索引的aclTensor,计算结果原地更新(对应公式中的kC)。
    • 支持B=0,Skv=0的空Tensor;
    • Nkv与N关联,N是超参,故Nkv不支持等于0
    BFLOAT16、INT8、FLOAT8_E4M3FN、HIFLOAT8 ND
    • CacheMode="PA_BSND"/"PA_NZ"/"PA_BLK_BSND"/"PA_BLK_NZ": (BlockNum,BlockSize,Nkv,Dtile)
    • CacheMode="BSND": (B,S,Nkv,Dtile)
    • CacheMode="TND": (T,Nkv,Dtile)
    ×
    krCacheRef 输入/输出 用于key位置编码的cache,计算结果原地更新(对应公式中的kR)。
    • 支持B=0,Skv=0的空Tensor;
    • Nkv与N关联,N是超参,故Nkv不支持等于0
    BFLOAT16、INT8 ND
    • CacheMode="PA_BSND"/"PA_NZ"/"PA_BLK_BSND"/"PA_BLK_NZ": (BlockNum,BlockSize,Nkv,Dr)
    • CacheMode="BSND": (B,S,Nkv,Dr)
    • CacheMode="TND"时: (T,Nkv,Dr)
    • 当ckvkrRepoMode=1时: 维度应包含0,支持shape为(0)
    ×
    cacheIndexOptional 输入 用于存储kvCache和krCache的索引。
    • 支持B=0,S=0,T=0的空Tensor
    • cacheMode="PA_BSND"/"PA_NZ": 取值范围需在[0,BlockNum*BlockSize)内
    • cacheMode="PA_BLK_BSND"/"PA_BLK_NZ": 取值范围需在[0,BlockNum)内
    • cacheMode="TND"/"BSND": nullptr
    INT64 ND
    • CacheMode="PA_BSND"/"PA_NZ":
      1. BS合轴:(T)
      2. BS非合轴:(B,S)
    • CacheMode="PA_BLK_BSND"/"PA_BLK_NZ":
      1. BS合轴:(Sum(⌈Si/BlockSize⌉)),Si 为每个Batch中的S的长度
      2. BS非合轴:(B,⌈S/BlockSize⌉)
    • CacheMode="TND"/"BSND": nullptr
    ×
    dequantScaleXOptional 输入 tokenX的反量化参数。 支持B=0,S=0,T=0的空Tensor(weightQuantMode=2/3/4/5的场景需传) FLOAT8_E8M0、FLOAT ND
    • weightQuantMode=2/4/5:
      1. BS合轴:(T, 1)
      2. BS非合轴:(B*S,1)
    • weightQuantMode=3:
      1. BS合轴:(T, He/32)
      2. BS非合轴:(B*S, He/32)
    ×
    dequantScaleWDqOptional 输入 weightDq的反量化参数。 支持非空Tensor(weightQuantMode=2/3/4/5的场景需传) FLOAT8_E8M0、FLOAT ND
    • weightQuantMode=2/4/5:(1,Hcq)
    • weightQuantMode=3:(Hcq, He/32)
    ×
    dequantScaleWUqQrOptional 输入 用于MatmulQcQr矩阵乘后反量化操作的per-channel参数。 支持非空Tensor(weightQuantMode=1/2/3/4/5的场景需传) FLOAT、FLOAT8_E8M0 ND
    • weightQuantMode=1/2/4/5:(1,N*(D+Dr))
    • weightQuantMode=3:(N*(D+Dr), Hcq/32)
    ×
    dequantScaleWDkvKrOptional 输入 weightDkvKr的反量化参数。 支持非空Tensor(weightQuantMode=2/3/4/5的场景需传) FLOAT8_E8M0、FLOAT ND
    • weightQuantMode=2/4/5:(1,Hckv+Dr)
    • weightQuantMode=3:(Hckv+Dr, He/32)
    ×
    quantScaleCkvOptional 输入 用于对kvCache输出数据做量化操作的参数。 支持非空Tensor(kvCacheQuantMode=1/2的场景需传) FLOAT ND
    • kvCacheQuantMode=1:(1)
    • kvCacheQuantMode=2:(1,Hckv)
    ×
    quantScaleCkrOptional 输入 用于对krCache输出数据做量化操作的参数。 支持非空Tensor(kvCacheQuantMode=2的场景需传) FLOAT ND (1,Dr) ×
    smoothScalesCqOptional 输入 用于对RmsNormCq输出做动态量化操作的参数。 支持非空Tensor(weightQuantMode=1/2/4/5的场景可选传) FLOAT ND (1,Hcq) ×
    actualSeqLenOptional 输入 表示每个batch中的序列长度,以前缀和的形式储存。 BS合轴且CacheMode="PA_BLK_BSND"/"PA_BLK_NZ"时需传 INT32 ND (B) ×
    kNopeClipAlphaOptional 输入 表示对kvCache做clip操作时的缩放因子。 在部分量化per-tile场景和int8全量化per-tile场景下shape为1,其余场景可不填,不支持空Tensor FLOAT ND (1) ×
    rmsnormEpsilonCq 输入 计算cQ的RmsNorm公式中的ε参数。
    • 用户未特意指定时,建议传入1e-05
    • 仅支持double类型
    DOUBLE - - -
    rmsnormEpsilonCkv 输入 计算cKV的RmsNorm公式中的ε参数。
    • 用户未特意指定时,建议传入1e-05
    • 仅支持double类型
    DOUBLE - - -
    cacheModeOptional 输入 表示kvCache的模式。
    • 用户未特意指定时,建议传入"PA_BSND"
    • 仅支持char*类型
    • 可选值为:"PA_BSND"、 "PA_NZ"、 "PA_BLK_BSND"、 "PA_BLK_NZ"、 "BSND"、 "TND"
    CHAR* - - -
    queryNormFlag 输入 表示是否输出queryNormOutOptional、dequantScaleQNormOutOptional。 false表示不输出,true表示输出,默认值为false BOOL - - -
    weightQuantMode 输入 表示weightDq、weightUqQr、weightUk、weightDkvKr的量化模式。
    • 0 表示非量化
    • 1 表示 weightUqQr 量化
    • 2 表示 weightDq、weightUqQr、weightDkvKr int8 量化
    • 3 表示 weightDq、weightUqQr、weightDkvKr mxfp8 量化
    • 4 表示 weightDq、weightUqQr、weightDkvKr fp8 量化
    • 5 表示 weightDq、weightUqQr、weightDkvKr hif8 量化
    • 默认值为 0
    INT - - -
    kvCacheQuantMode 输入 表示kvCache的量化模式。 0表示非量化,1表示per-tensor量化,2表示per-channel量化,3表示per-tile量化,默认值为0 INT64 - - -
    queryQuantMode 输入 表示query的量化模式。 0表示非量化,1表示per-token-head量化,默认值为0 INT64 - - -
    ckvkrRepoMode 输入 表示kvCache和krCache的存储模式。 0表示kvCache和krCache分别存储,1表示kvCache和krCache合并存储,默认值为0 INT64 - - -
    quantScaleRepoMode 输入 表示量化scale的存储模式。 0表示量化scale和数据分别存储,1表示量化scale和数据合并存储作为kvCacheRef输出,默认值为0 INT64 - - -
    tileSize 输入 表示per-tile量化时每个tile的大小,仅在kvCacheQuantMode为3时有效。 默认值为128 INT64 - - -
    qcQrScale 输入 表示Query的尺度矫正系数。 用户不特意指定时需要传入1.0 DOUBLE - - -
    kcScale 输入 表示Key的尺度矫正系数。 用户不特意指定时需要传入1.0 DOUBLE - - -
    queryOut 输出 公式中Query的输出tensor(对应qN)。 - BFLOAT16、FLOAT8_E4M3FN、INT8、HIFLOAT8 ND
    • BS合轴:(T,N,Hckv)
    • BS非合轴:(B,S,N,Hckv)
    ×
    queryRopeOut 输出 公式中Query位置编码的输出tensor(对应qR)。 - BFLOAT16 ND
    • BS合轴:(T,N,Dr)
    • BS非合轴:(B,S,N,Dr)
    ×
    dequantScaleQNopeOutOptional 输出 公式中Query输出的反量化参数。 weightQuantMode=2/3/4/5时输出,weightQuantMode=0/1时为nullptr FLOAT ND
    • BS合轴:(T,N,1)
    • BS非合轴:(B*S,N,1)
    ×
    queryNormOutOptional 输出 公式中tokenX做rmsNorm后的输出tensor(对应cQ)。 queryNormFlag=true时输出 BFLOAT16、INT8、FLOAT8_E4M3FN、HIFLOAT8 ND
    • BS合轴:(T,Hcq)
    • BS非合轴:(B,S,Hcq)
    ×
    dequantScaleQNormOutOptional 输出 queryNormOutOptional的反量化参数。 queryNormFlag=true,weightQuantMode=1/2/3/4/5时输出,weightQuantMode=0时为nullptr FLOAT、FLOAT8_E8M0 ND
    • weightQuantMode=1/2/4/5:
      1. BS合轴:(T,1)
      2. BS非合轴:(B*S,1)
    • weightQuantMode=3:
      1. BS合轴:(T, Hcq/32)
      2. BS非合轴:(B*S, Hcq/32)
    ×
    workspaceSize 输出 返回需在Device侧申请的workspace大小。 仅用于输出结果,无需输入配置,数据类型为uint64_t* - - - -
    executor 输出 返回op执行器,包含算子计算流程。 仅用于输出结果,无需输入配置,数据类型为aclOpExecutor** - - - -
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:

    • tokenX、weightDq、weightUqQr、weightDkvKr、kvCacheRef、queryOut、queryNormOutOptional不支持FLOAT8_E4M3FN、HIFLOAT8数据类型。
    • dequantScaleXOptional、dequantScaleWDqOptional、dequantScaleWUqQrOptional、dequantScaleWDkvKrOptional、dequantScaleQNormOutOptional不支持FLOAT8_E8M0数据类型。
  • 返回值

    aclnnStatus:返回状态码,具体参见aclnn返回码

    第一段接口完成入参校验,出现以下场景时报错:

    返回值 错误码 描述
    ACLNN_ERR_PARAM_NULLPTR 161001 必须传入的参数(如接口核心依赖的输入/输出参数)中存在空指针。
    ACLNN_ERR_PARAM_INVALID 161002 输入参数的 shape(维度/尺寸)、dtype(数据类型)不在接口支持的范围内。
    ACLNN_ERR_RUNTIME_ERROR 361001 API 内存调用 NPU Runtime 接口时发生异常(如 Runtime 服务未启动、内存申请失败等)。
    ACLNN_ERR_INNER_TILING_ERROR 561002 tiling发生异常,入参的dtype类型或者shape错误。

aclnnMlaPrologV3WeightNz

  • 参数说明

    参数名 参数类型 含义
    workspace void* 在Device侧申请的workspace内存地址。
    workspaceSize uint64_t 在Device侧申请的workspace大小,由第一段接口aclnnMlaPrologV3WeightNzGetWorkspaceSize获取。
    executor aclOpExecutor* op执行器,包含了算子计算流程。
    stream aclrtStream 指定执行任务的Stream。
  • 返回值

    aclnnStatus:返回状态码,具体参见aclnn返回码

约束说明

  • 确定性计算:
    • aclnnMlaPrologV3WeightNz默认非确定性实现,支持通过aclrtCtxSetSysParamOpt开启确定性。
shape 格式字段含义说明   
字段名 英文全称/含义 取值规则与说明
B Batch(输入样本批量大小) 取值范围:0~65536
S Seq-Length(输入样本序列长度) 取值范围:不限制
He Head-Size(隐藏层大小) 取值固定为:1024、2048、3072、4096、5120、6144、7168、7680、8192
Hcq q 低秩矩阵维度 取值固定为:1536、2048
N Head-Num(多头数) 取值范围:1、2、4、8、16、32、64、128
Hckv kv 低秩矩阵维度 取值固定为:512
D qk 不含位置编码维度 取值固定为:128、192
Dr qk 位置编码维度 取值固定为:64
Nkv kv 的 head 数 取值固定为:1
BlockNum PagedAttention 场景下的块数
  • 当CacheMode="PA_BSND"/"PA_NZ"时,取值大于或等于 ⌈(B * S) / BlockSize⌉。
  • 当CacheMode="PA_BLK_BSND"/"PA_BLK_NZ"时,取值大于或等于 B * ⌈S / BlockSize⌉。

注:BS合轴场景,每个Batch中的S长度可以不同,因此BlockNum的取值需大于或等于各Batch中S长度除以BlockSize后的向上取整结果相加。

举例:actualSeqLenOptional数值为[47, 151, 261, 422],blocksize=128,那么Batch中的长度分别为[47, 104, 110, 161],此时 BlockNum = ⌈47/128⌉ + ⌈104/128⌉ + ⌈110/128⌉ + ⌈161/128⌉ = 5

BlockSize PagedAttention 场景下的块大小 取值范围:16~1024,且为16的倍数
T BS 合轴后的大小
  • 取值范围:不限制
  • 注:若采用 BS 合轴,此时 tokenX、ropeSin、ropeCos 均为 2 维,cacheIndex 为 1 维,queryOut、queryRopeOut 为 3 维
Dtile kvCache的D维度的大小
  • Per-tile量化场景下,取值固定为656,即等于 512(Hckv) + 64(Dr)*2 + 4(Hckv/tileSize)*4
  • 其他场景下,取值固定为Hckv(512)
shape 约束
  • 若tokenX的维度采用BS合轴,即(T, He)
    • ropeSin和ropeCos的shape为(T, Dr)
    • 当CacheMode为PA_BSND或PA_NZ时,cacheIndex的shape为(T)
    • 当CacheMode为PA_BLK_BSND或PA_BLK_NZ时,cacheIndex的shape为(Sum(⌈Si/BlockSize⌉)),S_i为每个Batch中的S的长度
    • 当CacheMode为PA_BLK_BSND或PA_BLK_NZ时,actualSeqLenOptional需要传入,维度为(B)
    • int8/fp8/hif8全量化场景下,dequantScaleXOptional的shape为(T, 1);mxfp8全量化场景下,dequantScaleXOptional的shape为(T, He/32)
    • queryOut的shape为(T, N, Hckv)
    • queryRopeOut的shape为(T, N, Dr)
    • int8/mxfp8/fp8/hif8全量化场景下,dequantScaleQNopeOutOptional的shape为(T, N, 1),其他场景下为nullptr
  • 若tokenX的维度不采用BS合轴,即(B, S, He)
    • ropeSin和ropeCos的shape为(B, S, Dr)
    • 当CacheMode为PA_BSND或PA_NZ时,cacheIndex的shape为(B, S)
    • 当CacheMode为PA_BLK_BSND或PA_BLK_NZ时,cacheIndex的shape为(B, ⌈S/BlockSize⌉)
    • int8/fp8/hif8全量化场景下,dequantScaleXOptional的shape为(B*S, 1);mxfp8全量化场景下,dequantScaleXOptional的shape为(B*S, He/32)
    • queryOut的shape为(B, S, N, Hckv)
    • queryRopeOut的shape为(B, S, N, Dr)
    • int8/mxfp8/fp8/hif8全量化场景下,dequantScaleQNopeOutOptional的shape为(B*S, N, 1),其他场景下为nullptr
  • B、S、T、Skv值允许一个或多个取0,即Shape与B、S、T、Skv值相关的入参允许传入空Tensor,其余入参不支持传入空Tensor。
    • 如果B、S、T取值为0,则queryOut、queryRopeOut输出空Tensor,kvCacheRef、krCacheRef不做更新。
    • 如果Skv取值为0,则queryOut、queryRopeOut、dequantScaleQNopeOutOptional正常计算,kvCacheRef、krCacheRef不做更新,即输出空Tensor。
  • 当CacheMode为BSND时
    • tokenX应不采用BS合轴,即维度为(B, S, He)
    • kvCache的维度为(B,S,Nkv,Dr)
  • 当CacheMode为TND时
    • tokenX应采用BS合轴,即维度为(T, He)
    • kvCache的维度为(T,Nkv,Dr)
  • 当ckvkrRepoMode=1时
    • krCache的维度应包含0,支持shape为(0)
特殊约束
  • actualSeqLenOptional传入时,actualSeqLenOptional最后一个数需与T保持一致。
  • per-tile量化模式下,ckvkrRepoMode和quantScaleRepoMode必须同时为1;其他量化模式以及非量化场景下,ckvkrRepoMode和quantScaleRepoMode必须同时为0。
  • per-tile量化模式下,CacheMode只支持PA_BSND, BSND和TND。
  • 当ckvkrRepoMode值为1时,krCache必须为空Tensor(即shape的乘积为0)。
  • kvcache per-tensor量化模式下,kvCacheQuantMode和queryQuantMode必须同时为1。
aclnnMlaPrologV3WeightNz接口支持场景
  • Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品:当前不支持 fp8/hif8/mxfp8 全量化场景
  • Ascend 950PR/Ascend 950DT:当前支持所有量化场景
场景 含义
非量化 weight_quant_mode=0,kv_cache_quant_mode=0,query_quant_mode=0
入参:所有入参皆为非量化数据
出参:所有出参皆为非量化数据
部分量化 kvCache非量化 weight_quant_mode=1,kv_cache_quant_mode=0,query_quant_mode=0
入参:weightUqQr传入per-token量化数据,其余入参皆为非量化数据。dequantScaleWUqQr字段必须传入,smoothScalesCq字段可选传入
出参:所有出参返回非量化数据
kvCache per-channel量化 weight_quant_mode=1,kv_cache_quant_mode=2,query_quant_mode=0
入参:weightUqQr传入per-token量化数据,kvCacheRef、krCacheRef传入per-channel量化数据,其余入参皆为非量化数据。dequantScaleWUqQr、quantScaleCkv、quant_scale_ckr字段必须传入,smoothScalesCq字段可选传入
出参:kvCacheRef、krCacheRef返回per-channel量化数据,其余出参返回非量化数据
kvCache per-tile量化 weight_quant_mode=1, kv_cache_quant_mode=3, query_quant_mode=0
入参:weightUqQr传入per-token量化数据,其余入参皆为非量化数据。dequantScaleWUqQr字段必须传入,smoothScalesCq字段可选传入
出参:kvCacheRef返回per-tile量化数据,其余出参返回非量化数据
int8/fp8/hif8全量化 kvCache非量化 weight_quant_mode=2/4/5,kv_cache_quant_mode=0,query_quant_mode=0
入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr字段必须传入,smoothScalesCq字段可选传入
出参:所有出参返回非量化数据
kvCache per-tensor量化 weight_quant_mode=2/4/5,kv_cache_quant_mode=1,query_quant_mode=1
入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,kvCacheRef传入per-tensor量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr、quantScaleCkv字段必须传入,smoothScalesCq字段可选传入
出参:queryOut返回per-token-head量化数据,kvCacheRef出参返回per-tensor量化数据,其余出参返回非量化数据
kvCache per-tile量化 weight_quant_mode=2/4/5,kv_cache_quant_mode=3,query_quant_mode=0
入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr字段必须传入,smoothScalesCq字段可选传入
出参:kvCacheRef出参返回per-tile量化数据,其余出参返回非量化数据
mxfp8全量化 kvCache非量化 weight_quant_mode=3,kv_cache_quant_mode=0,query_quant_mode=0
入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr字段必须传入
出参:所有出参返回非量化数据
kvCache per-tensor量化 weight_quant_mode=3,kv_cache_quant_mode=1,query_quant_mode=1
入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,kvCacheRef传入per-tensor量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr、quantScaleCkv字段必须传入
出参:queryOut返回per-token-head量化数据,kvCacheRef出参返回per-tensor量化数据,其余出参返回非量化数据
kvCache per-tile量化 weight_quant_mode=3,kv_cache_quant_mode=3,query_quant_mode=0
入参:tokenX传入per-token量化数据,weightDq、weightUqQr、weightDkvKr传入per-channel量化数据,其余入参皆为非量化数据。dequantScaleX、dequantScaleWDq、dequantScaleWUqQr、dequantScaleWDkvKr字段必须传入
出参:kvCacheRef出参返回per-tile量化数据,其余出参返回非量化数据
不同量化场景参数的dtype约束
参数名 非量化场景 部分量化场景 int8全量化场景 mxfp8全量化场景 fp8全量化场景 hif8全量化场景
kvCache非量化 kvCache per-channel量化 kvCache per-tile量化 kvCache非量化 kvCache per-tensor量化 kvCache per-tile量化 kvCache非量化 kvCache per-tensor量化 kvCache per-tile量化 kvCache非量化 kvCache per-tensor量化 kvCache per-tile量化 kvCache非量化 kvCache per-tensor量化 kvCache per-tile量化
dtype dtype dtype dtype dtype dtype dtype dtype dtype dtype dtype dtype dtype dtype dtype dtype
tokenX BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 INT8 INT8 INT8 FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN HIFLOAT8 HIFLOAT8 HIFLOAT8
weightDq BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 INT8 INT8 INT8 FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN HIFLOAT8 HIFLOAT8 HIFLOAT8
weightUqQr BFLOAT16 INT8 INT8 INT8 INT8 INT8 INT8 FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN HIFLOAT8 HIFLOAT8 HIFLOAT8
weightUk BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16
weightDkvKr BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 INT8 INT8 INT8 FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN HIFLOAT8 HIFLOAT8 HIFLOAT8
rmsnormGammaCq BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16
rmsnormGammaCkv BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16
ropeSin BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16
ropeCos BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16
kvCacheRef BFLOAT16 BFLOAT16 INT8 INT8 BFLOAT16 INT8 INT8 BFLOAT16 FLOAT8_E4M3FN FLOAT8_E4M3FN BFLOAT16 FLOAT8_E4M3FN FLOAT8_E4M3FN BFLOAT16 HIFLOAT8 HIFLOAT8
krCacheRef BFLOAT16 BFLOAT16 INT8 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16
cacheIndexOptional INT64 INT64 INT64 INT64 INT64 INT64 INT64 INT64 INT64 INT64 INT64 INT64 INT64 INT64 INT64 INT64
dequantScaleXOptional NULLPTR NULLPTR NULLPTR NULLPTR FLOAT FLOAT FLOAT FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT FLOAT FLOAT FLOAT FLOAT FLOAT
dequantScaleWDqOptional NULLPTR NULLPTR NULLPTR NULLPTR FLOAT FLOAT FLOAT FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT FLOAT FLOAT FLOAT FLOAT FLOAT
dequantScaleWUqQrOptional NULLPTR FLOAT FLOAT FLOAT FLOAT FLOAT FLOAT FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT FLOAT FLOAT FLOAT FLOAT FLOAT
dequantScaleWDkvKrOptional NULLPTR NULLPTR NULLPTR NULLPTR FLOAT FLOAT FLOAT FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT FLOAT FLOAT FLOAT FLOAT FLOAT
quantScaleCkvOptional NULLPTR NULLPTR FLOAT NULLPTR NULLPTR FLOAT NULLPTR NULLPTR FLOAT NULLPTR NULLPTR FLOAT NULLPTR NULLPTR FLOAT NULLPTR
quantScaleCkrOptional NULLPTR NULLPTR FLOAT NULLPTR NULLPTR NULLPTR NULLPTR NULLPTR NULLPTR NULLPTR NULLPTR NULLPTR NULLPTR NULLPTR NULLPTR NULLPTR
smoothScalesCqOptional NULLPTR FLOAT FLOAT FLOAT FLOAT FLOAT FLOAT NULLPTR NULLPTR NULLPTR FLOAT FLOAT FLOAT FLOAT FLOAT FLOAT
kNopeClipAlphaOptional NULLPTR NULLPTR NULLPTR FLOAT NULLPTR NULLPTR FLOAT NULLPTR NULLPTR NULLPTR NULLPTR NULLPTR FLOAT NULLPTR NULLPTR FLOAT
queryOut BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 INT8 BFLOAT16 BFLOAT16 FLOAT8_E4M3FN BFLOAT16 BFLOAT16 FLOAT8_E4M3FN BFLOAT16 BFLOAT16 HIFLOAT8 BFLOAT16
queryRopeOut BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16 BFLOAT16
dequantScaleQNopeOutOptional NULLPTR NULLPTR NULLPTR NULLPTR NULLPTR FLOAT NULLPTR NULLPTR FLOAT NULLPTR NULLPTR FLOAT NULLPTR NULLPTR FLOAT NULLPTR
queryNormOutOptional BFLOAT16 INT8 INT8 INT8 INT8 INT8 INT8 FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E4M3FN HIFLOAT8 HIFLOAT8 HIFLOAT8
dequantScaleQNormOutOptional NULLPTR FLOAT FLOAT FLOAT FLOAT FLOAT FLOAT FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT FLOAT FLOAT FLOAT FLOAT FLOAT

调用示例

Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例

#include <iostream>
#include <cstring>
#include <vector>
#include <cstdint>
#include "acl/acl.h"
#include "aclnnop/aclnn_mla_prolog_v3_weight_nz.h"
#include<unistd.h>

#define CHECK_RET(cond, return_expr) \
  do {                               \
    if (!(cond)) {                   \
      return_expr;                   \
    }                                \
  } while (0)

#define LOG_PRINT(message, ...)      \
do {                               \
  printf(message, ##__VA_ARGS__);  \
} while (0)

int64_t GetShapeSize(const std::vector<int64_t>& shape) {
    int64_t shape_size = 1;
    for (auto i : shape) {
        shape_size *= i;
    }
    return shape_size;
}

int Init(int32_t deviceId, aclrtStream* stream) {
    // 固定写法,资源初始化
    auto ret = aclInit(nullptr);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
    ret = aclrtSetDevice(deviceId);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret);
    ret = aclrtCreateStream(stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
    return 0;
}

template <typename T>
int CreateAclTensorND(const std::vector<T>& shape, void** deviceAddr, void** hostAddr,
                    aclDataType dataType, aclTensor** tensor) {
    auto size = GetShapeSize(shape) * sizeof(T);
    // 调用aclrtMalloc申请device侧内存
    auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);
    // 调用aclrtMalloc申请host侧内存
    ret = aclrtMallocHost(hostAddr, size);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMallocHost failed. ERROR: %d\n", ret); return ret);
    memset(*hostAddr, 0, size);
    // 调用aclCreateTensor接口创建aclTensor
    *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, nullptr, 0, aclFormat::ACL_FORMAT_ND,
                              shape.data(), shape.size(), *deviceAddr);
    // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
    ret = aclrtMemcpy(*deviceAddr, size, *hostAddr, GetShapeSize(shape)*aclDataTypeSize(dataType), ACL_MEMCPY_HOST_TO_DEVICE);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);
    return 0;
}

template <typename T>
int CreateAclTensorNZ(const std::vector<T>& shape, void** deviceAddr, void** hostAddr,
                    aclDataType dataType, aclTensor** tensor) {
    auto size = GetShapeSize(shape) * sizeof(T);
    // 调用aclrtMalloc申请device侧内存
    auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);
    // 调用aclrtMalloc申请host侧内存
    ret = aclrtMallocHost(hostAddr, size);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMallocHost failed. ERROR: %d\n", ret); return ret);
    memset(*hostAddr, 0, size);
    // 调用aclCreateTensor接口创建aclTensor
    *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, nullptr, 0, aclFormat::ACL_FORMAT_FRACTAL_NZ,
                              shape.data(), shape.size(), *deviceAddr);
    // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
    ret = aclrtMemcpy(*deviceAddr, size, *hostAddr, GetShapeSize(shape) * aclDataTypeSize(dataType), ACL_MEMCPY_HOST_TO_DEVICE);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);
    return 0;
}

int TransToNZShape(std::vector<int64_t> &shapeND, size_t typeSize) {
    if (typeSize == static_cast<size_t>(0)) {
      return 0;
    }
    int64_t h = shapeND[0];
    int64_t w = shapeND[1];
    int64_t h0 = static_cast<int64_t>(16);
    int64_t w0 = static_cast<int64_t>(32) / static_cast<int64_t>(typeSize);
    int64_t h1 = h / h0;
    int64_t w1 = w / w0;
    shapeND[0] = w1;
    shapeND[1] = h1;
    shapeND.emplace_back(h0);
    shapeND.emplace_back(w0);
    return 0;
}

int main() {
    // 1. 固定写法,device/stream初始化, 参考AscendCL对外接口列表
    // 根据自己的实际device填写deviceId
    int32_t deviceId = 0;
    aclrtStream stream;
    auto ret = Init(deviceId, &stream);
    // check根据自己的需要处理
    CHECK_RET(ret == 0, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret);
    // 2. 构造输入与输出,需要根据API的接口定义构造
    std::vector<int64_t> tokenXShape = {8, 1, 7168};            // B,S,He
    std::vector<int64_t> weightDqShape = {7168, 1536};          // He,Hcq
    std::vector<int64_t> weightUqQrShape = {1536, 6144};        // Hcq,N*(D+Dr)
    std::vector<int64_t> weightUkShape = {32, 128, 512};        // N,D,Hckv
    std::vector<int64_t> weightDkvKrShape = {7168, 576};        // He,Hckv+Dr
    std::vector<int64_t> rmsnormGammaCqShape = {1536};          // Hcq
    std::vector<int64_t> rmsnormGammaCkvShape = {512};          // Hckv
    std::vector<int64_t> ropeSinShape = {8, 1, 64};             // B,S,Dr
    std::vector<int64_t> ropeCosShape = {8, 1, 64};             // B,S,Dr
    std::vector<int64_t> cacheIndexShape = {8, 1};              // B,S
    std::vector<int64_t> kvCacheShape = {16, 128, 1, 512};      // BlockNum,BlockSize,Nkv,Hckv
    std::vector<int64_t> krCacheShape = {16, 128, 1, 64};       // BlockNum,BlockSize,Nkv,Dr
    std::vector<int64_t> dequantScaleXShape = {8, 1};           // B*S, 1
    std::vector<int64_t> dequantScaleWDqShape = {1, 1536};      // 1, Hcq
    std::vector<int64_t> dequantScaleWUqQrShape = {1, 6144};    // 1, N*(D+Dr)
    std::vector<int64_t> dequantScaleWDkvKrShape = {1, 576};    // 1, Hckv+Dr
    std::vector<int64_t> quantScaleCkvShape = {1};              // 1
    std::vector<int64_t> smoothScalesCqShape = {1, 1536};       // 1, Hcq
    std::vector<int64_t> queryShape = {8, 1, 32, 512};          // B,S,N,Hckv
    std::vector<int64_t> queryRopeShape = {8, 1, 32, 64};       // B,S,N,Dr
    std::vector<int64_t> dequantScaleQNopeShape = {8, 32, 1};   // B*S, N, 1
    double rmsnormEpsilonCq = 1e-5;
    double rmsnormEpsilonCkv = 1e-5;
    char cacheMode[] = "PA_BSND";

    void* tokenXDeviceAddr = nullptr;
    void* weightDqDeviceAddr = nullptr;
    void* weightUqQrDeviceAddr = nullptr;
    void* weightUkDeviceAddr = nullptr;
    void* weightDkvKrDeviceAddr = nullptr;
    void* rmsnormGammaCqDeviceAddr = nullptr;
    void* rmsnormGammaCkvDeviceAddr = nullptr;
    void* ropeSinDeviceAddr = nullptr;
    void* ropeCosDeviceAddr = nullptr;
    void* cacheIndexDeviceAddr = nullptr;
    void* kvCacheDeviceAddr = nullptr;
    void* krCacheDeviceAddr = nullptr;
    void* dequantScaleXDeviceAddr = nullptr;
    void* dequantScaleWDqDeviceAddr = nullptr;
    void* dequantScaleWUqQrDeviceAddr = nullptr;
    void* dequantScaleWDkvKrDeviceAddr = nullptr;
    void* quantScaleCkvDeviceAddr = nullptr;
    void* smoothScalesCqDeviceAddr = nullptr;
    void* queryDeviceAddr = nullptr;
    void* queryRopeDeviceAddr = nullptr;
    void* dequantScaleQNopeDeviceAddr = nullptr;

    void* tokenXHostAddr = nullptr;
    void* weightDqHostAddr = nullptr;
    void* weightUqQrHostAddr = nullptr;
    void* weightUkHostAddr = nullptr;
    void* weightDkvKrHostAddr = nullptr;
    void* rmsnormGammaCqHostAddr = nullptr;
    void* rmsnormGammaCkvHostAddr = nullptr;
    void* ropeSinHostAddr = nullptr;
    void* ropeCosHostAddr = nullptr;
    void* cacheIndexHostAddr = nullptr;
    void* kvCacheHostAddr = nullptr;
    void* krCacheHostAddr = nullptr;
    void* dequantScaleXHostAddr = nullptr;
    void* dequantScaleWDqHostAddr = nullptr;
    void* dequantScaleWUqQrHostAddr = nullptr;
    void* dequantScaleWDkvKrHostAddr = nullptr;
    void* quantScaleCkvHostAddr = nullptr;
    void* smoothScalesCqHostAddr = nullptr;
    void* queryHostAddr = nullptr;
    void* queryRopeHostAddr = nullptr;
    void* dequantScaleQNopeHostAddr = nullptr;

    aclTensor* tokenX = nullptr;
    aclTensor* weightDq = nullptr;
    aclTensor* weightUqQr = nullptr;
    aclTensor* weightUk = nullptr;
    aclTensor* weightDkvKr = nullptr;
    aclTensor* rmsnormGammaCq = nullptr;
    aclTensor* rmsnormGammaCkv = nullptr;
    aclTensor* ropeSin = nullptr;
    aclTensor* ropeCos = nullptr;
    aclTensor* cacheIndex = nullptr;
    aclTensor* kvCache = nullptr;
    aclTensor* krCache = nullptr;
    aclTensor* dequantScaleX = nullptr;
    aclTensor* dequantScaleWDq = nullptr;
    aclTensor* dequantScaleWUqQr = nullptr;
    aclTensor* dequantScaleWDkvKr = nullptr;
    aclTensor* quantScaleCkv = nullptr;
    aclTensor* smoothScalesCq = nullptr;
    int64_t weightQuantMode = 2;
    int64_t kvQuantMode = 1;
    int64_t queryQuantMode = 1;
    int64_t ckvkrRepoMode = 0;
    int64_t quantScaleRepoMode = 0;
    int64_t tileSize = 128;
    double kNopeClipAlpha = 1.0f;
    double qcQrScale = 1.0f;
    double kcScale = 1.0f;
    aclTensor* query = nullptr;
    aclTensor* queryRope = nullptr;
    aclTensor* dequantScaleQNope = nullptr;

    // 转换三个NZ格式变量的shape
    constexpr size_t EXAMPLE_INT8_SIZE = sizeof(int8_t);
    constexpr size_t EXAMPLE_BFLOAT16_SIZE = sizeof(int16_t);
    ret = TransToNZShape(weightDqShape, EXAMPLE_INT8_SIZE);
    CHECK_RET(ret == 0, LOG_PRINT("trans NZ shape failed.\n"); return ret);
    ret = TransToNZShape(weightUqQrShape, EXAMPLE_INT8_SIZE);
    CHECK_RET(ret == 0, LOG_PRINT("trans NZ shape failed.\n"); return ret);
    ret = TransToNZShape(weightDkvKrShape, EXAMPLE_INT8_SIZE);
    CHECK_RET(ret == 0, LOG_PRINT("trans NZ shape failed.\n"); return ret);

    // 创建tokenX aclTensor
    ret = CreateAclTensorND(tokenXShape, &tokenXDeviceAddr, &tokenXHostAddr, aclDataType::ACL_INT8, &tokenX);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建weightDq aclTensor
    ret = CreateAclTensorNZ(weightDqShape, &weightDqDeviceAddr, &weightDqHostAddr, aclDataType::ACL_INT8, &weightDq);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建weightUqQr aclTensor
    ret = CreateAclTensorNZ(weightUqQrShape, &weightUqQrDeviceAddr, &weightUqQrHostAddr, aclDataType::ACL_INT8, &weightUqQr);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建weightUk aclTensor
    ret = CreateAclTensorND(weightUkShape, &weightUkDeviceAddr, &weightUkHostAddr, aclDataType::ACL_BF16, &weightUk);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建weightDkvKr aclTensor
    ret = CreateAclTensorNZ(weightDkvKrShape, &weightDkvKrDeviceAddr, &weightDkvKrHostAddr, aclDataType::ACL_INT8, &weightDkvKr);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建rmsnormGammaCq aclTensor
    ret = CreateAclTensorND(rmsnormGammaCqShape, &rmsnormGammaCqDeviceAddr, &rmsnormGammaCqHostAddr, aclDataType::ACL_BF16, &rmsnormGammaCq);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建rmsnormGammaCkv aclTensor
    ret = CreateAclTensorND(rmsnormGammaCkvShape, &rmsnormGammaCkvDeviceAddr, &rmsnormGammaCkvHostAddr, aclDataType::ACL_BF16, &rmsnormGammaCkv);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建ropeSin aclTensor
    ret = CreateAclTensorND(ropeSinShape, &ropeSinDeviceAddr, &ropeSinHostAddr, aclDataType::ACL_BF16, &ropeSin);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建ropeCos aclTensor
    ret = CreateAclTensorND(ropeCosShape, &ropeCosDeviceAddr, &ropeCosHostAddr, aclDataType::ACL_BF16, &ropeCos);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建cacheIndex aclTensor
    ret = CreateAclTensorND(cacheIndexShape, &cacheIndexDeviceAddr, &cacheIndexHostAddr, aclDataType::ACL_INT64, &cacheIndex);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建kvCache aclTensor
    ret = CreateAclTensorND(kvCacheShape, &kvCacheDeviceAddr, &kvCacheHostAddr, aclDataType::ACL_INT8, &kvCache);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建krCache aclTensor
    ret = CreateAclTensorND(krCacheShape, &krCacheDeviceAddr, &krCacheHostAddr, aclDataType::ACL_BF16, &krCache);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建dequantScaleX aclTensor
    ret = CreateAclTensorND(dequantScaleXShape, &dequantScaleXDeviceAddr, &dequantScaleXHostAddr, aclDataType::ACL_FLOAT, &dequantScaleX);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建dequantScaleWDq aclTensor
    ret = CreateAclTensorND(dequantScaleWDqShape, &dequantScaleWDqDeviceAddr, &dequantScaleWDqHostAddr, aclDataType::ACL_FLOAT, &dequantScaleWDq);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建dequantScaleWUqQr aclTensor
    ret = CreateAclTensorND(dequantScaleWUqQrShape, &dequantScaleWUqQrDeviceAddr, &dequantScaleWUqQrHostAddr, aclDataType::ACL_FLOAT, &dequantScaleWUqQr);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建dequantScaleWDkvKr aclTensor
    ret = CreateAclTensorND(dequantScaleWDkvKrShape, &dequantScaleWDkvKrDeviceAddr, &dequantScaleWDkvKrHostAddr, aclDataType::ACL_FLOAT, &dequantScaleWDkvKr);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建quantScaleCkv aclTensor
    ret = CreateAclTensorND(quantScaleCkvShape, &quantScaleCkvDeviceAddr, &quantScaleCkvHostAddr, aclDataType::ACL_FLOAT, &quantScaleCkv);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建smoothScalesCq aclTensor
    ret = CreateAclTensorND(smoothScalesCqShape, &smoothScalesCqDeviceAddr, &smoothScalesCqHostAddr, aclDataType::ACL_FLOAT, &smoothScalesCq);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建query aclTensor
    ret = CreateAclTensorND(queryShape, &queryDeviceAddr, &queryHostAddr, aclDataType::ACL_INT8, &query);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建queryRope aclTensor
    ret = CreateAclTensorND(queryRopeShape, &queryRopeDeviceAddr, &queryRopeHostAddr, aclDataType::ACL_BF16, &queryRope);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    // 创建dequantScaleQNope aclTensor
    ret = CreateAclTensorND(dequantScaleQNopeShape, &dequantScaleQNopeDeviceAddr, &dequantScaleQNopeHostAddr, aclDataType::ACL_FLOAT, &dequantScaleQNope);
    CHECK_RET(ret == ACL_SUCCESS, return ret);

    // 3. 调用CANN算子库API,需要修改为具体的API
    uint64_t workspaceSize = 0;
    aclOpExecutor* executor = nullptr;
    // 调用aclnnMlaPrologV3WeightNz第一段接口
    ret = aclnnMlaPrologV3WeightNzGetWorkspaceSize(tokenX, weightDq, weightUqQr, weightUk, weightDkvKr, rmsnormGammaCq, rmsnormGammaCkv, ropeSin, ropeCos, kvCache, krCache, cacheIndex,
      dequantScaleX, dequantScaleWDq, dequantScaleWUqQr, dequantScaleWDkvKr, quantScaleCkv, nullptr, smoothScalesCq, nullptr, nullptr,rmsnormEpsilonCq, rmsnormEpsilonCkv, cacheMode,
      weightQuantMode, kvQuantMode, queryQuantMode, ckvkrRepoMode, quantScaleRepoMode, tileSize, qcQrScale, kcScale,
      query, queryRope, dequantScaleQNope, nullptr, nullptr, &workspaceSize, &executor);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnMlaPrologV3WeightNzGetWorkspaceSize failed. ERROR: %d\n", ret); return ret);
    // 根据第一段接口计算出的workspaceSize申请device内存
    void* workspaceAddr = nullptr;
    if (workspaceSize > static_cast<uint64_t>(0)) {
        ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret;);
    }
    // 调用aclnnMlaPrologV3WeightNz第二段接口
    ret = aclnnMlaPrologV3WeightNz(workspaceAddr, workspaceSize, executor, stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnMlaPrologV3WeightNz failed. ERROR: %d\n", ret); return ret);

    // 4. 固定写法,同步等待任务执行结束
    ret = aclrtSynchronizeStream(stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret);

    // 5. 获取输出的值,将device侧内存上的结果拷贝至host侧,需要根据具体API的接口定义修改
    auto size = GetShapeSize(queryShape);
    std::vector<float> resultData(size, 0);
    ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), queryDeviceAddr, size * sizeof(float),
                      ACL_MEMCPY_DEVICE_TO_HOST);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return ret);
    for (int64_t i = 0; i < size; i++) {
      LOG_PRINT("result[%ld] is: %f\n", i, resultData[i]);
    }
    // 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改
    aclDestroyTensor(tokenX);
    aclDestroyTensor(weightDq);
    aclDestroyTensor(weightUqQr);
    aclDestroyTensor(weightUk);
    aclDestroyTensor(weightDkvKr);
    aclDestroyTensor(rmsnormGammaCq);
    aclDestroyTensor(rmsnormGammaCkv);
    aclDestroyTensor(ropeSin);
    aclDestroyTensor(ropeCos);
    aclDestroyTensor(cacheIndex);
    aclDestroyTensor(kvCache);
    aclDestroyTensor(krCache);
    aclDestroyTensor(dequantScaleX);
    aclDestroyTensor(dequantScaleWDq);
    aclDestroyTensor(dequantScaleWUqQr);
    aclDestroyTensor(dequantScaleWDkvKr);
    aclDestroyTensor(quantScaleCkv);
    aclDestroyTensor(smoothScalesCq);
    aclDestroyTensor(query);
    aclDestroyTensor(queryRope);
    aclDestroyTensor(dequantScaleQNope);

    // 7. 释放device 资源
    aclrtFree(tokenXDeviceAddr);
    aclrtFree(weightDqDeviceAddr);
    aclrtFree(weightUqQrDeviceAddr);
    aclrtFree(weightUkDeviceAddr);
    aclrtFree(weightDkvKrDeviceAddr);
    aclrtFree(rmsnormGammaCqDeviceAddr);
    aclrtFree(rmsnormGammaCkvDeviceAddr);
    aclrtFree(ropeSinDeviceAddr);
    aclrtFree(ropeCosDeviceAddr);
    aclrtFree(cacheIndexDeviceAddr);
    aclrtFree(kvCacheDeviceAddr);
    aclrtFree(krCacheDeviceAddr);
    aclrtFree(dequantScaleXDeviceAddr);
    aclrtFree(dequantScaleWDqDeviceAddr);
    aclrtFree(dequantScaleWUqQrDeviceAddr);
    aclrtFree(dequantScaleWDkvKrDeviceAddr);
    aclrtFree(quantScaleCkvDeviceAddr);
    aclrtFree(smoothScalesCqDeviceAddr);
    aclrtFree(queryDeviceAddr);
    aclrtFree(queryRopeDeviceAddr);
    aclrtFree(dequantScaleQNopeDeviceAddr);

    // 8. 释放host 资源
    aclrtFree(tokenXHostAddr);
    aclrtFree(weightDqHostAddr);
    aclrtFree(weightUqQrHostAddr);
    aclrtFree(weightUkHostAddr);
    aclrtFree(weightDkvKrHostAddr);
    aclrtFree(rmsnormGammaCqHostAddr);
    aclrtFree(rmsnormGammaCkvHostAddr);
    aclrtFree(ropeSinHostAddr);
    aclrtFree(ropeCosHostAddr);
    aclrtFree(cacheIndexHostAddr);
    aclrtFree(kvCacheHostAddr);
    aclrtFree(krCacheHostAddr);
    aclrtFree(dequantScaleXHostAddr);
    aclrtFree(dequantScaleWDqHostAddr);
    aclrtFree(dequantScaleWUqQrHostAddr);
    aclrtFree(dequantScaleWDkvKrHostAddr);
    aclrtFree(quantScaleCkvHostAddr);
    aclrtFree(smoothScalesCqHostAddr);
    aclrtFree(queryHostAddr);
    aclrtFree(queryRopeHostAddr);
    aclrtFree(dequantScaleQNopeHostAddr);

    if (workspaceSize > static_cast<uint64_t>(0)) {
        aclrtFree(workspaceAddr);
    }
    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();

    return 0;
}

Ascend 950PR/Ascend 950DT示例代码如下,仅供参考。

#include <iostream>
#include <cstring>
#include <vector>
#include <cstdint>
#include "acl/acl.h"
#include "aclnnop/aclnn_mla_prolog_v3_weight_nz.h"
#include <unistd.h>

#define CHECK_RET(cond, return_expr)    \
  do {                                \
      if (!(cond)) {                  \
          return_expr;                \
      }                               \
  } while (0)

#define LOG_PRINT(message, ...)         \
  do {                                \
      printf(message, ##__VA_ARGS__); \
  } while (0)

int64_t GetShapeSize(const std::vector<int64_t>& shape) {
  int64_t shape_size = 1;
  for (auto i : shape) {
      shape_size *= i;
  }
  return shape_size;
}

int Init(int32_t deviceId, aclrtStream* stream) {
  // 固定写法,资源初始化
  auto ret = aclInit(nullptr);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
  ret = aclrtSetDevice(deviceId);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret);
  ret = aclrtCreateStream(stream);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
  return 0;
}

template <typename T>
int CreateAclTensorND(const std::vector<T>& shape, void** deviceAddr, void** hostAddr,
                  aclDataType dataType, aclTensor** tensor) {
  auto size = GetShapeSize(shape) * sizeof(T);
  // 调用aclrtMalloc申请device侧内存
  auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);
  // 调用aclrtMalloc申请host侧内存
  ret = aclrtMallocHost(hostAddr, size);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMallocHost failed. ERROR: %d\n", ret); return ret);
  memset(*hostAddr, 0, size);
  // 调用aclCreateTensor接口创建aclTensor
  *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, nullptr, 0, aclFormat::ACL_FORMAT_ND,
                            shape.data(), shape.size(), *deviceAddr);
  // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
  ret = aclrtMemcpy(*deviceAddr, size, *hostAddr, GetShapeSize(shape) * aclDataTypeSize(dataType), ACL_MEMCPY_HOST_TO_DEVICE);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);
  return 0;
}

template <typename T>
int CreateAclTensorNZ(const std::vector<T>& shape, void** deviceAddr, void** hostAddr,
                  aclDataType dataType, aclTensor** tensor) {
  auto size = GetShapeSize(shape) * sizeof(T);
  // 调用aclrtMalloc申请device侧内存
  auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);
  // 调用aclrtMalloc申请host侧内存
  ret = aclrtMallocHost(hostAddr, size);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMallocHost failed. ERROR: %d\n", ret); return ret);
  memset(*hostAddr, 0, size);
  // 调用aclCreateTensor接口创建aclTensor
  *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, nullptr, 0, aclFormat::ACL_FORMAT_FRACTAL_NZ,
                            shape.data(), shape.size(), *deviceAddr);
  // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
  ret = aclrtMemcpy(*deviceAddr, size, *hostAddr, GetShapeSize(shape) * aclDataTypeSize(dataType), ACL_MEMCPY_HOST_TO_DEVICE);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);
  return 0;
}

int TransToNZShape(std::vector<int64_t> &shapeND, size_t typeSize) {
  if (typeSize == static_cast<size_t>(0)) {
      return 0;
  }
  int64_t h = shapeND[0];
  int64_t w = shapeND[1];
  int64_t h0 = static_cast<int64_t>(16);
  int64_t w0 = static_cast<int64_t>(32) / static_cast<int64_t>(typeSize);
  int64_t h1 = h / h0;
  int64_t w1 = w / w0;
  shapeND[0] = w1;
  shapeND[1] = h1;
  shapeND.emplace_back(h0);
  shapeND.emplace_back(w0);
  return 0;
}

int main() {
  // 1. 固定写法,device/stream初始化, 参考AscendCL对外接口列表
  // 根据自己的实际device填写deviceId
  int32_t deviceId = 0;
  aclrtStream stream;
  auto ret = Init(deviceId, &stream);
  // check根据自己的需要处理
  CHECK_RET(ret == 0, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret);
  // 2. 构造输入与输出,需要根据API的接口定义构造
  std::vector<int64_t> tokenXShape = {8, 1, 7168};            // B,S,He
  std::vector<int64_t> weightDqShape = {7168, 1536};          // He,Hcq
  std::vector<int64_t> weightUqQrShape = {1536, 24576};        // Hcq,N*(D+Dr)
  std::vector<int64_t> weightUkShape = {128, 128, 512};        // N,D,Hckv
  std::vector<int64_t> weightDkvKrShape = {7168, 576};        // He,Hckv+Dr
  std::vector<int64_t> rmsnormGammaCqShape = {1536};          // Hcq
  std::vector<int64_t> rmsnormGammaCkvShape = {512};          // Hckv
  std::vector<int64_t> ropeSinShape = {8, 1, 64};             // B,S,Dr
  std::vector<int64_t> ropeCosShape = {8, 1, 64};             // B,S,Dr
  std::vector<int64_t> kvCacheShape = {1, 16, 1, 512};      // BlockNum,BlockSize,Nkv,Hckv
  std::vector<int64_t> krCacheShape = {1, 16, 1, 64};       // BlockNum,BlockSize,Nkv,Dr
  std::vector<int64_t> cacheIndexShape = {8, 1};              // B,S
  std::vector<int64_t> dequantScaleXShape = {8, 224};           // B*S, 1
  std::vector<int64_t> dequantScaleWDqShape = {1536, 224};      // 1, Hcq
  std::vector<int64_t> dequantScaleWUqQrShape = {24576, 48};    // 1, N*(D+Dr)
  std::vector<int64_t> dequantScaleWDkvKrShape = {576, 224};    // 1, Hckv+Dr
  std::vector<int64_t> quantScaleCkvShape = {1};    // 1
  std::vector<int64_t> queryShape = {8, 1, 128, 512};          // B,S,N,Hckv
  std::vector<int64_t> queryRopeShape = {8, 1, 128, 64};       // B,S,N,Dr
  std::vector<int64_t> dequantScaleQNopeShape = {8, 128, 1};   // B*S, N, 1
  double rmsnormEpsilonCq = 1e-5;
  double rmsnormEpsilonCkv = 1e-5;
  char cacheMode[] = "PA_BSND";

  void* tokenXDeviceAddr = nullptr;
  void* weightDqDeviceAddr = nullptr;
  void* weightUqQrDeviceAddr = nullptr;
  void* weightUkDeviceAddr = nullptr;
  void* weightDkvKrDeviceAddr = nullptr;
  void* rmsnormGammaCqDeviceAddr = nullptr;
  void* rmsnormGammaCkvDeviceAddr = nullptr;
  void* ropeSinDeviceAddr = nullptr;
  void* ropeCosDeviceAddr = nullptr;
  void* cacheIndexDeviceAddr = nullptr;
  void* kvCacheDeviceAddr = nullptr;
  void* krCacheDeviceAddr = nullptr;
  void* dequantScaleXDeviceAddr = nullptr;
  void* dequantScaleWDqDeviceAddr = nullptr;
  void* dequantScaleWUqQrDeviceAddr = nullptr;
  void* dequantScaleWDkvKrDeviceAddr = nullptr;
  void* quantScaleCkvDeviceAddr = nullptr;
  void* queryDeviceAddr = nullptr;
  void* queryRopeDeviceAddr = nullptr;
  void* dequantScaleQNopeDeviceAddr = nullptr;

  void* tokenXHostAddr = nullptr;
  void* weightDqHostAddr = nullptr;
  void* weightUqQrHostAddr = nullptr;
  void* weightUkHostAddr = nullptr;
  void* weightDkvKrHostAddr = nullptr;
  void* rmsnormGammaCqHostAddr = nullptr;
  void* rmsnormGammaCkvHostAddr = nullptr;
  void* ropeSinHostAddr = nullptr;
  void* ropeCosHostAddr = nullptr;
  void* cacheIndexHostAddr = nullptr;
  void* kvCacheHostAddr = nullptr;
  void* krCacheHostAddr = nullptr;
  void* dequantScaleXHostAddr = nullptr;
  void* dequantScaleWDqHostAddr = nullptr;
  void* dequantScaleWUqQrHostAddr = nullptr;
  void* dequantScaleWDkvKrHostAddr = nullptr;
  void* quantScaleCkvHostAddr = nullptr;
  void* queryHostAddr = nullptr;
  void* queryRopeHostAddr = nullptr;
  void* dequantScaleQNopeHostAddr = nullptr;

  aclTensor* tokenX = nullptr;
  aclTensor* weightDq = nullptr;
  aclTensor* weightUqQr = nullptr;
  aclTensor* weightUk = nullptr;
  aclTensor* weightDkvKr = nullptr;
  aclTensor* rmsnormGammaCq = nullptr;
  aclTensor* rmsnormGammaCkv = nullptr;
  aclTensor* ropeSin = nullptr;
  aclTensor* ropeCos = nullptr;
  aclTensor* kvCache = nullptr;
  aclTensor* krCache = nullptr;
  aclTensor* cacheIndex = nullptr;
  aclTensor* dequantScaleX = nullptr;
  aclTensor* dequantScaleWDq = nullptr;
  aclTensor* dequantScaleWUqQr = nullptr;
  aclTensor* dequantScaleWDkvKr = nullptr;
  aclTensor* quantScaleCkv = nullptr;
  int64_t weightQuantMode = 3;
  int64_t kvQuantMode = 1;
  int64_t queryQuantMode = 1;
  int64_t ckvkrRepoMode = 0;
  int64_t quantScaleRepoMode = 0;
  int64_t tileSize = 128;
  double qcQrScale = 1.0f;
  double kcScale = 1.0f;
  aclTensor* query = nullptr;
  aclTensor* queryRope = nullptr;
  aclTensor* dequantScaleQNope = nullptr;

  // 转换三个NZ格式变量的shape
  constexpr size_t EXAMPLE_INT8_SIZE = sizeof(int8_t);
  constexpr size_t EXAMPLE_BFLOAT16_SIZE = sizeof(int16_t);
  ret = TransToNZShape(weightDqShape, EXAMPLE_INT8_SIZE);
  CHECK_RET(ret == 0, LOG_PRINT("trans NZ shape failed.\n"); return ret);
  ret = TransToNZShape(weightUqQrShape, EXAMPLE_INT8_SIZE);
  CHECK_RET(ret == 0, LOG_PRINT("trans NZ shape failed.\n"); return ret);
  ret = TransToNZShape(weightDkvKrShape, EXAMPLE_INT8_SIZE);
  CHECK_RET(ret == 0, LOG_PRINT("trans NZ shape failed.\n"); return ret);

  // 创建tokenX aclTensor
  ret = CreateAclTensorND(tokenXShape, &tokenXDeviceAddr, &tokenXHostAddr, aclDataType::ACL_FLOAT8_E4M3FN, &tokenX);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建weightDq aclTensor
  ret = CreateAclTensorNZ(weightDqShape, &weightDqDeviceAddr, &weightDqHostAddr, aclDataType::ACL_FLOAT8_E4M3FN, &weightDq);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建weightUqQr aclTensor
  ret = CreateAclTensorNZ(weightUqQrShape, &weightUqQrDeviceAddr, &weightUqQrHostAddr, aclDataType::ACL_FLOAT8_E4M3FN, &weightUqQr);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建weightUk aclTensor
  ret = CreateAclTensorND(weightUkShape, &weightUkDeviceAddr, &weightUkHostAddr, aclDataType::ACL_BF16, &weightUk);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建weightDkvKr aclTensor
  ret = CreateAclTensorNZ(weightDkvKrShape, &weightDkvKrDeviceAddr, &weightDkvKrHostAddr, aclDataType::ACL_FLOAT8_E4M3FN, &weightDkvKr);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建rmsnormGammaCq aclTensor
  ret = CreateAclTensorND(rmsnormGammaCqShape, &rmsnormGammaCqDeviceAddr, &rmsnormGammaCqHostAddr, aclDataType::ACL_BF16, &rmsnormGammaCq);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建rmsnormGammaCkv aclTensor
  ret = CreateAclTensorND(rmsnormGammaCkvShape, &rmsnormGammaCkvDeviceAddr, &rmsnormGammaCkvHostAddr, aclDataType::ACL_BF16, &rmsnormGammaCkv);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建ropeSin aclTensor
  ret = CreateAclTensorND(ropeSinShape, &ropeSinDeviceAddr, &ropeSinHostAddr, aclDataType::ACL_BF16, &ropeSin);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建ropeCos aclTensor
  ret = CreateAclTensorND(ropeCosShape, &ropeCosDeviceAddr, &ropeCosHostAddr, aclDataType::ACL_BF16, &ropeCos);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建cacheIndex aclTensor
  ret = CreateAclTensorND(cacheIndexShape, &cacheIndexDeviceAddr, &cacheIndexHostAddr, aclDataType::ACL_INT64, &cacheIndex);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建kvCache aclTensor
  ret = CreateAclTensorND(kvCacheShape, &kvCacheDeviceAddr, &kvCacheHostAddr, aclDataType::ACL_FLOAT8_E4M3FN, &kvCache);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建krCache aclTensor
  ret = CreateAclTensorND(krCacheShape, &krCacheDeviceAddr, &krCacheHostAddr, aclDataType::ACL_BF16, &krCache);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建dequantScaleX aclTensor
  ret = CreateAclTensorND(dequantScaleXShape, &dequantScaleXDeviceAddr, &dequantScaleXHostAddr, aclDataType::ACL_FLOAT8_E8M0, &dequantScaleX);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建dequantScaleWDq aclTensor
  ret = CreateAclTensorND(dequantScaleWDqShape, &dequantScaleWDqDeviceAddr, &dequantScaleWDqHostAddr, aclDataType::ACL_FLOAT8_E8M0, &dequantScaleWDq);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建dequantScaleWUqQr aclTensor
  ret = CreateAclTensorND(dequantScaleWUqQrShape, &dequantScaleWUqQrDeviceAddr, &dequantScaleWUqQrHostAddr, aclDataType::ACL_FLOAT8_E8M0, &dequantScaleWUqQr);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建dequantScaleWDkvKr aclTensor
  ret = CreateAclTensorND(dequantScaleWDkvKrShape, &dequantScaleWDkvKrDeviceAddr, &dequantScaleWDkvKrHostAddr, aclDataType::ACL_FLOAT8_E8M0, &dequantScaleWDkvKr);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建quantScaleCkv aclTensor
  ret = CreateAclTensorND(quantScaleCkvShape, &quantScaleCkvDeviceAddr, &quantScaleCkvHostAddr, aclDataType::ACL_FLOAT, &quantScaleCkv);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建query aclTensor
  ret = CreateAclTensorND(queryShape, &queryDeviceAddr, &queryHostAddr, aclDataType::ACL_FLOAT8_E4M3FN, &query);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建queryRope aclTensor
  ret = CreateAclTensorND(queryRopeShape, &queryRopeDeviceAddr, &queryRopeHostAddr, aclDataType::ACL_BF16, &queryRope);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建dequantScaleQNope aclTensor
  ret = CreateAclTensorND(dequantScaleQNopeShape, &dequantScaleQNopeDeviceAddr, &dequantScaleQNopeHostAddr, aclDataType::ACL_FLOAT, &dequantScaleQNope);
  CHECK_RET(ret == ACL_SUCCESS, return ret);


  // 3. 调用CANN算子库API,需要修改为具体的API
  uint64_t workspaceSize = 0;
  aclOpExecutor* executor = nullptr;
  // 调用aclnnMlaPrologV3WeightNz第一段接口
  ret = aclnnMlaPrologV3WeightNzGetWorkspaceSize(tokenX, weightDq, weightUqQr, weightUk, weightDkvKr, rmsnormGammaCq, rmsnormGammaCkv, ropeSin, ropeCos, kvCache, krCache, cacheIndex,
      dequantScaleX, dequantScaleWDq, dequantScaleWUqQr, dequantScaleWDkvKr, quantScaleCkv, nullptr, nullptr, nullptr, nullptr, rmsnormEpsilonCq, rmsnormEpsilonCkv, cacheMode,
      weightQuantMode, kvQuantMode, queryQuantMode, ckvkrRepoMode, quantScaleRepoMode, tileSize, qcQrScale, kcScale,
      query, queryRope, dequantScaleQNope, nullptr, nullptr, &workspaceSize, &executor);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnMlaPrologV3WeightNzGetWorkspaceSize failed. ERROR: %d\n", ret); return ret);
  // 根据第一段接口计算出的workspaceSize申请device内存
  void* workspaceAddr = nullptr;
  if (workspaceSize > static_cast<uint64_t>(0)) {
      ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret;);
  }
  // 调用aclnnMlaPrologV3WeightNz第二段接口
  ret = aclnnMlaPrologV3WeightNz(workspaceAddr, workspaceSize, executor, stream);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnMlaPrologV3WeightNz failed. ERROR: %d\n", ret); return ret);

  // 4. 固定写法,同步等待任务执行结束
  ret = aclrtSynchronizeStream(stream);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret);

  // 5. 获取输出的值,将device侧内存上的结果拷贝至host侧,需要根据具体API的接口定义修改
  auto size = GetShapeSize(queryShape);
  std::vector<float> resultData(size, 0);
  ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), queryDeviceAddr, size * sizeof(float),
                    ACL_MEMCPY_DEVICE_TO_HOST);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return ret);
  for (int64_t i = 0; i < size; i++) {
      LOG_PRINT("result[%ld] is: %f\n", i, resultData[i]);
  }
  // 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改
  aclDestroyTensor(tokenX);
  aclDestroyTensor(weightDq);
  aclDestroyTensor(weightUqQr);
  aclDestroyTensor(weightUk);
  aclDestroyTensor(weightDkvKr);
  aclDestroyTensor(rmsnormGammaCq);
  aclDestroyTensor(rmsnormGammaCkv);
  aclDestroyTensor(ropeSin);
  aclDestroyTensor(ropeCos);
  aclDestroyTensor(cacheIndex);
  aclDestroyTensor(kvCache);
  aclDestroyTensor(krCache);
  aclDestroyTensor(dequantScaleX);
  aclDestroyTensor(dequantScaleWDq);
  aclDestroyTensor(dequantScaleWUqQr);
  aclDestroyTensor(dequantScaleWDkvKr);
  aclDestroyTensor(quantScaleCkv);
  aclDestroyTensor(query);
  aclDestroyTensor(queryRope);
  aclDestroyTensor(dequantScaleQNope);

  // 7. 释放device 资源
  aclrtFree(tokenXDeviceAddr);
  aclrtFree(weightDqDeviceAddr);
  aclrtFree(weightUqQrDeviceAddr);
  aclrtFree(weightUkDeviceAddr);
  aclrtFree(weightDkvKrDeviceAddr);
  aclrtFree(rmsnormGammaCqDeviceAddr);
  aclrtFree(rmsnormGammaCkvDeviceAddr);
  aclrtFree(ropeSinDeviceAddr);
  aclrtFree(ropeCosDeviceAddr);
  aclrtFree(cacheIndexDeviceAddr);
  aclrtFree(kvCacheDeviceAddr);
  aclrtFree(krCacheDeviceAddr);
  aclrtFree(dequantScaleXDeviceAddr);
  aclrtFree(dequantScaleWDqDeviceAddr);
  aclrtFree(dequantScaleWUqQrDeviceAddr);
  aclrtFree(dequantScaleWDkvKrDeviceAddr);
  aclrtFree(quantScaleCkvDeviceAddr);
  aclrtFree(queryDeviceAddr);
  aclrtFree(queryRopeDeviceAddr);
  aclrtFree(dequantScaleQNopeDeviceAddr);

  // 8. 释放host 资源
  aclrtFree(tokenXHostAddr);
  aclrtFree(weightDqHostAddr);
  aclrtFree(weightUqQrHostAddr);
  aclrtFree(weightUkHostAddr);
  aclrtFree(weightDkvKrHostAddr);
  aclrtFree(rmsnormGammaCqHostAddr);
  aclrtFree(rmsnormGammaCkvHostAddr);
  aclrtFree(ropeSinHostAddr);
  aclrtFree(ropeCosHostAddr);
  aclrtFree(cacheIndexHostAddr);
  aclrtFree(kvCacheHostAddr);
  aclrtFree(krCacheHostAddr);
  aclrtFree(dequantScaleXHostAddr);
  aclrtFree(dequantScaleWDqHostAddr);
  aclrtFree(dequantScaleWUqQrHostAddr);
  aclrtFree(dequantScaleWDkvKrHostAddr);
  aclrtFree(quantScaleCkvHostAddr);
  aclrtFree(queryHostAddr);
  aclrtFree(queryRopeHostAddr);
  aclrtFree(dequantScaleQNopeHostAddr);

  if (workspaceSize > static_cast<uint64_t>(0)) {
      aclrtFree(workspaceAddr);
  }
  aclrtDestroyStream(stream);
  aclrtResetDevice(deviceId);
  aclFinalize();

  _exit(0);
}