Bbaishanyanginit project
5f1c8c3b创建于 4 天前历史提交
"""
Auto Test Runner for HarmonyOS Model Agent

Features:
- Discover and run all test modules
- Detailed terminal output with colors (Windows compatible)
- Grouped results by test module
- HTML report generation
- Failure details with traceback

Usage:
    python tests/auto_test.py
    python tests/auto_test.py --verbose
    python tests/auto_test.py --module test_operator_fix
"""

import sys
import os
import time
import argparse
import unittest
from pathlib import Path
from datetime import datetime
from io import StringIO
from typing import Dict, List, Any, Optional

# ===== 颜色配置(Windows兼容)=====
try:
    from colorama import init, Fore, Style
    init(autoreset=True)
    COLORS_AVAILABLE = True
except ImportError:
    COLORS_AVAILABLE = False
    # Fallback: 无颜色
    Fore = type('Fore', (), {
        'GREEN': '', 'RED': '', 'YELLOW': '', 'BLUE': '', 
        'CYAN': '', 'MAGENTA': '', 'WHITE': '', 'RESET': ''
    })()
    Style = type('Style', (), {
        'BRIGHT': '', 'DIM': '', 'RESET_ALL': ''
    })()

def colorize(text: str, color: str) -> str:
    """Apply color to text if available"""
    if COLORS_AVAILABLE:
        return f"{color}{text}{Style.RESET_ALL}"
    return text

def green(text: str) -> str:
    return colorize(text, Fore.GREEN)

def red(text: str) -> str:
    return colorize(text, Fore.RED)

def yellow(text: str) -> str:
    return colorize(text, Fore.YELLOW)

def cyan(text: str) -> str:
    return colorize(text, Fore.CYAN)

def bold(text: str) -> str:
    return colorize(text, Style.BRIGHT)

