import torch
import torch.nn as nn
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
class TestNormalizationLayers(TestCase):
def test_BatchNorm1d(self):
m = nn.BatchNorm1d(100).npu()
input1 = torch.randn(20, 100).npu()
output = m(input1)
self.assertEqual(output is not None, True)
def test_BatchNorm2d(self):
m = nn.BatchNorm2d(100, affine=False).npu()
input1 = torch.randn(20, 100, 35, 45).npu()
output = m(input1)
self.assertEqual(output is not None, True)
def test_BatchNorm3d(self):
m = nn.BatchNorm3d(100).npu()
input1 = torch.randn(20, 100, 35, 45, 10).npu()
output = m(input1)
self.assertEqual(output is not None, True)
def test_GroupNorm(self):
m = nn.GroupNorm(3, 6).npu()
input1 = torch.randn(20, 6, 10, 10).npu()
output = m(input1)
self.assertEqual(output is not None, True)
def test_convert_sync_batchnorm(self):
module = torch.nn.Sequential(
torch.nn.BatchNorm1d(100),
torch.nn.InstanceNorm1d(100)
).npu()
sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
children = list(sync_bn_module.children())
self.assertEqual(children[0].__class__, torch.nn.SyncBatchNorm)
self.assertEqual(children[1].__class__, torch.nn.InstanceNorm1d)
def test_InstanceNorm1d(self):
m = nn.InstanceNorm1d(100).npu()
input1 = torch.randn(20, 100, 40).npu()
output = m(input1)
self.assertEqual(output is not None, True)
def test_InstanceNorm2d(self):
m = nn.InstanceNorm2d(100).npu()
input1 = torch.randn(20, 100, 35, 45).npu()
output = m(input1)
self.assertEqual(output is not None, True)
def test_InstanceNorm3d(self):
m = nn.InstanceNorm3d(100).npu()
input1 = torch.randn(20, 100, 35, 45, 10).npu()
output = m(input1)
self.assertEqual(output is not None, True)
def test_LayerNorm(self):
input1 = torch.randn(20, 5, 10, 10).npu()
m = nn.LayerNorm(input1.size()[1:]).npu()
output = m(input1)
self.assertEqual(output is not None, True)
def test_LocalResponseNorm(self):
lrn = nn.LocalResponseNorm(2).npu()
signal_2d = torch.randn(32, 5, 24, 24).npu()
signal_4d = torch.randn(16, 5, 7, 7, 7, 7).npu()
output_2d = lrn(signal_2d)
output_4d = lrn(signal_4d)
self.assertEqual(output_2d is not None, True)
self.assertEqual(output_4d is not None, True)
if __name__ == "__main__":
run_tests()