"""
NPU 一键测试脚本
在华为 800I A2 NPU 环境下运行,验证:
1. torch_npu 可用性
2. 单元测试通过率
3. FP32/FP16 精度一致
4. 端到端推理流程
用法:
python test_npu.py # 跑全部
python test_npu.py --skip-precision # 跳过精度验证
python test_npu.py --skip-tests # 跳过单元测试
python test_npu.py --skip-inference # 跳过推理
python test_npu.py --dataroot /path/to/dataset # 指定数据集路径
"""
import argparse
import subprocess
import sys
import time
import torch
PASS = "✅ PASS"
FAIL = "❌ FAIL"
SKIP = "⏭️ SKIP"
TOTAL = 0
PASSED = 0
FAILED = 0
def header(title: str):
n = len(title) + 8
print(f"\n{'=' * n}")
print(f" {title}")
print(f"{'=' * n}")
def run(cmd: list[str], desc: str, cwd: str | None = None) -> bool:
global TOTAL, PASSED, FAILED
TOTAL += 1
print(f"\n--- {desc} ---")
print(f"$ {' '.join(cmd)}")
t0 = time.time()
r = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
elapsed = time.time() - t0
if r.returncode == 0:
print(f" {PASS} ({elapsed:.1f}s)")
PASSED += 1
return True
else:
print(f" {FAIL} ({elapsed:.1f}s)")
FAILED += 1
if r.stdout:
print(r.stdout[-2000:])
if r.stderr:
print(r.stderr[-2000:])
return False
def check_npu():
"""检查 NPU 环境"""
header("NPU 环境检测")
try:
import torch_npu
print(f" torch_npu 版本: {torch_npu.__version__}")
except ImportError:
print(f" {FAIL} torch_npu 未安装!")
return False
if not torch.npu.is_available():
print(f" {FAIL} torch.npu.is_available() = False")
return False
n = torch.npu.device_count()
print(f" NPU 设备数: {n}")
for i in range(n):
name = torch.npu.get_device_name(i)
cap = torch.npu.get_device_capability(i)
print(f" Device {i}: {name} Capability: {cap}")
print(f" {PASS}")
return True
def run_tests():
header("单元测试")
parts = ["-v", "--tb=short", "test/"]
cmd = [sys.executable, "-m", "pytest", *parts]
return run(cmd, "pytest test/ (18 cases)")
def run_precision():
header("精度验证 (NPU vs FaissNN)")
cmd = [sys.executable, "bin/verify_precision.py"]
return run(cmd, "bin/verify_precision.py")
def run_inference(dataroot: str | None):
header("端到端推理测试")
if not dataroot:
print(f" {SKIP} 未指定 --dataroot,跳过推理测试")
print(f" 可用: python test_npu.py --dataroot /path/to/mvtec")
return None
cmd = [sys.executable, "inference.py", "--data_dir", dataroot]
return run(cmd, "inference.py")
def summary():
header("结果汇总")
print(f" 总测试: {TOTAL}")
print(f" {PASS}: {PASSED}")
print(f" {FAIL}: {FAILED}")
if FAILED == 0:
print(f"\n🎉 全部通过!NPU 准备就绪。")
else:
print(f"\n❌ 失败 {FAILED} 项,请检查上方日志。")
def main():
global TOTAL, PASSED, FAILED
parser = argparse.ArgumentParser(description="NPU 一键测试")
parser.add_argument("--skip-precision", action="store_true", help="跳过精度验证")
parser.add_argument("--skip-tests", action="store_true", help="跳过单元测试")
parser.add_argument("--skip-inference", action="store_true", help="跳过推理测试")
parser.add_argument("--dataroot", type=str, default=None, help="MVTec-AD 数据集路径")
args = parser.parse_args()
print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
npu_ok = check_npu()
if not npu_ok:
print(f"\n❌ NPU 环境不可用,跳过后续测试")
sys.exit(1)
try:
import patchcore
except ImportError:
print(f"\n {FAIL} patchcore 模块未安装,请先执行: pip install -e .")
print(f" 然后重新运行: python test_npu.py")
sys.exit(1)
if not args.skip_tests:
run_tests()
if not args.skip_precision:
run_precision()
if not args.skip_inference:
run_inference(args.dataroot)
summary()
sys.exit(0 if FAILED == 0 else 1)
if __name__ == "__main__":
main()