import unittest
import numpy as np
import torch
import torch.nn as nn
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices


class TestRope(TestCase):
    rope_theta = 10000
    head_dim = 128
    num_heads = 32
    max_position_embeddings = 8192
    batch_size = 1
    seq_length = 4

    def compute_inv_freq(self, base):
        inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2, dtype=torch.float) / self.head_dim))
        return inv_freq

    def compute_cos_sin_cache(self):
        inv_freq = self.compute_inv_freq(self.rope_theta)
        t = torch.arange(self.max_position_embeddings, dtype=torch.float)
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        return cache.to('npu')

    def _apply_rotary_emb(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
        cos = cos.unsqueeze(-2).to(x.dtype)
        sin = sin.unsqueeze(-2).to(x.dtype)
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
        o1 = x1 * cos - x2 * sin
        o2 = x2 * cos + x1 * sin
        return torch.stack((o1, o2), dim=-1).flatten(-2)

    def native_rope(self, positions, query, key):
        positions = positions.flatten()
        num_tokens = positions.shape[0]
        cos_sin = self.compute_cos_sin_cache().index_select(0, positions)
        cos, sin = cos_sin.chunk(2, dim=-1)

        query_shape = query.shape
        query = query.view(num_tokens, -1, self.head_dim)
        query_rot = query[..., :self.head_dim]
        query_remainder = query[..., self.head_dim:]
        query_rot = self._apply_rotary_emb(query_rot, cos, sin)
        query = torch.cat((query_rot, query_remainder), dim=-1).reshape(query_shape)

        key_shape = key.shape
        key = key.view(num_tokens, -1, self.head_dim)
        key_rot = key[..., :self.head_dim]
        key_remainder = key[..., self.head_dim:]
        key_rot = self._apply_rotary_emb(key_rot, cos, sin)
        key = torch.cat((key_rot, key_remainder), dim=-1).reshape(key_shape)
        return query, key

    @SupportedDevices(['Ascend910B'])
    def test_rope(self):
        cos_sin_cache = self.compute_cos_sin_cache()
        positions = torch.arange(self.seq_length).repeat(self.batch_size).npu()
        query = torch.rand(self.batch_size * self.seq_length, self.num_heads * self.head_dim, dtype=torch.float16).npu()
        key = torch.rand(self.batch_size * self.seq_length, self.num_heads * self.head_dim, dtype=torch.float16).npu()
        is_neox = False
        expected_query, expected_key = self.native_rope(positions, query, key)
        torch_npu._npu_rotary_embedding(positions, query, key, self.head_dim, cos_sin_cache, is_neox)

        self.assertRtolEqual(expected_query, query)
        self.assertRtolEqual(expected_key, key)

    @unittest.skipIf(torch.__version__ < '2.5.1', "This compile ut needs torch version >=2.5.1")
    @SupportedDevices(['Ascend910B'])
    def test_rope_compile(self):
        class RopeModel(nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, positions, query, key, head_dim, cos_sin_cache, is_neox):
                torch_npu._npu_rotary_embedding(positions, query, key, head_dim, cos_sin_cache, is_neox)
                return query, key
        cos_sin_cache = self.compute_cos_sin_cache()
        positions = torch.arange(self.seq_length).repeat(self.batch_size).npu()
        query = torch.rand(self.batch_size * self.seq_length, self.num_heads * self.head_dim, dtype=torch.float16).npu()
        key = torch.rand(self.batch_size * self.seq_length, self.num_heads * self.head_dim, dtype=torch.float16).npu()
        query1, key1 = query.clone(), key.clone()
        is_neox = False
        model = RopeModel()
        compiled_model = torch.compile(
            model,
            backend="aot_eager",
            fullgraph=True,
        )
        compiled_output = compiled_model(positions, query, key, self.head_dim, cos_sin_cache, is_neox)
        torch_npu._npu_rotary_embedding(positions, query1, key1, self.head_dim, cos_sin_cache, is_neox)
        self.assertRtolEqual(compiled_output[0], query1)
        self.assertRtolEqual(compiled_output[1], key1)


if __name__ == '__main__':
    run_tests()