import json
import logging
from typing import Any, Dict
from openjiuwen_deepsearch.algorithm.source_trace.add_source import (add_source_references,
generate_source_datas,
merge_source_datas)
from openjiuwen_deepsearch.algorithm.source_trace.content_analyzer import recognize_content_to_cite
from openjiuwen_deepsearch.algorithm.source_trace.source_matcher import match_sources
from openjiuwen_deepsearch.algorithm.source_trace.source_tracer_preprocessors import (generate_origin_report_data,
preprocess_report,
preprocess_search_record)
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
from openjiuwen_deepsearch.utils.log_utils.log_manager import LogManager
logger = logging.getLogger(__name__)
class SourceTracer:
"""文档溯源处理器,用于为文档中的句子添加来源引用信息。"""
def __init__(self, algorithm_inputs: dict):
"""初始化SourceTracer实例,设置溯源所需的参数和数据源。
Args:
algorithm_inputs (dict): 进行溯源必要的输入数据字典,包含以下键:
- report (str): 需要进行溯源的报告文本
- classified_content (list): 用于溯源生成,子报告生成过程中使用的top-K文章信息
- llm_model_name (str): 使用的LLM模型名称
"""
self._similarity_threshold = 0.9
self._search_record_max_content_len = 3000
self._chunk_size = 40
self._report = algorithm_inputs.get("report", "")
self._classified_content = algorithm_inputs.get("classified_content", [])
self._search_record = self.transform_search_record(self._classified_content)
self._llm_model_name = algorithm_inputs.get("llm_model_name", "")
self._trace_source_datas = []
@staticmethod
def transform_search_record(classified_content: list) -> dict:
"""将子报告生成过程中使用的top-K文章信息转换为溯源模块使用的搜索记录格式。
Args:
classified_content (list): 子报告生成过程中使用的子报告生成过程中使用的top-K文章信息,每个元素为包含url、title和original_content的字典
Returns:
dict: 溯源模块使用的搜索记录,格式为{'search_record': [{'url': str, 'title': str, 'content': str}, ...]}
"""
if not classified_content:
return {}
filtered_content = []
for item in classified_content:
if isinstance(item, dict):
if "url" in item and "title" in item and "original_content" in item:
filtered_item = {
"url": item["url"],
"title": item["title"],
"content": item["original_content"]
}
filtered_content.append(filtered_item)
search_record = dict(search_record=filtered_content)
return search_record
async def research_trace_source(self) -> None:
"""在research模式下对报告进行溯源,生成引用信息data列表。
Returns:
None
"""
try:
if not self._report:
logger.warning("[research_trace_source] report为空,不做溯源")
return
_, preprocessed_report = preprocess_report(self._report)
preprocessed_search_record = preprocess_search_record(self._search_record,
self._search_record_max_content_len)
if not preprocessed_search_record:
logger.warning("[research_trace_source] 预处理搜索记录失败,退出溯源")
return
if not LogManager.is_sensitive():
logger.debug(
f"[research_trace_source] 预处理后的搜索记录: %s",
json.dumps(preprocessed_search_record, ensure_ascii=False, indent=2))
content_recognition_result = await recognize_content_to_cite(
preprocessed_report, self._similarity_threshold, self._llm_model_name)
if not content_recognition_result:
logger.warning("[research_trace_source] 未识别到需要增加引用的内容")
return
trace_results = await match_sources(
content_recognition_result,
preprocessed_search_record,
self._chunk_size,
self._llm_model_name
)
if not trace_results:
logger.warning("[research_trace_source] 未获取到有效溯源结果")
return
datas = generate_source_datas(preprocessed_report, preprocessed_search_record, trace_results)
self._trace_source_datas = datas
except Exception as e:
raise CustomValueException(StatusCode.SOURCE_TRACER_TRACE_SOURCE_ERROR.code,
StatusCode.SOURCE_TRACER_TRACE_SOURCE_ERROR.errmsg.format(e=str(e))) from e
def add_source_to_report(self) -> Dict[str, Any]:
"""将溯源引用信息添加到报告文本中,生成带有引用标记的报告。
Returns:
Dict[str, Any]: 处理结果字典,包含以下键:
- modified_report (str): 增加引用标记后的报告文本
- datas (list): 合并后的引用信息列表,包含所有溯源数据源
"""
try:
datas = self._trace_source_datas
removed_section, preprocessed_report = preprocess_report(self._report)
origin_report_dict = generate_origin_report_data(
preprocessed_report, self._classified_content)
origin_report_datas = origin_report_dict.get("origin_report_data", [])
need_add_source_report = origin_report_dict.get("modified_report", "")
all_datas = merge_source_datas(
need_add_source_report, datas, origin_report_datas)
added_source_report, all_datas = add_source_references(need_add_source_report, all_datas)
added_source_report = added_source_report + removed_section
if not LogManager.is_sensitive():
logger.info(f'[add_source_to_report] 添加来源引用后的报告 {added_source_report}')
logger.debug(f'[add_source_to_report] 合并后的引用信息: %s',
json.dumps(all_datas, ensure_ascii=False, indent=2))
return dict(modified_report=added_source_report, datas=all_datas)
except Exception as e:
raise CustomValueException(StatusCode.SOURCE_TRACER_ADD_SOURCE_ERROR.code,
StatusCode.SOURCE_TRACER_ADD_SOURCE_ERROR.errmsg.format(e=str(e))) from e