"""MCP Client"""
import asyncio
import logging
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from apps.constants import MCP_PATH
from apps.schemas.mcp import (
MCPServerSSEConfig,
MCPServerStdioConfig,
MCPStatus,
)
if TYPE_CHECKING:
from mcp.types import CallToolResult
logger = logging.getLogger(__name__)
class MCPClient:
"""MCP客户端基类"""
mcp_id: str
task: asyncio.Task
ready_sign: asyncio.Event
error_sign: asyncio.Event
stop_sign: asyncio.Event
client: ClientSession
status: MCPStatus
def __init__(self) -> None:
"""初始化MCP Client"""
self.status = MCPStatus.UNINITIALIZED
async def _main_loop(
self, user_id: str | None, mcp_id: str, config: MCPServerSSEConfig | MCPServerStdioConfig,
) -> None:
"""
创建MCP Client
抽象函数;作用为在初始化的时候使用MCP SDK创建Client
由于目前MCP的实现中Client和Session是1:1的关系,所以直接创建了 :class:`~mcp.ClientSession`
:param str user_id: 用户ID
:param str mcp_id: MCP ID
:param MCPServerSSEConfig | MCPServerStdioConfig config: MCP配置
:return: MCP ClientSession
:rtype: ClientSession
"""
if isinstance(config, MCPServerSSEConfig):
headers = config.headers or {}
client = sse_client(
url=config.url,
headers=headers,
)
elif isinstance(config, MCPServerStdioConfig):
if user_id:
cwd = MCP_PATH / "users" / user_id / mcp_id / "project"
else:
cwd = MCP_PATH / "template" / mcp_id / "project"
await cwd.mkdir(parents=True, exist_ok=True)
client = stdio_client(server=StdioServerParameters(
command=config.command,
args=config.args,
env=config.env,
cwd=cwd.as_posix(),
))
try:
exit_stack = AsyncExitStack()
read, write = await exit_stack.enter_async_context(client)
self.client = ClientSession(read, write)
session = await exit_stack.enter_async_context(self.client)
await session.initialize()
except Exception:
self.error_sign.set()
self.status = MCPStatus.STOPPED
logger.exception("[MCPClient] MCP %s:初始化失败", mcp_id)
raise
self.ready_sign.set()
self.status = MCPStatus.RUNNING
await self.stop_sign.wait()
logger.info("[MCPClient] MCP %s:收到停止信号,正在关闭", mcp_id)
try:
await exit_stack.aclose()
self.status = MCPStatus.STOPPED
except Exception:
logger.exception("[MCPClient] MCP %s:关闭失败", mcp_id)
async def init(self, user_id: str | None, mcp_id: str, config: MCPServerSSEConfig | MCPServerStdioConfig) -> None:
"""
初始化 MCP Client类
初始化MCP Client,并创建MCP Server进程和ClientSession
:param str user_id: 用户ID
:param str mcp_id: MCP ID
:param MCPServerSSEConfig | MCPServerStdioConfig config: MCP配置
:return: None
"""
self.mcp_id = mcp_id
self.ready_sign = asyncio.Event()
self.error_sign = asyncio.Event()
self.stop_sign = asyncio.Event()
self.task = asyncio.create_task(self._main_loop(user_id, mcp_id, config))
done, pending = await asyncio.wait(
[asyncio.create_task(self.ready_sign.wait()),
asyncio.create_task(self.error_sign.wait())],
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
if self.error_sign.is_set():
self.status = MCPStatus.ERROR
err_msg = f"[MCPClient] MCP {mcp_id} 初始化失败"
logger.error(err_msg)
raise RuntimeError(err_msg)
self.tools = (await self.client.list_tools()).tools
async def call_tool(self, tool_name: str, params: dict) -> "CallToolResult":
"""调用MCP Server的工具"""
return await self.client.call_tool(tool_name, params)
async def stop(self) -> None:
"""停止MCP Client"""
self.stop_sign.set()
try:
await self.task
except Exception as e:
logger.warning("[MCPClient] MCP %s:停止时发生异常:%s", self.mcp_id, e)