(beta)torch_npu.contrib.module.MultiheadAttention

产品支持情况

产品 是否支持
Atlas A3 训练系列产品
Atlas A2 训练系列产品
Atlas 推理系列产品
Atlas 训练系列产品

功能说明

Multi-head attention。

函数原型

torch_npu.contrib.module.MultiheadAttention(embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, q_noise=0.0, qn_block_size=8)

参数说明

  • embed_dim (int):模型总维度。
  • num_heads (int):并行attention head。
  • kdim(int):key的特性总数。默认值为None。
  • vdim(int):value的特性总数。默认值为None。
  • dropout (float):Dropout概率。
  • bias (bool):如果指定此参数,则向输入/输出投影层添加偏置。默认值为True。
  • add_bias_kv (bool):如果指定此参数,则在dim=0处向键值序列添加偏置。默认值为False。
  • add_zero_attn (bool):如果指定此参数,则在dim=1处向键值序列新加一批零。默认值为False。
  • self_attention(bool):表示是否为自注意力层,若取值为True,要求embed_dimkdimvdim取值相等。默认值为False。
  • encoder_decoder_attention (bool):输入为编码器输出和解码器self-attention输出,其中编码器self-attention用作key和value,解码器self-attention用作查询。默认值为False。
  • q_noise(float):量化噪声量。
  • qn_block_size(int):用于后续iPQ量化的块大小。

返回值说明

Tensor

Multi-head attention的计算结果。

调用示例

>>> from torch_npu.testing.common_utils import create_common_tensor
>>> from torch_npu.contrib.module import MultiheadAttention
>>> import numpy as np
>>> from torch_npu.contrib.module.multihead_attention import _MHAConfig
>>> _MHAConfig.set_fussion()
>>> model = MultiheadAttention(embed_dim=1024,num_heads=16,dropout=0.1,kdim=1024,vdim=1024,self_attention=True,encoder_decoder_attention=True)
>>> _, query = create_common_tensor([np.float16, 29, (1024,1024)], -1, 1)
>>> _, key = create_common_tensor([np.float16, 29, (1024,1024)], -1, 1)
>>> _, value = create_common_tensor([np.float16, 29, (1024,1024)], -1, 1)
>>> _, key_padding_mask = create_common_tensor([np.float16, 29, (1024,1024)], -1, 1)
>>> bsz = 16
>>> tgt_len = 64
>>> s_len=64
>>> model = model.to("npu")
>>> output = model(query, key, value, bsz, tgt_len, s_len, key_padding_mask)
>>> output
(tensor([[-0.0385,  0.0441,  0.2432,  ...,  0.0627,  0.0254,  0.0400],
        [-0.0999, -0.0258,  0.1002,  ...,  0.0632,  0.0344,  0.0573],
        [-0.0830,  0.0405,  0.1694,  ...,  0.0787, -0.0089,  0.0544],
        ...,
        [-0.0146, -0.0303, -0.1011,  ...,  0.0689,  0.1722, -0.1125],
        [-0.0305, -0.1129, -0.0944,  ...,  0.0280,  0.1777,  0.0410],
        [-0.0035, -0.1030, -0.0957,  ...,  0.0093,  0.1171,  0.0009]],
       device='npu:0', dtype=torch.float16,
       grad_fn=<NpuMultiHeadAttentionBackward0>), None)