"""
-------------------------------------------------------------------------
This file is part of the RAGSDK project.
Copyright (c) 2026 Huawei Technologies Co.,Ltd.
RAGSDK is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import threading
import re
from typing import List, Dict
import requests
from loguru import logger
from mx_rag.corag.prompts import get_evaluate_answer_prompt
from mx_rag.llm.text2text import Text2TextLLM
def normalize_retrieve_api_results(results):
"""Normalize retrieve API responses into a list for downstream consumers."""
if isinstance(results, dict):
for key in ['chunks', 'data', 'results', 'docs', 'passages']:
value = results.get(key)
if isinstance(value, list):
return value
return [results] if results else []
if isinstance(results, list):
return results
return []
def truncate_long_text_by_char(text: str, max_token_length: int) -> str:
"""
按字符数截断长文本,确保中文和英文的字符比例符合预期。
保留文本的开头和结尾部分,避免丢失重要信息。
Args:
text: 待截断的文本字符串
max_token_length: 允许的最大字符长度
Returns:
截断后的文本字符串
"""
if not text:
return text
chinese_ratio = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') / len(text)
max_char_len = max_token_length if chinese_ratio > 0.5 else max_token_length * 2
if len(text) <= max_char_len:
return text
half_len = max_char_len // 2
return text[:half_len] + text[- (max_char_len - half_len):]
def normalize_text(text: str) -> str:
"""标准化文本:小写、移除标点、冠词和多余空格。"""
if text is None:
return ""
if len(text) > 10000:
return ""
text = text.lower()
text = re.sub(r'[^\w\s]', ' ', text)
text = re.sub(r'\b(a|an|the)\b', ' ', text)
return re.sub(r'\s+', ' ', text).strip()
def check_answer(prediction: str, ground_truths: List[str]) -> bool:
"""检查预测答案是否匹配任一标准答案。
Args:
prediction: 预测的答案。
ground_truths: 标准答案列表。
Returns:
如果匹配成功返回True,否则返回False。
"""
if not prediction:
return False
norm_pred = normalize_text(prediction)
for gt in ground_truths:
norm_gt = normalize_text(gt)
if not norm_gt:
continue
if norm_pred == norm_gt or norm_gt in norm_pred:
return True
return False
def check_answer_with_llm_judge(
prediction: str,
ground_truths: List[str],
query: str,
judge_llm: Text2TextLLM,
) -> bool:
"""使用LLM作为评判者检查预测答案是否正确。
Args:
prediction: 预测的答案。
ground_truths: 标准答案列表。
query: 原始查询。
judge_llm: 用于评判的LLM实例。
Returns:
如果答案正确返回True,否则返回False。
"""
if not prediction:
return False
gt_text = " or ".join([f'"{gt}"' for gt in ground_truths if gt])
if not gt_text:
return False
prompt = get_evaluate_answer_prompt(query, prediction, gt_text)
try:
response = judge_llm.chat(query=prompt)
response_upper = response.strip().upper()
if response_upper.startswith("YES"):
return True
yes_pos = response_upper.find("YES")
no_pos = response_upper.find("NO")
if yes_pos != -1 and (no_pos == -1 or yes_pos < no_pos):
return True
return False
except Exception as e:
logger.warning(f"LLM judge error: {e}, falling back to string matching")
return check_answer(prediction, ground_truths)
class ThreadSafeCounter:
def __init__(self, initial_value=0):
"""Initialize a thread-safe counter with the specified initial value.
Args:
initial_value: Starting value for the counter (default: 0)
"""
self._count = initial_value
self._thread_lock = threading.Lock()
def increment(self, step=1):
"""Atomically increment the counter by the specified step and return the updated value.
Args:
step: Value to increment the counter by (default: 1)
Returns:
The new value of the counter after increment
"""
with self._thread_lock:
self._count += step
return self._count
def reset(self):
"""Reset the counter to zero and return the new value."""
with self._thread_lock:
self._count = 0
return self._count