import argparse
import threading
import traceback
from loguru import logger
from pymilvus import MilvusClient
from mx_rag.chain import SingleText2TextChain
from mx_rag.embedding.service import TEIEmbedding
from mx_rag.llm import Text2TextLLM
from mx_rag.retrievers import Retriever
from mx_rag.storage.document_store import MilvusDocstore
from mx_rag.storage.vectorstore import MilvusDB
from mx_rag.utils import ClientParam
class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter):
def _get_default_metavar_for_optional(self, action):
return action.type.__name__
def _get_default_metavar_for_positional(self, action):
return action.type.__name__
class ThreadWithResult(threading.Thread):
def __init__(self, group=None, target=None, name=None, args=None, kwargs=None, *, daemon=None):
if args is None:
args = ()
if kwargs is None:
kwargs = {}
self.result = None
def function():
self.result = target(*args, **kwargs)
super().__init__(group=group, target=function, name=name, daemon=daemon)
def rag_demo_query():
parse = argparse.ArgumentParser(formatter_class=CustomFormatter)
parse.add_argument(
"--embedding_url",
type=str,
default="http://127.0.0.1:8080/embed",
help="使用TEI服务化的embedding模型url地址",
)
parse.add_argument(
"--llm_url",
type=str,
default="http://127.0.0.1:1025/v1/chat/completions",
help="大模型url地址",
)
parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称")
parse.add_argument(
"--score_threshold",
type=float,
default=0.5,
help="相似性得分的阈值,大于阈值认为检索的信息与问题越相关,取值范围[0,1]",
)
parse.add_argument("--query", type=str, action="append", help="用户问题")
parse.add_argument("--num_threads", type=int, default=2, help="可以根据实际情况调整线程数量")
args = parse.parse_args().__dict__
embedding_url: str = args.pop("embedding_url")
llm_url: str = args.pop("llm_url")
model_name: str = args.pop("model_name")
score_threshold: float = args.pop("score_threshold")
query: list[str] = args.pop("query")
num_threads: int = args.pop("num_threads")
try:
emb = TEIEmbedding(url=embedding_url, client_param=ClientParam(use_http=True))
embedding_dim = len(emb.embed_documents(["test"])[0])
client = MilvusClient("./milvus.db")
vector_store = MilvusDB.create(client=client, x_dim=embedding_dim, collection_name="milvus_vector")
chunk_store = MilvusDocstore(client=client, collection_name="milvus_chunk")
text_retriever = Retriever(
vector_store=vector_store,
document_store=chunk_store,
embed_func=emb.embed_documents,
k=1,
score_threshold=score_threshold,
)
llm = Text2TextLLM(
base_url=llm_url,
model_name=model_name,
client_param=ClientParam(use_http=True, timeout=60),
)
def process_query(input_string: str) -> str:
text2text_chain = SingleText2TextChain(retriever=text_retriever, llm=llm)
res = text2text_chain.query(input_string)
logger.info(res)
return f"{res}"
results = []
batch_size = len(query) // num_threads
if len(query) % num_threads != 0:
batch_size += 1
batchs = [query[i : i + batch_size] for i in range(0, len(query), batch_size)]
threads = []
for batch in batchs:
def process_batch(batch):
batch_results = []
for s in batch:
batch_results.append(process_query(s))
return batch_results
thread = ThreadWithResult(target=process_batch, args=(batch,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
results.extend(thread.result)
return results
except Exception as e:
stack_trace = traceback.format_exc()
logger.error(stack_trace)
raise e
if __name__ == "__main__":
rag_demo_query()