from __future__ import annotations

import json
import re
import sqlite3
import time
from pathlib import Path
from typing import Any, TYPE_CHECKING

from .schemas import ExperimentResult, ExperimentTask

if TYPE_CHECKING:
    import os

ANSI_RE = re.compile(r"\x1b\[[0-9;]*m")


def _resolve_log_path(path_str: str) -> Path:
    if not path_str:
        return Path("")
    normalized = path_str.replace("\\", "/")
    return Path(normalized)


def _extract_optimizer_top1_from_log(raw_log: str) -> dict[str, Any]:
    if not raw_log:
        return {}
    raw_log = ANSI_RE.sub("", raw_log)
    pattern = re.compile(
        r"^\|\s*1\s*\|\s*([0-9.]+)\s*\|\s*([0-9.]+)\s*\|\s*([0-9.]+)\s*"
        r"\|\s*(\d+)\s*\|\s*(\d+)\s*\|\s*(.*?)\s*\|\s*(\d+)\s*\|$",
        flags=re.MULTILINE,
    )
    m = pattern.search(raw_log)
    if not m:
        return {}
    return {
        "best_throughput": float(m.group(1)),
        "best_ttft_ms": float(m.group(2)),
        "best_tpot_ms": float(m.group(3)),
        "best_concurrency": int(m.group(4)),
        "best_parallel": m.group(6).strip(),
        "best_batch_size": int(m.group(7)),
    }


def _infer_optimizer_no_result_reason_from_params(params: dict[str, Any]) -> str:
    ttft = params.get("ttft_limits")
    tpot = params.get("tpot_limits")
    ttft_text = f"{ttft:g} ms" if isinstance(ttft, (int, float)) else "unlimited"
    tpot_text = f"{tpot:g} ms" if isinstance(tpot, (int, float)) else "unlimited"
    return (
        f"No valid deployment was found under the current limits "
        f"(TTFT={ttft_text}, TPOT={tpot_text})."
        "Try relaxing the latency limits, increasing num-devices, "
        "changing quantization, or reducing input/output length."
    )


def _enrich_optimizer_summary(
    summary: dict[str, Any],
    tables: dict[str, Any],
    raw_log: str = "",
    params: dict[str, Any] | None = None,
    error: str | None = None,
) -> dict[str, Any]:
    if not isinstance(summary, dict):
        summary = {}
    if not isinstance(tables, dict):
        tables = {}
    params = params or {}

    top_rows = tables.get("top_configs") or []
    if isinstance(top_rows, list) and top_rows:
        top1 = top_rows[0]
        if isinstance(top1, dict):
            summary.setdefault("best_parallel", top1.get("parallel"))
            summary.setdefault("best_batch_size", top1.get("batch_size"))
            summary.setdefault("best_concurrency", top1.get("concurrency"))
            summary.setdefault("best_throughput", top1.get("throughput_token_s"))
            summary.setdefault("best_ttft_ms", top1.get("ttft_ms"))
            summary.setdefault("best_tpot_ms", top1.get("tpot_ms"))

    if (
        any(summary.get(k) in (None, "") for k in ["best_parallel", "best_batch_size", "best_concurrency"])
        or not top_rows
    ):
        summary.update(
            {k: v for k, v in _extract_optimizer_top1_from_log(raw_log).items() if summary.get(k) in (None, "")}
        )

    summary.setdefault("ttft_limits_ms", params.get("ttft_limits"))
    summary.setdefault("tpot_limits_ms", params.get("tpot_limits"))

    has_result = summary.get("best_throughput") not in (None, "")
    if error and not summary.get("execution_error"):
        summary["execution_error"] = error
    if not has_result and not summary.get("no_result_reason") and not summary.get("execution_error"):
        summary["no_result_reason"] = _infer_optimizer_no_result_reason_from_params(params)

    return summary


