"""数据库连接抽象:SQLite / PostgreSQL(含连接池)。"""

from __future__ import annotations

import os
import re
import sqlite3
from contextlib import contextmanager
from typing import Any, Iterator

from app.core.config import database_dialect, get_database_url, load_dotenv_file, sqlite_path_from_url
from app.core.paths import ensure_data_dirs

load_dotenv_file()

_PLACEHOLDER_RE = re.compile(r"\?")
_pg_pool: Any | None = None


class DBCursor:
    def __init__(self, cursor: Any) -> None:
        self._cursor = cursor

    def fetchone(self) -> dict[str, Any] | None:
        row = self._cursor.fetchone()
        if row is None:
            return None
        if isinstance(row, dict):
            return row
        return dict(row)

    def fetchall(self) -> list[dict[str, Any]]:
        rows = self._cursor.fetchall()
        out: list[dict[str, Any]] = []
        for row in rows:
            out.append(row if isinstance(row, dict) else dict(row))
        return out

    @property
    def rowcount(self) -> int:
        return int(self._cursor.rowcount or 0)


class DBConnection:
    def __init__(self, raw: Any, dialect: str, *, from_pool: bool = False) -> None:
        self._raw = raw
        self.dialect = dialect
        self._from_pool = from_pool

    def _adapt_sql(self, sql: str) -> str:
        if self.dialect == "postgres":
            return _PLACEHOLDER_RE.sub("%s", sql)
        return sql

    def execute(self, sql: str, params: tuple | list | None = None) -> DBCursor:
        sql = self._adapt_sql(sql)
        args = tuple(params) if params else ()
        if self.dialect == "sqlite":
            return DBCursor(self._raw.execute(sql, args))
        cur = self._raw.cursor()
        cur.execute(sql, args)
        return DBCursor(cur)

    def executescript(self, script: str) -> None:
        if self.dialect == "sqlite":
            self._raw.executescript(script)
            return
        for stmt in _split_sql_script(script):
            if stmt.strip():
                self.execute(stmt)

    def commit(self) -> None:
        self._raw.commit()

    def rollback(self) -> None:
        self._raw.rollback()

    def close(self) -> None:
        if self._from_pool and _pg_pool is not None:
            try:
                self._raw.rollback()
            except Exception:
                pass
            _pg_pool.putconn(self._raw)
            return
        self._raw.close()


def _split_sql_script(script: str) -> list[str]:
    parts: list[str] = []
    buf: list[str] = []
    for line in script.splitlines():
        stripped = line.strip()
        if stripped.startswith("--"):
            continue
        buf.append(line)
        if stripped.endswith(";"):
            parts.append("\n".join(buf))
            buf = []
    if buf:
        parts.append("\n".join(buf))
    return parts


def _connect_sqlite() -> DBConnection:
    ensure_data_dirs()
    path = sqlite_path_from_url(get_database_url())
    conn = sqlite3.connect(path, check_same_thread=False)
    conn.row_factory = sqlite3.Row
    conn.execute("PRAGMA journal_mode=WAL")
    conn.execute("PRAGMA foreign_keys=ON")
    return DBConnection(conn, "sqlite")


def _pg_pool_bounds() -> tuple[int, int]:
    try:
        lo = max(1, int((os.getenv("COMPILOT_PG_POOL_MIN") or "2").strip()))
    except ValueError:
        lo = 2
    try:
        hi = max(lo, int((os.getenv("COMPILOT_PG_POOL_MAX") or "10").strip()))
    except ValueError:
        hi = 10
    return lo, hi


def _get_pg_pool() -> Any:
    global _pg_pool
    if _pg_pool is not None:
        return _pg_pool
    try:
        import psycopg2
        from psycopg2.extras import RealDictCursor
        from psycopg2.pool import ThreadedConnectionPool
    except ImportError as exc:
        raise RuntimeError("请安装 psycopg2-binary 以使用 PostgreSQL") from exc
    lo, hi = _pg_pool_bounds()
    _pg_pool = ThreadedConnectionPool(
        lo,
        hi,
        get_database_url(),
        cursor_factory=RealDictCursor,
    )
    return _pg_pool


def _connect_postgres() -> DBConnection:
    raw = _get_pg_pool().getconn()
    raw.autocommit = False
    return DBConnection(raw, "postgres", from_pool=True)


def connect() -> DBConnection:
    if database_dialect() == "postgres":
        return _connect_postgres()
    return _connect_sqlite()


@contextmanager
def db_conn() -> Iterator[DBConnection]:
    conn = connect()
    try:
        yield conn
        conn.commit()
    except Exception:
        conn.rollback()
        raise
    finally:
        conn.close()