import unittest
import os
import re
import argparse
from typing import List
from importlib import import_module
from incremental_test_finder import IncrementalTestFinder
script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(script_dir)
custom_op_path1 = os.path.join(parent_dir, "mindiesd/ops/vendors/aie_ascendc")
custom_op_path2 = os.path.join(parent_dir, "mindiesd/ops/vendors/customize")
old_custom_op_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH", "")
new_custom_op_path = f"{custom_op_path1}:{custom_op_path2}:{old_custom_op_path}"
os.environ["ASCEND_CUSTOM_OPP_PATH"] = new_custom_op_path
def load_tests_from_list(test_files: List[str], base_path: str):
"""
从文件列表加载测试用例
Args:
test_files: 测试文件路径列表
base_path: 基础路径
Returns:
unittest.TestSuite
"""
test_suite = unittest.TestSuite()
for test_file in test_files:
if test_file.startswith("tests/"):
test_file = test_file[6:]
file_path = os.path.join(base_path, test_file)
if not os.path.isfile(file_path):
print(f"Warning: Test file not found: {file_path}")
continue
filename = os.path.basename(file_path)
if not (re.match(r"^test_", filename) and filename.endswith(".py")):
print(f"Warning: Not a valid test file: {filename}")
continue
try:
module_name = os.path.splitext(test_file)[0].replace(os.path.sep, '.')
module = import_module(f'{module_name}')
tests = unittest.TestLoader().loadTestsFromModule(module)
test_suite.addTests(tests)
print(f" ✅ Loaded: {test_file}")
except Exception as e:
print(f" ❌ Failed to load {test_file}: {e}")
return test_suite
def run_incremental_tests(
base_branch: str = "main", include_staged: bool = True, include_unstaged: bool = True, dry_run: bool = False
):
current_dir = os.path.dirname(os.path.abspath(__file__))
finder = IncrementalTestFinder(repo_root=os.path.dirname(current_dir), base_branch=base_branch)
tests, deletion_info, changes = finder.get_incremental_tests(
since_ref=base_branch, include_staged=include_staged, include_unstaged=include_unstaged
)
finder.print_test_plan(tests, deletion_info, changes)
if dry_run:
return None
if not tests:
print("\n没有需要运行的测试。")
return unittest.TestSuite()
print("\n🚀 正在加载测试用例...\n")
test_files = [t.replace("tests/", "") for t in tests]
return load_tests_from_list(test_files, current_dir)
def parse_args():
parser = argparse.ArgumentParser(
description="MindIE SD 增量UT测试", formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("--base-branch", "-b", default="master", help="增量测试的基分支 (默认: master)")
parser.add_argument("--no-staged", action="store_true", help="增量测试时不包含暂存区变更")
parser.add_argument("--no-unstaged", action="store_true", help="增量测试时不包含未暂存变更")
parser.add_argument("--verbose", "-v", action="store_true", help="详细输出模式")
parser.add_argument("--dry-run", "-n", action="store_true", help="仅显示测试计划,不执行测试")
args = parser.parse_args()
return args
def main():
args = parse_args()
suite = run_incremental_tests(
base_branch=args.base_branch,
include_staged=not args.no_staged,
include_unstaged=not args.no_unstaged,
dry_run=args.dry_run,
)
if suite is None:
return
if isinstance(suite, unittest.TestSuite) and suite.countTestCases() == 0:
return
print(f"\n共加载 {suite.countTestCases()} 个测试用例\n")
if args.dry_run:
return
verbosity = 2 if args.verbose else 1
runner = unittest.TextTestRunner(verbosity=verbosity)
runner.run(suite)
if __name__ == "__main__":
main()