"""Unit tests for Triton autotune patch helpers."""
import pytest
from akg_agents.op.utils import triton_autotune_patch as patch
class TestTritonAutotunePatch:
"""Test restore helpers used by Triton autotune patching."""
def test_wrap_kernel_call_restores_before_and_after(self, monkeypatch):
"""Benchmark wrapper should restore outputs on both sides of the kernel."""
calls = []
arg = {"value": "dirty"}
restore_info = {"saved": {0: "clean"}, "args": [arg]}
def fake_restore(dst, src):
calls.append(("restore", dst["value"], src))
dst["value"] = src
def kernel_call():
calls.append(("kernel", arg["value"]))
arg["value"] = "kernel_result"
return "ok"
monkeypatch.setattr(patch, "akg_restore_copy", fake_restore)
wrapped = patch._wrap_kernel_call_with_restore(kernel_call, restore_info)
result = wrapped()
assert result == "ok"
assert arg["value"] == "clean"
assert calls == [
("restore", "dirty", "clean"),
("kernel", "clean"),
("restore", "kernel_result", "clean"),
]
def test_wrap_kernel_call_restores_after_exception(self, monkeypatch):
"""Even failing configs should leave outputs restored for later configs."""
arg = {"value": "dirty"}
restore_info = {"saved": {0: "clean"}, "args": [arg]}
def fake_restore(dst, src):
dst["value"] = src
def kernel_call():
arg["value"] = "broken"
raise RuntimeError("boom")
monkeypatch.setattr(patch, "akg_restore_copy", fake_restore)
wrapped = patch._wrap_kernel_call_with_restore(kernel_call, restore_info)
with pytest.raises(RuntimeError, match="boom"):
wrapped()
assert arg["value"] == "clean"
def test_wrap_kernel_call_prevents_partial_write_pollution(self, monkeypatch):
"""A later config should not inherit untouched tail values from a previous run."""
output = {"value": [9.0, 9.0, 9.0, 9.0]}
restore_info = {"saved": {0: [0.0, 0.0, 0.0, 0.0]}, "args": [output]}
def fake_restore(dst, src):
dst["value"] = list(src)
def bad_kernel_call():
output["value"][0] = 1.0
output["value"][1] = 1.0
monkeypatch.setattr(patch, "akg_restore_copy", fake_restore)
wrapped = patch._wrap_kernel_call_with_restore(bad_kernel_call, restore_info)
wrapped()
assert output["value"] == [0.0, 0.0, 0.0, 0.0]
def test_wrap_kernel_call_without_restore_info_returns_original(self):
"""No restore_value should keep the original benchmark callback untouched."""
def kernel_call():
return "ok"
wrapped = patch._wrap_kernel_call_with_restore(kernel_call, None)
assert wrapped is kernel_call