"""Generate per-operator TC vs Profiling comparison data.

Reads TC chrome traces (from --performance-model profiling --chrome-trace)
for exact per-op invocation counts and latencies. No layer_multiplier needed.

Uses pre-extracted forward-pass trace CSVs from the repo for profiling
ground truth. All inputs are checked into the repository.

Usage:
    python3.10 tools/perf_data_analysis/generate_op_comparison.py \
        [--trace-dir docs/perf_database/forward_pass_traces] \
        [--data-dir <db_path>]

Outputs results/op_comparison.json with per-kernel-type rows:
  TC (HIT empirical + MISS analytic) vs Profiling per-forward.
"""

import argparse
import csv
import json
import os
from collections import defaultdict
from typing import Any

import yaml


DEFAULT_DATA_DIR = (
    "tensor_cast/performance_model/profiling_database/data"
    "/ATLAS_800_A3_752T_128G_DIE/vllm_ascend/vllm0.18.0_torch2.9.0_cann8.5"
)

TRACE_DIR = "docs/perf_database/forward_pass_traces"

TC_UNMODELED_STATUS = "TC\u672a\u5efa\u6a21"
SCENARIO_SEPARATOR = "\u2014"
TC_LATENCY_HEADER = "TC(\u03bcs)"
TC_COUNT_HEADER = "TC\u6b21\u6570"
PROF_LATENCY_HEADER = "Prof(\u03bcs)"
PROF_COUNT_HEADER = "Prof\u6b21\u6570"
LATENCY_ERROR_HEADER = "\u5ef6\u8fdf\u8bef\u5dee"
COUNT_ERROR_HEADER = "\u6b21\u6570\u8bef\u5dee"
PROF_SHARE_HEADER = "Prof\u5360\u6bd4"
STATUS_HEADER = "\u72b6\u6001"

# TC chrome trace files (generated by Step 2 with --chrome-trace)
TC_TRACE_FILES = {
    "Qwen3 PF": "results/qwen3_pf_trace.json",
    "Qwen3 DC": "results/qwen3_dc_trace.json",
    "DSv3 PF": "results/dsv3_pf_trace.json",
    "DSv3 DC": "results/dsv3_dc_trace.json",
}


def get_scenarios(trace_dir: str = TRACE_DIR):
    return [
        {
            "name": "Qwen3 PF",
            "m6": "results/qwen3_prefill_m6.json",
            "trace_csv": f"{trace_dir}/qwen3-32b_pf_4112tok.csv",
            "tc_trace": TC_TRACE_FILES["Qwen3 PF"],
        },
        {
            "name": "Qwen3 DC",
            "m6": "results/qwen3_decode_m6.json",
            "trace_csv": f"{trace_dir}/qwen3-32b_dc_16tok.csv",
            "tc_trace": TC_TRACE_FILES["Qwen3 DC"],
        },
        {
            "name": "DSv3 PF",
            "m6": "results/dsv3_prefill_m6.json",
            "trace_csv": f"{trace_dir}/dsv3_pf_4099tok.csv",
            "tc_trace": TC_TRACE_FILES["DSv3 PF"],
        },
        {
            "name": "DSv3 DC",
            "m6": "results/dsv3_decode_m6.json",
            "trace_csv": f"{trace_dir}/dsv3_dc_1tok.csv",
            "tc_trace": TC_TRACE_FILES["DSv3 DC"],
        },
    ]


def parse_profiling_by_type(trace_csv: str) -> dict[str, dict[str, Any]]:
    """Parse a forward-pass trace CSV with hcom deduplication.

    Accepts a CSV file path directly (pre-extracted forward pass trace
    from docs/perf_database/forward_pass_traces/).

    Dedup strategy: group hcom_* kernels by (int(start_time), Type) and
    take max(duration) per group (parallel HCCL + AICPU streams).
    Non-hcom kernels are summed directly.
    """
    stats: dict[str, dict[str, Any]] = defaultdict(
        lambda: {"total_us": 0.0, "count": 0}
    )
    hcom_groups: dict[tuple[int, str], float] = {}

    with open(trace_csv) as f:
        for row in csv.DictReader(f):
            kt = row.get("Type", "").strip()
            if not kt:
                continue
            dur = float(row.get("Duration(us)", 0))

            if kt.startswith("hcom_"):
                start_str = row.get("Start Time(us)", "0").strip()
                try:
                    start_key = int(float(start_str))
                except ValueError:
                    start_key = 0
                key = (start_key, kt)
                hcom_groups[key] = max(hcom_groups.get(key, 0.0), dur)
            elif kt.endswith("AicpuKernel") or kt == "AicpuKernel":
                continue
            else:
                stats[kt]["total_us"] += dur
                stats[kt]["count"] += 1

    for (_, kt), dur in hcom_groups.items():
        stats[kt]["total_us"] += dur
        stats[kt]["count"] += 1

    return dict(stats)