# ===== HTML报告生成器 =====
class HTMLReportGenerator:
    """Generate HTML test report"""
    
    CSS_STYLES = '''
        * { margin: 0; padding: 0; box-sizing: border-box; }
        body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; background: #f5f5f5; padding: 20px; }
        .container { max-width: 1200px; margin: 0 auto; background: white; border-radius: 8px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); }
        .header { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 30px; border-radius: 8px 8px 0 0; }
        .header h1 { font-size: 28px; margin-bottom: 10px; }
        .header .timestamp { font-size: 14px; opacity: 0.8; }
        .summary { display: flex; justify-content: space-around; padding: 20px; background: #f8f9fa; border-bottom: 1px solid #e9ecef; }
        .summary-item { text-align: center; }
        .summary-item .count { font-size: 32px; font-weight: bold; }
        .summary-item .label { font-size: 14px; color: #6c757d; }
        .summary-item.pass .count { color: #28a745; }
        .summary-item.fail .count { color: #dc3545; }
        .summary-item.skip .count { color: #6c757d; }
        .summary-item.error .count { color: #fd7e14; }
        .content { padding: 20px; }
        .group { margin-bottom: 30px; }
        .group-header { background: #e9ecef; padding: 15px 20px; border-radius: 4px; cursor: pointer; display: flex; justify-content: space-between; align-items: center; }
        .group-header:hover { background: #dee2e6; }
        .group-header .name { font-weight: bold; font-size: 16px; }
        .group-header .stats { font-size: 14px; color: #6c757d; }
        .group-content { margin-top: 10px; }
        table { width: 100%; border-collapse: collapse; }
        th { background: #f8f9fa; padding: 12px; text-align: left; font-weight: 600; border-bottom: 2px solid #dee2e6; }
        td { padding: 12px; border-bottom: 1px solid #e9ecef; }
        tr:hover { background: #f8f9fa; }
        .status-pass { color: #28a745; font-weight: bold; }
        .status-fail { color: #dc3545; font-weight: bold; }
        .status-skip { color: #6c757d; }
        .status-error { color: #fd7e14; font-weight: bold; }
        .duration { color: #6c757d; font-size: 12px; }
        .failures { margin-top: 30px; }
        .failure-item { background: #fff3cd; border: 1px solid #ffc107; border-radius: 4px; margin-bottom: 15px; }
        .failure-header { padding: 15px; background: #ffc107; color: #856404; font-weight: bold; border-radius: 4px 4px 0 0; }
        .failure-body { padding: 15px; }
        .failure-body pre { background: #f8f9fa; padding: 15px; border-radius: 4px; overflow-x: auto; font-size: 12px; white-space: pre-wrap; }
        .footer { padding: 20px; text-align: center; color: #6c757d; font-size: 12px; border-top: 1px solid #e9ecef; }
        .icon-pass::before { content: "[PASS] "; }
        .icon-fail::before { content: "[FAIL] "; }
        .icon-skip::before { content: "[SKIP] "; }
        .icon-error::before { content: "[ERROR] "; }
'''
    
    def generate(self, results: List[Dict], output_path: Path, total_duration: float) -> None:
        """Generate HTML report file"""
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        
        # 统计
        passed = sum(1 for r in results if r['status'] == 'PASS')
        failed = sum(1 for r in results if r['status'] == 'FAIL')
        skipped = sum(1 for r in results if r['status'] == 'SKIP')
        errors = sum(1 for r in results if r['status'] == 'ERROR')
        total = len(results)
        
        # 分组
        groups_html = self._generate_groups(results)
        
        # 失败详情
        failures_html = self._generate_failures(results)
        
        # 生成报告
        report = f'''<!DOCTYPE html>
<html lang="zh-CN">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Test Report - {timestamp}</title>
    <style>
{self.CSS_STYLES}
    </style>
</head>
<body>
    <div class="container">
        <div class="header">
            <h1>HarmonyOS Model Agent - Test Report</h1>
            <div class="timestamp">Generated: {timestamp}</div>
        </div>
        
        <div class="summary">
            <div class="summary-item pass">
                <div class="count">{passed}</div>
                <div class="label">Passed</div>
            </div>
            <div class="summary-item fail">
                <div class="count">{failed}</div>
                <div class="label">Failed</div>
            </div>
            <div class="summary-item skip">
                <div class="count">{skipped}</div>
                <div class="label">Skipped</div>
            </div>
            <div class="summary-item error">
                <div class="count">{errors}</div>
                <div class="label">Errors</div>
            </div>
            <div class="summary-item">
                <div class="count">{total}</div>
                <div class="label">Total</div>
            </div>
            <div class="summary-item">
                <div class="count">{total_duration:.2f}s</div>
                <div class="label">Duration</div>
            </div>
        </div>
        
        <div class="content">
            {groups_html}
            
            {failures_html}
        </div>
        
        <div class="footer">
            HarmonyOS Model Agent Test Runner | {total} tests executed
        </div>
    </div>
</body>
</html>'''
        
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write(report)
    
    def _generate_groups(self, results: List[Dict]) -> str:
        """Generate grouped test results HTML"""
        groups = {}
        for r in results:
            module = r['module']
            if module not in groups:
                groups[module] = []
            groups[module].append(r)
        
        groups_html = ""
        for module_name, tests in groups.items():
            module_passed = sum(1 for t in tests if t['status'] == 'PASS')
            module_failed = sum(1 for t in tests if t['status'] in ['FAIL', 'ERROR'])
            module_skip = sum(1 for t in tests if t['status'] == 'SKIP')
            
            # 组头
            group_header = f'''
            <div class="group">
                <div class="group-header">
                    <span class="name">{module_name}</span>
                    <span class="stats">{len(tests)} tests | {module_passed} passed | {module_failed} failed | {module_skip} skipped</span>
                </div>
                <div class="group-content">
                    <table>
                        <thead>
                            <tr>
                                <th>Test Name</th>
                                <th>Status</th>
                                <th>Duration</th>
                                <th>Message</th>
                            </tr>
                        </thead>
                        <tbody>
            '''
            
            # 测试行
            rows = ""
            for test in tests:
                status_class = f"status-{test['status'].lower()}"
                icon_class = f"icon-{test['status'].lower()}"
                status_text = test['status']
                duration_text = f"{test['duration']:.3f}s" if test['duration'] else "-"
                message_text = test.get('message', '')[:100] if test.get('message') else ''
                
                rows += f'''
                        <tr>
                            <td>{test['test_name']}</td>
                            <td class="{status_class} {icon_class}">{status_text}</td>
                            <td class="duration">{duration_text}</td>
                            <td>{message_text}</td>
                        </tr>
                '''
            
            # 组尾
            group_footer = '''
                        </tbody>
                    </table>
                </div>
            </div>
            '''
            
            groups_html += group_header + rows + group_footer
        
        return groups_html
    
    def _generate_failures(self, results: List[Dict]) -> str:
        """Generate failures detail section"""
        failures = [r for r in results if r['status'] in ['FAIL', 'ERROR']]
        
        if not failures:
            return '<div class="failures"><p style="text-align:center; color:#28a745;">✓ All tests passed!</p></div>'
        
        failures_html = '<div class="failures"><h2 style="margin-bottom:20px; color:#dc3545;">❌ Failed Tests Details</h2>'
        
        for fail in failures:
            traceback_text = fail.get('traceback', 'No traceback available')
            message_text = fail.get('message', '')
            
            failures_html += f'''
            <div class="failure-item">
                <div class="failure-header">{fail['module']}: {fail['test_name']}</div>
                <div class="failure-body">
                    <p><strong>Message:</strong> {message_text}</p>
                    <pre>{traceback_text}</pre>
                </div>
            </div>
            '''
        
        failures_html += '</div>'
        return failures_html


