import logging
import re
import unittest
from dataclasses import dataclass
from typing import Optional

import torch
from parameterized import parameterized
from tensor_cast.core.input_generator import generate_inputs
from tensor_cast.core.model_runner import ModelRunner, ModelRunnerMetrics
from tensor_cast.core.quantization.datatypes import (
    QuantizeLinearAction,
)
from tensor_cast.core.user_config import UserInputConfig

logger = logging.getLogger(__name__)


@dataclass
class AutoBaselineCase:
    name: str
    description: str
    baseline_input: UserInputConfig
    compare_input: UserInputConfig
    tolerance: float = 0.05


@dataclass
class AutoBaselineResult:
    case_name: str
    baseline_time_s: float
    actual_time_s: float
    diff_pct: float
    tolerance: float
    passed: bool
    error: Optional[str] = None


def _parse_total_time_s(table_result: str, model_name: str = "analytic") -> float:
    pattern = rf"Total time for {model_name}:\s*([\d.]+)\s*(ns|us|ms|s)"
    m = re.search(pattern, table_result)
    if not m:
        raise ValueError(f"Could not find 'Total time for {model_name}' in output:\n{table_result}")
    value = float(m.group(1))
    unit = m.group(2)
    return value * {"ns": 1e-9, "us": 1e-6, "ms": 1e-3, "s": 1.0}[unit]


def _run_single(user_input: UserInputConfig) -> float:
    torch.compiler.reset()
    model_runner = ModelRunner(user_input)
    result = model_runner.run_inference(generate_inputs_func=generate_inputs)
    if isinstance(result, ModelRunnerMetrics):
        return _parse_total_time_s(result.table_result)
    raise TypeError(f"Unexpected result type: {type(result)}")


def run_auto_baseline(
    baseline_input: UserInputConfig,
    compare_input: UserInputConfig,
    case_name: str = "auto_baseline",
    tolerance: float = 0.05,
) -> AutoBaselineResult:
    logger.info("\n%s", "=" * 60)
    logger.info("  AUTO BASELINE TEST: %s", case_name)
    logger.info("%s", "=" * 60)

    logger.info("[Run 1/2] Establishing baseline...")
    try:
        baseline_time_s = _run_single(baseline_input)
        logger.info("  Baseline time: %.3fms", baseline_time_s * 1000)
    except Exception as e:
        return AutoBaselineResult(
            case_name=case_name,
            baseline_time_s=0.0,
            actual_time_s=0.0,
            diff_pct=0.0,
            tolerance=tolerance,
            passed=False,
            error=f"Baseline run failed: {e}",
        )

    logger.info("[Run 2/2] Running comparison...")
    try:
        actual_time_s = _run_single(compare_input)
        logger.info("  Actual time:   %.3fms", actual_time_s * 1000)
    except Exception as e:
        return AutoBaselineResult(
            case_name=case_name,
            baseline_time_s=baseline_time_s,
            actual_time_s=0.0,
            diff_pct=0.0,
            tolerance=tolerance,
            passed=False,
            error=f"Comparison run failed: {e}",
        )

    if baseline_time_s <= 0:
        return AutoBaselineResult(
            case_name=case_name,
            baseline_time_s=baseline_time_s,
            actual_time_s=actual_time_s,
            diff_pct=0.0,
            tolerance=tolerance,
            passed=False,
            error=f"Invalid baseline time: {baseline_time_s}",
        )

    diff_pct = (actual_time_s - baseline_time_s) / baseline_time_s
    passed = abs(diff_pct) <= tolerance

    return AutoBaselineResult(
        case_name=case_name,
        baseline_time_s=baseline_time_s,
        actual_time_s=actual_time_s,
        diff_pct=diff_pct,
        tolerance=tolerance,
        passed=passed,
    )


def _print_result(result: AutoBaselineResult):
    logger.info("\n%s", "=" * 60)
    logger.info("  RESULT: %s", result.case_name)
    logger.info("%s", "=" * 60)

    if result.error:
        logger.error("  ERROR: %s", result.error)
        logger.info("  Status: FAIL")
        return

    logger.info("  Baseline:  %.3fms", result.baseline_time_s * 1000)
    logger.info("  Actual:    %.3fms", result.actual_time_s * 1000)
    logger.info("  Diff:      %+.2f%%", result.diff_pct * 100)
    logger.info("  Tolerance: ±%.0f%%", result.tolerance * 100)
    logger.info("  Status:    %s", "PASS" if result.passed else "FAIL")

    if not result.passed:
        direction = "slower" if result.diff_pct > 0 else "faster"
        logger.warning(
            "  Performance regression: %.2f%% %s than baseline!",
            abs(result.diff_pct) * 100,
            direction,
        )


AUTO_BASELINE_CASES: list[AutoBaselineCase] = [
    AutoBaselineCase(
        name="qwen3-8B_auto",
        description="Qwen3-8B decode, baseline ctx=1536 vs compare ctx=1500, TP=2, compile",
        baseline_input=UserInputConfig(
            device="ATLAS_800_A2_376T_64G",
            model_id="Qwen/Qwen3-8B",
            num_queries=32,
            query_len=1,
            context_length=1536,
            do_compile=True,
            decode=True,
            quantize_linear_action=QuantizeLinearAction.DISABLED,
            tp_size=2,
            world_size=2,
        ),
        compare_input=UserInputConfig(
            device="ATLAS_800_A2_376T_64G",
            model_id="Qwen/Qwen3-8B",
            num_queries=32,
            query_len=1,
            context_length=1500,
            do_compile=True,
            decode=True,
            quantize_linear_action=QuantizeLinearAction.DISABLED,
            tp_size=2,
            world_size=2,
        ),
        tolerance=0.05,
    ),
]


class TestAutoBaseline(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        logging.basicConfig(
            level=logging.INFO,
            format="[%(levelname)s] [%(name)s] %(message)s",
        )

    @parameterized.expand(
        [(case.name, case) for case in AUTO_BASELINE_CASES],
        name_func=lambda func, num, p: f"{func.__name__}_{p.args[0]}",
    )
    def test_auto_baseline(self, _name: str, case: AutoBaselineCase):
        logger.info("Baseline config:  context_length=%d", case.baseline_input.context_length)
        logger.info("Compare config:   context_length=%d", case.compare_input.context_length)

        result = run_auto_baseline(
            baseline_input=case.baseline_input,
            compare_input=case.compare_input,
            case_name=case.name,
            tolerance=case.tolerance,
        )
        _print_result(result)

        self.assertTrue(
            result.passed,
            f"[{case.name}] Auto baseline FAILED: "
            f"diff={result.diff_pct * 100:+.2f}%, "
            f"tolerance=±{result.tolerance * 100:.0f}%" + (f", error={result.error}" if result.error else ""),
        )