import ast
import inspect
import logging
import textwrap
import pypto.frontend.parser.pil as pil
LOGGER = logging.getLogger(__name__)
class Expr:
trace = []
def __init__(self, value):
self._item_dict = {}
self._attr_dict = {}
self._value = value
Expr.trace.append(('init', self._value))
def __getitem__(self, item):
Expr.trace.append(('getitem', self._value, item))
return self._item_dict[Expr._normalize_item_key(item)]
def __setitem__(self, item, value):
Expr.trace.append(('setitem', self._value, item, value))
self._item_dict[Expr._normalize_item_key(item)] = value
def __delitem__(self, item):
Expr.trace.append(('delitem', self._value, item))
del self._item_dict[Expr._normalize_item_key(item)]
def __eq__(self, other):
return (
self._attr_dict == other._attr_dict
and self._item_dict == other._item_dict
and self._value == other._value
)
@property
def value(self):
return self._value
@property
def attr_dict(self):
return self._attr_dict
@staticmethod
def clear():
Expr.trace.clear()
@staticmethod
def true(n):
Expr.trace.append(('true', n))
return True
@staticmethod
def false(n):
Expr.trace.append(('false', n))
return False
@staticmethod
def str(n):
Expr.trace.append(('str', n))
return f'str({n})'
@staticmethod
def int(n):
Expr.trace.append(('int', n))
return n
@staticmethod
def decorate(n):
Expr.trace.append(('decorate', n))
def wrapper(func):
Expr.trace.append(('decorate.wrapper', n))
return func
return wrapper
@staticmethod
def attr(method_dict, name):
@property
def field(self):
Expr.trace.append(('getattr', self.value, name))
return self.attr_dict[name]
@field.setter
def field(self, value):
Expr.trace.append(('setattr', self.value, name, value))
self.attr_dict[name] = value
@field.deleter
def field(self):
Expr.trace.append(('delattr', self.value, name))
del self.attr_dict[name]
method_dict[name] = field
@staticmethod
def _normalize_item_key(item):
if isinstance(item, slice):
return ('slice',
Expr._normalize_item_key(item.start),
Expr._normalize_item_key(item.stop),
Expr._normalize_item_key(item.step))
if isinstance(item, tuple):
return tuple(Expr._normalize_item_key(sub_item) for sub_item in item)
return item
attr.__func__(locals(), 'val')
class ContextManager:
def __init__(self, enter_n=None, exit_n=None, init_n=None):
self._enter_n = enter_n
self._exit_n = exit_n
if init_n is not None:
Expr.str(init_n)
def __enter__(self):
if self._enter_n is not None:
Expr.str(self._enter_n)
return self
def __exit__(self, *a):
if self._exit_n is not None:
Expr.str(self._exit_n)
def __eq__(self, other):
return self._enter_n == other._enter_n and self._exit_n == other._exit_n
class ValueError(Exception):
def __init__(self, value):
super().__init__(value)
Expr.trace.append(('error', value))
self._value = value
def __eq__(self, other):
return self._value == other._value
@property
def value(self):
return self._value
class TypeA(ValueError):
pass
class TypeB(ValueError):
pass
class TypeC(ValueError):
pass
class TestParser:
__test__ = False
target_list = []
def __init__(self):
Expr.trace.clear()
TestParser.target_list = []
@staticmethod
def __enter__():
TestParser.target_list.clear()
def __exit__(self, exc_type, exc, tb):
self.run()
@staticmethod
def test(target):
TestParser.target_list.append(target)
@staticmethod
def run_test(global_dict):
for test_name in global_dict:
if test_name.startswith('test_'):
global_dict[test_name]()
@staticmethod
def run_ast(stmt_list):
Expr.clear()
src = ast.unparse(stmt_list)
exec_global = {'Expr': Expr}
try:
LOGGER.debug('%s', '-' * 100)
LOGGER.debug('%s', src)
exec(src, exec_global)
except Exception:
LOGGER.exception(
"Failed to exec generated source:\n%s",
'\n'.join([f'{lineno + 1:3d} | {line}' for lineno, line in enumerate(src.strip().split('\n'))]),
)
raise
run_trace = Expr.trace[:]
run_vardict = {name: value for name, value in exec_global.items() if name.startswith('var_')}
return run_trace, run_vardict
def run(self):
for target in TestParser.target_list:
source_lines, _ = inspect.getsourcelines(target)
source = textwrap.dedent(''.join(source_lines))
stmt_list = ast.parse(source).body[0].body
python_trace, python_vardict = self.run_ast(stmt_list)
pil_trace, pil_vardict = self.run_ast(pil.parse_stmts(stmt_list))
assert python_trace == pil_trace, f'{target.__name__}: {python_trace=} {pil_trace=}'
assert python_vardict == pil_vardict, f'{target.__name__}: {python_vardict=} {pil_vardict=}'