"""Service API layer - write operations for ContextEngine.
This is the ONLY layer where RequestContext is mandatory and account_id is injected.
All external calls must provide a RequestContext for multi-tenant isolation.
Note: dev branch is WRITE-ONLY. Read operations (ReadAPI) are in phase1 branch.
See CLAUDE.md §7 for tool interface spec and §8 for multi-tenant rules.
"""
import threading
import uuid
from dataclasses import dataclass
from typing import Optional
from core.logging_config import get_logger
logger = get_logger(__name__)
from core.interfaces import ContextFS, LLM, CandidateExtractor
from core.models import RequestContext, CandidateMemory, WritePlan
from commit import ContextWriter, CandidatePipeline, OutboxStore
from extraction import Extractor
class MemoryWriteAPI:
"""Public API for memory write operations.
All methods enforce multi-tenant isolation via RequestContext.
Writes are orchestrated through the commit pipeline:
1. Extract candidates via CandidateExtractor
2. Plan write actions via MergePolicy
3. Build ContextNode via ArchiveBuilder
4. Write to storage via ContextFS
5. Register OutboxEvents for async indexing
"""
def __init__(
self,
fs: ContextFS,
llm: LLM,
outbox_store: Optional[OutboxStore] = None,
schema_registry=None,
vector_index=None,
embedder=None,
uri_resolver=None,
internal_tool_usage_tracker=None,
):
"""Initialize the API with required dependencies.
Args:
fs: ContextFS implementation for persisting nodes
llm: LLM instance for extraction
outbox_store: OutboxStore for registering index events (optional)
schema_registry: SchemaRegistry for dynamic tool generation (auto-created if None)
vector_index: Optional VectorIndex for prefetching existing memories
embedder: Optional Embedder for prefetching existing memories
uri_resolver: Optional URIResolver for prefetching existing memories
internal_tool_usage_tracker: Optional tracker for oGMem internal tool calls
"""
self._fs = fs
self._llm = llm
self._outbox_store = outbox_store
self._vector_index = vector_index
self._embedder = embedder
self._uri_resolver = uri_resolver
self._internal_tool_usage_tracker = internal_tool_usage_tracker
if schema_registry is None:
from extraction.schemas.registry import SchemaRegistry
schema_registry = SchemaRegistry()
self._schema_registry = schema_registry
from commit.policy_router import PolicyRouter
policy_router = PolicyRouter(fs, registry=schema_registry, uri_resolver=uri_resolver)
self._writer = ContextWriter(fs, llm=self._llm, outbox_store=outbox_store, policy_router=policy_router)
self._pipeline = CandidatePipeline()
self._pipeline.set_extractors(self._create_extractors())
self._tasks: dict[str, dict] = {}
self._tasks_lock = threading.Lock()
def _create_extractors(self) -> list[CandidateExtractor]:
"""Create default extractors for the pipeline.
Returns:
List of CandidateExtractor instances (single tool-use Extractor)
"""
try:
from extraction.prompts import PromptManager
pm = PromptManager()
return [Extractor(
self._llm,
prompt_manager=pm,
mode="eager",
schema_registry=self._schema_registry,
fs=self._fs,
vector_index=self._vector_index,
embedder=self._embedder,
uri_resolver=self._uri_resolver,
internal_tool_usage_tracker=self._internal_tool_usage_tracker,
)]
except Exception:
return [Extractor(
self._llm,
schema_registry=self._schema_registry,
fs=self._fs,
vector_index=self._vector_index,
embedder=self._embedder,
uri_resolver=self._uri_resolver,
internal_tool_usage_tracker=self._internal_tool_usage_tracker,
)]
def commit_session(
self,
messages: list[dict],
ctx: RequestContext,
confidence_threshold: float = 0.5,
wait: bool = True,
session_time=None,
session_summary: str = "",
tool_stats_text: str = "",
archive_id: str | None = None,
) -> dict:
"""Commit a conversation session to memory.
This is the main entry point for writing memories.
Extracts candidates from messages, filters by confidence,
and writes to storage.
Args:
messages: List of message dicts with "role" and "content"
Example: [{"role": "user", "content": "..."}, ...]
ctx: RequestContext for this operation
confidence_threshold: Minimum confidence for writing (default 0.5)
wait: If True, block until extraction completes (default).
If False, return immediately with task_id for async processing.
session_time: Optional datetime for temporal resolution (defaults to now).
session_summary: Optional summary of previously extracted content.
tool_stats_text: Optional tool usage statistics text.
archive_id: Optional archive_id for provenance tracking.
Returns:
Dict with write results:
{
"candidates_extracted": int,
"candidates_filtered": int,
"writes_completed": int,
"writes_skipped": int,
"writes_failed": int,
"plans": list[WritePlan dict],
"task_id": str (only if wait=False),
"status": "processing" (only if wait=False)
}
"""
if archive_id is None:
from session.session_manager import generate_archive_id
archive_id = generate_archive_id()
candidates = self._pipeline.extract(
messages, ctx, session_time=session_time,
session_summary=session_summary,
tool_stats_text=tool_stats_text,
archive_id=archive_id,
)
filtered = self._pipeline.filter_by_confidence(candidates, confidence_threshold)
deduplicated = filtered
plans = self._writer.write_candidates(deduplicated, ctx)
writes_completed = sum(1 for p in plans if p.action != "skip")
writes_skipped = sum(1 for p in plans if p.action == "skip")
writes_failed = len(deduplicated) - writes_completed - writes_skipped
result = {
"archive_id": archive_id,
"candidates_extracted": len(candidates),
"candidates_filtered": len(candidates) - len(filtered),
"writes_completed": writes_completed,
"writes_skipped": writes_skipped,
"writes_failed": writes_failed,
"plans": [
{
"action": p.action,
"target_uri": p.target_uri,
"merged_fields": p.merged_fields,
}
for p in plans
],
}
if not wait:
task_id = str(uuid.uuid4())
with self._tasks_lock:
self._tasks[task_id] = {"status": "completed", "result": result}
result["task_id"] = task_id
result["status"] = "completed"
return result
def write_raw_chunk(
self,
messages: list[dict],
ctx: RequestContext,
chunk_index: int = 0,
session_id: str = "",
session_time=None,
) -> dict:
"""Write raw chunk text as a session_archive node for retrieval.
Stores the original conversation text so that details missed by
structured extraction can still be found via vector search.
Args:
messages: List of message dicts in this chunk
ctx: RequestContext for this operation
chunk_index: Index of this chunk within the session
session_id: Session identifier for unique routing key
session_time: Optional datetime for temporal context
Returns:
Dict with write result
"""
if not messages:
return {"action": "skip", "reason": "empty messages"}
lines = []
participants = set()
for msg in messages:
role = msg.get("role", "unknown")
content = msg.get("content", "")
if isinstance(content, list):
content = " ".join(
b.get("text", "") for b in content
if isinstance(b, dict) and b.get("text")
)
if content:
lines.append(f"{role}: {content}")
if role not in ("system", "assistant"):
participants.add(role)
raw_text = "\n".join(lines)
if not raw_text.strip():
return {"action": "skip", "reason": "no text content"}
sid = session_id[:8] if session_id else "unknown"
time_str = ""
if session_time:
try:
from datetime import datetime
if isinstance(session_time, datetime):
time_str = session_time.strftime("%Y%m%d")
elif isinstance(session_time, str):
time_str = session_time[:10].replace("-", "")
except Exception:
pass
routing_key = f"chunk_{time_str}_{sid}_{chunk_index}"
abstract = raw_text[:200].replace("\n", " ")
if len(raw_text) > 200:
abstract += "..."
overview_lines = lines[:5]
overview = "\n".join(overview_lines)
if len(lines) > 5:
overview += f"\n... ({len(lines) - 5} more messages)"
candidate = CandidateMemory(
category="session_archive",
owner_scope="user",
routing_key=routing_key,
abstract=abstract,
overview=overview,
content=raw_text,
confidence=1.0,
when=str(session_time) if session_time else None,
who=", ".join(participants) if participants else None,
)
plan = self._writer.write_candidate(candidate, ctx)
return {
"action": plan.action,
"target_uri": plan.target_uri,
}
def commit_session_async(
self,
messages: list[dict],
ctx: RequestContext,
confidence_threshold: float = 0.5,
session_time=None,
session_summary: str = "",
tool_stats_text: str = "",
) -> str:
"""Fire-and-forget version of commit_session.
Dispatches extraction + write to a background thread and returns
a task_id immediately. The caller can poll get_task_status(task_id)
for the result.
Returns:
task_id string for tracking the background job.
"""
task_id = str(uuid.uuid4())
with self._tasks_lock:
self._tasks[task_id] = {"status": "processing", "result": None}
def _run():
try:
result = self.commit_session(
messages=messages,
ctx=ctx,
confidence_threshold=confidence_threshold,
wait=True,
session_time=session_time,
session_summary=session_summary,
tool_stats_text=tool_stats_text,
)
with self._tasks_lock:
self._tasks[task_id] = {"status": "completed", "result": result}
except Exception as exc:
logger.error("commit_session_async failed for task %s: %s", task_id, exc, exc_info=True)
with self._tasks_lock:
self._tasks[task_id] = {"status": "failed", "error": str(exc)}
t = threading.Thread(target=_run, daemon=True, name=f"commit-{task_id[:8]}")
t.start()
return task_id
def get_task_status(self, task_id: str) -> dict | None:
"""Check status of an async commit_session task."""
with self._tasks_lock:
return self._tasks.get(task_id)
def write_memory(
self,
candidate: CandidateMemory,
ctx: RequestContext,
) -> dict:
"""Write a single candidate memory.
Bypasses extraction - use when you already have a CandidateMemory.
Args:
candidate: CandidateMemory to write
ctx: RequestContext for this operation
Returns:
Dict with write result:
{
"action": str,
"target_uri": str,
"merged_fields": dict,
}
"""
plan = self._writer.write_candidate(candidate, ctx)
return {
"action": plan.action,
"target_uri": plan.target_uri,
"merged_fields": plan.merged_fields,
}
def write_memories(
self,
candidates: list[CandidateMemory],
ctx: RequestContext,
parallel: bool = True,
) -> list[dict]:
"""Write multiple candidate memories.
Args:
candidates: List of CandidateMemory to write
ctx: RequestContext for these operations
parallel: If True, write in parallel (default True)
Returns:
List of write result dicts
"""
deduplicated = candidates
if parallel:
plans = self._writer.write_candidates_parallel(deduplicated, ctx)
else:
plans = self._writer.write_candidates(deduplicated, ctx)
return [
{
"action": p.action,
"target_uri": p.target_uri,
"merged_fields": p.merged_fields,
}
for p in plans
]
_default_write_api: Optional[MemoryWriteAPI] = None
def init_write_api(
fs: ContextFS,
llm: LLM,
outbox_store: Optional[OutboxStore] = None,
schema_registry=None,
vector_index=None,
embedder=None,
uri_resolver=None,
internal_tool_usage_tracker=None,
) -> MemoryWriteAPI:
"""Initialize the global write API instance.
Args:
fs: ContextFS implementation
llm: LLM instance for extraction
outbox_store: Optional OutboxStore for async indexing
schema_registry: Optional SchemaRegistry for dynamic tool generation
vector_index: Optional VectorIndex for prefetching existing memories
embedder: Optional Embedder for prefetching existing memories
uri_resolver: Optional URIResolver for prefetching existing memories
internal_tool_usage_tracker: Optional tracker for oGMem internal tool calls
Returns:
Configured MemoryWriteAPI instance
"""
global _default_write_api
_default_write_api = MemoryWriteAPI(
fs, llm, outbox_store, schema_registry,
vector_index, embedder, uri_resolver,
internal_tool_usage_tracker,
)
return _default_write_api
def get_write_api() -> Optional[MemoryWriteAPI]:
"""Get the global write API instance.
Returns:
MemoryWriteAPI if initialized, None otherwise
"""
return _default_write_api
from core.errors import AccessDeniedError, ValidationError as CoreValidationError
from core.models import (
RetrievalConfig,
RetrievedBlock,
RetrieverMode,
SearchMemoryResult,
)
from retrieval.pipeline import RetrievalPipeline
from retrieval.context_reader import ContextReader
class ReadAPI:
"""Public API for memory search and read operations.
Exposes two tools consumed by AI agents:
- search_memory: semantic retrieval -> structured SearchMemoryResult
- read_memory: URI-based read -> RetrievedBlock with full content
"""
def __init__(
self,
pipeline: RetrievalPipeline,
read_service: ContextReader | None = None,
config: RetrievalConfig | None = None,
) -> None:
self._pipeline = pipeline
self._read_service = read_service
self._cfg = config or RetrievalConfig()
def search_memory(
self,
query: str,
ctx: RequestContext,
*,
top_k: int = 10,
categories: list[str] | None = None,
target_uri: str | None = None,
session_archive: dict | None = None,
score_threshold: float | None = None,
include_debug: bool = False,
mode: str = RetrieverMode.QUICK,
fill_content_for_top_k: int = 0,
) -> SearchMemoryResult:
if not (query or "").strip():
raise CoreValidationError("query", "query must not be empty")
if top_k > self._cfg.max_top_k:
raise CoreValidationError("top_k", f"top_k={top_k} exceeds max {self._cfg.max_top_k}")
if target_uri:
prefix = f"ctx://{ctx.account_id}/"
if target_uri.startswith("ctx://") and not target_uri.startswith(prefix):
raise AccessDeniedError(target_uri, ctx.account_id, "target_uri account mismatch")
result = self._pipeline.run(
query, ctx,
top_k=top_k,
categories=categories,
target_uri=target_uri,
session_archive=session_archive,
score_threshold=score_threshold,
mode=mode,
fill_content_for_top_k=fill_content_for_top_k,
)
if not include_debug:
result.trace = None
return result
def read_memory(
self,
uri: str,
ctx: RequestContext,
) -> RetrievedBlock:
"""Read L2 md file content by URI.
Since search_memory already returns abstract in results,
read_memory only reads the actual md file content.
"""
return self._read_service.read(uri, ctx=ctx)