e55cdff6创建于 2025年12月30日历史提交

torch_npu.npu_prompt_flash_attention

产品支持情况

产品 是否支持
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 推理系列加速卡产品

功能说明

  • API功能:全量FA实现。

  • 计算公式:

atten_out=softmax(scale⋅(Q⋅K)+atten_mask)⋅Vatten\_out = softmax\left(scale \cdot (Q \cdot K) + atten\_mask\right) \cdot V

函数原型

torch_npu.npu_prompt_flash_attention(query, key, value, *, pse_shift=None, padding_mask=None, atten_mask=None, actual_seq_lengths=None, deq_scale1=None, quant_scale1=None, deq_scale2=None, quant_scale2=None, quant_offset2=None, num_heads=1, scale_value=1.0, pre_tokens=2147483647, next_tokens=0, input_layout="BSH",num_key_value_heads=0, actual_seq_lengths_kv=None, sparse_mode=0) -> Tensor

参数说明

  • query (Tensor):必选参数,对应公式中的输入QQ,数据类型与key的数据类型需满足数据类型推导规则,即保持与keyvalue的数据类型一致。不支持非连续的Tensor,数据格式支持NDND

    • Atlas 推理系列加速卡产品:数据类型支持float16
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:数据类型支持float16bfloat16int8
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float16bfloat16int8
  • key (Tensor):必选参数,对应公式中的输入KK,数据类型与query的数据类型需满足数据类型推导规则,即保持与queryvalue的数据类型一致。不支持非连续的Tensor,数据格式支持NDND

    • Atlas 推理系列加速卡产品:数据类型支持float16
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:数据类型支持float16bfloat16int8
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float16bfloat16int8
  • value (Tensor):必选参数,对应公式中的输入VV,数据类型与query的数据类型需满足数据类型推导规则,即保持与querykey的数据类型一致。不支持非连续的Tensor,数据格式支持NDND

    • Atlas 推理系列加速卡产品:数据类型支持float16
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:数据类型支持float16bfloat16int8
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float16bfloat16int8
  • *:必选参数,代表其之前的变量是位置相关的,必须按照顺序输入;之后的变量是可选参数,位置无关,需要使用键值对赋值,不赋值会使用默认值。

  • pse_shift (Tensor):可选参数。不支持非连续的Tensor,数据格式支持NDND。输入shape类型需为(B,N,Q_S,KV_S)(B, N, Q\_S, KV\_S)(1,N,Q_S,KV_S)(1, N, Q\_S, KV\_S),其中Q_SQ\_Squery的shape中的SSKV_SKV\_Skeyvalue的shape中的SS。对于pse_shiftKV_SKV\_S为非32字节对齐的场景,建议padding到32字节来提高性能,多余部分的填充值不做要求。如不使用该功能时可传入None。综合约束请见约束说明

    • Atlas 推理系列加速卡产品:暂不支持该参数。
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:数据类型支持float16bfloat16。当pse_shiftfloat16时,要求queryfloat16int8;当pse_shiftbfloat16时,要求querybfloat16。在querykeyvaluefloat16pse_shift存在的情况下,默认走高精度模式。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float16bfloat16。当pse_shiftfloat16时,要求queryfloat16int8;当pse_shiftbfloat16时,要求querybfloat16。在querykeyvaluefloat16pse_shift存在的情况下,默认走高精度模式。
  • padding_mask:预留参数,暂未使用,默认值为None

  • atten_mask (Tensor):可选参数,对应公式中atten_maskatten\_mask,代表下三角全为0上三角全为负无穷的倒三角mask矩阵,数据类型支持boolint8uint8。数据格式支持NDND,不支持非连续的Tensor。如果不使用该功能可传入None。通常建议shape输入(Q_S,KV_S)(Q\_S, KV\_S)(B,Q_S,KV_S)(B, Q\_S, KV\_S)(1,Q_S,KV_S)(1, Q\_S, KV\_S)(B,1,Q_S,KV_S)(B, 1, Q\_S, KV\_S)(1,1,Q_S,KV_S)(1, 1, Q\_S, KV\_S),其中Q_SQ\_Squery的shape中的SSKV_SKV\_Skeyvalue的shape中的SS,对于atten_maskKV_SKV\_S为非32字节对齐的场景,建议padding到32字节对齐来提高性能,多余部分填充成1。综合约束请见约束说明

  • actual_seq_lengths (List[int]):可选参数,代表不同Batch中query的有效Sequence Length,数据类型支持int64。如果不指定seqlen可以传入None,表示和query的shape的s长度相同。限制:该入参中每个batch的有效Sequence Length应该不大于query中对应batch的Sequence Length。seqlen的传入长度为1时,每个Batch使用相同seqlen;传入长度大于等于Batch数时取seqlen的前Batch个数。其它长度不支持。

    • Atlas 推理系列加速卡产品:暂不支持该参数。
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:支持TND格式。当queryinput_layout为TND时,该入参必须传入,且以该入参元素的数量作为Batch值。该入参中每个元素的值表示当前Batch与之前所有Batch的seqlen和,因此后一个元素的值必须大于等于前一个元素的值,且不能出现负值。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持TND格式。当queryinput_layout为TND时,该入参必须传入,且以该入参元素的数量作为Batch值。该入参中每个元素的值表示当前Batch与之前所有Batch的seqlen和,因此后一个元素的值必须大于等于前一个元素的值,且不能出现负值。
  • deq_scale1 (Tensor):可选参数,表示BMM1后面的反量化因子,支持pertensor。数据类型支持uint64float32,数据格式支持NDND。如不使用该功能时可传入None。Atlas 推理系列加速卡产品暂不支持该参数。

  • quant_scale1 (Tensor):可选参数,数据类型支持float32。数据格式支持NDND,表示BMM2前面的量化因子,支持pertensor。如不使用该功能时可传入None。Atlas 推理系列加速卡产品暂不支持该参数。

  • deq_scale2 (Tensor):可选参数,数据类型支持uint64float32。数据格式支持NDND,表示BMM2后面的反量化因子,支持pertensor。如不使用该功能时可传入None。Atlas 推理系列加速卡产品暂不支持该参数。

  • quant_scale2 (Tensor):可选参数,数据格式支持NDND,表示输出的量化因子,支持pertensor、perchannel。如不使用该功能时可传入None

    • Atlas 推理系列加速卡产品:暂不支持该参数。
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:数据类型支持float32bfloat16。当输入为bfloat16时,同时支持float32bfloat16,否则仅支持float32。perchannel格式,当输出layout为BSHBSH时,要求quant_scale2所有维度的乘积等于HH;其他layout要求乘积等于N∗DN*D。当输出layout为BSHBSHquant_scale2 shape建议传入(1,1,H)(1, 1, H)(H,)(H,);输出为BNSDBNSD时,建议传入(1,N,1,D)(1, N, 1, D)(N,D)(N, D);输出为BSNDBSND时,建议传入(1,1,N,D)(1, 1, N, D)(N,D)(N, D)
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float32bfloat16。当输入为bfloat16时,同时支持float32bfloat16,否则仅支持float32。perchannel格式,当输出layout为BSHBSH时,要求quant_scale2所有维度的乘积等于HH;其他layout要求乘积等于N∗DN*D。当输出layout为BSHBSHquant_scale2 shape建议传入(1,1,H)(1, 1, H)(H,)(H,);输出为BNSDBNSD时,建议传入(1,N,1,D)(1, N, 1, D)(N,D)(N, D);输出为BSNDBSND时,建议传入(1,1,N,D)(1, 1, N, D)(N,D)(N, D)
  • quant_offset2 (Tensor):可选参数,数据格式支持NDND,表示输出的量化偏移,支持pertensor、perchannel。若传入 quant_offset2,需保证其类型和shape信息与quant_scale2一致。如不使用该功能时可传入None

    • Atlas 推理系列加速卡产品:暂不支持该参数。
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:数据类型支持float32bfloat16
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float32bfloat16
  • num_heads (List[int]):可选参数,代表query的head个数,数据类型支持int64

  • scale_value (float):可选参数,对应公式中scalescale,值通常是dd开根号的倒数,代表缩放系数,作为计算流中Muls的scalar值,数据类型支持float。数据类型与query的数据类型需满足数据类型推导规则。默认值为1.0

  • pre_tokens (int):可选参数,用于稀疏计算,表示Attention(注意力机制)需要和前几个Token计算关联,数据类型支持int64。默认值为2147483647。Atlas 推理系列加速卡产品仅支持默认值2147483647。

  • next_tokens (int):可选参数,用于稀疏计算,表示Attention需要和后几个Token计算关联。数据类型支持int64。默认值为0。Atlas 推理系列加速卡产品仅支持0和2147483647。

  • input_layout (str):可选参数,用于标识输入querykeyvalue的数据排布格式,当前支持BSHBSHBSNDBSNDBNSDBNSDBNSD_BSNDBNSD\_BSND(输入为BNSDBNSD时,输出格式为BSNDBSND)。默认值为"BSH"

  • num_key_value_heads:可选参数,代表keyvalue中head个数,用于支持GQA(Grouped-Query Attention,分组查询注意力)场景,数据类型支持int64。默认值为0,表示key/valuequery的head个数相等。限制:需要满足num_heads整除num_key_value_headsnum_headsnum_key_value_heads的比值不能大于64,且在BSNDBSNDBNSDBNSDBNSD_BSNDBNSD\_BSND场景下,需要与shape中的key/valueNN轴shape值相同,否则报错。Atlas 推理系列加速卡产品仅支持默认值0

  • actual_seq_lengths_kv (int):可选参数,代表不同batch中key/value的有效seqlenKV。数据类型支持int64。限制:该入参中每个batch的有效seqlenKV应该不大于key/value中对应batch的seqlenKV。seqlenKV的传入长度为1时,每个Batch使用相同seqlenKV;传入长度大于等于Batch数时取seqlenKV的前Batch个数,其它长度不支持。

    • Atlas 推理系列加速卡产品:暂不支持该参数。
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:支持TND格式。当key/value的input_layout为TND时,该入参必须传入,且以该入参元素的数量作为Batch值。该入参中每个元素的值表示当前Batch与之前所有Batch的seqlenKV和,因此后一个元素的值必须大于等于前一个元素的值,且不能出现负值。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持TND格式。当key/value的input_layout为TND时,该入参必须传入,且以该入参元素的数量作为Batch值。该入参中每个元素的值表示当前Batch与之前所有Batch的seqlenKV和,因此后一个元素的值必须大于等于前一个元素的值,且不能出现负值。
  • sparse_mode (int):可选参数,表示sparse的模式,数据类型支持int64。默认值为0,综合约束请见约束说明。Atlas 推理系列加速卡产品仅支持默认值0

    • sparse_mode为0时,代表defaultMask模式,如果atten_mask未传入则不做mask操作,忽略pre_tokensnext_tokens(内部赋值为INT_MAX);如果传入,则需要传入完整的atten_mask矩阵(S1∗S2)(S1 * S2),表示pre_tokensnext_tokens之间的部分需要计算。不支持传入的mask矩阵中参与计算的部分整行为1的情况。
    • sparse_mode为1时,代表allMask。不支持传入的mask矩阵中参与计算的部分整行为1的情况。
    • sparse_mode为2时,代表leftUpCausal模式的mask,需要传入优化后的atten_mask矩阵(2048*2048)。
    • sparse_mode为3时,代表rightDownCausal模式的mask,均对应以左顶点为划分的下三角场景,需要传入优化后的atten_mask矩阵(2048*2048)。
    • sparse_mode为4时,代表band模式的mask,需要传入优化后的atten_mask矩阵(2048*2048)。
    • sparse_mode为5、6、7、8时,分别代表prefix、global、dilated、block_local,均暂不支持。

