import numpy as np
from numpy.testing import assert_allclose
class Colors:
RESET = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
RED = '\033[91m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
BLUE = '\033[94m'
PURPLE = '\033[95m'
CYAN = '\033[96m'
def detailed_allclose_manual(cpu, npu, name, rtol=1e-3, atol=1e-3, max_prints=50, force_print_first_n=5):
"""
手动实现 np.allclose 的详细比较,打印超出容差和 NaN 的值
参数:
cpu: CPU端的数组
npu: NPU端的数组
rtol: 相对容差
atol: 绝对容差
max_prints: 最大打印数量
force_print_first_n: 强制打印前n个元素的值, 无论是否异常
"""
if cpu.shape != npu.shape:
print(f"错误: 形状不一致 - cpu {cpu.shape} vs npu {npu.shape}")
return False
total_elements = cpu.size
abnormal_count = 0
nan_count = 0
exceed_tolerance_count = 0
print(f"开始比较数组,形状: {cpu.shape}, 总元素数: {total_elements}")
print(f"容差条件: rtol={rtol}, atol={atol}")
print("=" * 80)
YELLOW = '\033[93m'
RESET = '\033[0m'
cpu_flat = cpu.reshape(-1)
npu_flat = npu.reshape(-1)
def get_multi_index(flat_index, shape):
indices = []
remaining = flat_index
for dim in reversed(shape):
indices.append(remaining % dim)
remaining = remaining // dim
return tuple(reversed(indices))
_print_first_n(cpu_flat, npu_flat, get_multi_index, force_print_first_n, YELLOW, RESET)
for flat_idx in range(total_elements):
cpu_val = cpu_flat[flat_idx]
npu_val = npu_flat[flat_idx]
multi_idx = get_multi_index(flat_idx, cpu.shape)
if _is_nan(npu_val, npu_val):
abnormal_count += 1
nan_count += 1
if abnormal_count <= max_prints:
_log_nan_error(multi_idx, cpu_val, npu_val, YELLOW, RESET)
elif _is_above_tolerance(cpu_val, npu_val, rtol, atol):
abnormal_count += 1
exceed_tolerance_count += 1
if abnormal_count <= max_prints:
_log_tolerance_error(multi_idx, cpu_val, npu_val, rtol, atol, YELLOW, RESET)
_print_summary(abnormal_count, nan_count, exceed_tolerance_count, total_elements, name)
is_allclose = (abnormal_count == 0)
print(f"\nnp.allclose 等价结果: {is_allclose}")
assert_allclose(cpu, npu, rtol, atol)
if abnormal_count > max_prints:
print(f"\n注意: 只显示了前 {max_prints} 个异常,共有 {abnormal_count} 个异常元素")
assert_allclose(cpu, npu, rtol, atol)
return is_allclose
def _is_nan(cpu_val, npu_val):
return np.isnan(cpu_val) or np.isnan(npu_val)
def _is_above_tolerance(cpu_val, npu_val, rtol, atol):
if np.isnan(cpu_val) or np.isnan(npu_val):
return False
abs_diff = np.abs(cpu_val - npu_val)
allowed_diff = atol + rtol * np.abs(npu_val)
return abs_diff > allowed_diff
def _print_first_n(cpu_flat, npu_flat, get_multi_index, n, YELLOW, RESET):
if n <= 0:
return
print(f"{YELLOW}强制打印前 {n} 个元素:{RESET}")
for flat_idx in range(min(n, len(cpu_flat))):
cpu_val = cpu_flat[flat_idx]
npu_val = npu_flat[flat_idx]
multi_idx = get_multi_index(flat_idx, cpu_flat.shape)
cpu_str = "NaN" if np.isnan(cpu_val) else f"{cpu_val:.6e}"
npu_str = "NaN" if np.isnan(npu_val) else f"{npu_val:.6e}"
diff_str = "NaN" if np.isnan(cpu_val) or np.isnan(npu_val) else f"{np.abs(cpu_val - npu_val):.6e}"
print(f"{YELLOW}索引 {multi_idx}: cpu={cpu_str}, npu={npu_str}, 差值={diff_str}{RESET}")
print("-" * 80)
def _log_nan_error(multi_idx, cpu_val, npu_val, YELLOW, RESET):
cpu_str = "NaN" if np.isnan(cpu_val) else f"{cpu_val:.6e}"
npu_str = "NaN" if np.isnan(npu_val) else f"{npu_val:.6e}"
print(f"索引 {multi_idx}: cpu={cpu_str}, npu={npu_str}, 差值=NaN")
def _log_tolerance_error(multi_idx, cpu_val, npu_val, rtol, atol, YELLOW, RESET):
abs_diff = np.abs(cpu_val - npu_val)
allowed_diff = atol + rtol * np.abs(npu_val)
print(f"索引 {multi_idx}: cpu={cpu_val:.6e}, npu={npu_val:.6e}, 差值={abs_diff:.6e}(超过容差{allowed_diff:.6e})")
def _print_summary(abnormal_count, nan_count, exceed_tolerance_count, total_elements, name):
print("=" * 80)
print(f"{Colors.BOLD}{Colors.PURPLE}{name} 比较结果统计:{Colors.RESET}")
print(f"总元素数量: {total_elements}")
print(f"异常元素数量: {abnormal_count}")
print(f" - NaN 数量: {nan_count}")
print(f" - 超出容差数量: {exceed_tolerance_count}")
print(f"异常比例: {abnormal_count / total_elements * 100:.4f}%")