import asyncio
import argparse
import json
from typing import List, Dict, Optional
from contextlib import AsyncExitStack
import uvicorn
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mindsdb.utilities import log
from mindsdb.interfaces.agents.mcp_client_agent import create_mcp_agent
logger = log.getLogger(__name__)
app = FastAPI(title="MindsDB MCP Agent LiteLLM API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
agent_wrapper = None
mcp_session = None
exit_stack = AsyncExitStack()
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
stream: bool = False
temperature: Optional[float] = None
max_tokens: Optional[int] = None
class ChatCompletionChoice(BaseModel):
index: int = 0
message: Optional[Dict[str, str]] = None
delta: Optional[Dict[str, str]] = None
finish_reason: Optional[str] = "stop"
class ChatCompletionResponse(BaseModel):
id: str = "mcp-agent-response"
object: str = "chat.completion"
created: int = 0
model: str
choices: List[ChatCompletionChoice]
usage: Dict[str, int] = Field(
default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
)
class DirectSQLRequest(BaseModel):
query: str
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
global agent_wrapper
if agent_wrapper is None:
raise HTTPException(
status_code=500,
detail="Agent not initialized. Make sure MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http",
)
try:
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
if request.stream:
async def generate():
try:
async for chunk in agent_wrapper.acompletion_stream(messages, model=request.model):
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
except Exception:
logger.exception("Streaming error:")
yield "data: {{'error': 'Streaming failed due to an internal error.'}}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
else:
response = await agent_wrapper.acompletion(messages)
content = response["choices"][0]["message"].get("content", "")
if not isinstance(content, str):
content = str(content)
return ChatCompletionResponse(
model=request.model, choices=[ChatCompletionChoice(message={"role": "assistant", "content": content})]
)
except Exception as e:
logger.exception("Error in chat completion:")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/direct-sql")
async def direct_sql(request: DirectSQLRequest, background_tasks: BackgroundTasks):
"""Execute a direct SQL query via MCP (for testing)"""
global agent_wrapper, mcp_session
if agent_wrapper is None and mcp_session is None:
raise HTTPException(
status_code=500, detail="No MCP session available. Make sure MindsDB server is running with MCP enabled."
)
try:
if hasattr(agent_wrapper.agent, "session") and agent_wrapper.agent.session:
session = agent_wrapper.agent.session
result = await session.call_tool("query", {"query": request.query})
return {"result": result.content}
elif mcp_session:
result = await mcp_session.call_tool("query", {"query": request.query})
return {"result": result.content}
else:
raise HTTPException(status_code=500, detail="No MCP session available")
except Exception as e:
logger.exception("Error executing direct SQL:")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/models")
async def list_models():
"""List available models - always returns the single model we're using"""
global agent_wrapper
if agent_wrapper is None:
return {"object": "list", "data": [{"id": "mcp-agent", "object": "model", "created": 0, "owned_by": "mindsdb"}]}
model_name = agent_wrapper.agent.args.get("model_name", "mcp-agent")
return {"object": "list", "data": [{"id": model_name, "object": "model", "created": 0, "owned_by": "mindsdb"}]}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
global agent_wrapper
health_status = {
"status": "ok",
"agent_initialized": agent_wrapper is not None,
}
if agent_wrapper is not None:
health_status["mcp_connected"] = (
hasattr(agent_wrapper.agent, "session") and agent_wrapper.agent.session is not None
)
health_status["agent_name"] = agent_wrapper.agent.agent.name
health_status["model_name"] = agent_wrapper.agent.args.get("model_name", "unknown")
return health_status
@app.get("/test-mcp-connection")
async def test_mcp_connection():
"""Test the connection to the MCP server"""
global mcp_session, exit_stack
try:
if mcp_session:
try:
tools_response = await mcp_session.list_tools()
return {
"status": "ok",
"message": "Successfully connected to MCP server",
"tools": [tool.name for tool in tools_response.tools],
}
except Exception:
await exit_stack.aclose()
mcp_session = None
server_params = StdioServerParameters(command="python", args=["-m", "mindsdb", "--api=mcp"], env=None)
stdio_transport = await exit_stack.enter_async_context(stdio_client(server_params))
stdio, write = stdio_transport
session = await exit_stack.enter_async_context(ClientSession(stdio, write))
await session.initialize()
mcp_session = session
tools_response = await session.list_tools()
return {
"status": "ok",
"message": "Successfully connected to MCP server",
"tools": [tool.name for tool in tools_response.tools],
}
except Exception as e:
logger.exception("Error connecting to MCP server:")
error_detail = f"Error connecting to MCP server: {str(e)}. Make sure MindsDB server is running with HTTP enabled: python -m mindsdb --api=http"
raise HTTPException(status_code=500, detail=error_detail)
async def init_agent(agent_name: str, project_name: str, mcp_host: str, mcp_port: int):
"""Initialize the agent"""
global agent_wrapper
try:
logger.info(f"Initializing MCP agent '{agent_name}' in project '{project_name}'")
logger.info(f"Connecting to MCP server at {mcp_host}:{mcp_port}")
logger.info("Make sure MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http")
agent_wrapper = create_mcp_agent(
agent_name=agent_name, project_name=project_name, mcp_host=mcp_host, mcp_port=mcp_port
)
logger.info("Agent initialized successfully")
return True
except Exception:
logger.exception("Failed to initialize agent:")
return False
@app.on_event("shutdown")
async def shutdown_event():
"""Clean up resources on server shutdown"""
global agent_wrapper, exit_stack
if agent_wrapper:
await agent_wrapper.cleanup()
await exit_stack.aclose()
async def run_server_async(
agent_name: str,
project_name: str = "mindsdb",
mcp_host: str = "127.0.0.1",
mcp_port: int = 47337,
host: str = "0.0.0.0",
port: int = 8000,
):
"""Run the FastAPI server"""
success = await init_agent(agent_name, project_name, mcp_host, mcp_port)
if not success:
logger.error("Failed to initialize agent. Make sure MindsDB server is running with MCP enabled.")
return 1
return 0
def run_server(
agent_name: str,
project_name: str = "mindsdb",
mcp_host: str = "127.0.0.1",
mcp_port: int = 47337,
host: str = "0.0.0.0",
port: int = 8000,
):
"""Run the FastAPI server"""
logger.info("Make sure MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http")
from mindsdb.interfaces.storage import db
db.init()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(run_server_async(agent_name, project_name, mcp_host, mcp_port))
if result != 0:
return result
logger.info(f"Starting server on {host}:{port}")
uvicorn.run(app, host=host, port=port)
return 0
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a LiteLLM-compatible API server for MCP agent")
parser.add_argument("--agent", type=str, required=True, help="Name of the agent to use")
parser.add_argument("--project", type=str, default="mindsdb", help="Project containing the agent")
parser.add_argument("--mcp-host", type=str, default="127.0.0.1", help="MCP server host")
parser.add_argument("--mcp-port", type=int, default=47337, help="MCP server port")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to")
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
args = parser.parse_args()
run_server(
agent_name=args.agent,
project_name=args.project,
mcp_host=args.mcp_host,
mcp_port=args.mcp_port,
host=args.host,
port=args.port,
)