import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
class TestDevice(TestCase):
def device_monitor(func):
def wrapper(self, *args, **kwargs):
device_id = 0
torch.npu.set_device(device_id)
npu_device = torch.randn(2).npu(device_id).device
device_types = [
"npu",
"npu:" + str(device_id),
torch.device("npu:" + str(device_id)),
torch.device("npu:" + str(device_id)).type,
npu_device
]
for device_type in device_types:
kwargs["device"] = device_type
npu_tensor = func(self, *args, **kwargs)
self.assertEqual(npu_tensor.device.type, "npu")
self.assertEqual(npu_tensor.device.index, device_id)
kwargs["device"] = None
func(self, *args, **kwargs)
return wrapper
@device_monitor
def test_torch_tensor_to_device(self, device=None):
cpu_tensor = torch.randn(2, 3)
return cpu_tensor.to(device, torch.int64)
@device_monitor
def test_torch_tensor_new_empty_with_device_input(self, device=None):
npu_tensor = torch.ones(2, 3).to(device)
return npu_tensor.new_empty((2, 3), dtype=torch.float16, device=device)
@device_monitor
def test_torch_func_arange_with_device_input(self, device=None):
return torch.arange(5, dtype=torch.float32, device=device)
@device_monitor
def test_torch_func_zeros_with_device_input(self, device=None):
return torch.zeros((2, 3), dtype=torch.int8, device=device)
@device_monitor
def test_tensor_method_npu_with_device_input(self, device=None):
if isinstance(device, str):
device = torch.device(device)
cpu_input = torch.randn(2, 3)
return cpu_input.npu(device)
@device_monitor
def test_torch_func_tensor_with_device_input(self, device=None):
return torch.tensor((2, 3), device=device)
def test_device_argument_as_input(self):
device_str = "npu:0"
torch.npu.set_device(device_str)
device = torch.device(device_str)
assert isinstance(device, torch.device)
torch.npu.set_device(device)
tensor = torch.rand(2, 3).npu()
assert isinstance(tensor.device, torch.device)
assert tensor.device.type == "npu"
assert tensor.device.index == 0
new_device = torch.device(device)
assert isinstance(new_device, torch.device)
assert new_device.type == "npu"
assert new_device.index == 0
new_device = torch.device(device=device)
assert isinstance(new_device, torch.device)
assert new_device.type == "npu"
assert new_device.index == 0
new_device = torch.device(device=device_str)
assert isinstance(new_device, torch.device)
assert new_device.type == "npu"
assert new_device.index == 0
new_device = torch.device(type="npu", index=0)
assert isinstance(new_device, torch.device)
assert new_device.type == "npu"
assert new_device.index == 0
def test_torch_npu_device(self):
device = torch.device(0)
assert device.type == "npu"
device = torch.device(device=0)
assert device.type == "npu"
assert isinstance(device, torch._C.device)
assert isinstance(device, torch.device)
def test_multithread_device(self):
import threading
def _worker(result):
try:
cur = torch_npu.npu.current_device()
self.assertEqual(cur, 0)
except Exception:
result[0] = 1
result = [0]
torch.npu.set_device("npu:0")
thread = threading.Thread(target=_worker, args=(result,))
thread.start()
thread.join()
self.assertEqual(result[0], 0)
if __name__ == '__main__':
run_tests()