torch_npu.npu_fast_gelu

产品支持情况

产品 是否支持
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 推理系列产品
Atlas 训练系列产品

功能说明

  • API功能:快速高斯误差线性单元激活函数(Fast Gaussian Error Linear Units activation function),对输入的每个元素计算FastGelu的前向结果。

  • 计算公式:

    • Atlas 训练系列产品、Atlas 推理系列产品上,计算公式如下:

    fast_gelu(x)=x1+e−1.702∣x∣e0.851x(x−∣x∣)fast\_gelu(x)=\frac{x}{1+e^{-1.702 \mid x\mid}} e^{0.851 x(x- \mid x\mid)}

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品产品上,计算公式如下:

    fast_gelu(x)=x1+e−1.702xfast\_gelu(x)=\frac{x}{1+e^{-1.702x}}

函数原型

torch_npu.npu_fast_gelu(input) -> Tensor

参数说明

input (Tensor):对应公式中的xx。数据格式支持NDND,支持非连续的Tensor。输入最大支持8维。

  • Atlas 训练系列产品:数据类型支持float16float32
  • Atlas A2 训练系列产品/Atlas A2 推理系列产品:数据类型支持float16float32bfloat16
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float16float32bfloat16
  • Atlas 推理系列产品:数据类型仅支持float16float32

返回值说明

Tensor

代表fast_gelu的计算结果。

约束说明

  • 该接口支持推理、训练场景下使用。
  • 该接口支持图模式。
  • input输入不能为None。

调用示例

  • 单算子调用

    import os
    import torch
    import torch_npu
    import numpy as np
    
    data_var = np.random.uniform(0, 1, [4, 2048, 16, 128]).astype(np.float32)
    x = torch.from_numpy(data_var).to(torch.float32).npu()
    y = torch_npu.npu_fast_gelu(x).cpu().numpy()
    
  • 图模式调用

    import os
    import torch
    import torch_npu
    import numpy as np
    import torch.nn as nn
    import torchair as tng
    from torchair.configs.compiler_config import CompilerConfig
    
    os.environ["ENABLE_ACLNN"] = "false"
    torch_npu.npu.set_compile_mode(jit_compile=True)
    
    class Network(nn.Module):
        def __init__(self):
            super(Network, self).__init__()
        def forward(self, x): 
            y = torch_npu.npu_fast_gelu(x)
            return y
            
    npu_mode = Network()
    config = CompilerConfig()
    npu_backend = tng.get_npu_backend(compiler_config=config)
    npu_mode = torch.compile(npu_mode, fullgraph=True, backend=npu_backend, dynamic=False)
    data_var = np.random.uniform(0, 1, [4, 2048, 16, 128]).astype(np.float32)
    x = torch.from_numpy(data_var).to(torch.float32).npu()
    y =npu_mode(x).cpu().numpy()
    print("shape of y:",y.shape)
    print("dtype of y:",y.dtype)
    
    # 执行上述代码的输出类似如下
    shape of y: (4, 2048, 16, 128)
    dtype of y: float32