ScatterPaKvCache

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:更新KvCache中指定位置的key和value。

  • 输入输出场景根据构造的参数来区别,输入输出支持以下场景,其中场景一、场景二、场景六没有compressLensOptional、seqLensOptional、compressSeqOffsetOptional这三个可选参数,场景四没有compressSeqOffsetOptional可选参数:

    • 场景一:

      key:[batch * seq_len, num_head, k_head_size]
      value:[batch * seq_len, num_head, v_head_size]
      keyCache:[num_blocks, num_head * k_head_size // last_dim_k, block_size, last_dim_k]/[num_blocks, num_head, k_head_size // last_dim_k, block_size, last_dim_k]
      valueCache:[num_blocks, num_head * v_head_size // last_dim_v, block_size, last_dim_v]/[num_blocks, num_head, v_head_size // last_dim_v, block_size, last_dim_v]
      slotMapping:[batch * seq_len]
      cacheMode:"PA_NZ"
      
    • 场景二:

      key:[batch * seq_len, num_head, k_head_size]
      value:[batch * seq_len, num_head, v_head_size]
      keyCache:[num_blocks, block_size, num_head, k_head_size]
      valueCache:[num_blocks, block_size, num_head, v_head_size]
      slotMapping:[batch * seq_len]
      cacheMode:"Norm"
      scatter_mode:"None"/"Nct"
      

      其中k_head_size与v_head_size可以不同,也可以相同。

    • 场景三:

      key:[batch, seq_len, num_head, k_head_size]
      value:[batch, seq_len, num_head, v_head_size]
      keyCache:[num_blocks, block_size, 1, k_head_size]
      valueCache:[num_blocks, block_size, 1, v_head_size]
      slotMapping:[batch, num_head]
      compressLensOptional:[batch, num_head]
      seqLensOptional:[batch]
      compressSeqOffsetOptional:[batch * num_head]
      cacheMode:"Norm"
      
    • 场景四:

      key:[num_tokens, num_head, k_head_size]
      value:[num_tokens, num_head, v_head_size]
      keyCache:[num_blocks, block_size, 1, k_head_size]
      valueCache:[num_blocks, block_size, 1, v_head_size]
      slotMapping:[batch * num_head]
      compressLensOptional:[batch * num_head]
      seqLensOptional:[batch]
      cacheMode:"Norm"
      scatter_mode:"Alibi"
      
    • 场景五:

      key:[num_tokens, num_head, k_head_size]
      value:[num_tokens, num_head, v_head_size]
      keyCache:[num_blocks, block_size, 1, k_head_size]
      valueCache:[num_blocks, block_size, 1, v_head_size]
      slotMapping:[batch * num_head]
      compressLensOptional:[batch * num_head]
      seqLensOptional:[batch]
      compressSeqOffsetOptional:[batch * num_head]
      cacheMode:"Norm"
      scatter_mode:"Rope"/"Omni"
      
    • 场景六:

      key:[batch * seq_len, num_head, k_head_size]
      value:[]
      keyCache:[num_blocks, block_size, num_head, k_head_size]
      valueCache:[]
      slotMapping:[batch * seq_len]
      cacheMode:"Norm"
      scatter_mode:"None"/"Nct"
      
    • 场景七:

      key:[num_tokens, num_head, k_head_size]
      value:[num_tokens, num_head, v_head_size]
      keyCache:[num_blocks, num_head, block_size, k_head_size]
      valueCache:[num_blocks, num_head, block_size, v_head_size]
      slotMapping:[num_tokens]
      cacheMode:"Norm"
      scatter_mode:"NHSD"
      
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:仅支持场景一、二、四、五、六、七。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
key 输入 待更新的key值,当前step多个token的key。 FLOAT16、FLOAT、BFLOAT16、INT8、UINT8、INT16、UINT16、INT32、UINT32、HIFLOAT8、FLOAT8_E5M2、FLOAT8_E4M3FN ND
keyCacheRef 输入/输出 需要更新的key cache,当前layer的key cache。 FLOAT16、FLOAT、BFLOAT16、INT8、UINT8、INT16、UINT16、INT32、UINT32、HIFLOAT8、FLOAT8_E5M2、FLOAT8_E4M3FN ND
slotMapping 输入 每个token key或value在cache中的存储偏移。 INT32、INT64 ND
value 输入 待更新的value值,当前step多个token的value。 FLOAT16、FLOAT、BFLOAT16、INT8、UINT8、INT16、UINT16、INT32、UINT32、HIFLOAT8、FLOAT8_E5M2、FLOAT8_E4M3FN ND
valueCacheRef 输入/输出 需要更新的value cache,当前layer的value cache。 FLOAT16、FLOAT、BFLOAT16、INT8、UINT8、INT16、UINT16、INT32、UINT32、HIFLOAT8、FLOAT8_E5M2、FLOAT8_E4M3FN ND
compressLensOptional 可选输入 压缩量。 INT32、INT64 ND
compressSeqOffsetOptional 可选输入 每个batch每个head的压缩起点。 INT32、INT64 ND
seqLensOptional 可选输入 每个batch的实际seqLens。 INT32、INT64 ND
cacheMode 输入 表示keyCacheRef和valueCacheRef的内存排布格式。 STRING -
scatterMode 输入 表示更新的key和value的状态。 STRING -
strides 输入 key和value在非连续状态下的步长,数组长度为2。其值应该大于0。 INT64 -
offsets 输入 key和value在非连续状态下的偏移,数组长度为2。其值应该大于0。 INT64 -
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:
    • key/value数据类型仅支持:FLOAT16、BFLOAT16、INT8;
    • cacheMode当传空指针或"Norm"时,仅支持ND内存排布格式,当传"PA_NZ"时,仅支持FRACTAL_NZ内存排布格式;
    • scatterMode当传空指针或"None"时,表示更新的key和value是非压缩状态且连续,当传"Alibi"时,表示更新key和value是基于Alibi结构的压缩状态,当传"Rope"时,表示更新key和value是基于Rope结构的压缩状态,当传"Omni"时,表示更新key和value是基于Omni结构的压缩状态,当传"Nct"时,表示更新的key和value是非压缩状态但非连续;
    • strides和offsets仅当scatterMode为"Nct"时生效,分别表示strideK和strideV、offsetK和offsetV。

约束说明

  • 输入shape限制:
    • 除了key和value,输入参数不支持非连续。
    • 当key和value都是3维,则key和value的前两维shape必须相同。
    • 当key和value都是4维,则key和value的前三维shape必须相同,且keyCacheRef和valueCacheRef的第三维必须是1。
    • 当key和value是4维时,compressLensOptional、seqLensOptional为必选参数;当key和value是3维时,compressLensOptional、compressSeqOffsetOptional、seqLensOptional为可选参数。
    • 当cacheMode为“PA_NZ”时,keyCacheRef和valueCacheRef的倒数第二维必须小于UINT16_MAX(对应场景一)。
    • k_head_size和v_head_size必须32字节对齐(对应场景七)。
    • num_head必须小于4095(对应场景七)。
  • 输入值域限制:
    • slotMapping的取值范围[0,num_blocks*block_size-1],且slotMapping内的元素值保证不重复,重复时不保证正确性。
    • 当key和value都是4维时,slotMapping是二维,且slotMapping的第一维值等于key的第一维为batch,slotMapping的第二维值等于key的第三维为num_head(对应场景三)。
    • 当key和value都是4维时,seqLensOptional是一维,且seqLensOptional的值等于key的第一维为batch(对应场景三)。
    • 当key和value是3维且存在seqLensOptional时,seqLensOptional中所有值的和等于key的第一维为num_blocks(对应场景四、五)。
    • seqLensOptional和compressLensOptional里面的每个元素值必须满足公式:reduceSum(seqLensOptional[i] - compressLensOptional[i]) <= num_blocks * block_size(对应场景三、四、五)。
    • block_size * k_head_size和block_size * v_head_size必须小于UINT32_MAX(对应场景七)。
  • 输入属性限制:
    • key、value、keyCacheRef、valueCacheRef的数据类型必须一致。
    • slotMapping、compressLensOptional、compressSeqOffsetOptional、seqLensOptional的数据类型必须一致。

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_ScatterPaKvCache 通过aclnnScatterPaKvCache调用ScatterPaKvCache算子
图模式 test_geir_ScatterPaKvCache 通过算子IR调用ScatterPaKvCache算子