import os
import re
import logging
import uuid
import time
import inspect
import asyncio
import json
from functools import wraps
from datetime import datetime, timezone
from pathlib import Path
from typing import List, Dict, Any, Union, Tuple, Optional
from fastapi import status, UploadFile
from openjiuwen.core.retrieval.indexing.processor.parser.auto_file_parser import AutoFileParser
from openjiuwen.core.retrieval.indexing.processor.chunker.chunking import TextChunker
from openjiuwen.core.retrieval.indexing.processor.extractor.triple_extractor import TripleExtractor
from openjiuwen.core.retrieval.indexing.indexer.milvus_indexer import MilvusIndexer
from openjiuwen.core.retrieval.vector_store.milvus_store import MilvusVectorStore
from openjiuwen.core.retrieval.simple_knowledge_base import SimpleKnowledgeBase
from openjiuwen.core.retrieval.graph_knowledge_base import GraphKnowledgeBase
from openjiuwen.core.retrieval.common.config import (
KnowledgeBaseConfig,
EmbeddingConfig,
VectorStoreConfig,
IndexConfig,
)
from openjiuwen.core.retrieval.common.document import Document, TextChunk
from openjiuwen.core.retrieval.embedding.api_embedding import APIEmbedding
from openjiuwen_deepsearch.framework.openjiuwen.llm.llm_model_factory import (
LLMModelFactory,
LLMModelParams,
)
from openjiuwen_deepsearch.utils.common_utils.embedding_utils import (
get_embedding_requests_verify,
)
from server.core.database import SessionLocal, milliseconds
from server.core.kb_obs_requirement import knowledge_base_requires_obs, kb_obs_misconfigured_message
from server.local_retrieval.core.manager.repositories.knowledge_base_repository import (
knowledge_base_repository,
)
from server.schemas.knowledge_base import (
EmbedModelConfig,
LLMConfig,
KnowledgeBaseCreate,
KnowledgeBaseResponseCreate,
KnowledgeBaseGet,
KnowledgeBaseUpdateRequest,
KnowledgeBaseInfo,
DocumentUploadResponse,
DocumentUploadBatchResponse,
KnowledgeBaseSearchRequest,
KnowledgeBaseSearchResponse,
KnowledgeBaseListRequest,
KnowledgeBaseListResponse,
KnowledgeBaseListItem,
DocumentStatusRequest,
DocumentStatusResponse,
DocumentStatusListResponse,
DocumentProcessRequest,
DocumentProcessResponse,
DocumentListRequest,
DocumentListResponse,
DocumentListItem,
DocumentUpdateRequest,
DocumentDeleteRequest,
TaskProgressRequest,
TaskProgressResponse,
TaskProgressItem,
)
from server.schemas.common import ResponseModel
from server.local_retrieval.models.knowledge_base_document import DocumentStatus
from server.core.manager.model_manager.utils import SecurityUtils
from server.local_retrieval.core.object.aioboto_storage_client import AioBotoClient
logger = logging.getLogger(__name__)
_RESILIENT_PDF_REGISTERED = False
def _ensure_resilient_pdf_parser() -> None:
"""Use PDF parser that clamps image bboxes; avoids pdfplumber crop failures on bad PDFs."""
global _RESILIENT_PDF_REGISTERED
if _RESILIENT_PDF_REGISTERED:
return
from server.local_retrieval.core.parser.resilient_pdf_parser import ResilientPDFParser
AutoFileParser.register_new_parser(".pdf", lambda: ResilientPDFParser())
_RESILIENT_PDF_REGISTERED = True
class OBSDocumentManager:
"""
Manages OBS documents and uploads/downloads them to/from OBS
"""
backend_dir = Path(__file__).resolve().parent.parent.parent.parent.parent
def __init__(self, bucket: Optional[str] = None):
self.bucket = bucket or os.getenv("OBS_BUCKET")
if not self.bucket:
logger.warning("[OBS] OBS_BUCKET not set, skipping upload_document")
if AioBotoClient is None:
logger.warning("[OBS] AioBotoClient not available, OBS operations will be no-op")
self.obs_client = None
return
server = os.getenv("OBS_SERVER")
region_name = os.getenv("OBS_REGION")
access_key_id = SecurityUtils.get_decrypted_secret(
"OBS_ACCESS_KEY_ID",
os.getenv("OBS_ACCESS_KEY_ID", None),
)
secret_access_key = SecurityUtils.get_decrypted_secret(
"OBS_SECRET_ACCESS_KEY",
os.getenv("OBS_SECRET_ACCESS_KEY", None),
)
self.obs_client = AioBotoClient(
server=server,
access_key_id=access_key_id,
secret_access_key=secret_access_key,
region_name=region_name,
)
@staticmethod
def obs_name(space_id: str, kb_id: str, file_name: str) -> str:
"""生成对象存储中的文件对象名。"""
return f"{space_id}/{kb_id}/{file_name}"
@classmethod
def local_path(cls, space_id: str, kb_id: str, file_name: str) -> Path:
"""生成本地文件存储路径。"""
storage_path = cls.backend_dir / "data" / "knowledge_base" / space_id / kb_id
storage_path.mkdir(parents=True, exist_ok=True)
return storage_path / file_name
async def delete_document(self, object_name: str):
"""删除对象存储中的文档对象。"""
if not self.bucket or not self.obs_client:
return
await self.obs_client.delete_object(self.bucket, object_name)
async def download_document(
self,
object_name: str,
file_path: str | Path,
):
"""从对象存储下载文档到本地路径。"""
if not self.bucket or not self.obs_client:
return
file_path = Path(file_path)
file_dir = file_path.parent
if not os.path.isdir(file_dir):
file_dir.mkdir(parents=True, exist_ok=True)
await self.obs_client.download_file(self.bucket, object_name, file_path)
async def upload_document(
self,
object_name: str,
file_path: str | Path,
):
"""上传文档到对象存储并返回对象信息。"""
if not self.bucket or not self.obs_client:
raise RuntimeError(
"OBS upload skipped: OBS_BUCKET unset or object storage client unavailable"
)
ok = await self.obs_client.upload_file(self.bucket, object_name, file_path)
if not ok:
raise RuntimeError(f"OBS upload failed for object key {object_name!r} (storage returned False)")
async def download_if_updated(
self,
object_name: str,
file_path: str,
):
"""当远端文档更新时下载并覆盖本地缓存。"""
if not self.bucket or not self.obs_client:
return
listed_objects = await self.obs_client.list_objects(
self.bucket, object_prefix=object_name, max_objects=1
)
if not listed_objects:
logger.info("No matching objects found on OBS, skipping download.")
return
obs_last_modified = listed_objects[0].get("LastModified")
if not obs_last_modified:
logger.info("OBS object missing LastModified, skipping download.")
return
if not os.path.exists(file_path):
await self.download_document(object_name, file_path)
logger.info("Local file missing, downloaded from OBS.")
return
local_mtime = os.path.getmtime(file_path)
local_modified = datetime.fromtimestamp(local_mtime, tz=timezone.utc)
if obs_last_modified <= local_modified:
logger.info("Local file is up to date, skipping download.")
return
await self.download_document(object_name, file_path)
logger.info("Downloaded updated file.")
class LocalSimpleKnowledgeBase(SimpleKnowledgeBase):
async def add_documents(
self,
documents: List[Document],
**kwargs: Any,
) -> List[str]:
"""Add documents to the knowledge base"""
if not self.chunker:
raise ValueError("chunker is required for add_documents")
if not self.index_manager:
raise ValueError("index_manager is required for add_documents")
chunks = self.chunker.chunk_documents(documents)
logger.info(f"Chunked {len(documents)} documents into {len(chunks)} chunks")
index_config = IndexConfig(
index_name=f"ds_kb_{self.config.kb_id}_chunks",
index_type=self.config.index_type,
)
success = await self.index_manager.build_index(
chunks=chunks,
config=index_config,
embed_model=self.embed_model,
)
if not success:
raise RuntimeError("Failed to build index")
doc_ids = [doc.id_ for doc in documents]
logger.info(f"Successfully added {len(doc_ids)} documents to knowledge base")
return doc_ids
class LocalGraphKnowledgeBase(GraphKnowledgeBase):
async def add_documents(
self,
documents: List[Document],
**kwargs: Any,
) -> List[str]:
"""Add documents to the knowledge base (including chunk index and triple index)"""
if not self.chunker:
raise ValueError("chunker is required for add_documents")
if not self.index_manager:
raise ValueError("index_manager is required for add_documents")
chunks = self.chunker.chunk_documents(documents)
logger.info(f"Chunked {len(documents)} documents into {len(chunks)} chunks")
chunk_index_config = IndexConfig(
index_name=f"ds_kb_{self.config.kb_id}_chunks",
index_type=self.config.index_type,
)
success = await self.index_manager.build_index(
chunks=chunks,
config=chunk_index_config,
embed_model=self.embed_model,
)
if not success:
raise RuntimeError("Failed to build chunk index")
if self.config.use_graph and self.extractor:
logger.info("Extracting triples for graph index...")
triples = await self.extractor.extract(chunks)
if triples:
logger.info(f"Extracted {len(triples)} triples")
triple_index_config = IndexConfig(
index_name=f"ds_kb_{self.config.kb_id}_triples",
index_type=self.config.index_type,
)
triple_chunks = []
for i, triple in enumerate(triples):
triple_text = f"{triple.subject} {triple.predicate} {triple.object}"
chunk = TextChunk(
id_=f"triple_{i}",
text=triple_text,
doc_id=triple.metadata.get("doc_id", ""),
metadata={
**triple.metadata,
"triple": json.dumps([triple.subject, triple.predicate, triple.object]),
"confidence": triple.confidence if triple.confidence else 0,
"chunk_index": i,
},
)
triple_chunks.append(chunk)
success = await self.index_manager.build_index(
chunks=triple_chunks,
config=triple_index_config,
embed_model=self.embed_model,
)
if not success:
logger.error("Failed to build triple index")
else:
logger.info(f"Built triple index with {len(triple_chunks)} triples")
doc_ids = [doc.id_ for doc in documents]
logger.info(f"Successfully added {len(doc_ids)} documents to knowledge base")
return doc_ids
def _extract_full_error_message(error: Exception) -> str:
"""提取完整的错误信息,包括异常链中的所有错误
用于提取 openjiuwen 包抛出的异常信息,因为 openjiuwen 包内部可能捕获异常后
使用 cause 参数重新抛出,形成异常链。
Args:
error: 异常对象
Returns:
完整的错误信息字符串,包含所有异常链中的错误
"""
error_parts = []
current_error = error
while current_error is not None:
error_str = str(current_error)
if error_str:
error_parts.append(error_str)
if hasattr(current_error, '__cause__') and current_error.__cause__:
current_error = current_error.__cause__
elif hasattr(current_error, '__context__') and current_error.__context__:
current_error = current_error.__context__
else:
break
if len(error_parts) == 1:
return error_parts[0]
return " -> ".join(error_parts)
def _format_error_message_for_frontend(error_msg: str) -> str:
"""格式化错误信息供前端显示
改写规则:
1. 固定错误消息保持不变
2. 带前缀的错误:去掉前缀、状态码、箭头(替换为分号)
3. 在 "reason" 之前截断(如果存在)
4. 确保首字母大写
Args:
error_msg: 原始错误信息
Returns:
格式化后的错误信息
"""
if not error_msg:
return error_msg
fixed_messages = {
"Document not found",
"Document status invalid",
"File path not found",
"Failed to update document status",
"Document validation failed",
"Processing failed with unknown error",
"Failed to update status to INDEXED",
}
if error_msg in fixed_messages:
return error_msg
result = error_msg
prefixes = [
"File parsing failed: ",
"Index building failed: ",
"Failed to update status to INDEXING: ",
]
for prefix in prefixes:
if result.startswith(prefix):
result = result[len(prefix):]
break
result = re.sub(r'\[\d+\]\s*', '', result)
result = re.sub(r'\s*->\s*', '; ', result)
reason_pattern = r',\s*reason\s*:'
match = re.search(reason_pattern, result, re.IGNORECASE)
if match:
result = result[:match.start()].strip()
result = ' '.join(result.split())
if result:
result = result[0].upper() + result[1:] if len(result) > 1 else result.upper()
return result
def _create_llm_client(llm_config: LLMConfig):
"""从请求配置创建 LLM 客户端
Args:
llm_config: 请求中的 LLM 配置
Returns:
Tuple[LLM客户端实例, model_name]: (LLM客户端, 模型名称)
"""
logger.info(
f"[LLM_CLIENT] Creating LLM client from request config - "
f"Model: {llm_config.model_name}, Type: {llm_config.model_type}"
)
api_key = (
llm_config.api_key.decode("utf-8")
if isinstance(llm_config.api_key, (bytes, bytearray))
else str(llm_config.api_key)
)
timeout = 120
llm_params = LLMModelParams(
model_provider=llm_config.model_type,
api_key=api_key or "",
api_base=llm_config.base_url or "",
timeout=timeout,
hyper_parameters=llm_config.hyper_parameters or None,
extension=llm_config.extension or None,
)
llm_client = LLMModelFactory().get_model(llm_params)
if llm_config.model_name and getattr(llm_client, "model_config", None) is not None:
llm_client.model_config.model_name = llm_config.model_name
logger.info(
f"[LLM_CLIENT] LLM client created - Model: {llm_config.model_name}"
)
return llm_client, llm_config.model_name
def _create_embed_model(embed_model_config: EmbedModelConfig) -> APIEmbedding:
"""从请求配置创建 Embedding 模型
Args:
embed_model_config: 请求中的 Embedding 模型配置
Returns:
APIEmbedding 实例
"""
logger.info(
f"[EMBED_MODEL] Creating embed model from request config - "
f"Model: {embed_model_config.model_name}"
)
api_key = (
embed_model_config.api_key.decode("utf-8")
if isinstance(embed_model_config.api_key, (bytes, bytearray))
else str(embed_model_config.api_key)
)
verify = get_embedding_requests_verify(embed_model_config.base_url)
if isinstance(verify, str) and verify:
os.environ["REQUESTS_CA_BUNDLE"] = verify
embed_config = EmbeddingConfig(
model_name=embed_model_config.model_name,
api_key=api_key,
base_url=embed_model_config.base_url,
)
api_kwargs: dict = {
"config": embed_config,
"timeout": embed_model_config.timeout,
"max_retries": embed_model_config.max_retries,
"max_batch_size": embed_model_config.max_batch_size,
}
embed_model = APIEmbedding(**api_kwargs)
logger.debug("[EMBED_MODEL] Embed model created from request config successfully")
return embed_model
def _config_dict_to_embed_and_llm(config: dict) -> tuple[EmbedModelConfig, LLMConfig]:
"""从知识库 config 字典解析出 EmbedModelConfig 和 LLMConfig(用于文档处理等)。
config 中应有 "embed_model_config" 和 "llm_config" 键,且 api_key 为 UTF-8 字符串。
"""
if not config:
raise ValueError("知识库 config 为空,无法读取 embed_model_config/llm_config")
embed_dict = config.get("embed_model_config")
llm_dict = config.get("llm_config")
if not embed_dict:
raise ValueError("知识库 config 中缺少 embed_model_config")
if not llm_dict:
raise ValueError("知识库 config 中缺少 llm_config")
embed_model_config = EmbedModelConfig(
model_name=embed_dict.get("model_name", ""),
api_key=embed_dict.get("api_key") or "",
base_url=embed_dict.get("base_url", ""),
max_batch_size=int(embed_dict.get("max_batch_size", 1)),
timeout=int(embed_dict.get("timeout", 60)),
max_retries=int(embed_dict.get("max_retries", 3)),
)
llm_config = LLMConfig(
model_name=llm_dict.get("model_name", ""),
model_type=llm_dict.get("model_type", "openai"),
base_url=llm_dict.get("base_url", ""),
api_key=llm_dict.get("api_key") or "",
hyper_parameters=llm_dict.get("hyper_parameters") or {},
extension=llm_dict.get("extension") or {},
)
return embed_model_config, llm_config
def with_exception_handling(func):
"""异常处理装饰器,支持同步和异步函数"""
if inspect.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as e:
logger.error(f"[KNOWLEDGE_BASE] Error in {func.__name__}: {str(e)}", exc_info=True)
return ResponseModel(
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
message=f"Internal server error: {str(e)}",
)
return async_wrapper
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
logger.error(f"[KNOWLEDGE_BASE] Error in {func.__name__}: {str(e)}", exc_info=True)
return ResponseModel(
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
message=f"Internal server error: {str(e)}",
)
return wrapper
def _make_json_serializable_dict(d: dict) -> dict:
out: Dict[str, Any] = {}
for k, v in d.items():
if isinstance(v, (bytes, bytearray)):
out[k] = v.decode("utf-8") if v else ""
elif isinstance(v, dict):
out[k] = _make_json_serializable_dict(v)
else:
out[k] = v
return out
def _build_kb_stored_config(
embed_model_config: EmbedModelConfig,
llm_config: LLMConfig,
extra: Optional[Dict[str, Any]] = None,
) -> dict:
"""与创建逻辑一致:合并扩展 config 并写入 embed / llm 配置"""
config: Dict[str, Any] = dict(extra) if extra else {}
config["embed_model_config"] = _make_json_serializable_dict(embed_model_config.model_dump())
config["llm_config"] = _make_json_serializable_dict(llm_config.model_dump())
return config
@with_exception_handling
def knowledge_base_create(req: KnowledgeBaseCreate) -> ResponseModel:
"""创建新的知识库"""
start_time = time.time()
logger.info(f"[KB_CREATE] Creating knowledge base - Name: {req.name}")
name_exists_result = knowledge_base_repository.knowledge_base_check_name_exists(
space_id=req.space_id, name=req.name
)
if name_exists_result.code != status.HTTP_200_OK:
logger.error(
f"[KB_CREATE] Failed to check name existence - Error: {name_exists_result.message}"
)
return ResponseModel(
code=name_exists_result.code,
message=name_exists_result.message,
)
if name_exists_result.data:
logger.warning(
f"[KB_CREATE] Knowledge base name already exists - Name: {req.name}, Space: {req.space_id}"
)
return ResponseModel(
code=status.HTTP_400_BAD_REQUEST,
message=f"知识库名称 '{req.name}' 已存在",
)
index_conn = _check_index_connection()
if index_conn is not None:
return index_conn
kb_id = uuid.uuid4().hex
logger.info(f"[KB_CREATE] Generated KB ID: {kb_id}")
config = _build_kb_stored_config(
req.embed_model_config, req.llm_config, req.config
)
kb_data = {
"space_id": req.space_id,
"kb_id": kb_id,
"name": req.name,
"description": req.description,
"config": config,
"create_time": milliseconds(),
"update_time": milliseconds(),
}
create_result = knowledge_base_repository.knowledge_base_create(kb_data)
if create_result.code != status.HTTP_200_OK:
logger.error(
f"[KB_CREATE] Database save failed - ID: {kb_id}, Error: {create_result.message}"
)
return ResponseModel(
code=create_result.code,
message=create_result.message,
)
response_data = KnowledgeBaseResponseCreate(id=kb_id)
logger.info(
f"[KB_CREATE] Knowledge base created - ID: {kb_id}, Duration: {time.time() - start_time:.3f}s"
)
return ResponseModel(
code=status.HTTP_200_OK,
message="create knowledge base success",
data=response_data.model_dump(by_alias=False),
)
@with_exception_handling
async def knowledge_base_delete(req: KnowledgeBaseGet) -> ResponseModel:
"""删除知识库"""
start_time = time.time()
logger.info(f"[KB_DELETE] Deleting knowledge base - KB ID: {req.kb_id}")
get_result = knowledge_base_repository.knowledge_base_get(req)
if get_result.code == status.HTTP_404_NOT_FOUND:
logger.warning(f"[KB_DELETE] Knowledge base not found - ID: {req.kb_id}")
return ResponseModel(code=status.HTTP_404_NOT_FOUND, message="Knowledge base not found")
delete_result = knowledge_base_repository.knowledge_base_delete(req)
if delete_result.code != status.HTTP_200_OK:
logger.error(f"[KB_DELETE] Delete failed - ID: {req.kb_id}, Error: {delete_result.message}")
return ResponseModel(
code=delete_result.code,
message=delete_result.message,
)
logger.info(
f"[KB_DELETE] Knowledge base deleted - ID: {req.kb_id}, "
f"Duration: {time.time() - start_time:.3f}s"
)
kb_storage_path = _get_storage_path(req.space_id, req.kb_id)
try:
if kb_storage_path.exists():
import shutil
shutil.rmtree(kb_storage_path)
logger.info(
f"[KB_DELETE] Local knowledge base directory deleted - Path: {kb_storage_path}"
)
else:
logger.warning(
f"[KB_DELETE] Local knowledge base directory not found - Path: {kb_storage_path}"
)
except Exception as e:
logger.error(
f"[KB_DELETE] Failed to delete local knowledge base directory - Path: {kb_storage_path}, Error: {str(e)}",
exc_info=True,
)
try:
index_result = await _delete_kb_indices(req.kb_id, req.space_id)
if index_result["success_count"] > 0:
logger.info(
f"[KB_DELETE] Indices successfully deleted - KB ID: {req.kb_id}, "
f"Success: {index_result['success_count']}, Failed: {index_result['failed_count']}"
)
if index_result["errors"]:
logger.warning(
f"[KB_DELETE] Some indices failed to delete - KB ID: {req.kb_id}, "
f"Errors: {index_result['errors']}"
)
except Exception as e:
logger.error(
f"[KB_DELETE] Failed to delete indices - KB ID: {req.kb_id}, Error: {str(e)}",
exc_info=True,
)
return ResponseModel(
code=status.HTTP_200_OK, message="delete knowledge base success", data=None
)
@with_exception_handling
def knowledge_base_update(req: KnowledgeBaseUpdateRequest) -> ResponseModel:
"""更新知识库"""
start_time = time.time()
logger.info(
f"[KB_UPDATE] Updating knowledge base - KB ID: {req.kb_id}, "
f"Name: {req.name}, Desc: {repr(req.desc)}"
)
kb_get = KnowledgeBaseGet(space_id=req.space_id, kb_id=req.kb_id)
get_result = knowledge_base_repository.knowledge_base_get(kb_get)
if get_result.code == status.HTTP_404_NOT_FOUND or not get_result.data:
logger.warning(f"[KB_UPDATE] Knowledge base not found - ID: {req.kb_id}")
return ResponseModel(code=status.HTTP_404_NOT_FOUND, message="Knowledge base not found")
current_kb = get_result.data
current_name = current_kb.get("name", "")
current_desc = current_kb.get("description", "")
logger.info(
f"[KB_UPDATE] Current description: {repr(current_desc)}, New description: {repr(req.desc)}"
)
if req.name != current_name:
name_exists_result = knowledge_base_repository.knowledge_base_check_name_exists(
space_id=req.space_id, name=req.name, exclude_kb_id=req.kb_id
)
if name_exists_result.code != status.HTTP_200_OK:
logger.error(
f"[KB_UPDATE] Failed to check name existence - Error: {name_exists_result.message}"
)
return ResponseModel(
code=name_exists_result.code,
message=name_exists_result.message,
)
if name_exists_result.data:
logger.warning(
f"[KB_UPDATE] Knowledge base name already exists - Name: {req.name}, "
f"Space: {req.space_id}, KB ID: {req.kb_id}"
)
return ResponseModel(
code=status.HTTP_400_BAD_REQUEST,
message=f"知识库名称 '{req.name}' 已存在",
)
description_value = req.desc if req.desc else None
stored_config = _build_kb_stored_config(
req.embed_model_config, req.llm_config, req.config
)
update_result = knowledge_base_repository.knowledge_base_update(
space_id=req.space_id,
kb_id=req.kb_id,
name=req.name,
description=description_value,
config=stored_config,
)
if update_result.code != status.HTTP_200_OK:
logger.error(f"[KB_UPDATE] Update failed - ID: {req.kb_id}, Error: {update_result.message}")
return ResponseModel(
code=update_result.code,
message=update_result.message,
)
logger.info(
f"[KB_UPDATE] Knowledge base updated - ID: {req.kb_id}, "
f"Duration: {time.time() - start_time:.3f}s"
)
return ResponseModel(
code=status.HTTP_200_OK, message="update knowledge base message success", data=None
)
def _get_storage_path(space_id: str, kb_id: str) -> Path:
"""获取知识库文件存储路径"""
backend_dir = Path(__file__).resolve().parent.parent.parent.parent.parent
storage_path = backend_dir / "data" / "knowledge_base" / space_id / kb_id
storage_path.mkdir(parents=True, exist_ok=True)
return storage_path
def _get_file_type(filename: str) -> str:
"""根据文件名获取文件类型"""
return Path(filename).suffix.lower().lstrip(".")
def _get_mime_type(file_type: str) -> str:
"""根据文件类型获取 MIME 类型"""
mime_types = {
"pdf": "application/pdf",
"doc": "application/msword",
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"txt": "text/plain",
"md": "text/markdown",
"xls": "application/vnd.ms-excel",
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"ppt": "application/vnd.ms-powerpoint",
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
}
return mime_types.get(file_type.lower(), "application/octet-stream")
def _detect_real_file_type(file_path: str) -> str:
"""检测文件的真实格式(通过文件头魔数)。
Args:
file_path: 文件路径
Returns:
检测到的真实文件扩展名,如 '.docx', '.doc', '.pdf' 等
如果无法识别则返回原扩展名
"""
try:
with open(file_path, "rb") as f:
header = f.read(8)
if header[:4] == b"PK\x03\x04":
return ".docx"
if header[:8] == b"\xd0\xcf\x11\xe0\xa1\xb1\x1a\xe1":
return ".doc"
if header[:4] == b"%PDF":
return ".pdf"
except Exception as e:
logger.warning(f"[PARSE] Failed to detect file type: {file_path}, Error: {e}")
return Path(file_path).suffix.lower()
def _get_corrected_file_path(original_path: str) -> str:
"""根据文件真实格式返回正确扩展名的文件路径。
如果文件扩展名与真实格式不符,会创建一个正确扩展名的临时副本文件。
Args:
original_path: 原始文件路径
Returns:
正确扩展名的文件路径(可能是临时副本)
"""
original_ext = Path(original_path).suffix.lower()
real_ext = _detect_real_file_type(original_path)
if original_ext == real_ext:
return original_path
if original_ext == ".doc" and real_ext == ".docx":
logger.info(
f"[PARSE] File format mismatch detected - "
f"Extension: {original_ext}, Real format: {real_ext}, "
f"Path: {original_path}"
)
original_path_obj = Path(original_path)
corrected_path = original_path_obj.with_suffix(real_ext)
if not corrected_path.exists():
try:
import shutil
shutil.copy2(original_path, corrected_path)
logger.info(
f"[PARSE] Created temporary file with correct extension: {corrected_path}"
)
except Exception as e:
logger.warning(
f"[PARSE] Failed to create temporary file with correct extension: {str(e)}. "
f"Using original path: {original_path}"
)
return original_path
return str(corrected_path)
return original_path
async def _parse_file(
doc_path: str, parsing_strategy, doc_id: str, file_name: str = None
) -> List[Document]:
"""调用新的知识库系统解析文件,返回Document列表"""
logger.debug(
"[PARSE] Parsing file - Path: %s, Strategy type: %s",
doc_path,
parsing_strategy.strategy_type,
)
if not doc_path:
raise ValueError("File path is empty")
_ensure_resilient_pdf_parser()
corrected_path = _get_corrected_file_path(doc_path)
temp_file_created = False
if corrected_path != doc_path:
logger.info(
f"[PARSE] Using corrected file path - "
f"Original: {doc_path}, Corrected: {corrected_path}"
)
temp_file_created = True
try:
parser = AutoFileParser()
documents = await parser.parse(
doc=corrected_path, doc_id=doc_id, file_name=file_name or Path(corrected_path).name
)
if not documents:
raise ValueError(f"No content parsed from file: {doc_path}")
for document in documents:
if document.metadata is None:
document.metadata = {}
document.metadata["doc_id"] = document.id_
logger.debug(
"[PARSE] Parsed file - Path: %s, Documents: %s",
doc_path,
len(documents),
)
return documents
finally:
if temp_file_created:
try:
corrected_path_obj = Path(corrected_path)
if corrected_path_obj.exists() and corrected_path_obj != Path(doc_path):
corrected_path_obj.unlink()
logger.debug(
"[PARSE] Cleaned up temporary file: %s", corrected_path
)
except Exception as e:
logger.warning(
f"[PARSE] Failed to clean up temporary file {corrected_path}: {str(e)}"
)
def _resolve_chunking_config(segmentation_strategy) -> tuple[int, float, Dict[str, bool]]:
"""提取分段配置,兼容前端字段命名"""
cfg = segmentation_strategy.strategy_config or {}
chunk_size = int(cfg.get("max_tokens") or cfg.get("chunk_size") or 512)
overlap_percent = float(cfg.get("chunk_overlap_percent") or cfg.get("chunk_overlap") or 0)
preprocess_options = {
"normalize_whitespace": bool(
cfg.get("remove_extra_spaces") or cfg.get("normalize_whitespace") or False
),
"remove_url_email": bool(
cfg.get("remove_urls_emails") or cfg.get("remove_url_email") or False
),
}
return chunk_size, overlap_percent, preprocess_options
def _create_chunker(segmentation_strategy, embed_model=None) -> TextChunker:
"""创建 Chunker 实例"""
chunk_size, overlap_percent, preprocess_options = _resolve_chunking_config(
segmentation_strategy
)
chunk_unit = "char"
strategy_config = segmentation_strategy.strategy_config or {}
if "chunk_unit" in strategy_config:
chunk_unit = strategy_config.get("chunk_unit", "char")
chunk_overlap = int(chunk_size * (overlap_percent / 100)) if overlap_percent > 0 else 0
logger.debug(
"[CHUNK] Creating chunker - Chunk size: %s, Overlap: %s (%s%%), Unit: %s, Preprocess: %s",
chunk_size,
chunk_overlap,
overlap_percent,
chunk_unit,
preprocess_options,
)
chunker = TextChunker(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
chunk_unit=chunk_unit,
embed_model=embed_model if chunk_unit == "token" else None,
preprocess_options=preprocess_options if any(preprocess_options.values()) else None,
)
return chunker
def _check_milvus_connection() -> Tuple[bool, str]:
"""检查 Milvus 连接性
Returns:
tuple[bool, str]: (是否连接成功, 错误信息)
"""
try:
from pymilvus import connections, utility
except ImportError:
return False, "无法连接到 Milvus: 未安装 pymilvus 库"
milvus_host = os.getenv("MILVUS_HOST", "localhost")
milvus_port = os.getenv("MILVUS_PORT", "19530")
milvus_token = os.getenv("MILVUS_TOKEN") or ""
alias = "kb_connection_test"
try:
try:
if connections.has_connection(alias):
connections.disconnect(alias)
except Exception as e:
logger.warning(f"[MILVUS] Failed to disconnect connection: {alias}, Error: {str(e)}")
connections.connect(
alias=alias, host=milvus_host, port=int(milvus_port), token=milvus_token
)
try:
_ = utility.list_collections(using=alias)
except Exception as e:
try:
connections.disconnect(alias)
except Exception as disconnect_error:
logger.warning(
f"[MILVUS] Failed to disconnect connection: {alias}, Error: {str(disconnect_error)}"
)
return False, f"无法访问 Milvus 服务: {str(e)}"
try:
connections.disconnect(alias)
except Exception as e:
logger.warning(f"[MILVUS] Failed to disconnect connection: {alias}, Error: {str(e)}")
return True, ""
except Exception as e:
error_msg = str(e)
try:
if connections.has_connection(alias):
connections.disconnect(alias)
except Exception as disconnect_error:
logger.warning(
f"[MILVUS] Failed to disconnect connection: {alias}, Error: {str(disconnect_error)}"
)
return False, f"Milvus 连接失败: {error_msg}"
def _check_index_connection() -> Union[ResponseModel, None]:
"""
Function for wrapping index connection type
based on the `INDEX_MANAGER_TYPE` variable set in `.env`.
Returns:
_type_: `Union[ResponseModel, None]`
"""
index_manager_type = os.getenv("INDEX_MANAGER_TYPE", "milvus")
if index_manager_type == "milvus":
logger.info(f"[KB_CREATE] Checking Milvus connection...")
milvus_connected, milvus_error = _check_milvus_connection()
if not milvus_connected:
logger.error(f"[KB_CREATE] Milvus connection check failed - Error: {milvus_error}")
return ResponseModel(
code=status.HTTP_503_SERVICE_UNAVAILABLE,
message=(
f"无法连接到 Milvus 服务,请检查 Milvus 配置和连接状态。"
f"错误信息: {milvus_error}"
),
)
logger.info(f"[KB_CREATE] Milvus connection check passed")
return None
else:
return None
def _create_index_manager(collection_name: str) -> MilvusIndexer:
"""
Creates Milvus index manager based on the `INDEX_MANAGER_TYPE` variable set in `.env`.
Returns:
MilvusIndexer
"""
index_manager_type = os.getenv("INDEX_MANAGER_TYPE", "milvus")
if index_manager_type == "milvus":
milvus_host = os.getenv("MILVUS_HOST", "localhost")
milvus_port = os.getenv("MILVUS_PORT", "19530")
milvus_token = os.getenv("MILVUS_TOKEN") or ""
milvus_uri = f"http://{milvus_host}:{milvus_port}"
vector_store_config = VectorStoreConfig(
store_provider="milvus",
collection_name=collection_name,
)
return MilvusIndexer(config=vector_store_config, milvus_uri=milvus_uri, milvus_token=milvus_token)
raise ValueError(
f"Un-supported {index_manager_type=} for env variable INDEX_MANAGER_TYPE"
)
async def _delete_kb_indices(kb_id: str, space_id: str) -> dict:
"""删除知识库下所有文档的 chunks 和 triples 索引
获取知识库下的所有文档,然后循环删除每个文档的 chunks 和 triples 索引数据
"""
result = {"success_count": 0, "failed_count": 0, "errors": []}
all_documents = []
page = 1
page_size = 100
try:
while True:
doc_list_result = knowledge_base_repository.document_list(
space_id=space_id, kb_id=kb_id, page=page, size=page_size
)
if doc_list_result.code != status.HTTP_200_OK or not doc_list_result.data:
break
items = doc_list_result.data.get("items", [])
if not items:
break
all_documents.extend(items)
if len(items) < page_size:
break
page += 1
if not all_documents:
logger.debug(
"[KB_DELETE] No documents to delete indices for KB %s", kb_id
)
return result
documents = all_documents
logger.info(f"[KB_DELETE] Deleting indices for {len(documents)} documents in KB {kb_id}")
chunk_index = f"ds_kb_{kb_id}_chunks"
triple_index = f"ds_kb_{kb_id}_triples"
index_manager = _create_index_manager(collection_name=chunk_index)
for doc in documents:
doc_id = doc.get("doc_id") or doc.get("id")
if not doc_id:
continue
try:
await _delete_document_from_index(
index_manager=index_manager,
index_name=chunk_index,
doc_id=doc_id,
kb_id=kb_id,
index_type="chunks",
)
await _delete_document_from_index(
index_manager=index_manager,
index_name=triple_index,
doc_id=doc_id,
kb_id=kb_id,
index_type="triples",
)
result["success_count"] += 1
except Exception as e:
result["failed_count"] += 1
result["errors"].append(f"Doc {doc_id}: {str(e)}")
logger.warning(f"[KB_DELETE] Failed to delete index for doc {doc_id}: {e}")
logger.info(
f"[KB_DELETE] Index deletion completed - KB: {kb_id}, "
f"Success: {result['success_count']}, Failed: {result['failed_count']}"
)
except Exception as e:
error_msg = f"Failed to delete KB indices: {str(e)}"
result["errors"].append(error_msg)
logger.error(f"[KB_DELETE] {error_msg}", exc_info=True)
return result
def _create_vector_store(collection_name: str) -> MilvusVectorStore:
"""
Creates Milvus vector store based on the `INDEX_MANAGER_TYPE` variable set in `.env`.
Args:
collection_name: 集合名称
Returns:
MilvusVectorStore
"""
index_manager_type = os.getenv("INDEX_MANAGER_TYPE", "milvus")
if index_manager_type == "milvus":
milvus_host = os.getenv("MILVUS_HOST", "localhost")
milvus_port = os.getenv("MILVUS_PORT", "19530")
milvus_token = os.getenv("MILVUS_TOKEN") or ""
milvus_uri = f"http://{milvus_host}:{milvus_port}"
vector_store_config = VectorStoreConfig(
store_provider="milvus",
collection_name=collection_name,
)
return MilvusVectorStore(
config=vector_store_config, milvus_uri=milvus_uri, milvus_token=milvus_token
)
raise ValueError(
f"Un-supported {index_manager_type=} for env variable INDEX_MANAGER_TYPE"
)
async def _delete_document_from_index(
index_manager: MilvusIndexer,
index_name: str,
doc_id: str,
kb_id: str,
index_type: str = "chunks",
) -> bool:
"""从索引中删除指定 doc_id 的数据
Args:
index_manager: MilvusIndexer
index_name: 索引名称
doc_id: 文档ID
kb_id: 知识库ID
index_type: 索引类型("chunks" 或 "triples"),用于日志
Returns:
bool: 是否成功删除(如果索引不存在或数据不存在,返回 True)
"""
try:
index_exists = await index_manager.index_exists(index_name)
if not index_exists:
logger.debug(
"[DOC_DELETE] %s index does not exist: %s",
index_type.capitalize(),
index_name,
)
return True
deleted = await index_manager.delete_index(doc_id=doc_id, index_name=index_name)
if deleted:
logger.info(
f"[DOC_DELETE] Deleted {index_type} from index - Index: {index_name}, Doc ID: {doc_id}"
)
else:
logger.debug(
"[DOC_DELETE] No %s found for doc_id: %s in index: %s",
index_type,
doc_id,
index_name,
)
return True
except Exception as delete_error:
error_msg = str(delete_error)
if "not exist" in error_msg.lower() or "not found" in error_msg.lower():
logger.debug(
"[DOC_DELETE] No %s found for doc_id: %s in index: %s",
index_type,
doc_id,
index_name,
)
return True
else:
logger.warning(
f"[DOC_DELETE] Failed to delete {index_type} - Doc ID: {doc_id}, KB ID: {kb_id}, Error: {delete_error}"
)
return False
async def _index_documents(*args, **kwargs) -> dict:
documents = kwargs.get("documents", args[0] if len(args) > 0 else None)
indexing_strategy = kwargs.get("indexing_strategy", args[1] if len(args) > 1 else None)
segmentation_strategy = kwargs.get("segmentation_strategy", args[2] if len(args) > 2 else None)
space_id = kwargs.get("space_id", args[3] if len(args) > 3 else None)
kb_id = kwargs.get("kb_id", args[4] if len(args) > 4 else None)
doc_id = kwargs.get("doc_id", args[5] if len(args) > 5 else None)
process_info = kwargs.get("process_info", args[6] if len(args) > 6 else None)
llm_config = kwargs.get("llm_config", args[7] if len(args) > 7 else None)
embed_model_config = kwargs.get("embed_model_config", args[8] if len(args) > 8 else None)
update_indexing_result = knowledge_base_repository.document_update_status(
space_id=space_id,
kb_id=kb_id,
doc_id=doc_id,
doc_status=DocumentStatus.INDEXING.value,
process_info={**process_info, "parsing_completed": True, "document_count": len(documents)},
)
if update_indexing_result.code != status.HTTP_200_OK:
raise Exception(f"Failed to update status to INDEXING: {update_indexing_result.message}")
logger.info(f"[INDEX] Document status updated to INDEXING - Doc ID: {doc_id}")
use_graph = bool(getattr(indexing_strategy, "enable_graph_enhancement", False))
chunk_index = f"ds_kb_{kb_id}_chunks"
triple_index = f"ds_kb_{kb_id}_triples" if use_graph else None
logger.info(
f"[INDEX] Indexing documents - KB ID: {kb_id}, Doc ID: {doc_id}, "
f"Documents: {len(documents)}, Use graph: {use_graph}, "
f"Chunk index: {chunk_index}, Triple index: {triple_index}"
)
if embed_model_config is None:
raise ValueError("embed_model_config is required for document indexing")
embed_model = _create_embed_model(embed_model_config)
llm_client = None
model_name = None
if use_graph:
if llm_config is None:
raise ValueError(
"llm_config is required when enable_graph_enhancement is True"
)
llm_client, model_name = _create_llm_client(llm_config)
logger.info(f"[INDEX] LLM client created - Model: {model_name}")
if not llm_client:
raise ValueError("llm_client is required when use_graph=True")
strategy_config = segmentation_strategy.strategy_config or {}
chunk_unit = strategy_config.get("chunk_unit", "char")
chunker = _create_chunker(
segmentation_strategy, embed_model=embed_model if chunk_unit == "token" else None
)
index_manager = _create_index_manager(collection_name=chunk_index)
vector_store = _create_vector_store(
collection_name=chunk_index,
)
extractor = None
if use_graph and llm_client:
extractor = TripleExtractor(
llm_client=llm_client,
model_name=model_name,
)
kb_config = KnowledgeBaseConfig(
kb_id=kb_id,
index_type="vector",
use_graph=use_graph,
chunk_size=chunker.chunk_size,
chunk_overlap=chunker.chunk_overlap,
)
if use_graph:
knowledge_base = LocalGraphKnowledgeBase(
config=kb_config,
vector_store=vector_store,
embed_model=embed_model,
parser=None,
chunker=chunker,
extractor=extractor,
index_manager=index_manager,
llm_client=llm_client,
)
else:
knowledge_base = LocalSimpleKnowledgeBase(
config=kb_config,
vector_store=vector_store,
embed_model=embed_model,
parser=None,
chunker=chunker,
index_manager=index_manager,
llm_client=llm_client,
)
try:
doc_ids = await knowledge_base.add_documents(documents)
if not doc_ids:
raise RuntimeError("Index build failed: no document IDs returned")
chunk_count = 0
try:
total_text_length = sum(len(doc.text) for doc in documents)
if chunker.chunk_size > 0:
estimated_chunks = max(1, total_text_length // chunker.chunk_size)
chunk_count = estimated_chunks
logger.debug(
"[INDEX] Estimated chunk count: %s (text length: %s, chunk_size: %s)",
chunk_count,
total_text_length,
chunker.chunk_size,
)
except Exception as e:
logger.warning(f"[INDEX] Failed to estimate chunk count: {str(e)}")
chunk_count = len(documents)
logger.debug(
(
"[INDEX] Indexing completed - KB ID: %s, Doc ID: %s, "
"Chunk index: %s, Triple index: %s, Estimated chunks: %s"
),
kb_id,
doc_id,
chunk_index,
triple_index,
chunk_count,
)
return {
"chunk_index": chunk_index,
"triple_index": triple_index,
"chunk_count": chunk_count,
}
finally:
try:
await knowledge_base.close()
except Exception as e:
logger.warning(f"[INDEX] Failed to close knowledge base: {str(e)}")
async def process_single_document(*args, **kwargs):
space_id = kwargs.get("space_id", args[0] if len(args) > 0 else None)
kb_id = kwargs.get("kb_id", args[1] if len(args) > 1 else None)
doc_id = kwargs.get("doc_id", args[2] if len(args) > 2 else None)
file_path = kwargs.get("file_path", args[3] if len(args) > 3 else None)
parsing_strategy = kwargs.get("parsing_strategy", args[4] if len(args) > 4 else None)
segmentation_strategy = kwargs.get("segmentation_strategy", args[5] if len(args) > 5 else None)
indexing_strategy = kwargs.get("indexing_strategy", args[6] if len(args) > 6 else None)
process_info = kwargs.get("process_info", args[7] if len(args) > 7 else None)
file_name = kwargs.get("file_name", args[8] if len(args) > 8 else None)
llm_config = kwargs.get("llm_config", args[9] if len(args) > 9 else None)
embed_model_config = kwargs.get("embed_model_config", args[10] if len(args) > 10 else None)
obs_name = kwargs.get("obs_name", args[11] if len(args) > 11 else None)
"""在后台异步处理单个文档"""
try:
logger.info(
f"[DOC_PROCESS_BG] Starting background processing - Doc ID: {doc_id}, KB ID: {kb_id}"
)
try:
if not file_name:
file_name = Path(file_path).name
if not os.path.exists(file_path) and obs_name and os.getenv("OBS_BUCKET"):
logger.info(
f'[DOC_PROCESS_BG] Local file missing, downloading from OBS - "{obs_name}" -> "{file_path}"'
)
obs_manager = OBSDocumentManager()
await obs_manager.download_document(object_name=obs_name, file_path=file_path)
except Exception as parse_error:
full_error_msg = _extract_full_error_message(parse_error)
error_message = f"OBS download failed: {full_error_msg}"
logger.error(
f"[DOC_PROCESS_BG] OBS file download failed - {file_name=}, {obs_name=}, Error: {error_message}",
exc_info=True,
)
raise Exception(error_message) from parse_error
try:
documents = await _parse_file(file_path, parsing_strategy, doc_id, file_name=file_name)
except Exception as parse_error:
full_error_msg = _extract_full_error_message(parse_error)
error_message = f"File parsing failed: {full_error_msg}"
logger.error(
f"[DOC_PROCESS_BG] File parsing failed - Doc ID: {doc_id}, KB ID: {kb_id}, Error: {error_message}",
exc_info=True,
)
raise Exception(error_message) from parse_error
try:
index_result = await _index_documents(
documents=documents,
indexing_strategy=indexing_strategy,
segmentation_strategy=segmentation_strategy,
space_id=space_id,
kb_id=kb_id,
doc_id=doc_id,
process_info=process_info,
llm_config=llm_config,
embed_model_config=embed_model_config,
)
except Exception as index_error:
full_error_msg = _extract_full_error_message(index_error)
error_message = f"Index building failed: {full_error_msg}"
logger.error(
f"[DOC_PROCESS_BG] Index building failed - Doc ID: {doc_id}, KB ID: {kb_id}, Error: {error_message}",
exc_info=True,
)
raise Exception(error_message) from index_error
final_process_info = {
**process_info,
"chunking_completed": True,
"indexing_completed": True,
"index_result": index_result,
}
update_indexed_result = knowledge_base_repository.document_update_status(
space_id=space_id,
kb_id=kb_id,
doc_id=doc_id,
doc_status=DocumentStatus.INDEXED.value,
process_info=final_process_info,
es_index_name=index_result.get("chunk_index"),
chunk_count=index_result.get("chunk_count"),
)
if update_indexed_result.code != status.HTTP_200_OK:
raise Exception("Failed to update status to INDEXED")
logger.info(
f"[DOC_PROCESS_BG] Document indexing completed - Doc ID: {doc_id}, "
f"Chunk index: {index_result.get('chunk_index')}, "
f"Chunks: {index_result.get('chunk_count')}, KB ID: {kb_id}"
)
except Exception as e:
error_message = str(e)
logger.error(
f"[DOC_PROCESS_BG] Document processing failed - Doc ID: {doc_id}, "
f"KB ID: {kb_id}, Error: {error_message}",
exc_info=True,
)
try:
knowledge_base_repository.document_update_status(
space_id=space_id,
kb_id=kb_id,
doc_id=doc_id,
doc_status=DocumentStatus.FAILED.value,
process_info={
**process_info,
"error": error_message,
"failed_time": milliseconds(),
},
)
except Exception as update_error:
logger.error(
f"[DOC_PROCESS_BG] Failed to update status to FAILED - Doc ID: {doc_id}, "
f"Error: {str(update_error)}"
)
async def _process_documents_sequentially(*args, **kwargs):
space_id = kwargs.get("space_id", args[0] if len(args) > 0 else None)
kb_id = kwargs.get("kb_id", args[1] if len(args) > 1 else None)
documents = kwargs.get("documents", args[2] if len(args) > 2 else None)
parsing_strategy = kwargs.get("parsing_strategy", args[3] if len(args) > 3 else None)
segmentation_strategy = kwargs.get("segmentation_strategy", args[4] if len(args) > 4 else None)
indexing_strategy = kwargs.get("indexing_strategy", args[5] if len(args) > 5 else None)
task_id = kwargs.get("task_id", args[6] if len(args) > 6 else None)
process_info_base = kwargs.get("process_info_base", args[7] if len(args) > 7 else None)
llm_config = kwargs.get("llm_config", args[8] if len(args) > 8 else None)
embed_model_config = kwargs.get("embed_model_config", args[9] if len(args) > 9 else None)
"""串行处理多个文档(后台任务)"""
logger.info(
f"[DOC_PROCESS_SEQ] Starting sequential processing - Task ID: {task_id}, "
f"KB ID: {kb_id}, Total documents: {len(documents)}"
)
for idx, doc_info in enumerate(documents, 1):
doc_id = doc_info.get("doc_id")
file_path = doc_info.get("file_path")
doc_name = doc_info.get("name")
obs_name = doc_info.get("obs_name")
try:
logger.info(
f"[DOC_PROCESS_SEQ] Processing document {idx}/{len(documents)} - "
f"Doc ID: {doc_id}, Task ID: {task_id}"
)
process_info = {
**process_info_base,
"task_id": task_id,
"current_index": idx,
"total_count": len(documents),
}
await process_single_document(
space_id=space_id,
kb_id=kb_id,
doc_id=doc_id,
file_path=file_path,
parsing_strategy=parsing_strategy,
segmentation_strategy=segmentation_strategy,
indexing_strategy=indexing_strategy,
process_info=process_info,
file_name=doc_name,
llm_config=llm_config,
embed_model_config=embed_model_config,
obs_name=obs_name,
)
logger.info(
f"[DOC_PROCESS_SEQ] Completed document {idx}/{len(documents)} - "
f"Doc ID: {doc_id}, Name: {doc_name}, Task ID: {task_id}"
)
except Exception as e:
logger.error(
f"[DOC_PROCESS_SEQ] Failed to process document {idx}/{len(documents)} - "
f"Doc ID: {doc_id}, Task ID: {task_id}, Error: {str(e)}",
exc_info=True,
)
continue
logger.info(
f"[DOC_PROCESS_SEQ] Sequential processing completed - Task ID: {task_id}, "
f"KB ID: {kb_id}, Total documents: {len(documents)}"
)
async def document_upload(
space_id: str,
kb_id: str,
files: List[UploadFile],
metadata: Dict[str, Any] | None,
) -> ResponseModel:
"""上传文档到知识库(支持多文件)
注意:此函数是异步的,异常处理在 Router 层完成
"""
start_time = time.time()
logger.info(
f"[DOC_UPLOAD] Uploading documents - KB ID: {kb_id}, Files: {len(files)}"
)
kb_get = KnowledgeBaseGet(space_id=space_id, kb_id=kb_id)
kb_result = knowledge_base_repository.knowledge_base_get(kb_get)
if kb_result.code != status.HTTP_200_OK or not kb_result.data:
logger.warning(f"[DOC_UPLOAD] Knowledge base not found - KB ID: {kb_id}")
return ResponseModel(code=status.HTTP_404_NOT_FOUND, message="Knowledge base not found")
obs_required_msg = kb_obs_misconfigured_message()
if obs_required_msg:
logger.error(f"[DOC_UPLOAD] {obs_required_msg}")
return ResponseModel(
code=status.HTTP_503_SERVICE_UNAVAILABLE,
message=obs_required_msg,
data=None,
)
storage_path = _get_storage_path(space_id, kb_id)
metadata_doc_list: list[str] = []
if isinstance(metadata, dict):
raw_doc_list = metadata.get("doc_list")
if isinstance(raw_doc_list, list):
metadata_doc_list = [str(doc_id) for doc_id in raw_doc_list if doc_id]
existing_doc_ids: set[str] = set()
doc_id_list_result = knowledge_base_repository.document_id_list(space_id=space_id, kb_id=kb_id)
if doc_id_list_result.code == status.HTTP_200_OK:
existing_doc_ids = set(doc_id_list_result.data or [])
else:
logger.warning(
f"[DOC_UPLOAD] Failed to get document id list - KB ID: {kb_id}, "
f"Error: {doc_id_list_result.message}"
)
if metadata_doc_list:
metadata_doc_id_set = set(metadata_doc_list)
delete_doc_ids = [doc_id for doc_id in existing_doc_ids if doc_id not in metadata_doc_id_set]
for doc_id in delete_doc_ids:
doc_get_result = knowledge_base_repository.document_get(
space_id=space_id, kb_id=kb_id, doc_id=doc_id
)
if doc_get_result.code != status.HTTP_200_OK or not doc_get_result.data:
logger.warning(
f"[DOC_UPLOAD] Document not found for delete - Doc ID: {doc_id}, KB ID: {kb_id}"
)
continue
file_path = doc_get_result.data.get("file_path")
delete_result = knowledge_base_repository.document_delete(
space_id=space_id, kb_id=kb_id, doc_id=doc_id
)
if delete_result.code == status.HTTP_200_OK:
logger.info(
f"[DOC_UPLOAD] Document deleted due to metadata sync - Doc ID: {doc_id}, KB ID: {kb_id}"
)
else:
logger.error(
f"[DOC_UPLOAD] Failed to delete document - Doc ID: {doc_id}, KB ID: {kb_id}, "
f"Error: {delete_result.message}"
)
allowed_file_extensions = {".pdf", ".doc", ".docx", ".txt", ".md"}
max_file_size = 20 * 1024 * 1024
uploaded_docs = []
success_count = 0
failed_count = 0
for file_index, file in enumerate(files):
try:
doc_id_from_metadata = (
metadata_doc_list[file_index] if file_index < len(metadata_doc_list) else ""
)
if doc_id_from_metadata:
doc_id = doc_id_from_metadata
if doc_id in existing_doc_ids:
logger.info(
f"[DOC_UPLOAD] Skip upload - Doc ID already exists: {doc_id}, "
f"KB ID: {kb_id}"
)
continue
else:
doc_id = str(uuid.uuid4())
filename = file.filename or f"unnamed_{doc_id}"
file_ext = Path(filename).suffix.lower()
if file_ext not in allowed_file_extensions:
failed_count += 1
logger.warning(
f"[DOC_UPLOAD] Unsupported file type - File: {filename}, Extension: {file_ext}, "
f"KB ID: {kb_id}"
)
continue
file_type = _get_file_type(filename)
mime_type = _get_mime_type(file_type)
safe_filename = f"{doc_id}{Path(filename).suffix}"
file_path = storage_path / safe_filename
file_content = await file.read()
file_size = len(file_content)
if file_size > max_file_size:
failed_count += 1
file_size_mb = file_size / (1024 * 1024)
max_size_mb = max_file_size / (1024 * 1024)
logger.warning(
f"[DOC_UPLOAD] File size exceeds limit - File: {filename}, Size: {file_size_mb:.2f}MB, "
f"Limit: {max_size_mb}MB, KB ID: {kb_id}"
)
continue
with open(file_path, "wb") as f:
f.write(file_content)
obs_manager = OBSDocumentManager()
object_name = obs_manager.obs_name(
space_id=space_id, kb_id=kb_id, file_name=file_path.name
)
obs_required = knowledge_base_requires_obs()
obs_configured = bool(obs_manager.bucket and obs_manager.obs_client)
if obs_required:
if not obs_configured:
failed_count += 1
logger.error(
f"[DOC_UPLOAD] OBS required when CHECKPOINTER_TYPE=redis but not configured - "
f"File: {filename}, Doc ID: {doc_id}, KB ID: {kb_id}"
)
if file_path.exists():
try:
file_path.unlink()
except OSError as unlink_err:
logger.warning(
f"[DOC_UPLOAD] Failed to remove local file after OBS misconfig - "
f"Path: {file_path}, Error: {unlink_err}"
)
uploaded_docs.append(
DocumentUploadResponse(
id=doc_id,
name=filename,
file_size=file_size,
status=DocumentStatus.FAILED.value,
)
)
continue
try:
await obs_manager.upload_document(
object_name=object_name,
file_path=file_path,
)
except Exception as obs_error:
failed_count += 1
logger.error(
f"[DOC_UPLOAD] OBS upload failed (required when CHECKPOINTER_TYPE=redis) - "
f"File: {filename}, Doc ID: {doc_id}, KB ID: {kb_id}, Error: {obs_error}",
exc_info=True,
)
if file_path.exists():
try:
file_path.unlink()
except OSError as unlink_err:
logger.warning(
f"[DOC_UPLOAD] Failed to remove local file after OBS failure - "
f"Path: {file_path}, Error: {unlink_err}"
)
uploaded_docs.append(
DocumentUploadResponse(
id=doc_id,
name=filename,
file_size=file_size,
status=DocumentStatus.FAILED.value,
)
)
continue
obs_stored_name = object_name if obs_required else ""
logger.debug(
"[DOC_UPLOAD] File saved - Path: %s, Size: %s bytes",
file_path,
file_size,
)
current_time = milliseconds()
doc_data = {
"space_id": space_id,
"kb_id": kb_id,
"doc_id": doc_id,
"name": filename,
"file_path": str(file_path),
"obs_name": obs_stored_name,
"file_size": file_size,
"file_type": file_type,
"mime_type": mime_type,
"status": DocumentStatus.UPLOADED.value,
"doc_metadata": metadata or {},
"create_time": current_time,
"update_time": current_time,
}
create_result = knowledge_base_repository.document_create(doc_data)
if create_result.code == status.HTTP_200_OK:
success_count += 1
uploaded_docs.append(
DocumentUploadResponse(
id=doc_id,
name=filename,
file_size=file_size,
status=DocumentStatus.UPLOADED.value,
)
)
logger.info(f"[DOC_UPLOAD] Document created - Doc ID: {doc_id}, Name: {filename}")
else:
failed_count += 1
if file_path.exists():
file_path.unlink()
logger.error(
f"[DOC_UPLOAD] Failed to create document record - Doc ID: {doc_id}, Error: {create_result.message}"
)
except Exception as e:
failed_count += 1
logger.error(
f"[DOC_UPLOAD] Error uploading file {file.filename}: {str(e)}", exc_info=True
)
try:
if "file_path" in locals() and file_path.exists():
file_path.unlink()
except Exception as cleanup_error:
logger.warning(
f"[DOC_UPLOAD] Failed to cleanup file after upload error - "
f"File: {file.filename}, Path: {file_path if 'file_path' in locals() else 'unknown'}, "
f"Error: {str(cleanup_error)}"
)
response_data = DocumentUploadBatchResponse(
success_count=success_count, failed_count=failed_count, documents=uploaded_docs
)
logger.info(
f"[DOC_UPLOAD] Upload completed - KB ID: {kb_id}, "
f"Success: {success_count}, Failed: {failed_count}, Duration: {time.time() - start_time:.3f}s"
)
return ResponseModel(
code=status.HTTP_200_OK,
message=f"Upload completed: {success_count} success, {failed_count} failed",
data=response_data.model_dump(by_alias=False),
)
def _timestamp_to_date_str(timestamp: int | None) -> str:
"""将时间戳(毫秒)转换为日期时间字符串(YYYY-MM-DD HH:MM:SS)"""
if not timestamp:
return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
dt = datetime.fromtimestamp(timestamp / 1000, tz=timezone.utc)
return dt.strftime("%Y-%m-%d %H:%M:%S")
@with_exception_handling
def knowledge_base_search(req: KnowledgeBaseSearchRequest) -> ResponseModel:
"""查询知识库(支持分页)"""
start_time = time.time()
page = req.page or 1
page_size = req.page_size or 10
logger.info(
f"[KB_SEARCH] Searching knowledge bases - Query: '{req.query}', Page: {page}, PageSize: {page_size}"
)
search_result = knowledge_base_repository.knowledge_base_search(
space_id=req.space_id, query=req.query, page=page, page_size=page_size
)
if search_result.code != status.HTTP_200_OK:
logger.error(f"[KB_SEARCH] Search failed - Error: {search_result.message}")
return search_result
result_data = search_result.data
knowledge_bases_data = result_data.get("knowledge_bases", [])
total = result_data.get("total", 0)
total_pages = result_data.get("total_pages", 1)
knowledge_bases = []
for kb in knowledge_bases_data:
kb_id = kb.get("kb_id", "")
has_graph_enhancement = knowledge_base_repository.has_graph_enhancement_documents(
space_id=req.space_id, kb_id=kb_id
)
kb_config = kb.get("config") or {}
knowledge_bases.append(
KnowledgeBaseInfo(
id=kb_id,
space_id=kb.get("space_id", ""),
name=kb.get("name", ""),
description=kb.get("description"),
embed_model_config=kb_config.get("embed_model_config"),
llm_config=kb_config.get("llm_config"),
config=kb.get("config"),
create_time=kb.get("create_time"),
update_time=kb.get("update_time"),
has_graph_enhancement=has_graph_enhancement,
)
)
response_data = KnowledgeBaseSearchResponse(
knowledge_bases=knowledge_bases,
total=total,
page=page,
page_size=page_size,
total_pages=total_pages,
)
logger.info(
f"[KB_SEARCH] Search completed - Found: {len(knowledge_bases)}/{total} knowledge bases, "
f"Page: {page}/{total_pages}, Duration: {time.time() - start_time:.3f}s"
)
return ResponseModel(
code=status.HTTP_200_OK,
message="Search knowledge bases successfully",
data=response_data.model_dump(
by_alias=True
),
)
@with_exception_handling
def knowledge_base_list(req: KnowledgeBaseListRequest) -> ResponseModel:
"""获取知识库列表(支持分页)"""
start_time = time.time()
logger.info(
f"[KB_LIST] Getting knowledge base list - Space ID: {req.space_id}, "
f"Page: {req.page}, Size: {req.size}"
)
list_result = knowledge_base_repository.knowledge_base_list(
space_id=req.space_id, page=req.page, size=req.size
)
if list_result.code != status.HTTP_200_OK:
logger.warning(
f"[KB_LIST] Database query failed, returning empty list - "
f"Space ID: {req.space_id}, Error: {list_result.message}"
)
return ResponseModel(
code=status.HTTP_200_OK,
message="get knowledge base list success",
data=KnowledgeBaseListResponse(
items=[], total=0, page=req.page, size=req.size
).model_dump(by_alias=False),
)
items = []
for kb_data in list_result.data.get("items", []):
kb_id = kb_data.get("kb_id", "")
has_graph_enhancement = knowledge_base_repository.has_graph_enhancement_documents(
space_id=req.space_id, kb_id=kb_id
)
kb_status = "indexing"
status_result = knowledge_base_repository.document_status_list(
space_id=req.space_id, kb_id=kb_id
)
if status_result.code == status.HTTP_200_OK:
status_list = status_result.data or []
if any(doc_status == "failed" for doc_status in status_list):
kb_status = "failed"
elif any(doc_status.startswith("upload") for doc_status in status_list):
kb_status = "uploading"
elif status_list and all(doc_status == "indexed" for doc_status in status_list):
kb_status = "indexed"
items.append(
KnowledgeBaseListItem(
name=kb_data.get("name", ""),
desc=kb_data.get("description"),
id=kb_id,
type="text",
status=kb_status,
created_at=_timestamp_to_date_str(kb_data.get("create_time")),
updated_at=_timestamp_to_date_str(kb_data.get("update_time")),
has_graph_enhancement=has_graph_enhancement,
)
)
total = list_result.data.get("total", 0)
response_data = KnowledgeBaseListResponse(
items=items, total=total, page=req.page, size=req.size
)
logger.info(
f"[KB_LIST] Knowledge base list retrieved - Space ID: {req.space_id}, "
f"Total: {total}, Count: {len(items)}, Page: {req.page}, Size: {req.size}, "
f"Duration: {time.time() - start_time:.3f}s"
)
return ResponseModel(
code=status.HTTP_200_OK,
message="get knowledge base list success",
data=response_data.model_dump(by_alias=False),
)
@with_exception_handling
def document_list(req: DocumentListRequest) -> ResponseModel:
"""获取知识库文档列表(支持分页)"""
start_time = time.time()
logger.info(
f"[DOC_LIST] Getting document list - Space ID: {req.space_id}, "
f"KB ID: {req.kb_id}, Page: {req.page}, Size: {req.size}"
)
kb_get = KnowledgeBaseGet(space_id=req.space_id, kb_id=req.kb_id)
kb_result = knowledge_base_repository.knowledge_base_get(kb_get)
if kb_result.code != status.HTTP_200_OK or not kb_result.data:
logger.warning(f"[DOC_LIST] Knowledge base not found - KB ID: {req.kb_id}")
return ResponseModel(code=status.HTTP_404_NOT_FOUND, message="Knowledge base not found")
list_result = knowledge_base_repository.document_list(
space_id=req.space_id, kb_id=req.kb_id, page=req.page, size=req.size
)
if list_result.code != status.HTTP_200_OK:
logger.error(
f"[DOC_LIST] Database query failed - Space ID: {req.space_id}, "
f"KB ID: {req.kb_id}, Error: {list_result.message}"
)
return ResponseModel(
code=list_result.code,
message=list_result.message,
data={"items": [], "total": 0, "page": req.page, "size": req.size},
)
items = []
for doc_data in list_result.data.get("items", []):
items.append(
DocumentListItem(
name=doc_data.get("name", ""),
id=doc_data.get("doc_id", ""),
status=doc_data.get("status", ""),
created_at=_timestamp_to_date_str(doc_data.get("create_time")),
updated_at=_timestamp_to_date_str(doc_data.get("update_time")),
)
)
total = list_result.data.get("total", 0)
response_data = DocumentListResponse(items=items, total=total, page=req.page, size=req.size)
logger.info(
f"[DOC_LIST] Document list retrieved - Space ID: {req.space_id}, "
f"KB ID: {req.kb_id}, Count: {len(items)}/{total}, "
f"Duration: {time.time() - start_time:.3f}s"
)
return ResponseModel(
code=status.HTTP_200_OK,
message="get documents success",
data=response_data.model_dump(by_alias=False),
)
@with_exception_handling
def document_update(req: DocumentUpdateRequest) -> ResponseModel:
"""更新文档信息(当前只支持更新文档名称)"""
start_time = time.time()
logger.info(
f"[DOC_UPDATE] Updating document - Space ID: {req.space_id}, "
f"KB ID: {req.kb_id}, Doc ID: {req.document_id}, Name: {req.document_name}"
)
kb_get = KnowledgeBaseGet(space_id=req.space_id, kb_id=req.kb_id)
kb_result = knowledge_base_repository.knowledge_base_get(kb_get)
if kb_result.code != status.HTTP_200_OK or not kb_result.data:
logger.warning(
f"[DOC_UPDATE] Knowledge base not found - KB ID: {req.kb_id}"
)
return ResponseModel(code=status.HTTP_404_NOT_FOUND, message="Knowledge base not found")
doc_get_result = knowledge_base_repository.document_get(
space_id=req.space_id, kb_id=req.kb_id, doc_id=req.document_id
)
if doc_get_result.code != status.HTTP_200_OK or not doc_get_result.data:
logger.warning(
f"[DOC_UPDATE] Document not found - Doc ID: {req.document_id}, KB ID: {req.kb_id}"
)
return ResponseModel(code=status.HTTP_404_NOT_FOUND, message="Document not found")
update_result = knowledge_base_repository.document_update(
space_id=req.space_id, kb_id=req.kb_id, doc_id=req.document_id, name=req.document_name
)
if update_result.code != status.HTTP_200_OK:
logger.error(
f"[DOC_UPDATE] Update failed - Doc ID: {req.document_id}, KB ID: {req.kb_id}, "
f"Error: {update_result.message}"
)
return ResponseModel(
code=update_result.code,
message=update_result.message,
)
logger.info(
f"[DOC_UPDATE] Document updated - Doc ID: {req.document_id}, KB ID: {req.kb_id}, "
f"New Name: {req.document_name}, Duration: {time.time() - start_time:.3f}s"
)
return ResponseModel(
code=status.HTTP_200_OK, message="update document message success", data=None
)
@with_exception_handling
async def document_delete(req: DocumentDeleteRequest) -> ResponseModel:
"""删除文档(支持批量删除)"""
start_time = time.time()
logger.info(
f"[DOC_DELETE] Deleting documents - Space ID: {req.space_id}, "
f"KB ID: {req.kb_id}, Doc IDs: {req.document_ids}"
)
kb_get = KnowledgeBaseGet(space_id=req.space_id, kb_id=req.kb_id)
kb_result = knowledge_base_repository.knowledge_base_get(kb_get)
if kb_result.code != status.HTTP_200_OK or not kb_result.data:
logger.warning(
f"[DOC_DELETE] Knowledge base not found - KB ID: {req.kb_id}"
)
return ResponseModel(code=status.HTTP_404_NOT_FOUND, message="Knowledge base not found")
success_count = 0
failed_count = 0
failed_doc_ids = []
for doc_id in req.document_ids:
doc_get_result = knowledge_base_repository.document_get(
space_id=req.space_id, kb_id=req.kb_id, doc_id=doc_id
)
if doc_get_result.code != status.HTTP_200_OK or not doc_get_result.data:
logger.warning(
f"[DOC_DELETE] Document not found - Doc ID: {doc_id}, KB ID: {req.kb_id}"
)
failed_count += 1
failed_doc_ids.append(doc_id)
continue
file_path = doc_get_result.data.get("file_path")
obs_name = doc_get_result.data.get("obs_name")
delete_result = knowledge_base_repository.document_delete(
space_id=req.space_id, kb_id=req.kb_id, doc_id=doc_id
)
if delete_result.code != status.HTTP_200_OK:
logger.error(
f"[DOC_DELETE] Delete failed - Doc ID: {doc_id}, KB ID: {req.kb_id}, "
f"Error: {delete_result.message}"
)
failed_count += 1
failed_doc_ids.append(doc_id)
else:
success_count += 1
if file_path:
try:
file_path_obj = Path(file_path)
if file_path_obj.exists():
file_path_obj.unlink()
logger.info(f"[DOC_DELETE] Local file deleted - Path: {file_path}")
else:
logger.warning(f"[DOC_DELETE] Local file not found - Path: {file_path}")
except Exception as e:
logger.warning(
f"[DOC_DELETE] Failed to delete local file - Path: {file_path}, Error: {str(e)}"
)
if obs_name and os.getenv("OBS_BUCKET"):
obs_manager = OBSDocumentManager()
await obs_manager.delete_document(obs_name)
try:
doc_data = doc_get_result.data
process_info = doc_data.get("process_info", {})
indexing_strategy = (
process_info.get("indexing_strategy", {})
if isinstance(process_info, dict)
else {}
)
use_graph = (
indexing_strategy.get("enable_graph_enhancement", False)
if isinstance(indexing_strategy, dict)
else False
)
chunk_index = f"ds_kb_{req.kb_id}_chunks"
index_manager = _create_index_manager(collection_name=chunk_index)
await _delete_document_from_index(
index_manager=index_manager,
index_name=chunk_index,
doc_id=doc_id,
kb_id=req.kb_id,
index_type="chunks",
)
if use_graph:
triple_index = f"ds_kb_{req.kb_id}_triples"
await _delete_document_from_index(
index_manager=index_manager,
index_name=triple_index,
doc_id=doc_id,
kb_id=req.kb_id,
index_type="triples",
)
except Exception as e:
logger.warning(
f"[DOC_DELETE] Index cleanup failed - Doc ID: {doc_id}, KB ID: {req.kb_id}, Error: {e}"
)
logger.info(
f"[DOC_DELETE] Documents deletion completed - KB ID: {req.kb_id}, "
f"Success: {success_count}, Failed: {failed_count}, "
f"Duration: {time.time() - start_time:.3f}s"
)
if success_count > 0:
return ResponseModel(code=status.HTTP_200_OK, message="delete documents success", data=None)
else:
return ResponseModel(
code=status.HTTP_400_BAD_REQUEST,
message=f"Failed to delete documents: {failed_doc_ids}",
data=None,
)
@with_exception_handling
def document_get_status_batch(req: DocumentStatusRequest) -> ResponseModel:
"""批量查询文档状态"""
start_time = time.time()
logger.info(
f"[DOC_STATUS] Getting document status batch - "
f"Space ID: {req.space_id}, KB ID: {req.kb_id}, Doc IDs: {len(req.doc_id_list)}"
)
status_items = []
for doc_id in req.doc_id_list:
doc_result = knowledge_base_repository.document_get(
space_id=req.space_id, kb_id=req.kb_id, doc_id=doc_id
)
if doc_result.code == status.HTTP_200_OK and doc_result.data:
doc_data = doc_result.data
status_value = doc_data.get("status", DocumentStatus.UPLOADING.value)
doc_name = doc_data.get("name")
error_msg = None
enable_graph_enhancement = None
process_info = doc_data.get("process_info")
if isinstance(process_info, dict):
error_msg = process_info.get("error")
indexing_strategy = process_info.get("indexing_strategy")
if isinstance(indexing_strategy, dict):
enable_graph_enhancement = indexing_strategy.get(
"enable_graph_enhancement", False
)
if status_value == DocumentStatus.FAILED.value and not error_msg:
error_msg = "Processing failed with unknown error"
if error_msg:
error_msg = _format_error_message_for_frontend(error_msg)
status_items.append(
DocumentStatusResponse(
id=doc_id,
status=status_value,
name=doc_name,
error_msg=error_msg,
enable_graph_enhancement=enable_graph_enhancement,
)
)
else:
logger.warning(
f"[DOC_STATUS] Document not found - Space ID: {req.space_id}, "
f"KB ID: {req.kb_id}, Doc ID: {doc_id}"
)
response_data = DocumentStatusListResponse(items=status_items)
logger.info(
f"[DOC_STATUS] Document status batch retrieved - Space ID: {req.space_id}, "
f"KB ID: {req.kb_id}, Requested: {len(req.doc_id_list)}, "
f"Found: {len(status_items)}, Duration: {time.time() - start_time:.3f}s"
)
return ResponseModel(
code=status.HTTP_200_OK,
message="get document status success",
data=response_data.model_dump(by_alias=False),
)
@with_exception_handling
async def document_process(req: DocumentProcessRequest) -> ResponseModel:
"""启动文档处理流程,使用 agentcore 的解析/分段/索引能力"""
start_time = time.time()
logger.info(
f"[DOC_PROCESS] Starting document processing - "
f"KB ID: {req.kb_id}, Files: {len(req.doc_id_list)}"
)
kb_get = KnowledgeBaseGet(space_id=req.space_id, kb_id=req.kb_id)
kb_result = knowledge_base_repository.knowledge_base_get(kb_get)
if kb_result.code != status.HTTP_200_OK or not kb_result.data:
logger.warning(
f"[DOC_PROCESS] Knowledge base not found - KB ID: {req.kb_id}"
)
return ResponseModel(code=status.HTTP_404_NOT_FOUND, message="Knowledge base not found")
try:
embed_model_config, llm_config = _config_dict_to_embed_and_llm(
kb_result.data.get("config") or {}
)
except ValueError as e:
logger.warning(f"[DOC_PROCESS] Invalid kb config - KB ID: {req.kb_id}, Error: {e}")
return ResponseModel(
code=status.HTTP_400_BAD_REQUEST,
message=str(e),
)
if getattr(req, "llm_config", None):
lcd = req.llm_config
if isinstance(lcd, dict) and lcd:
mt = str(lcd.get("model_type") or llm_config.model_type or "openai").lower()
if mt not in ("openai", "siliconflow"):
mt = "openai"
hp = lcd.get("hyper_parameters")
ext = lcd.get("extension")
llm_config = LLMConfig(
model_name=str(lcd.get("model_name") or llm_config.model_name),
model_type=mt,
base_url=str(
lcd.get("base_url")
if lcd.get("base_url") is not None
else llm_config.base_url
),
api_key=str(
lcd.get("api_key")
if lcd.get("api_key") is not None
else llm_config.api_key
),
hyper_parameters=hp if isinstance(hp, dict) else llm_config.hyper_parameters,
extension=ext if isinstance(ext, dict) else llm_config.extension,
)
logger.info(
f"[DOC_PROCESS] Applied request-level llm_config - KB ID: {req.kb_id}, "
f"model_type={llm_config.model_type}, model_name={llm_config.model_name}"
)
processed_count = 0
failed_count = 0
failed_docs: list[str] = []
task_id = str(uuid.uuid4())
current_time = milliseconds()
process_info_base = {
"task_id": task_id,
"parsing_strategy": req.parsing_strategy.model_dump(),
"segmentation_strategy": req.segmentation_strategy.model_dump(),
"indexing_strategy": req.indexing_strategy.model_dump(),
"start_time": current_time,
}
valid_documents: list[dict] = []
for doc_id in req.doc_id_list:
try:
doc_result = knowledge_base_repository.document_get(
space_id=req.space_id, kb_id=req.kb_id, doc_id=doc_id
)
if doc_result.code != status.HTTP_200_OK or not doc_result.data:
failed_count += 1
failed_docs.append(doc_id)
logger.warning(f"[DOC_PROCESS] Document not found - Doc ID: {doc_id}")
try:
knowledge_base_repository.document_update_status(
space_id=req.space_id,
kb_id=req.kb_id,
doc_id=doc_id,
doc_status=DocumentStatus.FAILED.value,
process_info={
**process_info_base,
"error": "Document not found",
"failed_time": milliseconds(),
},
)
except Exception:
pass
continue
current_status = doc_result.data.get("status")
if current_status == DocumentStatus.INDEXED.value:
logger.info(
f"[DOC_PROCESS] Document already indexed - Doc ID: {doc_id}, "
f"KB ID: {req.kb_id}"
)
continue
if current_status != DocumentStatus.UPLOADED.value:
failed_count += 1
failed_docs.append(doc_id)
logger.warning(
f"[DOC_PROCESS] Document status invalid - Doc ID: {doc_id}, Current status: {current_status}"
)
try:
knowledge_base_repository.document_update_status(
space_id=req.space_id,
kb_id=req.kb_id,
doc_id=doc_id,
doc_status=DocumentStatus.FAILED.value,
process_info={
**process_info_base,
"error": "Document status invalid",
"failed_time": milliseconds(),
},
)
except Exception as update_error:
logger.error(
f"[DOC_PROCESS] Failed to update FAILED status - Doc ID: {doc_id}, Error: {str(update_error)}"
)
continue
file_path = doc_result.data.get("file_path")
if not file_path:
failed_count += 1
failed_docs.append(doc_id)
logger.error(f"[DOC_PROCESS] File path not found for document {doc_id}")
knowledge_base_repository.document_update_status(
space_id=req.space_id,
kb_id=req.kb_id,
doc_id=doc_id,
doc_status=DocumentStatus.FAILED.value,
process_info={
**process_info_base,
"error": "File path not found",
"failed_time": milliseconds(),
},
)
continue
update_result = knowledge_base_repository.document_update_status(
space_id=req.space_id,
kb_id=req.kb_id,
doc_id=doc_id,
doc_status=DocumentStatus.PROCESSING.value,
process_info=process_info_base,
)
if update_result.code != status.HTTP_200_OK:
failed_count += 1
failed_docs.append(doc_id)
logger.error(
f"[DOC_PROCESS] Failed to update document status - "
f"Doc ID: {doc_id}, Error: {update_result.message}"
)
try:
knowledge_base_repository.document_update_status(
space_id=req.space_id,
kb_id=req.kb_id,
doc_id=doc_id,
doc_status=DocumentStatus.FAILED.value,
process_info={
**process_info_base,
"error": "Failed to update document status",
"failed_time": milliseconds(),
},
)
except Exception as update_error:
logger.error(
f"[DOC_PROCESS] Failed to update FAILED status - Doc ID: {doc_id}, Error: {str(update_error)}"
)
continue
doc_name = doc_result.data.get("name")
doc_obs_name = doc_result.data.get("obs_name")
valid_documents.append({
"doc_id": doc_id,
"file_path": file_path,
"name": doc_name,
"obs_name": doc_obs_name,
})
processed_count += 1
logger.info(
f"[DOC_PROCESS] Document validated and status updated to PROCESSING - Doc ID: {doc_id}"
)
except Exception as e:
failed_count += 1
failed_docs.append(doc_id)
logger.error(
f"[DOC_PROCESS] Failed to validate document - Doc ID: {doc_id}, "
f"KB ID: {req.kb_id}, Error: {str(e)}",
exc_info=True,
)
try:
knowledge_base_repository.document_update_status(
space_id=req.space_id,
kb_id=req.kb_id,
doc_id=doc_id,
doc_status=DocumentStatus.FAILED.value,
process_info={
**process_info_base,
"error": "Document validation failed",
"failed_time": milliseconds(),
},
)
except Exception as update_error:
logger.error(
f"[DOC_PROCESS] Failed to update FAILED status - Doc ID: {doc_id}, Error: {str(update_error)}",
exc_info=True,
)
if valid_documents:
logger.info(
f"[DOC_PROCESS] Creating sequential processing task - Task ID: {task_id}, "
f"Valid documents: {len(valid_documents)}, KB ID: {req.kb_id}"
)
asyncio.create_task(
_process_documents_sequentially(
space_id=req.space_id,
kb_id=req.kb_id,
documents=valid_documents,
parsing_strategy=req.parsing_strategy,
segmentation_strategy=req.segmentation_strategy,
indexing_strategy=req.indexing_strategy,
task_id=task_id,
process_info_base=process_info_base,
llm_config=llm_config,
embed_model_config=embed_model_config,
)
)
logger.info(
f"[DOC_PROCESS] Sequential processing task created - Task ID: {task_id}, "
f"KB ID: {req.kb_id}, Documents to process: {len(valid_documents)}"
)
response_data = DocumentProcessResponse(
task_id=task_id,
processed_count=processed_count,
failed_count=failed_count,
failed_docs=failed_docs,
)
logger.info(
f"[DOC_PROCESS] Document processing tasks started - Task ID: {task_id}, "
f"KB ID: {req.kb_id}, Processed: {processed_count}, Failed: {failed_count}, "
f"Duration: {time.time() - start_time:.3f}s"
)
return ResponseModel(
code=status.HTTP_200_OK,
message="Document processing tasks started",
data=response_data.model_dump(by_alias=False),
)
@with_exception_handling
def task_progress(req: TaskProgressRequest) -> ResponseModel:
"""查询任务处理进度"""
start_time = time.time()
logger.info(
f"[TASK_PROGRESS] Querying task progress - Task ID: {req.task_id}, KB ID: {req.kb_id}"
)
kb_get = KnowledgeBaseGet(space_id=req.space_id, kb_id=req.kb_id)
kb_result = knowledge_base_repository.knowledge_base_get(kb_get)
if kb_result.code != status.HTTP_200_OK or not kb_result.data:
logger.warning(
f"[TASK_PROGRESS] Knowledge base not found - KB ID: {req.kb_id}"
)
return ResponseModel(code=status.HTTP_404_NOT_FOUND, message="Knowledge base not found")
list_result = knowledge_base_repository.document_list(
space_id=req.space_id, kb_id=req.kb_id, page=1, size=1000
)
if list_result.code != status.HTTP_200_OK:
logger.error(
f"[TASK_PROGRESS] Failed to query documents - Space ID: {req.space_id}, "
f"KB ID: {req.kb_id}, Error: {list_result.message}"
)
return ResponseModel(
code=list_result.code,
message=list_result.message,
)
task_items = []
total_count = 0
processed_count = 0
success_count = 0
failed_count = 0
for doc_data in list_result.data.get("items", []):
process_info = doc_data.get("process_info", {})
if isinstance(process_info, dict) and process_info.get("task_id") == req.task_id:
total_count += 1
doc_id = doc_data.get("doc_id", "")
doc_name = doc_data.get("name", "")
doc_status = doc_data.get("status", "")
if doc_status == DocumentStatus.INDEXED.value:
success_count += 1
elif doc_status == DocumentStatus.FAILED.value:
failed_count += 1
if doc_status in [
DocumentStatus.PROCESSING.value,
DocumentStatus.INDEXING.value,
DocumentStatus.INDEXED.value,
]:
processed_count += 1
error = None
if doc_status == DocumentStatus.FAILED.value:
error = process_info.get("error", "Unknown error")
task_items.append(
TaskProgressItem(doc_id=doc_id, doc_name=doc_name, status=doc_status, error=error)
)
response_data = TaskProgressResponse(
task_id=req.task_id,
total_count=total_count,
processed_count=processed_count,
success_count=success_count,
failed_count=failed_count,
items=task_items,
)
logger.info(
f"[TASK_PROGRESS] Task progress retrieved - Task ID: {req.task_id}, "
f"KB ID: {req.kb_id}, Total: {total_count}, Processed: {processed_count}, "
f"Success: {success_count}, Failed: {failed_count}, "
f"Duration: {time.time() - start_time:.3f}s"
)
return ResponseModel(
code=status.HTTP_200_OK,
message="get task progress success",
data=response_data.model_dump(by_alias=False),
)