from functools import wraps, partialmethod
import os
import inspect
import itertools
import torch
def feed_data(func, new_name, *args, **kwargs):
"""
This internal method decorator feeds the test data item to the test.
"""
@wraps(func)
def wrapper(self):
return func(self, *args, **kwargs)
wrapper.__name__ = new_name
wrapper.__wrapped__ = func
return wrapper
def instantiate_tests(arg=None, **kwargs):
def wrapper(cls):
def gen_testcase(cls, func, name, key_list, func_args, value):
new_kwargs = dict(device="npu") if "device" in func_args else {}
test_name = name
for k, v in zip(key_list, value):
func_key = None
if k == "format":
test_name += ("_" + str(v))
elif k == "dtype":
test_name += ("_" + str(v).split('.')[1])
for _func_key in func_args:
if k in _func_key:
if func_key is not None:
raise RuntimeError(f"Multiple matches for {k}")
func_key = _func_key
new_kwargs[func_key] = v
setattr(cls, test_name, feed_data(func, test_name, **new_kwargs))
for name, func in list(cls.__dict__.items()):
data = {}
if hasattr(func, "dtypes"):
data['dtype'] = func.dtypes
if hasattr(func, "formats"):
data['format'] = func.formats
key_list = data.keys()
if not key_list:
continue
func_args = inspect.getfullargspec(func).args
value_list = [data.get(key) for key in key_list]
for value in itertools.product(*value_list):
gen_testcase(cls, func, name, key_list, func_args, value)
delattr(cls, name)
return cls
return wrapper(arg)
def gen_ops_testcase(cls, func, name, keys, value, op_info):
new_kwargs = {}
test_name = f'{func.__name__}_{name}'
for k, v in zip(keys, value):
if k == "npu_format":
test_name += ("_" + str(v))
elif k == "dtype":
test_name += ("_" + str(v).split('.')[1])
new_kwargs[k] = v
new_kwargs['op'] = op_info
new_func = partialmethod(func, **new_kwargs)
setattr(cls, test_name, new_func)
for decorator in op_info.get_decorators(cls.__name__, func.__name__, 'cpu', value[0], {}):
setattr(cls, test_name, decorator(new_func))
def gen_op_input(testcase, func, op_info):
data = {
'dtype': func.dtypes if hasattr(func, "dtypes") else op_info.dtypesIfNPU,
'npu_format': func.formats if hasattr(func, "formats") else op_info.formats
}
if 'test_variant_consistency_eager' in testcase:
if torch.float32 in op_info.dtypesIfNPU:
data['dtype'] = {torch.float32}
else:
data['dtype'] = {list(op_info.dtypesIfNPU)[-1]}
return data
def instantiate_ops_tests(op_db):
def wrapper(cls):
testcases = [x for x in dir(cls) if x.startswith('test_')]
for testcase in testcases:
if hasattr(cls, testcase):
func = getattr(cls, testcase)
for op_info in op_db:
data = gen_op_input(testcase, func, op_info)
keys = data.keys()
values = [data.get(key) for key in keys]
for value in itertools.product(*values):
gen_ops_testcase(cls, func, op_info.name, keys, value, op_info)
delattr(cls, testcase)
return cls
return wrapper
class Dtypes(object):
def __init__(self, *args):
if (args is None or len(args) == 0):
raise RuntimeError("No dtypes given")
if not all(isinstance(arg, torch.dtype) for arg in args):
raise RuntimeError("Unknown dtype in {0}".format(str(args)))
self.args = args
def __call__(self, fn):
fn.dtypes = self.args
return fn
class Formats(object):
def __init__(self, *args):
if args is None or len(args) == 0:
raise RuntimeError("No formats given")
self.args = args
def __call__(self, fn):
fn.formats = self.args
return fn