import unittest
import torch
import numpy as np
from torch_npu.testing.testcase import TestCase, run_tests
device = 'npu:0'
torch.npu.set_device(device)
class TestUtilities(TestCase):
@unittest.skip("Different compile parameters will cause different results")
def test_compiled_with_cxx11_abi(self):
output = torch.compiled_with_cxx11_abi()
self.assertTrue(output)
def test_result_type(self):
self.assertEqual(torch.result_type(torch.tensor(1, dtype=torch.int, device=device), 1), torch.int)
self.assertEqual(torch.result_type(1, torch.tensor(1, dtype=torch.int, device=device)), torch.int)
self.assertEqual(torch.result_type(1, 1.), torch.get_default_dtype())
self.assertEqual(torch.result_type(torch.tensor(1, device=device), 1.), torch.get_default_dtype())
self.assertEqual(torch.result_type(torch.tensor(1, dtype=torch.long, device=device),
torch.tensor([1, 1], dtype=torch.int, device=device)),
torch.int)
self.assertEqual(torch.result_type(torch.tensor([1., 1.], dtype=torch.float, device=device), 1.), torch.float)
self.assertEqual(torch.result_type(torch.tensor(1., dtype=torch.float, device=device),
torch.tensor(1, dtype=torch.double, device=device)),
torch.double)
def test_can_cast(self):
self.assertTrue(torch.can_cast(torch.double, torch.float))
self.assertFalse(torch.can_cast(torch.float, torch.int))
def test_promote_types(self):
self.assertEqual(torch.promote_types(torch.float, torch.int), torch.float)
self.assertEqual(torch.promote_types(torch.float, torch.double), torch.double)
self.assertEqual(torch.promote_types(torch.int, torch.uint8), torch.int)
if __name__ == "__main__":
run_tests()