import os
import unittest
import torch
from torch import nn
if os.environ.get("MINDIE_TEST_MODE", "ALL") != "CPU":
from mindiesd.layers.register_ops import _load_mindie_ops_library
_load_mindie_ops_library()
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU", "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU."
)
class TestLayerNorm(unittest.TestCase):
def setUp(self):
self.device = torch.device("npu:0")
torch.npu.set_device(self.device)
self.x_shape = (2, 48, 128)
self.dtype = torch.bfloat16
self.layernorm_origin = nn.LayerNorm(normalized_shape=128).npu()
self.x = torch.randn(self.x_shape, device=self.device, dtype=self.dtype)
def test_layernorm_output_shape(self):
output = torch.ops.mindiesd.layernorm(
self.x,
list(self.layernorm_origin.normalized_shape),
self.layernorm_origin.weight,
self.layernorm_origin.bias,
self.layernorm_origin.eps,
impl_mode=0,
)[0]
expected_shape = self.x_shape
self.assertEqual(output.shape, expected_shape, "Output shape does not match expected shape.")
def test_layernorm(self):
output_0 = torch.ops.mindiesd.layernorm(
self.x,
list(self.layernorm_origin.normalized_shape),
self.layernorm_origin.weight,
self.layernorm_origin.bias,
self.layernorm_origin.eps,
impl_mode=0,
)[0].reshape(1, -1)
output_1 = torch.ops.mindiesd.layernorm(
self.x,
list(self.layernorm_origin.normalized_shape),
self.layernorm_origin.weight,
self.layernorm_origin.bias,
self.layernorm_origin.eps,
impl_mode=1,
)[0].reshape(1, -1)
origin = self.layernorm_origin(self.x).reshape(1, -1)
self.assertGreater(torch.cosine_similarity(output_0, origin)[0], 2**-7)
self.assertGreater(torch.cosine_similarity(output_1, origin)[0], 2**-7)
def test_normalized_shape_too_large(self):
"""Test that normalized_shape with more dims than input raises RuntimeError from TORCH_CHECK."""
bad_normalized_shape = list(self.x_shape) + [64]
with self.assertRaises(RuntimeError) as ctx:
torch.ops.mindiesd.layernorm(
self.x,
bad_normalized_shape,
self.layernorm_origin.weight,
self.layernorm_origin.bias,
self.layernorm_origin.eps,
impl_mode=0,
)
self.assertIn("normalized_shape must fit within input dimensions", str(ctx.exception))
def test_normalized_shape_equal_dims(self):
"""Test edge case where normalized_shape ndim equals input ndim."""
x = torch.randn(self.x_shape, device=self.device, dtype=self.dtype)
normalized_shape = list(self.x_shape)
layernorm = nn.LayerNorm(normalized_shape=normalized_shape).to(device=self.device, dtype=self.dtype)
output = torch.ops.mindiesd.layernorm(
x, normalized_shape, layernorm.weight, layernorm.bias, layernorm.eps, impl_mode=0
)[0]
self.assertEqual(output.shape, x.shape)
if __name__ == "__main__":
unittest.main(argv=[''], exit=False)