"""
Tests for MCP runtime operator fix tool
"""
import unittest
import sys
import os
import asyncio
from pathlib import Path
from unittest.mock import patch, MagicMock
sys.path.insert(0, str(Path(__file__).parent.parent))
TEST_MODEL_DIR = Path(__file__).parent / "model"
TEST_OUTPUT_DIR = Path(__file__).parent / "output" / "mcp_runtime_operator_fix"
TEST_RUNTIME_FIX_ONNX = TEST_MODEL_DIR / "operator_fix" / "test_op_fix_hardswish.onnx"
def check_onnx_dependencies():
"""Check if ONNX and onnx_graphsurgeon are available"""
try:
import onnx
import onnx_graphsurgeon
return True
except ImportError:
return False
class TestMCPRuntimeOperatorFix(unittest.TestCase):
"""Test MCP runtime operator fix tool"""
@classmethod
def setUpClass(cls):
cls.output_dir = TEST_OUTPUT_DIR
os.makedirs(cls.output_dir, exist_ok=True)
def test_mcp_runtime_op_fix_schema_has_parameters(self):
"""Test that ohos_runtime_op_fix Tool schema includes expected parameters"""
from ohos_model_claw.mcp_server import TOOLS
runtime_op_fix_tool = None
for tool in TOOLS:
if tool.name == "ohos_runtime_op_fix":
runtime_op_fix_tool = tool
break
self.assertIsNotNone(runtime_op_fix_tool, "ohos_runtime_op_fix tool not found in TOOLS")
schema = runtime_op_fix_tool.inputSchema
properties = schema.get("properties", {})
expected_params = ["onnx_path", "output_path", "device", "operators"]
for param in expected_params:
self.assertIn(param, properties, f"Parameter '{param}' not found in ohos_runtime_op_fix schema")
self.assertEqual(properties["device"].get("enum"), ["CPU", "GPU", "NPU"], "device enum values incorrect")
self.assertEqual(properties["device"].get("default"), "NPU", "device default should be NPU")
self.assertEqual(properties["operators"].get("type"), "array", "operators should be array type")
required = schema.get("required", [])
self.assertIn("onnx_path", required, "onnx_path should be required parameter")
self.assertIn("device", required, "device should be required parameter")
def test_mcp_handle_runtime_op_fix_passes_parameters(self):
"""Test that handle_runtime_op_fix function passes parameters to backend script"""
from ohos_model_claw.mcp_server import handle_runtime_op_fix
if not TEST_RUNTIME_FIX_ONNX.exists():
self.skipTest(f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
mock_tm.apply_env_vars.return_value = None
with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
mock_task.run_task_sync.return_value = {
"status": "completed",
"output": "operators_replaced: ['Split']"
}
output_path = str(self.output_dir / "test_mcp_fixed.onnx")
arguments = {
"onnx_path": str(TEST_RUNTIME_FIX_ONNX),
"output_path": output_path,
"device": "NPU",
"operators": ["Split"]
}
result = asyncio.run(handle_runtime_op_fix(arguments))
called_cmd = mock_task.run_task_sync.call_args[1]['cmd']
cmd_str = ' '.join(called_cmd)
self.assertIn("runtime_op_fix", cmd_str, "runtime_op_fix command not in cmd")
self.assertIn("--device NPU", cmd_str, "device parameter not passed")
self.assertIn("--operators Split", cmd_str, "operators parameter not passed")
self.assertIn("--output_path", cmd_str, "output_path parameter not passed")
def test_mcp_handle_runtime_op_fix_auto_output_path(self):
"""Test that handle_runtime_op_fix generates default output_path when not provided"""
from ohos_model_claw.mcp_server import handle_runtime_op_fix
if not TEST_RUNTIME_FIX_ONNX.exists():
self.skipTest(f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
mock_tm.apply_env_vars.return_value = None
with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
mock_task.run_task_sync.return_value = {
"status": "completed",
"output": ""
}
arguments = {
"onnx_path": str(TEST_RUNTIME_FIX_ONNX),
"device": "NPU"
}
result = asyncio.run(handle_runtime_op_fix(arguments))
called_cmd = mock_task.run_task_sync.call_args[1]['cmd']
cmd_str = ' '.join(called_cmd)
expected_output = str(TEST_RUNTIME_FIX_ONNX.parent / "test_op_fix_hardswish_runtime_fixed.onnx")
self.assertIn(expected_output, cmd_str, "auto-generated output_path not correct")
def test_mcp_handle_runtime_op_fix_default_device(self):
"""Test that device defaults to NPU when not specified"""
from ohos_model_claw.mcp_server import handle_runtime_op_fix
if not TEST_RUNTIME_FIX_ONNX.exists():
self.skipTest(f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
mock_tm.apply_env_vars.return_value = None
with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
mock_task.run_task_sync.return_value = {
"status": "completed",
"output": ""
}
output_path = str(self.output_dir / "test_default_device.onnx")
arguments = {
"onnx_path": str(TEST_RUNTIME_FIX_ONNX),
"output_path": output_path
}
result = asyncio.run(handle_runtime_op_fix(arguments))
called_cmd = mock_task.run_task_sync.call_args[1]['cmd']
cmd_str = ' '.join(called_cmd)
self.assertIn("--device NPU", cmd_str, "device should default to NPU")
def test_mcp_handle_runtime_op_fix_missing_onnx_path(self):
"""Test that handle_runtime_op_fix returns error when onnx_path is missing"""
from ohos_model_claw.mcp_server import handle_runtime_op_fix
arguments = {
"device": "NPU"
}
result = asyncio.run(handle_runtime_op_fix(arguments))
result_text = result[0].text
self.assertIn("Error", result_text, "Expected error for missing onnx_path")
self.assertIn("onnx_path is required", result_text, "Expected specific error message")
def test_mcp_handle_runtime_op_fix_nonexistent_file(self):
"""Test that handle_runtime_op_fix returns error for nonexistent ONNX file"""
from ohos_model_claw.mcp_server import handle_runtime_op_fix
nonexistent_path = str(TEST_OUTPUT_DIR / "nonexistent.onnx")
arguments = {
"onnx_path": nonexistent_path,
"device": "NPU"
}
result = asyncio.run(handle_runtime_op_fix(arguments))
result_text = result[0].text
self.assertIn("Error", result_text, "Expected error for nonexistent file")
self.assertIn("ONNX file not found", result_text, "Expected specific error message")
def test_mcp_handle_runtime_op_fix_multiple_operators(self):
"""Test that handle_runtime_op_fix passes multiple operators correctly"""
from ohos_model_claw.mcp_server import handle_runtime_op_fix
if not TEST_RUNTIME_FIX_ONNX.exists():
self.skipTest(f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
mock_tm.apply_env_vars.return_value = None
with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
mock_task.run_task_sync.return_value = {
"status": "completed",
"output": ""
}
output_path = str(self.output_dir / "test_multi_ops.onnx")
arguments = {
"onnx_path": str(TEST_RUNTIME_FIX_ONNX),
"output_path": output_path,
"device": "NPU",
"operators": ["Split", "Mod"]
}
result = asyncio.run(handle_runtime_op_fix(arguments))
called_cmd = mock_task.run_task_sync.call_args[1]['cmd']
cmd_str = ' '.join(called_cmd)
self.assertIn("--operators", cmd_str, "operators flag not passed")
self.assertIn("Split", cmd_str, "Split not in operators list")
self.assertIn("Mod", cmd_str, "Mod not in operators list")
def test_mcp_handle_runtime_op_fix_result_structure(self):
"""Test that handle_runtime_op_fix returns result with expected structure"""
from ohos_model_claw.mcp_server import handle_runtime_op_fix
if not TEST_RUNTIME_FIX_ONNX.exists():
self.skipTest(f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
mock_tm.apply_env_vars.return_value = None
with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
mock_task.run_task_sync.return_value = {
"status": "completed",
"output": "operators_replaced: ['Split']\nnodes_per_operator: {'Split': 9}"
}
output_path = str(self.output_dir / "test_result_structure.onnx")
arguments = {
"onnx_path": str(TEST_RUNTIME_FIX_ONNX),
"output_path": output_path,
"device": "NPU"
}
result = asyncio.run(handle_runtime_op_fix(arguments))
result_dict = eval(result[0].text)
self.assertIn("status", result_dict, "status not in result")
self.assertIn("device", result_dict, "device not in result")
self.assertIn("output_path", result_dict, "output_path not in result")
self.assertIn("operators_replaced", result_dict, "operators_replaced not in result")
self.assertIn("nodes_per_operator", result_dict, "nodes_per_operator not in result")
self.assertEqual(result_dict["status"], "completed", "status should be completed")
self.assertEqual(result_dict["device"], "NPU", "device should be NPU")
def test_mcp_handle_runtime_op_fix_sync_execution(self):
"""Test that handle_runtime_op_fix uses synchronous execution (run_task_sync)"""
from ohos_model_claw.mcp_server import handle_runtime_op_fix
if not TEST_RUNTIME_FIX_ONNX.exists():
self.skipTest(f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
with patch('ohos_model_claw.mcp_server.tool_manager') as mock_tm:
mock_tm.apply_env_vars.return_value = None
with patch('ohos_model_claw.mcp_server.task_manager') as mock_task:
mock_task.run_task_sync.return_value = {
"status": "completed",
"output": ""
}
output_path = str(self.output_dir / "test_sync_execution.onnx")
arguments = {
"onnx_path": str(TEST_RUNTIME_FIX_ONNX),
"output_path": output_path,
"device": "NPU"
}
result = asyncio.run(handle_runtime_op_fix(arguments))
mock_task.run_task_sync.assert_called_once()
mock_task.create_task.assert_not_called()
mock_task.start_task.assert_not_called()
@unittest.skipIf(not check_onnx_dependencies(), "ONNX dependencies not available")
@unittest.skipIf(not TEST_RUNTIME_FIX_ONNX.exists(), f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
def test_mcp_runtime_op_fix_integration_split(self):
"""Integration test: MCP runtime_op_fix with Split operator"""
import subprocess
from ohos_model_claw.tool_manager import tool_manager
tool_manager.apply_env_vars()
env = os.environ.copy()
output_path = str(self.output_dir / "integration_split_fixed.onnx")
script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
cmd = [sys.executable, str(script_path), "runtime_op_fix", str(TEST_RUNTIME_FIX_ONNX),
"--output_path", output_path, "--device", "NPU", "--operators", "Split"]
result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=120, encoding='utf-8', errors='replace')
combined_output = (result.stdout or "") + (result.stderr or "")
self.assertTrue("修复完成" in combined_output or "OK" in combined_output or "success" in combined_output.lower() or "成功替换" in combined_output,
f"Runtime operator fix failed. stdout: {result.stdout}, stderr: {result.stderr}")
self.assertTrue(Path(output_path).exists(), f"Fixed ONNX file not created: {output_path}")
@unittest.skipIf(not check_onnx_dependencies(), "ONNX dependencies not available")
@unittest.skipIf(not TEST_RUNTIME_FIX_ONNX.exists(), f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
def test_mcp_runtime_op_fix_output_valid_onnx(self):
"""Test that fixed output is a valid ONNX model"""
import subprocess
import onnx
from ohos_model_claw.tool_manager import tool_manager
tool_manager.apply_env_vars()
env = os.environ.copy()
output_path = str(self.output_dir / "validation_fixed.onnx")
script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
cmd = [sys.executable, str(script_path), "runtime_op_fix", str(TEST_RUNTIME_FIX_ONNX),
"--output_path", output_path, "--device", "NPU"]
result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=120, encoding='utf-8', errors='replace')
if not Path(output_path).exists():
self.skipTest("Output file not created, possibly no operators to fix")
validation_error = None
try:
model = onnx.load(output_path)
onnx.checker.check_model(model)
except Exception as ex:
validation_error = ex
self.assertIsNone(validation_error, f"Fixed ONNX model is not valid: {validation_error}")
@unittest.skipIf(not check_onnx_dependencies(), "ONNX dependencies not available")
@unittest.skipIf(not TEST_RUNTIME_FIX_ONNX.exists(), f"Test model not found: {TEST_RUNTIME_FIX_ONNX}")
def test_mcp_runtime_op_fix_split_replaced_with_slices(self):
"""Test that Split nodes are replaced with Slice nodes"""
import subprocess
import onnx
from ohos_model_claw.tool_manager import tool_manager
tool_manager.apply_env_vars()
env = os.environ.copy()
output_path = str(self.output_dir / "split_to_slices.onnx")
script_path = Path(__file__).parent.parent / "ohos_model_claw" / "ohos_model_claw.py"
cmd = [sys.executable, str(script_path), "runtime_op_fix", str(TEST_RUNTIME_FIX_ONNX),
"--output_path", output_path, "--device", "NPU", "--operators", "Split"]
result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=120, encoding='utf-8', errors='replace')
if not Path(output_path).exists():
self.skipTest("Output file not created")
model = onnx.load(output_path)
split_count = len([n for n in model.graph.node if n.op_type == "Split"])
slice_count = len([n for n in model.graph.node if n.op_type == "Slice"])
self.assertEqual(split_count, 0, f"Should have no Split nodes, found {split_count}")
self.assertGreater(slice_count, 0, "Should have Slice nodes after replacement")
if __name__ == "__main__":
unittest.main()