from __future__ import annotations
import json
import re
import sqlite3
import time
from pathlib import Path
from typing import Any, TYPE_CHECKING
from .schemas import ExperimentResult, ExperimentTask
if TYPE_CHECKING:
import os
ANSI_RE = re.compile(r"\x1b\[[0-9;]*m")
def _resolve_log_path(path_str: str) -> Path:
if not path_str:
return Path("")
normalized = path_str.replace("\\", "/")
return Path(normalized)
def _extract_optimizer_top1_from_log(raw_log: str) -> dict[str, Any]:
"""Extract top 1 configuration from standard optimizer log (8-column table)."""
if not raw_log:
return {}
raw_log = ANSI_RE.sub("", raw_log)
pattern = re.compile(
r"^\|\s*1\s*\|\s*([0-9.]+)\s*\|\s*([0-9.]+)\s*\|\s*([0-9.]+)\s*"
r"\|\s*(\d+)\s*\|\s*(\d+)\s*\|\s*(.*?)\s*\|\s*(\d+)\s*\|$",
flags=re.MULTILINE,
)
m = pattern.search(raw_log)
if not m:
return {}
return {
"best_throughput": float(m.group(1)),
"best_ttft_ms": float(m.group(2)),
"best_tpot_ms": float(m.group(3)),
"best_concurrency": int(m.group(4)),
"best_parallel": m.group(6).strip(),
"best_batch_size": int(m.group(7)),
}
def _extract_pd_ratio_top1_from_log(raw_log: str) -> dict[str, Any]:
"""Extract top 1 configuration from PD Ratio mode log (15-column table)."""
if not raw_log:
return {}
raw_log = ANSI_RE.sub("", raw_log)
result = {}
pd_match = re.search(r"PD Ratio:\s+([0-9.]+)", raw_log)
if pd_match:
result["pd_ratio"] = float(pd_match.group(1))
p_qps_match = re.search(r"Prefill QPS:\s+([0-9.]+)", raw_log)
if p_qps_match:
result["p_qps"] = float(p_qps_match.group(1))
ttft_match = re.search(r"TTFT:\s+([0-9.]+)\s*ms", raw_log)
if ttft_match:
result["best_ttft_ms"] = float(ttft_match.group(1))
d_qps_match = re.search(r"Decode QPS:\s+([0-9.]+)", raw_log)
if d_qps_match:
result["d_qps"] = float(d_qps_match.group(1))
tpot_match = re.search(r"TPOT:\s+([0-9.]+)\s*ms", raw_log)
if tpot_match:
result["best_tpot_ms"] = float(tpot_match.group(1))
balanced_match = re.search(r"Balanced[^:]*:\s+([0-9.]+)", raw_log)
if balanced_match:
balanced_qps = float(balanced_match.group(1))
result["balanced_qps"] = balanced_qps
d_parallel_match = re.search(r"D Parallel:\s*([^,\n]+)", raw_log)
if d_parallel_match:
result["best_parallel"] = d_parallel_match.group(1).strip()
p_batch_match = re.search(r"P Batch:\s*(\d+)", raw_log)
if p_batch_match:
result["p_batch_size"] = int(p_batch_match.group(1))
d_batch_match = re.search(r"D Batch:\s*(\d+)", raw_log)
if d_batch_match:
result["best_batch_size"] = int(d_batch_match.group(1))
d_concurrency_match = re.search(r"D Concurrency:\s*(\d+)", raw_log)
if d_concurrency_match:
result["best_concurrency"] = int(d_concurrency_match.group(1))
if result.get("best_throughput") is not None:
return result
pattern = re.compile(
r"^\|\s*1\s*\|\s*([0-9.]+)\s*\|\s*([0-9.]+)\s*\|\s*([0-9.]+)\s*\|\s*([0-9.]+)\s*"
r"\|\s*([0-9.]+)\s*\|\s*([0-9.]+)\s*\|\s*(.*?)\s*\|\s*(.*?)\s*\|"
r"\s*(\d+)\s*\|\s*(\d+)\s*\|\s*(\d+)\s*\|\s*(\d+)\s*\|\s*(\d+)\s*\|\s*(\d+)\s*\|$",
flags=re.MULTILINE,
)
m = pattern.search(raw_log)
if m:
balanced_qps = float(m.group(2))
return {
"pd_ratio": float(m.group(1)),
"balanced_qps": balanced_qps,
"p_qps": float(m.group(3)),
"d_qps": float(m.group(4)),
"best_ttft_ms": float(m.group(5)),
"best_tpot_ms": float(m.group(6)),
"best_parallel": m.group(8).strip(),
"prefill_devices_per_instance": int(m.group(9)),
"decode_devices_per_instance": int(m.group(10)),
"p_batch_size": int(m.group(11)),
"best_batch_size": int(m.group(12)),
"p_concurrency": int(m.group(13)),
"best_concurrency": int(m.group(14)),
}
return {}
def _infer_optimizer_no_result_reason_from_params(params: dict[str, Any]) -> str:
ttft = params.get("ttft_limits")
tpot = params.get("tpot_limits")
ttft_text = f"{ttft:g} ms" if isinstance(ttft, (int, float)) else "unlimited"
tpot_text = f"{tpot:g} ms" if isinstance(tpot, (int, float)) else "unlimited"
return (
f"No valid deployment was found under the current limits "
f"(TTFT={ttft_text}, TPOT={tpot_text})."
"Try relaxing the latency limits, increasing num-devices, "
"changing quantization, or reducing input/output length."
)
def _enrich_optimizer_summary(
summary: dict[str, Any],
tables: dict[str, Any],
raw_log: str = "",
params: dict[str, Any] | None = None,
error: str | None = None,
) -> dict[str, Any]:
if not isinstance(summary, dict):
summary = {}
if not isinstance(tables, dict):
tables = {}
params = params or {}
top_rows = tables.get("top_configs") or []
if isinstance(top_rows, list) and top_rows:
top1 = top_rows[0]
if isinstance(top1, dict):
is_pd_ratio = top1.get("pd_ratio") is not None
if is_pd_ratio:
if top1.get("d_parallel") and summary.get("best_parallel") in (None, ""):
summary["best_parallel"] = top1.get("d_parallel")
if top1.get("d_batch_size") and summary.get("best_batch_size") in (None, ""):
summary["best_batch_size"] = top1.get("d_batch_size")
if top1.get("d_concurrency") and summary.get("best_concurrency") in (None, ""):
summary["best_concurrency"] = top1.get("d_concurrency")
if top1.get("throughput_token_s") and summary.get("best_throughput") in (None, ""):
summary["best_throughput"] = top1.get("throughput_token_s")
else:
if top1.get("parallel") and summary.get("best_parallel") in (None, ""):
summary["best_parallel"] = top1.get("parallel")
if top1.get("batch_size") and summary.get("best_batch_size") in (None, ""):
summary["best_batch_size"] = top1.get("batch_size")
if top1.get("concurrency") and summary.get("best_concurrency") in (None, ""):
summary["best_concurrency"] = top1.get("concurrency")
if top1.get("throughput_token_s") and summary.get("best_throughput") in (None, ""):
summary["best_throughput"] = top1.get("throughput_token_s")
if top1.get("ttft_ms") and summary.get("best_ttft_ms") in (None, ""):
summary["best_ttft_ms"] = top1.get("ttft_ms")
if top1.get("tpot_ms") and summary.get("best_tpot_ms") in (None, ""):
summary["best_tpot_ms"] = top1.get("tpot_ms")
if (
any(summary.get(k) in (None, "") for k in ["best_parallel", "best_batch_size", "best_concurrency"])
or not top_rows
):
is_pd_ratio_mode = (
params.get("enable_optimize_prefill_decode_ratio", False)
or "PD Ratio" in raw_log
or summary.get("pd_ratio") is not None
or (top_rows and top_rows[0].get("pd_ratio") is not None)
)
if is_pd_ratio_mode:
extracted = _extract_pd_ratio_top1_from_log(raw_log)
else:
extracted = _extract_optimizer_top1_from_log(raw_log)
summary.update({k: v for k, v in extracted.items() if summary.get(k) in (None, "")})
summary.setdefault("ttft_limits_ms", params.get("ttft_limits"))
summary.setdefault("tpot_limits_ms", params.get("tpot_limits"))
has_result = summary.get("best_throughput") not in (None, "")
if error and not summary.get("execution_error"):
summary["execution_error"] = error
if not has_result and not summary.get("no_result_reason") and not summary.get("execution_error"):
summary["no_result_reason"] = _infer_optimizer_no_result_reason_from_params(params)
return summary
class ResultStore:
def __init__(self, root: str | os.PathLike | None = None):
self.root = Path(root or ".msmodeling_ui")
self.root.mkdir(parents=True, exist_ok=True)
self.logs_dir = self.root / "logs"
self.logs_dir.mkdir(exist_ok=True)
self.db_path = self.root / "results.sqlite3"
self._init_db()
def _connect(self):
return sqlite3.connect(self.db_path)
def _init_db(self):
with self._connect() as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS runs (
task_hash TEXT PRIMARY KEY,
sim_type TEXT NOT NULL,
status TEXT NOT NULL,
label TEXT NOT NULL,
params_json TEXT NOT NULL,
summary_json TEXT NOT NULL,
tables_json TEXT NOT NULL,
warnings_json TEXT NOT NULL,
infos_json TEXT NOT NULL,
log_path TEXT NOT NULL,
error TEXT,
created_at REAL NOT NULL
)
"""
)
conn.commit()
def get_cached_result(self, task: ExperimentTask) -> ExperimentResult | None:
with self._connect() as conn:
row = conn.execute(
"SELECT sim_type,status,label,params_json,summary_json,tables_json,"
"warnings_json,infos_json,log_path,error FROM runs WHERE task_hash=?",
(task.task_hash,),
).fetchone()
if not row:
return None
log_file = _resolve_log_path(row[8])
raw_log = log_file.read_text(encoding="utf-8") if log_file.exists() else ""
summary = json.loads(row[4])
tables = json.loads(row[5])
if row[0] == "throughput_optimizer":
summary = _enrich_optimizer_summary(summary, tables, raw_log, json.loads(row[3]), row[9])
return ExperimentResult(
sim_type=row[0],
status=row[1],
params=json.loads(row[3]),
command=task.command,
task_hash=task.task_hash,
label=row[2],
summary=summary,
tables=tables,
warnings=json.loads(row[6]),
infos=json.loads(row[7]),
raw_log=raw_log,
error=row[9],
source="cache",
)
def save_result(self, result: ExperimentResult):
log_path = self.logs_dir / f"{result.task_hash}.log"
log_path.write_text(result.raw_log or "", encoding="utf-8")
with self._connect() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO runs(task_hash, sim_type, status, label, params_json, summary_json,
tables_json, warnings_json, infos_json, log_path, error, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
result.task_hash,
result.sim_type,
result.status,
result.label,
json.dumps(result.params, ensure_ascii=False),
json.dumps(result.summary, ensure_ascii=False),
json.dumps(result.tables, ensure_ascii=False),
json.dumps(result.warnings, ensure_ascii=False),
json.dumps(result.infos, ensure_ascii=False),
str(log_path),
result.error,
time.time(),
),
)
conn.commit()
def query_rows(self, sim_type: str | None = None) -> list[dict[str, Any]]:
query = (
"SELECT sim_type,status,label,params_json,summary_json,warnings_json,"
"infos_json,created_at,task_hash,error FROM runs"
)
args: tuple[Any, ...] = ()
if sim_type:
query += " WHERE sim_type=?"
args = (sim_type,)
query += " ORDER BY created_at DESC"
rows = []
with self._connect() as conn:
for row in conn.execute(query, args).fetchall():
params = json.loads(row[3])
summary = json.loads(row[4])
top_configs: list[dict[str, Any]] = []
if row[0] == "throughput_optimizer":
try:
raw_row = conn.execute(
"SELECT tables_json, log_path FROM runs WHERE task_hash=?",
(row[8],),
).fetchone()
tables = json.loads(raw_row[0]) if raw_row and raw_row[0] else {}
log_file = _resolve_log_path(raw_row[1]) if raw_row and raw_row[1] else Path("")
raw_log = log_file.read_text(encoding="utf-8") if log_file.exists() else ""
except Exception:
tables = {}
raw_log = ""
top_configs = tables.get("top_configs") or []
summary = _enrich_optimizer_summary(summary, tables, raw_log, json.loads(row[3]), row[9])
rows.append(
{
"sim_type": row[0],
"status": row[1],
"label": row[2],
**params,
**summary,
"top_configs": top_configs,
"warning_count": len(json.loads(row[5])),
"info_count": len(json.loads(row[6])),
"created_at": row[7],
"task_hash": row[8],
}
)
return rows