"""文件Manager"""
import logging
import uuid
import magic
from fastapi import UploadFile
from sqlalchemy import and_, func, select
from apps.common.postgres import postgres
from apps.models import (
ConvDocAssociated,
Conversation,
ConversationDocument,
Document,
Record,
)
from .knowledge_service import KnowledgeBaseService
logger = logging.getLogger(__name__)
class DocumentManager:
"""文件相关操作"""
@staticmethod
async def storage_docs(
user_id: str, conversation_id: uuid.UUID, documents: list[UploadFile],
) -> list[Document]:
"""存储多个文件"""
uploaded_files = []
for document in documents:
if document.filename is None or document.size is None:
logger.error("[DocumentManager] 文件名或大小为空: %s, %s", document.filename, document.size)
continue
mime = magic.from_buffer(document.file.read(), mime=True)
doc_info = Document(
userId=user_id,
name=document.filename,
extension=mime,
size=document.size / 1024.0,
conversationId=conversation_id,
)
uploaded_files.append(doc_info)
async with postgres.session() as session:
session.add_all(uploaded_files)
await session.commit()
return uploaded_files
@staticmethod
async def get_unused_docs(conversation_id: uuid.UUID) -> list[Document]:
"""获取Conversation中未使用的文件"""
async with postgres.session() as session:
conv = (await session.scalars(
select(ConversationDocument).where(
and_(
ConversationDocument.conversationId == conversation_id,
ConversationDocument.isUnused.is_(True),
),
),
)).all()
if not conv:
logger.error("[DocumentManager] 对话不存在: %s", conversation_id)
return []
docs_ids = [doc.documentId for doc in conv]
docs = (await session.scalars(select(Document).where(Document.id.in_(docs_ids)))).all()
return list(docs)
@staticmethod
async def get_used_docs_by_record(record_id: str, doc_type: str | None = None) -> list[Document]:
"""获取RecordGroup关联的文件"""
if doc_type not in ["question", "answer", None]:
logger.error("[DocumentManager] 参数错误: %s", doc_type)
return []
async with postgres.session() as session:
record_docs = (await session.scalars(
select(ConversationDocument).where(ConversationDocument.recordId == record_id),
)).all()
if not list(record_docs):
logger.info("[DocumentManager] 记录组不存在: %s", record_id)
return []
doc_infos: list[Document] = []
for doc in record_docs:
doc_info = (await session.scalars(select(Document).where(Document.id == doc.documentId))).one_or_none()
if doc_info:
doc_infos.append(doc_info)
return doc_infos
@staticmethod
async def get_used_docs(
conversation_id: uuid.UUID, record_num: int | None = 10, doc_type: str | None = None,
) -> list[Document]:
"""获取最后n次问答所用到的文件"""
if doc_type not in ["question", "answer", None]:
logger.error("[DocumentManager] 参数错误: %s", doc_type)
return []
async with postgres.session() as session:
records = (await session.scalars(
select(Record).where(
Record.conversationId == conversation_id,
).order_by(Record.createdAt.desc()).limit(record_num),
)).all()
docs = []
for current_record in records:
record_docs = (
await session.scalars(
select(ConversationDocument).where(ConversationDocument.recordId == current_record.id),
)
).all()
if list(record_docs):
docs += [doc.documentId for doc in record_docs]
docs = list(set(docs))
result = []
for doc_id in docs:
doc = (await session.scalars(select(Document).where(Document.id == doc_id))).one_or_none()
if doc:
result.append(doc)
return result
@staticmethod
async def delete_document(user_id: str, document_list: list[str]) -> None:
"""从未使用文件列表中删除一个文件"""
async with postgres.session() as session:
for doc in document_list:
doc_info = await session.scalars(
select(Document).where(
and_(
Document.id == doc,
Document.userId == user_id,
),
),
)
if not doc_info:
logger.error("[DocumentManager] 文件不存在: %s", doc)
continue
conv_doc = await session.scalars(
select(ConversationDocument).where(
and_(
ConversationDocument.documentId == doc,
ConversationDocument.isUnused.is_(True),
),
),
)
if not conv_doc:
logger.error("[DocumentManager] 文件不存在或已使用: %s", doc)
continue
await session.delete(conv_doc)
await session.delete(doc_info)
await session.commit()
@staticmethod
async def delete_document_by_conversation_id(
conversation_id: uuid.UUID, auth_header: str,
) -> list[str]:
"""通过ConversationID删除文件"""
doc_ids = []
async with postgres.session() as session:
docs = (await session.scalars(
select(Document).where(Document.conversationId == conversation_id),
)).all()
for doc in docs:
await session.delete(doc)
await session.commit()
await KnowledgeBaseService.delete_doc_from_rag(auth_header, doc_ids)
return doc_ids
@staticmethod
async def get_doc_count(conversation_id: uuid.UUID) -> int:
"""获取对话文件数量"""
async with postgres.session() as session:
return (await session.scalars(
select(func.count(ConversationDocument.id)).where(
ConversationDocument.conversationId == conversation_id,
),
)).one()
@staticmethod
async def change_doc_status(user_id: str, conversation_id: uuid.UUID) -> None:
"""文件状态由unused改为used"""
async with postgres.session() as session:
conversation = (await session.scalars(
select(Conversation).where(
and_(
Conversation.id == conversation_id,
Conversation.userId == user_id,
),
),
)).one_or_none()
if not conversation:
logger.error("[DocumentManager] 对话不存在: %s", conversation_id)
return
docs = (await session.scalars(
select(ConversationDocument).where(
and_(
ConversationDocument.conversationId == conversation_id,
ConversationDocument.isUnused.is_(True),
),
),
)).all()
if not docs:
return
for doc in docs:
doc.isUnused = False
await session.commit()
@staticmethod
async def save_answer_doc(user_id: str, record_id: uuid.UUID, doc_infos: list[ConversationDocument]) -> None:
"""保存与答案关联的文件(使用PostgreSQL)"""
async with postgres.session() as session:
record = (await session.scalars(
select(Record).where(
and_(
Record.id == record_id,
Record.userId == user_id,
),
),
)).one_or_none()
if not record:
logger.error("[DocumentManager] 记录不存在或非当前用户: %s", record_id)
return
for doc_info in doc_infos:
doc = (await session.scalars(
select(ConversationDocument).where(ConversationDocument.id == doc_info.id),
)).one_or_none()
if doc:
doc.isUnused = False
doc.associated = ConvDocAssociated.ANSWER
doc.recordId = record_id
await session.commit()