"""Tests for web_ui.result_store module."""
from __future__ import annotations
import tempfile
from pathlib import Path
import pytest
from web_ui.result_store import (
_enrich_optimizer_summary,
_extract_optimizer_top1_from_log,
_infer_optimizer_no_result_reason_from_params,
_resolve_log_path,
ResultStore,
)
from web_ui.schemas import ExperimentResult, ExperimentTask
class TestResolveLogPath:
"""Tests for _resolve_log_path function."""
def test_resolve_log_path_empty_string(self) -> None:
"""Test with empty string."""
result = _resolve_log_path("")
assert result == Path("")
def test_resolve_log_path_none(self) -> None:
"""Test with None."""
result = _resolve_log_path(None)
assert result == Path("")
def test_resolve_log_path_unix_style(self) -> None:
"""Test with Unix-style path."""
result = _resolve_log_path("logs/test.log")
assert result == Path("logs/test.log")
def test_resolve_log_path_windows_style(self) -> None:
"""Test with Windows-style path."""
result = _resolve_log_path(r"logs\test.log")
assert result == Path("logs/test.log")
def test_resolve_log_path_mixed_separators(self) -> None:
"""Test with mixed separators."""
result = _resolve_log_path(r"logs\\sub/test.log")
assert result == Path("logs/sub/test.log")
class TestExtractOptimizerTop1FromLog:
"""Tests for _extract_optimizer_top1_from_log function."""
def test_extract_top1_none_log(self) -> None:
"""Test with None log."""
result = _extract_optimizer_top1_from_log(None)
assert result == {}
def test_extract_top1_empty_log(self) -> None:
"""Test with empty log."""
result = _extract_optimizer_top1_from_log("")
assert result == {}
def test_extract_top1_valid_log(self) -> None:
"""Test extracting from valid log."""
log = """
| Top | Throughput (token/s) | TTFT (ms) | TPOT (ms) | concurrency | num_devices | parallel | batch_size |
| 1 | 2500.0 | 15000.0 | 45.0 | 200 | 8 | TP=8 | DP=1 | 200 |
"""
result = _extract_optimizer_top1_from_log(log)
assert result["best_throughput"] == 2500.0
assert result["best_ttft_ms"] == 15000.0
assert result["best_tpot_ms"] == 45.0
assert result["best_concurrency"] == 200
assert result["best_batch_size"] == 200
def test_extract_top1_with_pp(self) -> None:
"""Test extracting with PP (Pipeline Parallel) column."""
log = """
| Top | Throughput (token/s) | TTFT (ms) | TPOT (ms) | concurrency | num_devices | parallel | batch_size |
| 1 | 1000.0 | 16000.0 | 50.0 | 130 | 8 | TP=4 | PP=1 | DP=2 | 65 |
"""
result = _extract_optimizer_top1_from_log(log)
assert result["best_parallel"] == "TP=4 | PP=1 | DP=2"
assert result["best_batch_size"] == 65
def test_extract_top1_no_match(self) -> None:
"""Test with log that doesn't match pattern."""
log = "Some other log content\nwithout table"
result = _extract_optimizer_top1_from_log(log)
assert result == {}
class TestInferOptimizerNoResultReasonFromParams:
"""Tests for _infer_optimizer_no_result_reason_from_params function."""
def test_infer_both_limits(self) -> None:
"""Test with both limits set."""
params = {"ttft_limits": 100.0, "tpot_limits": 50.0}
result = _infer_optimizer_no_result_reason_from_params(params)
assert "TTFT=100 ms" in result
assert "TPOT=50 ms" in result
def test_infer_ttft_only(self) -> None:
"""Test with only TTFT set."""
params = {"ttft_limits": 200.0, "tpot_limits": None}
result = _infer_optimizer_no_result_reason_from_params(params)
assert "TTFT=200 ms" in result
assert "TPOT=unlimited" in result
def test_infer_tpot_only(self) -> None:
"""Test with only TPOT set."""
params = {"ttft_limits": None, "tpot_limits": 75.0}
result = _infer_optimizer_no_result_reason_from_params(params)
assert "TTFT=unlimited" in result
assert "TPOT=75 ms" in result
def test_infer_both_none(self) -> None:
"""Test with both limits None."""
params = {"ttft_limits": None, "tpot_limits": None}
result = _infer_optimizer_no_result_reason_from_params(params)
assert "TTFT=unlimited" in result
assert "TPOT=unlimited" in result
class TestEnrichOptimizerSummary:
"""Tests for _enrich_optimizer_summary function."""
def test_enrich_none_summary(self) -> None:
"""Test with None summary."""
result = _enrich_optimizer_summary(None, {}, "")
assert isinstance(result, dict)
def test_enrich_from_top_configs(self) -> None:
"""Test enriching from top_configs table."""
summary = {}
tables = {
"top_configs": [{"parallel": "TP=4", "batch_size": 64, "concurrency": 100, "throughput_token_s": 1000.0}]
}
result = _enrich_optimizer_summary(summary, tables, "", {}, None)
assert result["best_parallel"] == "TP=4"
assert result["best_batch_size"] == 64
def test_enrich_from_log(self) -> None:
"""Test enriching from raw log when table empty."""
summary = {}
tables = {}
log = """
| Top | Throughput (token/s) | TTFT (ms) | TPOT (ms) | concurrency | num_devices | parallel | batch_size |
| 1 | 1000.0 | 15000.0 | 45.0 | 200 | 8 | TP=8 | DP=1 | 200 |
"""
result = _enrich_optimizer_summary(summary, tables, log, {}, None)
assert result["best_throughput"] == 1000.0
assert result["best_concurrency"] == 200
def test_enrich_adds_no_result_reason(self) -> None:
"""Test that no_result_reason is added when needed."""
summary = {}
tables = {"top_configs": []}
params = {"ttft_limits": 50.0, "tpot_limits": 25.0}
result = _enrich_optimizer_summary(summary, tables, "", params, None)
assert "no_result_reason" in result
def test_enrich_preserves_existing_fields(self) -> None:
"""Test that existing fields are preserved."""
summary = {"existing_field": "value"}
tables = {}
result = _enrich_optimizer_summary(summary, tables, "", {}, None)
assert result["existing_field"] == "value"
class TestResultStore:
"""Tests for ResultStore class."""
@pytest.fixture
def temp_store(self) -> ResultStore:
"""Create a ResultStore with temporary directory."""
import shutil
tmpdir = tempfile.mkdtemp()
store = ResultStore(root=tmpdir)
yield store
try:
import gc
del store
gc.collect()
shutil.rmtree(tmpdir, ignore_errors=True)
except Exception:
pass
def test_store_creates_directories(self, temp_store: ResultStore) -> None:
"""Test that store creates required directories."""
assert temp_store.root.exists()
assert temp_store.logs_dir.exists()
def test_store_creates_database(self, temp_store: ResultStore) -> None:
"""Test that database is created."""
assert temp_store.db_path.exists()
def test_store_database_schema(self, temp_store: ResultStore) -> None:
"""Test that database has correct schema."""
with temp_store._connect() as conn:
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='runs'")
tables = cursor.fetchall()
assert len(tables) == 1
assert tables[0][0] == "runs"
def test_get_cached_result_none(self, temp_store: ResultStore) -> None:
"""Test getting non-existent cached result."""
task = ExperimentTask("text_generate", {}, [], "hash", "test")
result = temp_store.get_cached_result(task)
assert result is None
def test_save_and_get_result(self, temp_store: ResultStore) -> None:
"""Test saving and retrieving a result."""
result = ExperimentResult(
sim_type="text_generate",
status="success",
params={"model_id": "test_model"},
command=["test"],
task_hash="test_hash",
label="test_result",
summary={"tps_per_device": 100.0},
raw_log="Test log content",
)
temp_store.save_result(result)
task = ExperimentTask("text_generate", {"model_id": "test_model"}, [], "test_hash", "test_result")
cached = temp_store.get_cached_result(task)
assert cached is not None
assert cached.status == "success"
assert cached.summary["tps_per_device"] == 100.0
assert cached.source == "cache"
def test_save_creates_log_file(self, temp_store: ResultStore) -> None:
"""Test that saving creates a log file."""
result = ExperimentResult(
sim_type="text_generate",
status="success",
params={},
command=[],
task_hash="log_test",
label="test",
raw_log="Log line 1\nLog line 2",
)
temp_store.save_result(result)
log_file = temp_store.logs_dir / "log_test.log"
assert log_file.exists()
assert log_file.read_text(encoding="utf-8") == "Log line 1\nLog line 2"
def test_save_updates_existing(self, temp_store: ResultStore) -> None:
"""Test that save updates existing result."""
result1 = ExperimentResult(
sim_type="text_generate",
status="running",
params={},
command=[],
task_hash="update_test",
label="test",
summary={"progress": 50},
)
temp_store.save_result(result1)
result2 = ExperimentResult(
sim_type="text_generate",
status="success",
params={},
command=[],
task_hash="update_test",
label="test",
summary={"progress": 100},
)
temp_store.save_result(result2)
task = ExperimentTask("text_generate", {}, [], "update_test", "test")
cached = temp_store.get_cached_result(task)
assert cached.status == "success"
assert cached.summary["progress"] == 100
def test_query_rows_empty(self, temp_store: ResultStore) -> None:
"""Test querying rows from empty store."""
rows = temp_store.query_rows()
assert rows == []
def test_query_rows_by_sim_type(self, temp_store: ResultStore) -> None:
"""Test querying rows filtered by sim_type."""
result1 = ExperimentResult(
sim_type="text_generate",
status="success",
params={"device": "D1"},
command=[],
task_hash="h1",
label="test1",
)
result2 = ExperimentResult(
sim_type="video_generate",
status="success",
params={"device": "D2"},
command=[],
task_hash="h2",
label="test2",
)
temp_store.save_result(result1)
temp_store.save_result(result2)
text_rows = temp_store.query_rows("text_generate")
assert len(text_rows) == 1
assert text_rows[0]["sim_type"] == "text_generate"
video_rows = temp_store.query_rows("video_generate")
assert len(video_rows) == 1
assert video_rows[0]["sim_type"] == "video_generate"
def test_query_rows_orders_by_created_at(self, temp_store: ResultStore) -> None:
"""Test that results are ordered by created_at descending."""
for i in range(3):
result = ExperimentResult(
sim_type="text_generate",
status="success",
params={"index": i},
command=[],
task_hash=f"h{i}",
label=f"test{i}",
)
temp_store.save_result(result)
rows = temp_store.query_rows("text_generate")
assert rows[0]["label"] == "test2"
assert rows[2]["label"] == "test0"
def test_query_rows_includes_top_configs(self, temp_store: ResultStore) -> None:
"""Test that optimizer rows include top_configs."""
result = ExperimentResult(
sim_type="throughput_optimizer",
status="success",
params={},
command=[],
task_hash="opt_test",
label="opt",
tables={"top_configs": [{"rank": 1, "throughput_token_s": 1000}]},
)
temp_store.save_result(result)
rows = temp_store.query_rows("throughput_optimizer")
assert len(rows) == 1
assert "top_configs" in rows[0]
assert len(rows[0]["top_configs"]) == 1
def test_get_cached_result_failed_overwritten(self, temp_store: ResultStore) -> None:
"""Test that failed results are not cached."""
result = ExperimentResult(
sim_type="text_generate",
status="failed",
params={},
command=[],
task_hash="fail_test",
label="fail",
error="Test error",
)
temp_store.save_result(result)
task = ExperimentTask("text_generate", {}, [], "fail_test", "fail")
cached = temp_store.get_cached_result(task)
assert cached is not None
assert cached.status == "failed"
def test_optimizer_enrichment_on_query(self, temp_store: ResultStore) -> None:
"""Test that optimizer results are enriched on query."""
result = ExperimentResult(
sim_type="throughput_optimizer",
status="success",
params={"ttft_limits": 100.0},
command=[],
task_hash="opt_enrich",
label="opt",
summary={"best_throughput": None},
tables={},
raw_log="",
)
temp_store.save_result(result)
rows = temp_store.query_rows("throughput_optimizer")
assert len(rows) == 1
assert "no_result_reason" in rows[0]
def test_multiple_saves_same_hash(self, temp_store: ResultStore) -> None:
"""Test that multiple saves with same hash update correctly."""
for i in range(3):
result = ExperimentResult(
sim_type="text_generate",
status="success",
params={"run": i},
command=[],
task_hash="same_hash",
label=f"run_{i}",
summary={"count": i},
)
temp_store.save_result(result)
rows = temp_store.query_rows("text_generate")
assert len(rows) == 1
assert rows[0]["count"] == 2
def test_query_rows_all_types(self, temp_store: ResultStore) -> None:
"""Test querying all sim_types."""
for sim_type in ["text_generate", "video_generate", "throughput_optimizer"]:
result = ExperimentResult(
sim_type=sim_type,
status="success",
params={},
command=[],
task_hash=f"{sim_type}_h",
label=sim_type,
)
temp_store.save_result(result)
rows = temp_store.query_rows()
assert len(rows) == 3
def test_result_with_warnings_and_infos(self, temp_store: ResultStore) -> None:
"""Test saving result with warnings and infos."""
result = ExperimentResult(
sim_type="text_generate",
status="success",
params={},
command=[],
task_hash="warn_test",
label="test",
warnings=["WARNING: Test warning"],
infos=["INFO: Test info"],
)
temp_store.save_result(result)
rows = temp_store.query_rows("text_generate")
assert rows[0]["warning_count"] == 1
assert rows[0]["info_count"] == 1
def test_result_with_error(self, temp_store: ResultStore) -> None:
"""Test saving result with error."""
result = ExperimentResult(
sim_type="text_generate",
status="failed",
params={},
command=[],
task_hash="error_test",
label="test",
error="Process failed",
)
temp_store.save_result(result)
rows = temp_store.query_rows("text_generate")
assert rows[0]["status"] == "failed"