import re
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
from openjiuwen_deepsearch.config.config import LLMConfig
from openjiuwen_deepsearch.utils.validation_utils.field_validation import validate_not_empty_field, \
validate_str_field, validate_bool_field
SAFE_CONVERSATION_ID_PATTERN = re.compile(r"^[A-Za-z0-9_-]{1,128}$")
def validate_generate_template_params(
file_name: str,
file_stream: str,
is_template,
):
"""
校验 generate template 入参
"""
validate_not_empty_field("file_name", file_name)
validate_str_field("file_name", file_name)
validate_not_empty_field("file_stream", file_stream)
validate_str_field("file_stream", file_stream)
validate_bool_field("is_template", is_template)
def validate_run_agent_params(
message: str,
conversation_id: str,
report_template: str = "",
interrupt_feedback: str = "",
):
"""
校验 run agent 入参
"""
allow_empty_message_for_accept = interrupt_feedback == "accepted" and message == ""
if not allow_empty_message_for_accept:
validate_not_empty_field("message", message)
validate_str_field("message", message)
validate_not_empty_field("conversation_id", conversation_id)
validate_str_field("conversation_id", conversation_id)
if not SAFE_CONVERSATION_ID_PATTERN.fullmatch(conversation_id):
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_COMMON_INVALID.code,
StatusCode.PARAM_CHECK_ERROR_COMMON_INVALID.errmsg.format(param="conversation_id"),
)
validate_str_field("report_template", report_template)
validate_str_field("interrupt_feedback", interrupt_feedback)
allowed_interrupt_feedback = ("", "accepted", "cancel", "revise_outline", "revise_comment")
if interrupt_feedback.strip() and interrupt_feedback not in allowed_interrupt_feedback:
raise CustomValueException(StatusCode.PARAM_CHECK_ERROR_INTERRUPT_FEEDBACK_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_INTERRUPT_FEEDBACK_ERROR.errmsg)
def validate_llm_obj_params(input_config: LLMConfig):
"""
校验自定义的 llm 入参
"""
llm_config = input_config.model_dump()
model_provider = llm_config.get("model_type", "")
api_base = llm_config.get("base_url", "")
model_name = llm_config.get("model_name", "")
validate_not_empty_field("model_provider", model_provider)
validate_str_field("model_provider", model_provider)
validate_not_empty_field("api_base", api_base)
validate_str_field("api_base", api_base)
validate_not_empty_field("model_name", model_name)
validate_str_field("model_name", model_name)