"""
===============================================================
A script to generate FileCheck statements for mlir unit tests.
===============================================================
This script is a utility to add FileCheck patterns to an mlir file.
NOTE: The input ``.mlir`` is expected to be the output from the parser, not a
stripped down variant.
Example usage:
.. code-block:: shell
$ generate-test-checks.py foo.mlir
$ mlir-opt foo.mlir -transformation | generate-test-checks.py
$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir
$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i
$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @'
The script will heuristically generate CHECK/CHECK-LABEL commands for each line
within the file. By default this script will also try to insert string
substitution blocks for all SSA value names. If ``--source file`` is specified, the
script will attempt to insert the generated CHECKs to the source file by looking
for line positions matched by ``--source_delim_regex``.
The script is designed to make adding checks to a test case fast, it is *not*
designed to be authoritative about what constitutes a good test!
"""
import argparse
import os
import re
import sys
from typing import Optional
ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by "
ADVERT_END = """
// The script is designed to make adding checks to
// a test case fast, it is *not* designed to be authoritative
// about what constitutes a good test! The CHECK should be
// minimized and named to reflect the test intent.
"""
SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*"
SSA_RE = re.compile(SSA_RE_STR)
SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*='
SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR)
ATTR_RE_STR = r'(#[a-zA-Z._-][a-zA-Z0-9._-]*)'
ATTR_RE = re.compile(ATTR_RE_STR)
ATTR_DEF_RE_STR = r'\s*' + ATTR_RE_STR + r'\s*='
ATTR_DEF_RE = re.compile(ATTR_DEF_RE_STR)
class VariableNamer:
def __init__(self, variable_names):
self.scopes = []
self.name_counter = 0
self.generate_in_parent_scope_left = 0
self.variable_names = [name.upper() for name in variable_names.split(',')]
self.used_variable_names = set()
def generate_in_parent_scope(self, n):
self.generate_in_parent_scope_left = n
def generate_name(self, source_variable_name):
variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
if variable_name == '':
variable_name = "VAL_" + str(self.name_counter)
self.name_counter += 1
scope = len(self.scopes) - 1
if self.generate_in_parent_scope_left > 0:
self.generate_in_parent_scope_left -= 1
scope = len(self.scopes) - 2
assert (scope >= 0)
if variable_name in self.used_variable_names:
raise RuntimeError(variable_name + ': duplicate variable name')
self.scopes[scope][source_variable_name] = variable_name
self.used_variable_names.add(variable_name)
return variable_name
def push_name_scope(self):
self.scopes.append({})
def pop_name_scope(self):
self.scopes.pop()
def num_scopes(self):
return len(self.scopes)
def clear_names(self):
self.name_counter = 0
self.used_variable_names = set()
class AttributeNamer:
def __init__(self, attribute_names):
self.name_counter = 0
self.attribute_names = [name.upper() for name in attribute_names.split(',')]
self.map = {}
self.used_attribute_names = set()
def generate_name(self, source_attribute_name):
attribute_name = self.attribute_names.pop(0) if len(self.attribute_names) > 0 else ''
if attribute_name == '':
attribute_name = "ATTR_" + str(self.name_counter)
self.name_counter += 1
attribute_name = '$' + attribute_name
if attribute_name in self.used_attribute_names:
raise RuntimeError(attribute_name + ': duplicate attribute name')
self.map[source_attribute_name] = attribute_name
self.used_attribute_names.add(attribute_name)
return attribute_name
def get_name(self, source_attribute_name) -> Optional[str]:
return self.map.get(source_attribute_name)
def get_num_ssa_results(input_line):
m = SSA_RESULTS_RE.match(input_line)
return m.group().count('%') if m else 0
def process_line(line_chunks, variable_namer):
output_line = ""
for chunk in line_chunks:
m = SSA_RE.match(chunk)
ssa_name = m.group(0) if m is not None else ''
variable = None
for scope in variable_namer.scopes:
variable = scope.get(ssa_name)
if variable is not None:
break
if variable is not None:
output_line += "%[[" + variable + "]]"
else:
variable = variable_namer.generate_name(ssa_name)
output_line += "%[[" + variable + ":.*]]"
output_line += chunk[len(ssa_name):]
return output_line.rstrip() + "\n"
def process_source_lines(source_lines, note, args):
source_split_re = re.compile(args.source_delim_regex)
source_segments = [[]]
for line in source_lines:
if line == note:
continue
if line.find(args.check_prefix) != -1:
continue
if source_split_re.search(line):
source_segments.append([])
source_segments[-1].append(line + "\n")
return source_segments
def process_attribute_definition(line, attribute_namer, output):
m = ATTR_DEF_RE.match(line)
if m:
attribute_name = attribute_namer.generate_name(m.group(1))
line = '// CHECK: #[[' + attribute_name + ':.+]] =' + line[len(m.group(0)):] + '\n'
output.append(line)
def process_attribute_references(line, attribute_namer):
output_line = ''
components = ATTR_RE.split(line)
for component in components:
m = ATTR_RE.match(component)
name = attribute_namer.get_name(m.group(1)) if m else None
if name is None:
output_line += component
else:
output_line += '#[[' + name + ']]'
output_line += component[len(m.group()):]
return output_line
def preprocess_line(line):
output_line = line.replace("[[", "{{\\[\\[}}")
output_line = output_line.replace("[%", "{{\\[}}%")
return output_line
def main():
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--check-prefix", default="CHECK", help="Prefix to use from check file.")
parser.add_argument("-o", "--output", nargs="?", type=argparse.FileType("w"), default=None)
parser.add_argument("input", nargs="?", type=argparse.FileType("r"), default=sys.stdin)
parser.add_argument(
"--source",
type=str,
help="Print each CHECK chunk before each delimiter line in the source"
"file, respectively. The delimiter lines are identified by "
"--source_delim_regex.",
)
parser.add_argument("--source_delim_regex", type=str, default="func @")
parser.add_argument(
"--starts_from_scope",
type=int,
default=1,
help="Omit the top specified level of content. For example, by default "
'it omits "module {"',
)
parser.add_argument("-i", "--inplace", action="store_true", default=False)
parser.add_argument(
"--variable_names", type=str, default='',
help="Names to be used in FileCheck regular expression to represent SSA "
"variables in the order they are encountered. Separate names with commas, "
"and leave empty entries for default names (e.g.: 'DIM,,SUM,RESULT')")
parser.add_argument(
"--attribute_names", type=str, default='', help="Names to be used in FileCheck regular expression to represent "
"attributes in the order they are defined. Separate names with commas,"
"commas, and leave empty entries for default names (e.g.: 'MAP0,,,MAP1')")
args = parser.parse_args()
input_lines = [l.rstrip() for l in args.input]
args.input.close()
script_name = os.path.basename(__file__)
autogenerated_note = ADVERT_BEGIN + "utils/" + script_name + "\n" + ADVERT_END
source_segments = None
if args.source:
source_segments = process_source_lines([l.rstrip() for l in open(args.source, "r")], autogenerated_note, args)
if args.inplace:
assert args.output is None
output = open(args.source, "w")
elif args.output is None:
output = sys.stdout
else:
output = args.output
output_segments = [[]]
variable_namer = VariableNamer(args.variable_names)
attribute_namer = AttributeNamer(args.attribute_names)
for input_line in input_lines:
if not input_line:
continue
process_attribute_definition(input_line, attribute_namer, output_segments[-1])
lstripped_input_line = input_line.lstrip()
is_block = lstripped_input_line[0] == "^"
if is_block:
input_line = input_line.rsplit("//", 1)[0].rstrip()
cur_level = variable_namer.num_scopes()
if lstripped_input_line[0] == "}":
variable_namer.pop_name_scope()
cur_level = variable_namer.num_scopes()
if input_line[-1] == "{":
variable_namer.push_name_scope()
if cur_level == args.starts_from_scope:
output_segments.append([])
num_ssa_results = get_num_ssa_results(input_line)
variable_namer.generate_in_parent_scope(num_ssa_results)
if cur_level < args.starts_from_scope:
continue
if len(output_segments[-1]) == 0:
variable_namer.clear_names()
input_line = preprocess_line(input_line)
input_line = process_attribute_references(input_line, attribute_namer)
ssa_split = input_line.split("%")
if len(output_segments[-1]) != 0 or not ssa_split[0]:
output_line = "// " + args.check_prefix + ": "
output_line += " " * len("-LABEL")
output_line += ssa_split[0]
output_line += process_line(ssa_split[1:], variable_namer)
else:
output_line = "// " + args.check_prefix + "-LABEL: " + ssa_split[0] + "\n"
output_line += "// " + args.check_prefix + "-SAME: "
output_line += process_line(ssa_split[1:], variable_namer)
output_segments[-1].append(output_line)
output.write(autogenerated_note + "\n")
if source_segments:
assert len(output_segments) == len(source_segments), (len(output_segments), len(source_segments))
for check_segment, source_segment in zip(output_segments, source_segments):
for line in check_segment:
output.write(line)
for line in source_segment:
output.write(line)
else:
for segment in output_segments:
output.write("\n")
for output_line in segment:
output.write(output_line)
output.write("\n")
output.close()
if __name__ == "__main__":
main()