"""Compute M6: Empirical E2E Prediction Ratio (offline metric).

M6 = TC_empirical_hit_total / Real_per_forward_pass

Where:
  - TC_empirical_hit_total = sum of durations from TC chrome trace events
    where source ∈ {MEASURED, INTERPOLATED} (empirical data, not analytic)
  - Real_per_forward_pass = kernel duration sum from a clean forward pass
    trace CSV (pre-extracted from kernel_details.csv).

M6 = 1.0 is perfect. >1 = overestimate, <1 = underestimate.

This tool computes the E2E ratio only. For per-kernel analysis, use
tools/perf_data_analysis/generate_op_comparison.py.

Usage:
    python3.10 tools/perf_data_analysis/compute_m6.py \\
        --tc-trace results/qwen3_dc_trace.json \\
        --prof-trace docs/perf_database/forward_pass_traces/qwen3-32b_dc_16tok.csv

    # Only count exact CSV matches (exclude interpolated):
    python3.10 tools/perf_data_analysis/compute_m6.py \\
        --tc-trace results/qwen3_dc_trace.json \\
        --prof-trace docs/perf_database/forward_pass_traces/qwen3-32b_dc_16tok.csv \\
        --source-filter MEASURED
"""

import argparse
import csv
import json
from collections import defaultdict
from pathlib import Path

DEFAULT_SOURCE_FILTER = {"MEASURED", "INTERPOLATED"}


def _sum_kernels_with_dedup(events: list) -> tuple:
    """Sum kernel durations with hcom deduplication.

    kernel_details.csv records each hcom on both Stream N/A and a hardware
    stream with identical (start_time, duration). Deduplicate by
    (int(start), kernel_type) keeping the max duration.
    AicpuKernel entries are tracked separately and excluded from
    compute_us and hcom_us.

    Args:
        events: list of (start, end, kernel_type, input_shapes)

    Returns:
        (compute_us, hcom_us, aicpu_us, kernel_count, kernel_type_durations)
    """
    compute_us = 0.0
    hcom_us = 0.0
    aicpu_us = 0.0
    kernel_count = 0
    kernel_type_durations: dict[str, float] = defaultdict(float)
    hcom_seen: dict[tuple, float] = {}

    for start, end, ktype, _ in events:
        dur = end - start
        if ktype.endswith("AicpuKernel"):
            kernel_count += 1
            aicpu_us += dur
        elif ktype.startswith("hcom_"):
            key = (int(start), ktype)
            if key in hcom_seen:
                if dur > hcom_seen[key]:
                    hcom_us += dur - hcom_seen[key]
                    kernel_type_durations[ktype] += dur - hcom_seen[key]
                    hcom_seen[key] = dur
            else:
                hcom_seen[key] = dur
                hcom_us += dur
                kernel_type_durations[ktype] += dur
                kernel_count += 1
        else:
            kernel_count += 1
            compute_us += dur
            kernel_type_durations[ktype] += dur

    return compute_us, hcom_us, aicpu_us, kernel_count, dict(kernel_type_durations)


def _load_tc_trace(
    tc_trace_path: Path,
    source_filter: set[str],
) -> float:
    """Load TC chrome trace, sum empirical HIT durations.

    Only counts events where args.source ∈ source_filter and dur > 0.

    Returns:
        empirical_hit_us
    """
    with tc_trace_path.open() as f:
        data = json.load(f)

    total = 0.0
    for event in data.get("traceEvents", []):
        if event.get("ph") != "X":
            continue
        args = event.get("args", {})
        source = args.get("source", "")
        if source not in source_filter:
            continue
        dur = event.get("dur", 0)
        if dur <= 0:
            continue
        if args.get("kernel_type", ""):
            total += dur

    return total


def _load_prof_trace(prof_trace_path: Path) -> tuple[float, float, float]:
    """Load prof trace CSV, sum durations with hcom dedup.

    Returns:
        (real_per_fwd_us, compute_us, hcom_us)
    """
    events = []
    with prof_trace_path.open(encoding="utf-8-sig") as f:
        for row in csv.DictReader(f):
            ktype = (row.get("Type") or "").strip()
            if not ktype:
                continue
            try:
                start = float((row.get("Start Time(us)") or "0").strip())
                dur = float((row.get("Duration(us)") or "0").strip())
            except ValueError:
                continue
            events.append((start, start + dur, ktype, ""))

    compute_us, hcom_us, _, _, _ = _sum_kernels_with_dedup(events)
    return compute_us + hcom_us, compute_us, hcom_us


