QkvRmsNormRopeCache
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | × |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:输入qkv融合张量,通过SplitVD拆分q、k、v张量,执行RmsNorm、ApplyRotaryPosEmb、Quant、Scatter融合操作,输出q_out、k_cache、v_cache、q_out_before_quant(可选)、k_out_before_quant(可选)、v_out_before_quant(可选)。
-
本算子目前支持的场景如下表:
场景类型 情况概要 - cache_mode为PA_NZ
- q无量化
- k和v支持无量化、对称量化和非对称量化
- q_out_before_quant/k_out_before_quant/v_out_before_quant不输出
qkv Shape为[BqkvB_{qkv} * SqkvS_{qkv}, NqkvN_{qkv} * DqkvD_{qkv}],q、k、v具有完全相同的D维度。主要计算过程与输出对应关系: - qkv 经过SplitVD->q、k、v
- q经过RmsNorm、RoPE->q_out
- k经过RmsNorm、RoPE、Quant(可选)、Scatter->k_cache
- v经过Quant(可选)、Scatter->v_cache
-
计算公式:
(1) SplitVD:
下式中,NqN_q、NkN_k、NvN_v分别表示q、k、v分量的注意力头数量,必须满足:
{Nk=NvNqkv=Nk+Nv+NqDqkv=Dq=Dk=Dv\begin{cases} N_k = N_v \\ N_{qkv} = N_k + N_v + N_q \\ D_{qkv} = D_q = D_k = D_v \end{cases}
q=qkv[...,[:Nq]∗Dqkv]k=qkv[...,[Nq:−Nv]∗Dqkv]v=qkv[...,[−Nv:]∗Dqkv]\begin{aligned} q &= qkv[..., [:N_q] * D_{qkv}] \\ k &= qkv[..., [N_q:-N_v] * D_{qkv}] \\ v &= qkv[..., [-N_v:] * D_{qkv}] \end{aligned}
(2) RmsNorm:
此处x和y分别表示RmsNorm的输入张量和输出张量,归一化沿最后一维(feature dimension)进行,该计算规则通用于q、k分量。
squareX=x∗xsquareX = x * x
meanSquareX=squareX.mean(dim=−1,keepdim=True)meanSquareX = squareX.mean(dim = -1, keepdim = True)
rms=meanSquareX+epsilonrms = \sqrt{meanSquareX + epsilon}
y=(x/rms)∗gammay = (x / rms) * gamma
(3) RoPE (Half-and-Half):
此处的y指代完成RmsNorm计算的输出结果。
y1=y[…,:d/2]y1 = y[\ldots, :d/2]
y2=y[…,d/2:]y2 = y[\ldots, d/2:]
y_RoPE=torch.cat((−y2,y1),dim=−1)y\_RoPE = torch.cat((-y2, y1), dim = -1)
y_embed=(y∗cos)+y_RoPE∗siny\_embed = (y * cos) + y\_RoPE * sin
(4) Quant:
无量化:
kQuant=kRoPEvQuant=vkQuant = kRoPE \\ vQuant = v
对称量化部分:
kQuant=kRoPE/kScalevQuant=v/vScalekQuant = kRoPE / kScale \\ vQuant = v / vScale
非对称量化部分:
kQuant=kRoPE/kScale+kOffsetvQuant=v/vScale+vOffsetkQuant = kRoPE / kScale + kOffset \\ vQuant = v / vScale + vOffset
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| qkv | 输入 | 用于切分出q、k、v的输入数据,对应公式中的qkv。shape为[Bqkv * Sqkv, Nqkv * Dqkv]。 | FLOAT16、BFLOAT16 | ND |
| q_gamma | 输入 | 用于q的rms_norm计算的输入数据,对应公式中的gamma。与输入qkv的数据类型相同,shape为[Dqkv]。 | FLOAT16、BFLOAT16 | ND |
| k_gamma | 输入 | 用于k的rms_norm计算的输入数据,对应公式中的gamma。与输入qkv的数据类型相同,shape为[Dqkv]。 | FLOAT16、BFLOAT16 | ND |
| cos | 输入 | 用于rope计算的输入数据,对输入张量进行余弦变换,对应公式中的cos。与输入qkv的数据类型相同,shape为[Bqkv * Sqkv, 1 * D_rope],D_rope = Dqkv且要求(Dqkv * qkv数据类型所占字节数)可以被32整除。 | FLOAT16、BFLOAT16 | ND |
| sin | 输入 | 用于rope计算的输入数据,对输入张量进行正弦变换,对应公式中的sin。与输入cos的数据类型、格式保持一致。 | FLOAT16、BFLOAT16 | ND |
| index | 输入 | 用于指定写入cache的具体索引位置。shape为[Bqkv * Sqkv]。 | INT64 | ND |
| q_out | 输入/输出 | 提前申请的cache,输入输出同地址复用。与输入qkv的数据类型相同,shape为[Bqkv * Sqkv, Nq * Dqkv]。 | FLOAT16、BFLOAT16 | ND |
| k_cache | 输入/输出 | 提前申请的cache,输入输出同地址复用。与输入qkv的数据类型相同(k不量化),或者INT8(k量化)。shape为[BlockNum, Nk * Dqkv // 16, BlockSize, 16](k不量化),或者[BlockNum, Nk * Dqkv // 32, BlockSize, 32](k量化)。 | FLOAT16、BFLOAT16、INT8 | ND |
| v_cache | 输入/输出 | 提前申请的cache,输入输出同地址复用。与输入qkv的数据类型相同(v不量化),或者INT8(v量化)。shape为[BlockNum, Nk * Dqkv // 16, BlockSize, 16](v不量化),或者[BlockNum, Nk * Dqkv // 32, BlockSize, 32](v量化)。 | FLOAT16、BFLOAT16、INT8 | ND |
| k_scale | 可选输入 | 当k_cache数据类型为INT8时需要此输入参数,对应公式中的kScale。shape为[Nk, Dqkv]。 | FLOAT32 | ND |
| v_scale | 可选输入 | 当v_cache数据类型为INT8时需要此输入参数,对应公式中的vScale。shape为[Nv, Dqkv]。 | FLOAT32 | ND |
| k_offset | 可选输入 | 当k_cache数据类型为INT8且对应的k_scale输入存在并量化场景为非对称量化时,需要此参数输入,对应公式中的kOffset。shape为[Nk, Dqkv]。 | FLOAT32 | ND |
| v_offset | 可选输入 | 当v_cache数据类型为INT8且对应的v_scale输入存在并量化场景为非对称量化时,需要此参数输入,对应公式中的vOffset。shape为[Nv, Dqkv]。 | FLOAT32 | ND |
| q_out_before_quant | 可选输出 | 即将写入到q_out中的数据。 | FLOAT16、BFLOAT16 | ND |
| k_out_before_quant | 可选输出 | 即将写入到k_cache中的数据,在未经量化和Scatter前的中间计算结果。 | FLOAT16、BFLOAT16 | ND |
| v_out_before_quant | 可选输出 | 即将写入到v_cache中的数据,在未经量化和Scatter前的中间计算结果。 | FLOAT16、BFLOAT16 | ND |
| qkv_size | 属性 | 按[Bqkv, Sqkv, Nqkv, Dqkv]顺序传入,提供输入参数qkv矩阵的B,S,N,D维度具体尺寸。 | INT64 | - |
| head_nums | 属性 | 按[Nq, Nk, Nv]顺序传入,提供输入参数qkv矩阵中,qkv分量单元中分的N维度具体尺寸。 | INT64 | - |
| epsilon | 可选属性 |
|
FLOAT32 | - |
| cache_mode | 可选属性 |
|
CHAR* | - |
| is_output_qkv | 可选属性 |
|
BOOL | - |
约束说明
- 输入shape限制:
- Bqkv为输入qkv的batch_size,Sqkv为输入qkv的sequence length,大小由qkvSize决定。
- Nqkv为输入qkv的head number。Dqkv为输入qkv的head dim,目前仅支持128。Dq、Dk和Dk分别为q、k、v的head dim,要求Dqkv = Dq = Dk = Dv,Dqkv需要满足(Dqkv*qkv数据类型占字节数)可以被32整除。
- 根据rope规则,Dk和Dq为偶数。若cache_mode为PA_NZ场景下,Dk、Dq需32B对齐;BlockSize需32B对齐。
- 关于上述32B对齐的情形,对齐值由cache的数据类型决定。以BlockSize为例,若cache的数据类型为int8,则需要满足BlockSize % 32 = 0;若cache的数据类型为float16,则需要满足BlockSize % 16 = 0;若k_cache与v_cache参数的dtype不一致,BlockSize需同时满足BlockSize % 32 = 0和BlockSize % 16 = 0。
- BlockNum为写入cache的内存块数,大小由用户输入场景决定,要求BlockNum >= Ceil(Sqkv / BlockSize) * Bqkv。
- 使用requireMemory表示存放数据所需的空间大小,需满足:requireMemory >= (Bqkv * Sqkv * Nqkv * Dqkv + 2 * Dqkv + 2 * Bqkv * Sqkv * Dqkv + Bqkv * Sqkv * Nq * Dqkv + BlockNum * BlockSize * Nv * Dqkv + BlockNum * BlockSize * Nk * Dqkv) * sizeof(FLOAT16) + Bqkv * Sqkv * sizeof(INT64) + (2 * Nk * Dqkv + 2 * Nv) * sizeof(FLOAT),当计算出requireMemory的大小超过当前AI处理器的GM空间总大小,不支持使用该算子。
- 其他限制:
- 对于index,要求index的value值范围为[-1, BlockNum * BlockSize)。value数值不可以重复,index为-1时,代表跳过更新。
- k_scale, v_scale表示对称量化的缩放因子,因此若传参,则值不能为0。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_qkv_rms_norm_rope_cache | 通过aclnnQkvRmsNormRopeCache接口方式调用QkvRmsNormRopeCache算子。 |