"""ReAct loop for memory extraction with tool use.
Implements iterative tool-assisted extraction with safety checks.
"""
from __future__ import annotations
import json
import time
from dataclasses import dataclass, field
from logging import getLogger
from typing import Any
from uuid import uuid4
from core.interfaces import ContextFS, LLM
from core.models import CandidateMemory, RequestContext
from core.uri_resolver import URIResolver
from extraction.prefetch import MemoryPrefetcher, PrefetchResult
from extraction.schemas.registry import SchemaRegistry
from extraction.tool_builder import build_extraction_tools, parse_tool_call
logger = getLogger(__name__)
_READ_TOOL = {
"name": "read",
"description": "Read a memory node by URI.",
"input_schema": {
"type": "object",
"properties": {
"uri": {
"type": "string",
"description": "URI of the memory node to read.",
}
},
"required": ["uri"],
},
}
_LIST_TOOL = {
"name": "list",
"description": "List child nodes under a directory URI.",
"input_schema": {
"type": "object",
"properties": {
"uri": {
"type": "string",
"description": "URI of the directory to list.",
}
},
"required": ["uri"],
},
}
_RELATIONS_TOOL = {
"name": "get_relations",
"description": "Get related nodes for a URI. Returns one-hop neighbors.",
"input_schema": {
"type": "object",
"properties": {"uri": {"type": "string"}},
"required": ["uri"],
},
}
_ACCESS_STATS_TOOL = {
"name": "get_access_stats",
"description": "Get access stats for a node: last access time, 30d hit count.",
"input_schema": {
"type": "object",
"properties": {"uri": {"type": "string"}},
"required": ["uri"],
},
}
@dataclass
class ReActIterationTrace:
"""Trace data for a single ReAct iteration."""
iteration: int
latency_seconds: float
tool_calls_count: int = 0
tool_call_names: list = field(default_factory=list)
content_length: int = 0
safety_check_triggered: bool = False
refetch_uris: list = field(default_factory=list)
@dataclass
class ReActTrace:
"""Full trace for a ReAct extraction run."""
total_iterations: int = 0
total_latency_seconds: float = 0.0
iterations: list = field(default_factory=list)
final_candidate_count: int = 0
total_read_uris: int = 0
@dataclass
class ReActResult:
"""Result of a ReAct extraction loop."""
candidates: list[CandidateMemory] = field(default_factory=list)
tools_used: list[dict] = field(default_factory=list)
iterations: int = 0
read_uris: set[str] = field(default_factory=set)
trace: Optional[ReActTrace] = None
class ExtractionReActLoop:
"""ReAct loop for memory extraction with tool use.
Workflow:
1. Start with system prompt + prefetch context + conversation
2. While iteration < max_iterations:
a. Call LLM with tools (read, list, extract_*)
b. If tool calls: execute and continue
c. If content: parse operations → safety check → return
d. If neither: disable tools and continue
3. Return extracted candidates with metadata
Safety features:
- Refetch check: re-read existing files before write operations
- URI validation: ensure all URIs are within allowed namespace
- Single safety re-read: prevent infinite loops
"""
def __init__(
self,
llm: LLM,
fs: ContextFS,
registry: SchemaRegistry,
uri_resolver: URIResolver,
prefetcher: MemoryPrefetcher | None = None,
max_iterations: int = 3,
timeout_seconds: float = 30.0,
internal_tool_usage_tracker: Any | None = None,
):
"""Initialize the ReAct loop.
Args:
llm: LLM instance for completion with tools
fs: ContextFS for memory node operations
registry: SchemaRegistry for schema definitions
uri_resolver: URIResolver for URI generation and validation
prefetcher: Optional MemoryPrefetcher for context loading
max_iterations: Maximum number of ReAct iterations
timeout_seconds: Maximum seconds to run before graceful degradation
"""
self._llm = llm
self._fs = fs
self._registry = registry
self._uri_resolver = uri_resolver
self._prefetcher = prefetcher
self._max_iterations = max_iterations
self._timeout_seconds = timeout_seconds
self._internal_tool_usage_tracker = internal_tool_usage_tracker
self._extraction_tools = build_extraction_tools(registry)
self._all_tools = [_READ_TOOL, _LIST_TOOL, _RELATIONS_TOOL, _ACCESS_STATS_TOOL] + self._extraction_tools
self._read_files: set[str]
self._disable_tools_for_iteration: bool
self._did_safety_reread: bool
def run(
self,
conversation_text: str,
ctx: RequestContext,
prefetch_result: PrefetchResult | None = None,
) -> ReActResult:
"""Run the ReAct loop for memory extraction.
Args:
conversation_text: Text content of the conversation to extract from
ctx: RequestContext for multi-tenant isolation
prefetch_result: Optional prefetch result with pre-loaded context
Returns:
ReActResult with extracted candidates and metadata
"""
self._read_files = set()
self._disable_tools_for_iteration = False
self._did_safety_reread = False
if prefetch_result:
self._read_files.update(prefetch_result.read_uris)
messages = self._build_initial_messages(conversation_text, prefetch_result, ctx)
trace = ReActTrace()
run_start = time.monotonic()
self._run_start_time = run_start
iteration = 0
max_iterations = self._max_iterations
final_candidates = []
tools_used = []
while iteration < max_iterations:
iteration += 1
iter_start = time.monotonic()
logger.info(f"ReAct iteration {iteration}/{max_iterations}")
elapsed = time.monotonic() - self._run_start_time
if elapsed > self._timeout_seconds and iteration > 1:
logger.warning(
f"ReAct timeout ({elapsed:.1f}s > {self._timeout_seconds}s) "
f"at iteration {iteration}, returning current results"
)
break
is_last_iteration = iteration >= max_iterations
if is_last_iteration:
messages.append({
"role": "user",
"content": (
"You have reached the maximum number of tool call iterations. "
"Do not call any more tools - return your final result directly now."
),
})
tool_calls, content = self._call_llm(messages)
iter_latency = time.monotonic() - iter_start
iter_trace = ReActIterationTrace(
iteration=iteration,
latency_seconds=iter_latency,
tool_calls_count=len(tool_calls) if tool_calls else 0,
tool_call_names=[tc.get("name", "") for tc in tool_calls] if tool_calls else [],
content_length=len(content) if content else 0,
)
if tool_calls:
round_id = self._record_round(tool_calls, ctx)
has_unknown = self._execute_tool_calls(
messages, tool_calls, tools_used, ctx, round_id
)
if has_unknown:
self._disable_tools_for_iteration = True
logger.info("Unknown tool called, disabling tools for next iteration")
if iteration >= max_iterations:
max_iterations += 1
self._disable_tools_for_iteration = True
logger.info(f"Extended max_iterations to {max_iterations} for tool call")
trace.iterations.append(iter_trace)
continue
if content:
candidates = self._parse_operations(content)
if candidates:
refetch_uris = self._check_unread_existing_files(candidates, ctx)
if refetch_uris:
logger.info(f"Found unread existing files: {refetch_uris}, refetching...")
iter_trace.safety_check_triggered = True
iter_trace.refetch_uris = refetch_uris
self._add_refetch_results(messages, refetch_uris, ctx)
if iteration >= max_iterations:
max_iterations += 1
logger.info(f"Extended max_iterations to {max_iterations} for refetch")
trace.iterations.append(iter_trace)
continue
errors = self._validate_operations_uris(candidates, ctx)
if errors:
logger.error(f"URI validation errors: {errors}")
trace.iterations.append(iter_trace)
final_candidates = candidates
break
trace.iterations.append(iter_trace)
if is_last_iteration:
logger.warning("Last iteration: returning empty candidates")
break
logger.warning(
f"LLM returned neither tool calls nor content (iteration {iteration}/{max_iterations})"
)
self._disable_tools_for_iteration = True
trace.total_iterations = iteration
trace.total_latency_seconds = time.monotonic() - run_start
trace.final_candidate_count = len(final_candidates)
trace.total_read_uris = len(self._read_files)
result = ReActResult(
candidates=final_candidates,
tools_used=tools_used,
iterations=iteration,
read_uris=self._read_files.copy(),
trace=trace,
)
return result
def _build_initial_messages(
self,
conversation_text: str,
prefetch_result: PrefetchResult | None,
ctx: RequestContext,
) -> list[dict]:
"""Build initial messages for the LLM.
Args:
conversation_text: Conversation text to extract from
prefetch_result: Optional prefetch result
ctx: RequestContext
Returns:
List of messages in OpenAI format
"""
messages = []
system_prompt = self._build_system_prompt(ctx)
messages.append({"role": "system", "content": system_prompt})
if prefetch_result and prefetch_result.messages:
for msg in prefetch_result.messages:
messages.append({"role": "system", "content": msg})
messages.append({
"role": "user",
"content": f"## Conversation\n\n{conversation_text}",
})
return messages
def _build_system_prompt(self, ctx: RequestContext) -> str:
"""Build system prompt for extraction.
Args:
ctx: RequestContext
Returns:
System prompt string
"""
tool_descriptions = []
for tool in self._all_tools:
desc = f"- {tool['name']}: {tool['description']}"
if tool.get("input_schema", {}).get("properties"):
params = ", ".join(tool["input_schema"]["properties"].keys())
desc += f" (params: {params})"
tool_descriptions.append(desc)
prompt = f"""# Memory Extraction Task
You are a memory extraction assistant. Your task is to extract structured memories from the conversation.
## Available Tools
{chr(10).join(tool_descriptions)}
## Instructions
1. Use `read` and `list` tools to explore existing memories before making decisions.
2. Use `extract_*` tools to create new memories or update existing ones.
3. Always read existing content before deciding to merge or create new entries.
4. Return your final result using the extraction tool calls in your response.
## Context
Account: {ctx.account_id}
User: {ctx.user_id}
Agent: {ctx.agent_id}
"""
return prompt
def _call_llm(
self, messages: list[dict]
) -> tuple[list[dict] | None, str | None]:
"""Call LLM with tools.
Args:
messages: Current message list
Returns:
Tuple of (tool_calls, content) - one will be None
"""
tools = None
tool_choice = "none"
if not self._disable_tools_for_iteration:
tools = self._all_tools
tool_choice = "auto"
try:
result = self._llm.complete_with_tools_messages(
messages=messages,
tools=tools,
tool_choice=tool_choice,
)
if isinstance(result, tuple) and len(result) == 2:
return result
if isinstance(result, list):
return result, None
if isinstance(result, str):
return None, result
return None, None
except Exception as e:
logger.error(f"LLM call failed: {e}")
return None, None
def _execute_tool_calls(
self,
messages: list[dict],
tool_calls: list[dict],
tools_used: list[dict],
ctx: RequestContext,
round_id: str,
) -> bool:
"""Execute tool calls and add results to messages.
Args:
messages: Message list to append to
tool_calls: List of tool call dicts
tools_used: List to track tool usage
ctx: RequestContext
Returns:
True if any unknown tool was encountered
"""
has_unknown = False
for tool_call in tool_calls:
tool_name = tool_call.get("name", "")
tool_input = tool_call.get("input", {})
call_id = tool_call.get("id", "unknown")
started = time.monotonic()
result = self._execute_single_tool(tool_name, tool_input, ctx)
duration_ms = (time.monotonic() - started) * 1000
has_error = isinstance(result, dict) and bool(result.get("error"))
if isinstance(result, dict) and result.get("error", "").startswith("Unknown tool"):
has_unknown = True
self._record_tool_call(
ctx=ctx,
round_id=round_id,
tool_name=tool_name,
status="error" if has_error else "success",
duration_ms=duration_ms,
error_type=str(result.get("error", "")) if isinstance(result, dict) else "",
)
if tool_name == "read" and tool_input.get("uri"):
self._read_files.add(tool_input["uri"])
tools_used.append({
"tool_name": tool_name,
"params": tool_input,
"result": result,
})
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": json.dumps(tool_input),
},
}],
})
messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": json.dumps(result) if isinstance(result, dict) else str(result),
})
return has_unknown
def _record_round(self, tool_calls: list[dict], ctx: RequestContext) -> str:
round_id = _round_id_from_tool_calls(tool_calls)
if not self._internal_tool_usage_tracker:
return round_id
usage = _usage_from_tool_calls(tool_calls)
self._internal_tool_usage_tracker.record_round(
account_id=ctx.account_id,
user_id=ctx.user_id,
session_id=ctx.session_id,
pipeline="extraction.lazy",
round_id=round_id,
tool_names=[
call.get("name") or call.get("tool", "")
for call in tool_calls
if call.get("name") or call.get("tool")
],
model=usage.get("model", ""),
input_tokens=usage.get("input_tokens", 0),
output_tokens=usage.get("output_tokens", 0),
cache_read=usage.get("cache_read", 0),
cache_write=usage.get("cache_write", 0),
)
return round_id
def _record_tool_call(
self,
*,
ctx: RequestContext,
round_id: str,
tool_name: str,
status: str,
duration_ms: float,
error_type: str = "",
) -> None:
if not self._internal_tool_usage_tracker or not tool_name:
return
self._internal_tool_usage_tracker.record_tool_call(
account_id=ctx.account_id,
session_id=ctx.session_id,
pipeline="extraction.lazy",
round_id=round_id,
tool_name=tool_name,
status=status,
duration_ms=duration_ms,
error_type=error_type,
)
def _execute_single_tool(
self, tool_name: str, tool_input: dict, ctx: RequestContext
) -> dict | str:
"""Execute a single tool call.
Args:
tool_name: Name of the tool to execute
tool_input: Input parameters for the tool
ctx: RequestContext
Returns:
Tool result as dict or string
"""
try:
if tool_name == "read":
uri = tool_input.get("uri", "")
if not uri:
return {"error": "Missing required parameter: uri"}
node = self._fs.read_node(uri, ctx)
return {
"uri": uri,
"abstract": node.abstract,
"overview": node.overview,
"content": node.content,
"metadata": node.metadata,
}
elif tool_name == "list":
uri = tool_input.get("uri", "")
if not uri:
return {"error": "Missing required parameter: uri"}
children = self._fs.list_children(uri, ctx)
return {
"uri": uri,
"children": children,
}
elif tool_name == "get_relations":
uri = tool_input.get("uri", "")
return self._execute_get_relations(uri, ctx)
elif tool_name == "get_access_stats":
uri = tool_input.get("uri", "")
return self._execute_get_stats(uri, ctx)
elif tool_name.startswith("extract_"):
return {
"tool": tool_name,
"input": tool_input,
"status": "recorded",
}
else:
return {"error": f"Unknown tool: {tool_name}"}
except Exception as e:
logger.error(f"Tool execution failed for {tool_name}: {e}")
return {"error": str(e)}
def _parse_operations(self, content: str) -> list[CandidateMemory] | None:
"""Parse extraction operations from LLM content.
Args:
content: LLM response content
Returns:
List of CandidateMemory or None if parsing fails
"""
try:
data = json.loads(content)
candidates = []
if isinstance(data, list):
items = data
elif isinstance(data, dict):
if "tool_calls" in data:
items = data["tool_calls"]
elif "operations" in data:
items = data["operations"]
else:
items = [data]
else:
logger.warning(f"Unexpected JSON structure: {type(data)}")
return None
for item in items:
if not isinstance(item, dict):
continue
tool_name = item.get("name", "")
tool_input = item.get("input", item)
if not tool_name.startswith("extract_"):
for key in item.keys():
if key.startswith("extract_"):
tool_name = key
tool_input = item[key]
break
if not tool_name.startswith("extract_"):
continue
parsed = parse_tool_call(tool_name, tool_input, self._registry)
if parsed:
memory_type, owner_scope, candidate = parsed
candidates.append(candidate)
logger.info(f"Parsed candidate: {memory_type}/{candidate.routing_key}")
return candidates if candidates else None
except json.JSONDecodeError:
logger.warning("Failed to parse content as JSON")
return None
except Exception as e:
logger.error(f"Error parsing operations: {e}")
return None
def _check_unread_existing_files(
self, candidates: list[CandidateMemory], ctx: RequestContext
) -> list[str]:
"""Check if candidates target existing files that weren't read.
Args:
candidates: List of CandidateMemory to check
ctx: RequestContext
Returns:
List of URIs that need to be refetched
"""
if self._did_safety_reread:
return []
refetch_uris = []
for candidate in candidates:
schema = self._registry.get(candidate.category)
if schema is None:
continue
if schema.is_add_only:
continue
fields = self._candidate_to_fields(candidate)
try:
uri = self._uri_resolver.resolve(candidate.category, fields, ctx)
except Exception as e:
logger.warning(f"Failed to resolve URI for {candidate.category}: {e}")
continue
if uri in self._read_files:
continue
if self._fs.exists(uri, ctx):
refetch_uris.append(uri)
return refetch_uris
def _add_refetch_results(
self, messages: list[dict], refetch_uris: list[str], ctx: RequestContext
) -> None:
"""Add refetched file contents to messages.
Args:
messages: Message list to append to
refetch_uris: URIs to refetch
ctx: RequestContext
"""
self._did_safety_reread = True
for uri in refetch_uris:
try:
node = self._fs.read_node(uri, ctx)
self._read_files.add(uri)
result = {
"uri": uri,
"abstract": node.abstract,
"overview": node.overview,
"content": node.content,
"metadata": node.metadata,
}
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": f"refetch_{uri}",
"type": "function",
"function": {
"name": "read",
"arguments": json.dumps({"uri": uri}),
},
}],
})
messages.append({
"role": "tool",
"tool_call_id": f"refetch_{uri}",
"content": json.dumps(result),
})
except Exception as e:
logger.warning(f"Failed to refetch {uri}: {e}")
messages.append({
"role": "user",
"content": (
"SAFETY REMINDER: The files above were automatically read because they "
"exist and you didn't read them before deciding to write. Please consider "
"the existing content when making write decisions. You can now output "
"updated operations."
),
})
def _validate_operations_uris(
self, candidates: list[CandidateMemory], ctx: RequestContext
) -> list[str]:
"""Validate that all candidate URIs are within allowed namespace.
Args:
candidates: List of CandidateMemory to validate
ctx: RequestContext
Returns:
List of error strings (empty if all valid)
"""
errors = []
for candidate in candidates:
fields = self._candidate_to_fields(candidate)
try:
uri = self._uri_resolver.resolve(candidate.category, fields, ctx)
if not self._uri_resolver.validate_uri(uri, ctx):
errors.append(
f"Invalid URI for {candidate.category}/{candidate.routing_key}: {uri}"
)
except Exception as e:
errors.append(
f"Failed to validate URI for {candidate.category}/{candidate.routing_key}: {e}"
)
return errors
def _candidate_to_fields(self, candidate: CandidateMemory) -> dict:
"""Convert CandidateMemory to fields dict for URI resolution.
Args:
candidate: CandidateMemory to convert
Returns:
Dict of field names to values
"""
fields = {
"routing_key": candidate.routing_key,
"abstract": candidate.abstract,
"overview": candidate.overview,
"content": candidate.content,
"confidence": candidate.confidence,
}
if candidate.when:
fields["when"] = candidate.when
if candidate.who:
fields["who"] = candidate.who
if candidate.where:
fields["where"] = candidate.where
if candidate.tool_stats:
fields.update(candidate.tool_stats)
return fields
def _execute_get_relations(self, uri: str, ctx: RequestContext) -> dict:
"""Execute get_relations tool call.
Args:
uri: URI to get relations for
ctx: RequestContext
Returns:
Dict with uri and relations list
"""
try:
if hasattr(self._fs, 'get_relations'):
edges = self._fs.get_relations(uri, ctx)
return {"uri": uri, "relations": [
{"target": e.target_uri, "weight": e.weight, "reason": e.reason}
for e in edges
]}
return {"uri": uri, "relations": [], "note": "Not available"}
except Exception as e:
return {"error": str(e)}
def _execute_get_stats(self, uri: str, ctx: RequestContext) -> dict:
"""Execute get_access_stats tool call.
Args:
uri: URI to get access stats for
ctx: RequestContext
Returns:
Dict with uri, last_accessed_at, and hit_count_30d
"""
try:
node = self._fs.read_node(uri, ctx)
metadata = node.metadata or {}
return {
"uri": uri,
"last_accessed_at": metadata.get("last_accessed_at"),
"hit_count_30d": metadata.get("hit_count_30d", 0),
}
except Exception as e:
return {"error": str(e)}
def _usage_from_tool_calls(tool_calls: list[dict]) -> dict:
for call in tool_calls or []:
usage = call.get("_llm_usage")
if isinstance(usage, dict):
return usage
return {}
def _round_id_from_tool_calls(tool_calls: list[dict]) -> str:
usage = _usage_from_tool_calls(tool_calls)
if usage.get("round_id"):
return str(usage["round_id"])
for call in tool_calls or []:
if call.get("id"):
return str(call["id"])
return f"llm_{uuid4().hex[:12]}"