import os
import unittest
from unittest.mock import Mock
import torch
import torch_npu
from torch import nn
from device import DEVICE_ID
from utils.utils.precision_compare import data_compare
from mindiesd import layernorm_scale_shift
from mindiesd.utils import ParametersInvalid
from mindiesd.utils.get_platform import NPUDevice, get_npu_device
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU", "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU."
)
class TestAdaLayerNorm(unittest.TestCase):
def setUp(self):
self.norm_eps = 1e-5
def test_layernorm_type(self):
device = "npu"
layernorm = nn.GroupNorm(4, 64).to(device)
x = torch.randn([2, 1024, 128], dtype=torch.float32).to(device)
scale = torch.randn([2, 128], dtype=torch.float32).to(device)
shift = torch.randn([2, 128], dtype=torch.float32).to(device)
fused = True
with self.assertRaises(ParametersInvalid):
layernorm_scale_shift(layernorm, x, scale, shift, fused)
def test_x_type(self):
device = "npu"
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=False).to(device)
x = [2, 1024, 128]
scale = torch.randn([2, 128], dtype=torch.float32).to(device)
shift = torch.randn([2, 128], dtype=torch.float32).to(device)
fused = True
with self.assertRaises(ParametersInvalid):
layernorm_scale_shift(layernorm, x, scale, shift, fused)
def test_scale_type(self):
device = "npu"
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=False).to(device)
x = torch.randn([2, 1024, 128], dtype=torch.float32).to(device)
scale = [2, 128]
shift = torch.randn([2, 128], dtype=torch.float32).to(device)
fused = True
with self.assertRaises(ParametersInvalid):
layernorm_scale_shift(layernorm, x, scale, shift, fused)
def test_shift_type(self):
device = "npu"
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=False).to(device)
x = torch.randn([2, 1024, 128], dtype=torch.float32).to(device)
scale = torch.randn([2, 128], dtype=torch.float32).to(device)
shift = [2, 128]
fused = True
with self.assertRaises(ParametersInvalid):
layernorm_scale_shift(layernorm, x, scale, shift, fused)
def test_fused_type(self):
device = "npu"
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=False).to(device)
x = torch.randn([2, 1024, 128], dtype=torch.float32).to(device)
scale = torch.randn([2, 128], dtype=torch.float32).to(device)
shift = torch.randn([2, 128], dtype=torch.float32).to(device)
fused = "True"
with self.assertRaises(ParametersInvalid):
layernorm_scale_shift(layernorm, x, scale, shift, fused)
def test_x_dim(self):
device = "npu"
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=False).to(device)
x = torch.randn([2, 1, 1024, 128], dtype=torch.float32).to(device)
scale = torch.randn([2, 128], dtype=torch.float32).to(device)
shift = torch.randn([2, 128], dtype=torch.float32).to(device)
fused = True
with self.assertRaises(ParametersInvalid):
layernorm_scale_shift(layernorm, x, scale, shift, fused)
def test_scale_dim(self):
device = "npu"
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=False).to(device)
x = torch.randn([2, 1024, 128], dtype=torch.float32).to(device)
scale = torch.randn([2, 1, 1024, 128], dtype=torch.float32).to(device)
shift = torch.randn([2, 128], dtype=torch.float32).to(device)
fused = True
with self.assertRaises(ParametersInvalid):
layernorm_scale_shift(layernorm, x, scale, shift, fused)
def test_scale_invalid_second_dim(self):
device = "npu"
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=False).to(device)
x = torch.randn([2, 1024, 128], dtype=torch.float32).to(device)
scale = torch.randn([2, 512, 128], dtype=torch.float32).to(device)
shift = torch.randn([2, 128], dtype=torch.float32).to(device)
fused = True
with self.assertRaises(ParametersInvalid):
layernorm_scale_shift(layernorm, x, scale, shift, fused)
def test_shift_dim(self):
device = "npu"
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=False).to(device)
x = torch.randn([2, 1024, 128], dtype=torch.float32).to(device)
scale = torch.randn([2, 128], dtype=torch.float32).to(device)
shift = torch.randn([2, 1, 1024, 128], dtype=torch.float32).to(device)
fused = True
with self.assertRaises(ParametersInvalid):
layernorm_scale_shift(layernorm, x, scale, shift, fused)
def test_shift_invalid_second_dim(self):
device = "npu"
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=False).to(device)
x = torch.randn([2, 1024, 128], dtype=torch.float32).to(device)
scale = torch.randn([2, 128], dtype=torch.float32).to(device)
shift = torch.randn([2, 512, 128], dtype=torch.float32).to(device)
fused = True
with self.assertRaises(ParametersInvalid):
layernorm_scale_shift(layernorm, x, scale, shift, fused)
def test_x_scale_dim_equal(self):
device = "npu"
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=False).to(device)
x = torch.randn([2, 1024, 128], dtype=torch.float32).to(device)
scale = torch.randn([2, 64], dtype=torch.float32).to(device)
shift = torch.randn([2, 128], dtype=torch.float32).to(device)
fused = True
with self.assertRaises(ParametersInvalid):
layernorm_scale_shift(layernorm, x, scale, shift, fused)
def test_scale_shift_dim_equal(self):
device = "npu"
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=False).to(device)
x = torch.randn([2, 1024, 128], dtype=torch.float32).to(device)
scale = torch.randn([2, 128], dtype=torch.float32).to(device)
shift = torch.randn([2, 64], dtype=torch.float32).to(device)
fused = True
with self.assertRaises(ParametersInvalid):
layernorm_scale_shift(layernorm, x, scale, shift, fused)
@torch.no_grad()
def test_layernorm_scale_shift_2d_non_affine(self):
device = "npu"
batch_size = 2
sentence_length = 1024
hidden_size = 128
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=False).to(device)
x = torch.randn([batch_size, sentence_length, hidden_size], dtype=torch.float32).to(device)
scale = torch.randn([batch_size, hidden_size], dtype=torch.float32).to(device)
shift = torch.randn([batch_size, hidden_size], dtype=torch.float32).to(device)
out_fused = layernorm_scale_shift(layernorm, x, scale, shift, fused=True)
out_non_fused = layernorm_scale_shift(layernorm, x, scale, shift, fused=False)
self.assertEqual(out_non_fused.shape, out_fused.shape)
result, _, max_err = data_compare(out_fused.cpu(), out_non_fused.cpu())
self.assertEqual(result, "success", msg=f"Data compare failed. Max error is: {max_err}")
@torch.no_grad()
def test_layernorm_scale_shift_2d_use_affine(self):
device = "npu"
batch_size = 2
sentence_length = 1024
hidden_size = 128
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=True).to(device)
x = torch.randn([batch_size, sentence_length, hidden_size], dtype=torch.float32).to(device)
scale = torch.randn([batch_size, hidden_size], dtype=torch.float32).to(device)
shift = torch.randn([batch_size, hidden_size], dtype=torch.float32).to(device)
out_fused = layernorm_scale_shift(layernorm, x, scale, shift, fused=True)
out_non_fused = layernorm_scale_shift(layernorm, x, scale, shift, fused=False)
self.assertEqual(out_non_fused.shape, out_fused.shape)
result, _, max_err = data_compare(out_fused.cpu(), out_non_fused.cpu())
self.assertEqual(result, "success", msg=f"Data compare failed. Max error is: {max_err}")
@torch.no_grad()
def test_layernorm_scale_shift_3d_non_affine(self):
device = "npu"
batch_size = 2
sentence_length = 1024
hidden_size = 128
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=False).to(device)
x = torch.randn([batch_size, sentence_length, hidden_size], dtype=torch.float32).to(device)
scale = torch.randn([batch_size, 1, hidden_size], dtype=torch.float32).to(device)
shift = torch.randn([batch_size, 1, hidden_size], dtype=torch.float32).to(device)
out_fused = layernorm_scale_shift(layernorm, x, scale, shift, fused=True)
out_non_fused = layernorm_scale_shift(layernorm, x, scale, shift, fused=False)
self.assertEqual(out_non_fused.shape, out_fused.shape)
result, _, max_err = data_compare(out_fused.cpu(), out_non_fused.cpu())
self.assertEqual(result, "success", msg=f"Data compare failed. Max error is: {max_err}")
@torch.no_grad()
def test_layernorm_scale_shift_3d_use_affine(self):
device = "npu"
batch_size = 2
sentence_length = 1024
hidden_size = 128
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=True).to(device)
x = torch.randn([batch_size, sentence_length, hidden_size], dtype=torch.float32).to(device)
scale = torch.randn([batch_size, 1, hidden_size], dtype=torch.float32).to(device)
shift = torch.randn([batch_size, 1, hidden_size], dtype=torch.float32).to(device)
out_fused = layernorm_scale_shift(layernorm, x, scale, shift, fused=True)
out_non_fused = layernorm_scale_shift(layernorm, x, scale, shift, fused=False)
self.assertEqual(out_non_fused.shape, out_fused.shape)
result, _, max_err = data_compare(out_fused.cpu(), out_non_fused.cpu())
self.assertEqual(result, "success", msg=f"Data compare failed. Max error is: {max_err}")
@torch.no_grad()
def test_layernorm_scale_shift_bsh_non_affine(self):
device = "npu"
batch_size = 2
sentence_length = 1024
hidden_size = 128
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=False).to(device)
x = torch.randn([batch_size, sentence_length, hidden_size], dtype=torch.float32).to(device)
scale = torch.randn([batch_size, sentence_length, hidden_size], dtype=torch.float32).to(device)
shift = torch.randn([batch_size, sentence_length, hidden_size], dtype=torch.float32).to(device)
out_fused = layernorm_scale_shift(layernorm, x, scale, shift, fused=True)
out_non_fused = layernorm_scale_shift(layernorm, x, scale, shift, fused=False)
self.assertEqual(out_non_fused.shape, out_fused.shape)
result, _, max_err = data_compare(out_fused.cpu(), out_non_fused.cpu())
self.assertEqual(result, "success", msg=f"Data compare failed. Max error is: {max_err}")
@torch.no_grad()
def test_layernorm_scale_shift_bsh_use_affine(self):
device = "npu"
batch_size = 2
sentence_length = 1024
hidden_size = 128
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=True).to(device)
x = torch.randn([batch_size, sentence_length, hidden_size], dtype=torch.float32).to(device)
scale = torch.randn([batch_size, sentence_length, hidden_size], dtype=torch.float32).to(device)
shift = torch.randn([batch_size, sentence_length, hidden_size], dtype=torch.float32).to(device)
out_fused = layernorm_scale_shift(layernorm, x, scale, shift, fused=True)
out_non_fused = layernorm_scale_shift(layernorm, x, scale, shift, fused=False)
self.assertEqual(out_non_fused.shape, out_fused.shape)
result, _, max_err = data_compare(out_fused.cpu(), out_non_fused.cpu())
self.assertEqual(result, "success", msg=f"Data compare failed. Max error is: {max_err}")
@torch.no_grad()
def test_layernorm_scale_shift_bsh_mixed_modulation(self):
device = "npu"
batch_size = 2
sentence_length = 1024
hidden_size = 128
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=True).to(device)
x = torch.randn([batch_size, sentence_length, hidden_size], dtype=torch.float32).to(device)
scale = torch.randn([batch_size, hidden_size], dtype=torch.float32).to(device)
shift = torch.randn([batch_size, sentence_length, hidden_size], dtype=torch.float32).to(device)
out_fused = layernorm_scale_shift(layernorm, x, scale, shift, fused=True)
out_non_fused = layernorm_scale_shift(layernorm, x, scale, shift, fused=False)
self.assertEqual(out_non_fused.shape, out_fused.shape)
result, _, max_err = data_compare(out_fused.cpu(), out_non_fused.cpu())
self.assertEqual(result, "success", msg=f"Data compare failed. Max error is: {max_err}")
@torch.no_grad()
def test_layernorm_scale_shift_3d_use_affine_and_a5(self):
device = "npu"
batch_size = 1
sentence_length = 1024
hidden_size = 128
layernorm = nn.LayerNorm(128, self.norm_eps, elementwise_affine=True).to(device)
x = torch.randn([batch_size, sentence_length, hidden_size], dtype=torch.float32).to(device)
scale = torch.randn([batch_size, hidden_size], dtype=torch.float32).to(device)
shift = torch.randn([batch_size, hidden_size], dtype=torch.float32).to(device)
origin_ops_v2 = torch.ops.mindiesd.adaln_v2
origin_ops_v1 = torch.ops.mindiesd.adaln
ops_mock_v2 = Mock()
ops_mock_v1 = Mock()
def mock_ops_v2(*args, **kwargs):
ops_mock_v2()
return origin_ops_v2(*args, **kwargs)
def mock_ops_v1(*args, **kwargs):
ops_mock_v1()
return origin_ops_v1(*args, **kwargs)
torch.ops.mindiesd.adaln_v2 = mock_ops_v2
torch.ops.mindiesd.adaln = mock_ops_v1
try:
out = layernorm_scale_shift(layernorm, x, scale, shift, fused=True)
if get_npu_device() == NPUDevice.A5:
ops_mock_v2.assert_called_once()
else:
ops_mock_v1.assert_called_once()
finally:
torch.ops.mindiesd.adaln_v2 = origin_ops_v2
torch.ops.mindiesd.adaln = origin_ops_v1
if __name__ == "__main__":
torch_npu.npu.set_device(DEVICE_ID)
unittest.main()