import math
import unittest
import copy
import struct
from struct import pack, unpack
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices


class TestQuantScatter(TestCase):
    def supported_op_exec(self, query, key, cos, sin):
        x1 = query[..., :64].cpu()
        x2 = query[..., 64:].cpu()
        concat = np.concatenate((-x2, x1), axis=-1)
        x2_mul = torch.from_numpy(concat).npu() * sin
        x1_mul = query * cos
        res0 = x2_mul + x1_mul

        k1 = key[..., :64].cpu()
        k2 = key[..., 64:].cpu()
        concatk = np.concatenate((-k2, k1), axis=-1)
        x1k_mul = torch.from_numpy(concatk).npu() * sin
        x2k_mul = key * cos
        res1 = x2k_mul + x1k_mul
        return [res0, res1]

    @SupportedDevices(['Ascend910B'])
    def test_npu_apply_rotary_pos_emb(self, device="npu"):
        query_data = np.random.uniform(0, 1, [4, 1024, 16, 128]).astype(np.float16)
        query1 = torch.from_numpy(query_data).to(torch.float16).npu()
        query2 = query1.clone()

        key_data = np.random.uniform(0, 1, [4, 1024, 16, 128]).astype(np.float16)
        key1 = torch.from_numpy(key_data).to(torch.float16).npu()
        key2 = key1.clone()

        cos_data = np.random.uniform(0, 1, [4, 1024, 1, 128]).astype(np.float16)
        cos1 = torch.from_numpy(cos_data).to(torch.float16).npu()
        cos2 = cos1.clone()

        sin_data = np.random.uniform(0, 1, [4, 1024, 1, 128]).astype(np.float16)
        sin1 = torch.from_numpy(sin_data).to(torch.float16).npu()
        sin2 = sin1.clone()

        supported_output = self.supported_op_exec(query1, key1, cos1, sin1)
        custom_output = torch_npu.npu_apply_rotary_pos_emb(query2, key2, cos2, sin2, 'BSND')
        self.assertRtolEqual(supported_output, custom_output, 0.001)

    @unittest.skip("Skip until CANN is updated to support layout TND format")
    @SupportedDevices(['Ascend910B'])
    def test_npu_apply_rotary_pos_emb_TND(self, device="npu"):
        query_data = np.random.uniform(0, 1, [1024, 16, 128]).astype(np.float16)
        query1 = torch.from_numpy(query_data).to(torch.float16).npu()
        query2 = query1.clone()

        key_data = np.random.uniform(0, 1, [1024, 16, 128]).astype(np.float16)
        key1 = torch.from_numpy(key_data).to(torch.float16).npu()
        key2 = key1.clone()

        cos_data = np.random.uniform(0, 1, [1024, 1, 128]).astype(np.float16)
        cos1 = torch.from_numpy(cos_data).to(torch.float16).npu()
        cos2 = cos1.clone()

        sin_data = np.random.uniform(0, 1, [1024, 1, 128]).astype(np.float16)
        sin1 = torch.from_numpy(sin_data).to(torch.float16).npu()
        sin2 = sin1.clone()

        supported_output = self.supported_op_exec(query1, key1, cos1, sin1)
        custom_output = torch_npu.npu_apply_rotary_pos_emb(query2, key2, cos2, sin2, 'TND')
        self.assertRtolEqual(supported_output, custom_output, 0.001)

if __name__ == "__main__":
    run_tests()