"""
adaptive_search 模块单元测试
覆盖 SuccessDB、UCBParentSelector、AsyncTaskPool、数据类等核心组件的
接口正确性、参数校验和纯逻辑行为,不依赖 LLM 或 GPU。
"""
import asyncio
import json
import math
import os
import tempfile
from dataclasses import asdict
import pytest
from akg_agents.op.adaptive_search.success_db import SuccessDB, SuccessRecord
from akg_agents.op.adaptive_search.ucb_selector import UCBParentSelector
from akg_agents.op.adaptive_search.task_pool import (
AsyncTaskPool,
PendingTask,
TaskResult,
TaskWrapper,
TaskStatus,
)
from akg_agents.op.adaptive_search.controller import SearchConfig
from akg_agents.op.adaptive_search.task_generator import TaskGeneratorConfig
class TestSuccessRecord:
def test_default_fields(self):
r = SuccessRecord(id="t1", impl_code="pass")
assert r.id == "t1"
assert r.impl_code == "pass"
assert r.sketch == ""
assert r.gen_time == float("inf")
assert r.speedup == 0.0
assert r.selection_count == 0
assert r.generation == 0
assert r.parent_id is None
assert r.sketch_generated is False
def test_roundtrip_dict(self):
r = SuccessRecord(id="t2", impl_code="code", gen_time=1.5, speedup=2.0)
d = r.to_dict()
r2 = SuccessRecord.from_dict(d)
assert r2.id == r.id
assert r2.gen_time == r.gen_time
assert r2.speedup == r.speedup
def test_to_inspiration_format(self):
r = SuccessRecord(id="t3", impl_code="kernel", sketch="sk", generation=2)
insp = r.to_inspiration()
assert insp["id"] == "t3"
assert insp["sketch"] == "sk"
assert insp["impl_code"] == "kernel"
assert insp["strategy_mode"] == "adaptive_search"
assert insp["generation"] == 2
assert insp["is_parent"] is False
class TestSuccessDB:
def test_empty_db(self):
db = SuccessDB()
assert db.is_empty()
assert db.size() == 0
assert db.get_best_record() is None
assert db.get_best_gen_time() == float("inf")
assert db.get_best_speedup() == 0.0
def test_add_and_get(self):
db = SuccessDB()
r = SuccessRecord(id="r1", impl_code="code1", gen_time=10.0)
assert db.add(r) is True
assert db.size() == 1
assert not db.is_empty()
fetched = db.get("r1")
assert fetched is not None
assert fetched.id == "r1"
assert fetched.gen_time == 10.0
def test_add_duplicate_better_performance(self):
db = SuccessDB()
db.add(SuccessRecord(id="r1", impl_code="v1", gen_time=10.0))
result = db.add(SuccessRecord(id="r1", impl_code="v2", gen_time=5.0))
assert result is True
assert db.get("r1").gen_time == 5.0
def test_add_duplicate_worse_performance(self):
db = SuccessDB()
db.add(SuccessRecord(id="r1", impl_code="v1", gen_time=5.0))
result = db.add(SuccessRecord(id="r1", impl_code="v2", gen_time=10.0))
assert result is False
assert db.get("r1").gen_time == 5.0
def test_get_all_preserves_order(self):
db = SuccessDB()
for i in range(5):
db.add(SuccessRecord(id=f"r{i}", impl_code=f"c{i}", gen_time=float(10 - i)))
ids = [r.id for r in db.get_all()]
assert ids == ["r0", "r1", "r2", "r3", "r4"]
def test_get_all_sorted_by_performance(self):
db = SuccessDB()
db.add(SuccessRecord(id="slow", impl_code="s", gen_time=100.0))
db.add(SuccessRecord(id="fast", impl_code="f", gen_time=1.0))
db.add(SuccessRecord(id="mid", impl_code="m", gen_time=50.0))
sorted_ids = [r.id for r in db.get_all_sorted_by_performance()]
assert sorted_ids == ["fast", "mid", "slow"]
def test_increment_decrement_selection(self):
db = SuccessDB()
db.add(SuccessRecord(id="r1", impl_code="c"))
db.increment_selection("r1")
db.increment_selection("r1")
assert db.get("r1").selection_count == 2
assert db.get_total_selections() == 2
db.decrement_selection("r1")
assert db.get("r1").selection_count == 1
assert db.get_total_selections() == 1
def test_decrement_does_not_go_negative(self):
db = SuccessDB()
db.add(SuccessRecord(id="r1", impl_code="c"))
db.decrement_selection("r1")
assert db.get("r1").selection_count == 0
def test_add_from_result(self):
db = SuccessDB()
state = {
"coder_code": "def kernel(): pass",
"profile_res": {"gen_time": 3.14, "speedup": 1.5},
"verifier_result": {"passed": True},
}
record = db.add_from_result("task1", state, generation=1, parent_id="p0")
assert record is not None
assert record.gen_time == 3.14
assert record.speedup == 1.5
assert record.generation == 1
assert record.parent_id == "p0"
assert record.sketch_generated is False
def test_update_sketch(self):
db = SuccessDB()
db.add(SuccessRecord(id="r1", impl_code="c"))
assert db.update_sketch("r1", "new_sketch") is True
assert db.get("r1").sketch == "new_sketch"
assert db.get("r1").sketch_generated is True
def test_update_sketch_nonexistent(self):
db = SuccessDB()
assert db.update_sketch("nonexistent", "sk") is False
def test_get_pending_sketch_records(self):
db = SuccessDB()
db.add(SuccessRecord(id="r1", impl_code="c1"))
db.add(SuccessRecord(id="r2", impl_code="c2"))
db.update_sketch("r1", "done")
pending = db.get_pending_sketch_records()
assert len(pending) == 1
assert pending[0].id == "r2"
def test_statistics(self):
db = SuccessDB()
stats = db.get_statistics()
assert stats["total_records"] == 0
db.add(SuccessRecord(id="r1", impl_code="c", gen_time=10.0, speedup=2.0))
db.add(SuccessRecord(id="r2", impl_code="c", gen_time=20.0, speedup=3.0))
stats = db.get_statistics()
assert stats["total_records"] == 2
assert stats["best_gen_time"] == 10.0
assert stats["worst_gen_time"] == 20.0
assert stats["best_speedup"] == 3.0
def test_persistence(self, tmp_path):
db1 = SuccessDB(storage_dir=str(tmp_path))
db1.add(SuccessRecord(id="r1", impl_code="persistent", gen_time=5.0))
db2 = SuccessDB(storage_dir=str(tmp_path))
assert db2.size() == 1
assert db2.get("r1").gen_time == 5.0
assert db2.get("r1").impl_code == "persistent"
def test_increment_selection_not_auto_persisted(self, tmp_path):
"""increment_selection 不会自动持久化,这是当前设计行为"""
db1 = SuccessDB(storage_dir=str(tmp_path))
db1.add(SuccessRecord(id="r1", impl_code="c", gen_time=5.0))
db1.increment_selection("r1")
assert db1.get("r1").selection_count == 1
db2 = SuccessDB(storage_dir=str(tmp_path))
assert db2.get("r1").selection_count == 0
class TestUCBParentSelector:
def _make_db(self, records):
db = SuccessDB()
for r in records:
db.add(r)
return db
def test_select_from_empty_db(self):
db = SuccessDB()
selector = UCBParentSelector(db)
assert selector.select() is None
def test_select_single_record(self):
db = self._make_db([
SuccessRecord(id="only", impl_code="c", gen_time=10.0)
])
selector = UCBParentSelector(db)
result = selector.select()
assert result is not None
assert result.id == "only"
assert db.get("only").selection_count == 1
def test_quality_ranking(self):
db = self._make_db([
SuccessRecord(id="best", impl_code="c", gen_time=1.0),
SuccessRecord(id="worst", impl_code="c", gen_time=100.0),
])
selector = UCBParentSelector(db, random_factor=0.0)
best_q = selector._compute_quality(db.get("best"), db.get_all())
worst_q = selector._compute_quality(db.get("worst"), db.get_all())
assert best_q == 1.0
assert worst_q == 0.0
def test_quality_single_record(self):
db = self._make_db([
SuccessRecord(id="solo", impl_code="c", gen_time=5.0)
])
selector = UCBParentSelector(db)
q = selector._compute_quality(db.get("solo"), db.get_all())
assert q == 1.0
def test_exploration_decreases_with_selection(self):
db = self._make_db([
SuccessRecord(id="r1", impl_code="c", gen_time=5.0)
])
selector = UCBParentSelector(db)
e0 = selector._compute_exploration(selection_count=0, total_selections=10)
e5 = selector._compute_exploration(selection_count=5, total_selections=10)
assert e0 > e5
def test_exploration_zero_total(self):
db = SuccessDB()
selector = UCBParentSelector(db, exploration_coef=1.414)
e = selector._compute_exploration(selection_count=0, total_selections=0)
assert e == 1.414
def test_greedy_select_picks_highest_ucb(self):
db = self._make_db([
SuccessRecord(id="fast", impl_code="c", gen_time=1.0),
SuccessRecord(id="slow", impl_code="c", gen_time=100.0),
])
selector = UCBParentSelector(db, random_factor=0.0)
for _ in range(5):
result = selector.select()
assert result is not None
def test_softmax_select_returns_valid(self):
records = [
SuccessRecord(id=f"r{i}", impl_code="c", gen_time=float(i + 1))
for i in range(5)
]
db = self._make_db(records)
selector = UCBParentSelector(db, use_softmax=True, softmax_temperature=0.5)
result = selector.select()
assert result is not None
assert result.id in {f"r{i}" for i in range(5)}
def test_selection_stats(self):
db = self._make_db([
SuccessRecord(id="r1", impl_code="c", gen_time=5.0),
SuccessRecord(id="r2", impl_code="c", gen_time=10.0),
])
selector = UCBParentSelector(db, random_factor=0.0)
selector.select()
selector.select()
stats = selector.get_selection_stats()
assert stats["total_records"] == 2
assert stats["total_selections"] == 2
def test_empty_stats(self):
db = SuccessDB()
selector = UCBParentSelector(db)
stats = selector.get_selection_stats()
assert stats["total_records"] == 0
assert stats["total_selections"] == 0
class TestAsyncTaskPool:
def test_init(self):
pool = AsyncTaskPool(max_concurrent=4)
assert pool.max_concurrent == 4
def test_pending_task_dataclass(self):
async def dummy():
return True
pt = PendingTask(task_id="t1", coroutine_factory=dummy, generation=1, parent_id="p0")
assert pt.task_id == "t1"
assert pt.generation == 1
assert pt.parent_id == "p0"
assert callable(pt.coroutine_factory)
def test_task_result_dataclass(self):
tr = TaskResult(
task_id="t1", success=True,
final_state={"coder_code": "pass"},
generation=2, parent_id="p1"
)
assert tr.task_id == "t1"
assert tr.success is True
assert tr.generation == 2
assert tr.error is None
def test_task_result_failure(self):
tr = TaskResult(
task_id="t1", success=False,
final_state={}, error="timeout"
)
assert tr.success is False
assert tr.error == "timeout"
def test_task_status_enum(self):
assert TaskStatus.PENDING.value == "pending"
assert TaskStatus.RUNNING.value == "running"
assert TaskStatus.COMPLETED.value == "completed"
@pytest.mark.asyncio
async def test_context_manager(self):
async with AsyncTaskPool(max_concurrent=2) as pool:
assert pool.max_concurrent == 2
class TestSearchConfig:
def test_defaults(self):
cfg = SearchConfig()
assert cfg.max_concurrent == 8
assert cfg.initial_task_count == 8
assert cfg.max_total_tasks == 100
assert cfg.exploration_coef == 1.414
assert cfg.random_factor == 0.1
assert cfg.use_softmax is False
assert cfg.softmax_temperature == 1.0
assert cfg.use_evolution_controller is False
assert cfg.storage_dir is None
assert cfg.poll_interval == 0.5
def test_custom_values(self):
cfg = SearchConfig(
max_concurrent=16,
max_total_tasks=200,
exploration_coef=2.0,
use_softmax=True,
)
assert cfg.max_concurrent == 16
assert cfg.max_total_tasks == 200
assert cfg.exploration_coef == 2.0
assert cfg.use_softmax is True
class TestTaskGeneratorConfig:
def test_defaults(self):
cfg = TaskGeneratorConfig()
assert cfg.inspiration_sample_num == 3
assert cfg.use_tiered_sampling is True
assert cfg.handwrite_sample_num == 2
assert cfg.handwrite_decay_rate == 2.0
assert cfg.meta_prompts_per_task == 1
def test_custom_values(self):
cfg = TaskGeneratorConfig(
inspiration_sample_num=5,
handwrite_sample_num=0,
)
assert cfg.inspiration_sample_num == 5
assert cfg.handwrite_sample_num == 0
class TestModuleExports:
def test_package_exports(self):
from akg_agents.op.adaptive_search import (
adaptive_search,
adaptive_search_from_config,
load_adaptive_search_config,
SuccessDB,
SuccessRecord,
UCBParentSelector,
AsyncTaskPool,
PendingTask,
TaskResult,
TaskWrapper,
TaskGenerator,
TaskGeneratorConfig,
AdaptiveSearchController,
SearchConfig,
)
assert callable(adaptive_search)
assert callable(adaptive_search_from_config)
assert callable(load_adaptive_search_config)