import unittest
import torch

import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices


class TestRotaryMul(TestCase):
    def rotary_mul(self, x, r1, r2):
        x1, x2 = torch.chunk(x, 2, -1)
        x_new = torch.cat((-x2, x1), dim=-1)
        output = r1 * x + r2 * x_new
        return output

    def rotary_mul_mode(self, x, r1, r2):
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
        x_new = torch.stack((-x2, x1), dim=-1)
        x_new = x_new.reshape(x.shape)
        res = x * r1 + x_new * r2
        return res

    def gen_data(self, shape, dtype):
        cpu_input = torch.rand(shape, dtype=dtype)
        npu_input = cpu_input.npu()
        return cpu_input, npu_input

    def get_half_matrix(self, n, dtype):
        matrix = torch.zeros(n, n, dtype=dtype)
        half = n // 2
        matrix[:half, half:] = torch.eye(half)
        matrix[half:, :half] = -torch.eye(half)
        return matrix

    def cpu_to_exec(self, x, r1, r2, mode='half'):
        if mode == 'half':
            out = self.rotary_mul(x, r1, r2)
        else:
            out = self.rotary_mul_mode(x, r1, r2)
        return out.cpu()

    def npu_to_exec(self, x, r1, r2, mode='half'):
        if mode == 'half':
            out = torch_npu.npu_rotary_mul(x, r1, r2)
        else:
            out = torch_npu.npu_rotary_mul(x, r1, r2, rotary_mode=mode)
        return out.cpu()

    def npu_to_exec_matrix(self, x, r1, r2, mode='half', rotate=None):
        out = torch_npu.npu_rotary_mul(x, r1, r2, rotary_mode=mode, rotate=rotate)
        return out.cpu()

    @SupportedDevices(['Ascend910B'])
    def test_rotary_mul(self):
        dtype_list = [torch.float16, torch.float32]
        shape_list = [
            [[2, 8192, 5, 128], [1, 8192, 1, 128], [1, 8192, 1, 128]],
            [[8192, 2, 5, 128], [8192, 1, 1, 128], [8192, 1, 1, 128]],
            [[2048, 4, 32, 64], [2048, 4, 1, 64], [2048, 4, 1, 64]],
        ]
        items = [
            [shape, dtype]
            for shape in shape_list
            for dtype in dtype_list
        ]
        for shape, dtype in items:
            cpu_x, npu_x = self.gen_data(shape[0], dtype)
            cpu_r1, npu_r1 = self.gen_data(shape[1], dtype)
            cpu_r2, npu_r2 = self.gen_data(shape[2], dtype)
            cpu_out = self.cpu_to_exec(cpu_x, cpu_r1, cpu_r2)
            npu_out = self.npu_to_exec(npu_x, npu_r1, npu_r2)
            self.assertRtolEqual(cpu_out, npu_out)

    @SupportedDevices(['Ascend910B'])
    def test_rotary_mul_mode(self):
        dtype_list = [torch.float16, torch.float32]
        shape_list = [
            [[2, 2, 5, 128], [1, 2, 1, 128], [1, 2, 1, 128]],
            [[2, 24, 5, 128], [1, 1, 5, 128], [1, 1, 5, 128]],
            [[128, 4, 4, 256], [1, 1, 4, 256], [1, 1, 4, 256]],
            [[64, 8, 8, 512], [1, 1, 8, 512], [1, 1, 8, 512]],
        ]
        items = [
            [shape, dtype]
            for shape in shape_list
            for dtype in dtype_list
        ]
        for shape, dtype in items:
            cpu_x, npu_x = self.gen_data(shape[0], dtype)
            cpu_r1, npu_r1 = self.gen_data(shape[1], dtype)
            cpu_r2, npu_r2 = self.gen_data(shape[2], dtype)
            cpu_out = self.cpu_to_exec(cpu_x, cpu_r1, cpu_r2, mode='interleave')
            npu_out = self.npu_to_exec(npu_x, npu_r1, npu_r2, mode='interleave')
            self.assertRtolEqual(cpu_out, npu_out)

    @unittest.skip("skip") # CI版本不支持
    @SupportedDevices(['Ascend910B'])
    def test_rotary_mul_mode_matrix(self):
        dtype_list = [torch.bfloat16]
        shape_list = [
            [[2, 24, 28800, 128], [1, 1, 28800, 128], [1, 1, 28800, 128]],
        ]
        items = [
            [shape, dtype]
            for shape in shape_list
            for dtype in dtype_list
        ]
        for shape, dtype in items:
            cpu_x, npu_x = self.gen_data(shape[0], dtype)
            cpu_r1, npu_r1 = self.gen_data(shape[1], dtype)
            cpu_r2, npu_r2 = self.gen_data(shape[2], dtype)
            rotate = self.get_half_matrix(shape[0][3], dtype).npu()
            cpu_out = self.cpu_to_exec(cpu_x, cpu_r1, cpu_r2, mode='half')
            npu_out = self.npu_to_exec_matrix(npu_x, npu_r1, npu_r2, mode='half', rotate=rotate)
            self.assertRtolEqual(cpu_out, npu_out)

    @unittest.skip("skip") # CI版本不支持
    @SupportedDevices(['Ascend910B'])
    def test_rotary_mul_dim3(self):
        dtype_list = [torch.float16, torch.float32]
        shape_list = [
            ((2, 2, 4), (2, 2, 4), (2, 2, 4)),
            ((4, 8, 8), (4, 8, 8), (4, 8, 8)),
            ((5, 10, 32), (1, 1, 32), (1, 1, 32)),
            ((8, 4, 64), (1, 1, 64), (1, 1, 64)),
            ((16, 8, 256), (1, 1, 256), (1, 1, 256)),
            ((1, 128, 894), (1, 1, 894), (1, 1, 894)),
            ((1, 256, 512), (1, 1, 512), (1, 1, 512)),
        ]
        items = [
            [shape, dtype]
            for shape in shape_list
            for dtype in dtype_list
        ]
        for shape, dtype in items:
            cpu_x, npu_x = self.gen_data(shape[0], dtype)
            cpu_r1, npu_r1 = self.gen_data(shape[1], dtype)
            cpu_r2, npu_r2 = self.gen_data(shape[2], dtype)
            cpu_out = self.cpu_to_exec(cpu_x, cpu_r1, cpu_r2)
            npu_out = self.npu_to_exec(npu_x, npu_r1, npu_r2)
            self.assertRtolEqual(cpu_out, npu_out)

    @unittest.skip("skip") # CI版本不支持
    @SupportedDevices(['Ascend910B'])
    def test_rotary_mul_mode_dim3(self):
        dtype_list = [torch.float16, torch.float32]
        shape_list = [
            ((1, 1, 2), (1, 1, 2), (1, 1, 2)),
            ((1, 1, 4), (1, 1, 4), (1, 1, 4)),
            ((1, 8, 2), (1, 1, 2), (1, 1, 2)),
            ((1, 8, 4), (1, 8, 4), (1, 8, 4)),
            ((16, 1, 2), (1, 1, 2), (1, 1, 2)),
            ((8, 1, 4), (8, 1, 4), (8, 1, 4)),
            ((16, 8, 2), (1, 1, 2), (1, 1, 2)),
        ]
        items = [
            [shape, dtype]
            for shape in shape_list
            for dtype in dtype_list
        ]
        for shape, dtype in items:
            cpu_x, npu_x = self.gen_data(shape[0], dtype)
            cpu_r1, npu_r1 = self.gen_data(shape[1], dtype)
            cpu_r2, npu_r2 = self.gen_data(shape[2], dtype)
            cpu_out = self.cpu_to_exec(cpu_x, cpu_r1, cpu_r2, mode='interleave')
            npu_out = self.npu_to_exec(npu_x, npu_r1, npu_r2, mode='interleave')
            self.assertRtolEqual(cpu_out, npu_out)

    @SupportedDevices(['Ascend910B'])
    def test_rotary_mul_error_param(self):
        x = torch.rand(2, 2, 5, 128).npu()
        r1 = torch.rand(1, 2, 1, 128).npu()
        r2 = torch.rand(1, 2, 1, 128).npu()
        msg = "The rotary_mode of npu_rotary_mul should be half or interleave, but got "
        with self.assertRaisesRegex(RuntimeError, msg):
            torch_npu.npu_rotary_mul(x, r1, r2, 'quarter')

if __name__ == '__main__':
    run_tests()