KvRmsNormRopeCache

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×
Kirin X90 处理器系列产品
Kirin 9030 处理器系列产品

功能说明

  • 算子功能:对输入张量(kv)的尾轴,拆分出左半边用于rms_norm计算,右半边用于rope计算,再将计算结果分别scatter到两块cache中。

  • 计算公式:

    (1) interleaveRope:

    x=kv[...,Dv:]x=kv[...,Dv:]

    x1=x[...,::2]x1=x[...,::2]

    x2=x[...,1::2]x2=x[...,1::2]

    x_part1=torch.cat((x1,x2),dim=−1)x\_part1=torch.cat((x1,x2),dim=-1)

    x_part2=torch.cat((−x2,x1),dim=−1)x\_part2=torch.cat((-x2,x1),dim=-1)

    y=x_part1∗cos+x_part2∗siny=x\_part1*cos+x\_part2*sin

    (2) rmsNorm:

    x=kv[...,:Dv]x=kv[...,:Dv]

    square_x=x∗xsquare\_x=x*x

    mean_square_x=square_x.mean(dim=−1,keepdim=True)mean\_square\_x=square\_x.mean(dim=-1,keepdim=True)

    rms=torch.sqrt(mean_square_x+epsilon)rms=torch.sqrt(mean\_square\_x+epsilon)

    y=(x/rms)∗gammay=(x/rms)*gamma

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
kv 输入 用于切分出rms_norm计算所需数据Dv和rope计算所需数据Dk的输入数据,对应公式中的`kv`。 FLOAT16、BFLOAT16 ND
gamma 输入 用于rms_norm计算的输入数据,对应公式中的`gamma`。 FLOAT16、BFLOAT16 ND
cos 输入 用于rope计算的输入数据,对输入张量进行余弦变换,对应公式中的`cos`。 FLOAT16、BFLOAT16 ND
sin 输入 用于rope计算的输入数据,对输入张量进行正弦变换,对应公式中的`sin`。 FLOAT16、BFLOAT16 ND
index 输入 用于指定写入cache的具体索引位置。 INT64 ND
k_cache 输入/输出 提前申请的cache,输入输出同地址复用。 FLOAT16、BFLOAT16、INT8、HIFLOAT8、FLOAT8E5M2、FLOAT8E4M3FN ND
ckv_cache 输入/输出 提前申请的cache,输入输出同地址复用。 FLOAT16、BFLOAT16、INT8、HIFLOAT8、FLOAT8E5M2、FLOAT8E4M3FN ND
k_rope_scale 可选属性 当kCacheRef数据类型为INT8、HIFLOAT8、FLOAT8E5M2、FLOAT8E4M3FN时需要此输入参数。 FLOAT32 ND
c_kv_scale 可选属性 当ckv_cache数据类型为INT8、HIFLOAT8、FLOAT8E5M2、FLOAT8E4M3FN时需要此输入参数。 FLOAT32 ND
k_rope_offset 可选属性 当k_cache数据类型为INT8、HIFLOAT8、FLOAT8E5M2、FLOAT8E4M3FN且对应的k_rope_scale输入存在并量化场景为非对称量化时,需要此参数输入。 FLOAT32 ND
c_kv_offset 可选属性 当ckv_cache数据类型为INT8、HIFLOAT8、FLOAT8E5M2、FLOAT8E4M3FN且对应的c_kv_scale输入存在并量化场景为非对称量化时,需要此参数输入。 FLOAT32 ND
epsilon 可选属性
  • 用于防止rms_norm计算除0错误,对应公式中的eps。
  • 默认值为1e-5。
FLOAT32 -
cache_mode 可选属性 cache格式的选择标记。类型有Norm、PA、PA_BNSD、PA_NZ、PA_BLK_BNSD、PA_BLK_NZ。 CHAR* -
is_output_kv 可选属性 kRopeOut和cKvOut输出控制标记。 BOOL -
k_rope 输出 rope计算结果,对应interleaveRope计算公式中的`y`。由isOutputKv控制,当isOutputKv为true时,需输出。 FLOAT16、BFLOAT16 ND
c_kv 输出 rms_norm计算结果,对应rmsNorm计算公式中的`y`。由isOutputKv控制,当isOutputKv为true时,需输出。 FLOAT16、BFLOAT16 ND
  • Kirin X90/Kirin 9030 处理器系列产品: 不支持BFLOAT16、HIFLOAT8、FLOAT8E5M2、FLOAT8E4M3FN。

约束说明

  • 输入shape限制:
    • kv为四维张量,shape为[Bkv,N,Skv,D],Bkv为输入kv的batch size,Skv为输入kv的sequence length,大小由用户输入场景决定,无明确限制。
    • N为输入kv的head number。此算子与DeepSeekV3网络结构强相关,仅支持N=1的场景,不存在N非1的场景。
    • D为输入kv的head dim。rms_norm计算所需数据Dv和rope计算所需数据Dk由输入kv的D切分而来。故Dk、Dv大小需满足Dk+Dv=D。同时,Dk需满足rope规则。根据rope规则,Dk为偶数。
    • 若cache_mode为PA场景(cache_mode为PA、PA_BNSD、PA_NZ、PA_BLK_BNSD、PA_BLK_NZ),其shape[BlockNum,BlockSize,N,Dk]中BlockSize需32B对齐。
    • 输入张量均不支持空Tensor。
  • 其他限制:
    • 对于index,当cache_mode为Norm时,shape为2维[Bkv,Skv],要求index的value值范围为[-1,Scache)。不同的Bkv下,value数值可以重复。
    • 当cache_mode为PA_BNSD、PA_NZ时,shape为1维[Bkv * Skv],要求index的value值范围为[-1,BlockNum * BlockSize)。value数值不能重复。
    • 当cache_mode为PA_BLK_BNSD、PA_BLK_NZ时,shape为1维[Bkv * ceil_div(Skv,BlockSize)],要求index的value的数值范围为[-1,BlockNum * BlockSize)。value/BlockSize的值不能重复。

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_kv_rms_norm_rope_cache 通过aclnnKvRmsNormRopeCache接口方式调用KvRmsNormRopeCache算子。
图模式 test_geir_kv_rms_norm_rope_cache 通过算子IR构图方式调用KvRmsNormRopeCache算子。