class ResultStore:
    def __init__(self, root: str | os.PathLike | None = None):
        self.root = Path(root or ".msmodeling_ui")
        self.root.mkdir(parents=True, exist_ok=True)
        self.logs_dir = self.root / "logs"
        self.logs_dir.mkdir(exist_ok=True)
        self.db_path = self.root / "results.sqlite3"
        self._init_db()

    def _connect(self):
        return sqlite3.connect(self.db_path)

    def _init_db(self):
        with self._connect() as conn:
            conn.execute(
                """
                CREATE TABLE IF NOT EXISTS runs (
                    task_hash TEXT PRIMARY KEY,
                    sim_type TEXT NOT NULL,
                    status TEXT NOT NULL,
                    label TEXT NOT NULL,
                    params_json TEXT NOT NULL,
                    summary_json TEXT NOT NULL,
                    tables_json TEXT NOT NULL,
                    warnings_json TEXT NOT NULL,
                    infos_json TEXT NOT NULL,
                    log_path TEXT NOT NULL,
                    error TEXT,
                    created_at REAL NOT NULL
                )
                """
            )
            conn.commit()

    def get_cached_result(self, task: ExperimentTask) -> ExperimentResult | None:
        with self._connect() as conn:
            row = conn.execute(
                "SELECT sim_type,status,label,params_json,summary_json,tables_json,"
                "warnings_json,infos_json,log_path,error FROM runs WHERE task_hash=?",
                (task.task_hash,),
            ).fetchone()
        if not row:
            return None
        log_file = _resolve_log_path(row[8])
        raw_log = log_file.read_text(encoding="utf-8") if log_file.exists() else ""
        summary = json.loads(row[4])
        tables = json.loads(row[5])
        if row[0] == "throughput_optimizer":
            summary = _enrich_optimizer_summary(summary, tables, raw_log, json.loads(row[3]), row[9])
        return ExperimentResult(
            sim_type=row[0],
            status=row[1],
            params=json.loads(row[3]),
            command=task.command,
            task_hash=task.task_hash,
            label=row[2],
            summary=summary,
            tables=tables,
            warnings=json.loads(row[6]),
            infos=json.loads(row[7]),
            raw_log=raw_log,
            error=row[9],
            source="cache",
        )

    def save_result(self, result: ExperimentResult):
        log_path = self.logs_dir / f"{result.task_hash}.log"
        log_path.write_text(result.raw_log or "", encoding="utf-8")
        with self._connect() as conn:
            conn.execute(
                """
                INSERT OR REPLACE INTO runs(task_hash, sim_type, status, label, params_json, summary_json,
                    tables_json, warnings_json, infos_json, log_path, error, created_at)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                """,
                (
                    result.task_hash,
                    result.sim_type,
                    result.status,
                    result.label,
                    json.dumps(result.params, ensure_ascii=False),
                    json.dumps(result.summary, ensure_ascii=False),
                    json.dumps(result.tables, ensure_ascii=False),
                    json.dumps(result.warnings, ensure_ascii=False),
                    json.dumps(result.infos, ensure_ascii=False),
                    str(log_path),
                    result.error,
                    time.time(),
                ),
            )
            conn.commit()

    def query_rows(self, sim_type: str | None = None) -> list[dict[str, Any]]:
        query = (
            "SELECT sim_type,status,label,params_json,summary_json,warnings_json,"
            "infos_json,created_at,task_hash,error FROM runs"
        )
        args: tuple[Any, ...] = ()
        if sim_type:
            query += " WHERE sim_type=?"
            args = (sim_type,)
        query += " ORDER BY created_at DESC"
        rows = []
        with self._connect() as conn:
            for row in conn.execute(query, args).fetchall():
                params = json.loads(row[3])
                summary = json.loads(row[4])
                top_configs: list[dict[str, Any]] = []
                if row[0] == "throughput_optimizer":
                    try:
                        raw_row = conn.execute(
                            "SELECT tables_json, log_path FROM runs WHERE task_hash=?",
                            (row[8],),
                        ).fetchone()
                        tables = json.loads(raw_row[0]) if raw_row and raw_row[0] else {}
                        log_file = _resolve_log_path(raw_row[1]) if raw_row and raw_row[1] else Path("")
                        raw_log = log_file.read_text(encoding="utf-8") if log_file.exists() else ""
                    except Exception:
                        tables = {}
                        raw_log = ""
                    top_configs = tables.get("top_configs") or []
                    summary = _enrich_optimizer_summary(summary, tables, raw_log, json.loads(row[3]), row[9])
                rows.append(
                    {
                        "sim_type": row[0],
                        "status": row[1],
                        "label": row[2],
                        **params,
                        **summary,
                        "top_configs": top_configs,
                        "warning_count": len(json.loads(row[5])),
                        "info_count": len(json.loads(row[6])),
                        "created_at": row[7],
                        "task_hash": row[8],
                    }
                )
        return rows