SparseAttnSharedkv

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理系列产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • API功能:SparseAttnSharedKV算子旨在完成以下公式描述的Attention计算,支持Sliding Window Attention、Compressed Attention以及Sparse Compressed Attention。

  • 计算公式:

    O=softmax(Q@K~T⋅softmax_scale)@V~O = \text{softmax}(Q@\tilde{K}^T \cdot \text{softmax\_scale})@\tilde{V}

    其中K~=V~\tilde{K}=\tilde{V}为基于ori_kv、cmp_kv以及cmp_ratio等入参控制的实际参与计算的 KVKV

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
q 输入 对应公式中的QQ BFLOAT16、FLOAT16 ND
ori_kv 可选输入 对应公式中的K~和V~\tilde{K}和\tilde{V}的一部分,为原始不经压缩的KV。 BFLOAT16、FLOAT16 ND
cmp_kv 可选输入 对应公式中的K~和V~\tilde{K}和\tilde{V}的一部分,为经过压缩的KV。 BFLOAT16、FLOAT16 ND
ori_sparse_indices 可选输入 代表离散取oriKvCache的索引。 INT32 ND
cmp_sparse_indices 可选输入 代表离散取cmpKvCache的索引。 INT32 ND
ori_block_table 可选输入 表示PageAttention中oriKvCache存储使用的block映射表。 INT32 ND
cmp_block_table 可选输入 表示PageAttention中cmpKvCache存储使用的block映射表。 INT32 ND
cu_seqlens_q 可选输入 表示不同Batch中q的有效token数。 INT32 ND
cu_seqlens_ori_kv 可选输入 表示不同Batch中ori_kv的有效token数。 INT32 ND
cu_seqlens_cmp_kv 可选输入 表示不同Batch中cmp_kv的有效token数。 INT32 ND
seqused_q 可选输入 表示不同Batch中q实际参与运算的token数。 INT32 ND
seqused_kv 可选输入 表示不同Batch中ori_kv实际参与运算的token数。 INT32 ND
sinks 可选输入 注意力下沉tensor。 FLOAT32 ND
metadata 可选输入 aicpu算子(npu_sparse_attn_sharedkv_metadata)的分核结果。 INT32 ND
softmax_scale 可选属性 代表缩放系数,对应公式中的softmax_scale\text{softmax\_scale},默认值为None。 FLOAT32 -
cmp_ratio 可选属性 表示对ori_kv的压缩率,仅支持输入4或128,默认值为None。 INT32 -
ori_mask_mode 可选属性 表示qori_kv计算的mask模式,仅支持输入默认值4。 INT32 -
cmp_mask_mode 可选属性 表示qcmp_kv计算的mask模式,仅支持输入默认值3。 INT32 -
ori_win_left 可选属性 表示qori_kv计算中q对过去token计算的数量,仅支持输入默认值127。 INT32 -
ori_win_right 可选属性 表示qori_kv计算中q对未来token计算的数量,仅支持输入默认值0。 INT32 -
layout_q 可选属性 用于标识输入q的数据排布格式,支持输入"TND"和"BSND",默认值为"BSND"。 STRING -
layout_kv 可选属性 用于标识输入ori_kvcmp_kv的数据排布格式,支持输入"PA_ND"和"BSND"。 STRING -
return_softmax_lse 可选属性 表示是否返回softmax_lse。True表示返回,False表示不返回,默认值为False。 BOOL -
attention_out 输出 公式中的输出。 BFLOAT16、FLOAT16 ND
softmax_lse 输出 返回的softmax_lse FLOAT32 ND

