Bbaishanyanginit project
5f1c8c3b创建于 4 天前历史提交
"""
Tests for model benchmark functionality
"""

import unittest
import sys
import os
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

TEST_MODEL_DIR = Path(__file__).parent / "model"
TEST_OUTPUT_DIR = Path(__file__).parent / "output" / "benchmark"
TEST_BENCHMARK_ONNX = TEST_MODEL_DIR / "benchmark" / "test_benchmark.onnx"


def tools_available():
    """Check if MindSpore Lite tools are installed"""
    from ohos_model_claw.tool_manager import tool_manager
    return tool_manager.check_tool_exists() and tool_manager.check_lib_dirs_exist()


class TestModelBenchmark(unittest.TestCase):
    """Test model benchmark functionality"""

    @staticmethod
    def _get_output_file_path(onnx_path, output_dir):
        """Helper to compute output .ms file path"""
        onnx_name = Path(onnx_path).stem
        return str(Path(output_dir) / f"{onnx_name}.ms")

    @classmethod
    def setUpClass(cls):
        cls.output_dir = TEST_OUTPUT_DIR
        os.makedirs(cls.output_dir, exist_ok=True)
        cls.ms_file = None
        
        if tools_available() and TEST_BENCHMARK_ONNX.exists():
            import subprocess
            from ohos_model_claw.tool_manager import tool_manager
            
            tool_manager.apply_env_vars()
            
            output_dir = str(cls.output_dir / "ms_output")
            os.makedirs(output_dir, exist_ok=True)
            
            onnx_path = str(TEST_BENCHMARK_ONNX)
            ms_path = cls._get_output_file_path(onnx_path, output_dir)
            
            env = os.environ.copy()
            env["CONVERTER_LITE_PATH"] = str(tool_manager.converter_path)
            env["PACKAGE_ROOT_PATH"] = str(tool_manager.package_root)
            
            script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
            cmd = [sys.executable, str(script_path), "convert", onnx_path, "--output_dir", output_dir]
            
            result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=120, encoding='utf-8', errors='replace')
            
            if Path(ms_path).exists():
                cls.ms_file = ms_path
                print(f"[INFO] Benchmark test model: {ms_path}")

    @classmethod
    def tearDownClass(cls):
        pass

    def test_benchmark_requires_ms_file(self):
        """Test that benchmark requires a valid .ms file"""
        if not tools_available():
            self.skipTest("MindSpore Lite tools not installed")
        
        self.assertIsNotNone(self.ms_file, "Benchmark test requires a valid .ms file (conversion may have failed)")

    def _save_benchmark_log(self, result, test_name, params_desc):
        """Helper to save benchmark log file"""
        from datetime import datetime
        
        ms_output_dir = Path(self.ms_file).parent
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file = ms_output_dir / f"benchmark_{test_name}_{timestamp}.txt"
        
        with open(log_file, "w", encoding="utf-8") as f:
            f.write("=" * 60 + "\n")
            f.write(f"Benchmark Test: {test_name}\n")
            f.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write(f"Model: {self.ms_file}\n")
            f.write(f"Params: {params_desc}\n")
            f.write("=" * 60 + "\n\n")
            f.write("STDOUT:\n")
            f.write((result.stdout or "") + "\n")
            f.write("\nSTDERR:\n")
            f.write((result.stderr or "") + "\n")
        
        print(f"[INFO] Log saved: {log_file}")

    @unittest.skipIf(not tools_available(), "MindSpore Lite tools not installed")
    def test_benchmark_with_valid_ms(self):
        """Test running benchmark on a valid .ms file"""
        if self.ms_file is None:
            self.skipTest("No valid .ms file available for benchmark")
        
        import subprocess
        
        script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
        cmd = [sys.executable, str(script_path), "benchmark", str(self.ms_file), "--threads", "2", "--loop", "10"]
        
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=60, encoding='utf-8', errors='replace')
        
        self._save_benchmark_log(result, "valid_ms", "Threads=2, Loop=10")
        
        self.assertTrue("Benchmark" in (result.stdout or "") or "benchmark" in (result.stdout or "").lower() or "avg" in (result.stdout or "").lower(),
                        f"Expected benchmark output. stdout: {result.stdout}, stderr: {result.stderr}")

    @unittest.skipIf(not tools_available(), "MindSpore Lite tools not installed")
    def test_benchmark_invalid_file(self):
        """Test error handling for invalid benchmark file"""
        import subprocess
        
        invalid_path = str(Path(__file__).parent / "nonexistent_model.ms")
        
        script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
        cmd = [sys.executable, str(script_path), "benchmark", invalid_path]
        
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=30, encoding='utf-8', errors='replace')
        
        combined_output = (result.stdout or "") + (result.stderr or "")
        is_error = ("ERROR" in combined_output) or ("Error" in combined_output) or ("benchmark" in combined_output.lower())
        self.assertTrue(is_error, f"Expected benchmark output for invalid file: {combined_output}")

    @unittest.skipIf(not tools_available(), "MindSpore Lite tools not installed")
    def test_benchmark_with_different_threads(self):
        """Test benchmark with different thread counts"""
        if self.ms_file is None:
            self.skipTest("No valid .ms file available for benchmark")
        
        import subprocess
        
        script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
        
        for threads in [1, 4]:
            cmd = [sys.executable, str(script_path), "benchmark", str(self.ms_file), "--threads", str(threads), "--loop", "5"]
            result = subprocess.run(cmd, capture_output=True, text=True, timeout=60, encoding='utf-8', errors='replace')
            self._save_benchmark_log(result, f"threads_{threads}", f"Threads={threads}, Loop=5")
            self.assertTrue("Benchmark" in (result.stdout or "") or "avg" in (result.stdout or "").lower() or "OK" in (result.stdout or ""),
                            f"Expected benchmark output with threads={threads}. stdout: {result.stdout}")

    @unittest.skipIf(not tools_available(), "MindSpore Lite tools not installed")
    def test_benchmark_with_warmup(self):
        """Test benchmark with custom warmup count"""
        if self.ms_file is None:
            self.skipTest("No valid .ms file available for benchmark")
        
        import subprocess
        
        script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
        cmd = [sys.executable, str(script_path), "benchmark", str(self.ms_file), "--threads", "2", "--loop", "5", "--warmup", "2"]
        
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=60, encoding='utf-8', errors='replace')
        self._save_benchmark_log(result, "warmup", "Threads=2, Loop=5, Warmup=2")
        self.assertTrue("Benchmark" in (result.stdout or "") or "avg" in (result.stdout or "").lower() or "OK" in (result.stdout or ""),
                        f"Expected benchmark output with warmup. stdout: {result.stdout}")

    @unittest.skipIf(not tools_available(), "MindSpore Lite tools not installed")
    def test_benchmark_with_fp16(self):
        """Test benchmark with FP16 enabled"""
        if self.ms_file is None:
            self.skipTest("No valid .ms file available for benchmark")
        
        import subprocess
        
        script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
        cmd = [sys.executable, str(script_path), "benchmark", str(self.ms_file), "--threads", "2", "--loop", "5", "--fp16"]
        
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=60, encoding='utf-8', errors='replace')
        self._save_benchmark_log(result, "fp16", "Threads=2, Loop=5, FP16=True")
        self.assertTrue("Benchmark" in (result.stdout or "") or "avg" in (result.stdout or "").lower() or "OK" in (result.stdout or ""),
                        f"Expected benchmark output with FP16. stdout: {result.stdout}")

    @unittest.skipIf(not tools_available(), "MindSpore Lite tools not installed")
    def test_benchmark_with_time_profiling(self):
        """Test benchmark with time profiling enabled"""
        if self.ms_file is None:
            self.skipTest("No valid .ms file available for benchmark")
        
        import subprocess
        
        script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
        cmd = [sys.executable, str(script_path), "benchmark", str(self.ms_file), "--threads", "2", "--loop", "5", "--timeProfiling"]
        
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=60, encoding='utf-8', errors='replace')
        self._save_benchmark_log(result, "time_profiling", "Threads=2, Loop=5, TimeProfiling=True")
        self.assertTrue("Benchmark" in (result.stdout or "") or "avg" in (result.stdout or "").lower() or "OK" in (result.stdout or ""),
                        f"Expected benchmark output with time profiling. stdout: {result.stdout}")

    @unittest.skipIf(not tools_available(), "MindSpore Lite tools not installed")
    def test_benchmark_with_high_loop_count(self):
        """Test benchmark with high loop count"""
        if self.ms_file is None:
            self.skipTest("No valid .ms file available for benchmark")
        
        import subprocess
        
        script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
        cmd = [sys.executable, str(script_path), "benchmark", str(self.ms_file), "--threads", "2", "--loop", "50"]
        
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=120, encoding='utf-8', errors='replace')
        self._save_benchmark_log(result, "high_loop", "Threads=2, Loop=50")
        self.assertTrue("Benchmark" in (result.stdout or "") or "avg" in (result.stdout or "").lower() or "OK" in (result.stdout or ""),
                        f"Expected benchmark output with high loop count. stdout: {result.stdout}")


if __name__ == "__main__":
    unittest.main()