"""
SQLite State Store for Hermes Agent.
Provides persistent session storage with FTS5 full-text search, replacing
the per-session JSONL file approach. Stores session metadata, full message
history, and model configuration for CLI and gateway sessions.
Key design decisions:
- WAL mode for concurrent readers + one writer (gateway multi-platform)
- FTS5 virtual table for fast text search across all session messages
- Compression-triggered session splitting via parent_session_id chains
- Batch runner and RL trajectories are NOT stored here (separate systems)
- Session source tagging ('cli', 'telegram', 'discord', etc.) for filtering
"""
import json
import logging
import random
import re
import sqlite3
import threading
import time
from pathlib import Path
from agent.memory_manager import sanitize_context
from hermes_constants import get_hermes_home
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
logger = logging.getLogger(__name__)
T = TypeVar("T")
DEFAULT_DB_PATH = get_hermes_home() / "state.db"
SCHEMA_VERSION = 12
_WAL_INCOMPAT_MARKERS = (
"locking protocol",
"not authorized",
"disk i/o error",
)
_last_init_error: Optional[str] = None
_last_init_error_lock = threading.Lock()
_wal_fallback_warned_paths: set[str] = set()
_wal_fallback_warned_lock = threading.Lock()
def _set_last_init_error(msg: Optional[str]) -> None:
"""Record (or clear) the most recent state.db init failure.
Thread-safe via _last_init_error_lock. Callers pass a message to
record a failure or None to clear. SessionDB.__init__ only calls
this to SET on failure — it deliberately does NOT clear on success,
because in a multi-threaded caller (e.g. gateway / web_server per-
request SessionDB() instantiation), a concurrent successful open
racing past a different thread's failure would erase the cause
string that thread's /resume handler is about to format. Explicit
clears (e.g. test fixtures) are still supported by passing None.
"""
global _last_init_error
with _last_init_error_lock:
_last_init_error = msg
def get_last_init_error() -> Optional[str]:
"""Return the most recent state.db init failure, if any.
Slash-command handlers (``/resume``, ``/title``, ``/history``, ``/branch``)
call this to surface the underlying cause in their error messages when
``_session_db is None``. Returns ``None`` if SessionDB initialized
successfully (or hasn't been attempted).
"""
return _last_init_error
def format_session_db_unavailable(prefix: str = "Session database not available") -> str:
"""Format a user-facing 'session DB unavailable' message with cause.
When ``SessionDB()`` init fails, callers set ``_session_db = None`` and
several slash commands (/resume, /title, /history, /branch) previously
responded with a bare ``"Session database not available."`` — no
indication of WHY. This helper includes the captured cause (typically
``"locking protocol"`` from NFS/SMB) and points users at the known
culprit so they can fix it themselves.
Example output:
Session database not available: locking protocol (state.db may be
on NFS/SMB — see https://www.sqlite.org/wal.html).
"""
cause = get_last_init_error()
if not cause:
return f"{prefix}."
hint = ""
if any(marker in cause.lower() for marker in _WAL_INCOMPAT_MARKERS):
hint = " (state.db may be on NFS/SMB/FUSE — see https://www.sqlite.org/wal.html)"
return f"{prefix}: {cause}{hint}."
def apply_wal_with_fallback(
conn: sqlite3.Connection,
*,
db_label: str = "state.db",
) -> str:
"""Set ``journal_mode=WAL`` on ``conn``, falling back to DELETE on failure.
Returns the journal mode actually set (``"wal"`` or ``"delete"``).
On WAL-incompatible filesystems (NFS, SMB, some FUSE), SQLite raises
``OperationalError("locking protocol")`` when setting WAL. We fall
back to DELETE mode — the pre-WAL default, which works on NFS — and
log one WARNING explaining why.
The WARNING is deduplicated per ``db_label``: repeated connections
to the same underlying DB (e.g. kanban_db.connect() which is called
on every kanban operation) log once per process, not once per call.
Different db_labels log independently, so state.db and kanban.db
each get one warning on the same NFS mount.
Shared by :class:`SessionDB` and ``hermes_cli.kanban_db.connect`` so
both databases get identical fallback behavior.
"""
try:
conn.execute("PRAGMA journal_mode=WAL")
return "wal"
except sqlite3.OperationalError as exc:
msg = str(exc).lower()
if not any(marker in msg for marker in _WAL_INCOMPAT_MARKERS):
raise
_log_wal_fallback_once(db_label, exc)
conn.execute("PRAGMA journal_mode=DELETE")
return "delete"
def _log_wal_fallback_once(db_label: str, exc: Exception) -> None:
"""Log a single WARNING per (process, db_label) about WAL fallback.
Without this dedup, NFS users running kanban (which opens a fresh
connection on every operation — see hermes_cli/kanban_db.py) would
fill errors.log with hundreds of identical warnings per hour.
"""
with _wal_fallback_warned_lock:
if db_label in _wal_fallback_warned_paths:
return
_wal_fallback_warned_paths.add(db_label)
logger.warning(
"%s: WAL journal_mode unsupported on this filesystem (%s) — "
"falling back to journal_mode=DELETE (slower rollback-journal "
"mode; reduces concurrency but works on NFS/SMB/FUSE). See "
"https://www.sqlite.org/wal.html for details. This warning "
"fires once per process per database.",
db_label,
exc,
)
SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
source TEXT NOT NULL,
user_id TEXT,
model TEXT,
model_config TEXT,
system_prompt TEXT,
parent_session_id TEXT,
started_at REAL NOT NULL,
ended_at REAL,
end_reason TEXT,
message_count INTEGER DEFAULT 0,
tool_call_count INTEGER DEFAULT 0,
input_tokens INTEGER DEFAULT 0,
output_tokens INTEGER DEFAULT 0,
cache_read_tokens INTEGER DEFAULT 0,
cache_write_tokens INTEGER DEFAULT 0,
reasoning_tokens INTEGER DEFAULT 0,
billing_provider TEXT,
billing_base_url TEXT,
billing_mode TEXT,
estimated_cost_usd REAL,
actual_cost_usd REAL,
cost_status TEXT,
cost_source TEXT,
pricing_version TEXT,
title TEXT,
api_call_count INTEGER DEFAULT 0,
handoff_state TEXT,
handoff_platform TEXT,
handoff_error TEXT,
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL REFERENCES sessions(id),
role TEXT NOT NULL,
content TEXT,
tool_call_id TEXT,
tool_calls TEXT,
tool_name TEXT,
timestamp REAL NOT NULL,
token_count INTEGER,
finish_reason TEXT,
reasoning TEXT,
reasoning_content TEXT,
reasoning_details TEXT,
codex_reasoning_items TEXT,
codex_message_items TEXT,
platform_message_id TEXT
);
CREATE TABLE IF NOT EXISTS state_meta (
key TEXT PRIMARY KEY,
value TEXT
);
CREATE INDEX IF NOT EXISTS idx_sessions_source ON sessions(source);
CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_id);
CREATE INDEX IF NOT EXISTS idx_sessions_started ON sessions(started_at DESC);
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestamp);
"""
FTS_SQL = """
CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
content
);
CREATE TRIGGER IF NOT EXISTS messages_fts_insert AFTER INSERT ON messages BEGIN
INSERT INTO messages_fts(rowid, content) VALUES (
new.id,
COALESCE(new.content, '') || ' ' || COALESCE(new.tool_name, '') || ' ' || COALESCE(new.tool_calls, '')
);
END;
CREATE TRIGGER IF NOT EXISTS messages_fts_delete AFTER DELETE ON messages BEGIN
DELETE FROM messages_fts WHERE rowid = old.id;
END;
CREATE TRIGGER IF NOT EXISTS messages_fts_update AFTER UPDATE ON messages BEGIN
DELETE FROM messages_fts WHERE rowid = old.id;
INSERT INTO messages_fts(rowid, content) VALUES (
new.id,
COALESCE(new.content, '') || ' ' || COALESCE(new.tool_name, '') || ' ' || COALESCE(new.tool_calls, '')
);
END;
"""
FTS_TRIGRAM_SQL = """
CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts_trigram USING fts5(
content,
tokenize='trigram'
);
CREATE TRIGGER IF NOT EXISTS messages_fts_trigram_insert AFTER INSERT ON messages BEGIN
INSERT INTO messages_fts_trigram(rowid, content) VALUES (
new.id,
COALESCE(new.content, '') || ' ' || COALESCE(new.tool_name, '') || ' ' || COALESCE(new.tool_calls, '')
);
END;
CREATE TRIGGER IF NOT EXISTS messages_fts_trigram_delete AFTER DELETE ON messages BEGIN
DELETE FROM messages_fts_trigram WHERE rowid = old.id;
END;
CREATE TRIGGER IF NOT EXISTS messages_fts_trigram_update AFTER UPDATE ON messages BEGIN
DELETE FROM messages_fts_trigram WHERE rowid = old.id;
INSERT INTO messages_fts_trigram(rowid, content) VALUES (
new.id,
COALESCE(new.content, '') || ' ' || COALESCE(new.tool_name, '') || ' ' || COALESCE(new.tool_calls, '')
);
END;
"""
class SessionDB:
"""
SQLite-backed session storage with FTS5 search.
Thread-safe for the common gateway pattern (multiple reader threads,
single writer via WAL mode). Each method opens its own cursor.
"""
_WRITE_MAX_RETRIES = 15
_WRITE_RETRY_MIN_S = 0.020
_WRITE_RETRY_MAX_S = 0.150
_CHECKPOINT_EVERY_N_WRITES = 50
def __init__(self, db_path: Path = None):
self.db_path = db_path or DEFAULT_DB_PATH
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._lock = threading.Lock()
self._write_count = 0
try:
self._conn = sqlite3.connect(
str(self.db_path),
check_same_thread=False,
timeout=1.0,
isolation_level=None,
)
self._conn.row_factory = sqlite3.Row
apply_wal_with_fallback(self._conn, db_label="state.db")
self._conn.execute("PRAGMA foreign_keys=ON")
self._init_schema()
except Exception as exc:
_set_last_init_error(f"{type(exc).__name__}: {exc}")
raise
def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T:
"""Execute a write transaction with BEGIN IMMEDIATE and jitter retry.
*fn* receives the connection and should perform INSERT/UPDATE/DELETE
statements. The caller must NOT call ``commit()`` — that's handled
here after *fn* returns.
BEGIN IMMEDIATE acquires the WAL write lock at transaction start
(not at commit time), so lock contention surfaces immediately.
On ``database is locked``, we release the Python lock, sleep a
random 20-150ms, and retry — breaking the convoy pattern that
SQLite's built-in deterministic backoff creates.
Returns whatever *fn* returns.
"""
last_err: Optional[Exception] = None
for attempt in range(self._WRITE_MAX_RETRIES):
try:
with self._lock:
self._conn.execute("BEGIN IMMEDIATE")
try:
result = fn(self._conn)
self._conn.commit()
except BaseException:
try:
self._conn.rollback()
except Exception:
pass
raise
self._write_count += 1
if self._write_count % self._CHECKPOINT_EVERY_N_WRITES == 0:
self._try_wal_checkpoint()
return result
except sqlite3.OperationalError as exc:
err_msg = str(exc).lower()
if "locked" in err_msg or "busy" in err_msg:
last_err = exc
if attempt < self._WRITE_MAX_RETRIES - 1:
jitter = random.uniform(
self._WRITE_RETRY_MIN_S,
self._WRITE_RETRY_MAX_S,
)
time.sleep(jitter)
continue
raise
raise last_err or sqlite3.OperationalError(
"database is locked after max retries"
)
def _try_wal_checkpoint(self) -> None:
"""Best-effort PASSIVE WAL checkpoint. Never blocks, never raises.
Flushes committed WAL frames back into the main DB file for any
frames that no other connection currently needs. Keeps the WAL
from growing unbounded when many processes hold persistent
connections.
"""
try:
with self._lock:
result = self._conn.execute(
"PRAGMA wal_checkpoint(PASSIVE)"
).fetchone()
if result and result[1] > 0:
logger.debug(
"WAL checkpoint: %d/%d pages checkpointed",
result[2], result[1],
)
except Exception:
pass
def close(self):
"""Close the database connection.
Attempts a PASSIVE WAL checkpoint first so that exiting processes
help keep the WAL file from growing unbounded.
"""
with self._lock:
if self._conn:
try:
self._conn.execute("PRAGMA wal_checkpoint(PASSIVE)")
except Exception:
pass
self._conn.close()
self._conn = None
@staticmethod
def _parse_schema_columns(schema_sql: str) -> Dict[str, Dict[str, str]]:
"""Extract expected columns per table from SCHEMA_SQL.
Uses an in-memory SQLite database to parse the SQL — SQLite itself
handles all syntax (DEFAULT expressions with commas, inline
REFERENCES, CHECK constraints, etc.) so there are zero regex
edge cases. The in-memory DB is opened, the schema DDL is
executed, and PRAGMA table_info extracts the column metadata.
Adding a column to SCHEMA_SQL is all that's needed; the
reconciliation loop picks it up automatically.
"""
ref = sqlite3.connect(":memory:")
try:
ref.executescript(schema_sql)
table_columns: Dict[str, Dict[str, str]] = {}
for (tbl,) in ref.execute(
"SELECT name FROM sqlite_master "
"WHERE type='table' AND name NOT LIKE 'sqlite_%'"
).fetchall():
cols: Dict[str, str] = {}
for row in ref.execute(
f'PRAGMA table_info("{tbl}")'
).fetchall():
col_name = row[1]
col_type = row[2] or ""
notnull = row[3]
default = row[4]
pk = row[5]
parts = [col_type] if col_type else []
if notnull and not pk:
parts.append("NOT NULL")
if default is not None:
parts.append(f"DEFAULT {default}")
cols[col_name] = " ".join(parts)
table_columns[tbl] = cols
return table_columns
finally:
ref.close()
def _reconcile_columns(self, cursor: sqlite3.Cursor) -> None:
"""Ensure live tables have every column declared in SCHEMA_SQL.
Follows the Beets/sqlite-utils pattern: the CREATE TABLE definition
in SCHEMA_SQL is the single source of truth for the desired schema.
On every startup this method diffs the live columns (via PRAGMA
table_info) against the declared columns, and ADDs any that are
missing.
This makes column additions a declarative operation — just add
the column to SCHEMA_SQL and it appears on the next startup.
Version-gated migration blocks are no longer needed for ADD COLUMN.
"""
expected = self._parse_schema_columns(SCHEMA_SQL)
for table_name, declared_cols in expected.items():
try:
rows = cursor.execute(
f'PRAGMA table_info("{table_name}")'
).fetchall()
except sqlite3.OperationalError:
continue
live_cols = set()
for row in rows:
name = row[1] if isinstance(row, (tuple, list)) else row["name"]
live_cols.add(name)
for col_name, col_type in declared_cols.items():
if col_name not in live_cols:
safe_name = col_name.replace('"', '""')
try:
cursor.execute(
f'ALTER TABLE "{table_name}" ADD COLUMN "{safe_name}" {col_type}'
)
except sqlite3.OperationalError as exc:
logger.debug(
"reconcile %s.%s: %s", table_name, col_name, exc,
)
def _init_schema(self):
"""Create tables and FTS if they don't exist, reconcile columns.
Schema management follows the declarative reconciliation pattern
(Beets, sqlite-utils): SCHEMA_SQL is the single source of truth.
On existing databases, _reconcile_columns() diffs live columns
against SCHEMA_SQL and ADDs any missing ones. This eliminates
the version-gated migration chain for column additions, making
it impossible for reordered or inserted migrations to skip columns.
The schema_version table is retained for future data migrations
(transforming existing rows) which cannot be handled declaratively.
"""
cursor = self._conn.cursor()
cursor.executescript(SCHEMA_SQL)
self._reconcile_columns(cursor)
try:
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_messages_platform_msg_id "
"ON messages(session_id, platform_message_id) "
"WHERE platform_message_id IS NOT NULL"
)
except sqlite3.OperationalError as exc:
logger.debug("idx_messages_platform_msg_id create skipped: %s", exc)
cursor.execute("SELECT version FROM schema_version LIMIT 1")
row = cursor.fetchone()
if row is None:
cursor.execute(
"INSERT INTO schema_version (version) VALUES (?)",
(SCHEMA_VERSION,),
)
else:
current_version = row["version"] if isinstance(row, sqlite3.Row) else row[0]
if current_version < 10:
try:
cursor.execute("SELECT * FROM messages_fts_trigram LIMIT 0")
_fts_trigram_exists = True
except sqlite3.OperationalError:
_fts_trigram_exists = False
if not _fts_trigram_exists:
cursor.executescript(FTS_TRIGRAM_SQL)
cursor.execute(
"INSERT INTO messages_fts_trigram(rowid, content) "
"SELECT id, content FROM messages WHERE content IS NOT NULL"
)
if current_version < 11:
for _trig in (
"messages_fts_insert",
"messages_fts_delete",
"messages_fts_update",
"messages_fts_trigram_insert",
"messages_fts_trigram_delete",
"messages_fts_trigram_update",
):
try:
cursor.execute(f"DROP TRIGGER IF EXISTS {_trig}")
except sqlite3.OperationalError:
pass
for _tbl in ("messages_fts", "messages_fts_trigram"):
try:
cursor.execute(f"DROP TABLE IF EXISTS {_tbl}")
except sqlite3.OperationalError:
pass
cursor.executescript(FTS_SQL)
cursor.executescript(FTS_TRIGRAM_SQL)
cursor.execute(
"INSERT INTO messages_fts(rowid, content) "
"SELECT id, "
"COALESCE(content, '') || ' ' || "
"COALESCE(tool_name, '') || ' ' || "
"COALESCE(tool_calls, '') "
"FROM messages"
)
cursor.execute(
"INSERT INTO messages_fts_trigram(rowid, content) "
"SELECT id, "
"COALESCE(content, '') || ' ' || "
"COALESCE(tool_name, '') || ' ' || "
"COALESCE(tool_calls, '') "
"FROM messages"
)
if current_version < SCHEMA_VERSION:
cursor.execute(
"UPDATE schema_version SET version = ?",
(SCHEMA_VERSION,),
)
try:
cursor.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique "
"ON sessions(title) WHERE title IS NOT NULL"
)
except sqlite3.OperationalError:
pass
try:
cursor.execute("SELECT * FROM messages_fts LIMIT 0")
except sqlite3.OperationalError:
cursor.executescript(FTS_SQL)
try:
cursor.execute("SELECT * FROM messages_fts_trigram LIMIT 0")
except sqlite3.OperationalError:
cursor.executescript(FTS_TRIGRAM_SQL)
self._conn.commit()
def _insert_session_row(
self,
session_id: str,
source: str,
model: str = None,
model_config: Dict[str, Any] = None,
system_prompt: str = None,
user_id: str = None,
parent_session_id: str = None,
) -> None:
"""Shared INSERT OR IGNORE for session rows."""
def _do(conn):
conn.execute(
"""INSERT OR IGNORE INTO sessions (id, source, user_id, model, model_config,
system_prompt, parent_session_id, started_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(
session_id,
source,
user_id,
model,
json.dumps(model_config) if model_config else None,
system_prompt,
parent_session_id,
time.time(),
),
)
self._execute_write(_do)
def create_session(self, session_id: str, source: str, **kwargs) -> str:
"""Create a new session record. Returns the session_id."""
self._insert_session_row(session_id, source, **kwargs)
return session_id
def end_session(self, session_id: str, end_reason: str) -> None:
"""Mark a session as ended.
No-ops when the session is already ended. The first end_reason wins:
compression-split sessions must keep their ``end_reason = 'compression'``
record even if a later stale ``end_session()`` call (e.g. from a
desynced CLI session_id after ``/resume`` or ``/branch``) targets them
with a different reason. Use ``reopen_session()`` first if you
intentionally need to re-end a closed session with a new reason.
"""
def _do(conn):
conn.execute(
"UPDATE sessions SET ended_at = ?, end_reason = ? "
"WHERE id = ? AND ended_at IS NULL",
(time.time(), end_reason, session_id),
)
self._execute_write(_do)
def reopen_session(self, session_id: str) -> None:
"""Clear ended_at/end_reason so a session can be resumed."""
def _do(conn):
conn.execute(
"UPDATE sessions SET ended_at = NULL, end_reason = NULL WHERE id = ?",
(session_id,),
)
self._execute_write(_do)
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
"""Store the full assembled system prompt snapshot."""
def _do(conn):
conn.execute(
"UPDATE sessions SET system_prompt = ? WHERE id = ?",
(system_prompt, session_id),
)
self._execute_write(_do)
def update_token_counts(
self,
session_id: str,
input_tokens: int = 0,
output_tokens: int = 0,
model: str = None,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
reasoning_tokens: int = 0,
estimated_cost_usd: Optional[float] = None,
actual_cost_usd: Optional[float] = None,
cost_status: Optional[str] = None,
cost_source: Optional[str] = None,
pricing_version: Optional[str] = None,
billing_provider: Optional[str] = None,
billing_base_url: Optional[str] = None,
billing_mode: Optional[str] = None,
api_call_count: int = 0,
absolute: bool = False,
) -> None:
"""Update token counters and backfill model if not already set.
When *absolute* is False (default), values are **incremented** — use
this for per-API-call deltas (CLI path).
When *absolute* is True, values are **set directly** — use this when
the caller already holds cumulative totals (gateway path, where the
cached agent accumulates across messages).
"""
self._insert_session_row(session_id, "unknown", model=model)
if absolute:
sql = """UPDATE sessions SET
input_tokens = ?,
output_tokens = ?,
cache_read_tokens = ?,
cache_write_tokens = ?,
reasoning_tokens = ?,
estimated_cost_usd = COALESCE(?, 0),
actual_cost_usd = CASE
WHEN ? IS NULL THEN actual_cost_usd
ELSE ?
END,
cost_status = COALESCE(?, cost_status),
cost_source = COALESCE(?, cost_source),
pricing_version = COALESCE(?, pricing_version),
billing_provider = COALESCE(billing_provider, ?),
billing_base_url = COALESCE(billing_base_url, ?),
billing_mode = COALESCE(billing_mode, ?),
model = COALESCE(model, ?),
api_call_count = ?
WHERE id = ?"""
else:
sql = """UPDATE sessions SET
input_tokens = input_tokens + ?,
output_tokens = output_tokens + ?,
cache_read_tokens = cache_read_tokens + ?,
cache_write_tokens = cache_write_tokens + ?,
reasoning_tokens = reasoning_tokens + ?,
estimated_cost_usd = COALESCE(estimated_cost_usd, 0) + COALESCE(?, 0),
actual_cost_usd = CASE
WHEN ? IS NULL THEN actual_cost_usd
ELSE COALESCE(actual_cost_usd, 0) + ?
END,
cost_status = COALESCE(?, cost_status),
cost_source = COALESCE(?, cost_source),
pricing_version = COALESCE(?, pricing_version),
billing_provider = COALESCE(billing_provider, ?),
billing_base_url = COALESCE(billing_base_url, ?),
billing_mode = COALESCE(billing_mode, ?),
model = COALESCE(model, ?),
api_call_count = COALESCE(api_call_count, 0) + ?
WHERE id = ?"""
params = (
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
reasoning_tokens,
estimated_cost_usd,
actual_cost_usd,
actual_cost_usd,
cost_status,
cost_source,
pricing_version,
billing_provider,
billing_base_url,
billing_mode,
model,
api_call_count,
session_id,
)
def _do(conn):
conn.execute(sql, params)
self._execute_write(_do)
def ensure_session(
self,
session_id: str,
source: str = "unknown",
model: str = None,
**kwargs,
) -> str:
"""Ensure a session row exists (INSERT OR IGNORE). Accepts optional kwargs."""
self._insert_session_row(session_id, source, model=model, **kwargs)
return session_id
def prune_empty_ghost_sessions(self, sessions_dir: "Optional[Path]" = None) -> int:
"""Remove empty TUI ghost sessions (no messages, no title, >24hr old)."""
cutoff = time.time() - 86400
def _do(conn):
rows = conn.execute("""
SELECT id FROM sessions
WHERE source = 'tui'
AND title IS NULL
AND ended_at IS NOT NULL
AND started_at < ?
AND NOT EXISTS (
SELECT 1 FROM messages WHERE messages.session_id = sessions.id
)
""", (cutoff,)).fetchall()
ids = [r[0] if isinstance(r, (tuple, list)) else r["id"] for r in rows]
if ids:
placeholders = ",".join("?" * len(ids))
conn.execute(
f"DELETE FROM sessions WHERE id IN ({placeholders})", ids
)
return ids
removed_ids = self._execute_write(_do) or []
if sessions_dir and removed_ids:
for sid in removed_ids:
self._remove_session_files(sessions_dir, sid)
return len(removed_ids)
def finalize_orphaned_compression_sessions(self) -> int:
"""Mark orphaned compression continuation sessions as ended.
Targets child sessions that were never finalized: parent is ended
with reason='compression', child has messages but no end_reason/ended_at
and api_call_count=0. Non-destructive: preserves all messages and sets
end_reason='orphaned_compression'. Fix for #20001.
"""
cutoff = time.time() - 604800
def _do(conn):
now = time.time()
result = conn.execute(
"""
UPDATE sessions
SET ended_at = ?,
end_reason = 'orphaned_compression'
WHERE api_call_count = 0
AND end_reason IS NULL
AND ended_at IS NULL
AND started_at < ?
AND parent_session_id IS NOT NULL
AND EXISTS (
SELECT 1 FROM sessions p
WHERE p.id = sessions.parent_session_id
AND p.end_reason = 'compression'
AND p.ended_at IS NOT NULL
)
AND EXISTS (
SELECT 1 FROM messages m
WHERE m.session_id = sessions.id
)
""",
(now, cutoff),
)
return result.rowcount
return self._execute_write(_do) or 0
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get a session by ID."""
with self._lock:
cursor = self._conn.execute(
"SELECT * FROM sessions WHERE id = ?", (session_id,)
)
row = cursor.fetchone()
return dict(row) if row else None
def resolve_session_id(self, session_id_or_prefix: str) -> Optional[str]:
"""Resolve an exact or uniquely prefixed session ID to the full ID.
Returns the exact ID when it exists. Otherwise treats the input as a
prefix and returns the single matching session ID if the prefix is
unambiguous. Returns None for no matches or ambiguous prefixes.
"""
exact = self.get_session(session_id_or_prefix)
if exact:
return exact["id"]
escaped = (
session_id_or_prefix
.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
)
with self._lock:
cursor = self._conn.execute(
"SELECT id FROM sessions WHERE id LIKE ? ESCAPE '\\' ORDER BY started_at DESC LIMIT 2",
(f"{escaped}%",),
)
matches = [row["id"] for row in cursor.fetchall()]
if len(matches) == 1:
return matches[0]
return None
MAX_TITLE_LENGTH = 100
@staticmethod
def sanitize_title(title: Optional[str]) -> Optional[str]:
"""Validate and sanitize a session title.
- Strips leading/trailing whitespace
- Removes ASCII control characters (0x00-0x1F, 0x7F) and problematic
Unicode control chars (zero-width, RTL/LTR overrides, etc.)
- Collapses internal whitespace runs to single spaces
- Normalizes empty/whitespace-only strings to None
- Enforces MAX_TITLE_LENGTH
Returns the cleaned title string or None.
Raises ValueError if the title exceeds MAX_TITLE_LENGTH after cleaning.
"""
if not title:
return None
cleaned = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', title)
cleaned = re.sub(
r'[\u200b-\u200f\u2028-\u202e\u2060-\u2069\ufeff\ufffc\ufff9-\ufffb]',
'', cleaned,
)
cleaned = re.sub(r'\s+', ' ', cleaned).strip()
if not cleaned:
return None
if len(cleaned) > SessionDB.MAX_TITLE_LENGTH:
raise ValueError(
f"Title too long ({len(cleaned)} chars, max {SessionDB.MAX_TITLE_LENGTH})"
)
return cleaned
def set_session_title(self, session_id: str, title: str) -> bool:
"""Set or update a session's title.
Returns True if session was found and title was set.
Raises ValueError if title is already in use by another session,
or if the title fails validation (too long, invalid characters).
Empty/whitespace-only strings are normalized to None (clearing the title).
"""
title = self.sanitize_title(title)
def _do(conn):
if title:
cursor = conn.execute(
"SELECT id FROM sessions WHERE title = ? AND id != ?",
(title, session_id),
)
conflict = cursor.fetchone()
if conflict:
raise ValueError(
f"Title '{title}' is already in use by session {conflict['id']}"
)
cursor = conn.execute(
"UPDATE sessions SET title = ? WHERE id = ?",
(title, session_id),
)
return cursor.rowcount
rowcount = self._execute_write(_do)
return rowcount > 0
def get_session_title(self, session_id: str) -> Optional[str]:
"""Get the title for a session, or None."""
with self._lock:
cursor = self._conn.execute(
"SELECT title FROM sessions WHERE id = ?", (session_id,)
)
row = cursor.fetchone()
return row["title"] if row else None
def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]:
"""Look up a session by exact title. Returns session dict or None."""
with self._lock:
cursor = self._conn.execute(
"SELECT * FROM sessions WHERE title = ?", (title,)
)
row = cursor.fetchone()
return dict(row) if row else None
def resolve_session_by_title(self, title: str) -> Optional[str]:
"""Resolve a title to a session ID, preferring the latest in a lineage.
If the exact title exists, returns that session's ID.
If not, searches for "title #N" variants and returns the latest one.
If the exact title exists AND numbered variants exist, returns the
latest numbered variant (the most recent continuation).
"""
exact = self.get_session_by_title(title)
escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
with self._lock:
cursor = self._conn.execute(
"SELECT id, title, started_at FROM sessions "
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
(f"{escaped} #%",),
)
numbered = cursor.fetchall()
if numbered:
return numbered[0]["id"]
elif exact:
return exact["id"]
return None
def get_next_title_in_lineage(self, base_title: str) -> str:
"""Generate the next title in a lineage (e.g., "my session" → "my session #2").
Strips any existing " #N" suffix to find the base name, then finds
the highest existing number and increments.
"""
match = re.match(r'^(.*?) #(\d+)$', base_title)
if match:
base = match.group(1)
else:
base = base_title
escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
with self._lock:
cursor = self._conn.execute(
"SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'",
(base, f"{escaped} #%"),
)
existing = [row["title"] for row in cursor.fetchall()]
if not existing:
return base
max_num = 1
for t in existing:
m = re.match(r'^.* #(\d+)$', t)
if m:
max_num = max(max_num, int(m.group(1)))
return f"{base} #{max_num + 1}"
def get_compression_tip(self, session_id: str) -> Optional[str]:
"""Walk the compression-continuation chain forward and return the tip.
A compression continuation is a child session where:
1. The parent's ``end_reason = 'compression'``
2. The child was created AFTER the parent was ended (started_at >= ended_at)
The second condition distinguishes compression continuations from
delegate subagents or branch children, which can also have a
``parent_session_id`` but were created while the parent was still live.
Returns the session_id of the latest continuation in the chain, or the
input ``session_id`` if it isn't part of a compression chain (or if the
input itself doesn't exist).
"""
current = session_id
for _ in range(100):
with self._lock:
cursor = self._conn.execute(
"SELECT id FROM sessions "
"WHERE parent_session_id = ? "
" AND started_at >= ("
" SELECT ended_at FROM sessions "
" WHERE id = ? AND end_reason = 'compression'"
" ) "
"ORDER BY started_at DESC LIMIT 1",
(current, current),
)
row = cursor.fetchone()
if row is None:
return current
current = row["id"]
return current
def list_sessions_rich(
self,
source: str = None,
exclude_sources: List[str] = None,
limit: int = 20,
offset: int = 0,
include_children: bool = False,
project_compression_tips: bool = True,
order_by_last_active: bool = False,
) -> List[Dict[str, Any]]:
"""List sessions with preview (first user message) and last active timestamp.
Returns dicts with keys: id, source, model, title, started_at, ended_at,
message_count, preview (first 60 chars of first user message),
last_active (timestamp of last message).
Uses a single query with correlated subqueries instead of N+2 queries.
By default, child sessions (subagent runs, compression continuations)
are excluded. Pass ``include_children=True`` to include them.
With ``project_compression_tips=True`` (default), sessions that are
roots of compression chains are projected forward to their latest
continuation — one logical conversation = one list entry, showing the
live continuation's id/message_count/title/last_active. This prevents
compressed continuations from being invisible to users while keeping
delegate subagents and branches hidden. Pass ``False`` to return the
raw root rows (useful for admin/debug UIs).
Pass ``order_by_last_active=True`` to sort by most-recent activity
instead of original conversation start time. For compression chains,
the "most-recent activity" is taken from the live tip (not the root),
so an old conversation that was compressed and continued recently
surfaces in the correct slot. Ordering is computed at SQL level via
a recursive CTE that walks compression-continuation edges, so LIMIT
and OFFSET still apply efficiently.
"""
where_clauses = []
params = []
if not include_children:
where_clauses.append(
"(s.parent_session_id IS NULL"
" OR EXISTS (SELECT 1 FROM sessions p"
" WHERE p.id = s.parent_session_id"
" AND p.end_reason = 'branched'"
" AND s.started_at >= p.ended_at))"
)
if source:
where_clauses.append("s.source = ?")
params.append(source)
if exclude_sources:
placeholders = ",".join("?" for _ in exclude_sources)
where_clauses.append(f"s.source NOT IN ({placeholders})")
params.extend(exclude_sources)
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
if order_by_last_active:
query = f"""
WITH RECURSIVE chain(root_id, cur_id) AS (
SELECT s.id, s.id FROM sessions s {where_sql}
UNION ALL
SELECT c.root_id, child.id
FROM chain c
JOIN sessions parent ON parent.id = c.cur_id
JOIN sessions child ON child.parent_session_id = c.cur_id
WHERE parent.end_reason = 'compression'
AND child.started_at >= parent.ended_at
),
chain_max AS (
SELECT
root_id,
MAX(COALESCE(
(SELECT MAX(m.timestamp) FROM messages m WHERE m.session_id = cur_id),
(SELECT started_at FROM sessions ss WHERE ss.id = cur_id)
)) AS effective_last_active
FROM chain
GROUP BY root_id
)
SELECT s.*,
COALESCE(
(SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63)
FROM messages m
WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL
ORDER BY m.timestamp, m.id LIMIT 1),
''
) AS _preview_raw,
COALESCE(
(SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id),
s.started_at
) AS last_active,
COALESCE(cm.effective_last_active, s.started_at) AS _effective_last_active
FROM sessions s
LEFT JOIN chain_max cm ON cm.root_id = s.id
{where_sql}
ORDER BY _effective_last_active DESC, s.started_at DESC, s.id DESC
LIMIT ? OFFSET ?
"""
params = params + params + [limit, offset]
else:
query = f"""
SELECT s.*,
COALESCE(
(SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63)
FROM messages m
WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL
ORDER BY m.timestamp, m.id LIMIT 1),
''
) AS _preview_raw,
COALESCE(
(SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id),
s.started_at
) AS last_active
FROM sessions s
{where_sql}
ORDER BY s.started_at DESC
LIMIT ? OFFSET ?
"""
params.extend([limit, offset])
with self._lock:
cursor = self._conn.execute(query, params)
rows = cursor.fetchall()
sessions = []
for row in rows:
s = dict(row)
raw = s.pop("_preview_raw", "").strip()
if raw:
text = raw[:60]
s["preview"] = text + ("..." if len(raw) > 60 else "")
else:
s["preview"] = ""
s.pop("_effective_last_active", None)
sessions.append(s)
if project_compression_tips and not include_children:
projected = []
for s in sessions:
if s.get("end_reason") != "compression":
projected.append(s)
continue
tip_id = self.get_compression_tip(s["id"])
if tip_id == s["id"]:
projected.append(s)
continue
tip_row = self._get_session_rich_row(tip_id)
if not tip_row:
projected.append(s)
continue
merged = dict(s)
for key in (
"id", "ended_at", "end_reason", "message_count",
"tool_call_count", "title", "last_active", "preview",
"model", "system_prompt",
):
if key in tip_row:
merged[key] = tip_row[key]
merged["_lineage_root_id"] = s["id"]
projected.append(merged)
sessions = projected
return sessions
def _get_session_rich_row(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Fetch a single session with the same enriched columns as
``list_sessions_rich`` (preview + last_active). Returns None if the
session doesn't exist.
"""
query = """
SELECT s.*,
COALESCE(
(SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63)
FROM messages m
WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL
ORDER BY m.timestamp, m.id LIMIT 1),
''
) AS _preview_raw,
COALESCE(
(SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id),
s.started_at
) AS last_active
FROM sessions s
WHERE s.id = ?
"""
with self._lock:
cursor = self._conn.execute(query, (session_id,))
row = cursor.fetchone()
if not row:
return None
s = dict(row)
raw = s.pop("_preview_raw", "").strip()
if raw:
text = raw[:60]
s["preview"] = text + ("..." if len(raw) > 60 else "")
else:
s["preview"] = ""
return s
_CONTENT_JSON_PREFIX = "\x00json:"
@classmethod
def _encode_content(cls, content: Any) -> Any:
"""Serialize structured (list/dict) message content for sqlite.
sqlite3 can only bind ``str``, ``bytes``, ``int``, ``float``, and ``None``
to query parameters. Multimodal messages have ``content`` as a list of
parts (``[{"type": "text", ...}, {"type": "image_url", ...}]``), which
raises ``ProgrammingError: Error binding parameter N: type 'list' is
not supported`` when bound directly.
Returns the value unchanged when it's already a safe scalar, or a
sentinel-prefixed JSON string for lists/dicts. Paired with
:meth:`_decode_content` on read.
"""
if content is None or isinstance(content, (str, bytes, int, float)):
return content
try:
return cls._CONTENT_JSON_PREFIX + json.dumps(content)
except (TypeError, ValueError):
return str(content)
@classmethod
def _decode_content(cls, content: Any) -> Any:
"""Reverse :meth:`_encode_content`; returns scalars unchanged."""
if isinstance(content, str) and content.startswith(cls._CONTENT_JSON_PREFIX):
try:
return json.loads(content[len(cls._CONTENT_JSON_PREFIX):])
except (json.JSONDecodeError, TypeError):
logger.warning(
"Failed to decode JSON-encoded message content; "
"returning raw string"
)
return content
return content
def append_message(
self,
session_id: str,
role: str,
content: str = None,
tool_name: str = None,
tool_calls: Any = None,
tool_call_id: str = None,
token_count: int = None,
finish_reason: str = None,
reasoning: str = None,
reasoning_content: str = None,
reasoning_details: Any = None,
codex_reasoning_items: Any = None,
codex_message_items: Any = None,
platform_message_id: str = None,
) -> int:
"""
Append a message to a session. Returns the message row ID.
Also increments the session's message_count (and tool_call_count
if role is 'tool' or tool_calls is present).
``platform_message_id`` is the external messaging platform's own
message ID (e.g. Telegram update_id, Yuanbao msg_id). It is
independent of the SQLite autoincrement primary key and is used by
platform-specific flows like yuanbao's recall guard to redact a
message by its platform-side identifier.
"""
reasoning_details_json = (
json.dumps(reasoning_details)
if reasoning_details else None
)
codex_items_json = (
json.dumps(codex_reasoning_items)
if codex_reasoning_items else None
)
codex_message_items_json = (
json.dumps(codex_message_items)
if codex_message_items else None
)
tool_calls_json = json.dumps(tool_calls) if tool_calls else None
stored_content = self._encode_content(content)
num_tool_calls = 0
if tool_calls is not None:
num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1
def _do(conn):
cursor = conn.execute(
"""INSERT INTO messages (session_id, role, content, tool_call_id,
tool_calls, tool_name, timestamp, token_count, finish_reason,
reasoning, reasoning_content, reasoning_details, codex_reasoning_items,
codex_message_items, platform_message_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
session_id,
role,
stored_content,
tool_call_id,
tool_calls_json,
tool_name,
time.time(),
token_count,
finish_reason,
reasoning,
reasoning_content,
reasoning_details_json,
codex_items_json,
codex_message_items_json,
platform_message_id,
),
)
msg_id = cursor.lastrowid
if num_tool_calls > 0:
conn.execute(
"""UPDATE sessions SET message_count = message_count + 1,
tool_call_count = tool_call_count + ? WHERE id = ?""",
(num_tool_calls, session_id),
)
else:
conn.execute(
"UPDATE sessions SET message_count = message_count + 1 WHERE id = ?",
(session_id,),
)
return msg_id
return self._execute_write(_do)
def replace_messages(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
"""Atomically replace every message for a session.
Used by transcript-rewrite flows such as /retry, /undo, and /compress.
The delete + reinsert sequence must commit as one transaction so a
mid-rewrite failure does not leave SQLite with a partial transcript.
"""
def _do(conn):
conn.execute(
"DELETE FROM messages WHERE session_id = ?", (session_id,)
)
conn.execute(
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
(session_id,),
)
now_ts = time.time()
total_messages = 0
total_tool_calls = 0
for msg in messages:
role = msg.get("role", "unknown")
tool_calls = msg.get("tool_calls")
reasoning_details = msg.get("reasoning_details") if role == "assistant" else None
codex_reasoning_items = (
msg.get("codex_reasoning_items") if role == "assistant" else None
)
codex_message_items = (
msg.get("codex_message_items") if role == "assistant" else None
)
reasoning_details_json = (
json.dumps(reasoning_details) if reasoning_details else None
)
codex_items_json = (
json.dumps(codex_reasoning_items) if codex_reasoning_items else None
)
codex_message_items_json = (
json.dumps(codex_message_items) if codex_message_items else None
)
tool_calls_json = json.dumps(tool_calls) if tool_calls else None
platform_msg_id = (
msg.get("platform_message_id") or msg.get("message_id")
)
conn.execute(
"""INSERT INTO messages (session_id, role, content, tool_call_id,
tool_calls, tool_name, timestamp, token_count, finish_reason,
reasoning, reasoning_content, reasoning_details, codex_reasoning_items,
codex_message_items, platform_message_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
session_id,
role,
self._encode_content(msg.get("content")),
msg.get("tool_call_id"),
tool_calls_json,
msg.get("tool_name"),
now_ts,
msg.get("token_count"),
msg.get("finish_reason"),
msg.get("reasoning") if role == "assistant" else None,
msg.get("reasoning_content") if role == "assistant" else None,
reasoning_details_json,
codex_items_json,
codex_message_items_json,
platform_msg_id,
),
)
total_messages += 1
if tool_calls is not None:
total_tool_calls += (
len(tool_calls) if isinstance(tool_calls, list) else 1
)
now_ts += 1e-6
conn.execute(
"UPDATE sessions SET message_count = ?, tool_call_count = ? WHERE id = ?",
(total_messages, total_tool_calls, session_id),
)
self._execute_write(_do)
def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
"""Load all messages for a session, ordered by insertion order."""
with self._lock:
cursor = self._conn.execute(
"SELECT * FROM messages WHERE session_id = ? ORDER BY id",
(session_id,),
)
rows = cursor.fetchall()
result = []
for row in rows:
msg = dict(row)
if "content" in msg:
msg["content"] = self._decode_content(msg["content"])
if msg.get("tool_calls"):
try:
msg["tool_calls"] = json.loads(msg["tool_calls"])
except (json.JSONDecodeError, TypeError):
logger.warning("Failed to deserialize tool_calls in get_messages, falling back to []")
msg["tool_calls"] = []
result.append(msg)
return result
def get_messages_around(
self,
session_id: str,
around_message_id: int,
window: int = 5,
) -> Dict[str, Any]:
"""Load a window of messages anchored on a specific message id.
Returns a dict with:
- ``window``: up to ``window`` messages before the anchor, the anchor
itself, and up to ``window`` messages after, ordered by id ascending.
- ``messages_before``: count of messages strictly before the anchor
still in the session (== window unless we hit the start).
- ``messages_after``: count of messages strictly after the anchor
still in the session (== window unless we hit the end).
Used by ``session_search`` for both the discovery shape (anchored on the
FTS5 match) and the scroll shape (anchored on any message id). The
``messages_before`` / ``messages_after`` counts let the caller detect
session boundaries: when either is less than ``window``, the agent has
reached one end of the session.
Returns an empty window when ``around_message_id`` is not a real id in
``session_id`` — callers decide how to surface that.
"""
if window < 0:
window = 0
with self._lock:
anchor_exists = self._conn.execute(
"SELECT 1 FROM messages WHERE id = ? AND session_id = ? LIMIT 1",
(around_message_id, session_id),
).fetchone()
if not anchor_exists:
return {"window": [], "messages_before": 0, "messages_after": 0}
before_rows = self._conn.execute(
"SELECT * FROM messages "
"WHERE session_id = ? AND id <= ? "
"ORDER BY id DESC LIMIT ?",
(session_id, around_message_id, window + 1),
).fetchall()
after_rows = self._conn.execute(
"SELECT * FROM messages "
"WHERE session_id = ? AND id > ? "
"ORDER BY id ASC LIMIT ?",
(session_id, around_message_id, window),
).fetchall()
rows = list(reversed(before_rows)) + list(after_rows)
result = []
for row in rows:
msg = dict(row)
if "content" in msg:
msg["content"] = self._decode_content(msg["content"])
if msg.get("tool_calls"):
try:
msg["tool_calls"] = json.loads(msg["tool_calls"])
except (json.JSONDecodeError, TypeError):
logger.warning(
"Failed to deserialize tool_calls in get_messages_around, falling back to []"
)
msg["tool_calls"] = []
result.append(msg)
messages_before = max(0, len(before_rows) - 1)
messages_after = len(after_rows)
return {
"window": result,
"messages_before": messages_before,
"messages_after": messages_after,
}
def get_anchored_view(
self,
session_id: str,
around_message_id: int,
window: int = 5,
bookend: int = 3,
keep_roles: Optional[Tuple[str, ...]] = ("user", "assistant"),
) -> Dict[str, Any]:
"""Return an anchored window plus session bookends.
Built on top of ``get_messages_around``. Three slices:
- ``window``: messages immediately surrounding the anchor. Filtered
to ``keep_roles`` (tool-response noise dropped by default), EXCEPT
the anchor itself is always preserved regardless of role.
- ``bookend_start``: first ``bookend`` user/assistant messages of the
session — but only those whose id is strictly before the window's
first message id. Empty when the window already overlaps the
session head. Empty-content messages (tool-call-only assistant
turns) are skipped so they don't crowd out actual prose openings.
- ``bookend_end``: last ``bookend`` user/assistant messages of the
session, same non-overlap rule at the tail.
Bookends let an FTS5 hit anywhere in a long session yield the goal
(opening) and the resolution (closing) on a single call — without
loading the whole transcript.
Returns ``{"window": [], "messages_before": 0, "messages_after": 0,
"bookend_start": [], "bookend_end": []}`` when the anchor isn't in
the session.
``keep_roles=None`` disables role filtering (raw window + raw
bookends).
"""
if bookend < 0:
bookend = 0
primitive = self.get_messages_around(
session_id, around_message_id, window=window
)
window_rows = primitive["window"]
if not window_rows:
return {
"window": [],
"messages_before": 0,
"messages_after": 0,
"bookend_start": [],
"bookend_end": [],
}
if keep_roles is not None:
keep_set = set(keep_roles)
filtered_window = [
m for m in window_rows
if m.get("id") == around_message_id or m.get("role") in keep_set
]
else:
filtered_window = window_rows
window_min_id = window_rows[0]["id"]
window_max_id = window_rows[-1]["id"]
bookend_start_rows: List[Any] = []
bookend_end_rows: List[Any] = []
if bookend > 0:
with self._lock:
role_clause = ""
role_params: list = []
if keep_roles is not None:
role_placeholders = ",".join("?" for _ in keep_roles)
role_clause = f" AND role IN ({role_placeholders})"
role_params = list(keep_roles)
bookend_start_rows = self._conn.execute(
f"SELECT * FROM messages "
f"WHERE session_id = ? AND id < ?{role_clause} "
f"AND length(content) > 0 "
f"ORDER BY id ASC LIMIT ?",
(session_id, window_min_id, *role_params, bookend),
).fetchall()
bookend_end_rows = self._conn.execute(
f"SELECT * FROM messages "
f"WHERE session_id = ? AND id > ?{role_clause} "
f"AND length(content) > 0 "
f"ORDER BY id DESC LIMIT ?",
(session_id, window_max_id, *role_params, bookend),
).fetchall()
bookend_end_rows = list(reversed(bookend_end_rows))
def _hydrate(row) -> Dict[str, Any]:
msg = dict(row)
if "content" in msg:
msg["content"] = self._decode_content(msg["content"])
if msg.get("tool_calls"):
try:
msg["tool_calls"] = json.loads(msg["tool_calls"])
except (json.JSONDecodeError, TypeError):
logger.warning(
"Failed to deserialize tool_calls in get_anchored_view, falling back to []"
)
msg["tool_calls"] = []
return msg
return {
"window": filtered_window,
"messages_before": primitive["messages_before"],
"messages_after": primitive["messages_after"],
"bookend_start": [_hydrate(r) for r in bookend_start_rows],
"bookend_end": [_hydrate(r) for r in bookend_end_rows],
}
def resolve_resume_session_id(self, session_id: str) -> str:
"""Redirect a resume target to the descendant session that holds the messages.
Context compression ends the current session and forks a new child session
(linked via ``parent_session_id``). The flush cursor is reset, so the
child is where new messages actually land — the parent ends up with
``message_count = 0`` rows unless messages had already been flushed to
it before compression. See #15000.
This helper walks ``parent_session_id`` forward from ``session_id`` and
returns the first descendant in the chain that has at least one message
row. If the original session already has messages, or no descendant
has any, the original ``session_id`` is returned unchanged.
The chain is always walked via the child whose ``started_at`` is
latest; that matches the single-chain shape that compression creates.
A depth cap (32) guards against accidental loops in malformed data.
"""
if not session_id:
return session_id
with self._lock:
try:
row = self._conn.execute(
"SELECT 1 FROM messages WHERE session_id = ? LIMIT 1",
(session_id,),
).fetchone()
except Exception:
return session_id
if row is not None:
return session_id
current = session_id
seen = {current}
for _ in range(32):
try:
child_row = self._conn.execute(
"SELECT id FROM sessions "
"WHERE parent_session_id = ? "
"ORDER BY started_at DESC, id DESC LIMIT 1",
(current,),
).fetchone()
except Exception:
return session_id
if child_row is None:
return session_id
child_id = child_row["id"] if hasattr(child_row, "keys") else child_row[0]
if not child_id or child_id in seen:
return session_id
seen.add(child_id)
try:
msg_row = self._conn.execute(
"SELECT 1 FROM messages WHERE session_id = ? LIMIT 1",
(child_id,),
).fetchone()
except Exception:
return session_id
if msg_row is not None:
return child_id
current = child_id
return session_id
def get_messages_as_conversation(
self, session_id: str, include_ancestors: bool = False
) -> List[Dict[str, Any]]:
"""
Load messages in the OpenAI conversation format (role + content dicts).
Used by the gateway to restore conversation history.
"""
session_ids = [session_id]
if include_ancestors:
session_ids = self._session_lineage_root_to_tip(session_id)
with self._lock:
placeholders = ",".join("?" for _ in session_ids)
rows = self._conn.execute(
"SELECT role, content, tool_call_id, tool_calls, tool_name, "
"finish_reason, reasoning, reasoning_content, reasoning_details, "
"codex_reasoning_items, codex_message_items, platform_message_id "
f"FROM messages WHERE session_id IN ({placeholders}) ORDER BY id",
tuple(session_ids),
).fetchall()
messages = []
for row in rows:
content = self._decode_content(row["content"])
if row["role"] in {"user", "assistant"} and isinstance(content, str):
content = sanitize_context(content).strip()
msg = {"role": row["role"], "content": content}
if row["tool_call_id"]:
msg["tool_call_id"] = row["tool_call_id"]
if row["tool_name"]:
msg["tool_name"] = row["tool_name"]
if row["tool_calls"]:
try:
msg["tool_calls"] = json.loads(row["tool_calls"])
except (json.JSONDecodeError, TypeError):
logger.warning("Failed to deserialize tool_calls in conversation replay, falling back to []")
msg["tool_calls"] = []
if row["platform_message_id"]:
msg["message_id"] = row["platform_message_id"]
if row["role"] == "assistant":
if row["finish_reason"]:
msg["finish_reason"] = row["finish_reason"]
if row["reasoning"]:
msg["reasoning"] = row["reasoning"]
if row["reasoning_content"] is not None:
msg["reasoning_content"] = row["reasoning_content"]
if row["reasoning_details"]:
try:
msg["reasoning_details"] = json.loads(row["reasoning_details"])
except (json.JSONDecodeError, TypeError):
logger.warning("Failed to deserialize reasoning_details, falling back to None")
msg["reasoning_details"] = None
if row["codex_reasoning_items"]:
try:
msg["codex_reasoning_items"] = json.loads(row["codex_reasoning_items"])
except (json.JSONDecodeError, TypeError):
logger.warning("Failed to deserialize codex_reasoning_items, falling back to None")
msg["codex_reasoning_items"] = None
if row["codex_message_items"]:
try:
msg["codex_message_items"] = json.loads(row["codex_message_items"])
except (json.JSONDecodeError, TypeError):
logger.warning("Failed to deserialize codex_message_items, falling back to None")
msg["codex_message_items"] = None
if include_ancestors and self._is_duplicate_replayed_user_message(messages, msg):
continue
messages.append(msg)
return messages
def _session_lineage_root_to_tip(self, session_id: str) -> List[str]:
if not session_id:
return [session_id]
chain = []
current = session_id
seen = set()
with self._lock:
for _ in range(100):
if not current or current in seen:
break
seen.add(current)
chain.append(current)
row = self._conn.execute(
"SELECT parent_session_id FROM sessions WHERE id = ?",
(current,),
).fetchone()
if row is None:
break
current = row["parent_session_id"] if hasattr(row, "keys") else row[0]
return list(reversed(chain)) or [session_id]
@staticmethod
def _is_duplicate_replayed_user_message(messages: List[Dict[str, Any]], msg: Dict[str, Any]) -> bool:
if msg.get("role") != "user":
return False
content = msg.get("content")
if not isinstance(content, str) or not content:
return False
for prev in reversed(messages):
if prev.get("role") == "user" and prev.get("content") == content:
return True
if prev.get("role") == "assistant" and (prev.get("content") or prev.get("tool_calls")):
return False
return False
@staticmethod
def _sanitize_fts5_query(query: str) -> str:
"""Sanitize user input for safe use in FTS5 MATCH queries.
FTS5 has its own query syntax where characters like ``"``, ``(``, ``)``,
``+``, ``*``, ``{``, ``}`` and bare boolean operators (``AND``, ``OR``,
``NOT``) have special meaning. Passing raw user input directly to
MATCH can cause ``sqlite3.OperationalError``.
Strategy:
- Preserve properly paired quoted phrases (``"exact phrase"``)
- Strip unmatched FTS5-special characters that would cause errors
- Wrap unquoted hyphenated and dotted terms in quotes so FTS5
matches them as exact phrases instead of splitting on the
hyphen/dot (e.g. ``chat-send``, ``P2.2``, ``my-app.config.ts``)
"""
_quoted_parts: list = []
def _preserve_quoted(m: re.Match) -> str:
_quoted_parts.append(m.group(0))
return f"\x00Q{len(_quoted_parts) - 1}\x00"
sanitized = re.sub(r'"[^"]*"', _preserve_quoted, query)
sanitized = re.sub(r'[+{}()\"^]', " ", sanitized)
sanitized = re.sub(r"\*+", "*", sanitized)
sanitized = re.sub(r"(^|\s)\*", r"\1", sanitized)
sanitized = re.sub(r"(?i)^(AND|OR|NOT)\b\s*", "", sanitized.strip())
sanitized = re.sub(r"(?i)\s+(AND|OR|NOT)\s*$", "", sanitized.strip())
sanitized = re.sub(r"\b(\w+(?:[._-]\w+)+)\b", r'"\1"', sanitized)
for i, quoted in enumerate(_quoted_parts):
sanitized = sanitized.replace(f"\x00Q{i}\x00", quoted)
return sanitized.strip()
@staticmethod
def _is_cjk_codepoint(cp: int) -> bool:
return (0x4E00 <= cp <= 0x9FFF or
0x3400 <= cp <= 0x4DBF or
0x20000 <= cp <= 0x2A6DF or
0x3000 <= cp <= 0x303F or
0x3040 <= cp <= 0x309F or
0x30A0 <= cp <= 0x30FF or
0xAC00 <= cp <= 0xD7AF)
@staticmethod
def _contains_cjk(text: str) -> bool:
"""Check if text contains CJK (Chinese, Japanese, Korean) characters."""
for ch in text:
cp = ord(ch)
if (0x4E00 <= cp <= 0x9FFF or
0x3400 <= cp <= 0x4DBF or
0x20000 <= cp <= 0x2A6DF or
0x3000 <= cp <= 0x303F or
0x3040 <= cp <= 0x309F or
0x30A0 <= cp <= 0x30FF or
0xAC00 <= cp <= 0xD7AF):
return True
return False
@classmethod
def _count_cjk(cls, text: str) -> int:
"""Count CJK characters in text."""
return sum(1 for ch in text if cls._is_cjk_codepoint(ord(ch)))
def search_messages(
self,
query: str,
source_filter: List[str] = None,
exclude_sources: List[str] = None,
role_filter: List[str] = None,
limit: int = 20,
offset: int = 0,
sort: str = None,
) -> List[Dict[str, Any]]:
"""
Full-text search across session messages using FTS5.
Supports FTS5 query syntax:
- Simple keywords: "docker deployment"
- Phrases: '"exact phrase"'
- Boolean: "docker OR kubernetes", "python NOT java"
- Prefix: "deploy*"
Returns matching messages with session metadata, content snippet,
and surrounding context (1 message before and after the match).
``sort`` controls temporal ordering:
- ``None`` (default): FTS5 BM25 relevance only. Time-neutral.
- ``"newest"``: order by message timestamp DESC, then by rank.
- ``"oldest"``: order by message timestamp ASC, then by rank.
The short-CJK LIKE fallback already orders by timestamp DESC and
ignores ``sort``. The trigram CJK path honours ``sort`` like the main
FTS5 path.
"""
if not query or not query.strip():
return []
query = self._sanitize_fts5_query(query)
if not query:
return []
if isinstance(sort, str):
sort_norm = sort.strip().lower()
if sort_norm not in ("newest", "oldest"):
sort_norm = None
else:
sort_norm = None
if sort_norm == "newest":
order_by_sql = "ORDER BY m.timestamp DESC, rank"
elif sort_norm == "oldest":
order_by_sql = "ORDER BY m.timestamp ASC, rank"
else:
order_by_sql = "ORDER BY rank"
where_clauses = ["messages_fts MATCH ?"]
params: list = [query]
if source_filter is not None:
source_placeholders = ",".join("?" for _ in source_filter)
where_clauses.append(f"s.source IN ({source_placeholders})")
params.extend(source_filter)
if exclude_sources is not None:
exclude_placeholders = ",".join("?" for _ in exclude_sources)
where_clauses.append(f"s.source NOT IN ({exclude_placeholders})")
params.extend(exclude_sources)
if role_filter:
role_placeholders = ",".join("?" for _ in role_filter)
where_clauses.append(f"m.role IN ({role_placeholders})")
params.extend(role_filter)
where_sql = " AND ".join(where_clauses)
params.extend([limit, offset])
sql = f"""
SELECT
m.id,
m.session_id,
m.role,
snippet(messages_fts, 0, '>>>', '<<<', '...', 40) AS snippet,
m.content,
m.timestamp,
m.tool_name,
s.source,
s.model,
s.started_at AS session_started
FROM messages_fts
JOIN messages m ON m.id = messages_fts.rowid
JOIN sessions s ON s.id = m.session_id
WHERE {where_sql}
{order_by_sql}
LIMIT ? OFFSET ?
"""
is_cjk = self._contains_cjk(query)
if is_cjk:
raw_query = query.strip('"').strip()
cjk_count = self._count_cjk(raw_query)
_tokens_for_check = [
t for t in raw_query.split()
if t.upper() not in {"AND", "OR", "NOT"} and self._contains_cjk(t)
]
_any_short_cjk = any(
self._count_cjk(t) < 3 for t in _tokens_for_check
)
if cjk_count >= 3 and not _any_short_cjk:
tokens = raw_query.split()
parts = []
for tok in tokens:
if tok.upper() in {"AND", "OR", "NOT"}:
parts.append(tok)
else:
parts.append('"' + tok.replace('"', '""') + '"')
trigram_query = " ".join(parts)
tri_where = ["messages_fts_trigram MATCH ?"]
tri_params: list = [trigram_query]
if source_filter is not None:
tri_where.append(f"s.source IN ({','.join('?' for _ in source_filter)})")
tri_params.extend(source_filter)
if exclude_sources is not None:
tri_where.append(f"s.source NOT IN ({','.join('?' for _ in exclude_sources)})")
tri_params.extend(exclude_sources)
if role_filter:
tri_where.append(f"m.role IN ({','.join('?' for _ in role_filter)})")
tri_params.extend(role_filter)
tri_sql = f"""
SELECT
m.id,
m.session_id,
m.role,
snippet(messages_fts_trigram, 0, '>>>', '<<<', '...', 40) AS snippet,
m.content,
m.timestamp,
m.tool_name,
s.source,
s.model,
s.started_at AS session_started
FROM messages_fts_trigram
JOIN messages m ON m.id = messages_fts_trigram.rowid
JOIN sessions s ON s.id = m.session_id
WHERE {' AND '.join(tri_where)}
{order_by_sql}
LIMIT ? OFFSET ?
"""
tri_params.extend([limit, offset])
with self._lock:
try:
tri_cursor = self._conn.execute(tri_sql, tri_params)
except sqlite3.OperationalError:
matches = []
else:
matches = [dict(row) for row in tri_cursor.fetchall()]
else:
non_op_tokens = [
t for t in raw_query.split()
if t.upper() not in {"AND", "OR", "NOT"}
] or [raw_query]
token_clauses = []
like_params: list = []
for tok in non_op_tokens:
esc = tok.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
token_clauses.append(
"(m.content LIKE ? ESCAPE '\\' OR m.tool_name LIKE ? ESCAPE '\\' OR m.tool_calls LIKE ? ESCAPE '\\')"
)
like_params += [f"%{esc}%", f"%{esc}%", f"%{esc}%"]
like_where = [f"({' OR '.join(token_clauses)})"]
if source_filter is not None:
like_where.append(f"s.source IN ({','.join('?' for _ in source_filter)})")
like_params.extend(source_filter)
if exclude_sources is not None:
like_where.append(f"s.source NOT IN ({','.join('?' for _ in exclude_sources)})")
like_params.extend(exclude_sources)
if role_filter:
like_where.append(f"m.role IN ({','.join('?' for _ in role_filter)})")
like_params.extend(role_filter)
like_sql = f"""
SELECT m.id, m.session_id, m.role,
substr(m.content,
max(1, instr(m.content, ?) - 40),
120) AS snippet,
m.content, m.timestamp, m.tool_name,
s.source, s.model, s.started_at AS session_started
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE {' AND '.join(like_where)}
ORDER BY m.timestamp DESC
LIMIT ? OFFSET ?
"""
like_params.extend([limit, offset])
like_params = [non_op_tokens[0]] + like_params
with self._lock:
like_cursor = self._conn.execute(like_sql, like_params)
matches = [dict(row) for row in like_cursor.fetchall()]
else:
with self._lock:
try:
cursor = self._conn.execute(sql, params)
except sqlite3.OperationalError:
return []
else:
matches = [dict(row) for row in cursor.fetchall()]
for match in matches:
try:
with self._lock:
ctx_cursor = self._conn.execute(
"""WITH target AS (
SELECT session_id, timestamp, id
FROM messages
WHERE id = ?
)
SELECT role, content
FROM (
SELECT m.id, m.timestamp, m.role, m.content
FROM messages m
JOIN target t ON t.session_id = m.session_id
WHERE (m.timestamp < t.timestamp)
OR (m.timestamp = t.timestamp AND m.id < t.id)
ORDER BY m.timestamp DESC, m.id DESC
LIMIT 1
)
UNION ALL
SELECT role, content
FROM messages
WHERE id = ?
UNION ALL
SELECT role, content
FROM (
SELECT m.id, m.timestamp, m.role, m.content
FROM messages m
JOIN target t ON t.session_id = m.session_id
WHERE (m.timestamp > t.timestamp)
OR (m.timestamp = t.timestamp AND m.id > t.id)
ORDER BY m.timestamp ASC, m.id ASC
LIMIT 1
)""",
(match["id"], match["id"]),
)
context_msgs = []
for r in ctx_cursor.fetchall():
raw = r["content"]
decoded = self._decode_content(raw)
if isinstance(decoded, list):
text_parts = [
p.get("text", "") for p in decoded
if isinstance(p, dict) and p.get("type") == "text"
]
text = " ".join(t for t in text_parts if t).strip()
preview = text or "[multimodal content]"
elif isinstance(decoded, str):
preview = decoded
else:
preview = ""
context_msgs.append(
{"role": r["role"], "content": preview[:200]}
)
match["context"] = context_msgs
except Exception:
match["context"] = []
for match in matches:
match.pop("content", None)
return matches
def search_sessions(
self,
source: str = None,
limit: int = 20,
offset: int = 0,
) -> List[Dict[str, Any]]:
"""List sessions, optionally filtered by source.
Returns rows enriched with a computed ``last_active`` column (latest
message timestamp for the session, falling back to ``started_at``),
ordered by most-recently-used first.
"""
select_with_last_active = (
"SELECT s.*, COALESCE(m.last_active, s.started_at) AS last_active "
"FROM sessions s "
"LEFT JOIN ("
"SELECT session_id, MAX(timestamp) AS last_active "
"FROM messages GROUP BY session_id"
") m ON m.session_id = s.id "
)
with self._lock:
if source:
cursor = self._conn.execute(
f"{select_with_last_active}"
"WHERE s.source = ? "
"ORDER BY last_active DESC, s.started_at DESC, s.id DESC LIMIT ? OFFSET ?",
(source, limit, offset),
)
else:
cursor = self._conn.execute(
f"{select_with_last_active}"
"ORDER BY last_active DESC, s.started_at DESC, s.id DESC LIMIT ? OFFSET ?",
(limit, offset),
)
return [dict(row) for row in cursor.fetchall()]
def session_count(self, source: str = None) -> int:
"""Count sessions, optionally filtered by source."""
with self._lock:
if source:
cursor = self._conn.execute(
"SELECT COUNT(*) FROM sessions WHERE source = ?", (source,)
)
else:
cursor = self._conn.execute("SELECT COUNT(*) FROM sessions")
return cursor.fetchone()[0]
def message_count(self, session_id: str = None) -> int:
"""Count messages, optionally for a specific session."""
with self._lock:
if session_id:
cursor = self._conn.execute(
"SELECT COUNT(*) FROM messages WHERE session_id = ?", (session_id,)
)
else:
cursor = self._conn.execute("SELECT COUNT(*) FROM messages")
return cursor.fetchone()[0]
def export_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Export a single session with all its messages as a dict."""
session = self.get_session(session_id)
if not session:
return None
messages = self.get_messages(session_id)
return {**session, "messages": messages}
def export_all(self, source: str = None) -> List[Dict[str, Any]]:
"""
Export all sessions (with messages) as a list of dicts.
Suitable for writing to a JSONL file for backup/analysis.
"""
sessions = self.search_sessions(source=source, limit=100000)
results = []
for session in sessions:
messages = self.get_messages(session["id"])
results.append({**session, "messages": messages})
return results
def clear_messages(self, session_id: str) -> None:
"""Delete all messages for a session and reset its counters."""
def _do(conn):
conn.execute(
"DELETE FROM messages WHERE session_id = ?", (session_id,)
)
conn.execute(
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
(session_id,),
)
self._execute_write(_do)
@staticmethod
def _remove_session_files(sessions_dir: Optional[Path], session_id: str) -> None:
"""Remove on-disk transcript files for a session.
Cleans up ``{session_id}.json``, ``{session_id}.jsonl``, and any
``request_dump_{session_id}_*.json`` files left by the gateway.
Silently skips files that don't exist and swallows OSError so a
filesystem hiccup never blocks a DB operation.
"""
if sessions_dir is None:
return
for suffix in (".json", ".jsonl"):
p = sessions_dir / f"{session_id}{suffix}"
try:
p.unlink(missing_ok=True)
except OSError:
pass
try:
for p in sessions_dir.glob(f"request_dump_{session_id}_*.json"):
try:
p.unlink(missing_ok=True)
except OSError:
pass
except OSError:
pass
def delete_session(
self,
session_id: str,
sessions_dir: Optional[Path] = None,
) -> bool:
"""Delete a session and all its messages.
Child sessions are orphaned (parent_session_id set to NULL) rather
than cascade-deleted, so they remain accessible independently.
When *sessions_dir* is provided, also removes on-disk transcript
files (``.json`` / ``.jsonl`` / ``request_dump_*``) for the deleted
session. Returns True if the session was found and deleted.
"""
def _do(conn):
cursor = conn.execute(
"SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)
)
if cursor.fetchone()[0] == 0:
return False
conn.execute(
"UPDATE sessions SET parent_session_id = NULL "
"WHERE parent_session_id = ?",
(session_id,),
)
conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
return True
deleted = self._execute_write(_do)
if deleted:
self._remove_session_files(sessions_dir, session_id)
return deleted
def prune_sessions(
self,
older_than_days: int = 90,
source: str = None,
sessions_dir: Optional[Path] = None,
) -> int:
"""Delete sessions older than N days. Returns count of deleted sessions.
Only prunes ended sessions (not active ones). Child sessions outside
the prune window are orphaned (parent_session_id set to NULL) rather
than cascade-deleted. When *sessions_dir* is provided, also removes
on-disk transcript files (``.json`` / ``.jsonl`` /
``request_dump_*``) for every pruned session, outside the DB
transaction.
"""
cutoff = time.time() - (older_than_days * 86400)
removed_ids: list[str] = []
def _do(conn):
if source:
cursor = conn.execute(
"""SELECT id FROM sessions
WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""",
(cutoff, source),
)
else:
cursor = conn.execute(
"SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL",
(cutoff,),
)
session_ids = {row["id"] for row in cursor.fetchall()}
if not session_ids:
return 0
placeholders = ",".join("?" * len(session_ids))
conn.execute(
f"UPDATE sessions SET parent_session_id = NULL "
f"WHERE parent_session_id IN ({placeholders})",
list(session_ids),
)
for sid in session_ids:
conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
removed_ids.append(sid)
return len(session_ids)
count = self._execute_write(_do)
for sid in removed_ids:
self._remove_session_files(sessions_dir, sid)
return count
def get_meta(self, key: str) -> Optional[str]:
"""Read a value from the state_meta key/value store."""
with self._lock:
row = self._conn.execute(
"SELECT value FROM state_meta WHERE key = ?", (key,)
).fetchone()
if row is None:
return None
return row["value"] if isinstance(row, sqlite3.Row) else row[0]
def set_meta(self, key: str, value: str) -> None:
"""Write a value to the state_meta key/value store."""
def _do(conn):
conn.execute(
"INSERT INTO state_meta (key, value) VALUES (?, ?) "
"ON CONFLICT(key) DO UPDATE SET value = excluded.value",
(key, value),
)
self._execute_write(_do)
def apply_telegram_topic_migration(self) -> None:
"""Create Telegram DM topic-mode tables on explicit /topic opt-in.
This migration is deliberately not part of automatic SessionDB startup
reconciliation. Operators must be able to upgrade Hermes, keep the old
Telegram bot behavior running, and only mutate topic-mode state when the
user executes /topic to opt into the feature.
Schema versions:
v1 — initial shape (no ON DELETE CASCADE on session_id FK)
v2 — session_id FK gets ON DELETE CASCADE so session pruning
automatically clears bindings.
"""
def _do(conn):
conn.executescript(
"""
CREATE TABLE IF NOT EXISTS telegram_dm_topic_mode (
chat_id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
enabled INTEGER NOT NULL DEFAULT 1,
activated_at REAL NOT NULL,
updated_at REAL NOT NULL,
has_topics_enabled INTEGER,
allows_users_to_create_topics INTEGER,
capability_checked_at REAL,
intro_message_id TEXT,
pinned_message_id TEXT
);
CREATE TABLE IF NOT EXISTS telegram_dm_topic_bindings (
chat_id TEXT NOT NULL,
thread_id TEXT NOT NULL,
user_id TEXT NOT NULL,
session_key TEXT NOT NULL,
session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
managed_mode TEXT NOT NULL DEFAULT 'auto',
linked_at REAL NOT NULL,
updated_at REAL NOT NULL,
PRIMARY KEY (chat_id, thread_id)
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_telegram_dm_topic_bindings_session
ON telegram_dm_topic_bindings(session_id);
CREATE INDEX IF NOT EXISTS idx_telegram_dm_topic_bindings_user
ON telegram_dm_topic_bindings(user_id, chat_id);
"""
)
current = conn.execute(
"SELECT value FROM state_meta WHERE key = ?",
("telegram_dm_topic_schema_version",),
).fetchone()
current_version = int(current[0]) if current and str(current[0]).isdigit() else 0
if current_version < 2:
fk_rows = conn.execute(
"PRAGMA foreign_key_list('telegram_dm_topic_bindings')"
).fetchall()
needs_rebuild = any(
row[2] == "sessions" and (row[6] or "") != "CASCADE"
for row in fk_rows
)
if needs_rebuild:
conn.executescript(
"""
CREATE TABLE telegram_dm_topic_bindings_new (
chat_id TEXT NOT NULL,
thread_id TEXT NOT NULL,
user_id TEXT NOT NULL,
session_key TEXT NOT NULL,
session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
managed_mode TEXT NOT NULL DEFAULT 'auto',
linked_at REAL NOT NULL,
updated_at REAL NOT NULL,
PRIMARY KEY (chat_id, thread_id)
);
INSERT INTO telegram_dm_topic_bindings_new
SELECT chat_id, thread_id, user_id, session_key,
session_id, managed_mode, linked_at, updated_at
FROM telegram_dm_topic_bindings;
DROP TABLE telegram_dm_topic_bindings;
ALTER TABLE telegram_dm_topic_bindings_new
RENAME TO telegram_dm_topic_bindings;
CREATE UNIQUE INDEX idx_telegram_dm_topic_bindings_session
ON telegram_dm_topic_bindings(session_id);
CREATE INDEX idx_telegram_dm_topic_bindings_user
ON telegram_dm_topic_bindings(user_id, chat_id);
"""
)
conn.execute(
"INSERT INTO state_meta (key, value) VALUES (?, ?) "
"ON CONFLICT(key) DO UPDATE SET value = excluded.value",
("telegram_dm_topic_schema_version", "2"),
)
self._execute_write(_do)
def enable_telegram_topic_mode(
self,
*,
chat_id: str,
user_id: str,
has_topics_enabled: Optional[bool] = None,
allows_users_to_create_topics: Optional[bool] = None,
) -> None:
"""Enable Telegram DM topic mode for one private chat/user.
This method intentionally owns the explicit topic migration. Ordinary
SessionDB startup must not create these side tables.
"""
self.apply_telegram_topic_migration()
now = time.time()
def _to_int(value: Optional[bool]) -> Optional[int]:
if value is None:
return None
return 1 if value else 0
def _do(conn):
conn.execute(
"""
INSERT INTO telegram_dm_topic_mode (
chat_id, user_id, enabled, activated_at, updated_at,
has_topics_enabled, allows_users_to_create_topics,
capability_checked_at
) VALUES (?, ?, 1, ?, ?, ?, ?, ?)
ON CONFLICT(chat_id) DO UPDATE SET
user_id = excluded.user_id,
enabled = 1,
updated_at = excluded.updated_at,
has_topics_enabled = excluded.has_topics_enabled,
allows_users_to_create_topics = excluded.allows_users_to_create_topics,
capability_checked_at = excluded.capability_checked_at
""",
(
str(chat_id),
str(user_id),
now,
now,
_to_int(has_topics_enabled),
_to_int(allows_users_to_create_topics),
now,
),
)
self._execute_write(_do)
def disable_telegram_topic_mode(
self,
*,
chat_id: str,
clear_bindings: bool = True,
) -> None:
"""Disable Telegram DM topic mode for one private chat.
When ``clear_bindings`` is True (default) the (chat_id, thread_id)
bindings for this chat are also cleared so re-enabling later
starts from a clean slate. Set to False if the operator wants to
preserve bindings for a later re-enable.
Never creates the topic-mode tables from scratch; if they don't
exist there is nothing to disable and the call is a no-op.
"""
def _do(conn):
try:
conn.execute(
"UPDATE telegram_dm_topic_mode SET enabled = 0, updated_at = ? "
"WHERE chat_id = ?",
(time.time(), str(chat_id)),
)
if clear_bindings:
conn.execute(
"DELETE FROM telegram_dm_topic_bindings WHERE chat_id = ?",
(str(chat_id),),
)
except sqlite3.OperationalError:
return
self._execute_write(_do)
def is_telegram_topic_mode_enabled(self, *, chat_id: str, user_id: str) -> bool:
"""Return whether Telegram DM topic mode is enabled for this chat/user."""
with self._lock:
try:
row = self._conn.execute(
"""
SELECT enabled FROM telegram_dm_topic_mode
WHERE chat_id = ? AND user_id = ?
""",
(str(chat_id), str(user_id)),
).fetchone()
except sqlite3.OperationalError:
return False
if row is None:
return False
enabled = row["enabled"] if isinstance(row, sqlite3.Row) else row[0]
return bool(enabled)
def get_telegram_topic_binding(
self,
*,
chat_id: str,
thread_id: str,
) -> Optional[Dict[str, Any]]:
"""Return the session binding for a Telegram DM topic, if present."""
with self._lock:
try:
row = self._conn.execute(
"""
SELECT * FROM telegram_dm_topic_bindings
WHERE chat_id = ? AND thread_id = ?
""",
(str(chat_id), str(thread_id)),
).fetchone()
except sqlite3.OperationalError:
return None
return dict(row) if row else None
def list_telegram_topic_bindings_for_chat(
self,
*,
chat_id: str,
) -> List[Dict[str, Any]]:
"""All Telegram DM topic bindings for one chat, newest first.
Read-only; returns [] if the bindings table doesn't exist yet
(does not trigger the topic-mode migration).
"""
with self._lock:
try:
rows = self._conn.execute(
"SELECT * FROM telegram_dm_topic_bindings "
"WHERE chat_id = ? ORDER BY updated_at DESC",
(str(chat_id),),
).fetchall()
except sqlite3.OperationalError:
return []
return [dict(row) for row in rows]
def get_telegram_topic_binding_by_session(
self,
*,
session_id: str,
) -> Optional[Dict[str, Any]]:
"""Return the Telegram DM topic binding for a given session_id, if present.
Uses the UNIQUE INDEX on telegram_dm_topic_bindings(session_id) for an
efficient reverse lookup. Returns None when the session has no binding or
the table does not exist yet.
"""
with self._lock:
try:
row = self._conn.execute(
"""
SELECT * FROM telegram_dm_topic_bindings
WHERE session_id = ?
""",
(str(session_id),),
).fetchone()
except sqlite3.OperationalError:
return None
return dict(row) if row else None
def bind_telegram_topic(
self,
*,
chat_id: str,
thread_id: str,
user_id: str,
session_key: str,
session_id: str,
managed_mode: str = "auto",
) -> None:
"""Bind one Telegram DM topic thread to one Hermes session.
A Hermes session may only be linked to one Telegram topic in MVP.
Rebinding the same topic to the same session is idempotent; trying to
link the same session to a different topic raises ValueError.
"""
self.apply_telegram_topic_migration()
now = time.time()
chat_id = str(chat_id)
thread_id = str(thread_id)
user_id = str(user_id)
session_key = str(session_key)
session_id = str(session_id)
def _do(conn):
existing_session = conn.execute(
"""
SELECT chat_id, thread_id FROM telegram_dm_topic_bindings
WHERE session_id = ?
""",
(session_id,),
).fetchone()
if existing_session is not None:
linked_chat = existing_session["chat_id"] if isinstance(existing_session, sqlite3.Row) else existing_session[0]
linked_thread = existing_session["thread_id"] if isinstance(existing_session, sqlite3.Row) else existing_session[1]
if str(linked_chat) != chat_id or str(linked_thread) != thread_id:
raise ValueError("session is already linked to another Telegram topic")
conn.execute(
"""
INSERT INTO telegram_dm_topic_bindings (
chat_id, thread_id, user_id, session_key, session_id,
managed_mode, linked_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(chat_id, thread_id) DO UPDATE SET
user_id = excluded.user_id,
session_key = excluded.session_key,
session_id = excluded.session_id,
managed_mode = excluded.managed_mode,
updated_at = excluded.updated_at
""",
(
chat_id,
thread_id,
user_id,
session_key,
session_id,
managed_mode,
now,
now,
),
)
self._execute_write(_do)
def is_telegram_session_linked_to_topic(self, *, session_id: str) -> bool:
"""Return True if a Hermes session is already bound to any Telegram DM topic.
Read-only: does NOT trigger the telegram-topic migration. If the
topic-mode tables have not been created yet (i.e. nobody has run
``/topic`` in this profile), the session is by definition unbound
and we return False.
"""
with self._lock:
try:
row = self._conn.execute(
"""
SELECT 1 FROM telegram_dm_topic_bindings
WHERE session_id = ?
LIMIT 1
""",
(str(session_id),),
).fetchone()
except sqlite3.OperationalError:
return False
return row is not None
def list_unlinked_telegram_sessions_for_user(
self,
*,
chat_id: str,
user_id: str,
limit: int = 10,
) -> List[Dict[str, Any]]:
"""List previous Telegram sessions for this user that are not bound to a topic.
Read-only: does NOT trigger the telegram-topic migration. If the
topic-mode tables are absent, fall back to a simpler query that
just returns this user's Telegram sessions — there can't be any
bindings yet.
"""
with self._lock:
try:
rows = self._conn.execute(
"""
SELECT s.*,
COALESCE(
(SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63)
FROM messages m
WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL
ORDER BY m.timestamp, m.id LIMIT 1),
''
) AS _preview_raw,
COALESCE(
(SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id),
s.started_at
) AS last_active
FROM sessions s
WHERE s.source = 'telegram'
AND s.user_id = ?
AND NOT EXISTS (
SELECT 1 FROM telegram_dm_topic_bindings b
WHERE b.session_id = s.id
)
ORDER BY last_active DESC, s.started_at DESC
LIMIT ?
""",
(str(user_id), int(limit)),
).fetchall()
except sqlite3.OperationalError:
rows = self._conn.execute(
"""
SELECT s.*,
COALESCE(
(SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63)
FROM messages m
WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL
ORDER BY m.timestamp, m.id LIMIT 1),
''
) AS _preview_raw,
COALESCE(
(SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id),
s.started_at
) AS last_active
FROM sessions s
WHERE s.source = 'telegram'
AND s.user_id = ?
ORDER BY last_active DESC, s.started_at DESC
LIMIT ?
""",
(str(user_id), int(limit)),
).fetchall()
sessions: List[Dict[str, Any]] = []
for row in rows:
session = dict(row)
raw = str(session.pop("_preview_raw", "") or "").strip()
session["preview"] = raw[:60] + ("..." if len(raw) > 60 else "") if raw else ""
sessions.append(session)
return sessions
def vacuum(self) -> None:
"""Run VACUUM to reclaim disk space after large deletes.
SQLite does not shrink the database file when rows are deleted —
freed pages just get reused on the next insert. After a prune that
removed hundreds of sessions, the file stays bloated unless we
explicitly VACUUM.
VACUUM rewrites the entire DB, so it's expensive (seconds per
100MB) and cannot run inside a transaction. It also acquires an
exclusive lock, so callers must ensure no other writers are
active. Safe to call at startup before the gateway/CLI starts
serving traffic.
"""
with self._lock:
try:
self._conn.execute("PRAGMA wal_checkpoint(TRUNCATE)")
except Exception:
pass
self._conn.execute("VACUUM")
def maybe_auto_prune_and_vacuum(
self,
retention_days: int = 90,
min_interval_hours: int = 24,
vacuum: bool = True,
sessions_dir: Optional[Path] = None,
) -> Dict[str, Any]:
"""Idempotent auto-maintenance: prune old sessions + optional VACUUM.
Records the last run timestamp in state_meta so subsequent calls
within ``min_interval_hours`` no-op. Designed to be called once at
startup from long-lived entrypoints (CLI, gateway, cron scheduler).
When *sessions_dir* is provided, on-disk transcript files
(``.json`` / ``.jsonl`` / ``request_dump_*``) for pruned sessions
are removed as part of the same sweep (issue #3015).
Never raises. On any failure, logs a warning and returns a dict
with ``"error"`` set.
Returns a dict with keys:
- ``"skipped"`` (bool) — true if within min_interval_hours of last run
- ``"pruned"`` (int) — number of sessions deleted
- ``"vacuumed"`` (bool) — true if VACUUM ran
- ``"error"`` (str, optional) — present only on failure
"""
result: Dict[str, Any] = {"skipped": False, "pruned": 0, "vacuumed": False}
try:
last_raw = self.get_meta("last_auto_prune")
now = time.time()
if last_raw:
try:
last_ts = float(last_raw)
if now - last_ts < min_interval_hours * 3600:
result["skipped"] = True
return result
except (TypeError, ValueError):
pass
pruned = self.prune_sessions(
older_than_days=retention_days,
sessions_dir=sessions_dir,
)
result["pruned"] = pruned
if vacuum and pruned > 0:
try:
self.vacuum()
result["vacuumed"] = True
except Exception as exc:
logger.warning("state.db VACUUM failed: %s", exc)
self.set_meta("last_auto_prune", str(now))
if pruned > 0:
logger.info(
"state.db auto-maintenance: pruned %d session(s) older than %d days%s",
pruned,
retention_days,
" + VACUUM" if result["vacuumed"] else "",
)
except Exception as exc:
logger.warning("state.db auto-maintenance failed: %s", exc)
result["error"] = str(exc)
return result
def request_handoff(self, session_id: str, platform: str) -> bool:
"""Mark a session as pending handoff to the given platform.
Returns True if the row was found and not already in flight; False if
the session is already in a non-terminal handoff state.
"""
def _do(conn):
cur = conn.execute(
"UPDATE sessions "
"SET handoff_state = 'pending', "
" handoff_platform = ?, "
" handoff_error = NULL "
"WHERE id = ? AND (handoff_state IS NULL "
" OR handoff_state IN ('completed', 'failed'))",
(platform, session_id),
)
return cur.rowcount > 0
return self._execute_write(_do)
def get_handoff_state(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Read the current handoff state for a session.
Returns ``{"state", "platform", "error"}`` or None if the session has
no handoff record.
"""
try:
cur = self._conn.execute(
"SELECT handoff_state, handoff_platform, handoff_error "
"FROM sessions WHERE id = ?",
(session_id,),
)
row = cur.fetchone()
if not row:
return None
return {
"state": row["handoff_state"],
"platform": row["handoff_platform"],
"error": row["handoff_error"],
}
except Exception:
return None
def list_pending_handoffs(self) -> List[Dict[str, Any]]:
"""Return all sessions in handoff_state='pending', oldest first.
Used by the gateway's handoff watcher.
"""
try:
cur = self._conn.execute(
"SELECT * FROM sessions "
"WHERE handoff_state = 'pending' "
"ORDER BY started_at ASC"
)
return [dict(r) for r in cur.fetchall()]
except Exception:
return []
def claim_handoff(self, session_id: str) -> bool:
"""Atomically transition pending → running. Returns True if claimed."""
def _do(conn):
cur = conn.execute(
"UPDATE sessions SET handoff_state = 'running' "
"WHERE id = ? AND handoff_state = 'pending'",
(session_id,),
)
return cur.rowcount > 0
return self._execute_write(_do)
def complete_handoff(self, session_id: str) -> None:
"""Mark a handoff as completed."""
def _do(conn):
conn.execute(
"UPDATE sessions SET handoff_state = 'completed', "
"handoff_error = NULL WHERE id = ?",
(session_id,),
)
self._execute_write(_do)
def fail_handoff(self, session_id: str, error: str) -> None:
"""Mark a handoff as failed and record the reason."""
def _do(conn):
conn.execute(
"UPDATE sessions SET handoff_state = 'failed', "
"handoff_error = ? WHERE id = ?",
(error[:500], session_id),
)
self._execute_write(_do)