"""选择MCP Server及其工具"""
import logging
import random
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
from apps.llm.reasoning import ReasoningLLM
from apps.llm.token import TokenCalculator
from apps.scheduler.mcp_agent.base import MCPBase
from apps.scheduler.mcp_agent.prompt import TOOL_SELECT
from apps.schemas.mcp import MCPTool, MCPToolIdsSelectResult
from apps.schemas.enum_var import LanguageType
logger = logging.getLogger(__name__)
_env = SandboxedEnvironment(
loader=BaseLoader,
autoescape=True,
trim_blocks=True,
lstrip_blocks=True,
)
FINAL_TOOL_ID = "FIANL"
SUMMARIZE_TOOL_ID = "SUMMARIZE"
SELF_DESC_TOOL_ID = "SELF_DESC"
class MCPSelector(MCPBase):
"""MCP选择器"""
@staticmethod
async def select_top_tool(
goal: str,
tool_list: list[MCPTool],
additional_info: str | None = None,
top_n: int | None = None,
reasoning_llm: ReasoningLLM | None = None,
language: LanguageType = LanguageType.CHINESE,
) -> list[MCPTool]:
"""选择最合适的工具"""
random.shuffle(tool_list)
max_tokens = reasoning_llm._config.max_tokens
template = _env.from_string(TOOL_SELECT[language])
token_calculator = TokenCalculator()
if (
token_calculator.calculate_token_length(
messages=[
{
"role": "user",
"content": template.render(goal=goal, tools=[], additional_info=additional_info),
}
],
pure_text=True,
)
> max_tokens
):
logger.warning("[MCPSelector] 工具选择模板长度超过最大令牌数,无法进行选择")
return []
current_index = 0
tool_ids = []
while current_index < len(tool_list):
index = current_index
sub_tools = []
while index < len(tool_list):
tool = tool_list[index]
tokens = token_calculator.calculate_token_length(
messages=[
{
"role": "user",
"content": template.render(
goal=goal, tools=[tool], additional_info=additional_info
),
}
],
pure_text=True,
)
if tokens > max_tokens:
continue
sub_tools.append(tool)
tokens = token_calculator.calculate_token_length(
messages=[
{
"role": "user",
"content": template.render(
goal=goal, tools=sub_tools, additional_info=additional_info
),
},
],
pure_text=True,
)
if tokens > max_tokens:
del sub_tools[-1]
break
else:
index += 1
current_index = index
if sub_tools:
schema = MCPToolIdsSelectResult.model_json_schema()
if "items" not in schema["properties"]["tool_ids"]:
schema["properties"]["tool_ids"]["items"] = {}
schema["properties"]["tool_ids"]["items"]["enum"] = [tool.id for tool in sub_tools]
result = await MCPSelector.get_resoning_result(
template.render(goal=goal, tools=sub_tools, additional_info="请根据目标选择对应的工具"),
reasoning_llm,
)
result = await MCPSelector._parse_result(result, schema)
try:
result = MCPToolIdsSelectResult.model_validate(result)
tool_ids.extend(result.tool_ids)
except Exception:
logger.exception("[MCPSelector] 解析MCP工具ID选择结果失败")
continue
mcp_tools = [tool for tool in tool_list if tool.id in tool_ids]
if top_n is not None:
mcp_tools = mcp_tools[:top_n]
mcp_tools.append(
MCPTool(id=FINAL_TOOL_ID, name="Final", description="终止", mcp_id=FINAL_TOOL_ID, input_schema={})
)
return mcp_tools