"""MCP宿主"""
import json
import logging
from typing import Any
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
from apps.llm.function import JsonGenerator
from apps.llm.reasoning import ReasoningLLM
from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE
from apps.scheduler.mcp_agent.base import MCPBase
from apps.scheduler.mcp_agent.prompt import GEN_PARAMS, REPAIR_PARAMS
from apps.schemas.mcp import MCPTool
from apps.schemas.task import Task
from apps.schemas.enum_var import LanguageType
logger = logging.getLogger(__name__)
_env = SandboxedEnvironment(
loader=BaseLoader,
autoescape=False,
trim_blocks=True,
lstrip_blocks=True,
)
def tojson_filter(value):
return json.dumps(value, ensure_ascii=False, separators=(',', ':'))
_env.filters["tojson"] = tojson_filter
LLM_QUERY_FIX = {
LanguageType.CHINESE: "请生成修复之后的工具参数",
LanguageType.ENGLISH: "Please generate the tool parameters after repair",
}
class MCPHost(MCPBase):
"""MCP宿主服务"""
@staticmethod
async def assemble_memory(task: Task) -> str:
"""组装记忆"""
return _env.from_string(MEMORY_TEMPLATE[task.language]).render(
context_list=task.context,
)
@staticmethod
async def _get_first_input_params(
mcp_tool: MCPTool,
goal: str,
current_goal: str,
task: Task,
resoning_llm: ReasoningLLM = ReasoningLLM(),
) -> dict[str, Any]:
"""填充工具参数"""
prompt = _env.from_string(GEN_PARAMS[task.language]).render(
tool_name=mcp_tool.name,
tool_description=mcp_tool.description,
goal=goal,
current_goal=current_goal,
input_schema=mcp_tool.input_schema,
background_info=await MCPHost.assemble_memory(task),
)
result = await MCPHost.get_resoning_result(prompt, resoning_llm)
result = await MCPHost._parse_result(
result,
mcp_tool.input_schema,
)
return result
@staticmethod
async def _fill_params(
mcp_tool: MCPTool,
goal: str,
current_goal: str,
current_input: dict[str, Any],
error_message: str = "",
params: dict[str, Any] = {},
params_description: str = "",
language: LanguageType = LanguageType.CHINESE,
) -> dict[str, Any]:
llm_query = LLM_QUERY_FIX[language]
prompt = _env.from_string(REPAIR_PARAMS[language]).render(
tool_name=mcp_tool.name,
goal=goal,
current_goal=current_goal,
tool_description=mcp_tool.description,
input_schema=mcp_tool.input_schema,
input_params=current_input,
error_message=error_message,
params=params,
params_description=params_description,
)
json_generator = JsonGenerator(
llm_query,
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
mcp_tool.input_schema,
)
return await json_generator.generate()