Bbaishanyanginit project
5f1c8c3b创建于 4 天前历史提交
"""
Tests for MCP runtime operator fix tool
"""

import unittest
import sys
import os
import asyncio
from pathlib import Path
from unittest.mock import patch, MagicMock

sys.path.insert(0, str(Path(__file__).parent.parent))

TEST_MODEL_DIR = Path(__file__).parent / "model"
TEST_OUTPUT_DIR = Path(__file__).parent / "output" / "mcp_runtime_operator_fix"
TEST_RUNTIME_FIX_ONNX = TEST_MODEL_DIR / "operator_fix" / "test_op_fix_hardswish.onnx"


def check_onnx_dependencies():
    """Check if ONNX and onnx_graphsurgeon are available"""
    try:
        import onnx
        import onnx_graphsurgeon
        return True
    except ImportError:
        return False


class TestMCPRuntimeOperatorFix(unittest.TestCase):
    """Test MCP runtime operator fix tool"""

    @classmethod
    def setUpClass(cls):
        cls.output_dir = TEST_OUTPUT_DIR
        os.makedirs(cls.output_dir, exist_ok=True)

    def test_mcp_runtime_op_fix_schema_has_parameters(self):
        """Test that ohos_runtime_op_fix Tool schema includes expected parameters"""
        from ohos_model_claw.mcp_server import TOOLS
        
        runtime_op_fix_tool = None
        for tool in TOOLS:
            if tool.name == "ohos_runtime_op_fix":
                runtime_op_fix_tool = tool
                break
        
        self.assertIsNotNone(runtime_op_fix_tool, "ohos_runtime_op_fix tool not found in TOOLS")
        
        schema = runtime_op_fix_tool.inputSchema
        properties = schema.get("properties", {})
        
        expected_params = ["onnx_path", "output_path", "device", "operators"]
        for param in expected_params:
            self.assertIn(param, properties, f"Parameter '{param}' not found in ohos_runtime_op_fix schema")
        
        self.assertEqual(properties["device"].get("enum"), ["CPU", "GPU", "NPU"], "device enum values incorrect")
        self.assertEqual(properties["device"].get("default"), "NPU", "device default should be NPU")
        self.assertEqual(properties["operators"].get("type"), "array", "operators should be array type")
        
        required = schema.get("required", [])
        self.assertIn("onnx_path", required, "onnx_path should be required parameter")
        self.assertIn("device", required, "device should be required parameter")

    def test_mcp_handle_runtime_op_fix_passes_parameters(self):
        """Test that handle_runtime_op_fix function passes parameters to backend script"""
        from ohos_model_claw.mcp_server import handle_runtime_op_fix
        
        if not TEST_RUNTIME_FIX_ONNX.exists():
            self.skipTest(f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.run_task_sync.return_value = {
                    "status": "completed",
                    "output": "operators_replaced: ['Split']"
                }
                
                output_path = str(self.output_dir / "test_mcp_fixed.onnx")
                
                arguments = {
                    "onnx_path": str(TEST_RUNTIME_FIX_ONNX),
                    "output_path": output_path,
                    "device": "NPU",
                    "operators": ["Split"]
                }
                
                result = asyncio.run(handle_runtime_op_fix(arguments))
                
                called_cmd = mock_task.run_task_sync.call_args[1]['cmd']
                cmd_str = ' '.join(called_cmd)
                
                self.assertIn("runtime_op_fix", cmd_str, "runtime_op_fix command not in cmd")
                self.assertIn("--device NPU", cmd_str, "device parameter not passed")
                self.assertIn("--operators Split", cmd_str, "operators parameter not passed")
                self.assertIn("--output_path", cmd_str, "output_path parameter not passed")

    def test_mcp_handle_runtime_op_fix_auto_output_path(self):
        """Test that handle_runtime_op_fix generates default output_path when not provided"""
        from ohos_model_claw.mcp_server import handle_runtime_op_fix
        
        if not TEST_RUNTIME_FIX_ONNX.exists():
            self.skipTest(f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.run_task_sync.return_value = {
                    "status": "completed",
                    "output": ""
                }
                
                arguments = {
                    "onnx_path": str(TEST_RUNTIME_FIX_ONNX),
                    "device": "NPU"
                }
                
                result = asyncio.run(handle_runtime_op_fix(arguments))
                
                called_cmd = mock_task.run_task_sync.call_args[1]['cmd']
                cmd_str = ' '.join(called_cmd)
                
                expected_output = str(TEST_RUNTIME_FIX_ONNX.parent / "test_op_fix_hardswish_runtime_fixed.onnx")
                self.assertIn(expected_output, cmd_str, "auto-generated output_path not correct")

    def test_mcp_handle_runtime_op_fix_default_device(self):
        """Test that device defaults to NPU when not specified"""
        from ohos_model_claw.mcp_server import handle_runtime_op_fix
        
        if not TEST_RUNTIME_FIX_ONNX.exists():
            self.skipTest(f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.run_task_sync.return_value = {
                    "status": "completed",
                    "output": ""
                }
                
                output_path = str(self.output_dir / "test_default_device.onnx")
                
                arguments = {
                    "onnx_path": str(TEST_RUNTIME_FIX_ONNX),
                    "output_path": output_path
                }
                
                result = asyncio.run(handle_runtime_op_fix(arguments))
                
                called_cmd = mock_task.run_task_sync.call_args[1]['cmd']
                cmd_str = ' '.join(called_cmd)
                
                self.assertIn("--device NPU", cmd_str, "device should default to NPU")

    def test_mcp_handle_runtime_op_fix_missing_onnx_path(self):
        """Test that handle_runtime_op_fix returns error when onnx_path is missing"""
        from ohos_model_claw.mcp_server import handle_runtime_op_fix
        
        arguments = {
            "device": "NPU"
        }
        
        result = asyncio.run(handle_runtime_op_fix(arguments))
        
        result_text = result[0].text
        self.assertIn("Error", result_text, "Expected error for missing onnx_path")
        self.assertIn("onnx_path is required", result_text, "Expected specific error message")

    def test_mcp_handle_runtime_op_fix_nonexistent_file(self):
        """Test that handle_runtime_op_fix returns error for nonexistent ONNX file"""
        from ohos_model_claw.mcp_server import handle_runtime_op_fix
        
        nonexistent_path = str(TEST_OUTPUT_DIR / "nonexistent.onnx")
        
        arguments = {
            "onnx_path": nonexistent_path,
            "device": "NPU"
        }
        
        result = asyncio.run(handle_runtime_op_fix(arguments))
        
        result_text = result[0].text
        self.assertIn("Error", result_text, "Expected error for nonexistent file")
        self.assertIn("ONNX file not found", result_text, "Expected specific error message")

    def test_mcp_handle_runtime_op_fix_multiple_operators(self):
        """Test that handle_runtime_op_fix passes multiple operators correctly"""
        from ohos_model_claw.mcp_server import handle_runtime_op_fix
        
        if not TEST_RUNTIME_FIX_ONNX.exists():
            self.skipTest(f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.run_task_sync.return_value = {
                    "status": "completed",
                    "output": ""
                }
                
                output_path = str(self.output_dir / "test_multi_ops.onnx")
                
                arguments = {
                    "onnx_path": str(TEST_RUNTIME_FIX_ONNX),
                    "output_path": output_path,
                    "device": "NPU",
                    "operators": ["Split", "Mod"]
                }
                
                result = asyncio.run(handle_runtime_op_fix(arguments))
                
                called_cmd = mock_task.run_task_sync.call_args[1]['cmd']
                cmd_str = ' '.join(called_cmd)
                
                self.assertIn("--operators", cmd_str, "operators flag not passed")
                self.assertIn("Split", cmd_str, "Split not in operators list")
                self.assertIn("Mod", cmd_str, "Mod not in operators list")

    def test_mcp_handle_runtime_op_fix_result_structure(self):
        """Test that handle_runtime_op_fix returns result with expected structure"""
        from ohos_model_claw.mcp_server import handle_runtime_op_fix
        
        if not TEST_RUNTIME_FIX_ONNX.exists():
            self.skipTest(f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.run_task_sync.return_value = {
                    "status": "completed",
                    "output": "operators_replaced: ['Split']\nnodes_per_operator: {'Split': 9}"
                }
                
                output_path = str(self.output_dir / "test_result_structure.onnx")
                
                arguments = {
                    "onnx_path": str(TEST_RUNTIME_FIX_ONNX),
                    "output_path": output_path,
                    "device": "NPU"
                }
                
                result = asyncio.run(handle_runtime_op_fix(arguments))
                
                result_dict = eval(result[0].text)
                
                self.assertIn("status", result_dict, "status not in result")
                self.assertIn("device", result_dict, "device not in result")
                self.assertIn("output_path", result_dict, "output_path not in result")
                self.assertIn("operators_replaced", result_dict, "operators_replaced not in result")
                self.assertIn("nodes_per_operator", result_dict, "nodes_per_operator not in result")
                self.assertEqual(result_dict["status"], "completed", "status should be completed")
                self.assertEqual(result_dict["device"], "NPU", "device should be NPU")

    def test_mcp_handle_runtime_op_fix_sync_execution(self):
        """Test that handle_runtime_op_fix uses synchronous execution (run_task_sync)"""
        from ohos_model_claw.mcp_server import handle_runtime_op_fix
        
        if not TEST_RUNTIME_FIX_ONNX.exists():
            self.skipTest(f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.run_task_sync.return_value = {
                    "status": "completed",
                    "output": ""
                }
                
                output_path = str(self.output_dir / "test_sync_execution.onnx")
                
                arguments = {
                    "onnx_path": str(TEST_RUNTIME_FIX_ONNX),
                    "output_path": output_path,
                    "device": "NPU"
                }
                
                result = asyncio.run(handle_runtime_op_fix(arguments))
                
                mock_task.run_task_sync.assert_called_once()
                mock_task.create_task.assert_not_called()
                mock_task.start_task.assert_not_called()

    @unittest.skipIf(not check_onnx_dependencies(), "ONNX dependencies not available")
    @unittest.skipIf(not TEST_RUNTIME_FIX_ONNX.exists(), f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
    def test_mcp_runtime_op_fix_integration_split(self):
        """Integration test: MCP runtime_op_fix with Split operator"""
        import subprocess
        from ohos_model_claw.tool_manager import tool_manager
        
        tool_manager.apply_env_vars()
        
        env = os.environ.copy()
        
        output_path = str(self.output_dir / "integration_split_fixed.onnx")
        
        script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
        cmd = [sys.executable, str(script_path), "runtime_op_fix", str(TEST_RUNTIME_FIX_ONNX),
               "--output_path", output_path, "--device", "NPU", "--operators", "Split"]
        
        result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=120, encoding='utf-8', errors='replace')
        
        combined_output = (result.stdout or "") + (result.stderr or "")
        self.assertTrue("修复完成" in combined_output or "OK" in combined_output or "success" in combined_output.lower() or "成功替换" in combined_output,
                        f"Runtime operator fix failed. stdout: {result.stdout}, stderr: {result.stderr}")
        
        self.assertTrue(Path(output_path).exists(), f"Fixed ONNX file not created: {output_path}")

    @unittest.skipIf(not check_onnx_dependencies(), "ONNX dependencies not available")
    @unittest.skipIf(not TEST_RUNTIME_FIX_ONNX.exists(), f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
    def test_mcp_runtime_op_fix_output_valid_onnx(self):
        """Test that fixed output is a valid ONNX model"""
        import subprocess
        import onnx
        from ohos_model_claw.tool_manager import tool_manager
        
        tool_manager.apply_env_vars()
        
        env = os.environ.copy()
        
        output_path = str(self.output_dir / "validation_fixed.onnx")
        
        script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
        cmd = [sys.executable, str(script_path), "runtime_op_fix", str(TEST_RUNTIME_FIX_ONNX),
               "--output_path", output_path, "--device", "NPU"]
        
        result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=120, encoding='utf-8', errors='replace')
        
        if not Path(output_path).exists():
            self.skipTest("Output file not created, possibly no operators to fix")
        
        validation_error = None
        try:
            model = onnx.load(output_path)
            onnx.checker.check_model(model)
        except Exception as ex:
            validation_error = ex
        
        self.assertIsNone(validation_error, f"Fixed ONNX model is not valid: {validation_error}")

    @unittest.skipIf(not check_onnx_dependencies(), "ONNX dependencies not available")
    @unittest.skipIf(not TEST_RUNTIME_FIX_ONNX.exists(), f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
    def test_mcp_runtime_op_fix_split_replaced_with_slices(self):
        """Test that Split nodes are replaced with Slice nodes"""
        import subprocess
        import onnx
        from ohos_model_claw.tool_manager import tool_manager
        
        tool_manager.apply_env_vars()
        
        env = os.environ.copy()
        
        output_path = str(self.output_dir / "split_to_slices.onnx")
        
        script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
        cmd = [sys.executable, str(script_path), "runtime_op_fix", str(TEST_RUNTIME_FIX_ONNX),
               "--output_path", output_path, "--device", "NPU", "--operators", "Split"]
        
        result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=120, encoding='utf-8', errors='replace')
        
        if not Path(output_path).exists():
            self.skipTest("Output file not created")
        
        model = onnx.load(output_path)
        
        split_count = len([n for n in model.graph.node if n.op_type == "Split"])
        slice_count = len([n for n in model.graph.node if n.op_type == "Slice"])
        
        self.assertEqual(split_count, 0, f"Should have no Split nodes, found {split_count}")
        self.assertGreater(slice_count, 0, "Should have Slice nodes after replacement")


if __name__ == "__main__":
    unittest.main()