torch_npu.npu_fast_gelu
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 推理系列产品 | √ |
| Atlas 训练系列产品 | √ |
功能说明
-
API功能:快速高斯误差线性单元激活函数(Fast Gaussian Error Linear Units activation function),对输入的每个元素计算
FastGelu的前向结果。 -
计算公式:
- Atlas 训练系列产品、Atlas 推理系列产品上,计算公式如下:
fast_gelu(x)=x1+e−1.702∣x∣e0.851x(x−∣x∣)fast\_gelu(x)=\frac{x}{1+e^{-1.702 \mid x\mid}} e^{0.851 x(x- \mid x\mid)}
- Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品产品上,计算公式如下:
fast_gelu(x)=x1+e−1.702xfast\_gelu(x)=\frac{x}{1+e^{-1.702x}}
函数原型
torch_npu.npu_fast_gelu(input) -> Tensor
参数说明
input (Tensor):对应公式中的xx。数据格式支持NDND,支持非连续的Tensor。输入最大支持8维。
- Atlas 训练系列产品:数据类型支持
float16、float32。 - Atlas A2 训练系列产品/Atlas A2 推理系列产品:数据类型支持
float16、float32、bfloat16。 - Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持
float16、float32、bfloat16。 - Atlas 推理系列产品:数据类型仅支持
float16、float32。
返回值说明
Tensor
代表fast_gelu的计算结果。
约束说明
- 该接口支持推理、训练场景下使用。
- 该接口支持图模式。
input输入不能为None。
调用示例
-
单算子调用
import os import torch import torch_npu import numpy as np data_var = np.random.uniform(0, 1, [4, 2048, 16, 128]).astype(np.float32) x = torch.from_numpy(data_var).to(torch.float32).npu() y = torch_npu.npu_fast_gelu(x).cpu().numpy() -
图模式调用
import os import torch import torch_npu import numpy as np import torch.nn as nn import torchair as tng from torchair.configs.compiler_config import CompilerConfig os.environ["ENABLE_ACLNN"] = "false" torch_npu.npu.set_compile_mode(jit_compile=True) class Network(nn.Module): def __init__(self): super(Network, self).__init__() def forward(self, x): y = torch_npu.npu_fast_gelu(x) return y npu_mode = Network() config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) npu_mode = torch.compile(npu_mode, fullgraph=True, backend=npu_backend, dynamic=False) data_var = np.random.uniform(0, 1, [4, 2048, 16, 128]).astype(np.float32) x = torch.from_numpy(data_var).to(torch.float32).npu() y =npu_mode(x).cpu().numpy() print("shape of y:",y.shape) print("dtype of y:",y.dtype) # 执行上述代码的输出类似如下 shape of y: (4, 2048, 16, 128) dtype of y: float32