返回值说明

Tensor

公式中的atten_outatten\_out,表示计算的最终结果。当input_layoutBNSD_BSNDBNSD\_BSND时,输入query的shape是BNSDBNSD,输出shape为BSNDBSND,其余情况shape与query的shape保持一致。

约束说明

  • 该接口支持推理场景下使用。

  • 该接口支持图模式。

  • 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。

  • 入参为空的处理:算子内部需要判断参数query是否为空,如果是空则直接返回。参数query不为空Tensor,参数keyvalue为空Tensor(即S2S2为0),则填充全零的对应shape的输出(填充atten_out)。atten_out为空Tensor时,框架会处理。

  • querykeyvalue输入,功能使用限制如下:

    产品型号

    轴约束

    Atlas 推理系列加速卡产品

    • 支持B轴小于等于128。
    • 支持N轴小于等于256。
    • 支持S轴小于等于65535(64k)。
    • 支持D轴小于等于512。

    Atlas A2 训练系列产品/Atlas A2 推理系列产品

    Atlas A3 训练系列产品/Atlas A3 推理系列产品

    • 支持B轴小于等于65536(64k),D轴32byte不对齐时仅支持到128。
    • 支持N轴小于等于256。
    • S支持小于等于20971520(20M)。长序列场景下,如果计算量过大可能会导致PFA算子执行超时(aicore error类型报错,errorStr为timeout or trap error),此场景下建议做S切分处理,注:这里计算量会受B、S、N、D等的影响,值越大计算量越大。典型的会超时的长序列(即B、S、N、D的乘积较大)场景包括但不限于:
      • B=1,QN=20,QS=1048576,D = 256,KVN=1,KVS=1048576。
      • B=1,QN=2,QS=10485760,D = 256,KVN=2,KVS=10485760。
      • B=20,QN=1,QS=1048576,D = 256,KVN=1,KVS=1048576。
      • B=1,QN=10,QS=1048576,D = 512,KVN=1,KVS=1048576。
    • 支持D轴小于等于512。`input_layout`为BSH或者BSND时,要求N*D小于65535。
  • 参数sparse_mode当前仅支持值为0、1、2、3、4的场景,取其它值时会报错。

    • sparse_mode为0时,atten_mask如果为空指针,则忽略入参pre_tokensnext_tokens(内部赋值为INT_MAX)。
    • sparse_mode为2、3、4时,atten_mask的shape需要为(S,S)(S, S)(1,S,S)(1, S, S)(1,1,S,S)(1, 1, S, S),其中SS的值需要固定为2048,且需要用户保证传入的atten_mask为下三角,不传入atten_mask或者传入的shape不正确报错。
    • sparse_mode为1、2、3的场景忽略入参pre_tokensnext_tokens并按照相关规则赋值。
  • int8量化相关入参数量与输入、输出数据格式的综合限制:

    • 输入为int8,输出为int8的场景:入参deq_scale1quant_scale1deq_scale2quant_scale2需要同时存在,quant_offset2可选,不传时默认为0
    • 输入为int8,输出为float16的场景:入参deq_scale1quant_scale1deq_scale2需要同时存在,若存在入参quant_offset2quant_scale2(即不为None),则报错并返回。
    • 输入为float16bfloat16,输出为int8的场景:入参quant_scale2需存在,quant_offset2可选,不传时默认为0,若存在入参deq_scale1quant_scale1deq_scale2(即不为None),则报错并返回。
    • 入参 quant_offset2quant_scale2支持pertensor/perchannel两种格式和float32/bfloat16两种数据类型。若传入quant_offset2,需保证其类型和shape信息与quant_scale2一致。当输入为bfloat16时,同时支持float32bfloat16,否则仅支持float32。perchannel格式,当输出layout为BSHBSH时,要求quant_scale2所有维度的乘积等于HH;其他layout要求乘积等于N∗DN*D。当输出layout为BSHBSHquant_scale2 shape传入(1,1,H)(1, 1, H)(H,)(H,);输出为BNSDBNSD时,建议传入(1,N,1,D)(1, N, 1, D)(N,D)(N, D);输出为BSNDBSND时,建议传入(1,1,N,D)(1, 1, N, D)(N,D)(N, D)。pertensor格式,建议DD轴对齐到32Byte。
    • perchannel格式,入参quant_scale2quant_offset2暂不支持左padding、Ring Attention或者DD非32Byte对齐的场景。
    • 输出为int8时,暂不支持sparse为bandpre_tokens/next_tokens为负数。
  • pse_shift功能使用限制如下:

    • 支持query数据类型为float16bfloat16int8场景下使用该功能。
    • querykeyvalue数据类型为float16pse_shift存在时,强制走高精度模式,对应的限制继承自高精度模式的限制。
    • Q_SQ\_S需大于等于querySS长度,KV_SKV\_S需大于等于keySS长度。
  • 输出为int8,入参quant_offset2传入非空指针和非空Tensor值,并且sparse_modepre_tokensnext_tokens满足以下条件,矩阵会存在某几行不参与计算的情况,导致计算结果误差,该场景会拦截:

    • sparse_mode=0atten_mask如果非空指针,每个batch actual_seq_lengths-actual_seq_lengths_kv-pre_tokens>0next_tokens<0时,满足拦截条件。
    • sparse_mode=12,不会出现满足拦截条件的情况。
    • sparse_mode=3,每个batch actual_seq_lengths_kv-actual_seq_lengths<0,满足拦截条件。
    • sparse_mode=4pre_tokens<0或每个batch next_tokens+actual_seq_lengths_kv-actual_seq_lengths<0时,满足拦截条件。
  • kv伪量化参数分离当前暂不支持。

  • 暂不支持维度不对齐场景。

