import json
import asyncio
from typing import Dict, List, Any, Iterator, ClassVar
from contextlib import AsyncExitStack
import pandas as pd
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mindsdb.utilities import log
from mindsdb.interfaces.agents.langchain_agent import LangchainAgent
from mindsdb.interfaces.storage import db
from langchain_core.tools import BaseTool
logger = log.getLogger(__name__)
class MCPQueryTool(BaseTool):
"""Tool that executes queries via MCP server"""
name: ClassVar[str] = "mcp_query"
description: ClassVar[str] = "Execute SQL queries against the MindsDB server via MCP protocol"
def __init__(self, session: ClientSession):
super().__init__()
self.session = session
async def _arun(self, query: str) -> str:
"""Execute a query via MCP asynchronously"""
try:
logger.info(f"Executing MCP query: {query}")
tools_response = await self.session.list_tools()
query_tool = None
for tool in tools_response.tools:
if tool.name == "query":
query_tool = tool
break
if not query_tool:
return "Error: No 'query' tool found in the MCP server"
result = await self.session.call_tool("query", {"query": query})
if isinstance(result.content, dict) and "data" in result.content and "column_names" in result.content:
df = pd.DataFrame(result.content["data"], columns=result.content["column_names"])
return df.to_string()
return f"Query executed successfully: {json.dumps(result.content)}"
except Exception as e:
logger.error("Error executing MCP query:")
return f"Error executing query: {e}"
def _run(self, query: str) -> str:
"""Synchronous wrapper for async query function"""
loop = asyncio.get_event_loop()
return loop.run_until_complete(self._arun(query))
class MCPLangchainAgent(LangchainAgent):
"""Extension of LangchainAgent that delegates to MCP server"""
def __init__(
self,
agent: db.Agents,
model: dict = None,
llm_params: dict = None,
mcp_host: str = "127.0.0.1",
mcp_port: int = 47337,
):
super().__init__(agent, model, llm_params)
self.mcp_host = mcp_host
self.mcp_port = mcp_port
self.exit_stack = AsyncExitStack()
self.session = None
self.stdio = None
self.write = None
async def connect_to_mcp(self):
"""Connect to the MCP server using stdio transport"""
if self.session is None:
logger.info(f"Connecting to MCP server at {self.mcp_host}:{self.mcp_port}")
try:
server_params = StdioServerParameters(
command="python",
args=["-m", "mindsdb", "--api=mcp"],
env={"MCP_HOST": self.mcp_host, "MCP_PORT": str(self.mcp_port)},
)
logger.info(f"Connecting to MCP server at {self.mcp_host}:{self.mcp_port}")
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
await self.session.initialize()
tools_response = await self.session.list_tools()
logger.info(
f"Successfully connected to MCP server. Available tools: {[tool.name for tool in tools_response.tools]}"
)
except Exception as e:
logger.exception("Failed to connect to MCP server:")
raise ConnectionError(f"Failed to connect to MCP server: {e}") from e
def _langchain_tools_from_skills(self, llm):
"""Override to add MCP query tool along with other tools"""
tools = super()._langchain_tools_from_skills(llm)
try:
loop = asyncio.get_event_loop()
if self.session is None:
loop.run_until_complete(self.connect_to_mcp())
if self.session:
tools.append(MCPQueryTool(self.session))
logger.info("Added MCP query tool to agent tools")
except Exception:
logger.exception("Failed to add MCP query tool:")
return tools
def get_completion(self, messages, stream: bool = False):
"""Override to ensure MCP connection is established before getting completion"""
try:
if self.session is None:
loop = asyncio.get_event_loop()
loop.run_until_complete(self.connect_to_mcp())
except Exception:
logger.exception("Failed to connect to MCP server:")
response = super().get_completion(messages, stream)
if hasattr(response, "to_string"):
return response.to_string()
return response
async def cleanup(self):
"""Clean up resources"""
if self.exit_stack:
await self.exit_stack.aclose()
self.session = None
self.stdio = None
self.write = None
class LiteLLMAgentWrapper:
"""Wrapper for MCPLangchainAgent that provides LiteLLM-compatible interface"""
def __init__(self, agent: MCPLangchainAgent):
self.agent = agent
async def acompletion(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
"""Async completion interface compatible with LiteLLM"""
formatted_messages = [
{
"question": msg["content"] if msg["role"] == "user" else "",
"answer": msg["content"] if msg["role"] == "assistant" else "",
}
for msg in messages
]
response = self.agent.get_completion(formatted_messages)
if not isinstance(response, str):
if hasattr(response, "to_string"):
response = response.to_string()
else:
response = str(response)
return {
"choices": [{"message": {"role": "assistant", "content": response}}],
"model": self.agent.args["model_name"],
"object": "chat.completion",
}
async def acompletion_stream(self, messages: List[Dict[str, str]], **kwargs) -> Iterator[Dict[str, Any]]:
"""Async streaming completion interface compatible with LiteLLM"""
formatted_messages = [
{
"question": msg["content"] if msg["role"] == "user" else "",
"answer": msg["content"] if msg["role"] == "assistant" else "",
}
for msg in messages
]
model_name = kwargs.get("model", self.agent.args.get("model_name", "mcp-agent"))
try:
for chunk in self.agent._get_completion_stream(formatted_messages):
content = chunk.get("output", "")
if content and isinstance(content, str):
yield {
"choices": [{"delta": {"role": "assistant", "content": content}}],
"model": model_name,
"object": "chat.completion.chunk",
}
await asyncio.sleep(0)
except Exception:
logger.exception("Streaming error:")
raise
async def cleanup(self):
"""Clean up resources"""
await self.agent.cleanup()
def create_mcp_agent(
agent_name: str, project_name: str, mcp_host: str = "127.0.0.1", mcp_port: int = 47337
) -> LiteLLMAgentWrapper:
"""Create an MCP agent and wrap it for LiteLLM compatibility"""
from mindsdb.interfaces.agents.agents_controller import AgentsController
from mindsdb.interfaces.storage import db
db.init()
agent_controller = AgentsController()
agent_db = agent_controller.get_agent(agent_name, project_name)
if agent_db is None:
raise ValueError(f"Agent {agent_name} not found in project {project_name}")
llm_params = agent_controller.get_agent_llm_params(agent_db.params)
mcp_agent = MCPLangchainAgent(agent_db, llm_params=llm_params, mcp_host=mcp_host, mcp_port=mcp_port)
return LiteLLMAgentWrapper(mcp_agent)