import os
import sys
import gc
import unittest
from typing import NamedTuple
import torch
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import skipIfRocm, skipCUDANonDefaultStreamIf, NoTest, TEST_CUDA
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if not TEST_CUDA:
print('NPU not available, skipping tests', file=sys.stderr)
JitTestCase = NoTest
TEST_LARGE_TENSOR = TEST_CUDA
TEST_MULTIGPU = TEST_CUDA and torch.npu.device_count() >= 2
if TEST_CUDA:
torch.ones(1).npu()
TEST_LARGE_TENSOR = torch.npu.get_device_properties(0).total_memory >= 5e9
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestCUDA(JitTestCase):
"""
A suite of tests for the NPU API in TorchScript.
"""
def tearDown(self):
gc.collect()
torch.npu.empty_cache()
super().tearDown()
@skipIfRocm
@unittest.skipIf(not TEST_MULTIGPU, "detected only one NPU")
def test_npu_synchronize(self):
@torch.jit.script
def test_device_synchronize():
prev_current_device_index = torch.npu.current_device()
torch.npu.synchronize()
torch.npu.synchronize('npu')
torch.npu.synchronize('npu:0')
torch.npu.synchronize(0)
torch.npu.synchronize(torch.device('npu:1'))
after_current_device_index = torch.npu.current_device()
return prev_current_device_index == after_current_device_index
@torch.jit.script
def test_multi_device_synchronize():
torch.npu.synchronize(torch.device('npu:0'))
prev_current_device_index = torch.npu.current_device()
torch.npu.synchronize(1)
after_current_device_index = torch.npu.current_device()
return prev_current_device_index == after_current_device_index
self.assertTrue(test_device_synchronize)
FileCheck().check("npu::synchronize(") \
.run(test_device_synchronize.graph)
self.assertTrue(test_multi_device_synchronize)
FileCheck().check("npu::synchronize(") \
.run(test_multi_device_synchronize.graph)
def test_stream_args(self):
@torch.jit.script
def stream_default_args() -> bool:
s = torch.npu.Stream()
return s.device_index() == torch.npu.current_device()
@torch.jit.script
def stream_default_args_for_device() -> bool:
s = torch.npu.Stream(priority=0)
return s.device_index() == torch.npu.current_device()
@torch.jit.script
def stream_default_args_for_priority() -> bool:
d = torch.device("npu:1")
s = torch.npu.Stream(d)
return s.device_index() == 1
@torch.jit.script
def stream_args_all() -> bool:
d = torch.device("npu:0")
s = torch.npu.Stream(d, 0)
return s.device_index() == 0
self.assertTrue(stream_default_args)
self.assertTrue(stream_default_args_for_device)
self.assertTrue(stream_default_args_for_priority)
self.assertTrue(stream_args_all)
def test_event_args(self):
@torch.jit.script
def event_default_args() -> bool:
e = torch.npu.Event()
return e is not None
self.assertTrue(event_default_args)
@skipIfRocm
@unittest.skipIf(not TEST_MULTIGPU, "detected only one NPU")
def test_current_stream(self):
@torch.jit.script
def fn():
device_index = torch.npu.current_device()
device = torch.device("npu:" + str(device_index))
s0 = torch.npu.current_stream(device)
s1 = torch.npu.current_stream(torch.device("npu:1"))
s2 = torch.npu.current_stream(torch.device("npu:0"))
return s0.device_index(), s1.device_index(), s2.device_index()
d0, d1, d2 = fn()
self.assertEqual(0, d0)
self.assertEqual(1, d1)
self.assertEqual(0, d2)
self.assertEqual(d0, d2)
@torch.jit.script
def fn_with_device_index_args():
device_index = torch.npu.current_device()
s0 = torch.npu.current_stream(device_index)
s1 = torch.npu.current_stream(1)
s2 = torch.npu.current_stream(0)
return s0.device_index(), s1.device_index(), s2.device_index()
d0, d1, d2 = fn_with_device_index_args()
self.assertEqual(0, d0)
self.assertEqual(1, d1)
self.assertEqual(0, d2)
self.assertEqual(d0, d2)
@skipIfRocm
@unittest.skipIf(not TEST_MULTIGPU, "detected only one NPU")
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
@skipCUDANonDefaultStreamIf(True)
def test_streams_and_events(self):
@torch.jit.script
def test_default_streams_with_device_index_args():
s0 = torch.npu.default_stream(0)
s1 = torch.npu.default_stream(1)
return s0.device_index(), s1.device_index()
d0, d1 = test_default_streams_with_device_index_args()
self.assertEqual(d0, 0)
self.assertEqual(d1, 1)
@torch.jit.script
def test_default_streams():
s0 = torch.npu.default_stream(torch.device('npu:0'))
s1 = torch.npu.default_stream(torch.device('npu:1'))
d = torch.device('npu:1')
s2 = torch.npu.current_stream(torch.device('npu:0'))
check_s2 = s2.id() == s0.id()
check_d0 = torch.npu.current_device() == s2.device_index()
with torch.npu.device(d):
s3 = torch.npu.current_stream(d)
check_s3 = s3.id() == s1.id()
check_d1 = torch.npu.current_device() == s3.device_index()
is_device_d0 = torch.npu.current_device() == s2.device_index()
return s0.device_index(), s1.device_index(), check_s2, check_s3, check_d0, check_d1, is_device_d0
d0, d1, check_s2, check_s3, check_d0, check_d1, is_device_d0 = test_default_streams()
self.assertEqual(d0, 0)
self.assertEqual(d1, 1)
self.assertTrue(check_s2)
self.assertTrue(check_s3)
self.assertTrue(check_d0)
self.assertTrue(check_d1)
self.assertTrue(is_device_d0)
@torch.jit.script
def test_set_none_stream():
device_index = torch.npu.current_device()
device = torch.device("npu:" + str(device_index))
current_stream = torch.npu.current_stream(device)
default_stream = torch.npu.default_stream(device)
with torch.npu.stream(None):
cur_device_index = torch.npu.current_device()
is_device_index_same = cur_device_index == device_index
is_current_stream_same = torch.npu.current_stream(device).id() == current_stream.id()
is_default_stream_same = torch.npu.default_stream(device).id() == default_stream.id()
are_streams_same = is_device_index_same and is_current_stream_same and is_default_stream_same
return are_streams_same
self.assertTrue(test_set_none_stream())
@torch.jit.script
def test_set_device_none():
device_index = torch.npu.current_device()
with torch.npu.device(None):
is_device_same = torch.npu.current_device() == device_index
return is_device_same
self.assertTrue(test_set_device_none())
@torch.jit.script
def test_simple_stream():
device_index = torch.npu.current_device()
s = torch.npu.Stream()
return device_index == s.device_index()
self.assertTrue(test_simple_stream(), "Could not create Stream!")
class Result(NamedTuple):
t1 : torch.Tensor
t2 : torch.Tensor
is_current_and_default_stream_same : bool
is_default_and_user_stream_not_same : bool
is_stream_set : bool
is_stream_reset : bool
default_stream_query : bool
default_stream_id : int
user_stream_id : int
@torch.jit.script
def test_get_stream():
device_index = torch.npu.current_device()
device = torch.device("npu:" + str(device_index))
current_stream = torch.npu.current_stream(device)
default_stream = torch.npu.default_stream(device)
user_stream = torch.npu.Stream()
is_current_and_default_stream_same = current_stream.id() == default_stream.id()
is_default_and_user_stream_not_same = default_stream.id() != user_stream.id()
with torch.npu.stream(user_stream):
is_stream_set = torch.npu.current_stream(device).id() == user_stream.id()
is_stream_reset = torch.npu.current_stream(device).id() == current_stream.id()
tensor1 = torch.rand(10000, 10000, device="npu")
tensor2 = torch.mm(tensor1, tensor1).to("npu")
default_stream.synchronize()
default_stream_query = default_stream.query()
res = Result(
tensor1, tensor2, is_current_and_default_stream_same,
is_default_and_user_stream_not_same, is_stream_set,
is_stream_reset, default_stream_query, default_stream.id(), user_stream.id())
return res
result = test_get_stream()
self.assertEqual(torch.matmul(result.t1, result.t1), result.t2)
self.assertTrue(result.is_current_and_default_stream_same)
self.assertTrue(result.is_default_and_user_stream_not_same)
self.assertTrue(result.is_stream_set)
self.assertTrue(result.is_stream_reset)
self.assertTrue(result.default_stream_query)
self.assertEqual(result.default_stream_id, 0)
self.assertNotEqual(result.user_stream_id, 0)
@torch.jit.script
def test_stream_context():
device_index = torch.npu.current_device()
device = torch.device("npu:" + str(device_index))
current_stream = torch.npu.current_stream(device)
user_stream = torch.npu.Stream()
A = torch.rand(1000, 1000, device="npu")
with torch.npu.stream(user_stream):
check = torch.npu.current_stream(device).id() == user_stream.id()
B = torch.mm(A, A).to("npu")
user_stream.synchronize()
is_stream_reset = torch.npu.current_stream(device).id() == current_stream.id()
return A, B, check, is_stream_reset
A, B, is_stream_set, is_stream_reset = test_stream_context()
self.assertEqual(torch.matmul(A, A), B)
self.assertTrue(is_stream_set, "Error: Current stream was not set to user stream!")
self.assertTrue(is_stream_reset, "Error: The stream was not restored to previous stream!")
@torch.jit.script
def test_multiple_stream():
prev_device_index = torch.npu.current_device()
device = torch.device("npu:" + str(prev_device_index))
prev_current_stream = torch.npu.current_stream(device)
d1 = torch.device("npu:0")
d2 = torch.device("npu:1")
s1 = torch.npu.Stream(d1, 0)
s2 = torch.npu.Stream(d2, 0)
A = torch.rand(1000, 1000, device="npu")
B = torch.rand(1000, 1000, device="npu")
with torch.npu.stream(s1):
C = torch.mm(A, A).to("npu")
is_stream_s1 = torch.npu.current_stream(d1).id() == s1.id()
is_device_s1 = torch.npu.current_device() == s1.device_index()
with torch.npu.stream(s2):
is_stream_s2 = torch.npu.current_stream(d2).id() == s2.id()
is_device_s2 = torch.npu.current_device() == s2.device_index()
D = torch.mm(B, B).to("npu")
is_stream_s1_after = torch.npu.current_stream(d1).id() == s1.id()
is_device_s1_after = torch.npu.current_device() == s1.device_index()
s2.synchronize()
s1.synchronize()
is_device_current = torch.npu.current_device() == prev_device_index
is_stream_current = torch.npu.current_stream(device).id() == prev_current_stream.id()
check_stream = is_stream_s1 and is_stream_s2 and is_stream_s1_after and is_stream_current
check_device = is_device_s1 and is_device_s2 and is_device_s1_after and is_device_current
return A, B, C, D, check_stream, check_device
A, B, C, D, check_stream, check_device = test_multiple_stream()
self.assertEqual(torch.matmul(A, A), C)
self.assertEqual(torch.matmul(B, B), D)
self.assertTrue(check_stream)
self.assertTrue(check_device)
@torch.jit.script
def test_data_dependency_between_streams():
device_index = torch.npu.current_device()
device = torch.device("npu:" + str(device_index))
prev_current_stream = torch.npu.current_stream(device)
d = torch.device("npu:0")
s1 = torch.npu.Stream(d, 0)
s2 = torch.npu.Stream(d, 0)
event = torch.npu.Event(False, False, False)
A = torch.rand(1000, 1000, device="npu")
with torch.npu.stream(s1):
is_stream_s1 = torch.npu.current_stream(device).id() == s1.id()
B = torch.mm(A, A).to("npu")
s1.record_event(event)
is_current_stream_1 = torch.npu.current_stream(device).id() == prev_current_stream.id()
s2.wait_event(event)
with torch.npu.stream(s2):
is_stream_s2 = torch.npu.current_stream(device).id() == s2.id()
C = torch.mm(B, B).to("npu")
s2.synchronize()
is_current_stream_2 = torch.npu.current_stream(device).id() == prev_current_stream.id()
check_stream = is_current_stream_1 and is_current_stream_2 and is_stream_s1 and is_stream_s2
return A, B, C, check_stream
A, B, C, check_stream = test_data_dependency_between_streams()
self.assertEqual(torch.matmul(A, A), B)
self.assertEqual(torch.matmul(B, B), C)
self.assertTrue(check_stream)
@torch.jit.script
def test_simple_event():
e = torch.npu.Event(True, False, False)
return e is not None
self.assertTrue(test_simple_event(), "Could not create NPU Event!")
@torch.jit.script
def test_event():
device_index = torch.npu.current_device()
device = torch.device("npu:" + str(device_index))
stream = torch.npu.current_stream(device)
event = torch.npu.Event(True, False, False)
is_true_event_query = event.query()
start_event = torch.npu.Event(True, False, False)
stream.record_event(start_event)
tensor1 = torch.rand(1000000000, 1000000000, device="npu")
tensor2 = torch.mm(tensor1, tensor1).to("npu")
stream.record_event(event)
event.synchronize()
is_again_true_event_query = event.query()
if not (is_true_event_query and is_again_true_event_query):
return -1.0
return start_event.elapsed_time(event)
self.assertGreater(test_event(), 0)
@torch.jit.script
def test_stream_synchronize() -> float:
device_index = torch.npu.current_device()
s = torch.npu.Stream()
e_tik = torch.npu.Event(True, False, False)
e_tok = torch.npu.Event(True, False, False)
e_tik.record(s)
tensor1 = torch.rand(1000000000, 1000000000, device="npu")
with torch.npu.stream(s):
tensor2 = torch.mm(tensor1, tensor1).to("npu")
s.synchronize()
e_tok.record(s)
e_tok.synchronize()
if not s.query():
return -1.0
return e_tik.elapsed_time(e_tok)
self.assertGreater(test_stream_synchronize(), 0)
@torch.jit.script
def test_event_synchronize() -> float:
s = torch.npu.Stream()
e_tik = torch.npu.Event(True, False, False)
e_tok = torch.npu.Event(True, False, False)
e_tik.record(s)
tensor1 = torch.rand(1000000000, 1000000000, device="npu")
with torch.npu.stream(s):
tensor = torch.mm(tensor1, tensor1).to("npu")
s.record_event(e_tok)
e_tok.synchronize()
s.synchronize()
if not s.query():
return -1.0
return e_tik.elapsed_time(e_tok)
self.assertGreater(test_event_synchronize(), 0)
@torch.jit.script
def test_event_wait() -> float:
device_index = torch.npu.current_device()
device = torch.device("npu:" + str(device_index))
s0 = torch.npu.current_stream(device)
s1 = torch.npu.Stream()
e_tik = torch.npu.Event(True, True, False)
e_tok = torch.npu.Event(True, True, False)
e_tik.record(s0)
tensor1 = torch.rand(1000000000, 1000000000, device="npu")
with torch.npu.stream(s0):
tensor2 = torch.mm(tensor1, tensor1).npu()
e_sync = torch.npu.Event(True, False, False)
e_sync.record(torch.npu.current_stream(device))
e_sync.wait(s1)
with torch.npu.stream(s1):
tensor3 = torch.rand(1000000000, 1000000000, device="npu")
tensor4 = torch.mm(tensor3, tensor3).npu()
s1.synchronize()
e_tok.record(torch.npu.current_stream(device))
e_tok.synchronize()
s0.synchronize()
if not s0.query() or not s1.query() or not e_sync.query():
return -1.0
return e_tik.elapsed_time(e_tok)
self.assertGreater(test_event_wait(), 0)
@torch.jit.script
def test_wait_event():
d1 = torch.device('npu:1')
with torch.npu.device(d1):
s0 = torch.npu.current_stream(d1)
tensor1 = torch.rand(1000000000, 1000000000, device="npu")
tensor2 = torch.mm(tensor1, tensor1).to("npu")
e0 = torch.npu.Event(False, False, False)
s0.record_event(e0)
s1 = torch.npu.current_stream(torch.device('npu:0'))
s1.wait_event(e0)
s1.synchronize()
return e0.query() and s0.query() and s1.query()
self.assertTrue(test_wait_event())
def test_save_load(self):
class Model(torch.nn.Module):
def forward(self):
s = torch.npu.Stream()
a = torch.rand(3, 4, device="npu")
b = torch.rand(3, 4, device="npu")
with torch.npu.stream(s):
is_stream_s = torch.npu.current_stream(s.device).id() == s.id()
c = torch.cat((a, b), 0).npu()
s.synchronize()
return is_stream_s, a, b, c
model = Model()
script_model = torch.jit.script(model)
is_stream_s, a, b, c = script_model()
self.assertTrue(is_stream_s)
self.assertEqual(torch.cat((a, b), 0), c)
load_model = self.getExportImportCopy(script_model)
is_stream_s, a_load, b_load, c_load = load_model()
self.assertTrue(is_stream_s)
self.assertEqual(torch.cat((a_load, b_load), 0), c_load)
@unittest.skipIf(not TEST_CUDA, "Cuda not available")
def test__exchange_device_op(self):
def fn(device: int, tensor):
torch.npu._exchange_device(device)
return tensor.cos().relu()
fn_s = torch.jit.script(fn)
g = fn_s.graph
FileCheck().check("npu::_exchange_device(").run(g)
torch._C._jit_pass_inline(g)
FileCheck().check("npu::_exchange_device(").run(g)
@unittest.skipIf(not TEST_CUDA, "Cuda not available")
def test__maybe_exchange_device_op(self):
def fn(device: int, tensor):
torch.npu._maybe_exchange_device(device)
return tensor.cos().relu()
fn_s = torch.jit.script(fn)
g = fn_s.graph
FileCheck().check("npu::_maybe_exchange_device(").run(g)
torch._C._jit_pass_inline(g)
FileCheck().check("npu::_maybe_exchange_device(").run(g)