def compute_m6(
    tc_trace: str,
    prof_trace: str,
    source_filter: set[str] | None = None,
) -> dict:
    """Compute M6 by comparing TC chrome trace vs prof forward pass trace.

    Args:
        tc_trace: Path to TC chrome trace JSON (from --chrome-trace).
        prof_trace: Path to clean forward pass CSV (pre-extracted).
        source_filter: Set of QuerySource names to include from TC trace.
            Default: {"MEASURED", "INTERPOLATED"}.

    Returns:
        dict with m6_ratio, empirical_hit_us, real_per_fwd_us,
        and component breakdowns.
    """
    if source_filter is None:
        source_filter = DEFAULT_SOURCE_FILTER

    tc_trace_path = Path(tc_trace)
    prof_trace_path = Path(prof_trace)

    if not tc_trace_path.exists():
        raise FileNotFoundError(f"TC trace not found: {tc_trace_path}")
    if not prof_trace_path.exists():
        raise FileNotFoundError(f"Prof trace not found: {prof_trace_path}")

    empirical_hit_us = _load_tc_trace(tc_trace_path, source_filter)
    real_per_fwd_us, prof_compute, prof_hcom = _load_prof_trace(prof_trace_path)

    m6_ratio = empirical_hit_us / real_per_fwd_us if real_per_fwd_us > 0 else 0.0

    return {
        "m6_ratio": m6_ratio,
        "empirical_hit_us": round(empirical_hit_us, 2),
        "real_per_fwd_us": round(real_per_fwd_us, 2),
        "selected_fwd_compute_us": round(prof_compute, 2),
        "selected_fwd_hcom_us": round(prof_hcom, 2),
        "tc_trace": tc_trace,
        "prof_trace": prof_trace,
        "source_filter": sorted(source_filter),
    }


def _format_report(result: dict) -> str:
    lines = [
        "=" * 60,
        "M6: Empirical E2E Prediction Ratio",
        "=" * 60,
        "",
        f"TC trace:        {result['tc_trace']}",
        f"Prof trace:      {result['prof_trace']}",
        f"Source filter:   {result['source_filter']}",
        "",
        f"Empirical HIT total: {result['empirical_hit_us']:>12,.1f} us "
        f"({result['empirical_hit_us'] / 1e3:,.1f} ms)",
        f"Real per-fwd:        {result['real_per_fwd_us']:>12,.1f} us "
        f"({result['real_per_fwd_us'] / 1e3:,.1f} ms)",
        f"  Compute:           {result['selected_fwd_compute_us']:>12,.1f} us",
        f"  hcom:              {result['selected_fwd_hcom_us']:>12,.1f} us",
        "",
        f"M6 = {result['m6_ratio']:.3f}  (Empirical / Real)",
    ]

    return "\n".join(lines)


def build_argparser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Compute M6: compare TC chrome trace vs profiling forward pass."
    )
    parser.add_argument(
        "--tc-trace",
        required=True,
        help="Path to TC chrome trace JSON (from --chrome-trace)",
    )
    parser.add_argument(
        "--prof-trace",
        required=True,
        help="Path to clean forward pass CSV (pre-extracted from kernel_details)",
    )
    parser.add_argument(
        "--source-filter",
        default=None,
        help="Comma-separated QuerySource names to include "
        "(default: MEASURED,INTERPOLATED). Use MEASURED to exclude interpolated.",
    )
    parser.add_argument("--json-output", default=None)
    return parser


def main() -> None:
    args = build_argparser().parse_args()

    source_filter = DEFAULT_SOURCE_FILTER
    if args.source_filter:
        source_filter = {s.strip() for s in args.source_filter.split(",")}

    result = compute_m6(
        tc_trace=args.tc_trace,
        prof_trace=args.prof_trace,
        source_filter=source_filter,
    )
    print(_format_report(result))
    if args.json_output:
        output_path = Path(args.json_output)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        with output_path.open("w") as f:
            json.dump(result, f, indent=2, ensure_ascii=False)
        print(f"\nJSON output written to: {output_path}")


if __name__ == "__main__":
    main()