import torch
import torch_npu
from torch_npu.npu._recovery import (
check_npu_tensor_is_safe,
mark_all_npu_tensor_unsafe,
set_npu_tensor_unsafe_check_flag,
get_npu_tensor_unsafe_check_flag,
)
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.npu._recovery import restart_device
class TestNpu(TestCase):
def test_catch_data_check_error(self):
torch.npu.set_device(0)
tensor_a = torch.randn(2, 255, 255, device="npu:0")
tensor_b = torch.randn(2, 255, 255, device="npu:0")
set_npu_tensor_unsafe_check_flag(True)
mark_all_npu_tensor_unsafe(0)
self.assertFalse(check_npu_tensor_is_safe(tensor_a))
self.assertFalse(check_npu_tensor_is_safe(tensor_b))
with self.assertRaisesRegex(RuntimeError, "There is unsafe data in the input tensor."):
tensor_c = tensor_a + tensor_b
def test_mark_all_npu_tensors_unsafe_and_update_safe(self):
torch.npu.set_device(0)
tensor_a = torch.randn(2, 255, 255, device="npu:0")
tensor_b = torch.randn(2, 255, 255, device="npu:0")
self.assertTrue(check_npu_tensor_is_safe(tensor_a))
self.assertTrue(check_npu_tensor_is_safe(tensor_b))
mark_all_npu_tensor_unsafe(0)
self.assertFalse(check_npu_tensor_is_safe(tensor_a))
self.assertFalse(check_npu_tensor_is_safe(tensor_b))
del tensor_a
tensor_a_new = torch.randn(2, 255, 255, device="npu:0")
self.assertTrue(check_npu_tensor_is_safe(tensor_a_new))
tensor_b.copy_(tensor_a_new)
self.assertTrue(check_npu_tensor_is_safe(tensor_b))
def test_restart_device_with_rebuild(self):
torch.npu.set_device(0)
restart_device(0, rebuild_all_resources=True)
self.assertTrue(True)
def test_check_npu_tensor_is_safe_invalid_type(self):
with self.assertRaises(RuntimeError):
check_npu_tensor_is_safe("invalid_tensor")
def test_restart_device_with_disable_tensor_unsafe_check(self):
torch.npu.set_device(0)
tensor = torch.randn(2, 3, device="npu:0")
set_npu_tensor_unsafe_check_flag(False)
self.assertTrue(check_npu_tensor_is_safe(tensor))
self.assertFalse(get_npu_tensor_unsafe_check_flag())
restart_device(0, rebuild_all_resources=True, disable_tensor_unsafe_check=True)
self.assertTrue(check_npu_tensor_is_safe(tensor))
self.assertFalse(get_npu_tensor_unsafe_check_flag())
set_npu_tensor_unsafe_check_flag(False)
restart_device(0, rebuild_all_resources=True, disable_tensor_unsafe_check=False)
self.assertFalse(check_npu_tensor_is_safe(tensor))
self.assertTrue(get_npu_tensor_unsafe_check_flag())
set_npu_tensor_unsafe_check_flag(False)
restart_device(0, rebuild_all_resources=True)
self.assertFalse(check_npu_tensor_is_safe(tensor))
self.assertTrue(get_npu_tensor_unsafe_check_flag())
set_npu_tensor_unsafe_check_flag(False)
if __name__ == '__main__':
run_tests()