"""Unit tests for SQLContextFS.

Tests use mocked psycopg2 connections to verify behavior without
requiring a running PostgreSQL instance.
"""

from datetime import UTC, datetime
from unittest.mock import MagicMock, patch

import pytest

from core.errors import (
    AccessDeniedError,
    ConcurrentModificationError,
    NodeBrokenError,
    NodeNotFoundError,
)
from core.models import ContextNode, RelationEdge, RequestContext

# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------

def _ctx(account_id="acme", user_id="alice", session_id="s1"):
    return RequestContext(
        account_id=account_id,
        user_id=user_id,
        agent_id="main",
        session_id=session_id,
        trace_id="t1",
    )


def _profile_node(content="Profile content", abstract="Short abstract"):
    return ContextNode(
        uri="ctx://acme/users/alice/memories/profile",
        context_type="MEMORY",
        category="profile",
        level=0,
        owner_space="user:alice",
        abstract=abstract,
        overview="## Overview\nDetailed overview",
        content=content,
        metadata={"tags": ["test"]},
    )


def _pref_node(slug="coffee"):
    return ContextNode(
        uri=f"ctx://acme/users/alice/memories/preferences/{slug}",
        context_type="MEMORY",
        category="preference",
        level=0,
        owner_space="user:alice",
        abstract="Likes coffee",
        overview="## Coffee preferences",
        content="Prefers dark roast",
        metadata={},
    )


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------

@pytest.fixture
def mock_sql_fs():
    """Create SQLContextFS with mocked psycopg2."""
    # Mock Json to just return the value (identity wrapper)
    def mock_json(x):
        return x

    with patch("fs.sql_adapter.sql_context_fs.psycopg2") as mock_pg, \
         patch("fs.sql_adapter.sql_context_fs.Json", mock_json):
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
        mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
        mock_pg.connect.return_value = mock_conn

        # Import after patching
        from fs.sql_adapter.sql_context_fs import SQLContextFS
        fs = SQLContextFS.__new__(SQLContextFS)
        fs._connection_string = "host=localhost dbname=test"
        fs._pool_size = 5
        fs._pool = []

        # Override pool methods to use mock
        fs._get_connection = lambda: mock_conn
        fs._return_connection = lambda c: None

        yield fs, mock_conn, mock_cursor


