import argparse
import os
import shlex
import subprocess
import sys
from pathlib import Path
TEST_RESULT_SUCCESS = 0
TEST_RESULT_FAILURE = 1
TEST_RESULT_INVALID_INPUT = 2
def read_files_from_txt(txt_file):
with open(txt_file, "r", encoding='utf-8') as f:
return [line.strip() for line in f.readlines()]
def is_examples(file):
return file.startswith("examples/") and not file.endswith(".py")
def is_markdown(file):
return file.endswith(".md")
def is_txt(file):
return file.endswith(".txt") or file.endswith(".rst")
def is_image(file):
return file.endswith(".jpg") or file.endswith(".png")
def is_vedio(file):
return file.endswith(".gif")
def is_owners(file):
return file.startswith("OWNERS")
def is_license(file):
return file.startswith("LICENSE")
def is_no_suffix(file):
return os.path.splitext(file)[1] == ''
def skip_ci_file(files, skip_cond):
for file in files:
if not any(condition(file) for condition in skip_cond):
return False
return True
def alter_skip_ci():
parent_dir = Path(__file__).absolute().parents[2]
raw_txt_file = os.path.join(parent_dir, "modify.txt")
if not os.path.exists(raw_txt_file):
return False
file_list = read_files_from_txt(raw_txt_file)
skip_conds = [is_examples, is_markdown, is_txt, is_image, is_vedio, is_owners, is_license, is_no_suffix]
return skip_ci_file(file_list, skip_conds)
def acquire_exitcode(command):
"""不使用 shell 的更安全版本"""
cmd_args = shlex.split(command)
with subprocess.Popen(
cmd_args,
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
encoding='utf-8',
errors='replace',
bufsize=1,
) as process:
while True:
output = process.stdout.readline()
if output == '' and process.poll() is not None:
break
if output:
print(output, end='', flush=True)
return process.wait()
while True:
output = process.stdout.readline()
if output == '' and process.poll() is not None:
break
if output:
print(output, end='', flush=True)
return process.wait()
class UT_Test:
def __init__(self):
base_dir = Path(__file__).absolute().parent.parent
test_dir = os.path.join(base_dir, 'tests')
self.ut_file = os.path.join(test_dir, "ut")
self.ut_fsdp_file = os.path.join(test_dir, "ut_fsdp")
def run_ut(self, local=False):
if not local:
command = f"pytest -x {self.ut_file}"
else:
command = f"pytest {self.ut_file}"
code = acquire_exitcode(command)
if code == 0:
print("UT test success")
else:
print("UT failed")
return code
def run_ut_fsdp(self, local=False):
if not local:
command = f"pytest -x {self.ut_fsdp_file}"
else:
command = f"pytest {self.ut_fsdp_file}"
code = acquire_exitcode(command)
if code == 0:
print("UT(fsdp) test success")
else:
print("UT(fsdp) failed")
return code
class ST_Test:
def __init__(self):
base_dir = Path(__file__).absolute().parent.parent
test_dir = os.path.join(base_dir, 'tests')
st_dir = "st"
self.st_shell = os.path.join(test_dir, st_dir, "st_run.sh")
self.local_st_shell = os.path.join(test_dir, st_dir, 'local_st_run.sh')
def run_st(self, local=False):
if local:
command = f'bash {self.local_st_shell}'
else:
command = f"bash {self.st_shell}"
code = acquire_exitcode(command)
if code == 0:
print("ST test success")
else:
print("ST failed")
return code
def run_ut_tests():
ut = UT_Test()
code = ut.run_ut()
if code != 0:
return code
else:
return ut.run_ut_fsdp()
def run_ut_local_tests():
ut = UT_Test()
code = ut.run_ut(local=True)
if code != 0:
return code
else:
return ut.run_ut_fsdp(local=True)
def run_st_tests():
st = ST_Test()
return st.run_st()
def run_st_local_tests():
st = ST_Test()
return st.run_st(local=True)
def run_tests(options):
if options.type == "st":
st_code = run_st_tests()
return TEST_RESULT_FAILURE if st_code != 0 else TEST_RESULT_SUCCESS
elif options.type == "ut":
ut_code = run_ut_tests()
return TEST_RESULT_FAILURE if ut_code != 0 else TEST_RESULT_SUCCESS
elif options.type == "all":
code = run_ut_tests()
if code != 0:
return TEST_RESULT_FAILURE
st_code = run_st_tests()
return TEST_RESULT_FAILURE if st_code != 0 else TEST_RESULT_SUCCESS
elif options.type == 'all_loss':
ut_code = run_ut_local_tests()
st_code = run_st_local_tests()
return TEST_RESULT_FAILURE if st_code != 0 or ut_code != 0 else TEST_RESULT_SUCCESS
else:
print(f"TEST CASE TYPE ERROR: no type '{options.type}'")
return TEST_RESULT_INVALID_INPUT
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Control needed test cases")
parser.add_argument(
"--type",
type=str,
default="all",
choices=["all", "ut", "st", "codecheck", "all_loss"],
help='Test cases type. `all`: run all test cases; `ut`: run ut case, `st`: run st cases; `codecheck`: used for codecheck; `all_loss`: used for local ci',
)
args = parser.parse_args()
print(f"options: {args}")
if alter_skip_ci():
print("Skipping CI: Success")
elif args.type == "codecheck":
print("Skipping CI: Failed")
sys.exit(1)
else:
print("Skipping CI: Failed")
exit_code = run_tests(args)
if exit_code != 0:
sys.exit(exit_code)