import unittest
import torch
import torch_npu

from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices


class TestTransposeQuantBatchMatmul(TestCase):

    def supported_op_exec(self, x1, x2, scale):
        x1 = x1.transpose(0, 1)
        out = torch.matmul(x1.float(), x2.float())
        out = out.transpose(0, 1)
        out = out.reshape(out.shape[0], 1, out.shape[1] * out.shape[2])
        data = (out * scale).to(torch.int)
        output = torch.clip(data, -128, 127).to(torch.float16)
        return output

    def supported_op_exec_2(self, x1, x2):
        x1 = x1.transpose(0, 1)
        out = torch.matmul(x1.float(), x2.float())
        out = out.transpose(0, 1)
        return out.to(torch.float16)

    def supported_op_exec_3(self, x1, x2, batch_split_factor):
        x1 = x1.transpose(0, 1)
        out = torch.matmul(x1.float(), x2.float())
        out = out.transpose(0, 1)
        output = out.reshape(x1.shape[1], batch_split_factor, -1)
        output = output.transpose(0, 1)
        return output.to(torch.float16)

    @unittest.skip("Skipping test_npu_transpose_quant_batchmatmul temporarily")
    @SupportedDevices(["Ascend950"])
    def test_npu_transpose_quant_batchmatmul(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e5m2)
        x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e5m2)
        x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
        x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
        supported_output = self.supported_op_exec(x1, x2, x1_scale, x2_scale)
        custom_output = torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.float16,
                                                            x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
                                                            perm_x1=[1, 0, 2], perm_x2=[0, 1, 2], perm_y=[1, 0, 2]).to("cpu")
        self.assertRtolEqual(supported_output, custom_output, 0.01)

    @unittest.skip("Skipping test_npu_transpose_quant_batchmatmul temporarily")
    @SupportedDevices(["Ascend950"])
    def test_npu_transpose_quant_batchmatmul_1(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e5m2)
        x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e5m2)
        x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
        x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
        supported_output = self.supported_op_exec(x1, x2, x1_scale, x2_scale)
        custom_output = torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.bfloat16,
                                                            x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
                                                            perm_x1=[1, 0, 2], perm_x2=[0, 1, 2], perm_y=[1, 0, 2]).to("cpu")
        self.assertRtolEqual(supported_output, custom_output, 0.01)

    @unittest.skip("Skipping test_npu_transpose_batchmatmul temporarily")
    @SupportedDevices(["Ascend950"])
    def test_npu_transpose_quant_batchmatmul_2(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e4m3fn)
        x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e4m3fn)
        x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
        x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
        supported_output = self.supported_op_exec(x1, x2, x1_scale, x2_scale)
        custom_output = torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.float16,
                                                            x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
                                                            perm_x1=[1, 0, 2], perm_x2=[0, 1, 2], perm_y=[1, 0, 2]).to("cpu")
        self.assertRtolEqual(supported_output, custom_output, 0.001)

    @unittest.skip("Skipping test_npu_transpose_batchmatmul temporarily")
    @SupportedDevices(["Ascend950"])
    def test_npu_transpose_quant_batchmatmul_3(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e4m3fn)
        x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e4m3fn)
        x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
        x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
        supported_output = self.supported_op_exec(x1, x2, x1_scale, x2_scale)
        custom_output = torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.bfloat16,
                                                            x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
                                                            perm_x1=[1, 0, 2], perm_x2=[0, 1, 2], perm_y=[1, 0, 2]).to("cpu")
        self.assertRtolEqual(supported_output, custom_output, 0.001)

    @unittest.skip("Skipping test_npu_transpose_batchmatmul temporarily")
    @SupportedDevices(['Ascend950'])
    def test_npu_transpose_quant_batchmatmul_mxfp8(self):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e4m3fn)
        x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e4m3fn)
        x1_clone = x1.clone()
        x2_clone = x2.clone()
        x1_scale = torch.full((M, Batch, int(K/64), 2), 2, dtype=torch.uint8)
        x2_scale = torch.full((Batch, int(K/64), N, 2), 2, dtype=torch.uint8)
        x1_scale_clone = torch.full((M, Batch, int(K/64), 2), 2, dtype=torch.float8_e8m0fnu)
        x2_scale_clone = torch.full((Batch, int(K/64), N, 2), 2, dtype=torch.float8_e8m0fnu)
        x1 = x1.to(torch.float32).numpy()
        x2 = x2.to(torch.float32).numpy()
        x1 = x1.transpose(1,0,2)
        x1_scale = x1_scale.numpy().astype(np.float32)
        x2_scale = x2_scale.numpy().astype(np.float32)
        x1_scale = x1_scale.reshape(x1_scale.shape[0], x1_scale.shape[1], x1_scale.shape[2] * 2)
        x1_scale = x1_scale.transpose(1,0,2)
        x2_scale = x2_scale.transpose(0,1,3,2)
        x2_scale = x2_scale.reshape(x2_scale.shape[0], x2_scale.shape[1] * 2, x2_scale.shape[3])
        x1_scale_broadcast = np.repeat(x1_scale, 32, axis=-1)
        x2_scale_broadcast = np.repeat(x2_scale, 32, axis=-2)
        x1 = x1 * x1_scale_broadcast
        x2 = x2 * x2_scale_broadcast
        x1 = torch.from_numpy(x1.astype(np.float32))
        x2 = torch.from_numpy(x2.astype(np.float32))
        supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32))
        supported_output = torch.permute(supported_output, [1, 0, 2])
        custom_output = torch_npu.npu_transpose_quant_batchmatmul(x1_clone.npu(), x2_clone.npu(), dtype=torch.bfloat16,
                                                            x1_scale=x1_scale_clone.npu(), x2_scale=x2_scale_clone.npu(),
                                                            group_sizes=[0, 0, 32], perm_x1=[1, 0, 2], perm_x2=[0, 1, 2], perm_y=[1, 0, 2])
        self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.001)

    @unittest.skip("Skipping test_npu_transpose_quant_batchmatmul temporarily")
    @SupportedDevices(["Ascend950"])
    def test_npu_transpose_quant_batchmatmul_4(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e4m3fn)
        x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e4m3fn)
        x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
        x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
        with self.assertRaisesRegex(RuntimeError, "perm_x1 should be [1, 0, 2]"):
            torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.float16,
                                            x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
                                            perm_x1=[1, 1, 2], perm_x2=[0, 1, 2],
                                            perm_y=[1, 0, 2]).to("cpu")

    @unittest.skip("Skipping test_npu_transpose_quant_batchmatmul temporarily")
    @SupportedDevices(["Ascend950"])
    def test_npu_transpose_quant_batchmatmul_5(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e4m3fn)
        x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e4m3fn)
        x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
        x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
        with self.assertRaisesRegex(RuntimeError, "perm_x2 should be [0, 1, 2]"):
            torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.float16,
                                            x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
                                            perm_x1=[1, 0, 2], perm_x2=[1, 1, 2],
                                            perm_y=[1, 0, 2]).to("cpu")

    @unittest.skip("Skipping npu_transpose_quant_batchmatmul temporarily")
    @SupportedDevices(["Ascend950"])
    def test_npu_transpose_quant_batchmatmul_6(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e4m3fn)
        x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e4m3fn)
        x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
        x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
        with self.assertRaisesRegex(RuntimeError, "perm_y should be [1, 0, 2]"):
            torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.float16,
                                            x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
                                            perm_x1=[1, 0, 2], perm_x2=[0, 1, 2],
                                            perm_y=[1, 1, 2]).to("cpu")

    @unittest.skip("Skipping npu_transpose_quant_batchmatmul temporarily")
    @SupportedDevices(["Ascend950"])
    def test_npu_transpose_quant_batchmatmul_7(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e4m3fn)
        x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e4m3fn)
        x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
        x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
        with self.assertRaisesRegex(RuntimeError, "x1's type supported for float8_e5m2 or float8_e4m3fn"):
            torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.float16,
                                            x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
                                            perm_x1=[1, 0, 2], perm_x2=[0, 1, 2],
                                            perm_y=[1, 0, 2]).to("cpu")

    @unittest.skip("Skipping npu_transpose_quant_batchmatmul temporarily")
    @SupportedDevices(["Ascend950"])
    def test_npu_transpose_quant_batchmatmul_8(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e4m3fn)
        x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e4m3fn)
        x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
        x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
        with self.assertRaisesRegex(RuntimeError, "x2's type supported for float8_e5m2 or float8_e4m3fn"):
            torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.float16,
                                            x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
                                            perm_x1=[1, 0, 2], perm_x2=[0, 1, 2],
                                            perm_y=[1, 0, 2]).to("cpu")

    @SupportedDevices(["Ascend950"])
    def test_npu_transpose_quant_batchmatmul_nz(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e4m3fn)
        x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e4m3fn)
        x1_clone = x1.clone()
        x2_clone = torch_npu.npu_format_cast(x2.npu(), acl_format=29)
        x1_scale = torch.randint(2, 5, (M, Batch, int(K/64), 2), dtype=torch.uint8).view(torch.float8_e8m0fnu)
        x2_scale = torch.randint(2, 5, (Batch, int(K/64), N, 2), dtype=torch.uint8).view(torch.float8_e8m0fnu)
        x1_scale_clone = x1_scale.clone()
        x2_scale_clone = x2_scale.clone()

        x1 = x1.to(torch.float32).numpy()
        x2 = x2.to(torch.float32).numpy()
        x1 = x1.transpose(1,0,2)
        x1_scale = x1_scale.to(torch.float32).numpy()
        x2_scale = x2_scale.to(torch.float32).numpy()
        x1_scale = x1_scale.reshape(x1_scale.shape[0], x1_scale.shape[1], x1_scale.shape[2] * 2)
        x1_scale = x1_scale.transpose(1,0,2)
        x2_scale = x2_scale.transpose(0,1,3,2)
        x2_scale = x2_scale.reshape(x2_scale.shape[0], x2_scale.shape[1] * 2, x2_scale.shape[3])
        x1_scale_broadcast = np.repeat(x1_scale, 32, axis=-1)
        x2_scale_broadcast = np.repeat(x2_scale, 32, axis=-2)
        x1 = x1 * x1_scale_broadcast
        x2 = x2 * x2_scale_broadcast
        x1 = torch.from_numpy(x1.astype(np.float32))
        x2 = torch.from_numpy(x2.astype(np.float32))
        supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32))
        supported_output = torch.permute(supported_output, [1, 0, 2])
        custom_output = torch_npu.npu_transpose_quant_batchmatmul(x1_clone.npu(), x2_clone, dtype=torch.bfloat16,
                                                            x1_scale=x1_scale_clone.npu(), x2_scale=x2_scale_clone.npu(),
                                                            group_sizes=[0, 0, 32], perm_x1=[1, 0, 2], perm_x2=[0, 1, 2], perm_y=[1, 0, 2])
        self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.001)


if __name__ == "__main__":
    run_tests()