import os
import gc
import unittest
import numpy as np
import torch.multiprocessing as mp
import torch
import torch_npu
from torch_npu.testing.common_utils import SupportedDevices
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU
def is_ipc_event_supported():
try:
ev = torch.npu.Event(enable_timing=False, interprocess=True)
except RuntimeError as e:
return False
else:
return True
skip_ipc_event_case = not is_ipc_event_supported()
class Test_ipc_event(TestCase):
@skipIfUnsupportMultiNPU(2)
def test_d2d_copy1(self):
a = torch.tensor(1.).to('npu:0')
b = torch.tensor(1.).to('npu:1')
b.copy_(a)
self.assertEqual(a.cpu(), b.cpu())
@skipIfUnsupportMultiNPU(2)
def test_d2d_copy2(self):
a = torch.tensor(1.).to('npu:0')
b = a.to('npu:1')
self.assertEqual(a.cpu(), b.cpu())
@skipIfUnsupportMultiNPU(2)
def test_d2d_copy3(self):
a = torch.ones(2, 1024, 1024, 1024).to('npu:0')
b = a.to('npu:1', non_blocking=True)
self.assertEqual(a.cpu(), b.cpu())
@SupportedDevices(['Ascend910B'])
def test_ipc_event_pickle(self):
if skip_ipc_event_case:
return
ev = torch.npu.Event(enable_timing=False, interprocess=True)
ctx = mp.get_context("spawn")
q = ctx.Queue()
q.put(ev)
@staticmethod
def _child_proc1(q):
ev = q.get()
assert ev.device.type == 'npu'
assert ev.device.index == 0
ev.wait()
ev.synchronize()
@SupportedDevices(['Ascend910B'])
def test_ipc_event_1(self):
if skip_ipc_event_case:
return
ctx = mp.get_context("spawn")
q = ctx.Queue()
p = ctx.Process(target=Test_ipc_event._child_proc1, args=(q,))
p.start()
dev = torch.device("npu:0")
with torch.npu.device(dev):
stream = torch.npu.Stream()
with torch.npu.stream(stream):
ev = torch.npu.Event(enable_timing=False, interprocess=True)
ev.record(stream)
q.put(ev)
p.join()
@staticmethod
def _child_proc2(q1, q2):
dev = torch.device("npu:0")
with torch.npu.device(dev):
stream = torch.npu.Stream()
with torch.npu.stream(stream):
ev = q1.get()
assert ev.device.type == 'npu'
assert ev.device.index == 0
ev.wait()
ev.record(stream)
q2.put('x')
assert q1.get() == 'y'
@SupportedDevices(['Ascend910B'])
def test_ipc_event_2(self):
if skip_ipc_event_case:
return
ctx = mp.get_context("spawn")
q1 = ctx.Queue()
q2 = ctx.Queue()
p = ctx.Process(target=Test_ipc_event._child_proc2, args=(q1, q2))
p.start()
dev = torch.device("npu:0")
with torch.npu.device(dev):
stream = torch.npu.Stream()
with torch.npu.stream(stream):
ev = torch.npu.Event(enable_timing=False, interprocess=True)
ev.record(stream)
q1.put(ev)
self.assertEqual(q2.get(), 'x')
ev.wait()
ev.synchronize()
q1.put('y')
p.join()
@skipIfUnsupportMultiNPU(2)
@SupportedDevices(['Ascend910B'])
def test_event_handle_multi_npu(self):
if skip_ipc_event_case:
return
d0 = torch.device("npu:0")
d1 = torch.device("npu:1")
with torch.npu.device(d0):
e0 = torch.npu.Event(enable_timing=False, interprocess=True)
with torch.npu.device(d1):
e0.ipc_handle()
with torch.npu.device(d0):
e1 = torch.npu.Event(enable_timing=False, interprocess=True)
stream = torch.npu.Stream()
e1.record(stream)
with torch.npu.device(d1):
e1.ipc_handle()
@staticmethod
def _test_event_handle_importer_consumer(handle, p2c, c2p):
e1 = torch.npu.Event.from_ipc_handle(0, handle)
c2p.put(0)
p2c.get()
e1.synchronize()
c2p.put(1)
p2c.get()
@SupportedDevices(['Ascend910B'])
def test_event_handle_importer(self):
if skip_ipc_event_case:
return
e0 = torch.npu.Event(enable_timing=False, interprocess=True)
self.assertTrue(e0.query())
ctx = mp.get_context("spawn")
p2c = ctx.SimpleQueue()
c2p = ctx.SimpleQueue()
p = ctx.Process(
target=Test_ipc_event._test_event_handle_importer_consumer,
args=(e0.ipc_handle(), p2c, c2p),
)
p.start()
c2p.get()
e0.record()
p2c.put(0)
c2p.get()
self.assertTrue(e0.query())
p2c.put(1)
p.join()
@staticmethod
def _test_event_handle_exporter_consumer(handle, p2c, c2p):
stream = torch.npu.Stream()
with torch.npu.stream(stream):
e1 = torch.npu.Event.from_ipc_handle(torch.npu.current_device(), handle)
e1.record()
c2p.put(0)
p2c.get()
@SupportedDevices(['Ascend910B'])
def test_event_handle_exporter(self):
if skip_ipc_event_case:
return
e0 = torch.npu.Event(enable_timing=False, interprocess=True)
ctx = mp.get_context("spawn")
p2c = ctx.SimpleQueue()
c2p = ctx.SimpleQueue()
p = ctx.Process(
target=Test_ipc_event._test_event_handle_exporter_consumer,
args=(e0.ipc_handle(), p2c, c2p),
)
p.start()
c2p.get()
e0.synchronize()
self.assertTrue(e0.query())
p2c.put(0)
p.join()
if __name__ == '__main__':
run_tests()