"""
Tests for MCP convert tool with new parameters
"""
import unittest
import sys
import os
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
TEST_MODEL_DIR = Path(__file__).parent / "model"
TEST_OUTPUT_DIR = Path(__file__).parent / "output" / "mcp_convert"
TEST_CONVERT_ONNX = TEST_MODEL_DIR / "convert" / "test_convert.onnx"
def tools_available():
"""Check if MindSpore Lite tools are installed"""
from ohos_model_claw.tool_manager import tool_manager
return tool_manager.check_tool_exists() and tool_manager.check_lib_dirs_exist()
class TestMCPConvertParameters(unittest.TestCase):
"""Test MCP convert tool parameter handling"""
@classmethod
def setUpClass(cls):
cls.output_dir = TEST_OUTPUT_DIR
os.makedirs(cls.output_dir, exist_ok=True)
def test_mcp_convert_schema_has_new_parameters(self):
"""Test that ohos_convert Tool schema includes new parameters"""
from ohos_model_claw.mcp_server import TOOLS
convert_tool = None
for tool in TOOLS:
if tool.name == "ohos_convert":
convert_tool = tool
break
self.assertIsNotNone(convert_tool, "ohos_convert tool not found in TOOLS")
schema = convert_tool.inputSchema
properties = schema.get("properties", {})
expected_params = ["fp16", "input_data_format", "output_data_format", "input_shape", "optimize"]
for param in expected_params:
self.assertIn(param, properties, f"Parameter '{param}' not found in ohos_convert schema")
self.assertEqual(properties["fp16"].get("enum"), ["on", "off"], "fp16 enum values incorrect")
self.assertEqual(properties["input_data_format"].get("enum"), ["NHWC", "NCHW"], "input_data_format enum values incorrect")
self.assertEqual(properties["output_data_format"].get("enum"), ["NHWC", "NCHW"], "output_data_format enum values incorrect")
self.assertEqual(properties["optimize"].get("enum"), ["none", "general", "gpu_oriented", "ascend_oriented"], "optimize enum values incorrect")
def test_mcp_handle_convert_passes_parameters(self):
"""Test that handle_convert function passes new parameters to backend script"""
import asyncio
from unittest.mock import patch, MagicMock
from ohos_model_claw.mcp_server import handle_convert
if not TEST_CONVERT_ONNX.exists():
self.skipTest(f"Test model not found: {TEST_CONVERT_ONNX}")
with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
mock_tm.ensure_tool_ready.return_value = (True, "ready")
mock_tm.apply_env_vars.return_value = None
with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
mock_task.create_task.return_value = "test_task_123"
mock_task.start_task.return_value = None
arguments = {
"onnx_path": str(TEST_CONVERT_ONNX),
"output_dir": str(self.output_dir / "test_mcp_params"),
"fp16": "on",
"input_data_format": "NCHW",
"output_data_format": "NCHW",
"input_shape": "input:1,3,224,224",
"optimize": "general",
"async_mode": True
}
result = asyncio.run(handle_convert(arguments))
called_cmd = mock_task.start_task.call_args[1]['cmd']
cmd_str = ' '.join(called_cmd)
self.assertIn("--fp16 on", cmd_str, "fp16 parameter not passed")
self.assertIn("--inputDataFormat NCHW", cmd_str, "inputDataFormat parameter not passed")
self.assertIn("--outputDataFormat NCHW", cmd_str, "outputDataFormat parameter not passed")
self.assertIn("--inputShape input:1,3,224,224", cmd_str, "inputShape parameter not passed")
self.assertIn("--optimize general", cmd_str, "optimize parameter not passed")
def test_mcp_handle_convert_fp16_off_default(self):
"""Test that fp16=off is passed correctly"""
import asyncio
from unittest.mock import patch
from ohos_model_claw.mcp_server import handle_convert
if not TEST_CONVERT_ONNX.exists():
self.skipTest(f"Test model not found: {TEST_CONVERT_ONNX}")
with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
mock_tm.ensure_tool_ready.return_value = (True, "ready")
mock_tm.apply_env_vars.return_value = None
with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
mock_task.create_task.return_value = "test_task_456"
mock_task.start_task.return_value = None
arguments = {
"onnx_path": str(TEST_CONVERT_ONNX),
"async_mode": True
}
result = asyncio.run(handle_convert(arguments))
called_cmd = mock_task.start_task.call_args[1]['cmd']
cmd_str = ' '.join(called_cmd)
self.assertIn("--fp16 off", cmd_str, "fp16 off not passed as default")
def test_mcp_handle_convert_optimize_gpu_oriented(self):
"""Test that optimize=gpu_oriented is passed correctly"""
import asyncio
from unittest.mock import patch
from ohos_model_claw.mcp_server import handle_convert
if not TEST_CONVERT_ONNX.exists():
self.skipTest(f"Test model not found: {TEST_CONVERT_ONNX}")
with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
mock_tm.ensure_tool_ready.return_value = (True, "ready")
mock_tm.apply_env_vars.return_value = None
with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
mock_task.create_task.return_value = "test_task_789"
mock_task.start_task.return_value = None
arguments = {
"onnx_path": str(TEST_CONVERT_ONNX),
"optimize": "gpu_oriented",
"async_mode": True
}
result = asyncio.run(handle_convert(arguments))
called_cmd = mock_task.start_task.call_args[1]['cmd']
cmd_str = ' '.join(called_cmd)
self.assertIn("--optimize gpu_oriented", cmd_str, "optimize gpu_oriented not passed")
@unittest.skipIf(not tools_available(), "MindSpore Lite tools not installed")
def test_mcp_convert_integration(self):
"""Integration test: MCP convert with all parameters produces valid .ms file"""
import subprocess
from ohos_model_claw.tool_manager import tool_manager
tool_manager.apply_env_vars()
output_dir = str(self.output_dir / "test_mcp_integration")
os.makedirs(output_dir, exist_ok=True)
onnx_path = str(TEST_CONVERT_ONNX)
onnx_name = Path(onnx_path).stem
ms_path = str(Path(output_dir) / f"{onnx_name}.ms")
env = os.environ.copy()
env["CONVERTER_LITE_PATH"] = str(tool_manager.converter_path)
env["PACKAGE_ROOT_PATH"] = str(tool_manager.package_root)
script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
cmd = [
sys.executable, str(script_path), "convert", onnx_path,
"--output_dir", output_dir,
"--fp16", "on",
"--inputDataFormat", "NCHW",
"--outputDataFormat", "NCHW",
"--optimize", "general"
]
result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=120, encoding='utf-8', errors='replace')
self.assertTrue("Success" in (result.stdout or "") or "OK" in (result.stdout or ""),
f"Conversion failed. stdout: {result.stdout}, stderr: {result.stderr}")
self.assertTrue(Path(ms_path).exists(), f"Output .ms file not created: {ms_path}")
if __name__ == "__main__":
unittest.main()