QkvRmsNormRopeCache

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:输入qkv融合张量,通过SplitVD拆分q、k、v张量,执行RmsNorm、ApplyRotaryPosEmb、Quant、Scatter融合操作,输出q_out、k_cache、v_cache、q_out_before_quant(可选)、k_out_before_quant(可选)、v_out_before_quant(可选)。

  • 本算子目前支持的场景如下表:

    场景类型 情况概要
    • cache_mode为PA_NZ
    • q无量化
    • k和v支持无量化、对称量化和非对称量化
    • q_out_before_quant/k_out_before_quant/v_out_before_quant不输出
    qkv Shape为[BqkvB_{qkv} * SqkvS_{qkv}, NqkvN_{qkv} * DqkvD_{qkv}],q、k、v具有完全相同的D维度。主要计算过程与输出对应关系:
    • qkv 经过SplitVD->q、k、v
    • q经过RmsNorm、RoPE->q_out
    • k经过RmsNorm、RoPE、Quant(可选)、Scatter->k_cache
    • v经过Quant(可选)、Scatter->v_cache
  • 计算公式:

    (1) SplitVD:

    下式中,NqN_qNkN_kNvN_v分别表示q、k、v分量的注意力头数量,必须满足:

    {Nk=NvNqkv=Nk+Nv+NqDqkv=Dq=Dk=Dv\begin{cases} N_k = N_v \\ N_{qkv} = N_k + N_v + N_q \\ D_{qkv} = D_q = D_k = D_v \end{cases}

    q=qkv[...,[:Nq]∗Dqkv]k=qkv[...,[Nq:−Nv]∗Dqkv]v=qkv[...,[−Nv:]∗Dqkv]\begin{aligned} q &= qkv[..., [:N_q] * D_{qkv}] \\ k &= qkv[..., [N_q:-N_v] * D_{qkv}] \\ v &= qkv[..., [-N_v:] * D_{qkv}] \end{aligned}

    (2) RmsNorm:

    此处x和y分别表示RmsNorm的输入张量和输出张量,归一化沿最后一维(feature dimension)进行,该计算规则通用于q、k分量。

    squareX=x∗xsquareX = x * x

    meanSquareX=squareX.mean(dim=−1,keepdim=True)meanSquareX = squareX.mean(dim = -1, keepdim = True)

    rms=meanSquareX+epsilonrms = \sqrt{meanSquareX + epsilon}

    y=(x/rms)∗gammay = (x / rms) * gamma

    (3) RoPE (Half-and-Half):

    此处的y指代完成RmsNorm计算的输出结果。

    y1=y[…,:d/2]y1 = y[\ldots, :d/2]

    y2=y[…,d/2:]y2 = y[\ldots, d/2:]

    y_RoPE=torch.cat((−y2,y1),dim=−1)y\_RoPE = torch.cat((-y2, y1), dim = -1)

    y_embed=(y∗cos)+y_RoPE∗siny\_embed = (y * cos) + y\_RoPE * sin

    (4) Quant:

    无量化:

    kQuant=kRoPEvQuant=vkQuant = kRoPE \\ vQuant = v

    对称量化部分:

    kQuant=kRoPE/kScalevQuant=v/vScalekQuant = kRoPE / kScale \\ vQuant = v / vScale

    非对称量化部分:

    kQuant=kRoPE/kScale+kOffsetvQuant=v/vScale+vOffsetkQuant = kRoPE / kScale + kOffset \\ vQuant = v / vScale + vOffset

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
qkv 输入 用于切分出q、k、v的输入数据,对应公式中的qkv。shape为[Bqkv * Sqkv, Nqkv * Dqkv]。 FLOAT16、BFLOAT16 ND
q_gamma 输入 用于q的rms_norm计算的输入数据,对应公式中的gamma。与输入qkv的数据类型相同,shape为[Dqkv]。 FLOAT16、BFLOAT16 ND
k_gamma 输入 用于k的rms_norm计算的输入数据,对应公式中的gamma。与输入qkv的数据类型相同,shape为[Dqkv]。 FLOAT16、BFLOAT16 ND
cos 输入 用于rope计算的输入数据,对输入张量进行余弦变换,对应公式中的cos。与输入qkv的数据类型相同,shape为[Bqkv * Sqkv, 1 * D_rope],D_rope = Dqkv且要求(Dqkv * qkv数据类型所占字节数)可以被32整除。 FLOAT16、BFLOAT16 ND
sin 输入 用于rope计算的输入数据,对输入张量进行正弦变换,对应公式中的sin。与输入cos的数据类型、格式保持一致。 FLOAT16、BFLOAT16 ND
index 输入 用于指定写入cache的具体索引位置。shape为[Bqkv * Sqkv]。 INT64 ND
q_out 输入/输出 提前申请的cache,输入输出同地址复用。与输入qkv的数据类型相同,shape为[Bqkv * Sqkv, Nq * Dqkv]。 FLOAT16、BFLOAT16 ND
k_cache 输入/输出 提前申请的cache,输入输出同地址复用。与输入qkv的数据类型相同(k不量化),或者INT8(k量化)。shape为[BlockNum, Nk * Dqkv // 16, BlockSize, 16](k不量化),或者[BlockNum, Nk * Dqkv // 32, BlockSize, 32](k量化)。 FLOAT16、BFLOAT16、INT8 ND
v_cache 输入/输出 提前申请的cache,输入输出同地址复用。与输入qkv的数据类型相同(v不量化),或者INT8(v量化)。shape为[BlockNum, Nk * Dqkv // 16, BlockSize, 16](v不量化),或者[BlockNum, Nk * Dqkv // 32, BlockSize, 32](v量化)。 FLOAT16、BFLOAT16、INT8 ND
k_scale 可选输入 当k_cache数据类型为INT8时需要此输入参数,对应公式中的kScale。shape为[Nk, Dqkv]。 FLOAT32 ND
v_scale 可选输入 当v_cache数据类型为INT8时需要此输入参数,对应公式中的vScale。shape为[Nv, Dqkv]。 FLOAT32 ND
k_offset 可选输入 当k_cache数据类型为INT8且对应的k_scale输入存在并量化场景为非对称量化时,需要此参数输入,对应公式中的kOffset。shape为[Nk, Dqkv]。 FLOAT32 ND
v_offset 可选输入 当v_cache数据类型为INT8且对应的v_scale输入存在并量化场景为非对称量化时,需要此参数输入,对应公式中的vOffset。shape为[Nv, Dqkv]。 FLOAT32 ND
q_out_before_quant 可选输出 即将写入到q_out中的数据。 FLOAT16、BFLOAT16 ND
k_out_before_quant 可选输出 即将写入到k_cache中的数据,在未经量化和Scatter前的中间计算结果。 FLOAT16、BFLOAT16 ND
v_out_before_quant 可选输出 即将写入到v_cache中的数据,在未经量化和Scatter前的中间计算结果。 FLOAT16、BFLOAT16 ND
qkv_size 属性 按[Bqkv, Sqkv, Nqkv, Dqkv]顺序传入,提供输入参数qkv矩阵的B,S,N,D维度具体尺寸。 INT64 -
head_nums 属性 按[Nq, Nk, Nv]顺序传入,提供输入参数qkv矩阵中,qkv分量单元中分的N维度具体尺寸。 INT64 -
epsilon 可选属性
  • 用于防止RmsNorm计算除0错误,对应公式中的epsilon。
  • 默认值为1e-6。
FLOAT32 -
cache_mode 可选属性
  • cache格式的选择标记,目前只支持PA_NZ。
  • 默认值为PA_NZ。
CHAR* -
is_output_qkv 可选属性
  • 表示是否需要输出各cache输出中对应内容在未经量化和Scatter前的原始值。
  • 默认值为false。
BOOL -

约束说明

  • 输入shape限制:
    • Bqkv为输入qkv的batch_size,Sqkv为输入qkv的sequence length,大小由qkvSize决定。
    • Nqkv为输入qkv的head number。Dqkv为输入qkv的head dim,目前仅支持128。Dq、Dk和Dk分别为q、k、v的head dim,要求Dqkv = Dq = Dk = Dv,Dqkv需要满足(Dqkv*qkv数据类型占字节数)可以被32整除。
    • 根据rope规则,Dk和Dq为偶数。若cache_mode为PA_NZ场景下,Dk、Dq需32B对齐;BlockSize需32B对齐。
    • 关于上述32B对齐的情形,对齐值由cache的数据类型决定。以BlockSize为例,若cache的数据类型为int8,则需要满足BlockSize % 32 = 0;若cache的数据类型为float16,则需要满足BlockSize % 16 = 0;若k_cache与v_cache参数的dtype不一致,BlockSize需同时满足BlockSize % 32 = 0和BlockSize % 16 = 0。
    • BlockNum为写入cache的内存块数,大小由用户输入场景决定,要求BlockNum >= Ceil(Sqkv / BlockSize) * Bqkv
    • 使用requireMemory表示存放数据所需的空间大小,需满足:requireMemory >= (Bqkv * Sqkv * Nqkv * Dqkv + 2 * Dqkv + 2 * Bqkv * Sqkv * Dqkv + Bqkv * Sqkv * Nq * Dqkv + BlockNum * BlockSize * Nv * Dqkv + BlockNum * BlockSize * Nk * Dqkv) * sizeof(FLOAT16) + Bqkv * Sqkv * sizeof(INT64) + (2 * Nk * Dqkv + 2 * Nv) * sizeof(FLOAT),当计算出requireMemory的大小超过当前AI处理器的GM空间总大小,不支持使用该算子。
  • 其他限制:
    • 对于index,要求index的value值范围为[-1, BlockNum * BlockSize)。value数值不可以重复,index为-1时,代表跳过更新。
    • k_scale, v_scale表示对称量化的缩放因子,因此若传参,则值不能为0。

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_qkv_rms_norm_rope_cache 通过aclnnQkvRmsNormRopeCache接口方式调用QkvRmsNormRopeCache算子。