"""UT: ReplaySnapshotStore record / replay / guard logic."""
import json
import tempfile
import pytest
from akg_agents.core_v2.langgraph_base.replay_guard import (
ReplayGuardError,
ReplaySnapshotStore,
_fingerprint,
_snap_value,
_match,
_is_empty,
GUARD_FIELDS,
_NODE_SNAPSHOTS_KEY,
)
@pytest.fixture
def cache_file():
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
f.write("{}")
f.flush()
yield f.name
class TestSnapValueAndMatch:
def test_short_string_kept_verbatim(self):
assert _snap_value("verifier_result", True) is True
assert _snap_value("conductor_decision", "coder") == "coder"
def test_long_string_hashed(self):
code = "x" * 300
snap = _snap_value("coder_code", code)
assert isinstance(snap, dict)
assert "__hash" in snap
assert snap["__len"] == 300
def test_match_hash(self):
code = "x" * 300
snap = {"__hash": _fingerprint(code), "__len": 300}
assert _match("coder_code", snap, code)
assert not _match("coder_code", snap, "y" * 300)
def test_match_bool(self):
assert _match("verifier_result", True, True)
assert not _match("verifier_result", True, False)
def test_match_none(self):
assert _match("some_field", None, None)
assert not _match("some_field", None, "something")
def test_match_none_vs_empty_string(self):
assert _match("codegen_invalid_reason", None, "")
assert _match("codegen_invalid_reason", "", None)
assert _match("some_field", None, [])
assert _match("some_field", [], None)
class TestReplaySnapshotStore:
def test_record_and_replay_pass(self, cache_file):
store = ReplaySnapshotStore(cache_file, mode="record")
result = {
"coder_code": "class ModelNew: pass",
"codegen_invalid": False,
"codegen_invalid_reason": None,
}
store.record("coder", step=1, result=result)
replay_store = ReplaySnapshotStore(cache_file, mode="replay")
assert replay_store.has_snapshots
replay_store.verify("coder", step=1, result=result)
def test_record_and_replay_mismatch_raises(self, cache_file):
store = ReplaySnapshotStore(cache_file, mode="record")
result = {
"verifier_result": True,
"verifier_error": "",
}
store.record("verifier", step=2, result=result)
replay_store = ReplaySnapshotStore(cache_file, mode="replay")
bad_result = {
"verifier_result": False,
"verifier_error": "env error: compiler not found",
}
with pytest.raises(ReplayGuardError) as exc_info:
replay_store.verify("verifier", step=2, result=bad_result)
assert "verifier_result" in str(exc_info.value)
assert exc_info.value.field == "verifier_result"
def test_unknown_node_ignored(self, cache_file):
store = ReplaySnapshotStore(cache_file, mode="record")
store.record("unknown_node", step=0, result={"foo": "bar"})
assert not store.has_snapshots
def test_missing_snapshot_warns_not_raises(self, cache_file):
store = ReplaySnapshotStore(cache_file, mode="replay")
store.verify("coder", step=99, result={"coder_code": "x"})
def test_conductor_decision_guard(self, cache_file):
store = ReplaySnapshotStore(cache_file, mode="record")
result = {
"conductor_decision": "coder",
"conductor_suggestion": "fix the import",
}
store.record("conductor", step=3, result=result)
replay_store = ReplaySnapshotStore(cache_file, mode="replay")
replay_store.verify("conductor", step=3, result=result)
bad = {
"conductor_decision": "finish",
"conductor_suggestion": "fix the import",
}
with pytest.raises(ReplayGuardError) as exc_info:
replay_store.verify("conductor", step=3, result=bad)
assert exc_info.value.field == "conductor_decision"
def test_long_code_hash_comparison(self, cache_file):
long_code = "import torch\n" * 100 + "class ModelNew:\n pass\n"
store = ReplaySnapshotStore(cache_file, mode="record")
store.record("coder", step=0, result={
"coder_code": long_code,
"codegen_invalid": False,
"codegen_invalid_reason": None,
})
replay_store = ReplaySnapshotStore(cache_file, mode="replay")
replay_store.verify("coder", step=0, result={
"coder_code": long_code,
"codegen_invalid": False,
"codegen_invalid_reason": None,
})
with pytest.raises(ReplayGuardError):
replay_store.verify("coder", step=0, result={
"coder_code": long_code + "# modified",
"codegen_invalid": False,
"codegen_invalid_reason": None,
})
def test_snapshots_persisted_in_cache_file(self, cache_file):
store = ReplaySnapshotStore(cache_file, mode="record")
store.record("verifier", step=0, result={
"verifier_result": True,
"verifier_error": "",
})
with open(cache_file, "r") as f:
data = json.load(f)
assert _NODE_SNAPSHOTS_KEY in data
assert "verifier@0" in data[_NODE_SNAPSHOTS_KEY]
def test_invalid_mode_raises(self, cache_file):
with pytest.raises(ValueError, match="record/replay"):
ReplaySnapshotStore(cache_file, mode="off")