import os
import unittest
import torch
import torch_npu
from torch import nn
from device import DEVICE_ID
from mindiesd import fast_layernorm
from mindiesd.utils import ParametersInvalid
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU", "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU."
)
class TestLayerNorm(unittest.TestCase):
def setUp(self):
self.x = torch.randn([2, 1024, 128], dtype=torch.float32).npu()
self.layernorm_have_param = nn.LayerNorm(normalized_shape=128).npu()
self.layernorm_non_param = nn.LayerNorm(normalized_shape=128, elementwise_affine=False, bias=False).npu()
def test_layernorm_have_param(self):
out_npu = fast_layernorm(self.layernorm_have_param, self.x, 0).reshape(1, -1)
origin = self.layernorm_have_param(self.x).reshape(1, -1)
self.assertGreater(torch.cosine_similarity(out_npu, origin)[0], 2**-7)
def test_layernorm_non_param(self):
out_npu = fast_layernorm(self.layernorm_non_param, self.x, 0).reshape(1, -1)
origin = self.layernorm_non_param(self.x).reshape(1, -1)
self.assertGreater(torch.cosine_similarity(out_npu, origin)[0], 2**-7)
def test_impl_mode(self):
with self.assertRaises(ParametersInvalid):
fast_layernorm(self.layernorm_have_param, self.x, 5)
with self.assertRaises(ParametersInvalid):
fast_layernorm(self.layernorm_have_param, self.x.to(torch.bfloat16), 2)
out_npu = fast_layernorm(self.layernorm_have_param, self.x, 1).reshape(1, -1)
origin = self.layernorm_have_param(self.x).reshape(1, -1)
self.assertGreater(torch.cosine_similarity(out_npu, origin)[0], 2**-7)
def test_normalized_shape_too_large(self):
"""Test that normalized_shape with more dims than input raises ParametersInvalid."""
bad_layernorm = nn.LayerNorm(normalized_shape=(2, 1024, 128, 64)).npu()
with self.assertRaises(ParametersInvalid) as ctx:
fast_layernorm(bad_layernorm, self.x, 0)
self.assertIn("normalized_shape must fit within input dimensions", str(ctx.exception))
def test_normalized_shape_equal_dims(self):
"""Test edge case where normalized_shape ndim equals input ndim."""
layernorm = nn.LayerNorm(normalized_shape=(2, 1024, 128)).npu()
x = torch.randn([2, 1024, 128], dtype=torch.float32).npu()
out = fast_layernorm(layernorm, x, 0)
self.assertEqual(out.shape, x.shape)
def test_normalized_shape_less_than_dims(self):
"""Test normal case where normalized_shape ndim is less than input ndim."""
layernorm = nn.LayerNorm(normalized_shape=(128,)).npu()
x = torch.randn([2, 1024, 128], dtype=torch.float32).npu()
out = fast_layernorm(layernorm, x, 0)
self.assertEqual(out.shape, x.shape)
if __name__ == '__main__':
torch_npu.npu.set_device(DEVICE_ID)
unittest.main()