import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
device = 'npu:0'
torch.npu.set_device(device)
def get_npu_type(type_name):
if isinstance(type_name, type):
type_name = '{}.{}'.format(type_name.__module__, type_name.__name__)
module, name = type_name.rsplit('.', 1)
assert module == 'torch'
return getattr(torch.npu, name)
class TestGenerators(TestCase):
def test_generator(self):
g_npu = torch.Generator(device=device)
print(g_npu.device)
self.assertExpectedInline(str(g_npu.device), '''npu:0''')
def test_default_generator(self):
output = torch.default_generator
print(output)
if __name__ == "__main__":
run_tests()