def load_op_mapping(data_dir: str) -> dict:
    """Load op_mapping.yaml for func_name -> kernel_type resolution."""
    mapping_path = os.path.join(data_dir, "op_mapping.yaml")
    if not os.path.exists(mapping_path):
        return {}
    with open(mapping_path) as f:
        return yaml.safe_load(f) or {}


def _get_op_entry(func_name: str, op_mapping: dict) -> dict | None:
    """Look up op_mapping entry for a func_name."""
    short = func_name.replace("torch.ops.", "")
    entry = op_mapping.get(short) or op_mapping.get(func_name)
    if entry and isinstance(entry, dict):
        return entry
    return None


def resolve_kernel_type(func_name: str, op_mapping: dict) -> str | None:
    """Resolve func_name to kernel_type via op_mapping."""
    entry = _get_op_entry(func_name, op_mapping)
    return entry.get("kernel_type") if entry else None


def is_zero_cost(func_name: str, op_mapping: dict) -> bool:
    """Check if an op is zero_cost via op_mapping."""
    entry = _get_op_entry(func_name, op_mapping)
    return bool(entry.get("zero_cost")) if entry else False


def extract_tc_from_chrome_trace(
    trace_path: str, op_mapping: dict
) -> dict[str, dict[str, Any]]:
    """Extract per-kernel-type TC stats from chrome trace.

    The chrome trace contains one event per process_op call across all
    layers - exact invocation counts, no layer_multiplier needed.

    For profiling model events:
      - HIT ops: args contains 'kernel_type' from profiling data source
      - MISS ops: no 'kernel_type' in args, resolve via op_mapping

    Returns: {kernel_type: {"total_us": float, "count": int,
              "status": "HIT"|"MISS", "func_names": [...]}}
    """
    with open(trace_path) as f:
        trace = json.load(f)

    # Find profiling PID (may be the only PID, or second in dual-model)
    profiling_pid = None
    for ev in trace.get("traceEvents", []):
        if ev.get("ph") == "M" and ev.get("name") == "process_name":
            name_str = ev.get("args", {}).get("name", "").lower()
            if "profiling" in name_str or "empirical" in name_str:
                profiling_pid = ev["pid"]
                break

    # If no profiling PID found (single-model trace), use PID 0
    if profiling_pid is None:
        pids = {ev["pid"] for ev in trace.get("traceEvents", []) if ev.get("ph") == "X"}
        if len(pids) == 1:
            profiling_pid = pids.pop()
            import sys

            print(
                f"Warning: profiling PID not found by name in {trace_path}, "
                f"using sole PID {profiling_pid}",
                file=sys.stderr,
            )
        else:
            raise ValueError(
                f"Cannot identify profiling PID in {trace_path}. Found PIDs: {pids}"
            )

    stats: dict[str, dict[str, Any]] = defaultdict(
        lambda: {
            "total_us": 0.0,
            "count": 0,
            "hit_count": 0,
            "miss_count": 0,
            "func_names": [],
        }
    )

    for ev in trace.get("traceEvents", []):
        if ev.get("ph") != "X" or ev.get("pid") != profiling_pid:
            continue

        op_name = ev["name"]
        dur_us = ev.get("dur", 0)
        args = ev.get("args", {})

        # Skip zero_cost ops
        if is_zero_cost(op_name, op_mapping):
            continue

        # Determine kernel_type and HIT/MISS status
        # HIT: chrome trace args contains 'kernel_type' from profiling
        kt_from_args = args.get("kernel_type")
        if kt_from_args and kt_from_args != "?":
            kt_raw = str(kt_from_args)
            sub_kts = [s.strip() for s in kt_raw.split(",") if s.strip()]

            # Use sub_kernel_durations if available (precise per-sub-kernel
            # latencies from ProfilingDataSource composite lookup).
            # Falls back to equal-split for traces without this field.
            skd_raw = args.get("sub_kernel_durations")
            skd = None
            if skd_raw:
                try:
                    import ast

                    skd = ast.literal_eval(skd_raw)
                except (ValueError, SyntaxError):
                    skd = None
                # Guard: must be list of (name, dur) pairs
                if not (
                    isinstance(skd, list)
                    and all(isinstance(x, (list, tuple)) and len(x) == 2 for x in skd)
                ):
                    skd = None

            if skd:
                skd_names = {name for name, _ in skd}
                sub_kt_names = set(sub_kts)
                if skd_names != sub_kt_names:
                    import logging

                    logging.getLogger(__name__).debug(
                        "sub_kernel_durations names %s != kernel_type names %s for %s",
                        skd_names,
                        sub_kt_names,
                        op_name,
                    )
                for sk_name, sk_dur in skd:
                    stats[sk_name]["hit_count"] += 1
                    stats[sk_name]["total_us"] += sk_dur
                    stats[sk_name]["count"] += 1
                    fn_short = op_name.replace("torch.ops.", "")
                    if fn_short not in stats[sk_name]["func_names"]:
                        stats[sk_name]["func_names"].append(fn_short)
            else:
                per_sub_dur = dur_us / len(sub_kts) if sub_kts else dur_us
                for sub_kt in sub_kts:
                    stats[sub_kt]["hit_count"] += 1
                    stats[sub_kt]["total_us"] += per_sub_dur
                    stats[sub_kt]["count"] += 1
                    fn_short = op_name.replace("torch.ops.", "")
                    if fn_short not in stats[sub_kt]["func_names"]:
                        stats[sub_kt]["func_names"].append(fn_short)

                # Legacy MC2 handling: when sub_kernel_durations is absent,
                # add a count-only entry for hcom_allReduce_ (no duration).
                fn_lower = op_name.lower()
                if "all_reduce" in fn_lower and "hcom_allReduce_" not in sub_kts:
                    stats["hcom_allReduce_"]["hit_count"] += 1
                    stats["hcom_allReduce_"]["count"] += 1
                    fn_short = op_name.replace("torch.ops.", "")
                    if fn_short not in stats["hcom_allReduce_"]["func_names"]:
                        stats["hcom_allReduce_"]["func_names"].append(fn_short)
        else:
            # MISS: resolve via op_mapping, fall back to op name
            kt = resolve_kernel_type(op_name, op_mapping) or op_name
            stats[kt]["miss_count"] += 1
            stats[kt]["total_us"] += dur_us
            stats[kt]["count"] += 1
            fn_short = op_name.replace("torch.ops.", "")
            if fn_short not in stats[kt]["func_names"]:
                stats[kt]["func_names"].append(fn_short)

    # Derive status: HIT only if ALL invocations are HIT
    for kt_stats in stats.values():
        if kt_stats["miss_count"] > 0 and kt_stats["hit_count"] > 0:
            kt_stats["status"] = "PARTIAL"
        elif kt_stats["hit_count"] > 0:
            kt_stats["status"] = "HIT"
        else:
            kt_stats["status"] = "MISS"

    return dict(stats)


