"""
-------------------------------------------------------------------------
This file is part of the RAGSDK project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
RAGSDK is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import copy
from typing import Union, Iterator, List, Dict, Callable
from langchain_core.documents import Document
from loguru import logger
from langchain_core.retrievers import BaseRetriever
from mx_rag.llm.text2text import _check_sys_messages
from mx_rag.utils.common import validate_params, BOOL_TYPE_CHECK_TIP, TEXT_MAX_LEN, MB
from mx_rag.llm.llm_parameter import LLMParameterConfig
from mx_rag.chain import Chain
from mx_rag.llm import Text2TextLLM
from mx_rag.reranker.reranker import Reranker
from mx_rag.utils.common import MAX_PROMPT_LENGTH
DEFAULT_RAG_PROMPT = """根据上述已知信息,简洁和专业地回答用户的问题。如果无法从已知信息中得到答案,请根据自身经验做出回答"""
TEXT_RAG_TEMPLATE = """You are an assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know.
Use three sentences maximum and keep the answer as concise and accurate as possible.
Do NOT repeat the question or output any other words.
Context: {context}
Question: {question}
Answer:
"""
def _user_content_builder(query: str, docs: List[Document], prompt: str) -> str:
"""
默认的用户输入拼接逻辑。
参数说明:
----------
query : str
用户原始提问内容。
例如:“请根据以下材料总结关键要点。”
docs : List[Document]
从检索器(retriever)返回的文档对象列表。
每个 Document 通常包含:
- page_content:文档内容文本;
- metadata:元信息(如来源、标题、分数等)。
prompt : str
系统预设提示词(例如 RAG 模板、任务指令)。
会追加到最后一个文档后面,用于指导模型生成更精确的回答。
返回:
-----
str : 拼接后的完整 prompt 文本,作为大模型输入内容。
"""
final_prompt = ""
document_separator: str = "\n\n"
if len(docs) != 0:
if prompt != "":
last_doc = docs[-1]
last_doc.page_content = (last_doc.page_content
+ f"{document_separator}{prompt}")
docs[-1] = last_doc
final_prompt = document_separator.join(x.page_content for x in docs)
if final_prompt != "":
final_prompt += document_separator
final_prompt += query
return final_prompt
def _safe_call_builder(builder: Callable, *args, **kwargs) -> str:
"""
安全地调用用户自定义的 content builder。
"""
result = builder(*args, **kwargs)
if isinstance(result, str) and 0 < len(result) <= 4 * MB:
return result
else:
raise ValueError(f"callback function {builder.__name__} returned invalid result. "
f"Expected: str with length 0 < len <= 4MB. fallback to default builder.")
class SingleText2TextChain(Chain):
@validate_params(
llm=dict(validator=lambda x: isinstance(x, Text2TextLLM), message="param must be instance of Text2TextLLM"),
retriever=dict(validator=lambda x: isinstance(x, BaseRetriever),
message="param must be instance of BaseRetriever"),
reranker=dict(validator=lambda x: isinstance(x, Reranker) or x is None,
message="param must be None or instance of Reranker"),
prompt=dict(validator=lambda x: isinstance(x, str) and 1 <= len(x) <= MAX_PROMPT_LENGTH,
message=f"param must be str and length range [1, {MAX_PROMPT_LENGTH}]"),
sys_messages=dict(validator=lambda x: _check_sys_messages(x),
message="param must be None or List[dict], and length of dict <= 16, "
"k-v of dict: len(k) <=16 and len(v) <= 4 * MB"),
source=dict(validator=lambda x: isinstance(x, bool), message=BOOL_TYPE_CHECK_TIP),
user_content_builder=dict(validator=lambda x: isinstance(x, Callable),
message="param must be Callable"),
)
def __init__(self, llm: Text2TextLLM,
retriever: BaseRetriever,
reranker: Reranker = None,
prompt: str = DEFAULT_RAG_PROMPT,
sys_messages: List[dict] = None,
source: bool = True,
user_content_builder: Callable = _user_content_builder):
super().__init__()
self._retriever = retriever
self._reranker = reranker
self._llm = llm
self._prompt = prompt
self._sys_messages = sys_messages
self._source = source
self._role: str = "user"
self._user_content_builder = user_content_builder
@validate_params(
text=dict(validator=lambda x: isinstance(x, str) and 0 < len(x) <= TEXT_MAX_LEN,
message=f"param must be a str and its length meets (0, {TEXT_MAX_LEN}]"),
llm_config=dict(validator=lambda x: isinstance(x, LLMParameterConfig),
message="llm_config must be instance of LLMParameterConfig")
)
def query(self, text: str, llm_config: LLMParameterConfig = LLMParameterConfig(temperature=0.5, top_p=0.95),
*args, **kwargs) \
-> Union[Dict, Iterator[Dict]]:
return self._query(text, llm_config)
def _query(self,
question: str,
llm_config: LLMParameterConfig) -> Union[Dict, Iterator[Dict]]:
q_docs = self._retriever.invoke(question)
if self._reranker is not None and len(q_docs) > 0:
scores = self._reranker.rerank(question, [doc.page_content for doc in q_docs])
q_docs = self._reranker.rerank_top_k(q_docs, scores)
q_with_prompt = _safe_call_builder(self._user_content_builder,
question,
copy.deepcopy(q_docs),
self._prompt)
if not llm_config.stream:
return self._do_query(q_with_prompt, llm_config, question=question, q_docs=q_docs)
return self._do_stream_query(q_with_prompt, llm_config, question=question, q_docs=q_docs)
def _do_query(self, q_with_prompt: str, llm_config: LLMParameterConfig, question: str, q_docs: List[Document]) \
-> Dict:
logger.info("invoke normal query")
resp = {"query": question, "result": ""}
if self._source:
resp['source_documents'] = [{'metadata': x.metadata, 'page_content': x.page_content} for x in q_docs]
llm_response = self._llm.chat(query=q_with_prompt, sys_messages=self._sys_messages,
role=self._role, llm_config=llm_config)
resp['result'] = llm_response
return resp
def _do_stream_query(self, q_with_prompt: str, llm_config: LLMParameterConfig, question: str,
q_docs: List[Document] = None) -> Iterator[Dict]:
logger.info("invoke stream query")
resp = {"query": question, "result": ""}
if self._source and q_docs:
resp['source_documents'] = [{'metadata': x.metadata, 'page_content': x.page_content} for x in q_docs]
for response in self._llm.chat_streamly(query=q_with_prompt, sys_messages=self._sys_messages,
role=self._role, llm_config=llm_config):
resp['result'] = response
yield resp
class GraphRagText2TextChain(SingleText2TextChain):
def _query(self,
question: str,
llm_config: LLMParameterConfig) -> Union[Dict, Iterator[Dict]]:
contexts = self._retriever.invoke(question)
if self._reranker is not None and len(contexts) > 0:
scores = self._reranker.rerank(question, contexts)
contexts = self._reranker.rerank_top_k(contexts, scores)
input_context = '\n'.join(contexts) if contexts else ""
prompt = TEXT_RAG_TEMPLATE.format(context=input_context, question=question)
if self._llm.llm_config.stream:
return self._do_stream_query(prompt, llm_config, question, [])
return self._do_query(prompt, llm_config, question, [])
def _do_query(self, q_with_prompt: str, llm_config: LLMParameterConfig, question: str, q_docs: List[Document]) \
-> Dict:
logger.info("invoke normal query")
resp = {"query": question, "result": ""}
llm_response = self._llm.chat(query=q_with_prompt, role=self._role, llm_config=llm_config)
resp['result'] = llm_response
return resp
def _do_stream_query(self, q_with_prompt: str, llm_config: LLMParameterConfig, question: str,
q_docs: List[Document] = None) -> Iterator[Dict]:
logger.info("invoke stream query")
resp = {"query": question, "result": ""}
for response in self._llm.chat_streamly(query=q_with_prompt, role=self._role, llm_config=llm_config):
resp['result'] = response
yield resp