r"""Importing this file must **not** initialize NPU context. test_distributed
relies on this assumption to properly run. This means that when this is imported
no NPU calls shall be made, including torch.npu.device_count(), etc.
torch.testing._internal.common_npu.py can freely initialize NPU context when imported.
"""
from collections import OrderedDict
from collections.abc import Sequence
from contextlib import contextmanager
from numbers import Number
from unittest.result import TestResult
from unittest.util import strclass
import sys
import os
import re
import unittest
import time
import warnings
import random
import inspect
import torch
import expecttest
import numpy as np
from torch import inf
from torch_npu.testing.common_utils import set_npu_device, is_iterable, iter_indices
from torch_npu.testing.common_utils import PERF_TEST_ENABLE, PerfBaseline
IS_IN_CI = os.getenv('IN_CI') == '1'
TEST_REPORT_PATH = os.getenv("TEST_REPORT_PATH", "test-reports")
def run_tests():
argv = sys.argv
if IS_IN_CI:
import xmlrunner
filename = inspect.getfile(sys._getframe(1))
strip_py = re.sub(r'.py$', '', filename)
test_filename = re.sub('/', r'.', strip_py)
test_report_path = os.path.join(TEST_REPORT_PATH, test_filename)
verbose = '--verbose' in argv or '-v' in argv
if verbose:
print(f'Test results will be stored in {test_report_path}')
unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner(output=test_report_path,
verbosity=2 if verbose else 1))
else:
unittest.main(argv=argv)
class TestCase(expecttest.TestCase):
_precision = 1e-5
maxDiff = None
exact_dtype = False
def __init__(self, method_name='runTest'):
super(TestCase, self).__init__(method_name)
@property
def precision(self):
return self._precision
@precision.setter
def precision(self, prec):
self._precision = prec
@classmethod
def setUpClass(cls):
cls.npu_device = set_npu_device()
def setUp(self):
seed = int(os.getenv('SEED', "666"))
torch.manual_seed(seed)
random.seed(seed)
def assertTensorsSlowEqual(self, x, y, prec=None, message=''):
self.assertEqual(x.size(), y.size())
self.assertEqual(x.dtype, y.dtype)
y = y.type_as(x)
if x.dtype == torch.bool:
self.assertEqual(x, y)
else:
max_err = 0
for index in iter_indices(x):
max_err = max(max_err, abs(x[index] - y[index]))
self.assertLessEqual(max_err, prec, message)
def genSparseTensor(self, size, sparse_dim, nnz, is_uncoalesced, device='cpu'):
if not (all(size[d] > 0 for d in range(sparse_dim)) or nnz == 0):
raise RuntimeError('invalid arguments')
v_size = [nnz] + list(size[sparse_dim:])
v = torch.randn(*v_size, device=device)
i = torch.rand(sparse_dim, nnz, device=device)
i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
i = i.to(torch.long)
if is_uncoalesced:
v = torch.cat([v, torch.randn_like(v)], 0)
i = torch.cat([i, i], 1)
x = torch.sparse_coo_tensor(i, v, torch.Size(size))
if not is_uncoalesced:
x = x.coalesce()
else:
x = x.detach().clone()
return x, x._indices().clone(), x._values().clone()
def safeToDense(self, t):
r = self.safeCoalesce(t)
return r.to_dense()
def safeCoalesce(self, t):
tc = t.coalesce()
self.assertEqual(tc.to_dense(), t.to_dense())
self.assertTrue(tc.is_coalesced())
if t._nnz() == 0:
self.assertEqual(t._indices(), tc._indices())
self.assertEqual(t._values(), tc._values())
return tc
value_map = {}
for idx, val in zip(t._indices().t(), t._values()):
idx_tup = tuple(idx.tolist())
if idx_tup in value_map:
value_map[idx_tup] += val
else:
value_map[idx_tup] = val.clone() if isinstance(val, torch.Tensor) else val
new_indices = sorted(list(value_map.keys()))
new_values = [value_map.get(idx) for idx in new_indices]
if t._values().ndimension() < 2:
new_values = t._values().new(new_values)
else:
new_values = torch.stack(new_values)
new_indices = t._indices().new(new_indices).t()
tg = t.new(new_indices, new_values, t.size())
self.assertEqual(tc._indices(), tg._indices())
self.assertEqual(tc._values(), tg._values())
if t.is_coalesced():
self.assertEqual(tc._indices(), t._indices())
self.assertEqual(tc._values(), t._values())
return tg
def assertRtolEqual(self, x, y, prec=1.e-4, prec16=1.e-3, auto_trans_dtype=False, message=None):
def _assertRtolEqual(x, y, prec, prec16, message):
def compare_res(pre, minimum):
diff = y - x
nan_mask = np.isnan(x)
if not np.equal(nan_mask, np.isnan(y)).all():
self.fail(message)
if nan_mask.any():
diff[nan_mask] = 0
result = np.abs(diff)
deno = np.maximum(np.abs(x), np.abs(y))
result_atol = np.less_equal(result, pre)
result_rtol = np.less_equal(result / np.add(deno, minimum), pre)
if not result_rtol.all() and not result_atol.all():
if np.sum(result_rtol == False) > size * pre and np.sum(result_atol == False) > size * pre:
self.fail("result error")
minimum16 = 6e-8
minimum = 10e-10
if isinstance(x, Sequence) and isinstance(y, Sequence):
for x_, y_ in zip(x, y):
_assertRtolEqual(x_, y_, prec, prec16, message)
return
if isinstance(x, torch.Tensor) and isinstance(y, Sequence):
y = torch.as_tensor(y, dtype=x.dtype, device=x.device)
elif isinstance(x, Sequence) and isinstance(y, torch.Tensor):
x = torch.as_tensor(x, dtype=y.dtype, device=y.device)
if torch.is_tensor(x) and torch.is_tensor(y):
if auto_trans_dtype:
x = x.to(y.dtype)
if (x.dtype == torch.bfloat16) and (y.dtype == torch.bfloat16):
if (x.shape != y.shape):
self.fail("shape error!")
result = torch.allclose(x.cpu(), y.cpu(), rtol=prec16, atol=prec16)
if not result:
self.fail("result error!")
return
x = x.detach().cpu().numpy()
y = y.detach().cpu().numpy()
elif isinstance(x, Number) and isinstance(y, Number):
x = np.array(x)
y = np.array(y)
size = x.size
if (x.shape != y.shape):
self.fail("shape error")
if (x.dtype != y.dtype):
self.fail("dtype error")
dtype_list = [np.bool_, np.uint16, np.int16, np.int32, np.float16,
np.float32, np.int8, np.uint8, np.int64, np.float64]
if x.dtype not in dtype_list:
self.fail("required dtype in [np.bool_, np.uint16, np.int16, " +
"np.int32, np.float16, np.float32, np.int8, np.uint8, np.int64]")
if x.dtype == np.bool_:
result = np.equal(x, y)
if not result.all():
self.fail("result error")
elif (x.dtype == np.float16):
compare_res(prec16, minimum16)
elif (x.dtype in [np.float32, np.int8, np.uint8, np.uint16, np.int16, np.int32, np.int64, np.float64]):
compare_res(prec, minimum)
else:
self.fail("required numpy object")
_assertRtolEqual(x, y, prec, prec16, message)
def _assert_tensor_equal(self, a, b, message, exact_dtype, allow_inf, prec):
super(TestCase, self).assertEqual(a.size(), b.size(), message)
if exact_dtype:
self.assertEqual(a.dtype, b.dtype)
if a.numel() > 0:
if (a.device.type == 'cpu' and (a.dtype == torch.float16 or a.dtype == torch.bfloat16)):
a = a.to(torch.float32)
b = b.to(a)
if (a.dtype == torch.bool) != (b.dtype == torch.bool):
raise TypeError("Was expecting both tensors to be bool type.")
if a.dtype == torch.bool and b.dtype == torch.bool:
a = a.to(torch.int)
b = b.to(torch.int)
diff = a - b
if a.dtype.is_complex or a.dtype.is_floating_point:
nan_mask = torch.isnan(a)
self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message)
diff[nan_mask] = 0
if allow_inf:
inf_mask = torch.isinf(a)
inf_sign = inf_mask.sign()
self.assertTrue(torch.equal(inf_sign, torch.isinf(b).sign()), message)
diff[inf_mask] = 0
if diff.is_signed() and diff.dtype != torch.int8:
diff = diff.abs()
if diff.dtype == torch.complex64:
diff = diff.to(torch.float)
elif diff.dtype == torch.complex128:
diff = diff.to(torch.double)
max_err = diff.max()
self.assertLessEqual(max_err, prec, message)
def _assertNumberEqual(self, x, y, prec=None, message='', allow_inf=False, exact_dtype=None):
if isinstance(x, torch.Tensor) and isinstance(y, Number):
self._assertNumberEqual(x.item(), y, prec=prec, message=message,
allow_inf=allow_inf, exact_dtype=exact_dtype)
elif isinstance(y, torch.Tensor) and isinstance(x, Number):
self._assertNumberEqual(x, y.item(), prec=prec, message=message,
allow_inf=allow_inf, exact_dtype=exact_dtype)
else:
if abs(x) == inf or abs(y) == inf:
if allow_inf:
super(TestCase, self).assertEqual(x, y, message)
else:
self.fail("Expected finite numeric values - x={}, y={}".format(x, y))
return
super(TestCase, self).assertLessEqual(abs(x - y), prec, message)
def _assertBoolEqual(self, x, y, prec=None, message='', allow_inf=False, exact_dtype=None):
if isinstance(x, torch.Tensor) and isinstance(y, np.bool_):
self._assertBoolEqual(x.item(), y, prec=prec, message=message,
allow_inf=allow_inf, exact_dtype=exact_dtype)
elif isinstance(y, torch.Tensor) and isinstance(x, np.bool_):
self._assertBoolEqual(x, y.item(), prec=prec, message=message,
allow_inf=allow_inf, exact_dtype=exact_dtype)
else:
super(TestCase, self).assertEqual(x, y, message)
def _assertTensorsEqual(self, x, y, prec=None, message='', allow_inf=False, exact_dtype=None):
super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message)
super(TestCase, self).assertEqual(x.is_quantized, y.is_quantized, message)
if x.is_sparse:
x = self.safeCoalesce(x)
y = self.safeCoalesce(y)
self._assert_tensor_equal(x._indices(), y._indices(), message, exact_dtype, allow_inf, prec)
self._assert_tensor_equal(x._values(), y._values(), message, exact_dtype, allow_inf, prec)
elif x.is_quantized and y.is_quantized:
self.assertEqual(x.qscheme(), y.qscheme(), prec=prec,
message=message, allow_inf=allow_inf,
exact_dtype=exact_dtype)
if x.qscheme() == torch.per_tensor_affine:
self.assertEqual(x.q_scale(), y.q_scale(), prec=prec,
message=message, allow_inf=allow_inf,
exact_dtype=exact_dtype)
self.assertEqual(x.q_zero_point(), y.q_zero_point(),
prec=prec, message=message,
allow_inf=allow_inf, exact_dtype=exact_dtype)
elif x.qscheme() == torch.per_channel_affine:
self.assertEqual(x.q_per_channel_scales(), y.q_per_channel_scales(), prec=prec,
message=message, allow_inf=allow_inf,
exact_dtype=exact_dtype)
self.assertEqual(x.q_per_channel_zero_points(), y.q_per_channel_zero_points(),
prec=prec, message=message,
allow_inf=allow_inf, exact_dtype=exact_dtype)
self.assertEqual(x.q_per_channel_axis(), y.q_per_channel_axis(),
prec=prec, message=message)
self.assertEqual(x.dtype, y.dtype)
self.assertEqual(x.int_repr().to(torch.int32),
y.int_repr().to(torch.int32), prec=prec,
message=message, allow_inf=allow_inf,
exact_dtype=exact_dtype)
else:
self._assert_tensor_equal(x, y, message, exact_dtype, allow_inf, prec)
def assertEqual(self, x, y, prec=None, message='', allow_inf=False, exact_dtype=None):
if exact_dtype is None:
exact_dtype = self.exact_dtype
if isinstance(prec, str) and message == '':
message = prec
prec = None
if prec is None:
prec = self.precision
def _assertEqual(x, y, prec=None, message='', allow_inf=False, exact_dtype=None):
if isinstance(x, Number) or isinstance(y, Number):
self._assertNumberEqual(x, y, prec=prec, message=message,
allow_inf=allow_inf, exact_dtype=exact_dtype)
elif isinstance(x, np.bool_) or isinstance(y, np.bool_):
self._assertBoolEqual(x, y, prec=prec, message=message,
allow_inf=allow_inf, exact_dtype=exact_dtype)
elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
self._assertTensorsEqual(x, y, prec=prec, message=message,
allow_inf=allow_inf, exact_dtype=exact_dtype)
elif isinstance(x, (str, bytes)) and isinstance(y, (str, bytes)):
super(TestCase, self).assertEqual(x, y, message)
elif type(x) == set and type(y) == set:
super(TestCase, self).assertEqual(x, y, message)
elif isinstance(x, dict) and isinstance(y, dict):
if isinstance(x, OrderedDict) and isinstance(y, OrderedDict):
_assertEqual(x.items(), y.items(), prec=prec,
message=message, allow_inf=allow_inf,
exact_dtype=exact_dtype)
else:
_assertEqual(set(x.keys()), set(y.keys()), prec=prec,
message=message, allow_inf=allow_inf,
exact_dtype=exact_dtype)
key_list = list(x.keys())
_assertEqual([x[k] for k in key_list],
[y[k] for k in key_list],
prec=prec, message=message,
allow_inf=allow_inf, exact_dtype=exact_dtype)
elif is_iterable(x) and is_iterable(y):
super(TestCase, self).assertEqual(len(x), len(y), message)
for x_, y_ in zip(x, y):
_assertEqual(x_, y_, prec=prec, message=message,
allow_inf=allow_inf, exact_dtype=exact_dtype)
else:
super(TestCase, self).assertEqual(x, y, message)
_assertEqual(x, y, prec=prec, message=message, allow_inf=allow_inf, exact_dtype=exact_dtype)
def assertAlmostEqual(self, x, y, places=None, msg=None, delta=None, allow_inf=None):
prec = delta
if places:
prec = 10**(-places)
self.assertEqual(x, y, prec, msg, allow_inf)
def assertNotEqual(self, x, y, prec=None, message=''):
if isinstance(prec, str) and message == '':
message = prec
prec = None
if prec is None:
prec = self.precision
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
if x.size() != y.size():
super(TestCase, self).assertNotEqual(x.size(), y.size())
self.assertGreater(x.numel(), 0)
y = y.type_as(x)
nan_mask = x != x
if torch.equal(nan_mask, y != y):
if x.dtype == torch.bool and y.dtype == torch.bool:
x = x.to(torch.int)
y = y.to(torch.int)
diff = x - y
if diff.is_signed():
diff = diff.abs()
diff[nan_mask] = 0
max_err = diff.max().item()
self.assertGreaterEqual(max_err, prec, message)
elif type(x) == str and type(y) == str:
super(TestCase, self).assertNotEqual(x, y)
elif is_iterable(x) and is_iterable(y):
super(TestCase, self).assertNotEqual(x, y)
else:
try:
self.assertGreaterEqual(abs(x - y), prec, message)
return
except (TypeError, AssertionError):
pass
super(TestCase, self).assertNotEqual(x, y, message)
def assertObjectIn(self, obj, iterable):
for elem in iterable:
if id(obj) == id(elem):
return
raise AssertionError("object not found in iterable")
def assertExpectedRaises(self, exc_type, call_fn, *args, **kwargs):
subname = None
if 'subname' in kwargs:
subname = kwargs.get('subname')
del kwargs['subname']
try:
call_fn(*args, **kwargs)
except exc_type as e:
self.assertExpected(str(e), subname)
return
self.fail(msg="Did not raise when expected to")
def assertNotWarn(self, call_fn, msg=''):
r"""
Test if :attr:`call_fn` does not raise a warning.
"""
with self._reset_warning_registry(), warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always")
call_fn()
self.assertTrue(len(ws) == 0, msg)
def assertWarns(self, call_fn, msg=''):
r"""
Test if :attr:`call_fn` raises a warning.
"""
with self._reset_warning_registry(), warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always")
call_fn()
self.assertTrue(len(ws) > 0, msg)
@contextmanager
def maybeWarnsRegex(self, category, regex=''):
"""Context manager for code that *may* warn, e.g. ``TORCH_NPU_WARN_ONCE``.
This filters expected warnings from the test log and fails the test if
any unexpected warnings are caught.
"""
with self._reset_warning_registry(), warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always")
warnings.filterwarnings("ignore", message=regex, category=category)
try:
yield
finally:
msg = 'Caught unexpected warnings:\n' if len(ws) != 0 else None
for w in ws:
msg += warnings.formatwarning(
w.message, w.category, w.filename, w.lineno, w.line)
msg += '\n'
if msg is not None:
self.fail(msg)
@contextmanager
def _reset_warning_registry(self):
r"""
warnings.catch_warnings() in Python 2 misses already registered
warnings. We need to manually clear the existing warning registries to
ensure catching warnings in a scope.
"""
if sys.version_info >= (3,):
yield
return
def assertExpectedStripMangled(self, s, subname=None):
s = re.sub(r'__torch__[^ ]+', '', s)
self.assertExpected(s, subname)
def run(self, result=None):
super(TestCase, self).run(result)
if PERF_TEST_ENABLE:
performanceResult = TestResult()
startTime = time.perf_counter()
super(TestCase, self).run(performanceResult)
stopTime = time.perf_counter()
runtime = stopTime - startTime
if len(performanceResult.errors) == len(performanceResult.failures) == 0:
methodId = strclass(self.__class__) + "." + self._testMethodName
baseline = PerfBaseline.get_baseline(methodId)
if baseline and runtime > baseline * 1.2:
errMsg = "Performance test failed. Performance baseline: " \
+ str(baseline) + "s, current time: " + str(runtime) + "s"
perfErr = (self.failureException, self.failureException(errMsg), None)
self._feedErrorsToResult(result, [(self, perfErr)])
if baseline is None or runtime < baseline * 0.9:
PerfBaseline.set_baseline(methodId, runtime)