def build_comparison(scenario: dict, op_mapping: dict) -> list[dict]:
    """Build per-operator comparison: TC vs Profiling.

    TC data comes from chrome trace (exact per-invocation counts).
    Profiling data comes from forward-pass trace CSV (single forward).
    """
    with open(scenario["m6"]) as f:
        m6_data = json.load(f)
    # This script only supports trace-csv mode (pre-extracted single forward
    # pass CSVs). n_fwd is always 1. If compute_m6.py scenario mode is used
    # in the future, it must export "n_forward_passes" in its JSON output.
    n_fwd = m6_data.get("n_forward_passes", 1)

    # TC: exact counts from chrome trace
    tc_stats = extract_tc_from_chrome_trace(scenario["tc_trace"], op_mapping)

    # Profiling: ground truth from forward-pass trace CSV
    prof_stats = parse_profiling_by_type(scenario["trace_csv"])

    # Build comparison rows
    rows = []
    all_kts = set(tc_stats.keys()) | set(prof_stats.keys())
    for kt in sorted(all_kts):
        tc_info = tc_stats.get(
            kt,
            {
                "total_us": 0,
                "count": 0,
                "status": TC_UNMODELED_STATUS,
                "func_names": [],
            },
        )
        prof_info = prof_stats.get(kt, {"total_us": 0, "count": 0})

        tc_total_us = tc_info["total_us"]
        tc_count = tc_info["count"]
        # Profiling traces are single-forward, count is exact
        prof_per_fwd_us = prof_info["total_us"] / n_fwd if n_fwd > 0 else 0
        prof_count = prof_info["count"] / n_fwd if n_fwd > 0 else 0

        error_pct = (
            (tc_total_us - prof_per_fwd_us) / prof_per_fwd_us * 100
            if prof_per_fwd_us > 0
            else 0
        )
        count_error_pct = (
            (tc_count - prof_count) / prof_count * 100 if prof_count > 0 else 0
        )

        status = tc_info["status"] if kt in tc_stats else TC_UNMODELED_STATUS

        rows.append(
            {
                "kernel_type": kt,
                "tc_funcs": ", ".join(tc_info["func_names"][:3]),
                "tc_total_us": round(tc_total_us, 1),
                "tc_count": tc_count,
                "prof_per_fwd_us": round(prof_per_fwd_us, 1),
                "prof_count": int(round(prof_count)),
                "error_pct": round(error_pct, 1),
                "count_error_pct": round(count_error_pct, 1),
                "status": status,
            }
        )

    rows.sort(
        key=lambda x: max(x["tc_total_us"], x["prof_per_fwd_us"]),
        reverse=True,
    )
    return rows


