"""文件上传至RAG,作为临时语料"""
import logging
import uuid
import httpx
from fastapi import status
from apps.common.config import config
from apps.models import Document
from apps.schemas.rag_data import (
RAGFileParseReq,
RAGFileParseReqItem,
RAGFileStatusRspItem,
)
logger = logging.getLogger(__name__)
rag_host = config.rag.rag_service
_RAG_DOC_PARSE_URI = rag_host.rstrip("/") + "/doc/temporary/parser"
_RAG_DOC_STATUS_URI = rag_host.rstrip("/") + "/doc/temporary/status"
_RAG_DOC_DELETE_URI = rag_host.rstrip("/") + "/doc/temporary/delete"
class KnowledgeBaseService:
"""知识库服务"""
@staticmethod
async def send_file_to_rag(auth_header: str, docs: list[Document]) -> list[str]:
"""上传文件给RAG,进行处理和向量化"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {auth_header}",
}
rag_docs = [RAGFileParseReqItem(
id=str(doc.id),
name=doc.name,
bucket_name="document",
type=doc.extension,
)
for doc in docs
]
post_data = RAGFileParseReq(
document_list=rag_docs).model_dump(exclude_none=True, by_alias=True)
async with httpx.AsyncClient() as client:
resp = await client.post(_RAG_DOC_PARSE_URI, headers=headers, json=post_data, timeout=30.0)
resp_data = resp.json()
if resp.status_code != status.HTTP_200_OK:
return []
return resp_data["result"]
@staticmethod
async def delete_doc_from_rag(auth_header: str, doc_ids: list[str]) -> list[str]:
"""删除文件"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {auth_header}",
}
delete_data = {"ids": doc_ids}
async with httpx.AsyncClient() as client:
resp = await client.post(_RAG_DOC_DELETE_URI, headers=headers, json=delete_data, timeout=30.0)
resp_data = resp.json()
if resp.status_code != status.HTTP_200_OK:
return []
return resp_data["result"]
@staticmethod
async def get_doc_status_from_rag(auth_header: str, doc_ids: list[uuid.UUID]) -> list[RAGFileStatusRspItem]:
"""获取文件状态"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {auth_header}",
}
post_data = {"ids": doc_ids}
async with httpx.AsyncClient() as client:
resp = await client.post(_RAG_DOC_STATUS_URI, headers=headers, json=post_data, timeout=30.0)
resp_data = resp.json()
if resp.status_code != status.HTTP_200_OK:
return []
return [RAGFileStatusRspItem.model_validate(item) for item in resp_data["result"]]