class TestSQLContextFSWriteNode:

    def test_write_new_node(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        node = _profile_node()
        ctx = _ctx()

        fs.write_node(node, ctx)

        conn.commit.assert_called_once()
        # Verify INSERT was executed (cursor.execute called)
        assert cursor.execute.called

    def test_write_access_denied(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        node = _profile_node()
        ctx = _ctx(account_id="other")  # Different account

        with pytest.raises(AccessDeniedError):
            fs.write_node(node, ctx)

    def test_write_with_relations(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        edge = RelationEdge(
            from_uri="ctx://acme/users/alice/memories/profile",
            to_uri="ctx://acme/users/alice/memories/preferences/coffee",
            relation_type="related_to",
            weight=0.8,
            reason="test",
        )
        node = ContextNode(
            uri="ctx://acme/users/alice/memories/profile",
            context_type="MEMORY",
            category="profile",
            level=0,
            owner_space="user:alice",
            abstract="abs",
            overview="ov",
            content="cont",
            metadata={"_relations": [edge]},
        )
        ctx = _ctx()

        fs.write_node(node, ctx)
        conn.commit.assert_called_once()

    def test_write_optimistic_lock_failure(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        node = _profile_node()
        node.metadata["expected_version"] = 3
        ctx = _ctx()

        # Simulate UPDATE matching 0 rows (version mismatch)
        cursor.rowcount = 0
        # After rowcount==0, code does SELECT to get actual version
        cursor.fetchone.return_value = (5,)

        with pytest.raises(ConcurrentModificationError):
            fs.write_node(node, ctx)
        conn.rollback.assert_called_once()

    def test_write_optimistic_lock_create_new(self, mock_sql_fs):
        """Optimistic lock with expected_version=None should use INSERT ON CONFLICT (create)."""
        fs, conn, cursor = mock_sql_fs
        node = _profile_node()
        ctx = _ctx()

        # No expected_version → INSERT ON CONFLICT path
        fs.write_node(node, ctx)
        conn.commit.assert_called_once()


class TestSQLContextFSAccess:

    def test_session_state_matching_session_allowed(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        uri = "ctx://acme/sessions/session-1/state"
        assert fs._is_accessible(uri, _ctx(session_id="session-1")) is True

    def test_session_state_different_session_denied(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        uri = "ctx://acme/sessions/session-1/state"
        assert fs._is_accessible(uri, _ctx(session_id="session-2")) is False

    def test_session_history_same_account_remains_accessible(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        uri = "ctx://acme/sessions/session-1/history"
        assert fs._is_accessible(uri, _ctx(session_id="session-2")) is True

    def test_session_archive_same_account_remains_accessible(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        uri = "ctx://acme/sessions/session-1/history/archive-1"
        assert fs._is_accessible(uri, _ctx(session_id="session-2")) is True


class TestSQLContextFSReadNode:

    def test_read_existing_node(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx()

        cursor.fetchone.return_value = (
            "ctx://acme/users/alice/memories/profile",
            "MEMORY", "profile", 0, "user:alice",
            "Short abstract", "## Overview", "Profile content",
            [],  # relations
            {"tags": ["test"], "version": 1},  # metadata
            "ACTIVE",
            datetime.now(UTC),
            datetime.now(UTC),
            1,
        )

        node = fs.read_node("ctx://acme/users/alice/memories/profile", ctx)
        assert node.uri == "ctx://acme/users/alice/memories/profile"
        assert node.content == "Profile content"
        assert node.abstract == "Short abstract"

    def test_read_not_found(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx()

        cursor.fetchone.return_value = None

        with pytest.raises(NodeNotFoundError):
            fs.read_node("ctx://acme/users/alice/memories/profile", ctx)

    def test_read_broken_node(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx()

        cursor.fetchone.return_value = (
            "ctx://acme/users/alice/memories/profile",
            "MEMORY", "profile", 0, "user:alice",
            "", "", "", [], {}, "BROKEN", None, None, 1,
        )

        with pytest.raises(NodeBrokenError):
            fs.read_node("ctx://acme/users/alice/memories/profile", ctx)

    def test_read_access_denied(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx(account_id="other")

        with pytest.raises(AccessDeniedError):
            fs.read_node("ctx://acme/users/alice/memories/profile", ctx)

    def test_read_with_relations(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx()

        relations = [
            {
                "from_uri": "ctx://acme/users/alice/memories/profile",
                "to_uri": "ctx://acme/users/alice/memories/preferences/coffee",
                "relation_type": "related_to",
                "weight": 0.8,
                "reason": "test",
            }
        ]

        cursor.fetchone.return_value = (
            "ctx://acme/users/alice/memories/profile",
            "MEMORY", "profile", 0, "user:alice",
            "abs", "ov", "cont",
            relations,  # JSONB relations
            {"version": 1},
            "ACTIVE",
            datetime.now(UTC),
            datetime.now(UTC),
            1,
        )

        node = fs.read_node("ctx://acme/users/alice/memories/profile", ctx)
        assert len(node.metadata["_relations"]) == 1
        assert node.metadata["_relations"][0].relation_type == "related_to"


class TestSQLContextFSExists:

    def test_exists_active(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx()
        cursor.fetchone.return_value = (1,)

        assert fs.exists("ctx://acme/users/alice/memories/profile", ctx) is True

    def test_not_exists(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx()
        cursor.fetchone.return_value = None

        assert fs.exists("ctx://acme/users/alice/memories/profile", ctx) is False

    def test_exists_access_denied(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx(account_id="other")

        assert fs.exists("ctx://acme/users/alice/memories/profile", ctx) is False


class TestSQLContextFSDeleteNode:

    def test_delete_existing(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx()
        cursor.rowcount = 1

        fs.delete_node("ctx://acme/users/alice/memories/profile", ctx)
        conn.commit.assert_called_once()

    def test_delete_not_found(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx()
        cursor.rowcount = 0

        with pytest.raises(NodeNotFoundError):
            fs.delete_node("ctx://acme/users/alice/memories/profile", ctx)

    def test_delete_access_denied(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx(account_id="other")

        with pytest.raises(AccessDeniedError):
            fs.delete_node("ctx://acme/users/alice/memories/profile", ctx)


class TestSQLContextFSListChildren:

    def test_list_children(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx()

        # Return URIs under preferences/
        cursor.fetchall.return_value = [
            ("ctx://acme/users/alice/memories/preferences/coffee",),
            ("ctx://acme/users/alice/memories/preferences/tea",),
            ("ctx://acme/users/alice/memories/preferences/food/sushi",),  # nested, should be excluded
        ]

        children = fs.list_children(
            "ctx://acme/users/alice/memories/preferences", ctx
        )

        # Only immediate children (no nested paths)
        assert len(children) == 2
        assert "ctx://acme/users/alice/memories/preferences/coffee" in children
        assert "ctx://acme/users/alice/memories/preferences/tea" in children

    def test_list_children_empty(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx()
        cursor.fetchall.return_value = []

        children = fs.list_children(
            "ctx://acme/users/alice/memories/preferences", ctx
        )
        assert children == []

    def test_list_children_filters_by_account(self, mock_sql_fs):
        """list_children SQL must include account_id filter."""
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx()
        cursor.fetchall.return_value = []

        fs.list_children("ctx://acme/users/alice/memories/preferences", ctx)

        # Verify the SQL parameters include account_id
        execute_call = cursor.execute.call_args
        sql, params = execute_call[0]
        assert "account_id" in sql
        assert "acme" in params  # account_id extracted from URI

    def test_list_children_access_denied(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx(account_id="other")

        with pytest.raises(AccessDeniedError):
            fs.list_children(
                "ctx://acme/users/alice/memories/preferences", ctx
            )


class TestSQLContextFSMoveNode:

    def test_move_existing(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx()

        # First UPDATE (move node itself) matches 1 row
        cursor.rowcount = 1

        fs.move_node(
            "ctx://acme/users/alice/memories/preferences/coffee",
            "ctx://acme/users/alice/memories/preferences/latte",
            ctx,
        )
        conn.commit.assert_called_once()

    def test_move_source_not_found(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx()

        # First UPDATE matches 0 rows → NodeNotFoundError
        cursor.rowcount = 0

        with pytest.raises(NodeNotFoundError):
            fs.move_node(
                "ctx://acme/users/alice/memories/preferences/coffee",
                "ctx://acme/users/alice/memories/preferences/latte",
                ctx,
            )

    def test_move_access_denied(self, mock_sql_fs):
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx(account_id="other")

        with pytest.raises(AccessDeniedError):
            fs.move_node(
                "ctx://acme/users/alice/memories/preferences/coffee",
                "ctx://acme/users/alice/memories/preferences/latte",
                ctx,
            )

    def test_move_cascades_to_relations_and_outbox(self, mock_sql_fs):
        """move_node must UPDATE relation_edges and emit DELETE + UPSERT outbox events."""
        import hashlib
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx()
        cursor.rowcount = 1  # main node UPDATE matches 1 row

        # to_regclass returns table names → relation_edges and outbox_events exist
        # fetchone is called for: to_regclass('relation_edges'), to_regclass('outbox_events')
        cursor.fetchone.side_effect = [
            ("relation_edges",),  # to_regclass('relation_edges')
            ("outbox_events",),   # to_regclass('outbox_events')
        ]

        new_uri = "ctx://acme/users/alice/memories/preferences/latte"
        old_uri = "ctx://acme/users/alice/memories/preferences/coffee"

        # fetchall is called twice:
        # 1) SELECT uri ... for DELETE events (all moved nodes)
        # 2) SELECT uri,abstract,... for UPSERT events (re-index)
        cursor.fetchall.side_effect = [
            [(new_uri,)],  # all_new_uris — just the moved node
            [(new_uri, "abs", "ov", "content",
              "user:alice", "preference", "MEMORY")],  # moved_rows
        ]

        fs.move_node(old_uri, new_uri, ctx)
        conn.commit.assert_called_once()

        sql_calls = [str(c[0][0]) for c in cursor.execute.call_args_list]

        # 1) relation_edges cascade
        relation_updates = [s for s in sql_calls if "relation_edges" in s]
        assert len(relation_updates) >= 2, (
            f"Expected >= 2 relation_edges updates, got {len(relation_updates)}"
        )

        # 2) DELETE_CONTEXT: verify correct ids including L2 /content.md
        delete_calls = [
            c for c in cursor.execute.call_args_list
            if "DELETE_CONTEXT" in str(c[0][0])
        ]
        assert len(delete_calls) == 1, (
            f"Expected 1 DELETE_CONTEXT, got {len(delete_calls)}"
        )
        # call[0] = (sql, (event_id, uri, account_id, payload_dict))
        delete_params = delete_calls[0][0][1]  # params tuple
        payload = delete_params[3]  # 4th element = Json(payload)
        assert isinstance(payload, dict), f"Expected dict payload, got {type(payload)}"
        ids = payload["ids_to_delete"]
        assert len(ids) == 3, f"Expected 3 ids (L0/L1/L2), got {len(ids)}"
        expected_l2 = hashlib.sha256(
            f"{old_uri}/content.md:2".encode()
        ).hexdigest()[:16]
        assert expected_l2 in ids, f"L2 id {expected_l2} not in {ids}"

        # 3) UPSERT_CONTEXT: verify records with correct new-URI ids
        upsert_calls = [
            c for c in cursor.execute.call_args_list
            if "UPSERT_CONTEXT" in str(c[0][0])
        ]
        assert len(upsert_calls) == 1, (
            f"Expected 1 UPSERT_CONTEXT, got {len(upsert_calls)}"
        )
        upsert_params = upsert_calls[0][0][1]
        records = upsert_params[3]["records"]
        assert len(records) >= 1, "Expected at least 1 record"
        expected_new_l2 = hashlib.sha256(
            f"{new_uri}/content.md:2".encode()
        ).hexdigest()[:16]
        all_ids = [r["id"] for r in records]
        assert expected_new_l2 in all_ids, (
            f"New L2 id {expected_new_l2} not in {all_ids}"
        )

        # 4) UPSERT metadata must include parent_uri, category, context_type,
        #    has_overview, has_content — matching build_index_records() output.
        for rec in records:
            meta = rec["metadata"]
            assert "category" in meta, f"Missing 'category' in metadata: {meta}"
            assert "context_type" in meta, f"Missing 'context_type' in metadata: {meta}"
            assert "parent_uri" in meta, f"Missing 'parent_uri' in metadata: {meta}"
            assert "has_overview" in meta, f"Missing 'has_overview' in metadata: {meta}"
            assert "has_content" in meta, f"Missing 'has_content' in metadata: {meta}"
        # Spot-check values
        l0 = [r for r in records if r["level"] == 0][0]
        assert l0["metadata"]["category"] == "preference"
        assert l0["metadata"]["context_type"] == "MEMORY"
        assert l0["metadata"]["has_overview"] is True
        assert l0["metadata"]["has_content"] is True
        # L2 must have parent_uri pointing to the directory URI (not /content.md)
        l2 = [r for r in records if r["level"] == 2][0]
        assert l2["metadata"]["parent_uri"] == new_uri

    def test_move_leaf_node_includes_root_in_upsert(self, mock_sql_fs):
        """Moving a leaf node: the root to_uri must appear in UPSERT query."""
        fs, conn, cursor = mock_sql_fs
        ctx = _ctx()
        cursor.rowcount = 1

        cursor.fetchone.side_effect = [
            ("relation_edges",), ("outbox_events",),
        ]
        new_uri = "ctx://acme/users/alice/memories/preferences/latte"
        cursor.fetchall.side_effect = [
            [(new_uri,)],
            [(new_uri, "a", "o", "c", "user:alice", "preference", "MEMORY")],
        ]

        fs.move_node(
            "ctx://acme/users/alice/memories/preferences/coffee",
            new_uri,
            ctx,
        )

        # Verify the SELECT for re-index uses `uri = %s OR uri LIKE %s`
        # so the root to_uri is included (not just children under prefix)
        sql_calls = [str(c[0][0]) for c in cursor.execute.call_args_list]
        select_for_upsert = [
            s for s in sql_calls
            if "SELECT uri, abstract" in s and "uri = %s OR uri LIKE %s" in s
        ]
        assert len(select_for_upsert) == 1, (
            "Expected SELECT with `uri = %s OR uri LIKE %s` for re-index"
        )