import argparse
import base64
import io
import json
import os
import re
import shutil
from http import HTTPStatus
from typing import Optional
from pathlib import Path
import docx
import fitz
import httpx
import uvicorn
from PIL import Image
from fastapi import FastAPI, UploadFile
from langchain_openai import ChatOpenAI
from loguru import logger
from openai import OpenAI
from pydantic import BaseModel
from pymilvus import MilvusClient
from starlette.responses import JSONResponse
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from paddle.base import libpaddle
from mx_rag.embedding.service import TEIEmbedding
from mx_rag.document import LoaderMng
from mx_rag.document.loader import DocxLoader, PdfLoader
from mx_rag.knowledge import KnowledgeStore, KnowledgeDB
from mx_rag.reranker.service import TEIReranker
from mx_rag.retrievers import Retriever, FullTextRetriever
from mx_rag.storage.document_store import MilvusDocstore
from mx_rag.storage.vectorstore import MilvusDB
from mx_rag.utils import ClientParam
app = FastAPI()
upload_file_dir = os.environ.get("UPLOAD_FILE_DIR", "/home/data/files")
images_store_dir = os.environ.get("IMG_STORE_DIR", "/home/data/images")
global_config = {}
img_to_text_prompt = """Given an image containing a table or figure, please provide a structured and detailed
description in chinese with two levels of granularity:
Coarse-grained Description:
- Summarize the overall content and purpose of the image.
- Briefly state what type of data or information is presented (e.g., comparison, trend, distribution).
- Mention the main topic or message conveyed by the table or figure.
Fine-grained Description:
- Describe the specific details present in the image.
- For tables: List the column and row headers, units, and any notable values, patterns, or anomalies.
- For figures (e.g., plots, charts): Explain the axes, data series, legends, and any significant trends, outliers,
or data points.
- Note any labels, captions, or annotations included in the image.
- Highlight specific examples or noteworthy details.
Deliver the description in a clear, organized, and reader-friendly manner, using bullet points or paragraphs
as appropriate, answer in chinese"""
text_infer_prompt = """
You are a helpful question-answering assistant. Your task is to generate a interleaved text and image response based on provided questions and quotes. Here‘s how to refine your process:
1. **Evidence Selection**:
- From both text and image quotes, pinpoint those really relevant for answering the question. Focus on significance and direct relevance.
- Each image quote is the description of the image.
2. **Answer Construction**:
- Use Markdown to embed text and images in your response, avoid using obvious headings or divisions; ensure the response flows naturally and cohesively.
- Conclude with a direct and concise answer to the question in a simple and clear sentence.
3. **Quote Citation**:
- Cite text by adding [index]; for example, quote from the first text should be [1].
- Cite images using the format ``; for the first image, use ``;The {conclusion} should be a concise one-sentence summary of the image’s content.
- Ensure the cite of the image must strict follow ``, do not simply stating "See image1", "image1 shows" ,"[image1]" or "image1".
- Each image or text can only be quoted once.
- Do not cite irrelevant quotes.
- Compose a detailed and articulate interleaved answer to the question.
- Ensure that your answer is logical, informative, and directly ties back to the evidence provided by the quotes.
- Interleaved answer must contain both text and image response.
- Answer in chinese.
"""
def get_config(key, default=None):
return global_config.get(key, default)
def create_llm_chain(base_url, model_name):
http_client = httpx.Client()
root_client = OpenAI(base_url=base_url, api_key="sk_fake", http_client=http_client)
client = root_client.chat.completions
llm = ChatOpenAI(
api_key="sk_fake",
client=client,
model_name=model_name,
temperature=0.5,
streaming=True,
)
return llm
def compose_text_messages(question, text_docs, img_docs):
user_message = "Text Quotes are:"
for i, doc in enumerate(text_docs):
user_message += f"\n[{i + 1}] {doc.page_content}"
user_message += "\nImage Quotes are:"
for i, doc in enumerate(img_docs):
user_message += f"\nimage{i + 1} is described as: {doc.page_content}"
user_message += "\n\n"
user_message += f"The user question is: {question}"
return user_message
def extract_images_from_docx(image_out_dir, file_path):
doc = docx.Document(file_path)
output_folder = os.path.join(image_out_dir, os.path.basename(file_path))
if not os.path.exists(output_folder):
os.makedirs(output_folder)
for rel in doc.part.rels.values():
if "image" in rel.target_ref:
image_part = rel.target_part
image_filename = os.path.basename(image_part.partname)
image_path = os.path.join(output_folder, image_filename)
with open(image_path, "wb") as image_file:
image_file.write(image_part.blob)
logger.info(f"extract images from {file_path} successfully")
def extract_images_from_pdf(image_out_dir, file_path):
output_folder = os.path.join(image_out_dir, os.path.basename(file_path))
if not os.path.exists(output_folder):
os.makedirs(output_folder)
pdf_document = fitz.open(file_path)
for page_num in range(len(pdf_document)):
page = pdf_document.load_page(page_num)
image_list = page.get_images(full=True)
for img_index, img in enumerate(image_list):
xref = img[0]
base_image = pdf_document.extract_image(xref)
image_bytes = base_image["image"]
image_ext = base_image["ext"]
image_filename = f"image_{page_num + 1}_{img_index + 1}.{image_ext}"
image_path = f"{output_folder}/{image_filename}"
with open(image_path, "wb") as image_file:
image_file.write(image_bytes)
pdf_document.close()
logger.info(f"extract images from {file_path} successfully")
def get_model_name(base_url):
"""
调用 /v1/models 接口获取模型列表,返回第一个可用模型的名称
:param base_url: API 基础地址
:return: 模型名称或 None
"""
try:
models_url = f"{base_url.rstrip('/')}/models"
response = httpx.get(models_url, timeout=30)
response.raise_for_status()
data = response.json()
if "data" in data and len(data["data"]) > 0:
return data["data"][0]["id"]
return None
except Exception as e:
logger.error(f"Failed to get model name from {base_url}: {e}")
return None
def get_embedding(url):
return TEIEmbedding(url=url, client_param=ClientParam(use_http=True))
def get_embedding_dim(url):
emb = get_embedding(url)
return len(emb.embed_query("The capital of China is Beijing."))
def get_vector_store():
knowledge_name = get_config("knowledge_name")
milvus_client = MilvusClient(get_config("milvus_url"))
vector_store = MilvusDB.create(
client=milvus_client,
x_dim=get_config("embedding_dim"),
collection_name=f"{knowledge_name}_vector",
)
return vector_store
def get_chunk_store():
knowledge_name = get_config("knowledge_name")
milvus_client = MilvusClient(get_config("milvus_url"))
return MilvusDocstore(milvus_client, collection_name=f"{knowledge_name}_chunk")
def get_knowledge_db():
chunk_store = get_chunk_store()
vector_store = get_vector_store()
knowledge_store = KnowledgeStore(db_path="./knowledge_store_sql.db")
knowledge_name = get_config("knowledge_name")
knowledge_store.add_knowledge(knowledge_name, user_id="knowledge_demo")
knowledge_db = KnowledgeDB(
knowledge_store=knowledge_store,
chunk_store=chunk_store,
vector_store=vector_store,
knowledge_name=knowledge_name,
white_paths=["/home"],
user_id="knowledge_demo",
)
return knowledge_db
def get_document_loader_splitter(file_suffix):
loader_mng = LoaderMng()
loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"])
loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"])
loader_mng.register_loader(loader_class=PdfLoader, file_types=[".pdf"])
loader_mng.register_splitter(
splitter_class=RecursiveCharacterTextSplitter,
file_types=[".docx", ".txt", ".md", ".pdf"],
splitter_params={
"chunk_size": 750,
"chunk_overlap": 150,
"keep_separator": False,
},
)
loader_info = loader_mng.get_loader(file_suffix)
splitter_info = loader_mng.get_splitter(file_suffix)
return loader_info, splitter_info
def retrieve_similarity_docs(query, top_k, score_threshold):
emb = get_embedding(get_config("embedding_url"))
chunk_store = get_chunk_store()
vector_store = get_vector_store()
dense_retriever = Retriever(
vector_store=vector_store,
document_store=chunk_store,
embed_func=emb.embed_documents,
k=top_k,
score_threshold=score_threshold,
)
dense_res = dense_retriever.invoke(query)
full_text_retriever = FullTextRetriever(document_store=chunk_store, k=top_k)
full_text_res = full_text_retriever.invoke(query)
docs = dense_res + full_text_res
contents = set()
new_docs = []
for doc in docs:
if doc.page_content not in contents:
new_docs.append(doc)
contents.add(doc.page_content)
logger.info("retrieve similarity chunks from knowledge successfully")
return new_docs
def generate_answer(query, q_docs):
text_docs = [doc for doc in q_docs if doc.metadata.get("type", "") == "text"]
img_docs = [doc for doc in q_docs if doc.metadata.get("type", "") == "image"]
text = compose_text_messages(query, text_docs, img_docs)
messages = [
{"role": "system", "content": text_infer_prompt},
{"role": "user", "content": text},
]
llm_chain = create_llm_chain(base_url=get_config("llm_base_url"), model_name=get_config("llm_model_name"))
response = ""
try:
response = llm_chain.invoke(messages).content
logger.info("generate answer by llm successfully")
except Exception as e:
logger.error(f"call llm invoke failed:{e}")
result = replace_image_paths(response, img_docs)
return result
def replace_image_paths(text, image_docs):
"""
将response中的imageN替换为image_docs中对应的image_path
"""
if text == "":
return text
image_mapping = {}
for i, doc in enumerate(image_docs):
image_key = f"image{i + 1}"
image_path = doc.metadata.get("image_path", "")
image_mapping[image_key] = image_path
def replace_match(match):
full_match = match.group(0)
alt_text = match.group(1)
image_ref = match.group(2)
if image_ref in image_mapping:
return f""
else:
return full_match
pattern = r"!\[([^\]]*)\]\((image\d+)\)"
updated_response = re.sub(pattern, replace_match, text)
logger.info("Image paths replacement completed")
return updated_response
def extract_image_info_by_vlm(image_path):
with Image.open(image_path) as img:
width, height = img.size
if width < 256 and height < 256:
logger.warning(f"image:{image_path} size: ({width},{height}) too little, will be discarded")
return ""
buffer = io.BytesIO()
if Path(image_path).suffix == ".png":
img = img.convert("RGB")
if width > 1024 or height > 1024:
img = img.resize(size=(width // 2, height // 2))
img.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": img_to_text_prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image;base64,{img_str}"},
},
],
}
]
vlm = create_llm_chain(base_url=get_config("vlm_base_url"), model_name=get_config("vlm_model_name"))
try:
return vlm.invoke(messages).content
except Exception as e:
logger.error(f"call vlm invoke failed:{e}")
return ""
def find_images_files(directory, recursive=False):
base_path = Path(directory)
exts = (".jpg", ".jpeg", ".png")
files = []
for ext in exts:
pattern = f"**/*{ext}" if recursive else f"*{ext}"
file_list = list(base_path.glob(pattern))
files.extend(str(p) for p in file_list)
return files
def extract_images_info_by_vlm(image_out_dir, file_name):
image_dir = os.path.join(image_out_dir, file_name)
if os.path.exists(os.path.join(image_dir, "image_info.json")):
logger.warning(f"all images in {image_dir} have been extracted, no need to repeat extraction")
return
logger.info(f"start to extract images info in file [{file_name}] by vlm ...")
image_files = find_images_files(image_dir)
info = []
for image_file in image_files:
logger.debug(f"start to deal {[image_file]} by vlm")
res = extract_image_info_by_vlm(image_file)
if res:
info.append({"image_path": image_file, "image_description": res})
if len(info) > 0:
with open(os.path.join(image_dir, "image_info.json"), "w", encoding="utf-8") as f:
f.write(json.dumps(info, indent=4, ensure_ascii=False))
logger.info("extract images info successfully")
class RetrievalParam(BaseModel):
query: str
knowledge_id: str = ""
retrieval_setting: Optional[dict] = {"top_k": 3, "score_threshold": 0.5}
@app.post("/retrieval")
async def retrieve(arg: RetrievalParam):
logger.info("doing retrieval")
top_k = int(arg.retrieval_setting.get("top_k", 3))
try:
score_threshold = arg.retrieval_setting.get("score_threshold", 0.5)
except (ValueError, TypeError):
score_threshold = 0.5
text_reranker = TEIReranker(url=get_config("reranker_url"), k=top_k, client_param=ClientParam(use_http=True))
q_docs = retrieve_similarity_docs(arg.query, top_k, score_threshold)
if text_reranker is not None and len(q_docs) > 0:
score = text_reranker.rerank(arg.query, [doc.page_content for doc in q_docs])
q_docs = text_reranker.rerank_top_k(q_docs, score)
records = []
for doc in q_docs:
records.append(
{
"content": doc.page_content,
"score": doc.metadata.get("score", 0),
"title": doc.metadata.get("source", ""),
"metadata": doc.metadata,
}
)
return JSONResponse(content={"records": records})
class QueryParam(BaseModel):
query: str
retrieval_setting: Optional[dict] = {"top_k": 3, "score_threshold": 0.5}
@app.post("/query")
async def query_answer(arg: QueryParam):
logger.info("doing query")
top_k = int(arg.retrieval_setting.get("top_k", 3))
try:
score_threshold = arg.retrieval_setting.get("score_threshold", 0.5)
except (ValueError, TypeError):
score_threshold = 0.5
q_docs = retrieve_similarity_docs(arg.query, top_k, score_threshold)
text_reranker = TEIReranker(url=get_config("reranker_url"), k=top_k, client_param=ClientParam(use_http=True))
if text_reranker is not None:
score = text_reranker.rerank(arg.query, [doc.page_content for doc in q_docs])
q_docs = text_reranker.rerank_top_k(q_docs, score)
response = generate_answer(arg.query, q_docs)
with open("response.md", "w", encoding="utf-8") as f:
f.write(response)
return JSONResponse(content=response)
@app.post("/upload_file")
async def upload_file(file: UploadFile):
logger.info(f"start to upload file: {file.filename}")
ALLOWED_EXTENSIONS = {".txt", ".md", ".docx", ".pdf"}
MAX_FILE_SIZE = 50 * 1024 * 1024
file_ext = Path(file.filename).suffix.lower()
if file_ext not in ALLOWED_EXTENSIONS:
logger.warning(f"File type {file_ext} not allowed")
return JSONResponse(
status_code=HTTPStatus.BAD_REQUEST,
content={"error_info": f"File type not allowed. Allowed types: {', '.join(ALLOWED_EXTENSIONS)}"},
)
if not os.path.exists(upload_file_dir):
os.makedirs(upload_file_dir)
try:
contents = file.file.read()
if len(contents) > MAX_FILE_SIZE:
logger.warning(f"File size {len(contents)} exceeds limit {MAX_FILE_SIZE}")
return JSONResponse(
status_code=HTTPStatus.BAD_REQUEST,
content={"error_info": "File size exceeds limit (max 50MB)"},
)
file_path = os.path.join(upload_file_dir, file.filename)
with open(file_path, "wb") as f:
f.write(contents)
except Exception as e:
logger.error(f"write file {file.filename} failed: {e}")
return JSONResponse(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
content={"error_info": "Something went wrong when upload file"},
)
finally:
file.file.close()
file_obj = Path(file_path)
if get_config("parse_image"):
if file_obj.suffix == ".docx":
extract_images_from_docx(images_store_dir, file_path)
if file_obj.suffix == ".pdf":
extract_images_from_pdf(images_store_dir, file_path)
extract_images_info_by_vlm(images_store_dir, file.filename)
loader_info, splitter_info = get_document_loader_splitter(file_obj.suffix)
emb = get_embedding(get_config("embedding_url"))
knowledge_db = get_knowledge_db()
file_base_name = os.path.basename(file_path)
if knowledge_db.check_document_exist(file_base_name):
logger.warning(f"file {file_base_name} exists in knowledge db")
return
loader = loader_info.loader_class(file_path=file_obj.as_posix(), **loader_info.loader_params)
splitter = splitter_info.splitter_class(**splitter_info.splitter_params)
docs = loader.load_and_split(splitter)
texts = [doc.page_content for doc in docs if doc.page_content]
meta_data = [{**doc.metadata, "type": "text"} for doc in docs if doc.page_content]
if file_obj.suffix in [".docx", ".pdf"]:
file_image_out_dir = os.path.join(images_store_dir, file_base_name)
try:
with open(
os.path.join(file_image_out_dir, "image_info.json"),
"r",
encoding="utf-8",
) as f:
images_description = json.load(f)
except Exception as e:
logger.warning(f"read image info failed: {e}")
images_description = {}
for description in images_description:
texts.append(description["image_description"])
meta_data.append(
{
"type": "image",
"source": file_path,
"image_path": description.get("image_path"),
}
)
knowledge_db.add_file(file_obj, texts, {"dense": emb.embed_documents}, meta_data)
logger.info(f"upload file {file.filename} to knowledge successfully")
return JSONResponse(content={"info": f"upload file:{file.filename} successfully"})
class DeleteFileParam(BaseModel):
file_name: str
confirm: bool = False
@app.post("/delete_file")
async def delete_file(file: DeleteFileParam):
logger.info(f"start to delete file:{file.file_name}")
if not file.file_name or not file.file_name.strip():
logger.warning("File name cannot be empty")
return JSONResponse(
status_code=HTTPStatus.BAD_REQUEST,
content={"error_info": "File name cannot be empty"},
)
if ".." in file.file_name or "/" in file.file_name or "\\" in file.file_name:
logger.warning(f"Invalid file name: {file.file_name}")
return JSONResponse(
status_code=HTTPStatus.BAD_REQUEST,
content={"error_info": "Invalid file name"},
)
if not file.confirm:
logger.warning(f"Deletion not confirmed for file: {file.file_name}")
return JSONResponse(
status_code=HTTPStatus.BAD_REQUEST,
content={"error_info": "Please set 'confirm: true' to confirm deletion"},
)
file_path = os.path.join(upload_file_dir, file.file_name)
if not os.path.exists(file_path):
logger.warning(f"File not found: {file.file_name}")
return JSONResponse(
status_code=HTTPStatus.NOT_FOUND,
content={"error_info": f"File '{file.file_name}' not found"},
)
try:
knowledge_db = get_knowledge_db()
knowledge_db.delete_file(file.file_name)
os.remove(file_path)
if get_config("parse_image"):
image_dir = os.path.join(images_store_dir, file.file_name)
if os.path.exists(image_dir):
shutil.rmtree(image_dir)
except Exception as e:
logger.error(f"delete file [{file.file_name}] failed: {e}")
return JSONResponse(content=f"delete file [{file.file_name}] failed: {e}")
return JSONResponse(content=f"delete file [{file.file_name}] successfully")
class DeleteAllFilesParam(BaseModel):
confirm: bool = False
@app.delete("/delete_all_files")
async def delete_all_files(param: DeleteAllFilesParam):
logger.info("start to delete all files")
if not param.confirm:
logger.warning("Deletion not confirmed")
return JSONResponse(
status_code=HTTPStatus.BAD_REQUEST,
content={"error_info": "Please set 'confirm: true' to confirm deletion"},
)
knowledge_db = get_knowledge_db()
all_docs = knowledge_db.get_all_documents()
if len(all_docs) == 0:
logger.warning("No files to delete")
return JSONResponse(
status_code=HTTPStatus.BAD_REQUEST,
content={"error_info": "Knowledge base is already empty"},
)
logger.warning(f"DELETING ALL {len(all_docs)} FILES - this action cannot be undone!")
try:
knowledge_db.delete_all()
vector_store = get_vector_store()
chunk_store = get_chunk_store()
vector_store.drop_collection()
chunk_store.drop_collection()
try:
if os.path.exists(upload_file_dir):
shutil.rmtree(upload_file_dir)
if get_config("parse_image") and os.path.exists(images_store_dir):
shutil.rmtree(images_store_dir)
except Exception as e:
logger.info(f"delete directories failed: {e}")
except Exception as e:
logger.error(f"delete all files failed: {e}")
return JSONResponse(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
content={"error_info": f"Failed to delete all files: {e}"},
)
return JSONResponse({"message": "ok"})
@app.get("/list_files")
async def list_files():
logger.info("start to list all file")
knowledge_db = get_knowledge_db()
docs = knowledge_db.get_all_documents()
files = [doc.document_name for doc in docs]
return JSONResponse(content=files)
class QueryFileContentParam(BaseModel):
file_name: str
@app.post("/query_file_content")
async def query_file_content(file: QueryFileContentParam):
logger.info(f"start to query file:{file.file_name}")
doc_id = 0
knowledge_db = get_knowledge_db()
documents = knowledge_db.get_all_documents()
for doc in documents:
if doc.document_name == file.file_name:
doc_id = doc.document_id
break
if not doc_id:
logger.error(f"there is no {file.file_name} in db")
return JSONResponse(content="query file content failed")
chunk_store = get_chunk_store()
chunks = chunk_store.search_by_document_id(doc_id)
return JSONResponse(content=[{"page_content": chunk.page_content, "metadata": chunk.metadata} for chunk in chunks])
def load_config():
class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter):
def _get_default_metavar_for_optional(self, action):
return action.type.__name__
def _get_default_metavar_for_positional(self, action):
return action.type.__name__
parser = argparse.ArgumentParser(formatter_class=CustomFormatter)
parser.add_argument(
"--llm_base_url",
type=str,
default="http://192.168.9.146:1025/v1",
help="LLM大模型服务地址",
)
parser.add_argument(
"--vlm_base_url",
type=str,
default="http://192.168.9.146:8000/v1",
help="VLM大模型服务地址",
)
parser.add_argument("--host", type=str, default="192.168.9.146", help="服务host")
parser.add_argument("--port", type=int, default=9098, help="服务端口")
parser.add_argument("--ssl_key_file", type=str, help="ssl秘钥文件")
parser.add_argument("--ssl_cert_file", type=str, help="ssl证书文件")
parser.add_argument("--ssl_ca_certs", type=str, help="ssl证书根证书文件")
parser.add_argument("--ssl_cert_reqs", type=str, help="ssl证书验证要求")
parser.add_argument(
"--embedding_url",
type=str,
default="http://192.168.9.146:9123/v1/embeddings",
help="向量模型服务地址",
)
parser.add_argument(
"--reranker_url",
type=str,
default="http://192.168.9.146:9124/v1/rerank",
help="排序模型服务地址",
)
parser.add_argument("--milvus_url", type=str, default="./milvus.db", help="milvus数据库服务地址")
parser.add_argument("--knowledge_name", type=str, default="test", help="知识库名称")
parser.add_argument("--parse_image", action="store_true", help="是否解析图片信息")
args = parser.parse_args()
config = vars(args)
config["embedding_dim"] = get_embedding_dim(args.embedding_url)
config["llm_model_name"] = get_model_name(args.llm_base_url)
if config["parse_image"]:
vlm_model_name = get_model_name(args.vlm_base_url)
if vlm_model_name is None:
raise ValueError("parse_image is True, but vlm_base_url is invalid")
config["vlm_model_name"] = vlm_model_name
for k, v in config.items():
setattr(args, k, v)
return args, config
def main():
args, config = load_config()
global global_config
global_config = config
uvicorn.run(
app,
host=args.host,
port=args.port,
ssl_keyfile=args.ssl_key_file,
ssl_certfile=args.ssl_cert_file,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
)
if __name__ == "__main__":
main()