# ===== 测试结果收集器 =====
class TestResultCollector(unittest.TestResult):
    """Collect test results from unittest - inherits from unittest.TestResult"""
    
    def __init__(self, runner=None):
        super().__init__()
        self.results: List[Dict] = []
        self.test_start_time: float = 0
        self.runner = runner  # 用于实时打印
    
    def startTest(self, test: unittest.TestCase):
        super().startTest(test)
        self.test_start_time = time.time()
    
    def stopTest(self, test: unittest.TestCase):
        super().stopTest(test)
    
    def addSuccess(self, test: unittest.TestCase):
        super().addSuccess(test)
        duration = time.time() - self.test_start_time
        result = self._add_result(test, 'PASS', duration=duration)
        if self.runner:
            self.runner._print_test_result(result)
    
    def addFailure(self, test: unittest.TestCase, err):
        super().addFailure(test, err)
        exc_type, exc_value, tb = err
        result = self._add_result(
            test, 'FAIL',
            message=str(exc_value),
            traceback=self._format_traceback(err),
            duration=time.time() - self.test_start_time
        )
        if self.runner:
            self.runner._print_test_result(result)
    
    def addError(self, test: unittest.TestCase, err):
        super().addError(test, err)
        exc_type, exc_value, tb = err
        result = self._add_result(
            test, 'ERROR',
            message=str(exc_value),
            traceback=self._format_traceback(err),
            duration=time.time() - self.test_start_time
        )
        if self.runner:
            self.runner._print_test_result(result)
    
    def addSkip(self, test: unittest.TestCase, reason: str):
        super().addSkip(test, reason)
        result = self._add_result(test, 'SKIP', message=reason, duration=0)
        if self.runner:
            self.runner._print_test_result(result)
    
    def _add_result(self, test: unittest.TestCase, status: str, 
                    message: str = '', traceback: str = '', duration: float = 0) -> Dict:
        test_id = test.id()
        parts = test_id.split('.')
        
        # 解析测试名和模块名
        if len(parts) >= 3:
            module_name = parts[0]
            test_name = parts[-1]
        elif len(parts) >= 2:
            module_name = parts[0]
            test_name = parts[1]
        else:
            module_name = 'unknown'
            test_name = test_id
        
        result = {
            'module': module_name,
            'test_name': test_name,
            'full_name': test_id,
            'status': status,
            'message': message,
            'traceback': traceback,
            'duration': duration
        }
        self.results.append(result)
        return result
    
    def _format_traceback(self, err) -> str:
        """Format traceback for display"""
        import traceback as tb_module
        exc_type, exc_value, tb = err
        lines = tb_module.format_exception(exc_type, exc_value, tb)
        return ''.join(lines)
        
        # 解析测试名和模块名
        if len(parts) >= 3:
            module_name = parts[0]
            test_name = parts[-1]
        elif len(parts) >= 2:
            module_name = parts[0]
            test_name = parts[1]
        else:
            module_name = 'unknown'
            test_name = test_id
        
        self.results.append({
            'module': module_name,
            'test_name': test_name,
            'full_name': test_id,
            'status': status,
            'message': message,
            'traceback': traceback,
            'duration': duration
        })
    
    def _format_traceback(self, err) -> str:
        """Format traceback for display"""
        import traceback as tb_module
        exc_type, exc_value, tb = err
        lines = tb_module.format_exception(exc_type, exc_value, tb)
        return ''.join(lines)


