"""
"""
import math
import os
import torch
import pypto
import logging
import numpy as np
from numpy.testing import assert_allclose
def compare(t: torch.Tensor, t_ref: torch.Tensor, name, atol, rtol, max_error_ratio=0.005, max_error_count=10):
"""
比较两个张量的差异,超过阈值时打印错误点并抛出断言错误
Args:
t: 待比较张量
t_ref: 参考张量
name: 张量名称(用于日志)
atol: 绝对容差
rtol: 相对容差
max_error_ratio: 误差点占总元素数的最大比例
max_error_count: 显示的最大误差点数量(同时也是误差点阈值的上限)
"""
def check_is_nan_inf():
nan_mask = torch.isnan(t)
nan_count = nan_mask.sum().item()
inf_mask = torch.isinf(t)
inf_count = inf_mask.sum().item()
if nan_count > 0 or inf_count > 0:
error_msg = f"\n========== 张量 {name} 检测到非法值(禁止存在NaN/Inf)=========="
if nan_count > 0:
nan_positions = torch.nonzero(nan_mask, as_tuple=False)
show_nan_count = min(nan_count, max_error_count)
error_msg += f"\n- NaN数量:{nan_count},前 {show_nan_count} 个位置:"
for i in range(show_nan_count):
pos_tuple = tuple(p.item() for p in nan_positions[i])
error_msg += f"\n 位置 {pos_tuple}"
if inf_count > 0:
inf_positions = torch.nonzero(inf_mask, as_tuple=False)
show_inf_count = min(inf_count, max_error_count)
error_msg += f"\n- Inf数量:{inf_count},前 {show_inf_count} 个位置(值类型):"
for i in range(show_inf_count):
pos = inf_positions[i]
pos_tuple = tuple(p.item() for p in pos)
inf_val = t[pos_tuple].item()
inf_type = "+Inf" if inf_val == float('inf') else "-Inf"
error_msg += f"\n 位置 {pos_tuple}:{inf_type}"
error_msg += "\n" + "="*80 + "\n"
assert False, error_msg
check_is_nan_inf()
assert t.shape == t_ref.shape, f"张量形状不一致:t.shape={t.shape}, t_ref.shape={t_ref.shape}"
assert t.dtype == t_ref.dtype, f"张量数据类型不一致:t.dtype={t.dtype}, t_ref.dtype={t_ref.dtype}"
assert t.device == t_ref.device, f"张量设备不一致:t.device={t.device}, t_ref.device={t_ref.device}"
error_count_threshold = round(max_error_ratio * t_ref.numel())
diff_abs = (t - t_ref).abs()
tolerance = atol + rtol * t_ref.abs()
diff_mask = diff_abs > tolerance
error_count = diff_mask.sum().item()
max_diff, flat_max_pos = torch.max(diff_abs.flatten(), dim=0)
max_pos = torch.unravel_index(flat_max_pos, t.shape)
max_pos = tuple(idx.item() for idx in max_pos)
if error_count > 0:
print(f"\n========== 张量 {name} 存在 {error_count} 个误差点(阈值:{error_count_threshold})==========")
error_positions = torch.nonzero(diff_mask, as_tuple=False)
show_count = min(error_count, max_error_count)
print(f"显示前 {show_count} 个误差点(位置 | 待比较值 | 参考值 | 绝对误差 | 允许阈值):")
for i in range(show_count):
pos = error_positions[i]
pos_tuple = tuple(p.item() for p in pos)
t_val = t[pos_tuple].item()
t_ref_val = t_ref[pos_tuple].item()
diff_val = diff_abs[pos_tuple].item()
tol_val = tolerance[pos_tuple].item()
print(f" 位置 {pos_tuple}: {t_val:.8f} vs {t_ref_val:.8f} | 误差={diff_val:.8f} | 阈值={tol_val:.8f}")
print(f"\n最大误差点:位置 {max_pos} | 误差={max_diff.item():.8f} | 阈值={tolerance[max_pos].item():.8f}")
print("=" * 80 + "\n")
assert error_count <= error_count_threshold, \
(f"compare fail: {name}, max diff: {max_diff.item():.8f} at {max_pos}, "
f"error_count: {error_count}, error_count_threshold: {error_count_threshold}")
print("compare success !!!!")