import logging
import torch
from torch import Tensor
MARE_L1 = 5
MERE_L1 = 1.5
RMSE_L1 = 1.5
def allclose(a, b, atol, ratio):
if a.shape != b.shape:
raise Exception("The shape of a and b must be same.")
diff = torch.abs(a.cpu() - b.cpu()) > atol
diff_count = torch.sum(diff)
diff_ratio = diff_count / a.numel()
return diff_ratio < ratio
def compute_mare(actual: Tensor, golden: Tensor):
"""
计算最大相对误差
"""
if actual.shape != golden.shape:
raise ValueError(f"actual shape {actual.shape} != golden shape {golden.shape}")
diff = torch.abs(actual - golden)
denominator = torch.abs(golden) + 1e-7
rel_error = torch.where(denominator > 1e-7, diff / denominator, diff)
return rel_error.max().item()
def compute_mere(actual: Tensor, golden: Tensor):
"""
计算平均相对误差
"""
if actual.shape != golden.shape:
raise ValueError(f"actual shape {actual.shape} != golden shape {golden.shape}")
diff = torch.abs(actual - golden)
denominator = torch.abs(golden) + 1e-7
rel_error = torch.where(denominator > 1e-7, diff / denominator, diff)
return rel_error.mean().item()
def compute_rmse(actual: Tensor, golden: Tensor):
"""
计算均方根误差
"""
if actual.shape != golden.shape:
raise ValueError(f"actual shape {actual.shape} != golden shape {golden.shape}")
squared_error = (actual - golden).pow(2)
mse = torch.mean(squared_error)
return torch.sqrt(mse).item()
def compare_data_with_double_pole(tensor_msg: str, actual_fused: Tensor, actual_npu: Tensor, golden: Tensor):
"""
双标杆对比
Args:
tensor_msg: 待比较tensor描述信息
actual_fused: NPU融合算子计算结果
actual_npu: NPU小算子计算结果
golden: CPU 高精度计算结果
"""
if actual_fused.device.type != golden.device.type:
actual_fused = actual_fused.to(golden.device)
if actual_npu.device.type != golden.device.type:
actual_npu = actual_npu.to(golden.device)
actual_fused = actual_fused.float()
actual_npu = actual_npu.float()
golden = golden.float()
mare_fused = compute_mare(actual_fused, golden)
mare_npu = compute_mare(actual_npu, golden)
mere_fused = compute_mere(actual_fused, golden)
mere_npu = compute_mere(actual_npu, golden)
rmse_fused = compute_rmse(actual_fused, golden)
rmse_npu = compute_rmse(actual_npu, golden)
print_msg = (f"{tensor_msg}, mare_fused: {mare_fused}, mare_npu: {mare_npu}, mere_fused: {mere_fused},"
f" mere_npu: {mere_npu}, rmse_fused: {rmse_fused}, rmse_npu: {rmse_npu};")
assert mare_fused / mare_npu <= MARE_L1 if mare_npu != 0.0 else abs(mare_fused - mare_npu) < 1e-6, \
f"{print_msg} mare error ratio does not meet the requirement"
assert mere_fused / mere_npu <= MERE_L1 if mere_npu != 0.0 else abs(mere_fused - mere_npu) < 1e-6, \
f"{print_msg} mere error ratio does not meet the requirement"
assert rmse_fused / rmse_npu <= RMSE_L1 if rmse_npu != 0.0 else (rmse_fused - rmse_npu) < 1e-6, \
f"{print_msg} rmse error ratio does not meet the requirement"