"""Unit tests for SQLSessionArchiveStore.
Tests use mocked psycopg2 connections to verify behavior without
requiring a running PostgreSQL instance.
"""
import json
from unittest.mock import MagicMock, patch
import pytest
from core.models import RequestContext
from session.models import ArchiveEntry, ArchiveWriteResult
def _ctx(account_id="test_account", user_id="test_user"):
return RequestContext(
account_id=account_id,
user_id=user_id,
agent_id="test_agent",
session_id="test_session",
trace_id="test_trace",
)
SAMPLE_MESSAGES = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi!"},
{"role": "user", "content": "How are you?"},
{"role": "assistant", "content": "Doing well!"},
]
@pytest.fixture
def mock_store():
"""Create SQLSessionArchiveStore with mocked psycopg2 connection pool."""
mock_json = lambda x: x
with patch("session.sql_archive_store.psycopg2") as mock_pg, \
patch("session.sql_archive_store.Json", mock_json), \
patch("session.sql_archive_store._HAS_PSYCOPG2", True):
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_pg.connect.return_value = mock_conn
from session.sql_archive_store import SQLSessionArchiveStore
store = SQLSessionArchiveStore.__new__(SQLSessionArchiveStore)
store._connection_string = "host=localhost dbname=test"
store._pool_size = 5
store._pool = []
store._get_connection = lambda: mock_conn
store._return_connection = lambda c: None
yield store, mock_conn, mock_cursor
class TestEnsureTable:
def test_creates_table_with_composite_pk(self, mock_store):
"""_ensure_table issues CREATE TABLE with composite primary key."""
store, conn, cursor = mock_store
store._ensure_table()
sqls = [call.args[0] for call in cursor.execute.call_args_list]
assert any("PRIMARY KEY (account_id, session_id, archive_id)" in s for s in sqls)
def test_raises_on_failure(self, mock_store):
"""_ensure_table raises RuntimeError when DDL fails."""
import fs.sql_adapter.schema as schema_mod
schema_mod._schema_ensured = False
store, conn, cursor = mock_store
cursor.execute.side_effect = Exception("connection lost")
with pytest.raises(RuntimeError, match="Failed to ensure schema"):
store._ensure_table()
conn.rollback.assert_called()
class TestWriteArchive:
def test_success(self, mock_store):
store, conn, cursor = mock_store
ctx = _ctx()
result = store.write_archive(
session_id="sess1",
overview="Test overview",
abstract="Test abstract",
messages=SAMPLE_MESSAGES,
ctx=ctx,
archive_id="arc001",
)
assert result.success is True
assert result.archive_id == "arc001"
assert result.session_id == "sess1"
assert "ctx://test_account/sessions/sess1/history/arc001" in result.uri
assert result.error is None
conn.commit.assert_called_once()
def test_auto_generates_archive_id(self, mock_store):
store, conn, cursor = mock_store
ctx = _ctx()
result = store.write_archive(
session_id="sess1",
overview="Overview",
abstract="Abstract",
messages=SAMPLE_MESSAGES,
ctx=ctx,
)
assert result.success is True
assert result.archive_id is not None
assert len(result.archive_id) > 0
parts = result.archive_id.split("_")
assert len(parts) == 3
def test_passes_correct_params(self, mock_store):
store, conn, cursor = mock_store
ctx = _ctx()
store.write_archive(
session_id="sess1",
overview="Overview text",
abstract="Abstract text",
messages=SAMPLE_MESSAGES,
ctx=ctx,
archive_id="arc001",
)
insert_call = cursor.execute.call_args_list[-1]
sql = insert_call.args[0]
params = insert_call.args[1]
assert "INSERT INTO session_archives" in sql
assert "ON CONFLICT (account_id, session_id, archive_id)" in sql
assert params[0] == "arc001"
assert params[1] == "sess1"
assert params[2] == "test_account"
assert params[3] == "Abstract text"
assert params[4] == "Overview text"
assert params[5] == SAMPLE_MESSAGES
assert params[6]["message_count"] == 4
def test_returns_failure_on_db_error(self, mock_store):
store, conn, cursor = mock_store
cursor.execute.side_effect = Exception("disk full")
ctx = _ctx()
result = store.write_archive(
session_id="sess1",
overview="O",
abstract="A",
messages=[],
ctx=ctx,
archive_id="arc001",
)
assert result.success is False
assert "disk full" in result.error
conn.rollback.assert_called()
def test_empty_strings_for_none_abstract_overview(self, mock_store):
store, conn, cursor = mock_store
ctx = _ctx()
store.write_archive(
session_id="sess1",
overview=None,
abstract=None,
messages=[],
ctx=ctx,
archive_id="arc001",
)
params = cursor.execute.call_args.args[1]
assert params[3] == ""
assert params[4] == ""
def test_merges_extra_metadata(self, mock_store):
store, conn, cursor = mock_store
ctx = _ctx()
result = store.write_archive(
session_id="sess1",
overview="overview",
abstract="abstract",
messages=[],
ctx=ctx,
archive_id="arc-quality",
metadata={"compression_quality": {"entity_retention_ratio": 1.0}},
)
assert result.success is True
params = cursor.execute.call_args.args[1]
metadata = params[6]
assert metadata["archive_id"] == "arc-quality"
assert metadata["compression_quality"] == {"entity_retention_ratio": 1.0}
class TestListArchives:
def test_empty(self, mock_store):
store, conn, cursor = mock_store
cursor.fetchall.return_value = []
ctx = _ctx()
entries = store.list_archives("sess1", ctx)
assert entries == []
sql = cursor.execute.call_args.args[0]
params = cursor.execute.call_args.args[1]
assert "WHERE account_id = %s AND session_id = %s" in sql
assert params == ("test_account", "sess1")
def test_returns_entries_newest_first(self, mock_store):
store, conn, cursor = mock_store
cursor.fetchall.return_value = [
("arc2", "sess1", "Abstract 2", "Overview 2",
[], {"archive_id": "arc2"}, "2025-04-15T12:00:00"),
("arc1", "sess1", "Abstract 1", "Overview 1",
[{"role": "user", "content": "hi"}], {"archive_id": "arc1"}, "2025-04-14T12:00:00"),
]
ctx = _ctx()
entries = store.list_archives("sess1", ctx)
assert len(entries) == 2
assert entries[0].archive_id == "arc2"
assert entries[0].overview == "Overview 2"
assert entries[0].messages == []
assert entries[1].archive_id == "arc1"
sql = cursor.execute.call_args.args[0]
assert "ORDER BY created_at DESC" in sql
def test_returns_empty_on_db_error(self, mock_store):
store, conn, cursor = mock_store
cursor.execute.side_effect = Exception("timeout")
ctx = _ctx()
entries = store.list_archives("sess1", ctx)
assert entries == []
def test_isolates_by_account(self, mock_store):
"""list_archives filters by account_id — cross-account data invisible."""
store, conn, cursor = mock_store
cursor.fetchall.return_value = []
ctx_other = _ctx(account_id="other_account")
store.list_archives("sess1", ctx_other)
params = cursor.execute.call_args.args[1]
assert params[0] == "other_account"
def test_excludes_merged_archives(self, mock_store):
store, conn, cursor = mock_store
cursor.fetchall.return_value = [
("arc2", "sess1", "Abstract 2", "Overview 2",
[], {"archive_id": "arc2", "status": "MERGED"}, "2025-04-15T12:00:00"),
("arc1", "sess1", "Abstract 1", "Overview 1",
[], {"archive_id": "arc1"}, "2025-04-14T12:00:00"),
]
ctx = _ctx()
entries = store.list_archives("sess1", ctx)
assert [entry.archive_id for entry in entries] == ["arc1"]
class TestListArchivesSince:
def test_excludes_merged_archives(self, mock_store):
from datetime import datetime
store, conn, cursor = mock_store
cursor.fetchall.return_value = [
("arc2", "sess1", "Abstract 2", "Overview 2",
[], {"archive_id": "arc2", "status": "MERGED"}, "2025-04-15T12:00:00"),
("arc1", "sess1", "Abstract 1", "Overview 1",
[], {"archive_id": "arc1"}, "2025-04-14T12:00:00"),
]
ctx = _ctx()
entries = store.list_archives_since(datetime(2025, 4, 1), ctx)
sql = cursor.execute.call_args.args[0]
assert "UPPER(COALESCE(metadata->>'status', '')) <> 'MERGED'" in sql
assert sql.index("UPPER(COALESCE(metadata->>'status', '')) <> 'MERGED'") < sql.index("LIMIT")
assert [entry.archive_id for entry in entries] == ["arc1"]
class TestReadArchive:
def test_not_found(self, mock_store):
store, conn, cursor = mock_store
cursor.fetchone.return_value = None
ctx = _ctx()
entry = store.read_archive("sess1", "missing", ctx)
assert entry is None
def test_found(self, mock_store):
store, conn, cursor = mock_store
test_messages = [{"role": "user", "content": "Hello"}]
cursor.fetchone.return_value = (
"arc001", "sess1", "Test abstract", "Test overview",
json.dumps(test_messages), {"key": "val"}, "2025-04-15T10:00:00",
)
ctx = _ctx()
entry = store.read_archive("sess1", "arc001", ctx)
assert entry is not None
assert entry.archive_id == "arc001"
assert entry.session_id == "sess1"
assert entry.abstract == "Test abstract"
assert entry.overview == "Test overview"
assert entry.messages == test_messages
assert entry.metadata == {"key": "val"}
def test_filters_by_session_id(self, mock_store):
"""read_archive includes session_id in WHERE — prevents cross-session read."""
store, conn, cursor = mock_store
cursor.fetchone.return_value = None
ctx = _ctx()
store.read_archive("sess1", "arc001", ctx)
sql = cursor.execute.call_args.args[0]
params = cursor.execute.call_args.args[1]
assert "session_id = %s" in sql
assert params == ("arc001", "sess1", "test_account")
def test_returns_none_on_db_error(self, mock_store):
store, conn, cursor = mock_store
cursor.execute.side_effect = Exception("timeout")
ctx = _ctx()
entry = store.read_archive("sess1", "arc001", ctx)
assert entry is None
def test_handles_string_metadata(self, mock_store):
"""Metadata returned as string from psycopg2 is parsed correctly."""
store, conn, cursor = mock_store
cursor.fetchone.return_value = (
"arc001", "sess1", "abs", "ovr",
'[]', '{"archive_id": "arc001"}', "2025-04-15T10:00:00",
)
ctx = _ctx()
entry = store.read_archive("sess1", "arc001", ctx)
assert entry.metadata == {"archive_id": "arc001"}
assert entry.messages == []
def test_handles_null_messages(self, mock_store):
store, conn, cursor = mock_store
cursor.fetchone.return_value = (
"arc001", "sess1", "abs", "ovr",
None, {}, "2025-04-15T10:00:00",
)
ctx = _ctx()
entry = store.read_archive("sess1", "arc001", ctx)
assert entry.messages == []
def test_cross_session_read_returns_none(self, mock_store):
"""Reading with wrong session_id returns None even if archive_id exists for another session."""
store, conn, cursor = mock_store
cursor.fetchone.return_value = None
ctx = _ctx()
result = store.read_archive("sess2", "arc001", ctx)
assert result is None
params = cursor.execute.call_args.args[1]
assert params[1] == "sess2"
def test_returns_merged_archive_for_direct_lookup(self, mock_store):
store, conn, cursor = mock_store
cursor.fetchone.return_value = (
"arc001", "sess1", "abs", "ovr",
"[]", {"archive_id": "arc001", "status": "MERGED"}, "2025-04-15T10:00:00",
)
ctx = _ctx()
entry = store.read_archive("sess1", "arc001", ctx)
assert entry is not None
assert entry.archive_id == "arc001"
assert entry.metadata["status"] == "MERGED"
class TestReadArchiveAbstract:
def test_found(self, mock_store):
store, conn, cursor = mock_store
cursor.fetchone.return_value = ("Test abstract",)
ctx = _ctx()
abstract = store.read_archive_abstract("sess1", "arc001", ctx)
assert abstract == "Test abstract"
def test_not_found(self, mock_store):
store, conn, cursor = mock_store
cursor.fetchone.return_value = None
ctx = _ctx()
abstract = store.read_archive_abstract("sess1", "missing", ctx)
assert abstract is None
def test_empty_string_returns_none(self, mock_store):
"""Empty abstract string should return None, not empty string."""
store, conn, cursor = mock_store
cursor.fetchone.return_value = ("",)
ctx = _ctx()
abstract = store.read_archive_abstract("sess1", "arc001", ctx)
assert abstract is None
def test_filters_by_session_id(self, mock_store):
store, conn, cursor = mock_store
cursor.fetchone.return_value = None
ctx = _ctx()
store.read_archive_abstract("sess1", "arc001", ctx)
sql = cursor.execute.call_args.args[0]
params = cursor.execute.call_args.args[1]
assert "session_id = %s" in sql
assert params == ("arc001", "sess1", "test_account")
def test_returns_none_on_db_error(self, mock_store):
store, conn, cursor = mock_store
cursor.execute.side_effect = Exception("timeout")
ctx = _ctx()
abstract = store.read_archive_abstract("sess1", "arc001", ctx)
assert abstract is None
def test_returns_merged_abstract_for_direct_lookup(self, mock_store):
store, conn, cursor = mock_store
cursor.fetchone.return_value = (
"Test abstract",
{"archive_id": "arc001", "status": "MERGED"},
)
ctx = _ctx()
abstract = store.read_archive_abstract("sess1", "arc001", ctx)
assert abstract == "Test abstract"
class TestConstructor:
def test_raises_without_psycopg2(self):
with patch("session.sql_archive_store._HAS_PSYCOPG2", False):
from session.sql_archive_store import SQLSessionArchiveStore
with pytest.raises(ImportError, match="psycopg2 is required"):
SQLSessionArchiveStore("host=localhost")
def test_delete_archive_removes_single_row(mock_store):
store, conn, cur = mock_store
ctx = _ctx(account_id="acct-test")
cur.rowcount = 1
ok = store.delete_archive("sess1", "arc001", ctx)
assert ok is True
sql = cur.execute.call_args.args[0]
params = cur.execute.call_args.args[1]
assert "DELETE FROM session_archives" in sql
assert "archive_id = %s" in sql
assert params == ("arc001", "sess1", "acct-test")
conn.commit.assert_called_once()
def test_delete_archive_returns_false_when_no_row_deleted(mock_store):
store, conn, cur = mock_store
ctx = _ctx(account_id="acct-test")
cur.rowcount = 0
ok = store.delete_archive("sess1", "missing", ctx)
assert ok is False
conn.rollback.assert_called_once()
conn.commit.assert_not_called()
def test_mark_archive_merged_updates_metadata_atomically(mock_store):
store, conn, cur = mock_store
ctx = _ctx(account_id="acct-test")
cur.rowcount = 1
ok = store.mark_archive_merged("sess1", "arc001", ctx, merged_into="merged_001")
assert ok is True
executed = [call.args[0] for call in cur.execute.call_args_list]
assert not any("SELECT metadata FROM session_archives" in sql for sql in executed)
assert any("COALESCE(metadata" in sql and "|| %s::jsonb" in sql for sql in executed)
update_params = cur.execute.call_args.args[1]
assert update_params[0]["status"] == "MERGED"
assert update_params[0]["merged_into"] == "merged_001"
assert update_params[1:] == ("arc001", "sess1", "acct-test")
conn.commit.assert_called_once()
def test_mark_archive_merged_returns_false_when_no_sql_row_updated(mock_store):
store, conn, cur = mock_store
ctx = _ctx(account_id="acct-test")
cur.rowcount = 0
ok = store.mark_archive_merged("sess1", "missing", ctx, merged_into="merged_001")
assert ok is False
conn.rollback.assert_called_once()
conn.commit.assert_not_called()
def test_unmark_archive_merged_updates_metadata(mock_store):
store, conn, cur = mock_store
ctx = _ctx(account_id="acct-test")
cur.fetchone.return_value = (
{"archive_id": "arc001", "status": "MERGED", "merged_into": "merged_001"},
)
ok = store.unmark_archive_merged("sess1", "arc001", ctx, merged_into="merged_001")
assert ok is True
executed = [call.args[0] for call in cur.execute.call_args_list]
assert any("SELECT metadata FROM session_archives" in sql for sql in executed)
assert any("UPDATE session_archives" in sql for sql in executed)
update_params = cur.execute.call_args_list[-1].args[1]
assert "status" not in update_params[0]
assert "merged_into" not in update_params[0]
assert "merged_at" not in update_params[0]
assert update_params[1:] == ("arc001", "sess1", "acct-test")
conn.commit.assert_called_once()