import json
import logging
import difflib
import re
from typing import List, Dict, Tuple
import copy
from openjiuwen_deepsearch.utils.log_utils.log_manager import LogManager
from openjiuwen_deepsearch.algorithm.source_tracer_infer.infer_call_model import GraphInfo, NumberNodeParam
logger = logging.getLogger(__name__)
class NumberNode:
def __init__(self):
pass
@staticmethod
def number_citation_node(node, number_node_param: NumberNodeParam, title, url) -> Tuple[NumberNodeParam, int]:
"""
给引用节点编号
"""
try:
node_index = number_node_param.node_index
node_id = -1
node = str(node).strip()
number_node_param.node_set.add(node)
for i, v in number_node_param.node_map.items():
if v.get("url", "") == url:
node_id = i
break
if node_id != -1:
number_node_param.head_id_list.append(node_id)
return number_node_param
else:
head = f"《{title}》"
number_node_param.node_map[node_index] = {"label": head, "url": url}
number_node_param.citation_ids.add(node_index)
node_id = node_index
number_node_param.node_index += 1
number_node_param.head_id_list.append(node_id)
except Exception as e:
if LogManager.is_sensitive():
logger.warning(f'[SOURCE TRACER INFER] number_citation_node error: ***')
else:
logger.warning(f'[SOURCE TRACER INFER] number_citation_node error: {e}')
raise e
return number_node_param
@staticmethod
def number_programmer_node(node, number_node_param: NumberNodeParam):
"""
给programmer node的表述节点编号(看作特殊的引用)
"""
try:
node_index = number_node_param.node_index
node_id = -1
node = str(node).strip()
number_node_param.node_set.add(node)
for i, v in number_node_param.node_map.items():
if v.get("label", "") == node:
node_id = i
break
if node_id != -1:
number_node_param.head_id_list.append(node_id)
return number_node_param
number_node_param.node_map[node_index] = {"label": node, "is_program_info": True}
number_node_param.citation_ids.add(node_index)
node_id = node_index
number_node_param.node_index += 1
number_node_param.head_id_list.append(node_id)
except Exception as e:
if LogManager.is_sensitive():
logger.warning(f'[SOURCE TRACER INFER] number_programmer_node error: ***')
else:
logger.warning(f'[SOURCE TRACER INFER] number_programmer_node error: {e}')
raise e
return number_node_param
@staticmethod
def replace_index_with_url(index: int, search_records: List[Dict]) -> Tuple[str, str]:
"""将引用索引替换为对应 URL 与展示文本。"""
if index < 0 or index >= len(search_records):
raise ValueError("[SOURCE TRACER INFER] The index of search_records is out of range.")
record = search_records[index]
return record.get("title", ""), record.get("url", "")
@staticmethod
def _token_set_ratio(str1: str, str2: str) -> float:
"""
计算基于token集合的匹配比率(模拟 rapidfuzz.fuzz.token_set_ratio)
将字符串分词成token集合,然后计算相似度
"""
def tokenize(text: str) -> set:
tokens = re.findall(r'\w+|[^\w\s]', text.lower())
return set(tokens)
tokens1 = tokenize(str1)
tokens2 = tokenize(str2)
if not tokens1 or not tokens2:
return 0.0
intersection = tokens1 & tokens2
union = tokens1 | tokens2
if not union:
return 0.0
jaccard = len(intersection) / len(union) * 100
intersection_str = ' '.join(sorted(intersection))
union_str1 = ' '.join(sorted(tokens1))
union_str2 = ' '.join(sorted(tokens2))
ratio1 = difflib.SequenceMatcher(
None, intersection_str, union_str1).ratio() * 100 if union_str1 else 0
ratio2 = difflib.SequenceMatcher(
None, intersection_str, union_str2).ratio() * 100 if union_str2 else 0
return (jaccard + (ratio1 + ratio2) / 2) / 2
@staticmethod
def _extract_best_match(query: str, choices: list, limit: int = 1) -> list:
"""
从候选列表中提取最匹配的项(模拟 rapidfuzz.process.extract)
返回格式: [(matched_string, score), ...]
"""
if not choices:
return []
scored_choices = []
for choice in choices:
score = NumberNode._wr_ratio(query, choice)
scored_choices.append((choice, score))
scored_choices.sort(key=lambda x: x[1], reverse=True)
return scored_choices[:limit]
@staticmethod
def _partial_ratio(str1: str, str2: str) -> float:
"""
计算部分匹配比率(模拟 rapidfuzz.fuzz.partial_ratio)
找到较短字符串在较长字符串中的最佳匹配位置,计算相似度
"""
if not str1 or not str2:
return 0.0
if len(str1) > len(str2):
str1, str2 = str2, str1
best_ratio = 0.0
for i in range(len(str2) - len(str1) + 1):
segment = str2[i:i + len(str1)]
ratio = difflib.SequenceMatcher(None, str1, segment).ratio() * 100
if ratio > best_ratio:
best_ratio = ratio
return best_ratio
def number_conclusion_node(self, node,
number_node_param: NumberNodeParam,
conclusion, is_tail=False):
"""
给结论节点编号
"""
try:
node_set = number_node_param.node_set
node_index = number_node_param.node_index
node_id = -1
node = str(node).strip()
node_match = self._extract_best_match(node, list(node_set), limit=1)
if node_match and node_match[0][1] > 90:
for i, v in number_node_param.node_map.items():
if v.get("label", "") == node_match[0][0]:
node_id = i
else:
number_node_param.node_set.add(node)
number_node_param.node_map[node_index] = {"label": node}
node_id = node_index
number_node_param.node_index += 1
if self._partial_ratio(node, conclusion) > 60 or self._token_set_ratio(node, conclusion) > 60:
number_node_param.conclusion_ids.add(node_id)
if is_tail:
number_node_param.tail_id = node_id
else:
number_node_param.head_id_list.append(node_id)
except Exception as e:
if LogManager.is_sensitive():
logger.warning(f'[SOURCE TRACER INFER] number_conclusion_node error: ***')
else:
logger.warning(f'[SOURCE TRACER INFER] number_conclusion_node error: {e}')
raise e
return number_node_param
def number_node(self, structured_inference: List[List], conclusion: str, search_records: List[Dict]) -> GraphInfo:
"""
给节点编号
Args:
structured_inference: 结构化推理过程
conclusion:结论
search_records:搜索记录
Returns:
GraphInfo=(
structured_inference: 编号优化后的结构化推理过程
node_map: 节点id字典
citation_ids: 引用节点id序列
conclusion_ids: 最终结论节点id序列
)
"""
logger.info(f"[SOURCE TRACER INFER] number_node starting...")
number_node_param = NumberNodeParam()
try:
for item in structured_inference:
head_list, relation, tail = item
number_node_param.head_id_list = []
if not isinstance(head_list, list):
head_list = [head_list]
for head in head_list:
if isinstance(head, int):
title, url = self.replace_index_with_url(head, search_records)
if title == "ProgrammerNode":
number_node_param = self.number_programmer_node(head, number_node_param)
else:
number_node_param = self.number_citation_node(head, number_node_param,
title, url)
else:
number_node_param = self.number_conclusion_node(head, number_node_param, conclusion)
number_node_param = self.number_conclusion_node(tail, number_node_param,
conclusion, is_tail=True)
number_node_param.update_structured_inference(relation)
except Exception as e:
raise ValueError(f"ERROR in NUMBER_NODE: {e}") from e
logger.debug(
"[SOURCE TRACER INFER] number_node result\n %s",
json.dumps(number_node_param.structured_inference, ensure_ascii=False, indent=4))
return GraphInfo(structured_inference=copy.deepcopy(number_node_param.structured_inference),
node_map=copy.deepcopy(number_node_param.node_map),
citation_ids=copy.deepcopy(list(number_node_param.citation_ids)),
conclusion_ids=copy.deepcopy(list(number_node_param.conclusion_ids)))
@staticmethod
def _wr_ratio(str1: str, str2: str) -> float:
"""
计算两个字符串的加权比率相似度(模拟 rapidfuzz.fuzz.WRatio)
使用 difflib.SequenceMatcher 计算相似度,返回 0-100 的分数
"""
return difflib.SequenceMatcher(None, str1, str2).ratio() * 100