import unittest
import torch
import torch_npu

from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices


class TestTransposeBatchMatmul(TestCase):

    def supported_op_exec(self, x1, x2, scale):
        x1 = x1.transpose(0, 1)
        out = torch.matmul(x1.float(), x2.float())
        out = out.transpose(0, 1)
        out = out.reshape(out.shape[0], 1, out.shape[1] * out.shape[2])
        data = (out * scale).to(torch.int)
        output = torch.clip(data, -128, 127).to(torch.int8)
        return output

    def supported_op_exec_2(self, x1, x2):
        x1 = x1.transpose(0, 1)
        out = torch.matmul(x1.float(), x2.float())
        out = out.transpose(0, 1)
        return out.to(torch.float16)

    def supported_op_exec_3(self, x1, x2, batch_split_factor):
        x1 = x1.transpose(0, 1)
        out = torch.matmul(x1.float(), x2.float())
        out = out.transpose(0, 1)
        output = out.reshape(x1.shape[1], batch_split_factor, -1)
        output = output.transpose(0, 1)
        return output.to(torch.float16)

    @unittest.skip("Skipping test_npu_transpose_batchmatmul temporarily")
    @SupportedDevices(["Ascend910B"])
    def test_npu_transpose_batchmatmul(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randn((M, Batch, K), dtype=torch.float16)
        x2 = torch.randn((Batch, K, N), dtype=torch.float16)
        scale = torch.ones((Batch * N, ), dtype=torch.float32)
        supported_output = self.supported_op_exec(x1, x2, scale)
        scale = torch_npu.npu_trans_quant_param(scale.npu(), round_mode=1)
        custom_output = torch_npu.npu_transpose_batchmatmul(x1.npu(), x2.npu(), scale=scale.npu(),
                                                            perm_x1=[1, 0, 2], perm_x2=[0, 1, 2], perm_y=[1, 0, 2]).to("cpu")
        self.assertRtolEqual(supported_output, custom_output, 0.01)

    @unittest.skip("Skipping test_npu_transpose_batchmatmul temporarily")
    @SupportedDevices(["Ascend910B"])
    def test_npu_transpose_batchmatmul_2(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randn((M, Batch, K), dtype=torch.float16)
        x2 = torch.randn((Batch, K, N), dtype=torch.float16)
        supported_output = self.supported_op_exec_2(x1, x2)
        custom_output = torch_npu.npu_transpose_batchmatmul(x1.npu(), x2.npu(), scale=None,
                                                            perm_x1=[1, 0, 2], perm_x2=[0, 1, 2], perm_y=[1, 0, 2]).to("cpu")
        self.assertRtolEqual(supported_output, custom_output, 0.001)

    @unittest.skip("Skipping test_npu_transpose_batchmatmul temporarily")
    @SupportedDevices(["Ascend910B"])
    def test_npu_transpose_batchmatmul_3(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        batch_split_factor = 4
        x1 = torch.randn((M, Batch, K), dtype=torch.float16)
        x2 = torch.randn((Batch, K, N), dtype=torch.float16)
        supported_output = self.supported_op_exec_3(x1, x2, batch_split_factor)
        custom_output = torch_npu.npu_transpose_batchmatmul(x1.npu(), x2.npu(), scale=None,
                                                            perm_x1=[1, 0, 2], perm_x2=[0, 1, 2], perm_y=[1, 0, 2],
                                                            batch_split_factor=batch_split_factor).to("cpu")
        self.assertRtolEqual(supported_output, custom_output, 0.001)

    @unittest.skip("Skipping test_npu_transpose_batchmatmul temporarily")
    @SupportedDevices(["Ascend910B"])
    def test_npu_transpose_batchmatmul_4(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randn((M, Batch, K), dtype=torch.float16)
        x2 = torch.randn((Batch, K, N), dtype=torch.float16)
        with self.assertRaisesRegex(RuntimeError, "perm_x1 should be"):
            torch_npu.npu_transpose_batchmatmul(x1.npu(), x2.npu(), scale=None,
                                            perm_x1=[1, 1, 2], perm_x2=[0, 1, 2],
                                            perm_y=[1, 0, 2]).to("cpu")

    @unittest.skip("Skipping test_npu_transpose_batchmatmul temporarily")
    @SupportedDevices(["Ascend910B"])
    def test_npu_transpose_batchmatmul_5(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randn((M, Batch, K), dtype=torch.float16)
        x2 = torch.randn((Batch, K, N), dtype=torch.float16)
        with self.assertRaisesRegex(RuntimeError, "perm_x2 should be"):
            torch_npu.npu_transpose_batchmatmul(x1.npu(), x2.npu(), scale=None,
                                            perm_x1=[1, 0, 2], perm_x2=[1, 1, 2],
                                            perm_y=[1, 0, 2]).to("cpu")

    @unittest.skip("Skipping test_npu_transpose_batchmatmul temporarily")
    @SupportedDevices(["Ascend910B"])
    def test_npu_transpose_batchmatmul_6(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randn((M, Batch, K), dtype=torch.float16)
        x2 = torch.randn((Batch, K, N), dtype=torch.float16)
        with self.assertRaisesRegex(RuntimeError, "perm_y should be"):
            torch_npu.npu_transpose_batchmatmul(x1.npu(), x2.npu(), scale=None,
                                            perm_x1=[1, 0, 2], perm_x2=[0, 1, 2],
                                            perm_y=[1, 1, 2]).to("cpu")

    @unittest.skip("Skipping test_npu_transpose_batchmatmul temporarily")
    @SupportedDevices(["Ascend910B"])
    def test_npu_transpose_batchmatmul_7(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randint(-10, 10, (M, Batch, K), dtype=torch.int64)
        x2 = torch.randn((Batch, K, N), dtype=torch.float16)
        with self.assertRaisesRegex(RuntimeError, "input's type supported for float16, float32 and bfloat16"):
            torch_npu.npu_transpose_batchmatmul(x1.npu(), x2.npu(), scale=None,
                                            perm_x1=[1, 0, 2], perm_x2=[0, 1, 2],
                                            perm_y=[1, 0, 2]).to("cpu")

    @unittest.skip("Skipping test_npu_transpose_batchmatmul temporarily")
    @SupportedDevices(["Ascend910B"])
    def test_npu_transpose_batchmatmul_8(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randn((M, Batch, K), dtype=torch.float16)
        x2 = torch.randint(-10, 10, (Batch, K, N), dtype=torch.int64)
        with self.assertRaisesRegex(RuntimeError, "weight's type supported for float16, float32 and bfloat16"):
            torch_npu.npu_transpose_batchmatmul(x1.npu(), x2.npu(), scale=None,
                                            perm_x1=[1, 0, 2], perm_x2=[0, 1, 2],
                                            perm_y=[1, 0, 2]).to("cpu")

    @unittest.skip("Skipping test_npu_transpose_batchmatmul temporarily")
    @SupportedDevices(["Ascend910B"])
    def test_npu_transpose_batchmatmul_9(self, device="npu"):
        M, K, N, Batch = 32, 512, 128, 32
        x1 = torch.randn((M, Batch, K), dtype=torch.bfloat16)
        x2 = torch.randn((Batch, K, N), dtype=torch.bfloat16)
        x2_nz = torch_npu.npu_format_cast(x2.npu(), acl_format=29)
        supported_output = self.supported_op_exec_2(x1, x2)
        custom_output = torch_npu.npu_transpose_batchmatmul(x1.npu(), x2_nz.npu(), scale=None,
                                                            perm_x1=[1, 0, 2], perm_x2=[0, 1, 2], perm_y=[1, 0, 2]).to("cpu")
        self.assertRtolEqual(supported_output, custom_output, 0.001)

if __name__ == "__main__":
    run_tests()