import torch
def get_abs_err(x, y):
return (x.detach() - y.detach()).flatten().abs().max().item()
def get_err_ratio(x, y):
err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()
base = (x.detach()).flatten().square().mean().sqrt().item()
return err / (base + 1e-8)
def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6):
abs_atol = get_abs_err(ref, tri)
msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
error_rate = get_err_ratio(ref, tri)
if abs_atol <= err_atol:
return
else:
assert error_rate < ratio, msg
def print_diff(name, ref, tri, atol=0.005):
abs_diff = torch.abs(ref - tri)
max_abs_diff = abs_diff.max().item()
print(f"[{name}] Max absolute difference: {max_abs_diff:.6f}")
if max_abs_diff > atol:
print(f"Exceeds tolerance ({atol})!")