# -*- coding: utf-8 -*-
"""Inference latency benchmark for trained PhysicsResidualGRU."""

from __future__ import annotations

import argparse
import os
import time

import numpy as np
import torch

from models.physics_residual_gru import PhysicsResidualGRU


PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))


def resolve_ckpt_path(path_str: str) -> str:
    if os.path.isabs(path_str):
        return path_str
    return os.path.normpath(os.path.join(PROJECT_DIR, path_str))


def main():
    parser = argparse.ArgumentParser(description="Benchmark inference latency for PhysicsResidualGRU.")
    parser.add_argument("--ckpt", type=str, default=os.path.join(PROJECT_DIR, "checkpoints", "default", "best_model.pt"))
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--seq-len", type=int, default=21)
    parser.add_argument("--input-dim", type=int, default=17)
    parser.add_argument("--num-priors", type=int, default=3)
    parser.add_argument("--warmup", type=int, default=50)
    parser.add_argument("--runs", type=int, default=200)
    args = parser.parse_args()

    ckpt_path = resolve_ckpt_path(args.ckpt)
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    payload = torch.load(ckpt_path, map_location="cpu")
    cfg = payload.get("config", {})
    input_dim = int(payload.get("input_dim", args.input_dim))
    num_priors = int(payload.get("num_priors", args.num_priors))

    model = PhysicsResidualGRU(
        input_dim=input_dim,
        hidden_dim=int(cfg.get("hidden_dim", 64)),
        num_layers=int(cfg.get("num_layers", 2)),
        dropout=float(cfg.get("dropout", 0.1)),
        num_priors=num_priors,
    )
    model.load_state_dict(payload["model_state_dict"])
    model.eval()

    if torch.cuda.is_available():
        try:
            _ = torch.zeros(1, device="cuda")
            device = torch.device("cuda")
        except Exception:  # noqa: BLE001
            device = torch.device("cpu")
    else:
        device = torch.device("cpu")
    model.to(device)

    x = torch.randn(args.batch_size, args.seq_len, input_dim, device=device)
    priors = torch.randn(args.batch_size, args.seq_len, num_priors, 2, device=device)
    prior_conf = torch.softmax(torch.randn(args.batch_size, num_priors, device=device), dim=-1)

    with torch.no_grad():
        for _ in range(args.warmup):
            _ = model(x, priors, prior_conf)
        if device.type == "cuda":
            torch.cuda.synchronize()

        times = []
        for _ in range(args.runs):
            t0 = time.perf_counter()
            _ = model(x, priors, prior_conf)
            if device.type == "cuda":
                torch.cuda.synchronize()
            t1 = time.perf_counter()
            times.append((t1 - t0) * 1000.0)

    arr = np.array(times, dtype=np.float64)
    print("=" * 60)
    print("Latency benchmark finished")
    print("Device:", device)
    print("Batch:", args.batch_size, "SeqLen:", args.seq_len, "InputDim:", input_dim, "NumPriors:", num_priors)
    print(f"Mean: {arr.mean():.4f} ms")
    print(f"P50:  {np.percentile(arr, 50):.4f} ms")
    print(f"P95:  {np.percentile(arr, 95):.4f} ms")
    print(f"P99:  {np.percentile(arr, 99):.4f} ms")
    print("=" * 60)


if __name__ == "__main__":
    main()