import re
import logging
from typing import List, Dict
import asyncio
from openjiuwen_deepsearch.algorithm.source_tracer_infer.infer_call_model import call_model, type_check
from openjiuwen_deepsearch.utils.log_utils.log_manager import LogManager
from openjiuwen_deepsearch.utils.constants_utils.node_constants import AgentLlmName
logger = logging.getLogger(__name__)
class ResearchInferPreprocess():
def __init__(self, context: Dict):
self.language = context.get("language", "zh-CN")
self.llm_model = context.get("llm_model_name", "")
self.response = context.get("source_tracer_response", "")
self.search_records = context.get("all_classified_contents", [])
self.search_record_with_index = {}
self.citation_pattern = r'\[\[([^\]]+)\]\]\([^)]+\)'
self.conclusion_infos = {}
async def run(self) -> List[Dict]:
"""预处理是为溯源推理提取报告里需要推理的结论以及对应的搜索记录"""
logger.info(f"[INFERENCE INFO EXTRACT] Extracting preprocessed information for infer of research.")
self.classify_search_record()
self.conclusion_infos = await self._extract_conclusion()
if not self.conclusion_infos:
logger.warning(f"[INFERENCE INFO EXTRACT] no conclusion extracted from report.")
return []
conclusion_with_records = self._match_conclusion_with_records()
if LogManager.is_sensitive():
logger.info(f"[INFERENCE INFO EXTRACT] 结论与搜索记录匹配结果 ***")
else:
logger.info(f"[INFERENCE INFO EXTRACT] 结论与搜索记录匹配结果 {conclusion_with_records}")
return conclusion_with_records
def classify_search_record(self):
"""章节索引映射章节搜索记录, 修改成员变量 search_records"""
search_record_with_index = {}
for i, section_search_record in enumerate(self.search_records):
search_record_with_index[i] = []
for record in section_search_record:
search_record_with_index[i].append({"title": record.get("title", ""),
"url": record.get("url", ""),
"content": record.get("original_content", "")
})
self.search_record_with_index = search_record_with_index
async def _extract_conclusion(self):
"""提取research模式的结论"""
conclusions_info = await self._find_sentences_with_positions()
if not conclusions_info:
logger.warning(f"[INFERENCE INFO EXTRACT]: no conclusion extracted from report.")
if LogManager.is_sensitive():
logger.warning(f"[INFERENCE INFO EXTRACT]: 结论信息 ***")
else:
logger.warning(f"[INFERENCE INFO EXTRACT]: 结论信息 {conclusions_info}")
return conclusions_info
async def _find_sentences_with_positions(self) -> Dict[str, Dict]:
"""
查找句子并记录精确位置信息
Returns:
字典结构:{
'句子内容': {
"sentence_section_index", 章节索引,
'start_pos': global_start, 结论在全文中的起始位置
'end_pos': global_end, 结论在全文中的结束位置
'content_without_citation', 章节内容(去除行内引用)
'found': True
}
}
"""
sections = self._split_markdown_with_detailed_positions()
results = {}
sections = sections[2:-1]
conclusions = await self._extract_conclusions_for_sections(sections)
for section_index, conclusion in enumerate(conclusions):
for sentence in conclusion:
sentence = sentence.strip()
if not sentence:
continue
results[sentence] = self._locate_sentence_in_sections(sentence, sections[section_index], section_index)
return results
def _split_markdown_with_detailed_positions(self) -> List[Dict]:
"""
分割Markdown并记录详细位置信息
"""
h1_pattern = r'(?m)^#\s+(.+)$'
h1_matches = list(re.finditer(h1_pattern, self.response))
sections = []
for i, match in enumerate(h1_matches):
content_start = match.end()
if i < len(h1_matches) - 1:
content_end = h1_matches[i + 1].start()
else:
content_end = len(self.response)
content = self.response[content_start:content_end]
sections.append({
'content': content,
'global_start': content_start,
'global_end': content_end,
})
return sections
async def _extract_conclusions_for_sections(self, sections):
"""从每个章节中提取1个推理结论"""
logger.info(f"[INFERENCE INFO EXTRACT] extract_conclusions starting...")
detection_func_and_args = {"detection_func": type_check, "args": list}
tasks = [call_model(self.llm_model, "infer_extract_conclusion_prompt", {"input": section.get("content", "")},
detection_func_and_args=detection_func_and_args,
agent_name=AgentLlmName.SOURCE_TRACER_INFER_EXTRACT_CONCLUSION.value)
for section in sections]
results = await asyncio.gather(*tasks)
if LogManager.is_sensitive():
logger.info(f"[INFERENCE INFO EXTRACT]: Extracted conclusions: *")
else:
logger.info(f"[INFERENCE INFO EXTRACT]: Extracted conclusions: {results}")
return results
def _locate_sentence_in_sections(self, sentence: str, section: Dict, index) -> Dict:
"""
在章节中定位句子的精确位置
Args:
sentence: 结论句子
section: 提取结论的章节
index: 章节索引
Returns:
dict={
sentence_section_index: 章节索引
start_pos: 结论在全文中的首位置
end_pos: 结论在全文中的末位置
content_without_citation: 清理掉引用信息的章节文本
found: 是否在section中找到sentence
}
"""
content = section['content']
sentence_start = content.find(sentence)
if sentence_start != -1:
global_start = section['global_start'] + sentence_start
global_end = global_start + len(sentence)
return {
"sentence_section_index": index,
'start_pos': global_start,
'end_pos': global_end,
'content_without_citation': self._clean_citation(content),
'found': True
}
return {
'sentence_section_index': -1,
'start_pos': -1,
'end_pos': -1,
'found': False
}
def _clean_citation(self, content: str) -> str:
"""
清理行内引用
"""
cleaned_content = re.sub(self.citation_pattern, '', content)
return cleaned_content
def _match_conclusion_with_records(self) -> List[Dict]:
"""按章节匹配结论与搜索记录"""
match_conclusion_with_records = []
sorted_sentences = sorted(
[(sentence, info) for sentence, info in self.conclusion_infos.items() if info['found']],
key=lambda x: x[1]['start_pos'],
reverse=True
)
all_search_record = []
for _, record in self.search_record_with_index.items():
all_search_record += record
for conclusion, info in sorted_sentences:
section_idx = info.get("sentence_section_index", None)
section_content = info.get("content_without_citation", "")
if section_idx is not None and section_idx in self.search_record_with_index:
match_conclusion_with_records.append({
"conclusion": [section_content, conclusion],
"search_records": self.search_record_with_index[section_idx],
"start_pos": info['start_pos'],
"end_pos": info['end_pos']
})
elif section_idx not in self.search_record_with_index:
match_conclusion_with_records.append({
"conclusion": [section_content, conclusion],
"search_records": all_search_record,
"start_pos": info['start_pos'],
"end_pos": info['end_pos']
})
return match_conclusion_with_records