约束说明

  • 该接口支持推理场景下使用。

  • 该接口支持aclgraph模式。

  • 该接口当前支持三种计算场景:场景一,仅传入ori_kv时为Sliding Window Attention计算;场景二,传入ori_kvcmp_kv时为Sliding Window Attention + Compressed Attention计算;场景三,传入ori_kvcmp_kvcmp_sparse_indices时为Sliding Window Attention + Sparse Compressed Attention计算。

  • layout_q为TND时,功能使用限制如下:

    • q的shape需要为[T1,N1,D],其中N1仅支持64。
    • ori_sparse_indices的shape需要为[Q_T, KV_N, K1],其中K1为对ori_kv一次离散选取的token数,K1仅支持512。
    • cmp_sparse_indices的shape需要为[Q_T, KV_N, K2],其中K2为对cmp_kv一次离散选取的token数,K2仅支持512。
    • cu_seqlens_q必须传入,输入维度为B+1,大小为参数中每个元素的值表示当前batch与之前所有batch的token数总和,即前缀和,因此后一个元素的值必须>=前一个元素的值。
  • layout_q为BSND时,功能使用限制如下:

    • q的shape需要为[B, Q_S,N1,D],其中N1仅支持64。
    • ori_sparse_indices的shape需要为[B, Q_S, KV_N, K1],其中K1为对ori_kv一次离散选取的token数,K1仅支持512。
    • cmp_sparse_indices的shape需要为[B, Q_S, KV_N, K2],其中K2为对cmp_kv一次离散选取的token数,K2仅支持512。
  • PageAttention场景下,功能使用限制如下:

    • ori_kvcmp_kv的shape分别为[ori_block_num, ori_block_size, KV_N, D]和[cmp_block_num, cmp_block_size, KV_N, D],其中ori_block_num和cmp_block_num为PageAttention时block总数,ori_block_size和cmp_block_size为一个block的token数,ori_block_size和cmp_block_size取值为16的倍数,最大支持1024,KV_N仅支持1。
    • ori_block_tablecmp_block_table的shape为2维,其中第一维长度为B,第二维长度不小于所有batch中最大的S2和S3对应的block数量,即S2_max / block_size和S3_max / block_size向上取整。
  • metadata为算子实际需要使用的分核结果,目前该参数必传,shape大小固定为[1024]。

  • layout_kv仅支持输入"PA_ND"和"BSND"。

    • 当输入为PA_ND时,设置cu_seqlens_ori_kvcu_seqlens_cmp_kv无效。
    • 当输入为BSND时,ori_kvcmp_kv的layout都必须为BSND,ori_kv的shape为[B, S2, N2,D],cmp_kv的shape为[B, S3, N2,D]。
  • 目前暂不支持返回softmax_lsereturn_softmax_lse仅支持输入False,返回值softmax_lse为无效值。

  • ori_mask_mode及cmp_mask_mode所表示的mask模式的详细介绍见sparse_mode参数说明

  • 目前暂不支持指定q中参与运算的token数,因此设置seqused_q无效。

  • 目前暂不支持对ori_kv进行稀疏计算,因此设置ori_sparse_indices无效。

  • 目前所有输入不支持传入空tensor。

  • qori_kvcmp_kv数据排布格式支持从多种维度解读,B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Hidden-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示hidden层最小的单元尺寸,且满足D=H/N、T表示所有Batch输入样本序列长度的累加和。

  • Q_S和S1表示q shape中的S,S2表示ori_kv shape中的S,S3表示cmp_kv shape中的S;Q_N和N1表示num_q_heads,KV_N和N2表示num_ori_kv_heads和num_cmp_kv_heads;Q_T和T1表示q shape中的输入样本序列长度的累加和。

