"""Inference latency benchmark for trained PhysicsResidualGRU."""
from __future__ import annotations
import argparse
import os
import time
import numpy as np
import torch
from models.physics_residual_gru import PhysicsResidualGRU
PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))
def resolve_ckpt_path(path_str: str) -> str:
if os.path.isabs(path_str):
return path_str
return os.path.normpath(os.path.join(PROJECT_DIR, path_str))
def main():
parser = argparse.ArgumentParser(description="Benchmark inference latency for PhysicsResidualGRU.")
parser.add_argument("--ckpt", type=str, default=os.path.join(PROJECT_DIR, "checkpoints", "default", "best_model.pt"))
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--seq-len", type=int, default=21)
parser.add_argument("--input-dim", type=int, default=17)
parser.add_argument("--num-priors", type=int, default=3)
parser.add_argument("--warmup", type=int, default=50)
parser.add_argument("--runs", type=int, default=200)
args = parser.parse_args()
ckpt_path = resolve_ckpt_path(args.ckpt)
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
payload = torch.load(ckpt_path, map_location="cpu")
cfg = payload.get("config", {})
input_dim = int(payload.get("input_dim", args.input_dim))
num_priors = int(payload.get("num_priors", args.num_priors))
model = PhysicsResidualGRU(
input_dim=input_dim,
hidden_dim=int(cfg.get("hidden_dim", 64)),
num_layers=int(cfg.get("num_layers", 2)),
dropout=float(cfg.get("dropout", 0.1)),
num_priors=num_priors,
)
model.load_state_dict(payload["model_state_dict"])
model.eval()
if torch.cuda.is_available():
try:
_ = torch.zeros(1, device="cuda")
device = torch.device("cuda")
except Exception:
device = torch.device("cpu")
else:
device = torch.device("cpu")
model.to(device)
x = torch.randn(args.batch_size, args.seq_len, input_dim, device=device)
priors = torch.randn(args.batch_size, args.seq_len, num_priors, 2, device=device)
prior_conf = torch.softmax(torch.randn(args.batch_size, num_priors, device=device), dim=-1)
with torch.no_grad():
for _ in range(args.warmup):
_ = model(x, priors, prior_conf)
if device.type == "cuda":
torch.cuda.synchronize()
times = []
for _ in range(args.runs):
t0 = time.perf_counter()
_ = model(x, priors, prior_conf)
if device.type == "cuda":
torch.cuda.synchronize()
t1 = time.perf_counter()
times.append((t1 - t0) * 1000.0)
arr = np.array(times, dtype=np.float64)
print("=" * 60)
print("Latency benchmark finished")
print("Device:", device)
print("Batch:", args.batch_size, "SeqLen:", args.seq_len, "InputDim:", input_dim, "NumPriors:", num_priors)
print(f"Mean: {arr.mean():.4f} ms")
print(f"P50: {np.percentile(arr, 50):.4f} ms")
print(f"P95: {np.percentile(arr, 95):.4f} ms")
print(f"P99: {np.percentile(arr, 99):.4f} ms")
print("=" * 60)
if __name__ == "__main__":
main()