def main():
    parser = argparse.ArgumentParser(
        description="Generate per-operator TC vs Profiling comparison data"
    )
    parser.add_argument(
        "--trace-dir",
        default=TRACE_DIR,
        help=(
            "Directory with forward-pass trace CSVs "
            "(default: docs/perf_database/forward_pass_traces)"
        ),
    )
    parser.add_argument(
        "--data-dir",
        default=DEFAULT_DATA_DIR,
        help="Path to profiling database directory (with op_mapping.yaml)",
    )
    parser.add_argument(
        "--output",
        default="results/op_comparison.json",
        help="Output JSON path",
    )
    args = parser.parse_args()

    op_mapping = load_op_mapping(args.data_dir)

    scenarios = get_scenarios(args.trace_dir)
    all_results = {}
    for sc in scenarios:
        name = sc["name"]
        try:
            rows = build_comparison(sc, op_mapping)
        except FileNotFoundError as e:
            print(f"SKIP {name}: {e}")
            continue

        all_results[name] = rows

        total_tc = sum(r["tc_total_us"] for r in rows)
        total_prof = sum(r["prof_per_fwd_us"] for r in rows)

        print(f"\n{'=' * 115}")
        print(f"{name} {SCENARIO_SEPARATOR} TC vs Profiling per-forward")
        print(f"{'=' * 115}")
        print(
            f"{'kernel_type':<35} {TC_LATENCY_HEADER:>12} {TC_COUNT_HEADER:>7} "
            f"{PROF_LATENCY_HEADER:>12} {PROF_COUNT_HEADER:>8} "
            f"{LATENCY_ERROR_HEADER:>9} {COUNT_ERROR_HEADER:>9} "
            f"{PROF_SHARE_HEADER:>8} {STATUS_HEADER:<6}"
        )
        print("-" * 115)
        for r in rows[:15]:
            pf_pct = r["prof_per_fwd_us"] / total_prof * 100 if total_prof > 0 else 0
            print(
                f"{r['kernel_type']:<35} {r['tc_total_us']:>12.1f} "
                f"{r['tc_count']:>7d} "
                f"{r['prof_per_fwd_us']:>12.1f} {r['prof_count']:>8.0f} "
                f"{r['error_pct']:>8.1f}% {r['count_error_pct']:>8.1f}% "
                f"{pf_pct:>7.1f}% {r['status']:<6}"
            )
        print(f"{'TOTAL':<35} {total_tc:>12.1f} {'':>7} {total_prof:>12.1f}")

    os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
    with open(args.output, "w") as f:
        json.dump(all_results, f, indent=2, ensure_ascii=False)
    print(f"\nSaved to {args.output}")


if __name__ == "__main__":
    main()