"""PostgreSQL-backed SessionArchiveStore implementation.

Stores session archives in a ``session_archives`` table instead of
per-archive AGFS directories.  Same public interface as
``session.archive_store.SessionArchiveStore``.
"""

from __future__ import annotations

import json
import logging
import uuid
from datetime import UTC, datetime

from core.models import RequestContext
from fs.sql_adapter.pool import PoolAdapterMixin
from session.models import ArchiveEntry, ArchiveWriteResult

try:
    import psycopg2
    from psycopg2.extras import Json

    _HAS_PSYCOPG2 = True
except ImportError:
    psycopg2 = None
    Json = None
    _HAS_PSYCOPG2 = False

logger = logging.getLogger(__name__)


def _metadata_from_db(raw) -> dict:
    if isinstance(raw, dict):
        return dict(raw)
    return json.loads(raw) if isinstance(raw, str) else (raw or {})


def _is_merged_metadata(metadata: dict) -> bool:
    return str(metadata.get("status", "")).upper() == "MERGED"


class SQLSessionArchiveStore(PoolAdapterMixin):
    """PostgreSQL-backed session archive storage.

    Provides the same ``write_archive``, ``list_archives``,
    ``read_archive``, and ``read_archive_abstract`` methods as the
    AGFS-based ``SessionArchiveStore``, but backed by a single
    ``session_archives`` table.
    """

    def __init__(
        self,
        connection_string: str | None = None,
        pool_size: int = 5,
        pool=None,
    ):
        if not _HAS_PSYCOPG2:
            raise ImportError(
                "psycopg2 is required for SQLSessionArchiveStore. "
                "Install with: pip install psycopg2-binary"
            )
        self._init_pool(pool, connection_string, pool_size)
        self._ensure_table()

    # ------------------------------------------------------------------
    # Connection pool (standalone fallback — used when pool=None)
    # ------------------------------------------------------------------

    def _get_connection(self):
        if self._pool:
            conn = self._pool.pop()
            if conn.closed == 0:
                return conn
        return psycopg2.connect(self._connection_string)

    def _return_connection(self, conn):
        if len(self._pool) < self._pool_size and conn.closed == 0:
            self._pool.append(conn)
        else:
            conn.close()

    # ------------------------------------------------------------------
    # Schema
    # ------------------------------------------------------------------

    def _ensure_table(self) -> None:
        from fs.sql_adapter.schema import ensure_schema

        conn = self._get_connection()
        try:
            ensure_schema(conn)
        except Exception as exc:
            conn.rollback()
            raise RuntimeError(
                f"Failed to ensure schema: {exc}"
            ) from exc
        finally:
            self._return_connection(conn)

    # ------------------------------------------------------------------
    # Public API (same signatures as SessionArchiveStore)
    # ------------------------------------------------------------------

    def write_archive(
        self,
        session_id: str,
        overview: str,
        abstract: str,
        messages: list[dict],
        ctx: RequestContext,
        archive_id: str | None = None,
        metadata: dict | None = None,
    ) -> ArchiveWriteResult:
        """Write a session archive to PostgreSQL.

        Uses INSERT ON CONFLICT for idempotent upsert.
        """
        if archive_id is None:
            archive_id = (
                f"{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}"
                f"_{uuid.uuid4().hex[:8]}"
            )

        uri = (
            f"ctx://{ctx.account_id}/sessions/{session_id}"
            f"/history/{archive_id}"
        )
        now = datetime.now(UTC).isoformat()

        archive_metadata = {
            "archive_id": archive_id,
            "session_id": session_id,
            "created_at": now,
            "message_count": len(messages),
        }
        if metadata:
            archive_metadata.update(metadata)

        conn = self._get_connection()
        try:
            self.bind_tenant(conn, ctx.account_id)
            with conn.cursor() as cur:
                cur.execute(
                    """
                    INSERT INTO session_archives
                        (archive_id, session_id, account_id,
                         abstract, overview, messages, metadata, created_at)
                    VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
                    ON CONFLICT (account_id, session_id, archive_id) DO NOTHING
                    """,
                    (
                        archive_id,
                        session_id,
                        ctx.account_id,
                        abstract or "",
                        overview or "",
                        Json(messages),
                        Json(archive_metadata),
                        now,
                    ),
                )
            conn.commit()
            return ArchiveWriteResult(
                archive_id=archive_id,
                session_id=session_id,
                uri=uri,
                success=True,
                created_at=now,
            )
        except Exception as exc:
            conn.rollback()
            return ArchiveWriteResult(
                archive_id=archive_id,
                session_id=session_id,
                uri=uri,
                success=False,
                error=str(exc),
            )
        finally:
            self._return_connection(conn)

    def list_archives(
        self, session_id: str, ctx: RequestContext
    ) -> list[ArchiveEntry]:
        """List all archives for a session, newest first."""
        conn = self._get_connection()
        try:
            self.bind_tenant(conn, ctx.account_id)
            with conn.cursor() as cur:
                cur.execute(
                    """
                    SELECT archive_id, session_id, abstract, overview,
                           messages, metadata, created_at
                    FROM session_archives
                    WHERE account_id = %s AND session_id = %s
                      AND UPPER(COALESCE(metadata->>'status', '')) <> 'MERGED'
                    ORDER BY created_at DESC
                    """,
                    (ctx.account_id, session_id),
                )
                rows = cur.fetchall()
        except Exception:
            return []
        finally:
            conn.rollback()
            self._return_connection(conn)

        entries: list[ArchiveEntry] = []
        for row in rows:
            (
                archive_id,
                sess_id,
                abstract,
                overview,
                messages_raw,
                metadata_raw,
                created_at,
            ) = row

            meta = (
                json.loads(metadata_raw)
                if isinstance(metadata_raw, str)
                else (metadata_raw or {})
            )
            if _is_merged_metadata(meta):
                continue

            entries.append(
                ArchiveEntry(
                    archive_id=archive_id,
                    session_id=sess_id,
                    overview=overview or "",
                    abstract=abstract or "",
                    messages=[],  # Don't include full messages in list
                    created_at=str(created_at) if created_at else "",
                    metadata=meta,
                )
            )
        return entries

    def list_archives_since(
        self,
        since: datetime,
        ctx: RequestContext,
        limit: int = 50,
    ) -> list[ArchiveEntry]:
        """List archives created since a given timestamp, across all sessions.

        Args:
            since: Only return archives created after this timestamp
            ctx: RequestContext for account_id
            limit: Maximum number of archives to return (default 50)

        Returns:
            List of ArchiveEntry with messages included, ordered by created_at DESC
        """
        conn = self._get_connection()
        try:
            self.bind_tenant(conn, ctx.account_id)
            with conn.cursor() as cur:
                cur.execute(
                    """
                    SELECT archive_id, session_id, abstract, overview,
                           messages, metadata, created_at
                    FROM session_archives
                    WHERE account_id = %s AND created_at > %s
                      AND UPPER(COALESCE(metadata->>'status', '')) <> 'MERGED'
                    ORDER BY created_at DESC
                    LIMIT %s
                    """,
                    (ctx.account_id, since, limit),
                )
                rows = cur.fetchall()
        except Exception:
            return []
        finally:
            conn.rollback()
            self._return_connection(conn)

        entries: list[ArchiveEntry] = []
        for row in rows:
            (
                archive_id,
                sess_id,
                abstract,
                overview,
                messages_raw,
                metadata_raw,
                created_at,
            ) = row

            # Parse messages
            if isinstance(messages_raw, str):
                messages = json.loads(messages_raw)
            elif messages_raw is not None:
                messages = list(messages_raw)
            else:
                messages = []

            meta = (
                json.loads(metadata_raw)
                if isinstance(metadata_raw, str)
                else (metadata_raw or {})
            )
            if _is_merged_metadata(meta):
                continue

            entries.append(
                ArchiveEntry(
                    archive_id=archive_id,
                    session_id=sess_id,
                    overview=overview or "",
                    abstract=abstract or "",
                    messages=messages,
                    created_at=str(created_at) if created_at else "",
                    metadata=meta,
                )
            )
        return entries

    def read_archive(
        self,
        session_id: str,
        archive_id: str,
        ctx: RequestContext,
    ) -> ArchiveEntry | None:
        """Read a full session archive including messages."""
        conn = self._get_connection()
        try:
            self.bind_tenant(conn, ctx.account_id)
            with conn.cursor() as cur:
                cur.execute(
                    """
                    SELECT archive_id, session_id, abstract, overview,
                           messages, metadata, created_at
                    FROM session_archives
                    WHERE archive_id = %s AND session_id = %s AND account_id = %s
                    """,
                    (archive_id, session_id, ctx.account_id),
                )
                row = cur.fetchone()
        except Exception:
            return None
        finally:
            conn.rollback()
            self._return_connection(conn)

        if row is None:
            return None

        (
            aid,
            sess_id,
            abstract,
            overview,
            messages_raw,
            metadata_raw,
            created_at,
        ) = row

        # Parse messages
        if isinstance(messages_raw, str):
            messages = json.loads(messages_raw)
        elif messages_raw is not None:
            messages = list(messages_raw)
        else:
            messages = []

        meta = (
            json.loads(metadata_raw)
            if isinstance(metadata_raw, str)
            else (metadata_raw or {})
        )

        return ArchiveEntry(
            archive_id=aid,
            session_id=sess_id,
            overview=overview or "",
            abstract=abstract or "",
            messages=messages,
            created_at=str(created_at) if created_at else "",
            metadata=meta,
        )

    def read_archive_abstract(
        self,
        session_id: str,
        archive_id: str,
        ctx: RequestContext,
    ) -> str | None:
        """Read only the abstract from an archive."""
        conn = self._get_connection()
        try:
            self.bind_tenant(conn, ctx.account_id)
            with conn.cursor() as cur:
                cur.execute(
                    """
                    SELECT abstract, metadata FROM session_archives
                    WHERE archive_id = %s AND session_id = %s AND account_id = %s
                    """,
                    (archive_id, session_id, ctx.account_id),
                )
                row = cur.fetchone()
        except Exception:
            return None
        finally:
            conn.rollback()
            self._return_connection(conn)

        if row is None:
            return None
        return row[0] or None

    def delete_archive(
        self,
        session_id: str,
        archive_id: str,
        ctx: RequestContext,
    ) -> bool:
        """Delete a single archive row for the current account and session."""
        conn = self._get_connection()
        try:
            self.bind_tenant(conn, ctx.account_id)
            with conn.cursor() as cur:
                cur.execute(
                    """
                    DELETE FROM session_archives
                    WHERE archive_id = %s AND session_id = %s AND account_id = %s
                    """,
                    (archive_id, session_id, ctx.account_id),
                )
                if cur.rowcount != 1:
                    conn.rollback()
                    return False
            conn.commit()
            return True
        except Exception:
            conn.rollback()
            return False
        finally:
            self._return_connection(conn)

    def unmark_archive_merged(
        self,
        session_id: str,
        archive_id: str,
        ctx: RequestContext,
        merged_into: str,
    ) -> bool:
        """Undo a merge marker when a merge operation rolls back."""
        conn = self._get_connection()
        try:
            self.bind_tenant(conn, ctx.account_id)
            with conn.cursor() as cur:
                cur.execute(
                    """
                    SELECT metadata FROM session_archives
                    WHERE archive_id = %s AND session_id = %s AND account_id = %s
                    """,
                    (archive_id, session_id, ctx.account_id),
                )
                row = cur.fetchone()
                if row is None:
                    conn.rollback()
                    return False

                metadata = _metadata_from_db(row[0])
                if metadata.get("merged_into") != merged_into:
                    conn.rollback()
                    return False

                metadata.pop("status", None)
                metadata.pop("merged_into", None)
                metadata.pop("merged_at", None)

                cur.execute(
                    """
                    UPDATE session_archives
                    SET metadata = %s
                    WHERE archive_id = %s AND session_id = %s AND account_id = %s
                    """,
                    (Json(metadata), archive_id, session_id, ctx.account_id),
                )
            conn.commit()
            return True
        except Exception:
            conn.rollback()
            return False
        finally:
            self._return_connection(conn)

    def mark_archive_merged(
        self,
        session_id: str,
        archive_id: str,
        ctx: RequestContext,
        merged_into: str,
    ) -> bool:
        """Mark an archive row as merged into another archive."""
        conn = self._get_connection()
        try:
            self.bind_tenant(conn, ctx.account_id)
            marker = {
                "status": "MERGED",
                "merged_into": merged_into,
                "merged_at": datetime.now(UTC).isoformat(),
            }
            with conn.cursor() as cur:
                cur.execute(
                    """
                    UPDATE session_archives
                    SET metadata = COALESCE(metadata, '{}'::jsonb) || %s::jsonb
                    WHERE archive_id = %s AND session_id = %s AND account_id = %s
                    """,
                    (Json(marker), archive_id, session_id, ctx.account_id),
                )
                if cur.rowcount != 1:
                    conn.rollback()
                    return False
            conn.commit()
            return True
        except Exception:
            conn.rollback()
            return False
        finally:
            self._return_connection(conn)