"""Unit tests for ExtractionReActLoop — core loop logic only.
Tests the loop control flow, not the individual tool implementations
or JSON parsing (those belong to their own modules).
"""
import json
from unittest.mock import Mock, patch
import pytest
from core.models import CandidateMemory, ContextNode, RequestContext
from extraction.prefetch import PrefetchResult
from extraction.react_loop import ExtractionReActLoop, ReActResult, ReActTrace
from server.internal_tool_usage import InternalToolUsageTracker
_PARSE_PATCH = patch(
"extraction.react_loop.parse_tool_call",
side_effect=lambda name, inp, _reg: (
(cat, inp.get("owner_scope", "user"), CandidateMemory(
category=cat,
owner_scope=inp.get("owner_scope", "user"),
routing_key=inp.get("routing_key", "k"),
abstract=inp.get("abstract", ""),
overview=inp.get("overview", ""),
content=inp.get("content", ""),
confidence=inp.get("confidence", 0.9),
)) if (name.startswith("extract_")) and (cat := name.replace("extract_", ""))
else None
),
)
@pytest.fixture(autouse=True)
def _mock_parse():
with _PARSE_PATCH:
yield
def _ctx():
return RequestContext(account_id="acme", user_id="alice",
agent_id="bob", session_id="s1", trace_id="t1")
def _node(content="hello", metadata=None):
return ContextNode(uri="ctx://acme/x", context_type="MEMORY",
category="profile", level=0, owner_space="user:alice",
abstract="a", overview="o", content=content,
metadata=metadata or {})
def _extract_input(routing_key="k", **kw):
"""Shorthand for extract_profile tool input."""
return {"routing_key": routing_key, "abstract": "a", "overview": "o",
"content": "c", "confidence": 0.9, "owner_scope": "user", **kw}
def _extract_json(**kw):
"""JSON string for a single extract_profile operation."""
return json.dumps([{"name": "extract_profile", "input": _extract_input(**kw)}])
def _make_llm():
llm = Mock()
llm._queue = []
def _call(messages, tools=None, tool_choice="auto"):
return llm._queue.pop(0) if llm._queue else ([], "")
llm.complete_with_tools_messages = _call
return llm
def _make_fs(exists=False, content="hello"):
fs = Mock()
fs.read_node = Mock(return_value=_node(content=content))
fs.list_children = Mock(return_value=[])
fs.exists = Mock(return_value=exists)
return fs
def _make_registry(is_add_only=False):
schema = Mock(is_add_only=is_add_only)
reg = Mock()
reg.get = Mock(return_value=schema)
reg.list_enabled = Mock(return_value=[])
return reg
def _make_uri_resolver():
r = Mock()
r.resolve = Mock(return_value="ctx://acme/test/uri")
r.validate_uri = Mock(return_value=True)
return r
def _make_loop(llm=None, fs=None, **kw):
return ExtractionReActLoop(
llm=llm or _make_llm(),
fs=fs or _make_fs(),
registry=kw.pop("registry", _make_registry()),
uri_resolver=kw.pop("uri_resolver", _make_uri_resolver()),
max_iterations=kw.get("max_iterations", 3),
timeout_seconds=kw.get("timeout_seconds", 30.0),
internal_tool_usage_tracker=kw.get("internal_tool_usage_tracker"),
)
class TestLoopControl:
"""Core iteration, termination, and state management."""
def test_tool_call_then_content(self):
"""LLM calls read → gets result → outputs extraction in next iteration."""
llm = _make_llm()
llm._queue.append(([{"name": "read", "input": {"uri": "ctx://acme/p"}, "id": "1"}], ""))
llm._queue.append(([], _extract_json(routing_key="java-dev")))
result = _make_loop(llm=llm).run("I switched to Java", _ctx())
assert len(result.candidates) == 1
assert result.candidates[0].routing_key == "java-dev"
assert result.iterations == 2
assert "ctx://acme/p" in result.read_uris
def test_max_iterations_exhausted(self):
"""Loop keeps getting tool calls → exhausts max_iterations → empty."""
llm = _make_llm()
for _ in range(5):
llm._queue.append(([{"name": "read", "input": {"uri": "ctx://x"}, "id": "1"}], ""))
result = _make_loop(llm=llm, max_iterations=3).run("conv", _ctx())
assert result.candidates == []
assert result.iterations >= 3
def test_state_reset_between_runs(self):
"""run() resets read_files so second run doesn't leak first run's state."""
llm = _make_llm()
llm._queue.append(([{"name": "read", "input": {"uri": "ctx://a"}, "id": "1"}], ""))
llm._queue.append(([], _extract_json()))
llm._queue.append(([], _extract_json(routing_key="k2")))
loop = _make_loop(llm=llm)
r1 = loop.run("first", _ctx())
assert "ctx://a" in r1.read_uris
r2 = loop.run("second", _ctx())
assert "ctx://a" not in r2.read_uris
def test_prefetch_uris_included(self):
"""PrefetchResult.read_uris are included in final read_uris."""
llm = _make_llm()
llm._queue.append(([], _extract_json()))
prefetch = PrefetchResult(messages=["ctx"], read_uris={"ctx://acme/p"}, listed_uris=set())
result = _make_loop(llm=llm).run("conv", _ctx(), prefetch)
assert "ctx://acme/p" in result.read_uris
def test_internal_tool_usage_records_round_and_tool_tokens(self):
"""Internal oGMem ReAct tool calls are attributed to the current session."""
llm = _make_llm()
llm._queue.append(([
{
"id": "call-read",
"name": "read",
"input": {"uri": "ctx://acme/p"},
"_llm_usage": {
"round_id": "round-read",
"input_tokens": 100,
"output_tokens": 20,
},
}
], ""))
llm._queue.append(([], ""))
tracker = InternalToolUsageTracker()
_make_loop(llm=llm, internal_tool_usage_tracker=tracker).run("conv", _ctx())
stats = tracker.get_stats(session_id="s1", include_rounds=True)
assert stats["summary"]["llm_tool_rounds"] == 1
assert stats["summary"]["tool_calls"] == 1
assert stats["summary"]["total_tokens"] == 120
assert stats["tools"][0]["tool_name"] == "read"
assert stats["tools"][0]["attributed_tokens"] == 120
assert stats["tools"][0]["allocated_tokens"] == 120
assert stats["rounds"][0]["round_id"] == "round-read"
def test_content_present_but_parse_fails(self):
"""LLM returns non-JSON content → _parse_operations returns None → falls to Case 3 → disables tools."""
llm = _make_llm()
seen_choices = []
n = 0
def _track(messages, tools=None, tool_choice="auto"):
nonlocal n
n += 1
seen_choices.append(tool_choice)
if n == 1:
return [], "I think the user likes Python"
return [], _extract_json()
llm.complete_with_tools_messages = _track
result = _make_loop(llm=llm, max_iterations=5).run("conv", _ctx())
assert len(result.candidates) == 1
assert "none" in seen_choices
class TestToolDisable:
"""_disable_tools_for_iteration flips tool_choice to 'none'."""
def test_unknown_tool_disables_next(self):
llm = _make_llm()
seen = []
n = 0
def _track(messages, tools=None, tool_choice="auto"):
nonlocal n
n += 1
seen.append(tool_choice)
if n == 1:
return [{"name": "bogus", "input": {}, "id": "1"}], ""
return [], _extract_json()
llm.complete_with_tools_messages = _track
_make_loop(llm=llm).run("conv", _ctx())
assert seen == ["auto", "none"]
def test_empty_response_disables_next(self):
llm = _make_llm()
seen = []
n = 0
def _track(messages, tools=None, tool_choice="auto"):
nonlocal n
n += 1
seen.append(tool_choice)
if n <= 2:
return ([], "")
return [], _extract_json()
llm.complete_with_tools_messages = _track
_make_loop(llm=llm, max_iterations=5).run("conv", _ctx())
assert "none" in seen
class TestSafetyRefetch:
"""_check_unread_existing_files + _did_safety_reread guard."""
def test_refetch_triggered_for_unread_existing(self):
"""Candidate targets existing unread file → auto-refetch → extra iteration."""
llm = _make_llm()
fs = _make_fs(exists=True, content="old data")
llm._queue.append(([], _extract_json()))
llm._queue.append(([], _extract_json(routing_key="k2")))
result = _make_loop(llm=llm, fs=fs).run("conv", _ctx())
assert any(it.safety_check_triggered for it in result.trace.iterations)
def test_refetch_only_once(self):
"""_did_safety_reread prevents refetch from firing repeatedly."""
llm = _make_llm()
fs = _make_fs(exists=True)
for _ in range(4):
llm._queue.append(([], _extract_json()))
result = _make_loop(llm=llm, fs=fs, max_iterations=5).run("conv", _ctx())
safety_count = sum(1 for it in result.trace.iterations if it.safety_check_triggered)
assert safety_count <= 1
def test_add_only_schema_skips_refetch(self):
"""is_add_only=True means no conflict risk → no refetch."""
llm = _make_llm()
fs = _make_fs(exists=True)
llm._queue.append(([], _extract_json()))
result = _make_loop(llm=llm, fs=fs, registry=_make_registry(is_add_only=True)).run("conv", _ctx())
assert all(not it.safety_check_triggered for it in result.trace.iterations)
class TestErrorHandling:
def test_unknown_tool_returns_error(self):
"""Unknown tool → error dict, loop continues."""
llm = _make_llm()
llm._queue.append(([{"name": "fly", "input": {}, "id": "1"}], ""))
llm._queue.append(([], _extract_json()))
result = _make_loop(llm=llm).run("conv", _ctx())
assert result.tools_used[0]["result"]["error"] == "Unknown tool: fly"
def test_tool_exception_continues_loop(self):
"""Tool throws → error dict, loop doesn't crash."""
llm = _make_llm()
fs = _make_fs()
fs.read_node.side_effect = RuntimeError("boom")
llm._queue.append(([{"name": "read", "input": {"uri": "ctx://x"}, "id": "1"}], ""))
llm._queue.append(([], _extract_json()))
result = _make_loop(llm=llm, fs=fs).run("conv", _ctx())
assert "error" in result.tools_used[0]["result"]
assert len(result.candidates) == 1
def test_llm_exception_returns_empty(self):
"""LLM throws on first call → empty result, no crash."""
llm = _make_llm()
llm.complete_with_tools_messages = Mock(side_effect=RuntimeError("API down"))
result = _make_loop(llm=llm).run("conv", _ctx())
assert result.candidates == []