KvRmsNormRopeCache
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | × |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
| Kirin X90 处理器系列产品 | √ |
| Kirin 9030 处理器系列产品 | √ |
功能说明
-
算子功能:对输入张量(kv)的尾轴,拆分出左半边用于rms_norm计算,右半边用于rope计算,再将计算结果分别scatter到两块cache中。
-
计算公式:
(1) interleaveRope:
x=kv[...,Dv:]x=kv[...,Dv:]
x1=x[...,::2]x1=x[...,::2]
x2=x[...,1::2]x2=x[...,1::2]
x_part1=torch.cat((x1,x2),dim=−1)x\_part1=torch.cat((x1,x2),dim=-1)
x_part2=torch.cat((−x2,x1),dim=−1)x\_part2=torch.cat((-x2,x1),dim=-1)
y=x_part1∗cos+x_part2∗siny=x\_part1*cos+x\_part2*sin
(2) rmsNorm:
x=kv[...,:Dv]x=kv[...,:Dv]
square_x=x∗xsquare\_x=x*x
mean_square_x=square_x.mean(dim=−1,keepdim=True)mean\_square\_x=square\_x.mean(dim=-1,keepdim=True)
rms=torch.sqrt(mean_square_x+epsilon)rms=torch.sqrt(mean\_square\_x+epsilon)
y=(x/rms)∗gammay=(x/rms)*gamma
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| kv | 输入 | 用于切分出rms_norm计算所需数据Dv和rope计算所需数据Dk的输入数据,对应公式中的`kv`。 | FLOAT16、BFLOAT16 | ND |
| gamma | 输入 | 用于rms_norm计算的输入数据,对应公式中的`gamma`。 | FLOAT16、BFLOAT16 | ND |
| cos | 输入 | 用于rope计算的输入数据,对输入张量进行余弦变换,对应公式中的`cos`。 | FLOAT16、BFLOAT16 | ND |
| sin | 输入 | 用于rope计算的输入数据,对输入张量进行正弦变换,对应公式中的`sin`。 | FLOAT16、BFLOAT16 | ND |
| index | 输入 | 用于指定写入cache的具体索引位置。 | INT64 | ND |
| k_cache | 输入/输出 | 提前申请的cache,输入输出同地址复用。 | FLOAT16、BFLOAT16、INT8、HIFLOAT8、FLOAT8E5M2、FLOAT8E4M3FN | ND |
| ckv_cache | 输入/输出 | 提前申请的cache,输入输出同地址复用。 | FLOAT16、BFLOAT16、INT8、HIFLOAT8、FLOAT8E5M2、FLOAT8E4M3FN | ND |
| k_rope_scale | 可选属性 | 当kCacheRef数据类型为INT8、HIFLOAT8、FLOAT8E5M2、FLOAT8E4M3FN时需要此输入参数。 | FLOAT32 | ND |
| c_kv_scale | 可选属性 | 当ckv_cache数据类型为INT8、HIFLOAT8、FLOAT8E5M2、FLOAT8E4M3FN时需要此输入参数。 | FLOAT32 | ND |
| k_rope_offset | 可选属性 | 当k_cache数据类型为INT8、HIFLOAT8、FLOAT8E5M2、FLOAT8E4M3FN且对应的k_rope_scale输入存在并量化场景为非对称量化时,需要此参数输入。 | FLOAT32 | ND |
| c_kv_offset | 可选属性 | 当ckv_cache数据类型为INT8、HIFLOAT8、FLOAT8E5M2、FLOAT8E4M3FN且对应的c_kv_scale输入存在并量化场景为非对称量化时,需要此参数输入。 | FLOAT32 | ND |
| epsilon | 可选属性 |
|
FLOAT32 | - |
| cache_mode | 可选属性 | cache格式的选择标记。类型有Norm、PA、PA_BNSD、PA_NZ、PA_BLK_BNSD、PA_BLK_NZ。 | CHAR* | - |
| is_output_kv | 可选属性 | kRopeOut和cKvOut输出控制标记。 | BOOL | - |
| k_rope | 输出 | rope计算结果,对应interleaveRope计算公式中的`y`。由isOutputKv控制,当isOutputKv为true时,需输出。 | FLOAT16、BFLOAT16 | ND |
| c_kv | 输出 | rms_norm计算结果,对应rmsNorm计算公式中的`y`。由isOutputKv控制,当isOutputKv为true时,需输出。 | FLOAT16、BFLOAT16 | ND |
- Kirin X90/Kirin 9030 处理器系列产品: 不支持BFLOAT16、HIFLOAT8、FLOAT8E5M2、FLOAT8E4M3FN。
约束说明
- 输入shape限制:
- kv为四维张量,shape为[Bkv,N,Skv,D],Bkv为输入kv的batch size,Skv为输入kv的sequence length,大小由用户输入场景决定,无明确限制。
- N为输入kv的head number。此算子与DeepSeekV3网络结构强相关,仅支持N=1的场景,不存在N非1的场景。
- D为输入kv的head dim。rms_norm计算所需数据Dv和rope计算所需数据Dk由输入kv的D切分而来。故Dk、Dv大小需满足Dk+Dv=D。同时,Dk需满足rope规则。根据rope规则,Dk为偶数。
- 若cache_mode为PA场景(cache_mode为PA、PA_BNSD、PA_NZ、PA_BLK_BNSD、PA_BLK_NZ),其shape[BlockNum,BlockSize,N,Dk]中BlockSize需32B对齐。
- 输入张量均不支持空Tensor。
- 其他限制:
- 对于index,当cache_mode为Norm时,shape为2维[Bkv,Skv],要求index的value值范围为[-1,Scache)。不同的Bkv下,value数值可以重复。
- 当cache_mode为PA_BNSD、PA_NZ时,shape为1维[Bkv * Skv],要求index的value值范围为[-1,BlockNum * BlockSize)。value数值不能重复。
- 当cache_mode为PA_BLK_BNSD、PA_BLK_NZ时,shape为1维[Bkv * ceil_div(Skv,BlockSize)],要求index的value的数值范围为[-1,BlockNum * BlockSize)。value/BlockSize的值不能重复。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_kv_rms_norm_rope_cache | 通过aclnnKvRmsNormRopeCache接口方式调用KvRmsNormRopeCache算子。 |
| 图模式 | test_geir_kv_rms_norm_rope_cache | 通过算子IR构图方式调用KvRmsNormRopeCache算子。 |