import os
import torch
import numpy as np

import torch_npu
from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices

os.environ['PYTORCH_NPU_ALLOC_CONF'] = 'expandable_segments:False'


class TestNpu(TestCase):

    @skipIfUnsupportMultiNPU(2)
    def test_creat_tensor(self):
        device_number = torch.npu.device_count()
        for device_idx in range(device_number):
            t = torch.randn(2, 255, 255, device=f"npu:{device_idx}")
            self.assertTrue(str(t.device) == f"npu:{device_idx}")

    @skipIfUnsupportMultiNPU(2)
    def test_creat_stream(self):
        device_number = torch.npu.device_count()
        stream_instance = set()
        for i in range(device_number):
            torch.npu.set_device(i)
            default_stream = torch.npu.default_stream()
            current_stream = torch.npu.current_stream()
            stream_instance.add(current_stream)
        self.assertTrue(len(stream_instance) == device_number)

    def _test_host_to_device(self, t_cpu):
        t_device_0 = t_cpu.to("npu:0")
        self.assertTrue(str(t_device_0.device) == "npu:0")
        t_device_1 = t_cpu.to("npu:1")
        self.assertTrue(str(t_device_1.device) == "npu:1")
        self.assertRtolEqual(t_cpu.numpy(), t_device_0.cpu().numpy())
        self.assertRtolEqual(t_cpu.numpy(), t_device_1.cpu().numpy())

    def _test_device_to_device(self, t_cpu):
        t_device_0 = t_cpu.to("npu:0")
        self.assertTrue(str(t_device_0.device) == "npu:0")
        t_device_1 = t_device_0.to("npu:1")
        self.assertTrue(str(t_device_1.device) == "npu:1")
        self.assertRtolEqual(t_device_1.cpu(
        ).numpy(), t_device_0.cpu().numpy())

    def _test_device_copy(self):
        t_cpu = torch.rand(2, 255, 255)
        self._test_host_to_device(t_cpu)
        self._test_device_to_device(t_cpu)

    def _test_module(self):

        class MyModule(torch.nn.Module):

            def __init__(self):
                super().__init__()
                self.fc1 = torch.nn.Linear(2, 2)
                self.fc2 = torch.nn.Linear(2, 2)

            def forward(self, x):
                if (x.device.type == 'cpu'):
                    x = self.fc1(x)
                    return self.fc2(x)
                else:
                    self.fc1 = self.fc1.to("npu:0")
                    self.fc2 = self.fc2.to("npu:1")
                    x = self.fc1(x)
                    return self.fc2(x.to("npu:1"))

        module = MyModule()
        module.eval()
        input_t = torch.rand(8, 2)
        output_cpu = module(input_t)
        output_npu = module(input_t.to("npu:0"))
        self.assertTrue(str(output_npu.device) == "npu:1")
        self.assertRtolEqual(output_cpu.detach().numpy(), output_npu.detach().cpu().numpy(), prec=1.e-3)

    @skipIfUnsupportMultiNPU(2)
    def test_aclop_with_multi_device(self):
        torch.npu.set_compile_mode(jit_compile=True)
        self._test_device_copy()
        self._test_module()

    @skipIfUnsupportMultiNPU(2)
    def test_opapi_with_multi_device(self):
        torch.npu.set_compile_mode(jit_compile=False)
        self._test_device_copy()
        self._test_module()


class TestOp(TestCase):

    def _cpu_op_exec(self, input1):
        output = torch.abs(input1)
        output = output.cpu().numpy()
        return output

    def _npu_op_exec(self, input1):
        output = torch.abs(input1)
        output = output.cpu().numpy()
        return output

    def _test_abs(self, device="npu:1"):
        torch.npu.set_device(0)
        cpu_input = torch.Tensor([1, -2, -10])
        npu_input = cpu_input.to(device)
        cpu_output = self._cpu_op_exec(cpu_input)
        npu_output = self._npu_op_exec(npu_input)
        self.assertRtolEqual(cpu_output, npu_output)

    def _test_isfinite(self, device="npu:1"):
        torch.npu.set_device(0)
        x = torch.Tensor([1, 2, -10]).to(device)
        output = torch.isfinite(x)
        self.assertTrue(output.all())

    def _test_unique_dim(self, device="npu:1", dtype=torch.float):
        torch.npu.set_device(0)
        self.assertFalse(hasattr(torch, "unique_dim"))

        x = torch.tensor([[[1., 1.],
                               [0., 1.],
                               [2., 1.],
                               [0., 1.]],
                              [[1., 1.],
                               [0., 1.],
                               [2., 1.],
                               [0., 1.]]],
                             dtype=dtype,
                             device=device)
        expected_unique_dim0 = torch.tensor([[[1., 1.],
                                                  [0., 1.],
                                                  [2., 1.],
                                                  [0., 1.]]],
                                                dtype=dtype,
                                                device=device)
        expected_inverse_dim0 = torch.tensor([0, 0])
        expected_counts_dim0 = torch.tensor([2])

        x_unique, x_inverse, x_counts = torch.unique(
                x,
                return_inverse=True,
                return_counts=True,
                dim=0)
        self.assertEqual(expected_unique_dim0, x_unique)
        self.assertEqual(expected_inverse_dim0, x_inverse)
        self.assertEqual(expected_counts_dim0, x_counts)

    def _supported_op_exec(self, query_states1, past_key, past_value, head_dim):
        attn_weights1 = torch.matmul(query_states1, past_key.transpose(2, 3)) / 0.0078125
        attn_weights1 = torch.max(attn_weights1, torch.full(
            (1, 1), torch.finfo(attn_weights1.dtype).min, device=attn_weights1.device))
        attn_weights1 = torch.nn.functional.softmax(attn_weights1, dim=-1, dtype=torch.float32).to(query_states1.dtype)
        attn_output1 = torch.matmul(attn_weights1, past_value)
        return attn_output1

    def _custom_op_exec(self, query, key, value, head_dim):
        scale = 1 / 0.0078125
        return torch_npu.npu_prompt_flash_attention(
            query, key, value, num_heads=32, input_layout="BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535, sparse_mode=0)

    @SupportedDevices(['Ascend910B'])
    def _test_npu_prompt_flash_attention(self, device="npu:1"):
        torch.npu.set_device(0)
        query = torch.randn(1, 32, 2048, 128, dtype=torch.float16).to(device)
        key = torch.randn(1, 32, 2048, 128, dtype=torch.float16).to(device)
        value = torch.randn(1, 32, 2048, 128, dtype=torch.float16).to(device)

        head_dim = 128

        supported_output = self._supported_op_exec(query, key, value, head_dim)
        custom_output = self._custom_op_exec(query, key, value, head_dim)
        self.assertRtolEqual(supported_output, custom_output)

    @skipIfUnsupportMultiNPU(2)
    def test_aclop_op_with_multi_device(self):
        torch.npu.set_compile_mode(jit_compile=True)
        self._test_abs()
        self._test_isfinite()
        self._test_unique_dim()
        self._test_npu_prompt_flash_attention()

    @skipIfUnsupportMultiNPU(2)
    def test_opapi_op_with_multi_device(self):
        torch.npu.set_compile_mode(jit_compile=False)
        self._test_abs()
        self._test_isfinite()
        self._test_unique_dim()
        self._test_npu_prompt_flash_attention()


if __name__ == '__main__':
    run_tests()