(beta)torch_npu.contrib.npu_fused_attention_with_layernorm

[!NOTICE]
该接口计划废弃,可以使用torch_npu.npu_fusion_attentiontorch.nn.LayerNorm接口进行替换。

产品支持情况

产品 是否支持
Atlas A3 训练系列产品
Atlas A2 训练系列产品
Atlas 训练系列产品
Atlas 推理系列产品

功能说明

bert自注意力与层归一化的融合实现。

函数原型

torch_npu.contrib.npu_fused_attention_with_layernorm(hidden_states, attention_mask, query_kernel, key_kernel, value_kernel, query_bias, key_bias, value_bias, gamma, beta, scale=1, keep_prob=0)

参数说明

  • hidden_statesTensor):最后一层的hidden_states。
  • attention_maskTensor):attention mask。
  • query_kernelTensor):query的权重。
  • key_kernelTensor):key的权重。
  • value_kernelTensor):value的权重。
  • query_biasTensor):query的偏差值。
  • key_biasTensor):key的偏差值。
  • value_biasTensor):value的偏差值。
  • gammaTensor):torch.nn.LayerNorm.weight类型的tensor。
  • betaTensor):torch.nn.LayerNorm.bias类型的tensor。
  • scaledouble):计算score的缩放系数。默认值为1。
  • keep_probdouble):计算中保留数据的概率,值等于1-drop rate。默认值为0。

返回值说明

torch.Tensor

self attention的结果。