"""数据库连接抽象:SQLite / PostgreSQL(含连接池)。"""
from __future__ import annotations
import os
import re
import sqlite3
from contextlib import contextmanager
from typing import Any, Iterator
from app.core.config import database_dialect, get_database_url, load_dotenv_file, sqlite_path_from_url
from app.core.paths import ensure_data_dirs
load_dotenv_file()
_PLACEHOLDER_RE = re.compile(r"\?")
_pg_pool: Any | None = None
class DBCursor:
def __init__(self, cursor: Any) -> None:
self._cursor = cursor
def fetchone(self) -> dict[str, Any] | None:
row = self._cursor.fetchone()
if row is None:
return None
if isinstance(row, dict):
return row
return dict(row)
def fetchall(self) -> list[dict[str, Any]]:
rows = self._cursor.fetchall()
out: list[dict[str, Any]] = []
for row in rows:
out.append(row if isinstance(row, dict) else dict(row))
return out
@property
def rowcount(self) -> int:
return int(self._cursor.rowcount or 0)
class DBConnection:
def __init__(self, raw: Any, dialect: str, *, from_pool: bool = False) -> None:
self._raw = raw
self.dialect = dialect
self._from_pool = from_pool
def _adapt_sql(self, sql: str) -> str:
if self.dialect == "postgres":
return _PLACEHOLDER_RE.sub("%s", sql)
return sql
def execute(self, sql: str, params: tuple | list | None = None) -> DBCursor:
sql = self._adapt_sql(sql)
args = tuple(params) if params else ()
if self.dialect == "sqlite":
return DBCursor(self._raw.execute(sql, args))
cur = self._raw.cursor()
cur.execute(sql, args)
return DBCursor(cur)
def executescript(self, script: str) -> None:
if self.dialect == "sqlite":
self._raw.executescript(script)
return
for stmt in _split_sql_script(script):
if stmt.strip():
self.execute(stmt)
def commit(self) -> None:
self._raw.commit()
def rollback(self) -> None:
self._raw.rollback()
def close(self) -> None:
if self._from_pool and _pg_pool is not None:
try:
self._raw.rollback()
except Exception:
pass
_pg_pool.putconn(self._raw)
return
self._raw.close()
def _split_sql_script(script: str) -> list[str]:
parts: list[str] = []
buf: list[str] = []
for line in script.splitlines():
stripped = line.strip()
if stripped.startswith("--"):
continue
buf.append(line)
if stripped.endswith(";"):
parts.append("\n".join(buf))
buf = []
if buf:
parts.append("\n".join(buf))
return parts
def _connect_sqlite() -> DBConnection:
ensure_data_dirs()
path = sqlite_path_from_url(get_database_url())
conn = sqlite3.connect(path, check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
return DBConnection(conn, "sqlite")
def _pg_pool_bounds() -> tuple[int, int]:
try:
lo = max(1, int((os.getenv("COMPILOT_PG_POOL_MIN") or "2").strip()))
except ValueError:
lo = 2
try:
hi = max(lo, int((os.getenv("COMPILOT_PG_POOL_MAX") or "10").strip()))
except ValueError:
hi = 10
return lo, hi
def _get_pg_pool() -> Any:
global _pg_pool
if _pg_pool is not None:
return _pg_pool
try:
import psycopg2
from psycopg2.extras import RealDictCursor
from psycopg2.pool import ThreadedConnectionPool
except ImportError as exc:
raise RuntimeError("请安装 psycopg2-binary 以使用 PostgreSQL") from exc
lo, hi = _pg_pool_bounds()
_pg_pool = ThreadedConnectionPool(
lo,
hi,
get_database_url(),
cursor_factory=RealDictCursor,
)
return _pg_pool
def _connect_postgres() -> DBConnection:
raw = _get_pg_pool().getconn()
raw.autocommit = False
return DBConnection(raw, "postgres", from_pool=True)
def connect() -> DBConnection:
if database_dialect() == "postgres":
return _connect_postgres()
return _connect_sqlite()
@contextmanager
def db_conn() -> Iterator[DBConnection]:
conn = connect()
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()