GatherPaKvCache

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:根据blockTables中的blockId值、seqLens中key/value的seqLen从keyCache/valueCache中将内存不连续的token搬运、拼接成连续的key/value序列。

  • 计算逻辑:

    • keyRef/valueRef的第一个维度取决于seq_lens大小。
    • 如果isSeqLensCumsum为true,则seqLens中最后一个值即为keyRef/valueRef的第一个维度大小: keyRef[dim0] = seqLens[-1]
    • 如果isSeqLensCumsum为false,则seqLens中所有值的累加和即为keyRef/valueRef的第一个维度大小:keyRef[dim0] = sum(seqLens)

    关于keyRefvalueRef的一些限制条件如下:

    • 每个token大小控制在148k以内,例如,对于fp16/bf16类型,num_heads * head_size(keyRef/valueRef)取128*576。
  • 示例:

      keyCache_shape: [128, 128, 16, 144]
      valueCache_shape: [128, 128, 16, 128]
      blockTables_shape: [16, 12]
      seqLens_shape: [16]
      keyRef_shape: [8931, 16, 144]
      valueRef_shape: [8931, 16, 128]
      seqOffset_shape: [16]
      out1_shape: [8931, 16, 144]  
      out2_shape: [8931, 16, 128]        
    

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
keyCache 输入 当前层存储的key向量缓存 INT8, FLOAT16, BFLOAT16, FLOAT, UINT8, INT16, UINT16, INT32, UINT32, HIFLOAT8, FLOAT8_E5M2, FLOAT8_E4M3FN ND
valueCache 输入 当前层存储的value向量缓存 INT8, FLOAT16, BFLOAT16, FLOAT, UINT8, INT16, UINT16, INT32, UINT32, HIFLOAT8, FLOAT8_E5M2, FLOAT8_E4M3FN FRACTAL_NZ
blockTables 输入 每个batch中KV Cache的逻辑块到物理块的映射关系 INT32、INT64 ND
seqLens 输入 每个batch对应的序列长度 INT32、INT64 ND
keyRef 输入/输出 当前层的key向量 INT8, FLOAT16, BFLOAT16, FLOAT, UINT8, INT16, UINT16, INT32, UINT32, HIFLOAT8, FLOAT8_E5M2, FLOAT8_E4M3FN ND
valueRef 输入/输出 当前层的value向量 INT8, FLOAT16, BFLOAT16, FLOAT, UINT8, INT16, UINT16, INT32, UINT32, HIFLOAT8, FLOAT8_E5M2, FLOAT8_E4M3FN ND
seqOffset 输入 blockTables获取blockId时存在的首偏移 INT32、INT64 ND
cacheMode 输入 表示输入的数据排布格式,支持Norm、PA_NZ String ND
isSeqLensCumsum 输入 表示seqLens是否为累加和。false表示非累加和 BOOL ND

约束说明

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_gather_pa_kv_cache 通过aclnnGatherPaKvCache调用GatherPaKvCache算子