import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
device = 'npu:0'
torch.npu.set_device(device)
class TestParallelism(TestCase):
def test_set_num_threads(self):
torch.set_num_threads(2)
def test_get_num_threads(self):
output = torch.get_num_threads()
print(output)
def test_set_num_interop_threads(self):
torch.set_num_interop_threads(2)
def test_get_num_interop_threads(self):
output = torch.get_num_interop_threads()
print(output)
if __name__ == "__main__":
run_tests()