"""MCP宿主"""
import json
import logging
import uuid
from typing import Any
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
from mcp.types import TextContent
from apps.llm import LLM, json_generator
from apps.models import LanguageType, MCPTools
from apps.scheduler.pool.mcp.client import MCPClient
from apps.scheduler.pool.mcp.pool import mcp_pool
from apps.schemas.mcp import MCPContext, MCPPlanItem
from apps.services.mcp_service import MCPServiceManager
from .base import MCPNodeBase
from .prompt import MEMORY_TEMPLATE
logger = logging.getLogger(__name__)
class MCPHost(MCPNodeBase):
"""MCP宿主服务"""
def __init__(self, user_id: str, task_id: uuid.UUID, llm: LLM, language: LanguageType) -> None:
"""初始化MCP宿主"""
super().__init__(llm, language)
self._task_id = task_id
self._user_id = user_id
self._context_list = []
self._env = SandboxedEnvironment(
loader=BaseLoader(),
autoescape=False,
trim_blocks=True,
lstrip_blocks=True,
)
async def init(self) -> None:
"""初始化MCP宿主"""
async def get_client(self, mcp_id: str) -> MCPClient | None:
"""获取MCP客户端"""
try:
return await mcp_pool.get(mcp_id, self._user_id)
except (KeyError, RuntimeError) as e:
logger.warning("获取MCP客户端失败: %s", e)
return None
async def assemble_memory(self) -> list[dict[str, str]]:
"""组装记忆,返回虚拟的用户与助手间的对话历史"""
context_list = self._context_list
conversation_history = []
template = self._env.from_string(MEMORY_TEMPLATE[self._language])
for index, ctx in enumerate(context_list, start=1):
user_message = template.render(
msg_type="user",
step_index=index,
step_description=ctx.step_description,
step_name=ctx.step_name,
input_data=ctx.input_data,
)
conversation_history.append({
"role": "user",
"content": user_message.strip(),
})
assistant_message = template.render(
msg_type="assistant",
step_index=index,
step_status=ctx.step_status,
output_data=ctx.output_data,
)
conversation_history.append({
"role": "assistant",
"content": assistant_message.strip(),
})
return conversation_history
async def _save_memory(
self,
tool: MCPTools,
plan_item: MCPPlanItem,
input_data: dict[str, Any],
result: str,
) -> dict[str, Any]:
"""保存记忆"""
try:
output_data = json.loads(result)
except Exception:
logger.warning("[MCPHost] 得到的数据不是dict格式!尝试转换为str")
output_data = {
"message": result,
}
if not isinstance(output_data, dict):
output_data = {
"message": result,
}
context = MCPContext(
step_description=plan_item.content,
input_data=input_data,
output_data=output_data,
)
self._context_list.append(context)
return output_data
async def _fill_params(self, tool: MCPTools, query: str) -> dict[str, Any]:
"""填充工具参数"""
template = self._env.from_string(await self._load_prompt("gen_params"))
llm_query = template.render(
current_goal=query,
goal=query,
tool_name=tool.toolName,
tool_description=tool.description,
input_schema=json.dumps(tool.inputSchema, ensure_ascii=False),
)
function_definition = {
"name": tool.toolName,
"description": tool.description,
"parameters": tool.inputSchema,
}
memory_conversation = await self.assemble_memory()
return await json_generator.generate(
function=function_definition,
conversation=[
*memory_conversation,
],
prompt=llm_query,
)
async def call_tool(self, tool: MCPTools, plan_item: MCPPlanItem) -> list[dict[str, Any]]:
"""调用工具"""
client = await mcp_pool.get(tool.mcpId, self._user_id)
params = await self._fill_params(tool, plan_item.instruction)
result = await client.call_tool(tool.toolName, params)
processed_result = []
for item in result.content:
if not isinstance(item, TextContent):
logger.error("MCP结果类型不支持: %s", item)
continue
processed_result.append(await self._save_memory(tool, plan_item, params, item.text))
return processed_result
async def get_tool_list(self, mcp_id_list: list[str]) -> list[MCPTools]:
"""获取工具列表"""
tool_list = []
for mcp_id in mcp_id_list:
try:
tool_list.extend(await MCPServiceManager.get_mcp_tools(mcp_id))
except KeyError:
logger.warning("用户 %s 的MCP Tool %s 配置错误", self._user_id, mcp_id)
continue
return tool_list