import argparse
import threading
import traceback
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from loguru import logger
from pymilvus import MilvusClient
from mx_rag.document import LoaderMng
from mx_rag.document.loader import DocxLoader, PdfLoader
from mx_rag.embedding.service import TEIEmbedding
from mx_rag.knowledge import KnowledgeDB
from mx_rag.knowledge.handler import upload_files
from mx_rag.knowledge.knowledge import KnowledgeStore
from mx_rag.storage.document_store import MilvusDocstore
from mx_rag.storage.vectorstore import MilvusDB
from mx_rag.utils import ClientParam
class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter):
def _get_default_metavar_for_optional(self, action):
return action.type.__name__
def _get_default_metavar_for_positional(self, action):
return action.type.__name__
def rag_demo_upload():
parse = argparse.ArgumentParser(formatter_class=CustomFormatter)
parse.add_argument("--tei_emb", type=bool, default=False, help="是否使用TEI服务化的embedding模型")
parse.add_argument(
"--embedding_url",
type=str,
default="http://127.0.0.1:8080/embed",
help="使用TEI服务化的embedding模型url地址",
)
parse.add_argument("--white_path", type=str, nargs="+", default=["/home"], help="文件白名单路径")
parse.add_argument(
"--file_path",
type=str,
action="append",
help="要上传的文件路径,需在白名单路径下",
)
parse.add_argument("--num_threads", type=int, default=2, help="可以根据实际情况调整线程数量")
args = parse.parse_args().__dict__
embedding_url: str = args.pop("embedding_url")
white_path: list[str] = args.pop("white_path")
file_path: list[str] = args.pop("file_path")
num_threads: int = args.pop("num_threads")
try:
loader_mng = LoaderMng()
loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"])
loader_mng.register_loader(loader_class=PdfLoader, file_types=[".pdf"])
loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"])
loader_mng.register_splitter(
splitter_class=RecursiveCharacterTextSplitter,
file_types=[".pdf", ".docx", ".txt", ".md"],
splitter_params={
"chunk_size": 750,
"chunk_overlap": 150,
"keep_separator": False,
},
)
emb = TEIEmbedding(url=embedding_url, client_param=ClientParam(use_http=True))
embedding_dim = len(emb.embed_documents(["test"])[0])
client = MilvusClient("./milvus.db")
vector_store = MilvusDB.create(client=client, x_dim=embedding_dim, collection_name="milvus_vector")
chunk_store = MilvusDocstore(client=client, collection_name="milvus_chunk")
knowledge_store = KnowledgeStore(db_path="./milvus.db")
knowledge_store.add_knowledge("test", "Default", "admin")
knowledge_db = KnowledgeDB(
knowledge_store=knowledge_store,
chunk_store=chunk_store,
vector_store=vector_store,
knowledge_name="test",
white_paths=white_path,
user_id="Default",
)
batch_size = len(file_path) // num_threads
if len(file_path) % num_threads != 0:
batch_size += 1
file_batchs = [file_path[i : i + batch_size] for i in range(0, len(file_path), batch_size)]
threads = []
for batch in file_batchs:
thread = threading.Thread(
target=upload_files,
args=(knowledge_db, batch, loader_mng, emb.embed_documents, True),
)
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
documents = [document.document_name for document in knowledge_db.get_all_documents()]
logger.info(documents)
except Exception:
stack_trace = traceback.format_exc()
logger.error(stack_trace)
if __name__ == "__main__":
rag_demo_upload()