from __future__ import annotations
from core.models import ContextNode, RequestContext, SeedHit
from extraction.prefetch import MemoryPrefetcher
from extraction.schemas.models import FieldType, MemoryTypeSchema, SchemaField
from extraction.schemas.registry import SchemaRegistry
def _ctx() -> RequestContext:
return RequestContext(
account_id="acct",
user_id="user",
agent_id="agent",
session_id="sess",
trace_id="trace",
)
def _node(uri: str, category: str = "preference") -> ContextNode:
return ContextNode(
uri=uri,
context_type="MEMORY",
category=category,
level=2,
owner_space="user:user",
abstract=f"{category} abstract",
overview=f"{category} overview",
content=f"{category} content",
)
class FakeFS:
def __init__(self):
self.nodes: dict[str, ContextNode] = {}
self.children: dict[str, list[str]] = {}
def exists(self, uri, ctx):
return uri in self.nodes
def read_node(self, uri, ctx):
return self.nodes[uri]
def list_children(self, uri, ctx):
return self.children.get(uri, [])
class FailingFS(FakeFS):
def __init__(self, *, fail_exists=False, fail_read=False, fail_list=False):
super().__init__()
self.fail_exists = fail_exists
self.fail_read = fail_read
self.fail_list = fail_list
def exists(self, uri, ctx):
if self.fail_exists:
raise RuntimeError("exists failed")
return super().exists(uri, ctx)
def read_node(self, uri, ctx):
if self.fail_read:
raise RuntimeError("read failed")
return super().read_node(uri, ctx)
def list_children(self, uri, ctx):
if self.fail_list:
raise RuntimeError("list failed")
return super().list_children(uri, ctx)
class FakeEmbedder:
def __init__(self):
self.queries: list[str] = []
def embed_texts(self, texts):
self.queries.extend(texts)
return [[1.0, 0.0, 0.0] for _ in texts]
class FakeVectorIndex:
def __init__(self):
self.filters = None
self.hits = []
def search_by_vector(self, query_vector, filters, top_k):
self.filters = filters
return self.hits
class FailingVectorIndex(FakeVectorIndex):
def search_by_vector(self, query_vector, filters, top_k):
raise RuntimeError("search failed")
def _prefetcher(fs: FakeFS, vector_index: FakeVectorIndex, embedder: FakeEmbedder):
registry = SchemaRegistry()
from core.uri_resolver import URIResolver
return MemoryPrefetcher(fs, vector_index, embedder, registry, URIResolver(registry))
def test_prefetch_single_file_profile_reads_existing_node():
fs = FakeFS()
vector = FakeVectorIndex()
embedder = FakeEmbedder()
uri = "ctx://acct/users/user/memories/profile"
fs.nodes[uri] = _node(uri, "profile")
result = _prefetcher(fs, vector, embedder).prefetch("profile", _ctx())
assert uri in result.listed_uris
assert uri in result.read_uris
assert "Existing profile memory" in result.messages[0]
def test_prefetch_multi_file_uses_conversation_query_and_owner_filter():
fs = FakeFS()
vector = FakeVectorIndex()
embedder = FakeEmbedder()
hit_uri = "ctx://acct/users/user/memories/preferences/coffee/content.md"
node_uri = "ctx://acct/users/user/memories/preferences/coffee"
fs.nodes[node_uri] = _node(node_uri, "preference")
vector.hits = [
SeedHit(uri=hit_uri, score=0.91, category="preference", abstract="coffee", level=2)
]
result = _prefetcher(fs, vector, embedder).prefetch(
"preference",
_ctx(),
conversation_text="coffee preference",
)
assert embedder.queries == ["coffee preference"]
assert vector.filters == {
"category": "preference",
"account_id": "acct",
"owner_space": "user:user",
}
assert node_uri in result.read_uris
assert "similarity: 0.91" in result.messages[0]
def test_prefetch_add_only_lists_recent_entries_newest_first():
fs = FakeFS()
vector = FakeVectorIndex()
embedder = FakeEmbedder()
dir_uri = "ctx://acct/users/user/memories/events"
old_uri = f"{dir_uri}/20240501000000_old"
new_uri = f"{dir_uri}/20240601000000_new"
fs.children[dir_uri] = [old_uri, new_uri]
fs.nodes[old_uri] = _node(old_uri, "event")
fs.nodes[new_uri] = _node(new_uri, "event")
result = _prefetcher(fs, vector, embedder).prefetch("event", _ctx())
assert result.read_uris == {old_uri, new_uri}
assert f"URI: {new_uri}" in result.messages[0]
def test_prefetch_ignores_incompatible_schema_versions():
fs = FakeFS()
vector = FakeVectorIndex()
embedder = FakeEmbedder()
registry = SchemaRegistry(schemas_dir="/path/that/does/not/exist")
registry.register(
MemoryTypeSchema(
memory_type="future",
description="Future schema",
directory="future",
filename_template="{{ routing_key }}.md",
operation_mode="upsert",
version="2.0",
fields=[
SchemaField(
name="routing_key",
field_type=FieldType.STRING,
required=True,
)
],
)
)
from core.uri_resolver import URIResolver
result = MemoryPrefetcher(fs, vector, embedder, registry, URIResolver(registry)).prefetch(
"future",
_ctx(),
conversation_text="future",
)
assert result.messages == []
assert embedder.queries == []
assert vector.filters is None
def test_prefetch_unknown_and_unknown_operation_modes_return_empty():
fs = FakeFS()
vector = FakeVectorIndex()
embedder = FakeEmbedder()
registry = SchemaRegistry(schemas_dir="/path/that/does/not/exist")
registry.register(
MemoryTypeSchema(
memory_type="weird",
description="Weird schema",
directory="weird",
filename_template="{{ routing_key }}.md",
operation_mode="replace",
fields=[],
)
)
from core.uri_resolver import URIResolver
prefetcher = MemoryPrefetcher(fs, vector, embedder, registry, URIResolver(registry))
assert prefetcher.prefetch("missing", _ctx()).messages == []
assert prefetcher.prefetch("weird", _ctx()).messages == []
def test_prefetch_for_span_combines_results_and_continues_after_error(monkeypatch):
fs = FakeFS()
vector = FakeVectorIndex()
embedder = FakeEmbedder()
prefetcher = _prefetcher(fs, vector, embedder)
original_prefetch = prefetcher.prefetch
def flaky_prefetch(category, ctx, *, conversation_text=None):
if category == "broken":
raise RuntimeError("broken")
return original_prefetch(category, ctx, conversation_text=conversation_text)
uri = "ctx://acct/users/user/memories/profile"
fs.nodes[uri] = _node(uri, "profile")
monkeypatch.setattr(prefetcher, "prefetch", flaky_prefetch)
result = prefetcher.prefetch_for_span(["broken", "profile"], _ctx())
assert uri in result.read_uris
assert result.messages
def test_prefetch_single_file_handles_missing_and_errors():
vector = FakeVectorIndex()
embedder = FakeEmbedder()
missing = _prefetcher(FakeFS(), vector, embedder).prefetch("profile", _ctx())
failed = _prefetcher(FailingFS(fail_exists=True), vector, embedder).prefetch("profile", _ctx())
assert missing.messages == []
assert failed.messages == []
def test_prefetch_multi_file_handles_no_hits_search_failure_and_read_failure():
fs = FakeFS()
vector = FakeVectorIndex()
embedder = FakeEmbedder()
no_hits = _prefetcher(fs, vector, embedder).prefetch("preference", _ctx())
failing_search = _prefetcher(fs, FailingVectorIndex(), FakeEmbedder()).prefetch("preference", _ctx())
hit_uri = "ctx://acct/users/user/memories/preferences/coffee/content.md"
read_fail_vector = FakeVectorIndex()
read_fail_vector.hits = [SeedHit(uri=hit_uri, score=0.91, category="preference", abstract="coffee", level=2)]
read_failed = _prefetcher(FailingFS(fail_read=True), read_fail_vector, FakeEmbedder()).prefetch(
"preference",
_ctx(),
)
assert no_hits.messages == []
assert failing_search.messages == []
assert "read failed" in read_failed.messages[0]
def test_prefetch_multi_file_uses_agent_owner_filter():
fs = FakeFS()
vector = FakeVectorIndex()
embedder = FakeEmbedder()
_prefetcher(fs, vector, embedder).prefetch("tool", _ctx(), conversation_text="tool usage")
assert vector.filters["owner_space"] == "agent:agent"
def test_prefetch_add_only_handles_empty_read_failure_and_list_failure():
fs = FakeFS()
vector = FakeVectorIndex()
embedder = FakeEmbedder()
empty = _prefetcher(fs, vector, embedder).prefetch("event", _ctx())
dir_uri = "ctx://acct/users/user/memories/events"
event_uri = f"{dir_uri}/20240501000000_event"
read_fail_fs = FailingFS(fail_read=True)
read_fail_fs.children[dir_uri] = [event_uri]
read_failed = _prefetcher(read_fail_fs, vector, embedder).prefetch("event", _ctx())
list_failed = _prefetcher(FailingFS(fail_list=True), vector, embedder).prefetch("event", _ctx())
assert empty.messages == []
assert read_failed.messages == []
assert list_failed.messages == []