efa172af创建于 2023年10月16日历史提交
from functools import wraps, partialmethod



import os

import inspect

import itertools

import torch





def feed_data(func, new_name, *args, **kwargs):

    """

    This internal method decorator feeds the test data item to the test.

    """

    @wraps(func)

    def wrapper(self):

        return func(self, *args, **kwargs)

    wrapper.__name__ = new_name

    wrapper.__wrapped__ = func

    return wrapper





def instantiate_tests(arg=None, **kwargs):



    def wrapper(cls):

        def gen_testcase(cls, func, name, key_list, func_args, value):

            new_kwargs = dict(device="npu") if "device" in func_args else {}

            test_name = name

            for k, v in zip(key_list, value):

                func_key = None

                if k == "format":

                    test_name += ("_" + str(v))

                elif k == "dtype":

                    test_name += ("_" + str(v).split('.')[1])

                for _func_key in func_args:

                    if k in _func_key:

                        if func_key is not None:

                            raise RuntimeError(f"Multiple matches for {k}")

                        func_key = _func_key

                new_kwargs[func_key] = v

            setattr(cls, test_name, feed_data(func, test_name, **new_kwargs))



        for name, func in list(cls.__dict__.items()):

            data = {}

            if hasattr(func, "dtypes"):

                data['dtype'] = func.dtypes

            if hasattr(func, "formats"):

                data['format'] = func.formats



            key_list = data.keys()

            if not key_list:

                continue



            func_args = inspect.getfullargspec(func).args

            value_list = [data.get(key) for key in key_list]

            for value in itertools.product(*value_list):

                gen_testcase(cls, func, name, key_list, func_args, value)



            delattr(cls, name)

        return cls



    return wrapper(arg)





def gen_ops_testcase(cls, func, name, keys, value, op_info):

    new_kwargs = {}

    test_name = f'{func.__name__}_{name}'



    for k, v in zip(keys, value):

        if k == "npu_format":

            test_name += ("_" + str(v))

        elif k == "dtype":

            test_name += ("_" + str(v).split('.')[1])

        new_kwargs[k] = v



    new_kwargs['op'] = op_info

    new_func = partialmethod(func, **new_kwargs)



    setattr(cls, test_name, new_func)

    for decorator in op_info.get_decorators(cls.__name__, func.__name__, 'cpu', value[0], {}):

        setattr(cls, test_name, decorator(new_func))





def gen_op_input(testcase, func, op_info):

    data = {

        'dtype': func.dtypes if hasattr(func, "dtypes") else op_info.dtypesIfNPU, 

        'npu_format': func.formats if hasattr(func, "formats") else op_info.formats

    }



    if 'test_variant_consistency_eager' in testcase:

        if torch.float32 in op_info.dtypesIfNPU:

            data['dtype'] = {torch.float32}

        else:

            data['dtype'] = {list(op_info.dtypesIfNPU)[-1]}



    return data





def instantiate_ops_tests(op_db):



    def wrapper(cls):

        testcases = [x for x in dir(cls) if x.startswith('test_')]

        for testcase in testcases: 

            if hasattr(cls, testcase):

                func = getattr(cls, testcase)

                for op_info in op_db:

                    data = gen_op_input(testcase, func, op_info)

                    keys = data.keys()

                    values = [data.get(key) for key in keys]



                    for value in itertools.product(*values):

                        gen_ops_testcase(cls, func, op_info.name, keys, value, op_info)



                delattr(cls, testcase)



        return cls

        

    return wrapper





class Dtypes(object):



    def __init__(self, *args):

        if (args is None or len(args) == 0):

            raise RuntimeError("No dtypes given")

        if not all(isinstance(arg, torch.dtype) for arg in args):

            raise RuntimeError("Unknown dtype in {0}".format(str(args)))

        self.args = args



    def __call__(self, fn):

        fn.dtypes = self.args

        return fn





class Formats(object):



    def __init__(self, *args):

        if args is None or len(args) == 0:

            raise RuntimeError("No formats given")

        self.args = args



    def __call__(self, fn):

        fn.formats = self.args

        return fn