import logging
import json
from dataclasses import dataclass, field
from typing import List, Dict, NamedTuple, Set
import copy
from openjiuwen_deepsearch.utils.constants_utils.session_contextvars import llm_context
from openjiuwen_deepsearch.algorithm.prompts.template import apply_system_prompt
from openjiuwen_deepsearch.utils.common_utils.llm_utils import ainvoke_llm_with_stats, normalize_json_output
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
from openjiuwen_deepsearch.utils.log_utils.log_manager import LogManager
from openjiuwen_deepsearch.utils.constants_utils.node_constants import AgentLlmName
logger = logging.getLogger(__name__)
MAX_LLM_RETRY_TIMES = 3
class GraphInfo(NamedTuple):
structured_inference: List[List]
node_map: Dict
citation_ids: List[int]
conclusion_ids: List[int]
@dataclass
class NumberNodeParam:
"""
用于在编号过程中在多个函数之间传递的中间状态。
使用 dataclass + default_factory,避免可变默认参数在不同实例之间共享。
"""
node_set: Set = field(default_factory=set)
node_map: Dict = field(default_factory=dict)
node_index: int = 0
citation_ids: Set[int] = field(default_factory=set)
conclusion_ids: Set[int] = field(default_factory=set)
tail_id: int = -1
head_id_list: List = field(default_factory=list)
structured_inference: List = field(default_factory=list)
def update_structured_inference(self, relation: str):
self.structured_inference.append([copy.deepcopy(self.head_id_list), relation, self.tail_id])
def type_check(result, expected_type):
"""校验结果类型是否符合预期类型。"""
if not isinstance(result, expected_type):
error_msg = f"[SOURCE TRACER INFER]: 生成结果类型错误, 生成结果类型{type(result)}, 期望类型为{expected_type}"
raise CustomValueException(StatusCode.SOURCE_TRACER_INFER_DATA_TYPE_ERROR.code,
StatusCode.SOURCE_TRACER_INFER_DATA_TYPE_ERROR.errmsg.
format(e=error_msg))
def is_equal_length(result, target):
"""校验结果是否为固定长度的结构。"""
type_check(result, list)
for r in result:
type_check(r, list)
if len(r) != target:
error_msg = f"[SOURCE TRACER INFER]: 生成结果数量错误,"
error_msg += f"生成结果数量{len(result)}, 目标数量{target}"
raise CustomValueException(StatusCode.SOURCE_TRACER_INFER_DATA_LEN_ERROR.code,
StatusCode.SOURCE_TRACER_INFER_DATA_LEN_ERROR.errmsg.
format(e=error_msg))
async def call_model(model_name: str, prompt: str, user_input: dict,
detection_func_and_args: dict = None,
agent_name: str = AgentLlmName.SOURCE_TRACER_INFER.value):
"""调用LLM模型处理请求
调用指定的LLM模型处理用户提示,并返回标准化的JSON格式输出
Args:
model_name: llm调用名称
prompt: prompt文件名
user_input: 需要处理的输入数据
detection_func_and_args: 输出检测函数和参数
Returns:
str: 标准化的JSON格式输出字符串
"""
retries = 0
while retries < MAX_LLM_RETRY_TIMES:
try:
user_prompt = apply_system_prompt(prompt, user_input)
llm = llm_context.get().get(model_name)
response = await ainvoke_llm_with_stats(llm, user_prompt, agent_name=agent_name)
content = response.get("content", "")
content = normalize_json_output(content)
llm_result = json.loads(content.replace("```json", "").replace("```", ""))
if detection_func_and_args:
detection_func = detection_func_and_args.get("detection_func")
params = detection_func_and_args.get("args")
detection_func(llm_result, params)
return llm_result
except CustomValueException as e:
retries += 1
logger.warning(f'[SOURCE TRACER INFER] retry: {retries}/{MAX_LLM_RETRY_TIMES}, '
f'call_model error {e}')
except Exception as e:
retries += 1
if LogManager.is_sensitive():
logger.warning(f'[SOURCE TRACER INFER] retry: {retries}/{MAX_LLM_RETRY_TIMES}, '
f'call_model error')
else:
logger.warning(f'[SOURCE TRACER INFER] retry: {retries}/{MAX_LLM_RETRY_TIMES}, '
f'call_model error {e}')
logger.error(f'[SOURCE TRACER INFER] retry {MAX_LLM_RETRY_TIMES} times, call_model error')
return []