import torch
import torch_npu

from torch_npu.npu._recovery import (
    check_npu_tensor_is_safe,
    mark_all_npu_tensor_unsafe,
    set_npu_tensor_unsafe_check_flag,
    get_npu_tensor_unsafe_check_flag,
)
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.npu._recovery import restart_device


class TestNpu(TestCase):

    def test_catch_data_check_error(self):
        torch.npu.set_device(0)
        tensor_a = torch.randn(2, 255, 255, device="npu:0")
        tensor_b = torch.randn(2, 255, 255, device="npu:0")
        set_npu_tensor_unsafe_check_flag(True)
        mark_all_npu_tensor_unsafe(0)
        self.assertFalse(check_npu_tensor_is_safe(tensor_a))
        self.assertFalse(check_npu_tensor_is_safe(tensor_b))
        with self.assertRaisesRegex(RuntimeError, "There is unsafe data in the input tensor."):
            tensor_c = tensor_a + tensor_b

    def test_mark_all_npu_tensors_unsafe_and_update_safe(self):
        torch.npu.set_device(0)
        tensor_a = torch.randn(2, 255, 255, device="npu:0")
        tensor_b = torch.randn(2, 255, 255, device="npu:0")
        self.assertTrue(check_npu_tensor_is_safe(tensor_a))
        self.assertTrue(check_npu_tensor_is_safe(tensor_b))
        # data on device 0 is marked unsafe
        mark_all_npu_tensor_unsafe(0)
        self.assertFalse(check_npu_tensor_is_safe(tensor_a))
        self.assertFalse(check_npu_tensor_is_safe(tensor_b))
        # release tensor_a and empty again, the data is safe
        del tensor_a
        tensor_a_new = torch.randn(2, 255, 255, device="npu:0")
        self.assertTrue(check_npu_tensor_is_safe(tensor_a_new))
        # d2d copy can update the unsafe tag to safe
        tensor_b.copy_(tensor_a_new)
        self.assertTrue(check_npu_tensor_is_safe(tensor_b))

    def test_restart_device_with_rebuild(self):
        torch.npu.set_device(0)
        restart_device(0, rebuild_all_resources=True)
        self.assertTrue(True)

    def test_check_npu_tensor_is_safe_invalid_type(self):
        with self.assertRaises(RuntimeError):
            check_npu_tensor_is_safe("invalid_tensor")

    def test_restart_device_with_disable_tensor_unsafe_check(self):
        torch.npu.set_device(0)
        tensor = torch.randn(2, 3, device="npu:0")

        set_npu_tensor_unsafe_check_flag(False)
        self.assertTrue(check_npu_tensor_is_safe(tensor))
        self.assertFalse(get_npu_tensor_unsafe_check_flag())

        restart_device(0, rebuild_all_resources=True, disable_tensor_unsafe_check=True)
        self.assertTrue(check_npu_tensor_is_safe(tensor))
        self.assertFalse(get_npu_tensor_unsafe_check_flag())

        set_npu_tensor_unsafe_check_flag(False)
        restart_device(0, rebuild_all_resources=True, disable_tensor_unsafe_check=False)
        self.assertFalse(check_npu_tensor_is_safe(tensor))
        self.assertTrue(get_npu_tensor_unsafe_check_flag())

        set_npu_tensor_unsafe_check_flag(False)
        restart_device(0, rebuild_all_resources=True)
        self.assertFalse(check_npu_tensor_is_safe(tensor))
        self.assertTrue(get_npu_tensor_unsafe_check_flag())

        set_npu_tensor_unsafe_check_flag(False)


if __name__ == '__main__':
    run_tests()