"""
-------------------------------------------------------------------------
This file is part of the RAGSDK project.
Copyright (c) 2026 Huawei Technologies Co.,Ltd.
RAGSDK is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import re
import threading
import json
from typing import Optional, List, Tuple
from dataclasses import dataclass, field
from mx_rag.llm.text2text import Text2TextLLM
from mx_rag.corag.utils import normalize_retrieve_api_results, truncate_long_text_by_char
from mx_rag.utils.url import RequestUtils
from mx_rag.utils import ClientParam
from mx_rag.corag.prompts import (
get_generate_subquery_prompt,
get_generate_intermediate_answer_prompt,
get_generate_final_answer_prompt
)
def _process_subquery(input_subquery: str) -> Tuple[str, Optional[str]]:
if len(input_subquery) > 10000:
return input_subquery.strip(), None
reasoning_pattern = r'<reasoning>([^<]*)</reasoning>'
reasoning_match = re.search(reasoning_pattern, input_subquery)
reasoning_content = reasoning_match.group(1).strip() if reasoning_match else None
processed_subquery = re.sub(reasoning_pattern, '', input_subquery)
processed_subquery = re.sub(r'<think>([^<]*)</think>', '', processed_subquery)
processed_subquery = processed_subquery.strip()
if processed_subquery.startswith('"') and processed_subquery.endswith('"'):
processed_subquery = processed_subquery[1:-1]
processed_subquery = re.sub(r'^(Step|Intermediate query) \d+: ', '', processed_subquery)
return processed_subquery, reasoning_content
@dataclass
class ReasoningPath:
"""
表示CoRAG推理路径的类,包含原始查询、子查询、子答案、文档ID、思考过程和文档列表。
"""
original_query: str
subqueries: List[str] = field(default_factory=list)
subanswers: List[str] = field(default_factory=list)
document_ids: List[List[str]] = field(default_factory=list)
reasoning_steps: List[str] = field(default_factory=list)
documents: List[List[str]] = field(default_factory=list)
class CoRagAgent:
def __init__(
self, base_llm: Text2TextLLM,
retrieve_api_url: str,
final_llm: Optional[Text2TextLLM] = None,
sub_answer_llm: Optional[Text2TextLLM] = None,
retrieve_top_k: int = 5,
client_param: ClientParam = ClientParam()
):
self.base_llm = base_llm
self.final_llm = final_llm
self.sub_answer_llm = sub_answer_llm
self.retrieve_api_url = retrieve_api_url
self.retrieve_top_k = retrieve_top_k
self.client_param=client_param
self._client = RequestUtils(client_param=self.client_param)
self.lock = threading.Lock()
def sample_path(
self, query: str, task_desc: str,
max_path_length: int = 3
) -> ReasoningPath:
"""
生成CoRAG推理路径,通过迭代生成子查询,收集子答案和相关文档,构建完整的推理过程。
通过控制LLM调用次数和子查询数量,确保生成的路径在合理范围内。
"""
interaction_queries: List[str] = []
interaction_answers: List[str] = []
retrieved_doc_ids: List[List[str]] = []
retrieved_docs: List[List[str]] = []
thought_process: List[str] = []
original_temp = self.base_llm.llm_config.temperature
llm_call_count = 0
max_allowed_calls = 4 * max_path_length
while len(interaction_queries) < max_path_length and llm_call_count < max_allowed_calls:
llm_call_count += 1
followup_prompt = get_generate_subquery_prompt(
query=query,
past_subqueries=interaction_queries,
past_subanswers=interaction_answers,
task_desc=task_desc,
)
followup_prompt = truncate_long_text_by_char(followup_prompt, max_token_length=self.base_llm.llm_config.max_tokens)
generated_subquery = self.base_llm.chat(query=followup_prompt)
processed_query, step_reasoning = _process_subquery(generated_subquery)
with self.lock:
self.base_llm.llm_config.temperature = original_temp
if processed_query in interaction_queries:
self.base_llm.llm_config.temperature = max(original_temp, 0.7)
continue
query_answer, current_doc_ids, current_docs = self._get_subanswer_and_doc_ids(
subquery=processed_query
)
interaction_queries.append(processed_query)
interaction_answers.append(query_answer)
retrieved_doc_ids.append(current_doc_ids)
retrieved_docs.append(current_docs)
thought_process.append(step_reasoning)
complete_path = ReasoningPath(
original_query=query,
subqueries=interaction_queries,
subanswers=interaction_answers,
document_ids=retrieved_doc_ids,
documents=retrieved_docs,
reasoning_steps=thought_process,
)
return complete_path
def generate_final_answer(
self, rag_path: ReasoningPath, task_description: str
) -> str:
"""
基于完整的推理路径生成最终答案。
该方法接收一个包含完整查询和推理历史的 ReasoningPath 对象,
结合任务描述和可选的参考文档,通过 LLM 生成最终的综合答案。
Args:
rag_path: 包含查询和推理历史的 ReasoningPath 对象
task_description: 任务的详细描述
Returns:
生成的最终答案字符串
"""
final_prompt = get_generate_final_answer_prompt(
original_query=rag_path.original_query,
interaction_queries=rag_path.subqueries or [],
interaction_answers=rag_path.subanswers or [],
task_instructions=task_description,
reference_docs=rag_path.documents or [],
)
answer_llm = self.final_llm if self.final_llm is not None else self.base_llm
final_prompt = truncate_long_text_by_char(final_prompt, max_token_length=answer_llm.llm_config.max_tokens)
return answer_llm.chat(query=final_prompt)
def _get_subanswer_and_doc_ids(
self, subquery: str
) -> Tuple[str, List, List[str]]:
"""这段代码的主要功能是根据子查询从数据源检索相关文档,并生成一个子答案。
它还处理了消息的截断和聊天客户端的调用,最终返回子答案和相关的文档信息。"""
documents = []
doc_ids = []
if self.retrieve_api_url:
request_body = {
"query": subquery,
"top_k": self.retrieve_top_k
}
request_body["stream"] = False
response = self._client.post(url=self.retrieve_api_url, body=json.dumps(request_body),
headers={"Content-Type": "application/json"})
if response.success:
try:
data = json.loads(response.data)
retriever_results = normalize_retrieve_api_results(data)
except json.JSONDecodeError as e:
logger.error(f"response content cannot convert to json format: {e}")
retriever_results = []
except Exception as e:
logger.error(f"unexpected error while parsing JSON response. Error: {e}")
retriever_results = []
for res in retriever_results:
if isinstance(res, str):
documents.append(res)
doc_ids.append('graph_chunk')
elif isinstance(res, dict):
content = res.get('contents') or res.get('content') or res.get('text') or str(res)
documents.append(content)
doc_ids.append(str(res.get('id') or res.get('doc_id') or 'graph_chunk'))
prompt = get_generate_intermediate_answer_prompt(
subquery=subquery,
documents=documents,
)
client = self.sub_answer_llm if self.sub_answer_llm else self.base_llm
prompt = truncate_long_text_by_char(prompt, max_token_length=client.llm_config.max_tokens)
subanswer: str = client.chat(query=prompt)
return subanswer, doc_ids, documents