import os
import queue
import re
import subprocess
import sys
import threading
import warnings
from abc import ABCMeta, abstractmethod
from pathlib import Path
NUM_DEVICE = 8
BASE_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
TEST_DIR = os.path.join(BASE_DIR, "tests", "torch")
def check_path_owner_consistent(path: str):
"""
Function Description:
check whether the path belong to process owner
Parameter:
path: the path to check
Exception Description:
when invalid path, prompt the user
"""
if not os.path.exists(path):
msg = f"The path does not exist: {path}"
raise RuntimeError(msg)
if os.stat(path).st_uid != os.getuid():
warnings.warn(f"Warning: The {path} owner does not match the current user.")
def check_directory_path_readable(path):
check_path_owner_consistent(path)
if os.path.islink(path):
msg = f"Invalid path is a soft chain: {path}"
raise RuntimeError(msg)
if not os.access(path, os.R_OK):
msg = f"The path permission check failed: {path}"
raise RuntimeError(msg)
class AccurateTest(metaclass=ABCMeta):
@abstractmethod
def identify(self, modify_file):
raise Exception("abstract method. Subclasses should implement it.")
@staticmethod
def find_ut_by_regex(regex):
ut_files = []
cmd = "find {} -name {}".format(str(TEST_DIR), regex)
status, output = subprocess.getstatusoutput(cmd)
if status or not output:
pass
else:
files = output.split("\n")
for ut_file in files:
if ut_file.endswith("run_test.py"):
continue
if ut_file.endswith(".py"):
ut_files.append(ut_file)
return ut_files
class OpStrategy(AccurateTest):
def split_string(self, filename):
words = []
word = ""
for char in filename:
if char.isupper():
if word:
words.append(word.lower())
word = char
elif char == "_":
if word:
words.append(word.lower())
word = ""
else:
word += char
if word:
words.append(word.lower())
return words
def identify(self, modify_file):
filename = Path(modify_file).name
features = self.split_string(filename)
if len(features) > 1:
features = features[:-1]
regex = "*" + "*".join([f"{feature.lower()}" for feature in features]) + "*"
return AccurateTest.find_ut_by_regex(regex)
class DirectoryStrategy(AccurateTest):
def identify(self, modify_file):
path_modify_file = Path(modify_file)
is_test_file = (
str(path_modify_file.parts[0]) == "tests"
and str(path_modify_file.parts[1]) == "torch"
and re.match("test_(.+).py", Path(modify_file).name)
)
return [str(os.path.join(BASE_DIR, modify_file))] if is_test_file else []
class TestMgr:
def __init__(self):
self.modify_files = []
self.test_files = {"ut_files": [], "op_ut_files": []}
def load(self, modify_files):
check_directory_path_readable(modify_files)
with open(modify_files) as f:
for line in f:
line = line.strip()
self.modify_files.append(line)
def analyze(self):
for modify_file in self.modify_files:
self.test_files["ut_files"] += DirectoryStrategy().identify(modify_file)
self.test_files["ut_files"] += OpStrategy().identify(modify_file)
unique_files = sorted(set(self.test_files["ut_files"]))
exist_ut_file = [changed_file for changed_file in unique_files if Path(changed_file).exists()]
self.test_files["ut_files"] = exist_ut_file
def get_test_files(self):
return self.test_files
def print_modify_files(self):
print("modify files:")
for modify_file in self.modify_files:
print(modify_file)
def print_ut_files(self):
print("ut files:")
for ut_file in self.test_files["ut_files"]:
print(ut_file)
def print_op_ut_files(self):
print("op ut files:")
for op_ut_file in self.test_files["op_ut_files"]:
print(op_ut_file)
def exec_ut(files):
def get_op_name(ut_file):
return ut_file.split("/")[-1].split(".")[0].lstrip("test_")
def get_ut_name(ut_file):
return str(Path(ut_file).relative_to(TEST_DIR))[:-3]
def get_ut_cmd(ut_type, ut_file):
cmd = [sys.executable, "run_test.py", "-v", "-i"]
if ut_type == "op_ut_files":
return cmd + ["test_ops", "--", "-k", get_op_name(ut_file)]
return cmd + [get_ut_name(ut_file)]
def wait_thread(process, event_timer):
process.wait()
event_timer.set()
def enqueue_output(out, log_queue):
for line in iter(out.readline, b""):
log_queue.put(line.decode("utf-8"))
out.close()
return
def start_thread(fn, *args):
stdout_t = threading.Thread(target=fn, args=args)
stdout_t.daemon = True
stdout_t.start()
def print_subprocess_log(log_queue):
while not log_queue.empty():
print((log_queue.get()).strip())
def run_cmd_with_timeout(cmd):
os.chdir(str(TEST_DIR))
stdout_queue = queue.Queue()
event_timer = threading.Event()
p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.STDOUT)
start_thread(wait_thread, p, event_timer)
try:
event_timer.wait(2000)
ret = p.poll()
if ret:
print_subprocess_log(stdout_queue)
if not event_timer.is_set():
ret = 1
p.kill()
p.terminate()
print("Timeout: Command '{}' timed out after 2000 seconds".format(" ".join(cmd)))
print_subprocess_log(stdout_queue)
except Exception as err:
ret = 1
print(err)
return ret
def run_tests(files):
exec_infos = []
has_failed = 0
for ut_type, ut_files in files.items():
for ut_file in ut_files:
if not os.path.basename(ut_file).startswith("test_"):
continue
cmd = get_ut_cmd(ut_type, ut_file)
ut_info = " ".join(cmd[4:]).replace(" -- -k", "")
ret = run_cmd_with_timeout(cmd)
if ret:
has_failed = ret
exec_infos.append("exec ut {} failed.".format(ut_info))
else:
exec_infos.append("exec ut {} success.".format(ut_info))
return has_failed, exec_infos
ret_status, exec_infos = run_tests(files)
print("***** Total result:")
for exec_info in exec_infos:
print(exec_info)
return ret_status
if __name__ == "__main__":
cur_modify_files = str(os.path.join(BASE_DIR, "modify_files.txt"))
test_mgr = TestMgr()
test_mgr.load(cur_modify_files)
test_mgr.analyze()
cur_test_files = test_mgr.get_test_files()
test_mgr.print_modify_files()
test_mgr.print_ut_files()
test_mgr.print_op_ut_files()
ret_ut = exec_ut(cur_test_files)
sys.exit(ret_ut)