import argparse
import pathlib
import os
import sys
import signal
import math
from datetime import datetime, timezone
from typing import Optional, List
import torch
from torch.testing._internal.common_utils import shell
import torch_npu
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
SIGNALS_TO_NAMES_DICT = {
getattr(signal, n): n for n in dir(signal) if
n.startswith("SIG") and "_" not in n
}
def print_to_stderr(message):
print(message, file=sys.stderr)
def discover_tests(
base_dir: Optional[pathlib.Path] = None,
blocklisted_patterns: Optional[List[str]] = None,
blocklisted_tests: Optional[List[str]] = None,
extra_tests: Optional[List[str]] = None) -> List[str]:
"""
Searches for all python files starting with test_ excluding one specified by patterns
"""
def skip_test_p(name: str) -> bool:
rc = False
if blocklisted_patterns is not None:
rc |= any(
name.startswith(pattern) for pattern in blocklisted_patterns)
if blocklisted_tests is not None:
rc |= name in blocklisted_tests
return rc
cwd = pathlib.Path(
__file__).resolve().parent if base_dir is None else base_dir
all_py_files = list(cwd.glob('**/test_*.py'))
rc = [str(fname.relative_to(cwd))[:-3] for fname in all_py_files]
rc = [test for test in rc if not skip_test_p(test)]
if extra_tests is not None:
rc += extra_tests
return sorted(rc)
def parse_test_module(test):
return pathlib.Path(test).parts[0]
TESTS = discover_tests(
blocklisted_patterns=['_afd'],
blocklisted_tests=[],
extra_tests=[]
)
TESTS_MODULE = list(set([parse_test_module(test) for test in TESTS]))
TEST_CHOICES = TESTS + TESTS_MODULE
def parse_args():
parser = argparse.ArgumentParser(
description="Run the PyTorch unit test suite",
epilog="where TESTS is any of: {}".format(", ".join(TESTS)),
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"-v",
"--verbose",
action="count",
default=0,
help="print verbose information and test-by-test results",
)
parser.add_argument(
"-i",
"--include",
nargs="+",
choices=TEST_CHOICES,
default=TESTS,
metavar="TESTS",
help="select a set of tests to include (defaults to ALL tests)."
" tests must be a part of the TESTS list defined in run_test.py",
)
parser.add_argument(
"additional_unittest_args",
nargs="*",
help="additional arguments passed through to unittest, e.g., "
"python run_test.py -i sparse -- TestSparse.test_factory_size_check",
)
return parser.parse_args()
def get_selected_tests(options):
selected_tests = []
if options.include:
for item in options.include:
selected_tests.extend(
list(filter(lambda test_name: item == test_name \
or (
item in TESTS_MODULE and test_name.startswith(
item)), TESTS)))
else:
selected_tests = TESTS
return selected_tests
def run_test(test, test_directory, options):
unittest_args = options.additional_unittest_args.copy()
if options.verbose:
unittest_args.append("-v")
executable = [sys.executable]
argv = [test + ".py"] + unittest_args
command = executable
calculate_python_coverage = os.getenv("CALCULATE_PYTHON_COVERAGE")
if calculate_python_coverage and calculate_python_coverage == "1":
command = command + ["-m", "coverage", "run", "-p", "--source=torch_npu", "--branch"]
command = command + argv
print_to_stderr(
"Executing {} ... [{}]".format(command, datetime.now(tz=timezone.utc)))
return shell(command, test_directory)
def run_test_module(test: str, test_directory: str, options) -> Optional[str]:
print_to_stderr(
"Running {} ... [{}]".format(test, datetime.now(tz=timezone.utc)))
return_code = run_test(test, test_directory, options)
if not (isinstance(return_code, int) and not isinstance(return_code, bool)):
raise TypeError("Return code should be an integer")
if return_code == 0:
return None
message = f"exec ut {test} failed!"
if return_code < 0:
signal_name = SIGNALS_TO_NAMES_DICT[-return_code]
message += f" Received signal: {signal_name}"
return message
def main():
options = parse_args()
test_directory = os.path.join(REPO_ROOT, "test")
selected_tests = get_selected_tests(options)
if options.verbose:
print_to_stderr("Selected tests: {}".format(", ".join(selected_tests)))
has_failed = False
failure_msgs = []
for test in selected_tests:
err_msg = run_test_module(test, test_directory, options)
if err_msg is None:
continue
has_failed = True
failure_msgs.append(err_msg)
print_to_stderr(err_msg)
if has_failed:
for err in failure_msgs:
print_to_stderr(err)
return False
return True
if __name__ == "__main__":
if not main():
sys.exit(1)