Bbaishanyanginit project
5f1c8c3b创建于 4 天前历史提交
"""
Tests for MCP operator fix tool
"""

import unittest
import sys
import os
import asyncio
from pathlib import Path
from unittest.mock import patch, MagicMock

sys.path.insert(0, str(Path(__file__).parent.parent))

TEST_MODEL_DIR = Path(__file__).parent / "model"
TEST_OUTPUT_DIR = Path(__file__).parent / "output" / "mcp_operator_fix"
TEST_OP_FIX_HARDSIGMOID_ONNX = TEST_MODEL_DIR / "operator_fix" / "test_op_fix_hardsigmoid.onnx"
TEST_OP_FIX_HARDSWISH_ONNX = TEST_MODEL_DIR / "operator_fix" / "test_op_fix_hardswish.onnx"


def check_onnx_dependencies():
    """Check if ONNX and onnx_graphsurgeon are available"""
    try:
        import onnx
        import onnx_graphsurgeon
        return True
    except ImportError:
        return False


class TestMCPOperatorFix(unittest.TestCase):
    """Test MCP operator fix tool"""

    @classmethod
    def setUpClass(cls):
        cls.output_dir = TEST_OUTPUT_DIR
        os.makedirs(cls.output_dir, exist_ok=True)

    def test_mcp_op_fix_schema_has_parameters(self):
        """Test that ohos_op_fix Tool schema includes expected parameters"""
        from ohos_model_claw.mcp_server import TOOLS
        
        op_fix_tool = None
        for tool in TOOLS:
            if tool.name == "ohos_op_fix":
                op_fix_tool = tool
                break
        
        self.assertIsNotNone(op_fix_tool, "ohos_op_fix tool not found in TOOLS")
        
        schema = op_fix_tool.inputSchema
        properties = schema.get("properties", {})
        
        expected_params = ["onnx_path", "output_path", "operators"]
        for param in expected_params:
            self.assertIn(param, properties, f"Parameter '{param}' not found in ohos_op_fix schema")
        
        self.assertEqual(properties["operators"].get("type"), "array", "operators should be array type")
        
        required = schema.get("required", [])
        self.assertIn("onnx_path", required, "onnx_path should be required parameter")

    def test_mcp_handle_op_fix_passes_parameters(self):
        """Test that handle_op_fix function passes parameters to backend script"""
        from ohos_model_claw.mcp_server import handle_op_fix
        
        if not TEST_OP_FIX_HARDSWISH_ONNX.exists():
            self.skipTest(f"Test model not found: {TEST_OP_FIX_HARDSWISH_ONNX}")
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.run_task_sync.return_value = {
                    "status": "completed",
                    "output": "operators_replaced: ['HardSwish']"
                }
                
                output_path = str(self.output_dir / "test_fixed.onnx")
                
                arguments = {
                    "onnx_path": str(TEST_OP_FIX_HARDSWISH_ONNX),
                    "output_path": output_path,
                    "operators": ["HardSwish"]
                }
                
                result = asyncio.run(handle_op_fix(arguments))
                
                called_cmd = mock_task.run_task_sync.call_args[1]['cmd']
                cmd_str = ' '.join(called_cmd)
                
                self.assertIn("op_fix", cmd_str, "op_fix command not in cmd")
                self.assertIn("--operators HardSwish", cmd_str, "operators parameter not passed")
                self.assertIn("--output_path", cmd_str, "output_path parameter not passed")

    def test_mcp_handle_op_fix_auto_output_path(self):
        """Test that handle_op_fix generates default output_path when not provided"""
        from ohos_model_claw.mcp_server import handle_op_fix
        
        if not TEST_OP_FIX_HARDSWISH_ONNX.exists():
            self.skipTest(f"Test model not found: {TEST_OP_FIX_HARDSWISH_ONNX}")
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.run_task_sync.return_value = {
                    "status": "completed",
                    "output": ""
                }
                
                arguments = {
                    "onnx_path": str(TEST_OP_FIX_HARDSWISH_ONNX)
                }
                
                result = asyncio.run(handle_op_fix(arguments))
                
                called_cmd = mock_task.run_task_sync.call_args[1]['cmd']
                cmd_str = ' '.join(called_cmd)
                
                expected_output = str(TEST_OP_FIX_HARDSWISH_ONNX.parent / "test_op_fix_hardswish_fixed.onnx")
                self.assertIn(expected_output, cmd_str, "auto-generated output_path not correct")

    def test_mcp_handle_op_fix_missing_onnx_path(self):
        """Test that handle_op_fix returns error when onnx_path is missing"""
        from ohos_model_claw.mcp_server import handle_op_fix
        
        arguments = {}
        
        result = asyncio.run(handle_op_fix(arguments))
        
        result_text = result[0].text
        self.assertIn("Error", result_text, "Expected error for missing onnx_path")
        self.assertIn("onnx_path is required", result_text, "Expected specific error message")

    def test_mcp_handle_op_fix_nonexistent_file(self):
        """Test that handle_op_fix returns error for nonexistent ONNX file"""
        from ohos_model_claw.mcp_server import handle_op_fix
        
        nonexistent_path = str(TEST_OUTPUT_DIR / "nonexistent.onnx")
        
        arguments = {
            "onnx_path": nonexistent_path
        }
        
        result = asyncio.run(handle_op_fix(arguments))
        
        result_text = result[0].text
        self.assertIn("Error", result_text, "Expected error for nonexistent file")
        self.assertIn("ONNX file not found", result_text, "Expected specific error message")

    def test_mcp_handle_op_fix_multiple_operators(self):
        """Test that handle_op_fix passes multiple operators correctly"""
        from ohos_model_claw.mcp_server import handle_op_fix
        
        if not TEST_OP_FIX_HARDSWISH_ONNX.exists():
            self.skipTest(f"Test model not found: {TEST_OP_FIX_HARDSWISH_ONNX}")
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.run_task_sync.return_value = {
                    "status": "completed",
                    "output": ""
                }
                
                output_path = str(self.output_dir / "test_multi_fixed.onnx")
                
                arguments = {
                    "onnx_path": str(TEST_OP_FIX_HARDSWISH_ONNX),
                    "output_path": output_path,
                    "operators": ["HardSwish", "HardSigmoid"]
                }
                
                result = asyncio.run(handle_op_fix(arguments))
                
                called_cmd = mock_task.run_task_sync.call_args[1]['cmd']
                cmd_str = ' '.join(called_cmd)
                
                self.assertIn("--operators", cmd_str, "operators flag not passed")
                self.assertIn("HardSwish", cmd_str, "HardSwish not in operators list")
                self.assertIn("HardSigmoid", cmd_str, "HardSigmoid not in operators list")

    def test_mcp_handle_op_fix_result_structure(self):
        """Test that handle_op_fix returns result with expected structure"""
        from ohos_model_claw.mcp_server import handle_op_fix
        
        if not TEST_OP_FIX_HARDSWISH_ONNX.exists():
            self.skipTest(f"Test model not found: {TEST_OP_FIX_HARDSWISH_ONNX}")
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.run_task_sync.return_value = {
                    "status": "completed",
                    "output": "operators_replaced: ['HardSwish']"
                }
                
                output_path = str(self.output_dir / "test_result_fixed.onnx")
                
                arguments = {
                    "onnx_path": str(TEST_OP_FIX_HARDSWISH_ONNX),
                    "output_path": output_path
                }
                
                result = asyncio.run(handle_op_fix(arguments))
                
                result_dict = eval(result[0].text)
                
                self.assertIn("status", result_dict, "status not in result")
                self.assertIn("output_path", result_dict, "output_path not in result")
                self.assertIn("operators_replaced", result_dict, "operators_replaced not in result")
                self.assertEqual(result_dict["status"], "completed", "status should be completed")

    def test_mcp_handle_op_fix_sync_execution(self):
        """Test that handle_op_fix uses synchronous execution (run_task_sync)"""
        from ohos_model_claw.mcp_server import handle_op_fix
        
        if not TEST_OP_FIX_HARDSWISH_ONNX.exists():
            self.skipTest(f"Test model not found: {TEST_OP_FIX_HARDSWISH_ONNX}")
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.run_task_sync.return_value = {
                    "status": "completed",
                    "output": ""
                }
                
                output_path = str(self.output_dir / "test_sync_fixed.onnx")
                
                arguments = {
                    "onnx_path": str(TEST_OP_FIX_HARDSWISH_ONNX),
                    "output_path": output_path
                }
                
                result = asyncio.run(handle_op_fix(arguments))
                
                mock_task.run_task_sync.assert_called_once()
                mock_task.create_task.assert_not_called()
                mock_task.start_task.assert_not_called()

    @unittest.skipIf(not check_onnx_dependencies(), "ONNX dependencies not available")
    @unittest.skipIf(not TEST_OP_FIX_HARDSWISH_ONNX.exists(), f"Test model not found: {TEST_OP_FIX_HARDSWISH_ONNX}")
    def test_mcp_op_fix_integration_hardswish(self):
        """Integration test: MCP op_fix with HardSwish model"""
        import subprocess
        from ohos_model_claw.tool_manager import tool_manager
        
        tool_manager.apply_env_vars()
        
        env = os.environ.copy()
        
        output_path = str(self.output_dir / "integration_hardswish_fixed.onnx")
        
        script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
        cmd = [sys.executable, str(script_path), "op_fix", str(TEST_OP_FIX_HARDSWISH_ONNX),
               "--output_path", output_path, "--operators", "HardSwish"]
        
        result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=120, encoding='utf-8', errors='replace')
        
        combined_output = (result.stdout or "") + (result.stderr or "")
        self.assertTrue("修复完成" in combined_output or "OK" in combined_output or "success" in combined_output.lower(),
                        f"Operator fix failed. stdout: {result.stdout}, stderr: {result.stderr}")
        
        self.assertTrue(Path(output_path).exists(), f"Fixed ONNX file not created: {output_path}")

    @unittest.skipIf(not check_onnx_dependencies(), "ONNX dependencies not available")
    @unittest.skipIf(not TEST_OP_FIX_HARDSIGMOID_ONNX.exists(), f"Test model not found: {TEST_OP_FIX_HARDSIGMOID_ONNX}")
    def test_mcp_op_fix_integration_hardsigmoid(self):
        """Integration test: MCP op_fix with HardSigmoid model"""
        import subprocess
        from ohos_model_claw.tool_manager import tool_manager
        
        tool_manager.apply_env_vars()
        
        env = os.environ.copy()
        
        output_path = str(self.output_dir / "integration_hardsigmoid_fixed.onnx")
        
        script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
        cmd = [sys.executable, str(script_path), "op_fix", str(TEST_OP_FIX_HARDSIGMOID_ONNX),
               "--output_path", output_path, "--operators", "HardSigmoid"]
        
        result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=120, encoding='utf-8', errors='replace')
        
        combined_output = (result.stdout or "") + (result.stderr or "")
        self.assertTrue("修复完成" in combined_output or "OK" in combined_output or "success" in combined_output.lower(),
                        f"Operator fix failed. stdout: {result.stdout}, stderr: {result.stderr}")
        
        self.assertTrue(Path(output_path).exists(), f"Fixed ONNX file not created: {output_path}")

    @unittest.skipIf(not check_onnx_dependencies(), "ONNX dependencies not available")
    @unittest.skipIf(not TEST_OP_FIX_HARDSWISH_ONNX.exists(), f"Test model not found: {TEST_OP_FIX_HARDSWISH_ONNX}")
    def test_mcp_op_fix_output_valid_onnx(self):
        """Test that fixed output is a valid ONNX model"""
        import subprocess
        import onnx
        from ohos_model_claw.tool_manager import tool_manager
        
        tool_manager.apply_env_vars()
        
        env = os.environ.copy()
        
        output_path = str(self.output_dir / "validation_fixed.onnx")
        
        script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
        cmd = [sys.executable, str(script_path), "op_fix", str(TEST_OP_FIX_HARDSWISH_ONNX),
               "--output_path", output_path]
        
        result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=120, encoding='utf-8', errors='replace')
        
        if not Path(output_path).exists():
            self.skipTest("Output file not created, possibly no operators to fix")
        
        validation_error = None
        try:
            model = onnx.load(output_path)
            onnx.checker.check_model(model)
        except Exception as ex:
            validation_error = ex
        
        self.assertIsNone(validation_error, f"Fixed ONNX model is not valid: {validation_error}")


if __name__ == "__main__":
    unittest.main()