Bbaishanyanginit project
5f1c8c3b创建于 4 天前历史提交
"""
Tests for MCP benchmark tool
"""

import unittest
import sys
import os
import asyncio
from pathlib import Path
from unittest.mock import patch, MagicMock

sys.path.insert(0, str(Path(__file__).parent.parent))

TEST_MODEL_DIR = Path(__file__).parent / "model"
TEST_OUTPUT_DIR = Path(__file__).parent / "output" / "mcp_benchmark"
TEST_BENCHMARK_ONNX = TEST_MODEL_DIR / "benchmark" / "test_benchmark.onnx"


def tools_available():
    """Check if MindSpore Lite tools are installed"""
    from ohos_model_claw.tool_manager import tool_manager
    return tool_manager.check_tool_exists() and tool_manager.check_lib_dirs_exist()


class TestMCPBenchmark(unittest.TestCase):
    """Test MCP benchmark tool"""

    @classmethod
    def setUpClass(cls):
        cls.output_dir = TEST_OUTPUT_DIR
        os.makedirs(cls.output_dir, exist_ok=True)
        cls.ms_file = None
        
        if tools_available() and TEST_BENCHMARK_ONNX.exists():
            import subprocess
            from ohos_model_claw.tool_manager import tool_manager
            
            tool_manager.apply_env_vars()
            
            output_dir = str(cls.output_dir / "ms_output")
            os.makedirs(output_dir, exist_ok=True)
            
            onnx_path = str(TEST_BENCHMARK_ONNX)
            onnx_name = Path(onnx_path).stem
            ms_path = str(Path(output_dir) / f"{onnx_name}.ms")
            
            env = os.environ.copy()
            env["CONVERTER_LITE_PATH"] = str(tool_manager.converter_path)
            env["PACKAGE_ROOT_PATH"] = str(tool_manager.package_root)
            
            script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
            cmd = [sys.executable, str(script_path), "convert", onnx_path, "--output_dir", output_dir]
            
            result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=120, encoding='utf-8', errors='replace')
            
            if Path(ms_path).exists():
                cls.ms_file = ms_path
                print(f"[INFO] MCP Benchmark test model: {ms_path}")

    def test_mcp_benchmark_schema_has_parameters(self):
        """Test that ohos_benchmark Tool schema includes expected parameters"""
        from ohos_model_claw.mcp_server import TOOLS
        
        benchmark_tool = None
        for tool in TOOLS:
            if tool.name == "ohos_benchmark":
                benchmark_tool = tool
                break
        
        self.assertIsNotNone(benchmark_tool, "ohos_benchmark tool not found in TOOLS")
        
        schema = benchmark_tool.inputSchema
        properties = schema.get("properties", {})
        
        expected_params = ["ms_path", "device", "threads", "loop_count", "input_shape"]
        for param in expected_params:
            self.assertIn(param, properties, f"Parameter '{param}' not found in ohos_benchmark schema")
        
        self.assertEqual(properties["device"].get("enum"), ["CPU", "GPU", "NPU"], "device enum values incorrect")
        self.assertEqual(properties["device"].get("default"), "CPU", "device default incorrect")
        self.assertEqual(properties["threads"].get("default"), 4, "threads default incorrect")
        self.assertEqual(properties["loop_count"].get("default"), 100, "loop_count default incorrect")
        
        required = schema.get("required", [])
        self.assertIn("ms_path", required, "ms_path should be required parameter")

    def test_mcp_handle_benchmark_passes_parameters(self):
        """Test that handle_benchmark function passes parameters to backend script"""
        from ohos_model_claw.mcp_server import handle_benchmark
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.ensure_tool_ready.return_value = (True, "ready")
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.create_task.return_value = "benchmark_task_123"
                mock_task.start_task.return_value = None
                
                if self.ms_file:
                    ms_path = self.ms_file
                else:
                    ms_path = str(TEST_OUTPUT_DIR / "dummy.ms")
                
                arguments = {
                    "ms_path": ms_path,
                    "device": "CPU",
                    "threads": 2,
                    "loop_count": 10,
                    "input_shape": "input:1,3,224,224"
                }
                
                result = asyncio.run(handle_benchmark(arguments))
                
                called_cmd = mock_task.start_task.call_args[1]['cmd']
                cmd_str = ' '.join(called_cmd)
                
                self.assertIn("benchmark", cmd_str, "benchmark command not in cmd")
                self.assertIn("--threads 2", cmd_str, "threads parameter not passed")
                self.assertIn("--loop 10", cmd_str, "loop_count parameter not passed")
                self.assertIn("--device CPU", cmd_str, "device parameter not passed")
                self.assertIn("--inputShape input:1,3,224,224", cmd_str, "input_shape parameter not passed")

    def test_mcp_handle_benchmark_default_values(self):
        """Test that handle_benchmark uses default values correctly"""
        from ohos_model_claw.mcp_server import handle_benchmark
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.ensure_tool_ready.return_value = (True, "ready")
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.create_task.return_value = "benchmark_task_default"
                mock_task.start_task.return_value = None
                
                ms_path = str(TEST_OUTPUT_DIR / "dummy.ms")
                
                arguments = {
                    "ms_path": ms_path
                }
                
                result = asyncio.run(handle_benchmark(arguments))
                
                called_cmd = mock_task.start_task.call_args[1]['cmd']
                cmd_str = ' '.join(called_cmd)
                
                self.assertIn("--threads 4", cmd_str, "threads default (4) not passed")
                self.assertIn("--loop 100", cmd_str, "loop_count default (100) not passed")
                self.assertIn("--device CPU", cmd_str, "device default (CPU) not passed")

    def test_mcp_handle_benchmark_missing_ms_path(self):
        """Test that handle_benchmark returns error when ms_path is missing"""
        from ohos_model_claw.mcp_server import handle_benchmark
        
        arguments = {}
        
        result = asyncio.run(handle_benchmark(arguments))
        
        result_text = result[0].text
        self.assertIn("Error", result_text, "Expected error for missing ms_path")
        self.assertIn("ms_path is required", result_text, "Expected specific error message")

    def test_mcp_handle_benchmark_gpu_device(self):
        """Test that handle_benchmark passes GPU device correctly"""
        from ohos_model_claw.mcp_server import handle_benchmark
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.ensure_tool_ready.return_value = (True, "ready")
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.create_task.return_value = "benchmark_task_gpu"
                mock_task.start_task.return_value = None
                
                ms_path = str(TEST_OUTPUT_DIR / "dummy.ms")
                
                arguments = {
                    "ms_path": ms_path,
                    "device": "GPU"
                }
                
                result = asyncio.run(handle_benchmark(arguments))
                
                called_cmd = mock_task.start_task.call_args[1]['cmd']
                cmd_str = ' '.join(called_cmd)
                
                self.assertIn("--device GPU", cmd_str, "GPU device not passed correctly")

    def test_mcp_handle_benchmark_npu_device(self):
        """Test that handle_benchmark passes NPU device correctly"""
        from ohos_model_claw.mcp_server import handle_benchmark
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.ensure_tool_ready.return_value = (True, "ready")
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.create_task.return_value = "benchmark_task_npu"
                mock_task.start_task.return_value = None
                
                ms_path = str(TEST_OUTPUT_DIR / "dummy.ms")
                
                arguments = {
                    "ms_path": ms_path,
                    "device": "NPU"
                }
                
                result = asyncio.run(handle_benchmark(arguments))
                
                called_cmd = mock_task.start_task.call_args[1]['cmd']
                cmd_str = ' '.join(called_cmd)
                
                self.assertIn("--device NPU", cmd_str, "NPU device not passed correctly")

    def test_mcp_handle_benchmark_returns_task_id(self):
        """Test that handle_benchmark returns task_id for async operation"""
        from ohos_model_claw.mcp_server import handle_benchmark
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.ensure_tool_ready.return_value = (True, "ready")
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.create_task.return_value = "test_async_task_id"
                mock_task.start_task.return_value = None
                
                ms_path = str(TEST_OUTPUT_DIR / "dummy.ms")
                
                arguments = {
                    "ms_path": ms_path
                }
                
                result = asyncio.run(handle_benchmark(arguments))
                
                result_dict = eval(result[0].text)
                self.assertIn("task_id", result_dict, "task_id not in result")
                self.assertEqual(result_dict["task_id"], "test_async_task_id", "task_id value incorrect")
                self.assertEqual(result_dict["status"], "started", "status should be 'started'")

    @unittest.skipIf(not tools_available(), "MindSpore Lite tools not installed")
    def test_mcp_benchmark_integration(self):
        """Integration test: MCP benchmark with valid .ms file"""
        if self.ms_file is None:
            self.skipTest("No valid .ms file available for benchmark")
        
        import subprocess
        from ohos_model_claw.tool_manager import tool_manager
        
        tool_manager.apply_env_vars()
        
        env = os.environ.copy()
        env["CONVERTER_LITE_PATH"] = str(tool_manager.converter_path)
        env["PACKAGE_ROOT_PATH"] = str(tool_manager.package_root)
        env["BENCHMARK_PATH"] = str(tool_manager.benchmark_path)
        
        script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
        cmd = [sys.executable, str(script_path), "benchmark", str(self.ms_file),
               "--threads", "2", "--loop", "10"]
        
        result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=60, encoding='utf-8', errors='replace')
        
        combined_output = (result.stdout or "") + (result.stderr or "")
        self.assertTrue("Benchmark" in combined_output or "avg" in combined_output.lower() or "OK" in combined_output,
                        f"Benchmark failed. stdout: {result.stdout}, stderr: {result.stderr}")


if __name__ == "__main__":
    unittest.main()