import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestNPUGroupNormSwish(TestCase):
@SupportedDevices(['Ascend910B'])
def supported_op_exec(self, x, weight, bias, num_groups, eps, swish_scale):
N = x.size(0)
C = x.size(1)
remaining_dims = x.size()[2:]
hw = 1
for size in remaining_dims:
hw *= size
x.requires_grad_(True)
weight.requires_grad_(True)
bias.requires_grad_(True)
out, mean_out, rstd_out = torch.ops.aten.native_group_norm(x, weight=weight, bias=bias, N=N, C=C, HxW=hw, group=num_groups, eps=eps)
sigmoid_x = torch.nn.functional.sigmoid(swish_scale * out)
out = out * sigmoid_x
out.backward(torch.ones_like(out))
return out, mean_out, rstd_out, x.grad, weight.grad, bias.grad
def custom_op_exec(self, x, weight, bias, num_groups, data_format, eps, swish_scale):
x.requires_grad_(True)
weight.requires_grad_(True)
bias.requires_grad_(True)
out, mean_out, rstd_out = torch_npu.npu_group_norm_swish(x, num_groups, weight, bias, eps=eps, swish_scale=swish_scale)
out.backward(torch.ones_like(out))
return out, mean_out, rstd_out, x.grad, weight.grad, bias.grad
def test_npu_group_norm_swish(self):
torch.manual_seed(123)
shape_list = [[3, 3], [3, 6, 7, 2], [24, 35, 76]]
dtype_list = [torch.float32]
for shape in shape_list:
for dtype in dtype_list:
x = torch.randn(shape, dtype=dtype)
weight = torch.randn(x.size(1), dtype=dtype)
bias = torch.randn(x.size(1), dtype=dtype)
x_npu = x.npu()
weight_npu = weight.npu()
bias_npu = bias.npu()
eps = 1e-5
num_groups = x.size(1)
data_format = 'NCHW'
swish_scale = 1.0
cpuout = self.supported_op_exec(x, weight, bias, num_groups, eps, swish_scale)
npuout = self.custom_op_exec(x_npu, weight_npu, bias_npu, num_groups, data_format, eps, swish_scale)
self.assertRtolEqual(cpuout[:3], npuout[:3])
self.assertRtolEqual(cpuout[3:], npuout[3:])
if __name__ == "__main__":
run_tests()