"""Server-side session lifecycle manager.
Manages per-session message buffers with optional AGFS persistence.
Accumulate → threshold commit → archive + extract session model.
"""
from __future__ import annotations
import json
import logging
import threading
import uuid
from copy import deepcopy
from dataclasses import dataclass, field, fields as dataclass_fields
from datetime import datetime, timezone
from typing import Any, Callable, Optional
from core.models import ContextNode, RequestContext
from providers.token_tracker import UsageTracker
from session.models import ArchiveEntry, SessionMessage, SessionMeta, SessionWindowState
from session.topic_buffer import SessionTopicBuffer
from session.session_state import Commitment, SessionState, TaskState
logger = logging.getLogger("ogmem.session")
def generate_archive_id() -> str:
"""Generate a unique archive_id.
Format: {timestamp}_{uuid8} e.g., '20260515_100000_a1b2c3d4'
"""
return (
f"{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}"
f"_{uuid.uuid4().hex[:8]}"
)
@dataclass
class SessionBuffer:
"""In-memory buffer for a single session's messages."""
session_id: str
messages: list[SessionMessage] = field(default_factory=list)
meta: SessionMeta = field(default_factory=SessionMeta)
commit_in_progress: bool = False
turn_count: int = 0
window_state: SessionWindowState = field(default_factory=SessionWindowState)
usage_stats: UsageTracker = field(default_factory=UsageTracker)
extraction_watermark: int = 0
extraction_summary: str = ""
compaction_prepare_token: str = ""
compaction_prepare_archive_id: str = ""
compaction_prepare_message_count: int = 0
compaction_prepare_watermark: int = 0
compaction_prepare_token_created_at: str = ""
extraction_in_progress: bool = False
extraction_done_event: object = None
extraction_active_count: int = 0
extraction_lock: object = field(default_factory=threading.Lock, repr=False)
@property
def pending_tokens(self) -> int:
return sum(m.estimated_tokens for m in self.messages)
@property
def tool_usage_stats(self) -> dict[str, dict]:
return self.usage_stats.tool_stats
@tool_usage_stats.setter
def tool_usage_stats(self, stats: dict[str, dict]) -> None:
self.usage_stats.merge_tool_stats(stats)
def add(self, role: str, content: str, created_at: str | None = None) -> SessionMessage:
kwargs: dict = dict(
id=f"msg_{uuid.uuid4().hex[:12]}",
role=role,
content=content,
)
if created_at:
kwargs["created_at"] = created_at
msg = SessionMessage(**kwargs)
self.messages.append(msg)
if role == "user":
self.turn_count += 1
if content and len(content.strip()) > 0:
task_candidate = content.strip()[:200]
action_words = ["help", "need", "want", "please", "can you", "could you"]
if any(word in task_candidate.lower() for word in action_words):
self.window_state.active_task = task_candidate
self.meta.message_count = len(self.messages)
self.meta.updated_at = datetime.now(timezone.utc).isoformat()
return msg
def should_compress(
self,
turn_threshold: int = 10,
token_threshold: int = 10_000,
) -> bool:
"""Check if window state needs recompression."""
turns_since = self.turn_count - self.window_state.turn_count_at_last_compress
tokens_since = self.pending_tokens - self.window_state.token_count_at_last_compress
return turns_since >= turn_threshold or tokens_since >= token_threshold
def snapshot_and_clear(self) -> list[SessionMessage]:
"""Return a copy of messages and clear the buffer.
Note: extraction_watermark and extraction_summary are preserved across
snapshots because they track which memories have been extracted, not
which messages are in the buffer.
"""
snap = list(self.messages)
extracted_count = self.extraction_watermark
summary = self.extraction_summary
self.messages.clear()
self.turn_count = 0
self.window_state = SessionWindowState()
self.meta.message_count = 0
self.meta.updated_at = datetime.now(timezone.utc).isoformat()
self.extraction_watermark = 0
self.extraction_summary = summary
self.compaction_prepare_token = ""
self.compaction_prepare_archive_id = ""
self.compaction_prepare_message_count = 0
self.compaction_prepare_watermark = 0
self.compaction_prepare_token_created_at = ""
return snap
def capture_rollback_state(self) -> dict:
"""Capture buffer fields that snapshot_and_clear mutates."""
return {
"messages": list(self.messages),
"turn_count": self.turn_count,
"window_state": deepcopy(self.window_state),
"meta_message_count": self.meta.message_count,
"meta_updated_at": self.meta.updated_at,
"extraction_watermark": self.extraction_watermark,
"extraction_summary": self.extraction_summary,
"compaction_prepare_token": self.compaction_prepare_token,
"compaction_prepare_archive_id": self.compaction_prepare_archive_id,
"compaction_prepare_message_count": self.compaction_prepare_message_count,
"compaction_prepare_watermark": self.compaction_prepare_watermark,
"compaction_prepare_token_created_at": self.compaction_prepare_token_created_at,
}
def restore_rollback_state(self, state: dict) -> None:
"""Restore buffer fields after a failed archive commit."""
original_messages = list(state.get("messages", []))
original_ids = {msg.id for msg in original_messages}
new_messages = [msg for msg in self.messages if msg.id not in original_ids]
self.messages = original_messages + new_messages
self.turn_count = sum(1 for msg in self.messages if msg.role == "user")
self.window_state = deepcopy(state.get("window_state", SessionWindowState()))
self.meta.message_count = len(self.messages)
self.meta.updated_at = state.get("meta_updated_at", self.meta.updated_at)
self.extraction_watermark = min(
int(state.get("extraction_watermark", 0) or 0),
len(self.messages),
)
self.extraction_summary = state.get("extraction_summary", "")
self.compaction_prepare_token = state.get("compaction_prepare_token", "")
self.compaction_prepare_archive_id = state.get("compaction_prepare_archive_id", "")
self.compaction_prepare_message_count = int(
state.get("compaction_prepare_message_count", 0) or 0
)
self.compaction_prepare_watermark = int(
state.get("compaction_prepare_watermark", 0) or 0
)
self.compaction_prepare_token_created_at = state.get(
"compaction_prepare_token_created_at", ""
)
def begin_extraction(self):
"""Mark a background extraction as active and return the shared done event."""
with self.extraction_lock:
event = self.extraction_done_event
if event is None or not self.extraction_in_progress or getattr(event, "is_set", lambda: False)():
event = threading.Event()
self.extraction_done_event = event
self.extraction_active_count += 1
self.extraction_in_progress = True
return event
def end_extraction(self) -> None:
"""Mark one background extraction as finished."""
with self.extraction_lock:
if self.extraction_active_count > 0:
self.extraction_active_count -= 1
if self.extraction_active_count <= 0:
self.extraction_active_count = 0
self.extraction_in_progress = False
event = self.extraction_done_event
if event is not None:
event.set()
def remove_messages_by_id(self, message_ids: set[str]) -> int:
"""Remove archived snapshot messages without touching newer messages."""
if not message_ids:
return 0
kept: list[SessionMessage] = []
removed = 0
removed_before_watermark = 0
for idx, msg in enumerate(self.messages):
if msg.id in message_ids:
removed += 1
if idx < self.extraction_watermark:
removed_before_watermark += 1
else:
kept.append(msg)
if removed:
self.messages = kept
self.turn_count = sum(1 for msg in self.messages if msg.role == "user")
self.window_state.session_state_sync_turn_count = min(
self.window_state.session_state_sync_turn_count,
self.turn_count,
)
self.extraction_watermark = max(
0,
min(len(self.messages), self.extraction_watermark - removed_before_watermark),
)
self.meta.message_count = len(self.messages)
self.meta.updated_at = datetime.now(timezone.utc).isoformat()
return removed
def rewind_watermark_to_ids(self, message_ids: set[str]) -> bool:
"""Rewind extraction watermark to the first remaining failed message."""
if not message_ids:
return False
for idx, msg in enumerate(self.messages):
if msg.id in message_ids:
if idx < self.extraction_watermark:
self.extraction_watermark = idx
self.meta.updated_at = datetime.now(timezone.utc).isoformat()
return True
return False
return False
class SessionManager:
"""Server-side session lifecycle manager.
In-memory buffer + optional AGFS persistence.
Provides: get_or_create, add_message, get_session, commit, get_context.
"""
def __init__(
self,
get_llm: Callable[[], Any] | None = None,
get_write_api: Callable[[], Any] | None = None,
get_agfs: Callable[[], Any] | None = None,
get_archive_store: Callable[[], Any] | None = None,
archive_store_required: bool = False,
get_context_fs: Callable[[], Any] | None = None,
session_state: SessionState | None = None,
archive_max_count: int = 10,
archive_merge_threshold: int = 10,
compression_quality_enabled: bool = False,
compression_quality_persist_metadata: bool = False,
):
self._sessions: dict[str, SessionBuffer] = {}
self._get_llm = get_llm
self._get_write_api = get_write_api
self._get_agfs = get_agfs
self._get_archive_store = get_archive_store
self._archive_store_required = archive_store_required
self._get_context_fs = get_context_fs
self._session_state = session_state or SessionState()
self._session_state_loaded: set[str] = set()
self._session_load_events: dict[str, threading.Event] = {}
self._topic_buffers: dict[str, SessionTopicBuffer] = {}
self._archive_max_count = max(1, int(archive_max_count or 10))
self._archive_merge_threshold = max(1, int(archive_merge_threshold or 10))
self._archive_merge_max_messages = 200
self._compression_quality_enabled = compression_quality_enabled
self._compression_quality_persist_metadata = compression_quality_persist_metadata
self._tasks: dict[str, dict] = {}
self._lock = threading.RLock()
self._save_lock = threading.RLock()
def get_session_state(self) -> SessionState:
return self._session_state
@staticmethod
def _session_state_context(
session_id: str,
meta: SessionMeta | None,
fallback_ctx: RequestContext,
) -> RequestContext:
account_id = meta.account_id if meta and meta.account_id else fallback_ctx.account_id
user_id = meta.user_id if meta and meta.user_id else fallback_ctx.user_id
agent_id = meta.agent_id if meta and meta.agent_id else fallback_ctx.agent_id
return RequestContext(
account_id=account_id,
user_id=user_id,
agent_id=agent_id,
session_id=session_id,
trace_id=fallback_ctx.trace_id,
role=fallback_ctx.role,
visible_owner_spaces=(),
)
def load_session_state(self, session_id: str, ctx: RequestContext) -> bool:
with self._lock:
if session_id in self._session_state_loaded:
return True
if self._get_context_fs is None:
return False
try:
fs = self._get_context_fs()
if fs is None:
return False
loaded, payload = self._session_state.load_with_payload(session_id, fs, ctx)
if loaded:
self._restore_runtime_state_from_payload(session_id, payload or {})
with self._lock:
self._session_state_loaded.add(session_id)
return loaded
except Exception as exc:
logger.warning("load_session_state failed session=%s: %s", session_id, exc)
return False
def save_session_state(self, session_id: str, ctx: RequestContext) -> bool:
if self._get_context_fs is None:
return False
try:
fs = self._get_context_fs()
if fs is None:
return False
with self._save_lock:
with self._lock:
buf = self._sessions.get(session_id)
if buf is not None:
self._sync_runtime_state_to_session_state(session_id, buf)
session_meta = deepcopy(buf.meta) if buf is not None else None
window_state = deepcopy(buf.window_state) if buf is not None else None
extraction_summary = buf.extraction_summary if buf is not None else None
payload = self._session_state.to_dict(
session_id,
session_meta=session_meta,
window_state=window_state,
extraction_summary=extraction_summary,
)
now = datetime.now(timezone.utc).isoformat()
node = ContextNode(
uri=self._session_state.state_uri(ctx.account_id, session_id),
context_type="RESOURCE",
category="state",
level=0,
owner_space=f"session:{session_id}",
abstract="Session task state and commitments",
overview="Session task state and commitments",
content=json.dumps(payload, ensure_ascii=False, indent=2),
metadata={
"session_id": session_id,
"updated_at": now,
},
)
fs.write_node(node, ctx)
return True
except Exception as exc:
logger.warning("save_session_state failed session=%s: %s", session_id, exc)
return False
def _sync_runtime_state_to_session_state(
self,
session_id: str,
buf: SessionBuffer,
) -> None:
"""Promote runtime window state into durable SessionState before save."""
window_state = buf.window_state
current = self._session_state.get_task_state(session_id)
task_update = TaskState()
if not current.objective and window_state.active_task:
task_update.objective = window_state.active_task
if not current.blockers and window_state.uncertainties:
task_update.blockers = list(window_state.uncertainties)
if task_update.objective is not None or task_update.blockers:
self._session_state.update_task_state(session_id, task_update)
existing = {
(item.content, item.status, item.kind)
for item in self._session_state.get_commitments(session_id, status=None)
}
def add_missing(content: str, status: str, kind: str) -> None:
content = str(content or "").strip()
if not content:
return
key = (content, status, kind)
if key in existing:
return
self._session_state.add_commitment(
session_id,
Commitment(content=content, status=status, kind=kind),
)
existing.add(key)
for content in window_state.open_loops:
add_missing(content, "open", "loop")
for content in window_state.confirmed_constraints:
add_missing(content, "fulfilled", "constraint")
for content in window_state.recent_decisions:
add_missing(content, "fulfilled", "decision")
@staticmethod
def _restore_dataclass_fields(target: Any, payload: dict) -> None:
if not isinstance(payload, dict):
return
for item in dataclass_fields(target):
name = item.name
if name not in payload:
continue
value = payload.get(name)
current = getattr(target, name)
if isinstance(current, list):
setattr(target, name, list(value) if isinstance(value, list) else [])
elif isinstance(current, int):
try:
setattr(target, name, int(value or 0))
except (TypeError, ValueError):
setattr(target, name, 0)
elif isinstance(current, str):
setattr(target, name, str(value or ""))
else:
setattr(target, name, value)
def _restore_runtime_state_from_payload(self, session_id: str, payload: dict) -> None:
with self._lock:
buf = self._sessions.get(session_id)
if buf is None:
return
self._restore_dataclass_fields(buf.meta, payload.get("session_meta") or {})
buf.meta.session_id = session_id
self._restore_dataclass_fields(buf.window_state, payload.get("window_state") or {})
extraction_summary = payload.get("extraction_summary")
if isinstance(extraction_summary, str):
buf.extraction_summary = extraction_summary
if self._session_state.has_state(session_id):
try:
from session.session_state_bridge import apply_session_state_bridge
apply_session_state_bridge(
buf.window_state,
self._session_state,
session_id,
turn_count=buf.turn_count,
force=True,
)
except Exception as exc:
logger.warning("restore session state bridge failed session=%s: %s", session_id, exc)
def get_topic_buffer(self, session_id: str) -> SessionTopicBuffer:
"""Get or create the per-session topic buffer."""
with self._lock:
if session_id not in self._topic_buffers:
self._topic_buffers[session_id] = SessionTopicBuffer(session_id)
return self._topic_buffers[session_id]
def _maybe_merge_archives(
self,
session_id: str,
ctx: RequestContext,
wait: bool = False,
) -> None:
if self._get_archive_store is None:
return
def _run() -> None:
try:
self._merge_archives_once(session_id, ctx)
except Exception as exc:
logger.warning("archive auto-merge failed session=%s: %s", session_id, exc)
if wait:
_run()
return
thread = threading.Thread(
target=_run,
daemon=True,
name=f"archive-merge-{session_id[:8]}",
)
thread.start()
def _merge_archives_once(self, session_id: str, ctx: RequestContext) -> bool:
store = self._get_archive_store()
if store is None:
return False
entries = store.list_archives(session_id, ctx)
if len(entries) < 2 or len(entries) < self._archive_merge_threshold:
return False
sorted_entries = sorted(entries, key=lambda entry: entry.created_at or "")
merge_count = max(0, len(sorted_entries) - self._archive_max_count + 1)
merge_count = max(2, merge_count)
to_merge_entries = sorted_entries[:merge_count]
full_entries = []
for entry in to_merge_entries:
merged_entry = store.read_archive(session_id, entry.archive_id, ctx)
full_entries.append(merged_entry if isinstance(merged_entry, ArchiveEntry) else entry)
from session.archive_merger import ArchiveMerger
merged = ArchiveMerger(max_messages=self._archive_merge_max_messages).merge(full_entries)
if not merged.overview and not merged.messages:
return False
merged_archive_id = (
f"merged_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}"
f"_{uuid.uuid4().hex[:8]}"
)
result = store.write_archive(
session_id=session_id,
overview=merged.overview,
abstract=merged.abstract,
messages=merged.messages,
ctx=ctx,
archive_id=merged_archive_id,
metadata=merged.metadata,
)
if not getattr(result, "success", False):
return False
merged_archive_id = getattr(result, "archive_id", merged_archive_id) or merged_archive_id
marked_archive_ids: list[str] = []
if hasattr(store, "mark_archive_merged"):
for entry in to_merge_entries:
try:
ok = store.mark_archive_merged(
session_id,
entry.archive_id,
ctx,
merged_archive_id,
)
except Exception as exc:
logger.warning(
"archive merge mark failed session=%s archive=%s: %s",
session_id,
entry.archive_id,
exc,
)
ok = False
if not ok:
if hasattr(store, "unmark_archive_merged"):
for marked_id in marked_archive_ids:
try:
store.unmark_archive_merged(
session_id,
marked_id,
ctx,
merged_archive_id,
)
except Exception:
logger.warning(
"archive merge unmark failed session=%s archive=%s",
session_id,
marked_id,
)
if hasattr(store, "delete_archive"):
try:
store.delete_archive(session_id, merged_archive_id, ctx)
except Exception:
logger.warning(
"archive merge cleanup delete failed session=%s archive=%s",
session_id,
merged_archive_id,
)
return False
marked_archive_ids.append(entry.archive_id)
self.invalidate_topic_slot(session_id, "archive_history")
return True
@staticmethod
def _commit_succeeded(task_info: dict) -> bool:
return (
task_info.get("status") == "completed"
and task_info.get("archived", False) is True
)
def _update_task(self, task_info: dict, **fields) -> dict:
with self._lock:
task_info.update(fields)
return dict(task_info)
def _task_snapshot(self, task_info: dict) -> dict:
with self._lock:
return dict(task_info)
def get_or_create(self, session_id: str, ctx: RequestContext | None = None) -> SessionBuffer:
"""Get existing buffer or create a new one.
On first creation, binds identity (account_id, user_id, agent_id)
from *ctx* into the SessionMeta so subsequent requests can inherit it.
"""
load_event = None
should_load = False
wait_for_load = False
with self._lock:
buf = self._sessions.get(session_id)
if buf is None:
meta = SessionMeta(session_id=session_id)
if ctx is not None:
meta.account_id = ctx.account_id
meta.user_id = ctx.user_id
meta.agent_id = ctx.agent_id
buf = SessionBuffer(
session_id=session_id, meta=meta
)
self._sessions[session_id] = buf
logger.debug("Created new session buffer: %s", session_id)
load_event = self._session_load_events.get(session_id)
if load_event is not None and not load_event.is_set():
wait_for_load = True
elif ctx is not None and session_id not in self._session_state_loaded:
should_load = True
load_event = threading.Event()
self._session_load_events[session_id] = load_event
if should_load and load_event is not None:
try:
self.load_session_state(session_id, ctx)
finally:
with self._lock:
pending_event = self._session_load_events.pop(session_id, None)
if pending_event is not None:
pending_event.set()
elif wait_for_load and load_event is not None:
load_event.wait()
return buf
def get_window_state(self, session_id: str) -> SessionWindowState:
"""Get current window state for a session."""
buf = self.get_or_create(session_id)
with self._lock:
return buf.window_state
def update_window_state(
self,
session_id: str,
window_state: SessionWindowState,
) -> None:
"""Update window state after compression."""
buf = self.get_or_create(session_id)
with self._lock:
buf.window_state = window_state
def invalidate_topic_slot(self, session_id: str, name: str) -> None:
with self._lock:
buffer = self._topic_buffers.get(session_id)
if buffer is not None:
buffer.invalidate(name)
def list_session_working_set(self) -> list[dict]:
with self._lock:
rows: list[dict] = []
for session_id, buf in self._sessions.items():
rows.append({
"session_id": session_id,
"last_accessed_at": buf.window_state.last_accessed_at,
"message_count": len(buf.messages),
"pending_tokens": buf.pending_tokens,
"active_task": buf.window_state.active_task,
})
return sorted(
rows,
key=lambda row: row.get("last_accessed_at") or "",
reverse=True,
)
def evict_idle_sessions(
self,
max_idle_seconds: int,
now_iso: str | None = None,
ctx: RequestContext | None = None,
) -> list[str]:
if max_idle_seconds <= 0:
return []
now = (
datetime.fromisoformat(now_iso.replace("Z", "+00:00"))
if now_iso
else datetime.now(timezone.utc)
)
evicted: list[str] = []
with self._lock:
candidates = [
(
session_id,
buf.window_state.last_accessed_at,
buf.commit_in_progress,
buf.extraction_in_progress,
len(buf.messages),
buf.extraction_watermark,
deepcopy(buf.meta),
)
for session_id, buf in self._sessions.items()
]
for (
session_id,
last,
commit_in_progress,
extraction_in_progress,
message_count,
extraction_watermark,
session_meta,
) in candidates:
if not last:
continue
try:
accessed = datetime.fromisoformat(last.replace("Z", "+00:00"))
except Exception:
continue
if (now - accessed).total_seconds() > max_idle_seconds:
if (
commit_in_progress
or extraction_in_progress
or message_count > extraction_watermark
):
continue
if ctx is not None:
self.save_session_state(
session_id,
self._session_state_context(session_id, session_meta, ctx),
)
with self._lock:
buf = self._sessions.get(session_id)
if buf is None:
continue
if (
buf.commit_in_progress
or buf.extraction_in_progress
or len(buf.messages) > buf.extraction_watermark
):
continue
self._sessions.pop(session_id, None)
self._topic_buffers.pop(session_id, None)
logger.debug("Removed session buffer: %s", session_id)
evicted.append(session_id)
return evicted
def issue_compaction_prepare_token(self, session_id: str, archive_id: str = "") -> str:
buf = self.get_or_create(session_id)
with self._lock:
token = uuid.uuid4().hex
buf.compaction_prepare_token = token
buf.compaction_prepare_archive_id = archive_id
buf.compaction_prepare_message_count = len(buf.messages)
buf.compaction_prepare_watermark = buf.extraction_watermark
buf.compaction_prepare_token_created_at = datetime.now(timezone.utc).isoformat()
return token
def get_compaction_prepare_archive_id(self, session_id: str) -> str:
buf = self.get_or_create(session_id)
with self._lock:
return buf.compaction_prepare_archive_id
def get_compaction_prepare_state(
self,
session_id: str,
ttl_seconds: int = 300,
) -> dict | None:
buf = self.get_or_create(session_id)
with self._lock:
token = buf.compaction_prepare_token
if not token:
return None
if not self._compaction_prepare_token_is_fresh(
buf.compaction_prepare_token_created_at,
ttl_seconds,
):
self._clear_compaction_prepare_token_locked(buf)
return None
if (
len(buf.messages) != buf.compaction_prepare_message_count
or buf.extraction_watermark != buf.compaction_prepare_watermark
):
self._clear_compaction_prepare_token_locked(buf)
return None
return {
"prepareToken": token,
"archive_id": buf.compaction_prepare_archive_id or "",
}
def consume_compaction_prepare_token(
self,
session_id: str,
token: str,
ttl_seconds: int = 300,
) -> bool:
buf = self.get_or_create(session_id)
with self._lock:
token_matches = bool(token) and token == buf.compaction_prepare_token
message_count_matches = len(buf.messages) == buf.compaction_prepare_message_count
watermark_matches = buf.extraction_watermark == buf.compaction_prepare_watermark
ttl_matches = self._compaction_prepare_token_is_fresh(
buf.compaction_prepare_token_created_at,
ttl_seconds,
)
if buf.compaction_prepare_token and not ttl_matches:
self._clear_compaction_prepare_token_locked(buf)
return False
is_valid = token_matches and message_count_matches and watermark_matches and ttl_matches
if is_valid:
self._clear_compaction_prepare_token_locked(buf)
return is_valid
def clear_compaction_prepare_token(self, session_id: str) -> None:
buf = self.get_or_create(session_id)
with self._lock:
self._clear_compaction_prepare_token_locked(buf)
@staticmethod
def _clear_compaction_prepare_token_locked(buf: SessionBuffer) -> None:
buf.compaction_prepare_token = ""
buf.compaction_prepare_archive_id = ""
buf.compaction_prepare_message_count = 0
buf.compaction_prepare_watermark = 0
buf.compaction_prepare_token_created_at = ""
@staticmethod
def _compaction_prepare_token_is_fresh(created_at: str, ttl_seconds: int) -> bool:
if ttl_seconds <= 0:
return False
if not created_at:
return False
try:
created = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
now = datetime.now(created.tzinfo or timezone.utc)
return (now - created).total_seconds() < ttl_seconds
except Exception:
return False
def add_message(
self,
session_id: str,
role: str,
content: str,
ctx: RequestContext,
created_at: str | None = None,
) -> dict:
"""Add a message to the session buffer."""
buf = self.get_or_create(session_id, ctx=ctx)
with self._lock:
msg = buf.add(role, content, created_at=created_at)
pending_tokens = buf.pending_tokens
return {
"ok": True,
"message_id": msg.id,
"pending_tokens": pending_tokens,
}
def get_session(self, session_id: str, ctx: RequestContext) -> dict:
"""Return session meta + pending_tokens."""
buf = self.get_or_create(session_id, ctx=ctx)
with self._lock:
return {
"ok": True,
"session_id": session_id,
"pending_tokens": buf.pending_tokens,
"message_count": buf.meta.message_count,
"commit_count": buf.meta.commit_count,
"last_commit_at": buf.meta.last_commit_at,
"created_at": buf.meta.created_at,
"updated_at": buf.meta.updated_at,
}
def commit(
self,
session_id: str,
ctx: RequestContext,
wait: bool = False,
skip_extraction: bool = False,
session_time=None,
archive_id: str | None = None,
) -> dict:
"""Two-phase commit.
Phase 1 (sync): snapshot messages, clear buffer, write raw archive.
Phase 2 (background thread): generate overview/abstract from archive.
Returns { task_id, archived, archive_uri, status }.
Args:
skip_extraction: If True, skip memory extraction in phase 2
(used when caller already performed extraction, e.g. after_turn).
session_time: Optional datetime for temporal resolution in extraction.
"""
buf = self.get_or_create(session_id, ctx=ctx)
with self._lock:
if not buf.messages:
return {"ok": True, "archived": False, "reason": "empty_buffer"}
if buf.commit_in_progress:
return {"ok": True, "archived": False, "reason": "commit_in_progress"}
buf.commit_in_progress = True
self._sync_runtime_state_to_session_state(session_id, buf)
rollback_state = buf.capture_rollback_state()
snapshot = buf.snapshot_and_clear()
if archive_id is None:
archive_id = generate_archive_id()
task_id = f"task_{uuid.uuid4().hex[:12]}"
task_info = {
"task_id": task_id,
"session_id": session_id,
"archive_id": archive_id,
"status": "processing",
"created_at": datetime.now(timezone.utc).isoformat(),
}
self._tasks[task_id] = task_info
if wait:
try:
self._process_snapshot(
snapshot, session_id, archive_id, task_info, ctx
)
except Exception as exc:
logger.error(
"commit phase2 failed session=%s: %s",
session_id, exc, exc_info=True,
)
self._update_task(task_info, status="failed", error=str(exc))
with self._lock:
buf.commit_in_progress = False
buf.meta.commit_count += 1
buf.meta.last_commit_at = datetime.now(timezone.utc).isoformat()
task_snapshot = dict(task_info)
should_save = self._commit_succeeded(task_snapshot)
if task_snapshot.get("status") == "failed":
buf.restore_rollback_state(rollback_state)
if should_save:
self.save_session_state(session_id, ctx)
return {
"ok": True,
"archived": task_snapshot.get("archived", False),
"archive_id": archive_id,
"task_id": task_id,
"status": task_snapshot.get("status", "completed"),
"error": task_snapshot.get("error"),
}
def _commit_thread():
try:
self._process_snapshot(
snapshot, session_id, archive_id, task_info, ctx
)
except Exception as exc:
logger.error(
"commit phase2 failed session=%s: %s",
session_id, exc, exc_info=True,
)
self._update_task(task_info, status="failed", error=str(exc))
finally:
with self._lock:
buf.commit_in_progress = False
buf.meta.commit_count += 1
buf.meta.last_commit_at = datetime.now(timezone.utc).isoformat()
task_snapshot = dict(task_info)
should_save = self._commit_succeeded(task_snapshot)
if task_snapshot.get("status") == "failed":
buf.restore_rollback_state(rollback_state)
if should_save:
self.save_session_state(session_id, ctx)
t = threading.Thread(target=_commit_thread, daemon=True, name=f"commit-{session_id[:8]}")
t.start()
return {
"ok": True,
"archived": False,
"archive_id": archive_id,
"task_id": task_id,
"status": "processing",
}
def commit_snapshot(
self,
session_id: str,
snapshot: list[SessionMessage],
ctx: RequestContext,
wait: bool = True,
archive_id: str | None = None,
) -> dict:
"""Archive a fixed message snapshot without reading/clearing live buffer."""
if not snapshot:
return {"ok": True, "archived": False, "reason": "empty_snapshot"}
buf = self.get_or_create(session_id)
with self._lock:
if archive_id is None:
archive_id = generate_archive_id()
task_id = f"task_{uuid.uuid4().hex[:12]}"
task_info = {
"task_id": task_id,
"session_id": session_id,
"archive_id": archive_id,
"status": "processing",
"created_at": datetime.now(timezone.utc).isoformat(),
}
self._tasks[task_id] = task_info
snapshot_copy = list(snapshot)
if wait:
try:
self._process_snapshot(snapshot_copy, session_id, archive_id, task_info, ctx)
except Exception as exc:
logger.error(
"commit_snapshot phase2 failed session=%s: %s",
session_id, exc, exc_info=True,
)
self._update_task(task_info, status="failed", error=str(exc))
with self._lock:
buf.meta.commit_count += 1
buf.meta.last_commit_at = datetime.now(timezone.utc).isoformat()
task_snapshot = dict(task_info)
should_save = self._commit_succeeded(task_snapshot)
if should_save:
self.save_session_state(session_id, ctx)
return {
"ok": True,
"archived": task_snapshot.get("archived", False),
"archive_id": archive_id,
"task_id": task_id,
"status": task_snapshot.get("status", "completed"),
"error": task_snapshot.get("error"),
}
def _commit_snapshot_thread():
try:
self._process_snapshot(snapshot_copy, session_id, archive_id, task_info, ctx)
except Exception as exc:
logger.error(
"commit_snapshot phase2 failed session=%s: %s",
session_id, exc, exc_info=True,
)
self._update_task(task_info, status="failed", error=str(exc))
finally:
with self._lock:
buf.meta.commit_count += 1
buf.meta.last_commit_at = datetime.now(timezone.utc).isoformat()
task_snapshot = dict(task_info)
should_save = self._commit_succeeded(task_snapshot)
if should_save:
self.save_session_state(session_id, ctx)
t = threading.Thread(
target=_commit_snapshot_thread,
daemon=True,
name=f"commit-snapshot-{session_id[:8]}",
)
t.start()
return {
"ok": True,
"archived": False,
"archive_id": archive_id,
"task_id": task_id,
"status": "processing",
}
def _process_snapshot(
self,
snapshot: list[SessionMessage],
session_id: str,
archive_id: str,
task_info: dict,
ctx: RequestContext,
skip_extraction: bool = False,
session_time=None,
) -> None:
"""Background work: compress → write archive.
Extraction is NOT done here — it is the caller's responsibility
(e.g. after_turn or compact in MemoryService) to extract memories
before calling commit(). This prevents double extraction.
"""
messages_dicts = [
{"role": m.role, "content": m.content, "id": m.id, "created_at": m.created_at}
for m in snapshot
]
prev_overview, prev_abstract = self._get_latest_archive_context(
session_id, ctx,
)
overview, abstract = self._compress(
messages_dicts,
prev_overview=prev_overview,
prev_abstract=prev_abstract,
)
archive_metadata: dict = {}
if self._compression_quality_enabled:
try:
from session.compression_quality import CompressionQualityEvaluator
metrics = CompressionQualityEvaluator().evaluate(
messages_dicts,
overview,
abstract,
)
self._update_task(task_info, compression_quality=metrics)
if self._compression_quality_persist_metadata:
archive_metadata["compression_quality"] = metrics
except Exception as exc:
logger.warning(
"compression quality evaluation failed session=%s: %s",
session_id,
exc,
)
archive_result = self._write_archive(
session_id,
archive_id,
overview,
abstract,
messages_dicts,
ctx,
metadata=archive_metadata,
)
self._update_task(
task_info,
archived=archive_result.get("success", False),
archive_uri=archive_result.get("uri", ""),
)
if not archive_result.get("success", False):
self._update_task(
task_info,
status="failed",
error=archive_result.get("error", "archive write failed"),
)
return
self._update_task(task_info, status="completed")
self.invalidate_topic_slot(session_id, "archive_history")
self._maybe_merge_archives(session_id, ctx)
def _compress(
self,
messages: list[dict],
prev_overview: str = "",
prev_abstract: str = "",
) -> tuple[str, str]:
"""Compress messages into overview + abstract, fusing with previous archive."""
try:
if self._get_llm is not None:
llm = self._get_llm()
if llm is not None:
from session.compressor import SessionCompressor
compressor = SessionCompressor(llm=llm)
return compressor.compress(
messages,
prev_overview=prev_overview,
prev_abstract=prev_abstract,
)
except Exception as exc:
logger.warning("LLM compress failed, using fallback: %s", exc)
from session.compressor import SessionCompressor
return SessionCompressor(llm=None).compress(
messages,
prev_overview=prev_overview,
prev_abstract=prev_abstract,
)
def _write_archive(
self,
session_id: str,
archive_id: str,
overview: str,
abstract: str,
messages: list[dict],
ctx: RequestContext,
metadata: dict | None = None,
) -> dict:
"""Write archive via archive store if available.
When ``archive_store_required`` is enabled, archive persistence
failures are surfaced instead of silently degrading to an in-memory
URI.
"""
try:
store = None
if self._get_archive_store is not None:
store = self._get_archive_store()
if store is None and self._archive_store_required:
return {
"success": False,
"uri": "",
"archive_id": archive_id,
"error": "archive store unavailable",
}
elif self._get_agfs is not None:
agfs = self._get_agfs()
if agfs is not None:
from session.archive_store import SessionArchiveStore
store = SessionArchiveStore(fs=agfs)
if store is not None:
result = store.write_archive(
session_id=session_id,
overview=overview,
abstract=abstract,
messages=messages,
ctx=ctx,
archive_id=archive_id,
metadata=metadata or {},
)
if result.success:
return {
"success": True,
"uri": result.uri,
"archive_id": result.archive_id,
}
logger.warning("Archive write failed: %s", result.error)
if self._archive_store_required:
return {
"success": False,
"uri": result.uri,
"archive_id": archive_id,
"error": result.error or "archive write failed",
}
except Exception as exc:
logger.warning("Archive write failed: %s", exc, exc_info=True)
if self._archive_store_required:
return {
"success": False,
"uri": "",
"archive_id": archive_id,
"error": str(exc),
}
return {
"success": True,
"uri": f"memory://{session_id}/{archive_id}",
"archive_id": archive_id,
}
def _get_latest_archive_context(
self, session_id: str, ctx: RequestContext,
) -> tuple[str, str]:
"""Fetch overview and abstract from the most recent archive for fusion.
Returns:
(prev_overview, prev_abstract) — empty strings if none found.
"""
def _latest_from_entries(entries) -> tuple[str, str]:
if not entries:
return "", ""
entries.sort(key=lambda e: e.created_at or "", reverse=True)
latest = entries[0]
return latest.overview or "", latest.abstract or ""
if self._get_archive_store is not None:
try:
store = self._get_archive_store()
if store is not None:
overview, abstract = _latest_from_entries(
store.list_archives(session_id, ctx),
)
if overview or abstract:
return overview, abstract
except Exception as exc:
logger.warning(
"Failed to fetch latest archive for fusion via archive store: %s",
exc,
)
if self._get_agfs is not None:
try:
agfs = self._get_agfs()
if agfs is not None:
from session.archive_store import SessionArchiveStore
store = SessionArchiveStore(fs=agfs)
return _latest_from_entries(store.list_archives(session_id, ctx))
except Exception as exc:
logger.warning(
"Failed to fetch latest archive for fusion via AGFS fallback: %s",
exc,
)
return "", ""
def _extract_memories(
self, messages: list[dict], ctx: RequestContext,
buf: SessionBuffer | None = None,
) -> dict:
"""Extract candidate memories from committed messages."""
try:
if self._get_write_api is not None:
write_api = self._get_write_api()
if write_api is not None:
with self._lock:
session_summary = buf.extraction_summary if buf else ""
result = write_api.commit_session(
messages, ctx, confidence_threshold=0.5, wait=True,
session_summary=session_summary,
)
if buf and result.get("candidates_extracted", 0) > 0:
new_items = [
f"- [{p['action']}] {p['target_uri']}"
for p in result.get("plans", [])
if p["action"] != "skip"
]
if new_items:
with self._lock:
appended = (
buf.extraction_summary + "\n" + "\n".join(new_items)
).strip()
buf.extraction_summary = appended[-2000:]
return {
"candidates_extracted": result.get("candidates_extracted", 0),
"writes_completed": result.get("writes_completed", 0),
}
except Exception as exc:
logger.warning("Memory extraction failed: %s", exc, exc_info=True)
return {"candidates_extracted": 0, "writes_completed": 0}
def get_context(
self,
session_id: str,
token_budget: int,
ctx: RequestContext,
) -> dict:
"""Assemble context: profile + archives + active messages."""
buf = self.get_or_create(session_id)
with self._lock:
active_tokens = buf.pending_tokens
active_message_count = len(buf.messages)
latest_archive_overview = ""
archive_refs: list[dict] = []
try:
store = None
if self._get_archive_store is not None:
store = self._get_archive_store()
elif self._get_agfs is not None:
agfs = self._get_agfs()
if agfs is not None:
from session.archive_store import SessionArchiveStore
store = SessionArchiveStore(fs=agfs)
if store is not None:
entries = store.list_archives(session_id, ctx)
if entries:
latest = entries[0]
latest_archive_overview = latest.overview or ""
for entry in entries:
archive_refs.append({
"archive_id": entry.archive_id,
"abstract": entry.abstract,
})
except Exception as exc:
logger.warning("get_context archive collection failed: %s", exc)
return {
"ok": True,
"session_id": session_id,
"pending_tokens": active_tokens,
"estimatedTokens": active_tokens,
"active_message_count": active_message_count,
"archive_count": len(archive_refs),
"latest_archive_overview": latest_archive_overview,
"archives": archive_refs,
}
def get_task(self, task_id: str) -> dict | None:
"""Poll background task status."""
with self._lock:
task = self._tasks.get(task_id)
return deepcopy(task) if task is not None else None
def has_session(self, session_id: str) -> bool:
with self._lock:
return session_id in self._sessions
def remove_session(self, session_id: str) -> None:
"""Remove session buffer (for dispose/cleanup)."""
with self._lock:
removed = self._sessions.pop(session_id, None)
self._topic_buffers.pop(session_id, None)
if removed:
logger.debug("Removed session buffer: %s", session_id)
def get_pending_session_ids(self) -> list[str]:
"""Return IDs of sessions with unextracted messages."""
with self._lock:
return [
sid for sid, buf in self._sessions.items()
if buf.extraction_watermark < len(buf.messages)
]