import os
import threading
from ctypes import byref, c_int, c_void_p, CDLL
import torch
from torch.testing._internal.common_distributed import MultiProcessTestCase
from torch.testing._internal.common_utils import run_tests
import torch_npu
from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU
class TestDevice(MultiProcessTestCase):
def setUp(self):
super().setUp()
self._spawn_processes()
@property
def world_size(self):
return 1
def _check_not_npu(self, device_id=0):
ascendcl_h = CDLL("libascendcl.so")
device_id = c_int(device_id)
activate = c_int(1)
rc = ascendcl_h.aclrtGetPrimaryCtxState(device_id, c_void_p(), byref(activate))
if rc != 0:
raise RuntimeError("call aclrtGetPrimaryCtxState error")
del ascendcl_h
self.assertEqual(activate.value, 0)
def test_event_create(self):
a = torch.full((3, 4), float(0), device='npu:0')
e = torch.npu.Event()
s = torch.npu.Stream()
def target_fuc(result):
e.record(s)
e.synchronize()
result[0] = 1
result = [0]
t = threading.Thread(target=target_fuc, args=(result, ))
t.start()
t.join()
self.assertEqual(result[0], 1)
def test_stream_create(self):
s = torch_npu._C._npu_getCurrentStream(0)
def test_tensor(self):
a = torch.full((3, 4), float(0), device='npu:0')
def target_fuc(result):
b = torch.full((3, 4), float(0), device='npu:0')
result[0] = 1
result = [0]
t = threading.Thread(target=target_fuc, args=(result, ))
t.start()
t.join()
self.assertEqual(result[0], 1)
def test_storage(self):
s = torch.npu.Stream()
def target_fuc(result):
b = torch.npu.FloatStorage(10)
result[0] = 1
result = [0]
t = threading.Thread(target=target_fuc, args=(result, ))
t.start()
t.join()
self.assertEqual(result[0], 1)
@skipIfUnsupportMultiNPU(2)
def test_set_device(self):
torch.npu.set_device('npu:1')
self._check_not_npu()
device = torch.npu.current_device()
self.assertEqual(device, 1)
@skipIfUnsupportMultiNPU(2)
def test_with_device(self):
with torch.npu.device('npu:1'):
a = torch.rand(1).npu()
self.assertEqual(a.device.index, 1)
self._check_not_npu()
with torch.npu.device('npu:1'):
b = torch.rand(1).npu()
self.assertEqual(b.device.index, 1)
self._check_not_npu()
@skipIfUnsupportMultiNPU(2)
def test_stream(self):
s = torch.npu.Stream()
self._check_not_npu(1)
@skipIfUnsupportMultiNPU(2)
def test_stream0(self):
s = torch.npu.Stream(0)
self._check_not_npu(1)
@skipIfUnsupportMultiNPU(2)
def test_stream1(self):
s = torch.npu.Stream(1)
self._check_not_npu(0)
def test_event(self):
s = torch.npu.Event()
s.record()
s.wait()
@skipIfUnsupportMultiNPU(2)
def test_storage0(self):
s1 = torch.npu.FloatStorage(10)
@skipIfUnsupportMultiNPU(2)
def test_storage1(self):
s = torch.UntypedStorage(10, device='npu:1')
self._check_not_npu()
@skipIfUnsupportMultiNPU(2)
def test_empty0(self):
a = torch.empty(2, device='npu:0')
@skipIfUnsupportMultiNPU(2)
def test_empty1(self):
a = torch.empty(2, device='npu:1')
self._check_not_npu()
@skipIfUnsupportMultiNPU(2)
def test_rand0(self):
a = torch.rand(2, device='npu:0')
@skipIfUnsupportMultiNPU(2)
def test_rand1(self):
a = torch.rand(2, device='npu:1')
self._check_not_npu()
@skipIfUnsupportMultiNPU(2)
def test_npu0(self):
a = torch.rand(2).npu()
@skipIfUnsupportMultiNPU(2)
def test_npu1(self):
a = torch.rand(2).npu(1)
self._check_not_npu()
@skipIfUnsupportMultiNPU(2)
def test_to0(self):
a = torch.rand(2).to('npu:0')
@skipIfUnsupportMultiNPU(2)
def test_to1(self):
a = torch.rand(2).to('npu:1')
self._check_not_npu()
@skipIfUnsupportMultiNPU(2)
def test_pin_memory(self):
def worker_function():
pinmemory_tensor = torch.empty(32, pin_memory=True)
a = torch.rand(2).to('npu:1')
t = threading.Thread(target=worker_function)
t.start()
t.join()
device = torch.npu.current_device()
self.assertEqual(device, 0)
self._check_not_npu()
@skipIfUnsupportMultiNPU(2)
def thread(self):
def ff():
a = torch.rand(2, device='npu:1')
self.assertEqual(a.device.index, 1)
t = threading.Thread(target=ff)
t.start()
t.join()
self._check_not_npu()
@skipIfUnsupportMultiNPU(2)
def thread1(self):
def ff():
a = torch.rand(2, device='npu:1')
self.assertEqual(a.device.index, 1)
b = torch.rand(2).npu()
self.assertEqual(b.device.index, 0)
t = threading.Thread(target=ff)
t.start()
t.join()
if __name__ == "__main__":
os.environ["ACL_OP_INIT_MODE"] = "1"
run_tests()