import json
import logging
import asyncio
from typing import List, Dict, Tuple
import base64
from openjiuwen_deepsearch.utils.log_utils.log_manager import LogManager
from openjiuwen_deepsearch.algorithm.source_tracer_infer.infer_extract_info import ResearchInferPreprocess
from openjiuwen_deepsearch.algorithm.source_tracer_infer.number_node import NumberNode
from openjiuwen_deepsearch.algorithm.source_tracer_infer.supplement_graph import SupplementGraph
from openjiuwen_deepsearch.algorithm.source_tracer_infer.generate_html import GenerateHTML
from openjiuwen_deepsearch.algorithm.source_tracer_infer.infer_call_model import (call_model, is_equal_length,
type_check, GraphInfo)
from openjiuwen_deepsearch.utils.constants_utils.node_constants import AgentLlmName
logger = logging.getLogger(__name__)
class SourceTracerInfer:
def __init__(self, context):
self.context = context
self.language = context.get("language", "zh-CN")
self.model_name = context.get("llm_model_name", "")
self.response = context.get("source_tracer_response", "")
self.conclusion_with_records = context.get("conclusion_with_records", None)
self.checker_infos = {"graph_infos": [], "search_records": []}
self.node_number = NumberNode()
self.supplement_graph = SupplementGraph(context.get("llm_model_name", ""))
self.generate_html = GenerateHTML(context.get("language", "zh-CN"))
async def run(self) -> Tuple[str, List[Dict], List[GraphInfo], str]:
"""执行溯源校验
Returns:
response: 处理后的报告文本
infer_messages: 溯源推理输出字段
check_infos: 溯源推理校验模块所需数据
error: 错误信息
"""
logger.info(f"[SOURCE TRACER INFER] run starting...")
logger.debug("[SOURCE TRACER INFER] The response before Source Tracer Infer:\n %s", self.response)
infer_messages = []
error = None
try:
await self.get_conclusion_and_records()
checked_infer_graphs = []
final_conclusion_info = []
task = [self.async_run(
{"conclusion": item.get("conclusion", []),
"search_records": item.get("search_records", [])}) for item in self.conclusion_with_records]
results = await asyncio.gather(*task)
for index, (infer_message, checked_infer_graph) in enumerate(results):
if infer_message.get("html_base64", ""):
infer_message["id"] = index
infer_messages.append(infer_message)
checked_infer_graphs.append(checked_infer_graph)
final_conclusion_info.append(self.conclusion_with_records[index])
response = self.mark_conclusion_in_report(infer_messages, final_conclusion_info)
self.checker_infos["graph_infos"] = checked_infer_graphs
self.checker_infos["search_records"] = [info.get("search_records", []) for info in final_conclusion_info]
except Exception as e:
if LogManager.is_sensitive():
logger.error(f"[SOURCE TRACER INFER] run error: **")
else:
logger.error(f"[SOURCE TRACER INFER] run error: {repr(e)}")
error = str(e)
response = self.response
infer_messages = []
logger.debug("[SOURCE TRACER INFER] The response after Source Tracer Infer:\n %s", response)
logger.info(f"[SOURCE TRACER INFER] run end.")
return response, infer_messages, self.checker_infos, error
async def async_run(self, datas: Dict) -> Tuple[Dict, GraphInfo]:
"""异步执行每个结论推理图的绘制程序
Args:
datas: 单条结论和对应的搜索记录
Returns:
html_file_path: 结论生成的推理图相对路径
checked_infer_graphs: 编号后最终的图抽象数据,包含(structured_inference, node_map, citation_ids, conclusion_ids)
"""
logger.info(f"[SOURCE TRACER INFER] async_run starting...")
checked_infer_graphs = None
infer_message = {}
inferences: Dict = {}
try:
conclusion_and_evidences = await self.extract_reference(datas)
inferences = await self.infer(conclusion_and_evidences)
inferences = await self.filter_invalid_infer(inferences)
structured_inferences = await self.structured_infer(inferences)
infer_graphs = self.node_number.number_node(structured_inferences,
conclusion=conclusion_and_evidences.get("conclusion", ""),
search_records=datas.get("search_records", []))
checked_infer_graphs = await self.supplement_graph.run(infer_graphs)
html_content = self.generate_html.run(checked_infer_graphs)
infer_message["conclusion"] = inferences.get("conclusion", "")
infer_message["inference"] = inferences.get("inference", "")
infer_message["html_base64"] = self._encode_html_to_base64(html_content)
except Exception as e:
if LogManager.is_sensitive():
error_msg = f"run error: **"
else:
error_msg = f"run error: {e}, the conclusion is: {inferences.get('conclusion', '')}"
logger.warning(f"[SOURCE TRACER INFER] single conclusion infer error: {error_msg}")
infer_message = {}
checked_infer_graphs = None
logger.info(f"[SOURCE TRACER INFER] run end.")
return infer_message, checked_infer_graphs
@staticmethod
def _encode_html_to_base64(html_content: str):
"""将html标签语言转换为base64"""
base64_string = base64.b64encode(html_content.encode('utf-8')).decode('utf-8')
try:
decoded_string = base64.b64decode(base64_string).decode('utf-8')
is_valid = (decoded_string == html_content)
if not is_valid:
raise ValueError('Encode html to base64 failed.')
except Exception as e:
raise e
return base64_string
async def get_conclusion_and_records(self):
"""调用溯源推理信息提取模块,提取溯源推理所需数据,目前只有 research 模式"""
if self.conclusion_with_records:
return
preprocessor = ResearchInferPreprocess(self.context)
self.conclusion_with_records = await preprocessor.run()
async def extract_reference(self, datas: Dict) -> Dict:
"""
筛查search_records中与conclusion相关的引用
Args:
datas: dict={
"conclusion": 原始结论,
"search_records": 对应章节的搜索记录
}
Returns:
dict={
"conclusion": 真正结论,
"references": 筛选后的参考文献
}
"""
logger.info(f"[SOURCE TRACER INFER] extract valid citations starting...")
conclusions, search_records = datas.get("conclusion", []), datas.get("search_records", [])
if not conclusions or not search_records:
if LogManager.is_sensitive():
logger.warning(
f"[SOURCE TRACER INFER]: conclusion: *** or search_records *** is None, skip current infer.")
else:
logger.warning(
f"[SOURCE TRACER INFER]: conclusion: {conclusions[-1]} or search_records {search_records} is None, \
skip current infer.")
return {}
records = [{"id": index, "content": record.get("content", "")} for index, record in enumerate(search_records)]
handle_datas = {"statement": conclusions[0], "references": records}
detection_func_and_args = {"detection_func": type_check, "args": list}
results = await call_model(self.model_name, "infer_validate_prompt", handle_datas,
detection_func_and_args=detection_func_and_args,
agent_name=AgentLlmName.SOURCE_TRACER_INFER_EXTRACT_REFERENCE.value)
if not results:
logger.warning("[SOURCE TRACER INFER] No supported reference.")
return {}
references = []
try:
for index in results:
if 0 <= index < len(search_records):
references.append({"id": index, "content": search_records[index].get("content", "")})
evidence = {"conclusion": conclusions[-1], "reference": references}
logger.debug(
"[SOURCE TRACER INFER] extract supported references:\n %s",
json.dumps(evidence, ensure_ascii=False, indent=4))
except Exception as e:
if LogManager.is_sensitive():
error_msg = f"extract supported references error: ***"
else:
error_msg = f"extract supported references error: {str(e)}, the conclusion is: {conclusions[-1]}"
logger.warning(f"[SOURCE TRACER INFER] {error_msg}")
return {}
logger.info(f"[SOURCE TRACER INFER] extract valid citations end.")
return evidence
async def infer(self, evidences: Dict) -> Dict:
"""
对支撑材料与文本相关内容进行关联推理, llm根据支撑材料输出得出结论的推理过程
Args:
dict={
"conclusion": 待推理的结论,
"reference": 筛选后的参考文献
}
Returns:
dict={
"conclusion": 待推理的结论,
"inference": 推理过程
}
"""
logger.info(f"[SOURCE TRACER INFER] infer start...")
detection_func_and_args = {"detection_func": type_check, "args": list}
results = await call_model(self.model_name, "infer_conclusion_prompt", evidences,
detection_func_and_args=detection_func_and_args,
agent_name=AgentLlmName.SOURCE_TRACER_INFER_INFER.value)
inference = results[0] if (isinstance(results, list) and results) else ""
results = {"conclusion": evidences.get("conclusion", ""), "inference": inference}
logger.debug("[SOURCE TRACER INFER] infer result:\n %s", json.dumps(results, ensure_ascii=False, indent=4))
logger.info(f"[SOURCE TRACER INFER] infer end.")
return results
async def filter_invalid_infer(self, inferences: Dict) -> Dict:
"""
过滤掉无效的、质量较差的推理
"""
logger.info(f"[SOURCE TRACER INFER] filter invalid inference starting...")
input_inferences = inferences.get('inference', "")
detection_func_and_args = {"detection_func": type_check, "args": str}
results = await call_model(self.model_name, "infer_filter_inference_prompt", {"input": [input_inferences]},
detection_func_and_args=detection_func_and_args,
agent_name=AgentLlmName.SOURCE_TRACER_INFER_FILTER_INVALID_INFER.value)
if not results:
if LogManager.is_sensitive():
logger.warning(f"[SOURCE TRACER INFER] filter invalid inference: ***")
else:
logger.warning(f"[SOURCE TRACER INFER] filter invalid inference: {input_inferences}")
raise ValueError("invalid inference")
logger.info(f"[SOURCE TRACER INFER] infer filter inference ending.")
return inferences
async def structured_infer(self, inference: Dict) -> List[List]:
"""
结构化inference,提取结构化参考材料的关系
"""
logger.info(f"[SOURCE TRACER INFER] structured_infer starting...")
detection_func_and_args = {"detection_func": is_equal_length, "args": 3}
result = await call_model(self.model_name, "infer_structured_prompt", inference,
detection_func_and_args=detection_func_and_args,
agent_name=AgentLlmName.SOURCE_TRACER_INFER_STRUCTURED_INFER.value)
if not result:
raise ValueError(f"unstructured inference!")
logger.debug("[SOURCE TRACER INFER] structured_infer result:\n %s",
json.dumps(result, ensure_ascii=False, indent=4))
return result
def mark_conclusion_in_report(self, infer_messages, conclusion_infos):
"""
标注报告中的推理内容
"""
logger.info(f"[SOURCE TRACER INFER] mark conclusion in report starting...")
origin_response = self.response
label_template = "[{conclusion}](#inference:{infer_id})"
try:
for conclusion_info, infer_message in zip(conclusion_infos, infer_messages):
labeled_conclusion = label_template.format(conclusion=infer_message.get("conclusion", ""),
infer_id=infer_message.get("id", -1))
self.response = self.response[:conclusion_info['start_pos']] + \
labeled_conclusion + \
self.response[conclusion_info['end_pos']:]
logger.debug("[SOURCE TRACER INFER] report with marked inference conclusions:\n %s", self.response)
logger.info(f"[SOURCE TRACER INFER] mark conclusion in report end.")
return self.response
except Exception as e:
if LogManager.is_sensitive():
error_msg = f"mark conclusion in report error: ***"
else:
error_msg = f"mark conclusion in report error: {repr(e)}"
logger.warning(f"[SOURCE TRACER INFER] {error_msg}")
origin_response = self.response
return origin_response