import collections
import contextlib
import ctypes
import io
import gc
import queue
import sys
import tempfile
import threading
import unittest
from itertools import repeat, chain
from typing import NamedTuple
import torch
import torch.cuda.comm as comm
from torch.nn.parallel import scatter_gather
import torch_npu
import torch_npu.testing
from torch_npu.testing.common_utils import get_cycles_per_ms
from torch.testing._internal.common_utils import (
IS_JETSON,
IS_REMOTE_GPU,
IS_SANDCASTLE,
NoTest,
TEST_PRIVATEUSE1,
TestCase,
instantiate_parametrized_tests,
run_tests,
skipCUDANonDefaultStreamIf,
skipIfRocm,
)
from torch.testing._internal.common_cuda import _create_scaling_case, _create_scaling_models_optimizers
TEST_MULTINPU = TEST_PRIVATEUSE1 and torch_npu.npu.device_count() >= 2
if not TEST_PRIVATEUSE1:
print('NPU not available, skipping tests', file=sys.stderr)
TestCase = NoTest
class TestNpuMultiNpu(TestCase):
FIFTY_MIL_CYCLES = 50000000
def _check_memory_stat_consistency(self):
snapshot = torch_npu.npu.memory_snapshot()
expected_each_device = collections.defaultdict(lambda: collections.defaultdict(int))
for segment in snapshot:
expandable = segment["is_expandable"]
expected = expected_each_device[segment["device"]]
pool_str = segment["segment_type"] + "_pool"
if not expandable:
expected["segment.all.current"] += 1
expected["segment." + pool_str + ".current"] += 1
expected["allocated_bytes.all.current"] += segment["allocated_size"]
expected["allocated_bytes." + pool_str + ".current"] += segment["allocated_size"]
expected["reserved_bytes.all.current"] += segment["total_size"]
expected["reserved_bytes." + pool_str + ".current"] += segment["total_size"]
expected["active_bytes.all.current"] += segment["active_size"]
expected["active_bytes." + pool_str + ".current"] += segment["active_size"]
expected["requested_bytes.all.current"] += segment["requested_size"]
expected["requested_bytes." + pool_str + ".current"] += segment["requested_size"]
sum_requested = 0
is_split = len(segment["blocks"]) > 1
for block in segment["blocks"]:
if block["state"] == "active_allocated":
expected["allocation.all.current"] += 1
expected["allocation." + pool_str + ".current"] += 1
if block["state"].startswith("active_"):
sum_requested += block["requested_size"]
expected["active.all.current"] += 1
expected["active." + pool_str + ".current"] += 1
if block["state"] == "inactive" and is_split and not expandable:
expected["inactive_split.all.current"] += 1
expected["inactive_split." + pool_str + ".current"] += 1
expected["inactive_split_bytes.all.current"] += block["size"]
expected["inactive_split_bytes." + pool_str + ".current"] += block["size"]
self.assertEqual(sum_requested, segment["requested_size"])
for device, expected in expected_each_device.items():
stats = torch_npu.npu.memory_stats(device)
for k, v in expected.items():
self.assertEqual(v, stats[k])
def test_cuda_synchronize(self):
torch_npu.npu.synchronize()
torch_npu.npu.synchronize('npu')
torch_npu.npu.synchronize('npu:0')
torch_npu.npu.synchronize(0)
torch_npu.npu.synchronize(torch.device('npu:0'))
if TEST_MULTINPU:
torch_npu.npu.synchronize('npu:1')
torch_npu.npu.synchronize(1)
torch_npu.npu.synchronize(torch.device('npu:1'))
with self.assertRaisesRegex(ValueError, "Expected a npu device, but"):
torch_npu.npu.synchronize(torch.device("cpu"))
with self.assertRaisesRegex(ValueError, "Expected a npu device, but"):
torch_npu.npu.synchronize("cpu")
@staticmethod
def _test_memory_stats_generator(self, device=None, N=35):
if device is None:
device = torch_npu.npu.current_device()
m0 = torch_npu.npu.memory_allocated(device)
last_m_arr = [torch_npu.npu.memory_allocated(device)]
max_m_arr = [torch_npu.npu.max_memory_allocated(device)]
last_r_arr = [torch_npu.npu.memory_reserved(device)]
max_r_arr = [torch_npu.npu.max_memory_reserved(device)]
def alloc(*size):
with torch_npu.npu.device(device):
return torch_npu.npu.FloatTensor(*size)
def assert_change(comp=1, empty_cache=False, reset_peak=False):
new_m = torch_npu.npu.memory_allocated(device)
new_max_m = torch_npu.npu.max_memory_allocated(device)
if comp > 0:
self.assertGreater(new_m, last_m_arr[0])
elif comp < 0:
self.assertLess(new_m, last_m_arr[0])
else:
self.assertEqual(new_m, last_m_arr[0])
self.assertLessEqual(new_m, new_max_m)
self.assertGreaterEqual(new_max_m, max_m_arr[0])
last_m_arr[0] = new_m
max_m_arr[0] = new_max_m
new_r = torch_npu.npu.memory_reserved(device)
new_max_r = torch_npu.npu.max_memory_reserved(device)
self.assertLessEqual(new_r, new_max_r)
self.assertGreaterEqual(new_max_r, max_r_arr[0])
last_r_arr[0] = new_r
max_r_arr[0] = new_max_r
if empty_cache:
torch_npu.npu.empty_cache()
new_r = torch_npu.npu.memory_reserved(device)
new_max_r = torch_npu.npu.max_memory_reserved(device)
self.assertLessEqual(new_r, last_r_arr[0])
self.assertLessEqual(new_r, new_max_r)
self.assertEqual(new_max_r, max_r_arr[0])
last_r_arr[0] = new_r
if reset_peak:
torch_npu.npu.reset_peak_memory_stats(device)
self.assertEqual(torch_npu.npu.memory_allocated(device), last_m_arr[0])
self.assertEqual(torch_npu.npu.max_memory_allocated(device), last_m_arr[0])
max_m_arr[0] = last_m_arr[0]
self.assertEqual(torch_npu.npu.memory_reserved(device), last_r_arr[0])
self.assertEqual(torch_npu.npu.max_memory_reserved(device), last_r_arr[0])
max_r_arr[0] = last_r_arr[0]
assert_change(0)
assert_change(0, reset_peak=True)
assert_change(0, empty_cache=True)
assert_change(0, reset_peak=True)
assert_change(0)
yield
tensors1 = [alloc(1), alloc(10, 20), alloc(200, 300, 2000)]
m1 = torch_npu.npu.memory_allocated(device)
assert_change(1)
yield
tensors2 = []
for i in range(1, int(N / 2) + 1):
tensors2.append(alloc(i, i * 4))
assert_change(1)
yield
for i in range(5, int(N / 2) + 5):
tensors2.append(alloc(i, i * 7, i * 9, i * 11))
assert_change(1, reset_peak=(i % 2 == 0))
yield
tensors2.append(alloc(0, 0, 0))
assert_change(0)
yield
permute = []
for i in torch.randperm(len(tensors2)):
permute.append(tensors2[i])
assert_change(0)
yield
del tensors2
assert_change(0)
yield
tensors2 = permute
assert_change(0)
yield
del permute
assert_change(0, reset_peak=True)
yield
for i in range(int(N / 2)):
x = tensors2[i].numel()
del tensors2[i]
assert_change(-x)
yield
for i in range(2, int(2 * N / 3) + 2):
tensors2.append(alloc(i, i * 3, i * 8))
assert_change(1)
yield
del tensors2
assert_change(-1, reset_peak=True)
assert_change(0)
self.assertEqual(torch_npu.npu.memory_allocated(device), m1)
yield True
del tensors1
assert_change(-1, reset_peak=True)
self.assertEqual(torch_npu.npu.memory_allocated(device), m0)
assert_change(0, empty_cache=True)
assert_change(0, reset_peak=True)
def test_memory_stats(self):
gc.collect()
torch_npu.npu.empty_cache()
for _ in self._test_memory_stats_generator(self):
self._check_memory_stat_consistency()
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_memory_stats_multigpu(self):
def advance(gen, end):
if not end:
try:
next(gen)
except StopIteration:
end = True
return end
torch_npu.npu.empty_cache()
gen0 = self._test_memory_stats_generator(self, device='npu:0', N=35)
gen1 = self._test_memory_stats_generator(self, device=torch.device('npu:1'), N=35)
end0 = end1 = False
while not (end0 and end1):
end0 = advance(gen0, end0)
end1 = advance(gen1, end1)
torch_npu.npu.empty_cache()
gen0 = self._test_memory_stats_generator(self, device=0, N=35)
gen1 = self._test_memory_stats_generator(self, device=torch.device('npu:1'), N=35)
end0 = end1 = False
while not (end0 and end1):
end0 = advance(gen0, end0)
if not end0:
gen1_max_times = torch.LongTensor(1).random_(0, 3)[0]
else:
gen1_max_times = torch.inf
t = 0
while t < gen1_max_times and not end1:
end1 = advance(gen1, end1)
t += 1
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_autogpu(self):
x = torch.randn(5, 5).npu()
y = torch.randn(5, 5).npu()
self.assertEqual(x.get_device(), 0)
self.assertEqual(x.get_device(), 0)
with torch_npu.npu.device(1):
z = torch.randn(5, 5).npu()
self.assertEqual(z.get_device(), 1)
q = x.add(y)
self.assertEqual(q.get_device(), 0)
w = torch.randn(5, 5).npu()
self.assertEqual(w.get_device(), 1)
self.assertEqual(y.npu().get_device(), 1)
z = z.npu()
self.assertEqual(z.get_device(), 0)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_new(self):
x = torch.randn(3, 3).npu()
self.assertEqual(x.new([0, 1, 2]).get_device(), 0)
self.assertEqual(x.new([0, 1, 2], device=1).get_device(), 1)
with torch_npu.npu.device(1):
self.assertEqual(x.new([0, 1, 2]).get_device(), 0)
self.assertEqual(x.new([0, 1, 2], device=1).get_device(), 1)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_copy_device(self):
x = torch.randn(5, 5).npu()
with torch_npu.npu.device(1):
y = x.npu()
self.assertEqual(y.get_device(), 1)
self.assertIs(y.npu(), y)
z = y.npu(0)
self.assertEqual(z.get_device(), 0)
self.assertIs(z.npu(0), z)
x = torch.randn(5, 5)
with torch_npu.npu.device(1):
y = x.npu()
self.assertEqual(y.get_device(), 1)
self.assertIs(y.npu(), y)
z = y.npu(0)
self.assertEqual(z.get_device(), 0)
self.assertIs(z.npu(0), z)
def _test_copy_sync_current_stream(self, x, y):
x_plus_one = x + 1
s0 = torch_npu.npu.Stream(device=x.device)
s1 = torch_npu.npu.Stream(device=y.device)
s2 = torch_npu.npu.Stream(device=x.device)
s3 = torch_npu.npu.Stream(device=y.device)
with torch_npu.npu.stream(s0):
torch_npu.npu._sleep(TestNpuMultiNpu.FIFTY_MIL_CYCLES)
with torch_npu.npu.stream(s1):
y.copy_(x_plus_one)
with torch_npu.npu.stream(s2), torch_npu.npu.stream(s1):
y.copy_(x)
s1.synchronize()
self.assertEqual(y, x)
with torch_npu.npu.stream(s1):
torch_npu.npu._sleep(TestNpuMultiNpu.FIFTY_MIL_CYCLES)
with torch_npu.npu.stream(s0):
y.copy_(x_plus_one)
with torch_npu.npu.stream(s3), torch_npu.npu.stream(s0):
y.copy_(x)
s0.synchronize()
self.assertEqual(y, x)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_copy_streams(self):
d0 = torch.device('npu:0')
x0 = torch.zeros(5, 5, device=d0)
d1 = torch.device('npu:1')
x1 = torch.zeros(5, 5, device=d1)
self._test_copy_sync_current_stream(x0, x1)
x2 = torch.zeros(5, 5, device=d0)
self._test_copy_sync_current_stream(x0, x2)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_cat_autogpu(self):
x = torch.randn(4, 4).npu(1)
y = torch.randn(4, 4).npu(1)
z = torch.cat([x, y], 0)
self.assertEqual(z.get_device(), x.get_device())
@unittest.skipIf(torch_npu.npu.device_count() >= 10, "Loading a npu:9 tensor")
def test_load_nonexistent_device(self):
tensor = torch.randn(2, device='npu')
buf = io.BytesIO()
torch.save(tensor, buf)
buf = io.BytesIO(buf.getvalue().replace(b'npu:0', b'npu:9'))
msg = r'Attempting to deserialize object on NPU device 9'
with self.assertRaisesRegex(RuntimeError, msg):
_ = torch.load(buf, weights_only=False)
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
def test_multigpu_serialization_remap(self):
x = [torch.randn(4, 4).npu(0), torch.randn(4, 4).npu(1)]
def gpu_remap(storage, location):
if location == 'npu:1':
return storage.npu(0)
with tempfile.NamedTemporaryFile() as f:
torch.save(x, f)
f.seek(0)
x_copy = torch.load(f, map_location=gpu_remap)
for original, copy in zip(x, x_copy):
self.assertEqual(copy, original)
self.assertIs(type(copy), type(original))
self.assertEqual(copy.get_device(), 0)
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
def test_multigpu_serialization_remap_dict(self):
x = [torch.randn(4, 4).npu(0), torch.randn(4, 4).npu(1)]
with tempfile.NamedTemporaryFile() as f:
torch.save(x, f)
f.seek(0)
x_copy = torch.load(f, map_location={'npu:1': 'npu:0'})
for original, copy in zip(x, x_copy):
self.assertEqual(copy, original)
self.assertIs(type(copy), type(original))
self.assertEqual(copy.get_device(), 0)
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
def test_multigpu_storage_clone(self):
x = torch.randn(4, 4, device='npu:1').storage()
y = x.clone()
self.assertEqual(x.get_device(), y.get_device())
for t in ['byte', 'char', 'short', 'int', 'long', 'half', 'double']:
self.assertEqual(getattr(x, t)().get_device(), x.get_device())
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
def test_cuda_set_device(self):
x = torch.randn(5, 5)
with torch_npu.npu.device(1):
self.assertEqual(x.npu().get_device(), 1)
torch_npu.npu.set_device(0)
self.assertEqual(x.npu().get_device(), 0)
with torch_npu.npu.device(1):
self.assertEqual(x.npu().get_device(), 1)
self.assertEqual(x.npu().get_device(), 0)
torch_npu.npu.set_device(1)
self.assertEqual(x.npu().get_device(), 0)
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
def test_current_stream(self):
d0 = torch.device('npu:0')
d1 = torch.device('npu:1')
s0 = torch_npu.npu.current_stream()
s1 = torch_npu.npu.current_stream(device=1)
s2 = torch_npu.npu.current_stream(device=0)
self.assertEqual(d0, s0.device)
self.assertEqual(d1, s1.device)
self.assertEqual(d0, s2.device)
self.assertEqual(s0, s2)
with torch_npu.npu.device(d1):
s0 = torch_npu.npu.current_stream()
s1 = torch_npu.npu.current_stream(1)
s2 = torch_npu.npu.current_stream(d0)
self.assertEqual(d1, s0.device)
self.assertEqual(d1, s1.device)
self.assertEqual(d0, s2.device)
self.assertEqual(s0, s1)
with self.assertRaisesRegex(ValueError,
"Expected a npu device, but got: cpu"):
torch_npu.npu.current_stream(torch.device('cpu'))
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
@skipCUDANonDefaultStreamIf(True)
def test_default_stream(self):
d0 = torch.device('npu:0')
d1 = torch.device('npu:1')
with torch_npu.npu.device(d0):
s0 = torch_npu.npu.default_stream()
with torch_npu.npu.device(d1):
s1 = torch_npu.npu.default_stream()
s2 = torch_npu.npu.default_stream(device=0)
s3 = torch_npu.npu.default_stream(d1)
self.assertEqual(d0, s0.device)
self.assertEqual(d1, s1.device)
self.assertEqual(d0, s2.device)
self.assertEqual(d1, s3.device)
self.assertEqual(s0, s2)
self.assertEqual(s1, s3)
with torch_npu.npu.device(d0):
self.assertEqual(torch_npu.npu.current_stream(), s0)
with torch_npu.npu.device(d1):
self.assertEqual(torch_npu.npu.current_stream(), s1)
with self.assertRaisesRegex(ValueError,
"Expected a npu device, but got: cpu"):
torch_npu.npu.default_stream(torch.device('cpu'))
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
def test_stream_event_device(self):
d0 = torch.device('npu:0')
d1 = torch.device('npu:1')
e0 = torch_npu.npu.Event()
self.assertEqual(None, e0.device)
with torch_npu.npu.device(d0):
s0 = torch_npu.npu.current_stream()
s0.record_event(e0)
with torch_npu.npu.device(d1):
s1 = torch_npu.npu.Stream()
e1 = s1.record_event()
self.assertEqual(s0.device, torch.device('npu:0'))
self.assertEqual(e0.device, torch.device('npu:0'))
self.assertEqual(s1.device, torch.device('npu:1'))
self.assertEqual(e1.device, torch.device('npu:1'))
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
def test_stream_context(self):
s0 = torch_npu.npu.current_stream()
s1 = torch_npu.npu.Stream(device=1)
s2 = torch_npu.npu.Stream(device=0)
with torch_npu.npu.device(s1.device):
prev_stream_on_cuda1 = torch_npu.npu.current_stream()
self.assertEqual(torch_npu.npu.current_stream(), s0)
self.assertEqual(0, torch_npu.npu.current_device())
with torch_npu.npu.stream(s1):
self.assertEqual(torch_npu.npu.current_stream(), s1)
self.assertEqual(1, torch_npu.npu.current_device())
with torch_npu.npu.stream(s2):
self.assertEqual(torch_npu.npu.current_stream(), s2)
self.assertEqual(0, torch_npu.npu.current_device())
with torch_npu.npu.stream(s0):
self.assertEqual(torch_npu.npu.current_stream(), s0)
self.assertEqual(0, torch_npu.npu.current_device())
self.assertEqual(torch_npu.npu.current_stream(), s2)
self.assertEqual(0, torch_npu.npu.current_device())
self.assertEqual(torch_npu.npu.current_stream(), s1)
self.assertEqual(1, torch_npu.npu.current_device())
with torch_npu.npu.device(s1.device):
self.assertEqual(prev_stream_on_cuda1, torch_npu.npu.current_stream())
self.assertEqual(torch_npu.npu.current_stream(), s0)
self.assertEqual(0, torch_npu.npu.current_device())
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
def test_streams_multi_gpu(self):
default_stream = torch_npu.npu.current_stream()
self.assertEqual(default_stream.device, torch.device('npu:0'))
stream = torch_npu.npu.Stream(device=1)
self.assertEqual(stream.device, torch.device('npu:1'))
with torch_npu.npu.device(1):
self.assertEqual(
torch_npu.npu.current_stream().device, torch.device('npu:1'))
self.assertNotEqual(torch_npu.npu.current_stream(), default_stream)
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
def test_streams_multi_gpu_query(self):
d0 = torch.device('npu:0')
d1 = torch.device('npu:1')
torch_npu.npu.synchronize(d0)
torch_npu.npu.synchronize(d1)
with torch_npu.npu.device(d0):
s0 = torch_npu.npu.current_stream()
with torch_npu.npu.device(d1):
s1 = torch_npu.npu.current_stream()
torch_npu.npu._sleep(TestNpuMultiNpu.FIFTY_MIL_CYCLES)
self.assertTrue(s0.query())
self.assertFalse(s1.query())
with torch_npu.npu.device(d0):
self.assertTrue(s0.query())
self.assertFalse(s1.query())
with torch_npu.npu.device(d1):
self.assertTrue(s0.query())
self.assertFalse(s1.query())
with torch_npu.npu.device(d0):
s1.synchronize()
self.assertTrue(s0.query())
self.assertTrue(s1.query())
with torch_npu.npu.device(d0):
self.assertTrue(s0.query())
self.assertTrue(s1.query())
with torch_npu.npu.device(d1):
self.assertTrue(s0.query())
self.assertTrue(s1.query())
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
def test_streams_multi_gpu_eq(self):
d0 = torch.device('npu:0')
d1 = torch.device('npu:1')
with torch_npu.npu.device(d0):
s0 = torch_npu.npu.current_stream()
s1 = torch_npu.npu.current_stream()
with torch_npu.npu.device(d1):
s2 = torch_npu.npu.current_stream()
s3 = torch_npu.npu.current_stream()
self.assertTrue(s0 == s0)
self.assertTrue(s0 == s1)
self.assertTrue(s2 == s2)
self.assertTrue(s2 == s3)
self.assertFalse(s0 == s2)
self.assertFalse(s1 == s3)
self.assertEqual(s0.device, s1.device)
self.assertEqual(s0.cuda_stream, s1.cuda_stream)
self.assertEqual(s2.device, s3.device)
self.assertEqual(s2.cuda_stream, s3.cuda_stream)
self.assertNotEqual(s0.device, s3.device)
self.assertEqual(hash(s0), hash(s1))
self.assertEqual(hash(s2), hash(s3))
self.assertNotEqual(hash(s0), hash(s3))
@unittest.skipIf(not TEST_MULTINPU, "multi-NPU not supported")
def test_streams_priority(self):
low, high = torch_npu.npu.Stream.priority_range()
s0 = torch_npu.npu.Stream(device=0, priority=low)
self.assertEqual(low, s0.priority)
self.assertEqual(torch.device('npu:0'), s0.device)
s1 = torch_npu.npu.Stream(device=1, priority=high)
self.assertEqual(high, s1.priority)
self.assertEqual(torch.device('npu:1'), s1.device)
@unittest.skipIf(not TEST_MULTINPU, "multi-NPU not supported")
def test_tensor_device(self):
self.assertEqual(torch_npu.npu.FloatTensor(1).get_device(), 0)
self.assertEqual(torch_npu.npu.FloatTensor(1, device=1).get_device(), 1)
with torch_npu.npu.device(1):
self.assertEqual(torch_npu.npu.FloatTensor(1).get_device(), 1)
self.assertEqual(torch_npu.npu.FloatTensor(1, device=0).get_device(), 0)
self.assertEqual(torch_npu.npu.FloatTensor(1, device=None).get_device(), 1)
@staticmethod
def _stream_synchronize(self, spin_time_cycles):
s = torch_npu.npu.current_stream()
e_tik = torch_npu.npu.Event(enable_timing=True)
e_tok = torch_npu.npu.Event(enable_timing=True)
e_tik.record(s)
torch_npu.npu._sleep(spin_time_cycles)
e_tok.record(s)
s.synchronize()
self.assertTrue(s.query())
return e_tik.elapsed_time(e_tok)
@staticmethod
def _event_synchronize(self, spin_time_cycles):
s = torch_npu.npu.current_stream()
e_tik = torch_npu.npu.Event(enable_timing=True)
e_tok = torch_npu.npu.Event(enable_timing=True)
e_tik.record(s)
torch_npu.npu._sleep(spin_time_cycles)
s.record_event(e_tok)
e_tok.synchronize()
self.assertTrue(s.query())
return e_tik.elapsed_time(e_tok)
@staticmethod
def _event_wait(self, spin_time_cycles):
s0 = torch_npu.npu.current_stream()
s1 = torch_npu.npu.Stream()
e_tik = torch_npu.npu.Event(blocking=True, enable_timing=True)
e_tok = torch_npu.npu.Event(blocking=True, enable_timing=True)
e_tik.record(s0)
torch_npu.npu._sleep(spin_time_cycles - 10)
e_sync = torch_npu.npu.Event(blocking=True)
e_sync.record()
e_sync.wait(s1)
with torch_npu.npu.stream(s1):
torch_npu.npu._sleep(10)
s1.synchronize()
e_tok.record()
e_tok.synchronize()
self.assertTrue(s0.query())
self.assertTrue(s1.query())
self.assertTrue(e_sync.query())
return e_tik.elapsed_time(e_tok)
@staticmethod
def _test_stream_event_nogil(self, sync_func, p2c, c2p):
with torch_npu.npu.device('npu:1'):
c2p.put(0)
p2c.get()
c2p.put(sync_func(self, TestNpuMultiNpu.FIFTY_MIL_CYCLES))
@skipIfRocm
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
def test_stream_event_nogil(self):
for sync_func in [TestNpuMultiNpu._stream_synchronize,
TestNpuMultiNpu._event_synchronize,
TestNpuMultiNpu._event_wait]:
p2c = queue.Queue()
c2p = queue.Queue()
e_tik = torch_npu.npu.Event(enable_timing=True)
e_tok = torch_npu.npu.Event(enable_timing=True)
t = threading.Thread(
target=TestNpuMultiNpu._test_stream_event_nogil,
args=(self, sync_func, p2c, c2p))
t.daemon = True
t.start()
c2p.get()
with torch_npu.npu.device('npu:0'):
e_tik.record()
p2c.put(0)
parent_time = sync_func(self, TestNpuMultiNpu.FIFTY_MIL_CYCLES)
child_time = c2p.get()
e_tok.record()
e_tok.synchronize()
total_time = e_tik.elapsed_time(e_tok)
self.assertGreater(parent_time + child_time, total_time * 1.4)
@skipIfRocm
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
def test_events_wait(self):
d0 = torch.device('npu:0')
d1 = torch.device('npu:1')
torch_npu.npu.synchronize(d0)
torch_npu.npu.synchronize(d1)
with torch_npu.npu.device(d0):
s0 = torch_npu.npu.current_stream()
torch_npu.npu._sleep(TestNpuMultiNpu.FIFTY_MIL_CYCLES)
e0 = torch_npu.npu.Event()
s0.record_event(e0)
with torch_npu.npu.device(d1):
s1 = torch_npu.npu.current_stream()
self.assertFalse(s0.query())
self.assertTrue(s1.query())
s1.wait_event(e0)
s1.synchronize()
self.assertTrue(e0.query())
self.assertTrue(s0.query())
self.assertTrue(s1.query())
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
def test_events_multi_gpu_query(self):
d0 = torch.device('npu:0')
d1 = torch.device('npu:1')
with torch_npu.npu.device(d0):
s0 = torch_npu.npu.current_stream()
e0 = s0.record_event()
s0.synchronize()
with torch_npu.npu.device(d1):
s1 = torch_npu.npu.current_stream()
torch_npu.npu._sleep(TestNpuMultiNpu.FIFTY_MIL_CYCLES)
e1 = s1.record_event()
self.assertTrue(e0.query())
self.assertFalse(e1.query())
with torch_npu.npu.device(d0):
self.assertTrue(e0.query())
self.assertFalse(e1.query())
with torch_npu.npu.device(d1):
self.assertTrue(e0.query())
self.assertFalse(e1.query())
with torch_npu.npu.device(d0):
e1.synchronize()
self.assertTrue(e0.query())
self.assertTrue(e1.query())
with torch_npu.npu.device(d0):
self.assertTrue(e0.query())
self.assertTrue(e1.query())
with torch_npu.npu.device(d1):
self.assertTrue(e0.query())
self.assertTrue(e1.query())
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
@skipIfRocm
def test_events_multi_gpu_elapsed_time(self):
d0 = torch.device('npu:0')
d1 = torch.device('npu:1')
with torch_npu.npu.device(d0):
s0 = torch_npu.npu.current_stream()
e0 = torch_npu.npu.Event(enable_timing=True)
torch_npu.npu._sleep(10)
s0.record_event(e0)
with torch_npu.npu.device(d1):
s1 = torch_npu.npu.current_stream()
e1 = torch_npu.npu.Event(enable_timing=True)
torch_npu.npu._sleep(TestNpuMultiNpu.FIFTY_MIL_CYCLES)
s1.record_event(e1)
e0.synchronize()
e1.synchronize()
with torch_npu.npu.device(d0):
with self.assertRaises(RuntimeError):
self.assertGreater(e0.elapsed_time(e1), 0)
with torch_npu.npu.device(d1):
with self.assertRaises(RuntimeError):
self.assertGreater(e0.elapsed_time(e1), 0)
with torch_npu.npu.device(d0):
s0 = torch_npu.npu.current_stream()
e2 = torch_npu.npu.Event(enable_timing=True)
torch_npu.npu._sleep(TestNpuMultiNpu.FIFTY_MIL_CYCLES)
s0.record_event(e2)
s0.synchronize()
self.assertGreater(e0.elapsed_time(e2), 0)
with torch_npu.npu.device(d1):
self.assertGreater(e0.elapsed_time(e2), 0)
@contextlib.contextmanager
def _get_external_stream(self, device):
cudart = torch_npu.npu.cudart()
stream = ctypes.c_ulonglong(0)
stream_p = ctypes.POINTER(ctypes.c_void_p)(stream)
stream_p_int = ctypes.cast(stream_p, ctypes.c_void_p).value
with device:
try:
out = cudart.cudaStreamCreate(stream_p_int)
self.assertEqual(out, 0)
self.assertNotEqual(stream.value, 0)
yield stream.value
finally:
out = cudart.cudaStreamDestroy(stream.value)
self.assertEqual(out, 0)
def test_external_streams(self):
device = torch_npu.npu.device(0)
with self._get_external_stream(device) as stream_v:
ext_stream = torch_npu.npu.ExternalStream(stream_v)
self.assertEqual(stream_v, ext_stream.cuda_stream)
self.assertEqual(ext_stream.device.index, device.idx)
@unittest.skipIf(not TEST_MULTINPU, "detected only one NPU")
def test_external_streams_multi_device(self):
device = torch_npu.npu.device(1)
with self._get_external_stream(device) as stream_v:
ext_stream = torch_npu.npu.ExternalStream(
stream_v, device=device)
self.assertEqual(stream_v, ext_stream.cuda_stream)
self.assertEqual(ext_stream.device.index, device.idx)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_caching_pinned_memory_multi_gpu(self):
cycles_per_ms = get_cycles_per_ms()
t = torch.FloatTensor([1]).pin_memory()
ptr = t.data_ptr()
gpu_tensor0 = torch_npu.npu.FloatTensor([0], device=0)
gpu_tensor1 = torch_npu.npu.FloatTensor([0], device=1)
with torch_npu.npu.device(1):
torch_npu.npu._sleep(int(1000 * cycles_per_ms))
gpu_tensor1.copy_(t, non_blocking=True)
del t
t = torch.FloatTensor([2]).pin_memory()
self.assertNotEqual(t.data_ptr(), ptr, msg='allocation re-used too soon')
with torch_npu.npu.device(0):
gpu_tensor0.copy_(t, non_blocking=True)
self.assertEqual(gpu_tensor1[0], 1)
self.assertEqual(gpu_tensor0[0], 2)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_get_set_rng_state_all(self):
states = torch_npu.npu.get_rng_state_all()
before0 = torch_npu.npu.FloatTensor(100, device=0).normal_()
before1 = torch_npu.npu.FloatTensor(100, device=1).normal_()
torch_npu.npu.set_rng_state_all(states)
after0 = torch_npu.npu.FloatTensor(100, device=0).normal_()
after1 = torch_npu.npu.FloatTensor(100, device=1).normal_()
self.assertEqual(before0, after0, atol=0, rtol=0)
self.assertEqual(before1, after1, atol=0, rtol=0)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_rng_state_offset(self):
before = torch_npu.npu.get_rng_state()
torch_npu.npu._set_rng_state_offset(100)
offset = torch_npu.npu._get_rng_state_offset()
torch_npu.npu.set_rng_state(before)
self.assertEqual(offset, 100)
def test_mem_get_info(self):
def _test(idx):
before_free_bytes, before_available_bytes = torch_npu.npu.mem_get_info(idx)
t = torch.randn(1024 * 1024 * 8, device='npu:' + str(idx))
if IS_JETSON:
torch_npu.npu.synchronize()
after_free_bytes, after_available_bytes = torch_npu.npu.mem_get_info(idx)
self.assertLess(after_free_bytes, before_free_bytes)
self.assertEqual(before_available_bytes, after_available_bytes)
_test(0)
if TEST_MULTINPU:
_test(1)
def test_cuda_memory_leak_detection(self):
ln = []
@self.wrap_with_cuda_memory_check
def no_leak():
pass
@self.wrap_with_cuda_memory_check
def leak_gpu0():
ln.append(torch.randn(1024 * 1024 * 8, device=torch.device("npu:0")))
no_leak()
regex = r"NPU driver API confirmed .+ on device 0.+"
if IS_JETSON:
try:
leak_gpu0()
except RuntimeError as e:
import re
assert re.match(regex, str(e)), str(e) + "\n does not match: \n" + regex
else:
with self.assertRaisesRegex(RuntimeError, regex):
leak_gpu0()
if TEST_MULTINPU:
@self.wrap_with_cuda_memory_check
def leak_gpu1():
ln.append(torch.randn(1024 * 1024 * 8, device=torch.device("npu:1")))
with self.assertRaisesRegex(RuntimeError, r"NPU driver API confirmed .+ on device 1.+"):
leak_gpu1()
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_streaming_backwards_device_transfer(self):
dev0 = torch.device("npu:0")
dev1 = torch.device("npu:1")
size = 2**26
a = torch.full((size,), 1, device=dev1, dtype=torch.float64, requires_grad=True)
b = torch.full((size,), 1, device=dev1, dtype=torch.float64, requires_grad=True)
to_backward_recipient = a * b
s = to_backward_recipient.to(device="npu:0").sum()
torch_npu.npu.synchronize(device=dev0)
torch_npu.npu.synchronize(device=dev1)
s.backward()
self.assertTrue(a.grad.sum().item() == size)
self.assertTrue(b.grad.sum().item() == size)
a.grad = None
b.grad = None
to_backward_recipient = a * b
s0 = to_backward_recipient.to(device="npu:0").sum() * 2.
s1 = to_backward_recipient.to(device="npu:0").sum() * 2.
torch_npu.npu.synchronize(device=dev0)
torch_npu.npu.synchronize(device=dev1)
s0.backward(retain_graph=True)
s1.backward()
self.assertTrue(a.grad.sum().item() == 4 * size)
self.assertTrue(b.grad.sum().item() == 4 * size)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
@unittest.skipIf(IS_SANDCASTLE or IS_REMOTE_GPU, "Does not work on Sandcastle")
def test_cuda_init_race(self):
import subprocess
subprocess.check_call([sys.executable, '-c', """\
import torch
import torch_npu
import threading
def worker(rank):
torch.tensor([1.]).npu(rank)
t1 = threading.Thread(target=worker, args=(0,))
t2 = threading.Thread(target=worker, args=(1,))
t1.start()
t2.start()
"""])
def test_grad_scaling_unscale(self, dtype=torch.float):
inv_scale = torch.full((1,), 0.25, dtype=torch.float, device="npu:0")
found_inf = torch.full((1,), 0.0, dtype=torch.float, device="npu:0")
size = 10
g = torch.full((size, size), 4.0, dtype=dtype, device="npu:0")
ginf = g.clone()
ginf[2, 2] = float('inf')
gnan = g.clone()
gnan[2, 2] = float('nan')
cases = (
([g.clone(), g.clone()], False),
([g.clone(), g.clone().t()], False),
([g.clone(), g.clone()[:, :5]], False),
([g.clone()[:, :5], g.clone()[:, :5]], False),
([g.clone(), ginf.clone()], True),
([g.clone(), gnan.clone()], True),
([g.clone(), ginf.clone()[:, :5]], True),
([g.clone(), gnan.clone()[:, :5]], True),
([ginf.clone(), g.clone()[:, :5]], True),
([ginf.clone()[:, :5], g.clone()[:, :5]], True),
)
for grads, has_inf in cases:
found_inf.zero_()
torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale)
if has_inf:
self.assertEqual(found_inf, 1.0)
else:
self.assertEqual(found_inf, 0.0)
for grad in grads:
self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7)
grads = [g.clone(), g.to(dtype=torch.float16)]
torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale)
for grad in grads:
self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7)
if TEST_MULTINPU:
with self.assertRaisesRegex(RuntimeError, r"Expected all tensors to be on the same device"):
torch._amp_foreach_non_finite_check_and_unscale_([g.clone(), g.to(device="cuda:1")],
found_inf,
inv_scale)
def perfect_storm_grads(inject_inf):
grads = [g.clone(), g.clone()[:, :5], g.to(dtype=torch.float16), g.to(dtype=torch.float16)]
if TEST_MULTINPU:
grads += [g.to(device="npu:1"),
g.to(device="npu:1")[:, :5],
g.to(device="npu:1", dtype=torch.float16),
g.to(device="npu:1", dtype=torch.float16)]
if inject_inf >= 0:
grads[inject_inf][2, 2] = float('inf')
return grads
scaler = torch_npu.npu.amp.GradScaler()
dummy_params = [torch.empty_like(g) for g in perfect_storm_grads(-1)]
dummy_opt = torch.optim.SGD(dummy_params, lr=1.)
for inject_inf in range(-1, len(dummy_params)):
found_inf = torch.full((1,), 0.0, dtype=torch.float, device="npu:0")
grads = perfect_storm_grads(inject_inf)
for i, p in enumerate(dummy_params):
p.grad = grads[i]
found_inf_per_device = scaler._unscale_grads_(dummy_opt, inv_scale, found_inf, True)
if inject_inf < 0:
self.assertTrue(sum(v.item() for v in found_inf_per_device.values()) == 0)
for grad in grads:
self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7)
else:
self.assertTrue(sum(v.item() for v in found_inf_per_device.values()) == 1)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_grad_scaling_device_as_key(self):
d = {}
t = torch.empty((1,), device="npu:0")
dev0a = torch.device("npu:0")
dev0b = torch.device("npu:0")
dev1a = torch.device("npu:1")
dev1b = torch.device("npu:1")
self.assertTrue(hash(dev0a) == hash(dev0b))
self.assertTrue(hash(dev1a) == hash(dev1b))
d[dev0a] = "0a"
d[dev0b] = "0b"
self.assertTrue(len(d) == 1)
self.assertTrue(d[dev0a] == "0b")
d[t.device] = "t"
self.assertTrue(len(d) == 1)
self.assertTrue(d[dev0a] == "t")
d[dev1a] = "1a"
d[dev1b] = "1b"
self.assertTrue(len(d) == 2)
self.assertTrue(d[dev1a] == "1b")
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_grad_scaling_scale(self):
scaler = torch_npu.npu.amp.GradScaler(init_scale=2.)
t0 = torch.full((1,), 4.0, dtype=torch.float32, device="npu:0")
t1 = torch.full((1,), 4.0, dtype=torch.float32, device="npu:1")
outputs = (t1.clone(), (t0.clone(), t1.clone()), [t0.clone(), (t1.clone(), t0.clone())])
outputs = scaler.scale(outputs)
self.assertTrue(outputs[0] == 8.0 and outputs[1][0] == 8.0 and outputs[1][1] == 8.0 and
outputs[2][0] == 8.0 and outputs[2][1][0] == 8.0 and outputs[2][1][1] == 8.0)
self.assertTrue(scaler._scale.device == t1.device)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_grad_scaling_multigpu(self):
dev0 = torch.device("npu:0")
dev1 = torch.device("npu:1")
for enabled in True, False:
mod_control0, mod_scaling0, opt_control0, opt_scaling0, data, loss_fn, skip_iter = \
_create_scaling_case()
mod_control1, mod_scaling1, opt_control1, opt_scaling1 = \
_create_scaling_models_optimizers(device=dev1)
scaler = torch_npu.npu.amp.GradScaler(init_scale=128., growth_factor=2.0,
enabled=enabled, growth_interval=1)
def run(model0, model1, optimizer0, optimizer1, try_scaling_api):
for i, (input1, target) in enumerate(data):
optimizer0.zero_grad()
optimizer1.zero_grad()
output0 = model0(input1)
output1 = model1(input1.to(dev1))
loss0 = loss_fn(0.3 * output0 + 0.7 * output1.to(dev0), target)
loss1 = loss_fn(0.6 * output0.to(dev1) - 0.4 * output1, target.to(dev1))
if try_scaling_api:
scaler.scale(loss0).backward(retain_graph=True)
scaler.scale(loss1).backward()
if i == skip_iter and scaler.is_enabled():
model1[1].weight.grad.data.fill_(float('inf'))
scaler.unscale_(optimizer0)
scaler.step(optimizer0)
scaler.step(optimizer1)
if scaler.is_enabled():
self.assertTrue(len(scaler._found_inf_per_device(optimizer0)) == 1)
self.assertTrue(len(scaler._found_inf_per_device(optimizer1)) == 1)
self.assertTrue(scaler._found_inf_per_device(optimizer0)[dev0].item() == 0.)
self.assertTrue(scaler._found_inf_per_device(optimizer1)[dev1].item() ==
float(i == skip_iter))
scaler.update()
else:
loss0.backward(retain_graph=True)
loss1.backward()
optimizer0.step()
if (not scaler.is_enabled()) or (i != skip_iter):
optimizer1.step()
run(mod_control0, mod_control1, opt_control0, opt_control1, False)
run(mod_scaling0, mod_scaling1, opt_scaling0, opt_scaling1, True)
self.assertTrue(scaler.get_scale() == (128. * scaler.get_growth_factor()**3 *
scaler.get_backoff_factor()**1) if enabled else 1.0)
mod_control1.to(dev0)
mod_scaling1.to(dev0)
for c, s in zip(chain(mod_control0.parameters(), mod_control1.parameters()),
chain(mod_scaling0.parameters(), mod_scaling1.parameters())):
self.assertEqual(c, s, rtol=1e-5, atol=1e-7)
@unittest.skipIf(not TEST_MULTINPU, "Test needs multiple NPUs")
def test_cuda_device_memory_allocated(self):
from torch.cuda import memory_allocated
device_count = torch_npu.npu.device_count()
current_alloc = [memory_allocated(idx) for idx in range(device_count)]
x = torch.ones(10, device="npu:0")
self.assertGreater(memory_allocated(0), current_alloc[0])
self.assertTrue(all(memory_allocated(torch_npu.npu.device(idx)) ==
current_alloc[idx] for idx in range(1, device_count)))
class TestNpuComm(TestCase):
def _test_broadcast(self, input1):
if not TEST_MULTINPU:
raise unittest.SkipTest("only one NPU detected")
results = comm.broadcast(input1, (0, 1))
for i, t in enumerate(results):
self.assertEqual(t.get_device(), i)
self.assertEqual(t, input1)
if input1.is_npu and input1.get_device() == i:
self.assertEqual(t.data_ptr(), input1.data_ptr())
for inplace in [True, False]:
if inplace:
outputs = [torch.empty_like(input1, device=0), torch.empty_like(input1, device=1)]
else:
outputs = [input1.npu(0), torch.empty_like(input1, device=1)]
results = comm.broadcast(input1, out=outputs)
for r, s in zip(results, outputs):
self.assertIs(r, s)
for i, t in enumerate(results):
self.assertEqual(t.get_device(), i)
self.assertEqual(t, input1)
with self.assertRaisesRegex(RuntimeError, r"Exactly one of 'devices' and 'out'"):
comm.broadcast(input1, (0, 1), out=outputs)
with self.assertRaisesRegex(RuntimeError,
r"Expected all output tensors to be NPU tensors, but output tensor at index 1"):
comm.broadcast(input1, out=[input1.npu(0), input1.cpu()])
with self.assertRaisesRegex(RuntimeError,
r"Expected all output tensors to have same shape as the source .+ at index 1"):
comm.broadcast(input1, out=[input1.npu(0), input1.npu(1).unsqueeze(0)])
def test_broadcast_cpu(self):
self._test_broadcast(torch.randn(5, 5))
def test_broadcast_gpu(self):
self._test_broadcast(torch.randn(5, 5).npu())
def _test_broadcast_coalesced(self, tensors, buffer_size):
b_tensors = [comm.broadcast(t, (0, 1)) for t in tensors]
for (_, bt), t in zip(b_tensors, tensors):
self.assertEqual(bt.get_device(), 1)
self.assertEqual(bt, t)
self.assertIsInstance(bt, type(t))
bc_tensors = comm.broadcast_coalesced(tensors, (0, 1), buffer_size=buffer_size)
bc_tensors_t = list(zip(*bc_tensors))
self.assertEqual(b_tensors, bc_tensors_t)
for (_, bt), (_, bct) in zip(b_tensors, bc_tensors_t):
self.assertEqual(bt.get_device(), bct.get_device())
self.assertIsInstance(bct, type(bt))
for out_tensors in (b_tensors, bc_tensors_t):
for inp_t, (out_t, _) in zip(tensors, out_tensors):
self.assertIs(inp_t, out_t)
versions = [t._version for _, t in bc_tensors_t]
for old_version, (_, t) in zip(versions, bc_tensors_t):
self.assertEqual(t._version, old_version)
t.zero_()
self.assertEqual(t._version, old_version + 1)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_broadcast_coalesced(self):
numel = 5
num_bytes = numel * 8
tensors = [
self.genSparseTensor((2, 3), 2, 1, False, 'npu', torch.float64)[0],
torch.randn(numel).long().npu(),
torch.randn(numel).npu(),
self.genSparseTensor((2, 3), 2, 10, False, 'npu', torch.float64)[0],
self.genSparseTensor((2, 3), 2, 5, False, 'npu', torch.float64)[0],
self.genSparseTensor((3, 3), 2, 7, False, 'npu', torch.int64)[0],
self.genSparseTensor((2, 3), 2, 2, False, 'npu', torch.float32)[0],
torch.randn(numel).long().npu(),
torch.randn(numel).long().npu(),
self.genSparseTensor((2, 7), 2, 3, False, 'npu', torch.int64)[0],
torch.randn(numel * 2).int().npu(),
torch.randn(numel).npu(),
]
self._test_broadcast_coalesced(tensors, num_bytes * 5 // 2)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_broadcast_coalesced_dense_only(self):
numel = 5
num_bytes = numel * 8
tensors = [
torch.randn(numel).long().npu(),
torch.randn(numel).npu(),
torch.randn(numel).long().npu(),
torch.randn(numel).long().npu(),
torch.randn(numel * 2).int().npu(),
torch.randn(numel).npu(),
]
self._test_broadcast_coalesced(tensors, num_bytes * 5 // 2)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_broadcast_coalesced_empty_tensors(self):
tensors = [
torch.tensor([]).byte().npu(),
torch.randn(5).npu(),
torch.randn(5).double().npu()
]
self._test_broadcast_coalesced(tensors, 256)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_reduce_add(self):
x = torch.randn(5, 5)
y = torch.randn(5, 5)
x_npu = x.npu(0)
y_npu = y.npu(1)
result = comm.reduce_add((x_npu, y_npu))
self.assertEqual(result.get_device(), 0)
self.assertEqual(result.cpu(), x + y)
def _test_reduce_add_coalesced(self, tensors, buffer_size):
dup_tensors = [tensors, [t.npu(1) for t in tensors]]
r_tensors = [comm.reduce_add(t) for t in zip(*dup_tensors)]
for r, t in zip(r_tensors, tensors):
self.assertEqualTypeString(r, t)
self.assertEqual(r.coalesce() if r.is_sparse else r, t * 2)
rc_tensors = comm.reduce_add_coalesced(dup_tensors, buffer_size=buffer_size)
self.assertEqual(r_tensors, rc_tensors)
for r, rc in zip(r_tensors, rc_tensors):
self.assertEqualTypeString(rc, r)
versions = [t._version for t in rc_tensors]
for old_version, t in zip(versions, rc_tensors):
self.assertEqual(t._version, old_version)
t.zero_()
self.assertEqual(t._version, old_version + 1)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_reduce_add_coalesced(self):
numel = 5
num_bytes = numel * 8
tensors = [
self.genSparseTensor((2, 3), 2, 1, False, 'npu', torch.float64)[0],
torch.randn(numel).long().npu(),
torch.randn(numel).npu(),
self.genSparseTensor((2, 3), 2, 10, False, 'npu', torch.float64)[0],
self.genSparseTensor((2, 3), 2, 5, False, 'npu', torch.float64)[0],
self.genSparseTensor((3, 3), 2, 7, False, 'npu', torch.int64)[0],
self.genSparseTensor((2, 3), 2, 2, False, 'npu', torch.float32)[0],
torch.randn(numel).long().npu(),
torch.randn(numel).long().npu(),
self.genSparseTensor((2, 7), 2, 3, False, 'npu', torch.int64)[0],
torch.randn(numel * 2).int().npu(),
torch.randn(numel).npu(),
]
self._test_reduce_add_coalesced(tensors, num_bytes * 5 // 2)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_reduce_add_coalesced_dense_only(self):
numel = 5
num_bytes = numel * 8
tensors = [
torch.randn(numel).long().npu(),
torch.randn(numel).npu(),
torch.randn(numel).long().npu(),
torch.randn(numel).long().npu(),
torch.randn(numel * 2).int().npu(),
torch.randn(numel).npu(),
]
self._test_reduce_add_coalesced(tensors, num_bytes * 5 // 2)
def _test_scatter(self, input1, chunk_sizes=None, dim=0):
if not TEST_MULTINPU:
raise unittest.SkipTest("only one NPU detected")
if chunk_sizes is None:
ref_chunk_sizes = tuple(repeat(input1.size(dim) // 2, 2))
else:
ref_chunk_sizes = chunk_sizes
result = comm.scatter(input1, (0, 1), chunk_sizes, dim)
self.assertEqual(len(result), 2)
chunk_start = 0
for i, r in enumerate(result):
chunk_end = chunk_start + ref_chunk_sizes[i]
index = [slice(None, None) for _ in range(input1.dim())]
index[dim] = slice(chunk_start, chunk_end)
self.assertEqual(r, input1[tuple(index)], atol=0, rtol=0)
chunk_start = chunk_end
if r.device == input1.device:
self.assertEqual(r.data_ptr(), input1.data_ptr())
out = [torch.empty_like(t) for t in result]
result = comm.scatter(input1, dim=dim, out=out)
self.assertEqual(len(result), 2)
chunk_start = 0
for i, r in enumerate(result):
self.assertIs(r, out[i])
chunk_end = chunk_start + ref_chunk_sizes[i]
index = [slice(None, None) for _ in range(input1.dim())]
index[dim] = slice(chunk_start, chunk_end)
self.assertEqual(r, input1[tuple(index)], atol=0, rtol=0)
chunk_start = chunk_end
if chunk_sizes is not None:
with self.assertRaisesRegex(RuntimeError, r"Expected devices and chunk_sizes to be of same length"):
comm.scatter(input1, [0 for _ in range(len(chunk_sizes) + 1)], dim=dim, chunk_sizes=chunk_sizes)
with self.assertRaisesRegex(RuntimeError, r"'devices' must not be specified"):
comm.scatter(input1, (0, 1), dim=dim, out=out)
with self.assertRaisesRegex(RuntimeError, r"Expected at least one device to scatter to"):
comm.scatter(input1, (), dim=dim)
with self.assertRaisesRegex(RuntimeError, r"Expected at least one output tensor to scatter to"):
comm.scatter(input1, dim=dim, out=[])
with self.assertRaisesRegex(RuntimeError,
r"Expected all output tensors to be NPU tensors, but output tensor at index 0"):
comm.scatter(input1, dim=dim, out=([out[0].cpu()] + out[1:]))
with self.assertRaisesRegex(RuntimeError, r"Output tensor at index 0 has incorrect shape"):
comm.scatter(input1, dim=dim, out=([out[0].unsqueeze(0)] + out[1:]))
with self.assertRaisesRegex(RuntimeError, r"Total size for output tensors along scatter dim \d+ does not match"):
index = [slice(None, None) for _ in range(input1.dim())]
index[dim] = slice(1, None)
comm.scatter(input1, dim=dim, out=([out[0][tuple(index)]] + out[1:]))
def test_scatter_cpu(self):
self._test_scatter(torch.randn(4, 4), dim=0)
def test_scatter_cpu_dim(self):
self._test_scatter(torch.randn(4, 4), dim=1)
def test_scatter_cpu_neg_dim(self):
self._test_scatter(torch.randn(4, 4), dim=-2)
def test_scatter_cpu_sizes(self):
self._test_scatter(torch.randn(6, 4), chunk_sizes=(2, 4))
def test_scatter_gpu(self):
self._test_scatter(torch.randn(4, 4).npu(), dim=0)
def test_scatter_gpu_dim(self):
self._test_scatter(torch.randn(4, 4).npu(), dim=1)
def test_scatter_gpu_neg_dim(self):
self._test_scatter(torch.randn(4, 4).npu(), dim=-2)
def test_scatter_gpu_sizes(self):
self._test_scatter(torch.randn(6, 4).npu(), chunk_sizes=(2, 4))
def _test_gather(self, dim):
if not TEST_MULTINPU:
raise unittest.SkipTest("only one NPU detected")
x = torch.randn(2, 5, device=0)
y = torch.randn(2, 5, device=1)
expected_size = list(x.size())
expected_size[dim] += y.size(dim)
expected_size = torch.Size(expected_size)
destinations = [None, torch.device('npu:0'), torch.device('cpu')]
if torch_npu.npu.device_count() > 2:
destinations.append(torch.device('npu:2'))
with torch_npu.npu.device(1):
for destination in destinations:
if destination is None:
expected_device = torch.device('npu', torch_npu.npu.current_device())
else:
expected_device = destination
for use_out in [True, False]:
if use_out:
out = torch.empty(expected_size, device=expected_device)
result = comm.gather((x, y), dim, out=out)
self.assertIs(out, result)
else:
result = comm.gather((x, y), dim, destination=destination)
self.assertEqual(result.device, expected_device)
self.assertEqual(result.size(), expected_size)
index = [slice(None, None), slice(None, None)]
index[dim] = slice(0, x.size(dim))
self.assertEqual(result[tuple(index)], x)
index[dim] = slice(x.size(dim), x.size(dim) + y.size(dim))
self.assertEqual(result[tuple(index)], y)
with self.assertRaisesRegex(RuntimeError, r"'destination' must not be specified"):
comm.gather((x, y), dim, destination='cpu', out=torch.empty(expected_size, device='cpu'))
with self.assertRaisesRegex(RuntimeError, r"Expected at least one tensor to gather from"):
comm.gather(())
with self.assertRaisesRegex(RuntimeError, r"Expected all input tensors to be NPU tensors, "):
comm.gather((x.cpu(), y))
with self.assertRaisesRegex(RuntimeError, r"Expected all input tensors to have the same number of dimensions"):
comm.gather((x, y.unsqueeze(0)))
with self.assertRaisesRegex(RuntimeError, r"Input tensor at index 1 has invalid shape"):
if dim in [0, -2]:
comm.gather((x, y[:, 1:]), dim=dim)
elif dim in [1, -1]:
comm.gather((x, y[1:, :]), dim=dim)
def test_gather(self):
self._test_gather(0)
def test_gather_dim(self):
self._test_gather(1)
def test_gather_neg_dim(self):
self._test_gather(-1)
@unittest.skipIf(not TEST_MULTINPU, "only one NPU detected")
def test_memory_format_scatter_gather(self):
nhwc = torch.randn((10, 3, 32, 32), device='cpu').contiguous(memory_format=torch.channels_last)
results = torch_npu.npu.comm.scatter(nhwc, (0, 1), None, 0)
for result in results:
self.assertFalse(result.is_contiguous())
self.assertTrue(result.is_contiguous(memory_format=torch.channels_last))
gathered = torch_npu.npu.comm.gather(results)
self.assertTrue(gathered.is_contiguous(memory_format=torch.channels_last))
@unittest.skipIf(not TEST_MULTINPU, "Test needs multiple NPUs")
def test_scatter_namedtuple(self):
fields = ("a", "b")
TestNamedTupleInput_0 = collections.namedtuple("NamedTuple", fields)
num_gpus = torch_npu.npu.device_count()
a = torch.rand(num_gpus * 2, device=0)
b = torch.rand(num_gpus * 2, device=0)
a_tensors_for_gpu = [a[2 * i: 2 * i + 2].to(i) for i in range(num_gpus)]
b_tensors_for_gpu = [b[2 * i: 2 * i + 2].to(i) for i in range(num_gpus)]
inp = TestNamedTupleInput_0(a, b)
target_gpus = [torch.device(i) for i in range(num_gpus)]
scatter_out = scatter_gather.scatter(inp, target_gpus)
for i, x in enumerate(scatter_out):
self.assertTrue(isinstance(x, type(inp)))
self.assertEqual(x._fields, fields)
expected_a = a_tensors_for_gpu[i]
expected_b = b_tensors_for_gpu[i]
self.assertEqual(expected_a, x.a)
self.assertEqual(expected_b, x.b)
class TestNamedTupleInput_1(NamedTuple):
a: torch.tensor
b: torch.tensor
a = torch.rand(num_gpus * 2, device=0)
b = torch.rand(num_gpus * 2, device=0)
a_tensors_for_gpu = [a[2 * i: 2 * i + 2].to(i) for i in range(num_gpus)]
b_tensors_for_gpu = [b[2 * i: 2 * i + 2].to(i) for i in range(num_gpus)]
inp = TestNamedTupleInput_1(a, b)
scatter_out = scatter_gather.scatter(inp, target_gpus)
for i, x in enumerate(scatter_out):
self.assertTrue(isinstance(x, type(inp)))
self.assertEqual(x._fields, fields)
expected_a = a_tensors_for_gpu[i]
expected_b = b_tensors_for_gpu[i]
self.assertEqual(expected_a, x.a)
self.assertEqual(expected_b, x.b)
@unittest.skipIf(not TEST_MULTINPU, "Test needs multiple NPUs")
def test_gather_namedtuple(self):
fields = ['a', 'b']
TestNamedTupleInput_0 = collections.namedtuple('NamedTuple', fields)
num_gpus = torch_npu.npu.device_count()
a = torch.rand(num_gpus * 2, device=0)
b = torch.rand(num_gpus * 2, device=1)
out1 = TestNamedTupleInput_0(a, b)
a = torch.rand(num_gpus * 2, device=1)
b = torch.rand(num_gpus * 2, device=0)
out2 = TestNamedTupleInput_0(a, b)
outputs = [out1, out2]
out = scatter_gather.gather(outputs, 'cpu')
for i, x in enumerate(out):
self.assertTrue(isinstance(x, type(out2[-1])))
cat = torch.cat((outputs[0][i].to('cpu'), outputs[1][i].to('cpu')))
self.assertTrue(torch.equal(x, cat))
out = scatter_gather.gather(outputs, 0)
for i, x in enumerate(out):
self.assertTrue(isinstance(x, type(out2[-1])))
cat = torch.cat((outputs[0][i].to(0), outputs[1][i].to(0)))
self.assertTrue(torch.equal(x, cat))
class TestNamedTupleInput_1(NamedTuple):
a: torch.tensor
b: torch.tensor
a = torch.rand(num_gpus * 2, device=0)
b = torch.rand(num_gpus * 2, device=1)
out1 = TestNamedTupleInput_1(a, b)
a = torch.rand(num_gpus * 2, device=1)
b = torch.rand(num_gpus * 2, device=0)
out2 = TestNamedTupleInput_1(a, b)
outputs = [out1, out2]
out = scatter_gather.gather(outputs, 0)
for i, x in enumerate(out):
self.assertTrue(isinstance(x, type(out2[-1])))
cat = torch.cat((outputs[0][i].to(0), outputs[1][i].to(0)))
self.assertTrue(torch.equal(x, cat))
out = scatter_gather.gather(outputs, 'cpu')
for i, x in enumerate(out):
self.assertTrue(isinstance(x, type(out2[-1])))
cat = torch.cat((outputs[0][i].to('cpu'), outputs[1][i].to('cpu')))
self.assertTrue(torch.equal(x, cat))
instantiate_parametrized_tests(TestNpuMultiNpu)
if __name__ == '__main__':
run_tests()