# ===== 测试运行器 =====
class TestRunner:
    """Main test runner"""
    
    def __init__(self, verbose: bool = False, module_filter: Optional[str] = None):
        self.verbose = verbose
        self.module_filter = module_filter
        self.tests_dir = Path(__file__).parent
        self.html_generator = HTMLReportGenerator()
    
    def discover_tests(self) -> List[Path]:
        """Discover all test files"""
        test_files = list(self.tests_dir.glob("test_*.py"))
        test_files = [f for f in test_files if f.name != "auto_test.py"]
        
        # 排序:operator_fix -> benchmark -> convert
        order_map = {
            'test_operator_fix': 0,
            'test_model_benchmark': 1,
            'test_model_convert': 2
        }
        test_files.sort(key=lambda f: order_map.get(f.stem, 99))
        
        return test_files
    
    def run(self) -> int:
        """Run all tests and return exit code"""
        # 1. 发现测试文件
        test_files = self.discover_tests()
        
        if self.module_filter:
            test_files = [f for f in test_files if self.module_filter in f.stem]
        
        if not test_files:
            print(yellow("No test files found!"))
            return 1
        
        # 2. 创建收集器(传入runner以实时打印)
        collector = TestResultCollector(runner=self)
        
        # 3. 保存原始 sys.path 并确保标准库路径存在
        original_sys_path = sys.path.copy()
        # 确保项目根目录在最前面
        project_root = str(self.tests_dir.parent)
        if project_root in sys.path:
            sys.path.remove(project_root)
        sys.path.insert(0, project_root)
        
        # 4. 打印开始信息
        self._print_header(test_files)
        
        # 5. 运行测试
        start_time = time.time()
        
        for test_file in test_files:
            self._print_module_header(test_file.stem)
            self._run_test_module(test_file, collector)
            # 每个模块运行后恢复 sys.path
            sys.path = original_sys_path.copy()
            sys.path.insert(0, project_root)
        
        total_duration = time.time() - start_time
        
        # 6. 打印结果
        self._print_summary(collector.results, total_duration)
        
        # 7. 生成HTML报告(保存到tests/output目录)
        output_dir = self.tests_dir / "output"
        output_dir.mkdir(exist_ok=True)
        report_path = output_dir / f"test_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.html"
        self.html_generator.generate(collector.results, report_path, total_duration)
        print(f"\n{cyan(f'HTML Report: {report_path}')}")
        
        # 8. 返回状态码
        failed_count = sum(1 for r in collector.results if r['status'] in ['FAIL', 'ERROR'])
        return 1 if failed_count > 0 else 0
    
    def _run_test_module(self, test_file: Path, collector: TestResultCollector) -> None:
        """Run a single test module"""
        # 动态导入测试模块
        import importlib.util
        
        # 确保项目根目录在 sys.path 中
        project_root = str(self.tests_dir.parent)
        if project_root not in sys.path:
            sys.path.insert(0, project_root)
        
        # 使用唯一的模块名避免冲突
        module_name = f"tests.{test_file.stem}"
        
        # 如果模块已加载,先移除
        if module_name in sys.modules:
            del sys.modules[module_name]
        if test_file.stem in sys.modules:
            del sys.modules[test_file.stem]
        
        try:
            spec = importlib.util.spec_from_file_location(module_name, test_file)
            module = importlib.util.module_from_spec(spec)
            sys.modules[module_name] = module
            
            spec.loader.exec_module(module)
        except Exception as e:
            import traceback
            error_detail = traceback.format_exc()
            print(red(f"  [FAIL] Failed to load module: {e}"))
            if self.verbose:
                print(f"    {error_detail[:500]}")
            return
        
        # 获取测试套件
        loader = unittest.TestLoader()
        suite = loader.loadTestsFromModule(module)
        
        # 运行测试(collector会自动处理结果并实时打印)
        suite.run(collector)
    
    def _print_test_result(self, result: Dict) -> None:
        """Print single test result (called by collector)"""
        status = result['status']
        test_name = result['test_name']
        duration = result['duration']
        message = result.get('message', '')
        
        # 状态图标
        if status == 'PASS':
            status_str = green("[PASS]")
        elif status == 'FAIL':
            status_str = red("[FAIL]")
        elif status == 'ERROR':
            status_str = red("[ERROR]")
        elif status == 'SKIP':
            status_str = yellow("[SKIP]")
        else:
            status_str = status
        
        # 打印
        if self.verbose:
            print(f"  {status_str} {test_name:<50} {duration:.3f}s")
            # 显示失败详情
            if status in ['FAIL', 'ERROR'] and message:
                print(f"      {red(message[:100])}")
        else:
            print(f"  {status_str} {test_name}")
        
        sys.path.pop(0)
    
    def _print_header(self, test_files: List[Path]) -> None:
        """Print test header"""
        print("\n" + "=" * 80)
        print(bold("                    HarmonyOS Model Agent - Auto Test Runner"))
        print("=" * 80)
        print(f"\n{cyan(f'Discovered {len(test_files)} test modules:')}")
        for f in test_files:
            print(f"  - {f.stem}")
        print()
    
    def _print_module_header(self, module_name: str) -> None:
        """Print module header"""
        print(f"\n{bold(f'[{module_name}]')}")
    
    def _print_summary(self, results: List[Dict], total_duration: float) -> None:
        """Print summary"""
        passed = sum(1 for r in results if r['status'] == 'PASS')
        failed = sum(1 for r in results if r['status'] == 'FAIL')
        errors = sum(1 for r in results if r['status'] == 'ERROR')
        skipped = sum(1 for r in results if r['status'] == 'SKIP')
        total = len(results)
        
        print("\n" + "=" * 80)
        print(bold("SUMMARY"))
        print("=" * 80)
        
        # 打印统计
        stats = f"Total: {total} | {green(f'Passed: {passed}')} | {red(f'Failed: {failed+errors}')} | {yellow(f'Skipped: {skipped}')}"
        print(stats)
        print(f"Duration: {total_duration:.2f}s")
        
        # 打印失败详情
        if failed + errors > 0:
            print("\n" + red("Failed Tests:"))
            for r in results:
                if r['status'] in ['FAIL', 'ERROR']:
                    fail_name = f'* {r["module"]}/{r["test_name"]}'
                    print(f"  {red(fail_name)}")
                    if self.verbose and r.get('traceback'):
                        # 显示简化的traceback
                        tb_lines = r['traceback'].split('\n')
                        relevant_lines = [l for l in tb_lines if 'AssertionError' in l or 'Error' in l or 'at line' in l.lower()]
                        for line in relevant_lines[:3]:
                            print(f"    {line}")


# ===== 主函数 =====
def main():
    parser = argparse.ArgumentParser(description="HarmonyOS Model Agent Test Runner")
    parser.add_argument('--verbose', '-v', action='store_true', 
                        help='Show detailed output including tracebacks')
    parser.add_argument('--module', '-m', type=str, default=None,
                        help='Filter tests by module name (e.g., "operator" for test_operator_fix)')
    parser.add_argument('--no-html', action='store_true',
                        help='Skip HTML report generation')
    
    args = parser.parse_args()
    
    # 设置路径
    tests_dir = Path(__file__).parent
    project_root = tests_dir.parent
    sys.path.insert(0, str(project_root))
    
    # 运行测试
    runner = TestRunner(verbose=args.verbose, module_filter=args.module)
    exit_code = runner.run()
    
    sys.exit(exit_code)


if __name__ == "__main__":
    main()