from typing import Literal, List
from pydantic import BaseModel, Field, field_validator
from openjiuwen_deepsearch.utils.validation_utils.param_validation import (
SAFE_CONVERSATION_ID_PATTERN,
)
_CONVERSATION_ID_SCHEMA_ERR = (
"conversation_id must be 1–128 characters and use only ASCII letters, digits, "
"underscore, or hyphen (^[A-Za-z0-9_-]{1,128}$)."
)
class WebSearchConfig(BaseModel):
web_search_config_id: int = Field(description="联网增强引擎ID")
max_web_search_results: int = Field(default=5, ge=1, le=10, description="一次网页搜索的最大返回结果数量")
class LocalSearchConfig(BaseModel):
local_search_config_ids: List[str] = Field(default=[], description="本地知识库ID列表")
max_local_search_results: int = Field(default=5, ge=1, le=10, description="最大本地搜索结果数量")
recall_threshold: float = Field(default=0.5, ge=0.0, le=1.0, description="知识库检索阈值")
class PluginToolParam(BaseModel):
name: str = Field(..., description="参数名称")
desc: str = Field(default="", description="参数描述")
type: int = Field(default=1, description="参数类型")
is_required: bool = Field(default=False, description="是否必填")
method: int = Field(default=0, description="参数发送方式")
is_runtime: bool = Field(default=False, description="是否运行时参数")
value: str = Field(default="", description="默认值")
priority: int = Field(default=0, description="参数优先级")
class PluginApiHeader(BaseModel):
name: str = Field(..., description="请求头名称")
value: str = Field(default="", description="请求头值")
class RuntimeApiToolRequest(BaseModel):
tool_id: str = Field(..., description="工具ID")
name: str = Field(..., description="工具名称")
desc: str = Field(default="", description="工具描述")
response_wrapper: str = Field(default="", description="响应包装器类型,当前支持 search_result")
path: str = Field(..., description="接口路径或完整 URL")
method: int | str = Field(..., description="HTTP 方法")
request_params: List[PluginToolParam] = Field(default_factory=list, description="请求参数")
response_params: List[PluginToolParam] = Field(
default_factory=list,
description="响应参数定义,当前仅透传保留,实际返回结构请优先使用 response_wrapper 约束",
)
headers: List[PluginApiHeader] = Field(default_factory=list, description="默认请求头")
plugin_id: str = Field(default="", description="插件ID")
plugin_version: str | None = Field(default=None, description="插件版本")
url: str = Field(default="", description="插件服务基址")
class DeepSearchRequest(BaseModel):
space_id: str = Field(..., description="用户空间ID")
conversation_id: str = Field(..., description="请求对话ID")
message: str = Field(..., description="用户请求查询或者人机交互时的反馈")
workflow_human_in_the_loop: bool = Field(default=True, description="是否启用人机交互")
outliner_max_section_num: int = Field(default=10, ge=1, le=15, description="最大规划章节数量,取值范围:[1,15]")
source_tracer_research_trace_source_switch: bool = Field(default=True, description="溯源功能开关")
source_tracer_generated_citation_switch: bool = Field(default=True, description="新增引用生成开关")
source_tracer_infer_switch: bool = Field(default=True, description="溯源推理功能开关")
info_collector_search_method: Literal["web", "local", "all"] = Field(default="web",
description="搜索方式:"
"web: 联网搜索"
"local: 本地搜索工具搜索"
"all : 联网+本地融合搜索")
llm_config: dict = Field(default_factory=dict, description="LLM配置")
web_search_config: WebSearchConfig = Field(default=None, description="联网增强引擎配置,和本地知识库配置至少选择一个")
local_search_config: LocalSearchConfig = Field(default=None,
description="本地知识库配置,和联网增强引擎配置至少选择一个")
template_id: int = Field(default=-1, description="报告模板ID(可选)")
interrupt_feedback: Literal[
"", "accepted", "cancel", "revise_outline", "revise_comment"
] = Field(default="", description="中断反馈标识(可选)")
outline_interaction_enabled: bool = Field(default=False, description="大纲交互开关")
outline_interaction_max_rounds: int = Field(default=3, ge=1, le=100, description="大纲交互最大轮次")
search_mode: Literal["research", "search", "react"] = Field(
default="research",
description="research: 报告; search: DeepSearch; react: 简单 ReAct",
)
enable_question_router: bool = Field(
default=False,
description="search 模式下先经 LLM 路由:0→react,1→search(DeepSearch)",
)
execution_method: Literal["parallel", "dependency_driving"] = Field(default="parallel",
description="执行方法:"
"parallel: 并行工作流执行"
"dependency_driving: 依赖驱动工作流执行")
web_search_max_qps: float = Field(default=0, description="联网增强引擎最大 QPS,0 表示不限流,支持浮点数如 0.5 表示每 2 秒 1 个请求")
user_feedback_processor_enable: bool = Field(default=False, description="是否启用用户反馈优化功能")
user_feedback_processor_max_interactions: int = Field(default=100, ge=1, le=100, description="最大交互次数")
stats_info_llm: bool = Field(default=False, description="是否开启 LLM 调用统计")
tools: List[RuntimeApiToolRequest] = Field(default_factory=list, description="前端传入的 API 工具列表")
vlm_chart_generator_enable: bool = Field(default=False, description="vlm迭代生成图开关")
vlm_chart_generator_max_iterations: int = Field(default=1, ge=0, le=3, description="vlm迭代生成图最大迭代次数,0表示不进行迭代")
agent_llm_timeouts: dict[str, int] = Field(default_factory=dict, description="按 agent 配置的 LLM 总超时时间")
@field_validator("conversation_id")
@classmethod
def _validate_conversation_id_chars(cls, v: str) -> str:
if not SAFE_CONVERSATION_ID_PATTERN.fullmatch(v):
raise ValueError(_CONVERSATION_ID_SCHEMA_ERR)
return v