Bbaishanyanginit project
5f1c8c3b创建于 4 天前历史提交
"""
Tests for MCP convert tool with new parameters
"""

import unittest
import sys
import os
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

TEST_MODEL_DIR = Path(__file__).parent / "model"
TEST_OUTPUT_DIR = Path(__file__).parent / "output" / "mcp_convert"
TEST_CONVERT_ONNX = TEST_MODEL_DIR / "convert" / "test_convert.onnx"


def tools_available():
    """Check if MindSpore Lite tools are installed"""
    from ohos_model_claw.tool_manager import tool_manager
    return tool_manager.check_tool_exists() and tool_manager.check_lib_dirs_exist()


class TestMCPConvertParameters(unittest.TestCase):
    """Test MCP convert tool parameter handling"""

    @classmethod
    def setUpClass(cls):
        cls.output_dir = TEST_OUTPUT_DIR
        os.makedirs(cls.output_dir, exist_ok=True)

    def test_mcp_convert_schema_has_new_parameters(self):
        """Test that ohos_convert Tool schema includes new parameters"""
        from ohos_model_claw.mcp_server import TOOLS
        
        convert_tool = None
        for tool in TOOLS:
            if tool.name == "ohos_convert":
                convert_tool = tool
                break
        
        self.assertIsNotNone(convert_tool, "ohos_convert tool not found in TOOLS")
        
        schema = convert_tool.inputSchema
        properties = schema.get("properties", {})
        
        expected_params = ["fp16", "input_data_format", "output_data_format", "input_shape", "optimize"]
        for param in expected_params:
            self.assertIn(param, properties, f"Parameter '{param}' not found in ohos_convert schema")
        
        self.assertEqual(properties["fp16"].get("enum"), ["on", "off"], "fp16 enum values incorrect")
        self.assertEqual(properties["input_data_format"].get("enum"), ["NHWC", "NCHW"], "input_data_format enum values incorrect")
        self.assertEqual(properties["output_data_format"].get("enum"), ["NHWC", "NCHW"], "output_data_format enum values incorrect")
        self.assertEqual(properties["optimize"].get("enum"), ["none", "general", "gpu_oriented", "ascend_oriented"], "optimize enum values incorrect")

    def test_mcp_handle_convert_passes_parameters(self):
        """Test that handle_convert function passes new parameters to backend script"""
        import asyncio
        from unittest.mock import patch, MagicMock
        from ohos_model_claw.mcp_server import handle_convert
        
        if not TEST_CONVERT_ONNX.exists():
            self.skipTest(f"Test model not found: {TEST_CONVERT_ONNX}")
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.ensure_tool_ready.return_value = (True, "ready")
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.create_task.return_value = "test_task_123"
                mock_task.start_task.return_value = None
                
                arguments = {
                    "onnx_path": str(TEST_CONVERT_ONNX),
                    "output_dir": str(self.output_dir / "test_mcp_params"),
                    "fp16": "on",
                    "input_data_format": "NCHW",
                    "output_data_format": "NCHW",
                    "input_shape": "input:1,3,224,224",
                    "optimize": "general",
                    "async_mode": True
                }
                
                result = asyncio.run(handle_convert(arguments))
                
                called_cmd = mock_task.start_task.call_args[1]['cmd']
                cmd_str = ' '.join(called_cmd)
                
                self.assertIn("--fp16 on", cmd_str, "fp16 parameter not passed")
                self.assertIn("--inputDataFormat NCHW", cmd_str, "inputDataFormat parameter not passed")
                self.assertIn("--outputDataFormat NCHW", cmd_str, "outputDataFormat parameter not passed")
                self.assertIn("--inputShape input:1,3,224,224", cmd_str, "inputShape parameter not passed")
                self.assertIn("--optimize general", cmd_str, "optimize parameter not passed")

    def test_mcp_handle_convert_fp16_off_default(self):
        """Test that fp16=off is passed correctly"""
        import asyncio
        from unittest.mock import patch
        from ohos_model_claw.mcp_server import handle_convert
        
        if not TEST_CONVERT_ONNX.exists():
            self.skipTest(f"Test model not found: {TEST_CONVERT_ONNX}")
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.ensure_tool_ready.return_value = (True, "ready")
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.create_task.return_value = "test_task_456"
                mock_task.start_task.return_value = None
                
                arguments = {
                    "onnx_path": str(TEST_CONVERT_ONNX),
                    "async_mode": True
                }
                
                result = asyncio.run(handle_convert(arguments))
                
                called_cmd = mock_task.start_task.call_args[1]['cmd']
                cmd_str = ' '.join(called_cmd)
                
                self.assertIn("--fp16 off", cmd_str, "fp16 off not passed as default")

    def test_mcp_handle_convert_optimize_gpu_oriented(self):
        """Test that optimize=gpu_oriented is passed correctly"""
        import asyncio
        from unittest.mock import patch
        from ohos_model_claw.mcp_server import handle_convert
        
        if not TEST_CONVERT_ONNX.exists():
            self.skipTest(f"Test model not found: {TEST_CONVERT_ONNX}")
        
        with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
            mock_tm.ensure_tool_ready.return_value = (True, "ready")
            mock_tm.apply_env_vars.return_value = None
            
            with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
                mock_task.create_task.return_value = "test_task_789"
                mock_task.start_task.return_value = None
                
                arguments = {
                    "onnx_path": str(TEST_CONVERT_ONNX),
                    "optimize": "gpu_oriented",
                    "async_mode": True
                }
                
                result = asyncio.run(handle_convert(arguments))
                
                called_cmd = mock_task.start_task.call_args[1]['cmd']
                cmd_str = ' '.join(called_cmd)
                
                self.assertIn("--optimize gpu_oriented", cmd_str, "optimize gpu_oriented not passed")

    @unittest.skipIf(not tools_available(), "MindSpore Lite tools not installed")
    def test_mcp_convert_integration(self):
        """Integration test: MCP convert with all parameters produces valid .ms file"""
        import subprocess
        from ohos_model_claw.tool_manager import tool_manager
        
        tool_manager.apply_env_vars()
        
        output_dir = str(self.output_dir / "test_mcp_integration")
        os.makedirs(output_dir, exist_ok=True)
        
        onnx_path = str(TEST_CONVERT_ONNX)
        onnx_name = Path(onnx_path).stem
        ms_path = str(Path(output_dir) / f"{onnx_name}.ms")
        
        env = os.environ.copy()
        env["CONVERTER_LITE_PATH"] = str(tool_manager.converter_path)
        env["PACKAGE_ROOT_PATH"] = str(tool_manager.package_root)
        
        script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
        cmd = [
            sys.executable, str(script_path), "convert", onnx_path,
            "--output_dir", output_dir,
            "--fp16", "on",
            "--inputDataFormat", "NCHW",
            "--outputDataFormat", "NCHW",
            "--optimize", "general"
        ]
        
        result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=120, encoding='utf-8', errors='replace')
        
        self.assertTrue("Success" in (result.stdout or "") or "OK" in (result.stdout or ""),
                        f"Conversion failed. stdout: {result.stdout}, stderr: {result.stderr}")
        self.assertTrue(Path(ms_path).exists(), f"Output .ms file not created: {ms_path}")


if __name__ == "__main__":
    unittest.main()