#!/usr/bin/env python3
"""
NPU 一键测试脚本
在华为 800I A2 NPU 环境下运行,验证:
1. torch_npu 可用性
2. 单元测试通过率
3. FP32/FP16 精度一致
4. 端到端推理流程

用法:
    python test_npu.py                          # 跑全部
    python test_npu.py --skip-precision          # 跳过精度验证
    python test_npu.py --skip-tests              # 跳过单元测试
    python test_npu.py --skip-inference          # 跳过推理
    python test_npu.py --dataroot /path/to/dataset  # 指定数据集路径
"""

import argparse
import subprocess
import sys
import time
import torch

PASS = "✅ PASS"
FAIL = "❌ FAIL"
SKIP = "⏭️  SKIP"

TOTAL = 0
PASSED = 0
FAILED = 0


def header(title: str):
    n = len(title) + 8
    print(f"\n{'=' * n}")
    print(f"   {title}")
    print(f"{'=' * n}")


def run(cmd: list[str], desc: str, cwd: str | None = None) -> bool:
    global TOTAL, PASSED, FAILED
    TOTAL += 1
    print(f"\n--- {desc} ---")
    print(f"$ {' '.join(cmd)}")
    t0 = time.time()
    r = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
    elapsed = time.time() - t0
    if r.returncode == 0:
        print(f"  {PASS}  ({elapsed:.1f}s)")
        PASSED += 1
        return True
    else:
        print(f"  {FAIL}  ({elapsed:.1f}s)")
        FAILED += 1
        if r.stdout:
            print(r.stdout[-2000:])
        if r.stderr:
            print(r.stderr[-2000:])
        return False


def check_npu():
    """检查 NPU 环境"""
    header("NPU 环境检测")

    try:
        import torch_npu  # noqa: F401
        print(f"  torch_npu 版本: {torch_npu.__version__}")
    except ImportError:
        print(f"  {FAIL} torch_npu 未安装!")
        return False

    if not torch.npu.is_available():
        print(f"  {FAIL} torch.npu.is_available() = False")
        return False

    n = torch.npu.device_count()
    print(f"  NPU 设备数: {n}")
    for i in range(n):
        name = torch.npu.get_device_name(i)
        cap = torch.npu.get_device_capability(i)
        print(f"    Device {i}: {name}  Capability: {cap}")
    print(f"  {PASS}")
    return True


def run_tests():
    header("单元测试")
    # 使用 pytest 而非直接 unittest
    parts = ["-v", "--tb=short", "test/"]
    cmd = [sys.executable, "-m", "pytest", *parts]
    return run(cmd, "pytest test/ (18 cases)")


def run_precision():
    header("精度验证 (NPU vs FaissNN)")
    cmd = [sys.executable, "bin/verify_precision.py"]
    return run(cmd, "bin/verify_precision.py")


def run_inference(dataroot: str | None):
    header("端到端推理测试")
    if not dataroot:
        print(f"  {SKIP}  未指定 --dataroot,跳过推理测试")
        print(f"  可用: python test_npu.py --dataroot /path/to/mvtec")
        return None
    cmd = [sys.executable, "inference.py", "--data_dir", dataroot]
    return run(cmd, "inference.py")


def summary():
    header("结果汇总")
    print(f"  总测试: {TOTAL}")
    print(f"  {PASS}: {PASSED}")
    print(f"  {FAIL}: {FAILED}")
    if FAILED == 0:
        print(f"\n🎉 全部通过!NPU 准备就绪。")
    else:
        print(f"\n❌ 失败 {FAILED} 项,请检查上方日志。")


def main():
    global TOTAL, PASSED, FAILED
    parser = argparse.ArgumentParser(description="NPU 一键测试")
    parser.add_argument("--skip-precision", action="store_true", help="跳过精度验证")
    parser.add_argument("--skip-tests", action="store_true", help="跳过单元测试")
    parser.add_argument("--skip-inference", action="store_true", help="跳过推理测试")
    parser.add_argument("--dataroot", type=str, default=None, help="MVTec-AD 数据集路径")
    args = parser.parse_args()

    print(f"Python: {sys.version}")
    print(f"PyTorch: {torch.__version__}")

    # 1. NPU 环境检测
    npu_ok = check_npu()
    if not npu_ok:
        print(f"\n❌ NPU 环境不可用,跳过后续测试")
        sys.exit(1)

    # 检查 patchcore 模块是否已安装
    try:
        import patchcore  # noqa: F401
    except ImportError:
        print(f"\n  {FAIL} patchcore 模块未安装,请先执行: pip install -e .")
        print(f"  然后重新运行: python test_npu.py")
        sys.exit(1)

    # 2. 单元测试
    if not args.skip_tests:
        run_tests()

    # 3. 精度验证
    if not args.skip_precision:
        run_precision()

    # 4. 推理测试
    if not args.skip_inference:
        run_inference(args.dataroot)

    # 汇总
    summary()
    sys.exit(0 if FAILED == 0 else 1)


if __name__ == "__main__":
    main()