"""Custom assertions for tensor and latency checks."""
import math
import torch
def assert_tensor_close(
actual: torch.Tensor,
expected: torch.Tensor,
*,
rtol: float = 1e-5,
atol: float = 1e-8,
equal_nan: bool = False,
) -> None:
"""Assert two tensors are element-wise close (torch.testing.assert_close semantics)."""
torch.testing.assert_close(actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan)
def assert_latency_within(
actual_ms: float,
expected_ms: float,
*,
metric: str = "latency",
tolerance_ms: float | None = None,
rel_tolerance: float = 0.0,
) -> None:
"""Assert latency is within absolute and/or relative tolerance.
When ``tolerance_ms`` and ``rel_tolerance`` are both unset/zero, uses exact
match with a small float epsilon for deterministic unit tests. Callers
comparing noisy measurements must pass ``tolerance_ms`` and/or ``rel_tolerance``.
"""
if actual_ms < 0 or expected_ms < 0:
raise AssertionError(f"Latency must be non-negative, got actual={actual_ms}, expected={expected_ms}")
if tolerance_ms is None and rel_tolerance == 0:
if not math.isclose(actual_ms, expected_ms, rel_tol=0.0, abs_tol=1e-9):
delta = abs(actual_ms - expected_ms)
raise AssertionError(
f"{metric} out of range: metric={metric}, baseline={expected_ms}, "
f"actual={actual_ms}, delta={delta}, allowed=1e-9"
)
return
allowed_abs = 0.0 if tolerance_ms is None else tolerance_ms
if rel_tolerance > 0:
allowed_delta = max(allowed_abs, rel_tolerance * max(abs(actual_ms), abs(expected_ms)))
else:
allowed_delta = allowed_abs
delta = abs(actual_ms - expected_ms)
if delta > allowed_delta:
raise AssertionError(
f"{metric} out of range: metric={metric}, baseline={expected_ms}, "
f"actual={actual_ms}, delta={delta}, allowed={allowed_delta}"
)