调用示例

  • 单算子调用

    >>> import torch
    >>> import torch_npu
    >>> import math
    >>>
    >>> # 生成随机数据,并发送到npu
    >>> q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
    >>> k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
    >>> v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
    >>> scale = 1/math.sqrt(128.0)
    >>> actseqlen = [164]
    >>> actseqlenkv = [1024]
    >>>
    >>> # 调用PFA算子
    >>> out = torch_npu.npu_prompt_flash_attention(q, k, v,
    ... actual_seq_lengths = actseqlen, actual_seq_lengths_kv = actseqlenkv,
    ... num_heads = 8, input_layout = "BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535)
    >>> out.shape
    torch.Size([1, 8, 164, 128])
    >>> out.dtype
    torch.float16
    
  • 图模式调用

    # 入图方式
    import torch
    import torch_npu
    import math
    
    import torchair as tng    
    from torchair.configs.compiler_config import CompilerConfig
    import torch._dynamo
    TORCHDYNAMO_VERBOSE=1
    TORCH_LOGS="+dynamo"
    
    # 支持入图的打印宏
    import logging
    from torchair.core.utils import logger
    logger.setLevel(logging.DEBUG)
    config = CompilerConfig()
    config.debug.graph_dump.type = "pbtxt"
    npu_backend = tng.get_npu_backend(compiler_config=config)
    from torch.library import Library, impl
    
    # 数据生成
    q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
    k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
    v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
    scale = 1/math.sqrt(128.0)
    
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
    
        def forward(self):
            return torch_npu.npu_prompt_flash_attention(q, k, v, num_heads = 8, input_layout = "BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535)
    
    def MetaInfershape():
        with torch.no_grad():
            model = Model()
            model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
            graph_output = model()
    
        single_op = torch_npu.npu_prompt_flash_attention(q, k, v, num_heads = 8, input_layout = "BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535)
        print("single op output with mask:", single_op, single_op.shape)
        print("graph output with mask:", graph_output, graph_output.shape)
        
    if __name__ == "__main__":
        MetaInfershape()
    
    # 执行上述代码的输出类似如下
    single op output with mask: tensor([[[[ 0.0219,  0.0201,  0.0049,  ...,  0.0118, -0.0011, -0.0140],
            [ 0.0294,  0.0256, -0.0081,  ...,  0.0267,  0.0067, -0.0117],
            [ 0.0285,  0.0296,  0.0011,  ...,  0.0150,  0.0056, -0.0062],
            ...,
            [ 0.0177,  0.0194, -0.0060,  ...,  0.0226,  0.0029, -0.0039],
            [ 0.0180,  0.0186, -0.0067,  ...,  0.0204, -0.0045, -0.0164],
            [ 0.0176,  0.0288, -0.0091,  ...,  0.0304,  0.0033, -0.0173]]]],
            device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128])
    
    graph output with mask: tensor([[[[ 0.0219,  0.0201,  0.0049,  ...,  0.0118, -0.0011, -0.0140],
            [ 0.0294,  0.0256, -0.0081,  ...,  0.0267,  0.0067, -0.0117],
            [ 0.0285,  0.0296,  0.0011,  ...,  0.0150,  0.0056, -0.0062],
            ...,
            [ 0.0177,  0.0194, -0.0060,  ...,  0.0226,  0.0029, -0.0039],
            [ 0.0180,  0.0186, -0.0067,  ...,  0.0204, -0.0045, -0.0164],
            [ 0.0176,  0.0288, -0.0091,  ...,  0.0304,  0.0033, -0.0173]]]],
            device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128])