import argparse
import logging
import re
import sys
import difflib
import platform
from textwrap import indent
from functools import partial
from utils import complete_path, process_condition, parse_condition_from_str_to_set, escape, safe_print
ASSERT_FLAG = "ASSERT"
EXPECTED_FLAG = "EXPECTED"
EXPECTED_REGEX = r"\:{line_num}\:.*\:.*"
PATH_PATTERN = r'(/[^<>"\\|?]+)+\:\d+'
SCAN_KEYWORDS = ["full", "not", "begin", "next", "end", "after", "txt"]
REGEX_KEYWORDS = ["auto", "not", "begin", "next", "end", "after"]
EXPECTED_KEYWORDS = ["regex", "regex-auto", "regex-not"]
COLOR_CODES = { "RED": 31, "GREEN": 32, "YELLOW": 33, "BLUE": 34, "MAGENTA": 35, "CYAN": 36, "GRAY": 90 }
TRANSFER_BLACK_LIST = ["Internal error"]
global_group_dict = dict()
IS_WINDOWS = True if platform.system() == "Windows" else False
multiline_pattern = r'(SCAN(-IN|-TXT)*)'
multiscan_pattern = r'/\*\s*SCAN'
class CompareError(Exception):
pass
def check_condition(compare_lines, conditions):
res = []
for line in compare_lines:
src = line[0]
if src.startswith(ASSERT_FLAG) or src.startswith(EXPECTED_FLAG):
res.append(line)
continue
else:
cur_con = re.match(r'\([0-9a-zA-Z_\|\&\!\(\) ]+\)', src)
if cur_con and process_condition(cur_con.group(), conditions):
tmp = line.copy()
tmp[0] = tmp[0].replace(cur_con.group(), "")
res.append(tmp)
return res
def multiline_check(input_str, expect_str: str, in_mode, txt_mode):
def print_error(a, b):
d = difflib.Differ()
diff = d.compare(a, b)
try:
safe_print('\n'.join(list(diff)))
except Exception as e:
safe_print('\nMultiLine SCAN-IN Compare Failed!')
safe_print("[ERROR] Output string print error. {}".format(e.args))
if IS_WINDOWS and not txt_mode:
if re.findall(PATH_PATTERN, input_str):
input_str = input_str.replace('/', '\\')
expect_str = expect_str.replace('/', '\\')
if in_mode:
input_str = input_str.strip()
if len(input_str) != 0 and input_str in expect_str:
return True
else:
print_error(input_str.splitlines(), expect_str.splitlines())
return False
else:
i_lines = input_str.splitlines()
e_lines = expect_str.splitlines()
if len(i_lines) != len(e_lines):
print_error(i_lines, e_lines)
return False
else:
l = len(i_lines)
for i in range(0, l):
if i_lines[i] != e_lines[i]:
print_error(i_lines, e_lines)
return False
return True
def main():
opts = parse_cli()
case_path = opts.case_path
com_opt = opts.com_opt
comment = opts.comment
condition = parse_condition_from_str_to_set(opts.condition)
compare_object = opts.compare_object
assert_flags = opts.assert_flag
transfer = opts.transfer
compare_number = opts.compare_number
if not assert_flags:
assert_flags.append(ASSERT_FLAG)
expected_flags = opts.expected_flag
if not expected_flags:
expected_flags.append(EXPECTED_FLAG)
try:
content = compare_object.read()
except UnicodeDecodeError as e:
safe_print("Wrong stdin encoding:{}\n{}".format(e.object, e))
sys.exit(1)
content_line_map = gen_line_map(content)
safe_print("compare.py input:")
safe_print(indent(content, "\t", lambda line: True))
safe_print("compare.py input end\n")
if compare_object.isatty():
sys.stderr.write("ERROR: require compare objects, filepath or stdin \n")
sys.exit(253)
compare_line_regex = gen_compare_regex(comment, assert_flags, expected_flags)
compare_lines = extract_compare_lines(case_path, compare_line_regex)
compare_lines = check_condition(compare_lines, condition)
multiline_compares = parse_all_multiline_comment(case_path, compare_number)
compare_result = True
start = 0
if not compare_lines and not multiline_compares:
safe_print("[ERROR]`ASSERT` or `EXPECTED` or 'SCAN' key words not found in your case!")
sys.exit(1)
if com_opt == 'O2':
patterns = [r':[0-9]+:[0-9]+:', r':[0-9]+:', r':[0-9]+']
for pattern in patterns:
x = re.compile(pattern)
for y in range(len(compare_lines)):
if 'regex' in compare_lines[y][0]:
compare_lines[y][0] = x.sub(pattern, compare_lines[y][0])
elif 'scan' in compare_lines[y][0]:
compare_lines[y][0] = x.sub(pattern, compare_lines[y][0]).replace('scan', 'regex')
if compare_lines:
safe_print("Starting SingleLine Compare:")
for compare_line in compare_lines:
output_line_num = text_index_to_line_num(content_line_map, start)
compare_line, line_num = compare_line
flag, compare_pattern = split_compare_line(compare_line)
pattern_flag, pattern = split_pattern_line(compare_pattern)
info = ""
keywords = pattern_flag.split("-")
if flag.strip() in assert_flags:
info = "It's a assert, "
elif flag.strip() in expected_flags:
if "auto" in keywords:
pattern = r"\s+".join([re.escape(word) for word in pattern.split()])
keywords.remove("auto")
pattern = EXPECTED_REGEX.format(line_num=line_num) + pattern
if pattern_flag.strip() not in EXPECTED_KEYWORDS:
raise CompareError(
"Unsupport expected keywords: {!r}".format(pattern_flag)
)
else:
raise CompareError("Unsupport flag: {!r}".format(flag))
count = -1
for keyword in keywords:
if keyword.isnumeric():
count = int(keyword)
match_func = gen_match_func(keywords)
if "next" not in keywords and "end" not in keywords and "after" not in keywords:
start = 0
if "scan" in keywords:
single_txt = True if "txt" in keywords else False
if IS_WINDOWS and not single_txt:
if re.findall(PATH_PATTERN, pattern):
pattern = pattern.replace('/', '\\\\')
pattern = update_pattern(pattern)
result, start = match_func(content, content_line_map, pattern, start, count=count)
info += "flag: {}, pattern: {} , result: {}, matched at output line: {} ".format(
pattern_flag, pattern, result, output_line_num
)
safe_print(info.encode(encoding='utf-8', errors='ignore'))
safe_print('back reference dict:{}'.format(global_group_dict))
if result is False:
safe_print("SingleLine Compare Failed")
else:
safe_print('SingleLine Compare End')
compare_result &= result
if multiline_compares:
index = 1
for i in multiline_compares:
safe_print("Start MultiLine Compare {}: {} ...".format(index, i[:min(20, len(i))]))
index += 1
start = re.match(r'/\*\s*' + multiline_pattern + r'\n', i)
multi_str = i[len(start.group()):-2]
multi_str = handle_ansi_attributes(multi_str)
res = multiline_check(input_str=multi_str, expect_str=content, in_mode='-IN' in start.group(),
txt_mode='-TXT' in start.group())
if not res:
compare_result = False
diagkind = re.findall(r'("DiagKind": .*")', str(content))
if transfer:
safe_print("[WARNING] trying to transfer cur case {} into multi line compare case.".format(transfer))
with open(transfer, 'r', encoding='utf-8') as f:
src = f.read()
pattern = r'(//.*ASSERT.*|//.*EXPECTED.*)'
new_content = re.sub(pattern, "", src).rstrip()
new_content = re.sub(r'/\*.*SCAN[\s\S]*\*/', "", new_content, re.M).rstrip()
content = re.sub(pattern, "", str(content))
if multiline_compares:
for multiline_compare in multiline_compares:
new_content = new_content.replace(multiline_compare, "").rstrip()
for bad_words in TRANSFER_BLACK_LIST:
if bad_words in content:
raise Exception('Unexpected output <{}> found! Please update case manually!'.format(bad_words))
if len(diagkind) == 0:
new_content = new_content + "\n\n/* SCAN\n{}*/\n".format(content)
else:
for diag in diagkind:
diag = diag.split("\": \"")[1].rstrip("\"")
if not diag in new_content:
new_content = new_content.rstrip() + "\n// ASSERT: scan {}\n".format(diag)
with open(transfer, 'w', encoding='utf-8') as f:
f.write(new_content)
safe_print("[WARNING] {} transfer into multi line compare case done.".format(transfer))
if compare_result is True:
safe_print("[Compare Pass]")
return 0
else:
safe_print("[Compare Fail]")
sys.exit(1)
def handle_ansi_attributes(input_str):
colors_stack = []
def replace_color(match):
color_name_to_set = match.group("color")
if color_name_to_set:
if color_name_to_set not in COLOR_CODES:
raise ValueError("Unsupported color '%s' in color directive. Supported colors: %s"
% (color_name_to_set, ", ".join(COLOR_CODES.keys())))
colors_stack.append(color_name_to_set)
return "\x1b[%dm" % COLOR_CODES[color_name_to_set]
elif match.group("unset_color"):
if len(colors_stack) == 0:
raise ValueError("There is no corresponding set color directive for the unset directive {/COLOR}.")
colors_stack.pop()
set_previous_color_directive = "\x1b[%dm" % COLOR_CODES[colors_stack[-1]] if len(colors_stack) > 0 else ""
return "\x1b[0m%s" % set_previous_color_directive
elif match.group("escape"):
return "\x1b"
else:
raise ValueError("The ansi code directive should be either {COLOR=<color_name>} or {/COLOR}")
return re.sub(r"{(?:COLOR=(?P<color>\w+)|(?P<unset_color>/COLOR)|(?P<escape>ESC))}", replace_color, input_str)
def split_compare_line(compare_line):
if len(compare_line.lstrip().split(":", 1)) < 2:
safe_print(
"Please check compare line, found compare flag but no actual compare content!!!"
)
raise CompareError(
"Please check compare line, found compare flag but no actual compare content!!!"
)
else:
return compare_line.lstrip().split(":", 1)
def split_pattern_line(compare_pattern):
try:
pattern_flag, pattern = compare_pattern.lstrip().split(" ", 1)
except ValueError:
pattern_flag = compare_pattern.lstrip()
pattern = ""
return pattern_flag, pattern
begin_case = False
def gen_match_func(keywords):
global begin_case
valid_keywords = []
assert_mode = keywords[0]
match_func = None
if assert_mode == "scan":
match_func = scan_match
valid_keywords = SCAN_KEYWORDS
elif assert_mode == "regex":
match_func = regex_match
valid_keywords = REGEX_KEYWORDS
else:
raise CompareError("scan/regex mode: {} is not valid".format(assert_mode))
for keyword in keywords[1:]:
if keyword not in valid_keywords and not keyword.isnumeric():
raise CompareError(
"keyword: {} is not valid for {}".format(keyword, assert_mode)
)
if keyword == "auto":
match_func = partial(auto_regex_match, match_func=match_func)
elif keyword == "not":
match_func = partial(not_match, match_func=match_func)
elif keyword == "next":
if not begin_case:
raise Exception("Please use `next` after `begin`!")
match_func = partial(next_match, match_func=match_func)
elif keyword == "after":
if not begin_case:
raise Exception("Please use `after` after `begin`!")
match_func = partial(after_match, match_func=match_func)
elif keyword == "begin":
begin_case = True
match_func = partial(begin_match, match_func=match_func)
elif keyword == "end":
match_func = end_match
elif keyword == "full":
match_func = full_match
elif keyword.isnumeric():
match_func = partial(num_match, match_func=match_func)
return match_func
def update_pattern(pattern):
if len(global_group_dict.keys()) != 0:
for i in global_group_dict.keys():
pattern = pattern.replace('(?P={})'.format(i), global_group_dict[i])
return pattern
def is_valid_pattern(pattern):
try:
re.compile(pattern)
except re.error:
logging.error("Error pattern: {!r}".format(pattern))
return False
except TypeError:
logging.error(type(pattern), repr(pattern))
return True
def regex_match(content, line_map, pattern, start=0, count=-1):
if not is_valid_pattern(pattern):
raise CompareError("Not valid pattern: {!r}".format(pattern))
matches = re.finditer(str(pattern), content, re.MULTILINE)
end = 0
if count == -1:
for _, match in enumerate(matches, start=1):
for i in match.groupdict():
global_group_dict.update({i: match.groupdict()[i]})
end = match.end() + start
line_num = text_index_to_line_num(line_map, end)
if line_num + 1 >= len(line_map):
return True, end
return True, line_map[line_num] + 1
return False, start
else:
x = re.findall(pattern, content)
if len(x) != count:
return False, start
else:
return True, end
def scan_match(content, line_map, pattern, start=0, match_func=regex_match, count=-1):
line_num = text_index_to_line_num(line_map, start)
pattern = re.sub(r'([^\u0100-\uffff]+)',
lambda x: (x.group(1) + r' ').encode("utf-8").decode("unicode_escape")[:-1], pattern)
if count == -1:
if content.find(pattern) != -1:
end = content.find(pattern) + start
line_num = text_index_to_line_num(line_map, end)
return True, line_map[line_num] + 1
return False, start
else:
if content.count(pattern) != count:
return False, start
return True, line_map[line_num] + 1
def begin_match(content, line_map, pattern, start=0, match_func=regex_match, count=-1):
return match_func(content, line_map, pattern, start=0)
def auto_regex_match(content, line_map, pattern, start=0, match_func=regex_match, count=-1):
pattern = r"\s+".join([re.escape(word) for word in pattern.split()])
return match_func(content, line_map, pattern, start)
def not_match(content, line_map, pattern, start=0, match_func=regex_match, count=-1):
result, end = match_func(content, line_map, pattern, start)
if not result:
line_num = text_index_to_line_num(line_map, start)
return not result, line_map[line_num] + 1
return not result, start
def next_match(content, line_map, pattern, start=0, match_func=regex_match, count=-1):
line_num = text_index_to_line_num(line_map, start)
return match_func(content.splitlines()[line_num], line_map, pattern, start)
def after_match(content, line_map, pattern, start=0, match_func=regex_match, count=-1):
line_num = text_index_to_line_num(line_map, start)
return match_func('\n'.join(content.splitlines()[line_num:]), line_map, pattern, start)
def end_match(content, line_map, pattern, start=0, match_func=regex_match, count=-1):
line_num = text_index_to_line_num(line_map, start)
if line_num < len(line_map):
return False, start
return True, start
def full_match(content, line_map, pattern, start=0, match_func=regex_match, count=-1):
line_num = text_index_to_line_num(line_map, start)
pattern = re.sub(r'([^\u0100-\uffff]+)',
lambda x: (x.group(1) + r' ').encode("utf-8").decode("unicode_escape")[:-1], pattern)
if content != pattern:
return False, start
return True, line_map[line_num] + 1
def num_match(content, line_map, pattern, start=0, match_func=regex_match, count=-1):
return match_func(content, line_map, pattern, start, count=count)
def gen_line_map(text):
regex = ".*\n?"
line_map = []
for match in re.finditer(regex, text):
line_map.append(match.end())
return line_map
def text_index_to_line_num(line_map, index):
for line_num, end in enumerate(line_map):
if end >= index:
return line_num
return line_num + 1
def gen_compare_regex(comment, assert_flags, expected_flag):
regex = ""
for flag in expected_flag:
excepted_regex = r"(?:{comment}\s*)(.*{flag}[\t ]*\:[\t ]*.*$)".format(
comment=escape("\\$()*+.[]?^{}|", comment), flag=flag
)
if regex != "":
regex = "{}|{}".format(regex, excepted_regex)
else:
regex = excepted_regex
for flag in assert_flags:
assert_regex = r"(?:{comment}\s*)(.*{flag}[\t ]*\:[\t ]*.*$)".format(
comment=escape("\\$()*+.[]?^{}|", comment), flag=flag
)
if regex != "":
regex = "{}|{}".format(regex, assert_regex)
else:
regex = assert_regex
return regex
def extract_compare_lines(file_path, regex):
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
matches = re.finditer(regex, content, re.MULTILINE)
compare_lines = []
end_regex = ".*\n?"
line_map = []
for match in re.finditer(end_regex, content):
line_map.append(match.end())
for match in matches:
for group_num in range(0, len(match.groups())):
group_num = group_num + 1
if match.group(group_num) is None:
continue
for line_num, end in enumerate(line_map):
if end > match.start(group_num):
compare_lines.append([match.group(group_num), line_num + 1])
break
return compare_lines
multi_L = '/*'
multi_R = '*/'
def parse_all_multiline_comment(file, compare_number):
with open(file, 'r', encoding='utf-8') as f:
if compare_number == -1:
contents = f.read()
else:
contents = read_scans(f, compare_number)
all_match = []
found = re.search(multiscan_pattern, contents)
if not found:
return
start_index = found.start()
temp_stack = [start_index]
for i in range(start_index + 1, len(contents) - 1):
if contents[i:i + 2] == multi_R:
if temp_stack:
last_index = temp_stack.pop()
if not temp_stack:
all_match.append([last_index, i + 2])
elif contents[i:i + 2] == multi_L:
temp_stack.append(i)
res = []
for i in range(len(all_match) - 1, -1, -1):
cur = all_match[i]
if not re.match(multiscan_pattern, contents[cur[0]:cur[1]]):
all_match.remove(cur)
else:
res.append(contents[cur[0]:cur[1]])
return res
def read_scans(file, compare_number):
compares_above = 0
content = ""
for line in file:
if compares_above > compare_number:
break
if line.find("| compare %f") != -1:
compares_above += 1
if compares_above == compare_number:
content += line
return content
def parse_cli():
parser = argparse.ArgumentParser(prog="compare.py")
parser.add_argument("--comment", help="Test case comment")
parser.add_argument(
"--assert_flag",
help="Test case assert flag, default ASSERT",
action="append",
default=[],
)
parser.add_argument(
"--expected_flag",
help="Test case expected flag for compile, default EXCEPTED",
action="append",
default=[],
)
parser.add_argument(
"case_path", type=complete_path, help="Source path: read compare rules"
)
parser.add_argument(
"--compare_number", type=int, nargs='?', default=-1,
help="Number of 'compare %%f' pipe in file after which SCANs will be checked. All SCANs is checked by default."
)
parser.add_argument(
"com_opt", type=str, nargs='?', default='', help="Compile option of this case, default empty."
)
parser.add_argument(
"--condition", type=str, nargs='?', default='', help="Compare condition for Keyword, default empty."
)
parser.add_argument(
"--compare_object",
nargs="?",
default=sys.stdin,
help="compare object, default stdin",
)
parser.add_argument(
"--transfer",
help="Base dir for transfer/update normal case to multiline case.",
type=str,
default="",
)
opts = parser.parse_args()
return opts
if __name__ == "__main__":
logging.basicConfig(
format="\t%(message)s", level=logging.DEBUG, stream=sys.stderr,
)
main()