(beta)torch_npu.fast_gelu

产品支持情况

产品 是否支持
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 训练系列产品
Atlas 推理系列产品

功能说明

快速高斯误差线性单元激活函数(Fast Gaussian Error Linear Units activation function),对输入的每个元素计算FastGelu。支持FakeTensor模式。

函数原型

torch_npu.fast_gelu(input) -> Tensor

参数说明

input (Tensor):对应公式中的xx。数据格式支持NDND,支持非连续的Tensor。输入最大支持8维。支持空Tensor。

  • Atlas 训练系列产品:数据类型支持float16float32
  • Atlas A2 训练系列产品/Atlas A2 推理系列产品:数据类型支持float16float32bfloat16
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float16float32bfloat16
  • Atlas 推理系列产品:数据类型仅支持float16float32

调用示例

示例一:

>>> import torch
>>> import torch_npu
>>> x = torch.rand(2).npu()
>>> x
tensor([0.5991, 0.4094], device='npu:0')
>>> torch_npu.fast_gelu(x)
tensor([0.4403, 0.2733], device='npu:0')

示例二:

>>> import torch
>>> import torch_npu
# FakeTensor模式
>>> from torch._subclasses.fake_tensor import FakeTensorMode
>>> with FakeTensorMode():
...     x = torch.rand(2).npu()
...     torch_npu.fast_gelu(x)
>>> FakeTensor(..., device='npu:0', size=(2,))