Atlas A3 推理系列产品 调用说明

  • 单算子模式调用

    import torch
    import torch_npu
    import numpy as np
    import random
    import math
    import custom_ops
    
    data_type = torch.bfloat16
    softmax_scale = 0.041666666666666664
    b = 4
    s1 = 128
    s2 = 8192
    n1 = 64
    n2 = 1
    dn = 512
    k = 512
    ori_block_size = 128
    cmp_block_size = 128
    s2_act = 4096
    cmp_ratio = 4
    ori_win_left = 127
    ori_win_right = 0
    layout_q = 'TND'
    layout_kv = 'PA_ND'
    ori_mask_mode = 4
    cmp_mask_mode = 3
    q = torch.tensor(np.random.uniform(-10, 10, (b*s1, n1, dn))).to(data_type).npu()
    
    cu_seqlens_q = torch.arange(0, (b + 1) * s1, step=s1).to(torch.int32).npu()
    t = cu_seqlens_q[-1].item()
    seqused_kv = torch.tensor([s2_act]*b).to(torch.int32).npu()
    
    cmp_kv_len = s2_act // cmp_ratio
    idxs = random.sample(range(cmp_kv_len - s1 + 1),  k)
    cmp_sparse_indices = torch.tensor([idxs for _ in range(t * n2)]).reshape(t, n2, k). \
        to(torch.int32).npu()
    
    ori_block_num =  math.ceil(s2_act/ori_block_size) * b
    ori_block_table = torch.tensor(np.random.permutation(range(ori_block_num))).to(torch.int32).reshape(b, -1).npu()
    ori_kv = torch.tensor(np.random.uniform(-5, 10, (ori_block_num, ori_block_size, n2, dn))).to(data_type).npu()
    
    block_num2 =  math.ceil(cmp_kv_len/ori_block_size) * b
    cmp_block_table = torch.tensor(np.random.permutation(range(block_num2))).to(torch.int32).reshape(b, -1).npu()
    cmp_kv = torch.tensor(np.random.uniform(-5, 10, (block_num2, cmp_block_size, n2, dn))).to(data_type).npu()
    sinks = torch.rand(n1).to(torch.float32).npu()
    metadata = torch.ops.custom.npu_sparse_attn_sharedkv_metadata(
        num_heads_q=n1,
        num_heads_kv=n2,
        head_dim=dn,
        cu_seqlens_q=cu_seqlens_q,
        seqused_kv=seqused_kv,
        batch_size=b,
        max_seqlen_q=s1,
        max_seqlen_kv=s2,
        cmp_topk=k,
        cmp_ratio=cmp_ratio,
        ori_mask_mode=ori_mask_mode,
        cmp_mask_mode=cmp_mask_mode,
        ori_win_left=ori_win_left,
        ori_win_right=ori_win_right,
        layout_q=layout_q,
        layout_kv=layout_kv,
        has_ori_kv=True,
        has_cmp_kv=True
    )
    attn_out, softmax_lse = torch.ops.custom.npu_sparse_attn_sharedkv(
        q,
        ori_kv=ori_kv,
        cmp_kv=cmp_kv,
        ori_sparse_indices=None,
        cmp_sparse_indices=cmp_sparse_indices,
        ori_block_table=ori_block_table,
        cmp_block_table=cmp_block_table,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_ori_kv=None,
        cu_seqlens_cmp_kv=None,
        seqused_q=None,
        seqused_kv=seqused_kv,
        sinks=sinks,
        metadata=metadata,
        softmax_scale=softmax_scale,
        cmp_ratio=cmp_ratio,
        ori_mask_mode=ori_mask_mode,
        cmp_mask_mode=cmp_mask_mode,
        ori_win_left=ori_win_left,
        ori_win_right=ori_win_right,
        layout_q=layout_q,
        layout_kv=layout_kv,
        return_softmax_lse=False)
    
  • aclgraph模式调用

    import torch
    import torch_npu
    import numpy as np
    import random
    import math
    import torchair
    import custom_ops
    
    data_type = torch.bfloat16
    softmax_scale = 0.041666666666666664
    b = 4
    s1 = 128
    s2 = 8192
    n1 = 64
    n2 = 1
    dn = 512
    k = 512
    ori_block_size = 128
    cmp_block_size = 128
    s2_act = 4096
    cmp_ratio = 4
    ori_win_left = 127
    ori_win_right = 0
    layout_q = 'TND'
    layout_kv = 'PA_ND'
    ori_mask_mode = 4
    cmp_mask_mode = 3
    q = torch.tensor(np.random.uniform(-10, 10, (b*s1, n1, dn))).to(data_type).npu()
    
    cu_seqlens_q = torch.arange(0, (b + 1) * s1, step=s1).to(torch.int32).npu()
    t = cu_seqlens_q[-1].item()
    seqused_kv = torch.tensor([s2_act]*b).to(torch.int32).npu()
    
    cmp_kv_len = s2_act // cmp_ratio
    idxs = random.sample(range(cmp_kv_len - s1 + 1),  k)
    cmp_sparse_indices = torch.tensor([idxs for _ in range(t * n2)]).reshape(t, n2, k). \
        to(torch.int32).npu()
    
    ori_block_num =  math.ceil(s2_act/ori_block_size) * b
    ori_block_table = torch.tensor(np.random.permutation(range(ori_block_num))).to(torch.int32).reshape(b, -1).npu()
    ori_kv = torch.tensor(np.random.uniform(-5, 10, (ori_block_num, ori_block_size, n2, dn))).to(data_type).npu()
    
    block_num2 =  math.ceil(cmp_kv_len/ori_block_size) * b
    cmp_block_table = torch.tensor(np.random.permutation(range(block_num2))).to(torch.int32).reshape(b, -1).npu()
    cmp_kv = torch.tensor(np.random.uniform(-5, 10, (block_num2, cmp_block_size, n2, dn))).to(data_type).npu()
    sinks = torch.rand(n1).to(torch.float32).npu()
    
    from torchair.configs.compiler_config import CompilerConfig
    config = CompilerConfig()
    config.mode = "reduce-overhead"
    npu_backend = torchair.get_npu_backend(compiler_config=config)
    
    class Network(torch.nn.Module):
        def __init__(self):
            super(Network, self).__init__()
    
        def forward(self, num_heads_q, num_heads_kv, head_dim, batch_size, max_seqlen_q, max_seqlen_kv,
            topk, has_ori_kv, has_cmp_kv, q, ori_kv, cmp_kv, cmp_sparse_indices, ori_block_table,
            cmp_block_table, cu_seqlens_q, seqused_kv, softmax_scale, cmp_ratio, sinks,
            ori_mask_mode, cmp_mask_mode, ori_win_left, ori_win_right, layout_q, layout_kv):
            metadata = torch.ops.custom.npu_sparse_attn_sharedkv_metadata(
                num_heads_q=num_heads_q,
                num_heads_kv=num_heads_kv,
                head_dim=head_dim,
                cu_seqlens_q=cu_seqlens_q,
                seqused_kv=seqused_kv,
                batch_size=batch_size,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
                cmp_topk=topk,
                cmp_ratio=cmp_ratio,
                ori_mask_mode=ori_mask_mode,
                cmp_mask_mode=cmp_mask_mode,
                ori_win_left=ori_win_left,
                ori_win_right=ori_win_right,
                layout_q=layout_q,
                layout_kv=layout_kv,
                has_ori_kv=has_ori_kv,
                has_cmp_kv=has_cmp_kv,
                device="npu:0"
            )
            npu_out = torch.ops.custom.npu_sparse_attn_sharedkv(
                q,
                ori_kv=ori_kv,
                cmp_kv=cmp_kv,
                ori_sparse_indices=None,
                cmp_sparse_indices=cmp_sparse_indices,
                ori_block_table=ori_block_table,
                cmp_block_table=cmp_block_table,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_ori_kv=None,
                cu_seqlens_cmp_kv=None,
                seqused_q=None,
                seqused_kv=seqused_kv,
                sinks=sinks,
                metadata=metadata,
                softmax_scale=softmax_scale,
                cmp_ratio=cmp_ratio,
                ori_mask_mode=ori_mask_mode,
                cmp_mask_mode=cmp_mask_mode,
                ori_win_left=ori_win_left,
                ori_win_right=ori_win_right,
                layout_q=layout_q,
                layout_kv=layout_kv,
                return_softmax_lse=False)
            return npu_out
    
    mod = torch.compile(Network().npu(), backend=npu_backend, fullgraph=True)
    attn_out, softmax_lse = mod(
        num_heads_q=n1,
        num_heads_kv=n2,
        head_dim=dn,
        batch_size=b,
        max_seqlen_q=s1,
        max_seqlen_kv=s2,
        topk=k,
        has_ori_kv=True,
        has_cmp_kv=True,
        q=q,
        ori_kv=ori_kv,
        cmp_kv=cmp_kv,
        cmp_sparse_indices=cmp_sparse_indices,
        ori_block_table=ori_block_table,
        cmp_block_table=cmp_block_table,
        cu_seqlens_q=cu_seqlens_q,
        seqused_kv=seqused_kv,
        softmax_scale=softmax_scale,
        cmp_ratio=cmp_ratio,
        sinks=sinks,
        ori_mask_mode=ori_mask_mode,
        cmp_mask_mode=cmp_mask_mode,
        ori_win_left=ori_win_left,
        ori_win_right=ori_win_right,
        layout_q=layout_q,
        layout_kv=layout_kv)
    

更多使用示例见pytest示例