torch_npu.npu_group_norm_swish
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Atlas A3 训练系列产品 | √ |
| Atlas A2 训练系列产品 | √ |
功能说明
-
API功能:计算输入
input的组归一化结果y,均值mean,标准差的倒数rstd,以及swish的输出。 -
计算公式:
- GroupNorm: 公式中的xx代表
input, E[x]=xˉE[x] = \bar{x} 代表xx的均值,Var[x]=1n∗∑i=1n(xi−E[x])2Var[x] = \frac{1}{n} * \sum_{i=1}^{n} (x_i - E[x])^2 代表xx的方差,γ\gamma代表weight,β\beta代表bias,则公式如下:
{y=x−E[x]Var[x]+eps∗γ+βmean=E[x]rstd=1Var[x]+eps\begin{cases} y & = \frac{x - E[x]}{\sqrt{{Var[x]} + eps}} * \gamma + \beta \\ mean & = E[x] \\ rstd & = \frac{1}{\sqrt{{Var[x]} + eps}} \end{cases}
- swish:swish计算公式的xx为GroupNorm公式得到的yy。
y=x1+e−scale⋅xy = \frac{x}{1 + e^{-scale \cdot x}}
- GroupNorm: 公式中的xx代表
Note
需要计算反向梯度场景时,若需要输出结果排除随机性,则需要设置确定性计算开关。
函数原型
torch_npu.npu_group_norm_swish(input, num_groups, weight, bias, eps=1e-5, swish_scale=1.0) -> (Tensor, Tensor, Tensor)
参数说明
- input(
Tensor):必选参数,表示需要进行组归一化的数据,支持2-8D张量,数据类型支持float16,float32,bfloat16。 - num_groups(
int):必选参数,表示将input的第1维分为num_groups组,input的第1维必须能被num_groups整除。 - weight(
Tensor):必选参数,表示权重,支持1D张量,并且第0维大小与input的第1维相同;数据类型支持float16,float32,bfloat16,并且需要与input一致。 - bias(
Tensor):必选参数,表示偏置,支持1D张量,并且第0维大小与input的第1维相同;数据类型支持float16,float32,bfloat16,并且需要与input一致。 - eps(
float):可选参数,计算组归一化时加到分母上的值,以保证数值的稳定性。默认值为1e-5。 - swish_scale(
float):可选参数,用于进行swish计算的值。默认值为1.0。
返回值说明
y(Tensor):表示组归一化和swish计算的结果。
mean(Tensor):表示分组后的均值。
rstd(Tensor):表示分组后的标准差的倒数。
约束说明
需要计算反向梯度场景时,input的第1维除以num_groups的结果不能超过4000,input、weight、bias参数不支持含有-inf、inf或nan值。
调用示例
import torch
import torch_npu
input = torch.randn(3, 4, 6, dtype=torch.float32).npu()
weight = torch.randn(input.size(1), dtype=torch.float32).npu()
bias = torch.randn(input.size(1), dtype=torch.float32).npu()
num_groups = input.size(1)
eps = 1e-5
swish_scale = 1.0
out, mean, rstd = torch_npu.npu_group_norm_swish(input, num_groups, weight, bias, eps=eps, swish_scale=swish_scale)