"""
"""
import os
import math
import logging
import torch
import torch_npu
import numpy as np
from numpy.testing import assert_allclose
import pypto
small_value_thres_dict = {
torch.float16: 2**-11,
torch.bfloat16: 2**-8,
torch.float32: 2**-14,
torch.uint8: 2**-4, torch.float8_e4m3fn: 2**-4
}
small_value_error_thres_dict = {
torch.float16: 2**-16,
torch.bfloat16: 2**-16,
torch.float32: 2**-30,
torch.uint8: 2**-6, torch.float8_e4m3fn: 2**-6
}
def get_split_index(golden_data, dtype):
thres = small_value_thres_dict[dtype]
large_mask = torch.abs(golden_data) >= thres
small_mask = torch.abs(golden_data) < thres
return large_mask, small_mask, thres
def compute_matrix_small_value(input_data, golden_data, dtype, small_mask):
if not torch.any(small_mask):
return 0
thres = small_value_error_thres_dict[dtype]
error_count = torch.sum(torch.abs(input_data[small_mask] - golden_data[small_mask]) > thres).item()
return error_count
def compute_matrix_large_value(input_data, golden_data, large_mask):
if not torch.any(large_mask):
return 0, 0, 0
input_large = input_data[large_mask]
golden_large = golden_data[large_mask]
abs_diff = torch.abs(input_large - golden_large)
relative_error = abs_diff / (torch.abs(golden_large) + 1e-7)
mare = torch.max(relative_error).item()
mere = torch.mean(relative_error).item()
rmse = torch.sqrt(torch.mean((input_large - golden_large) ** 2)).item()
return mare, mere, rmse
def compute_re_matrix(input_value, bm_value, small_value_thres):
if math.isinf(bm_value) or math.isnan(bm_value):
return 1
if math.isinf(input_value) or math.isnan(input_value):
return 1000
return input_value / max(bm_value, small_value_thres)
def compute_re_triplet_matrix(npu_matrix, golden_matrix, small_value_thres):
mare_npu, mere_npu, rmse_npu = npu_matrix
mare_bm, mere_bm, rmse_bm = golden_matrix
mare_matrix = compute_re_matrix(mare_npu, mare_bm, small_value_thres)
mere_matrix = compute_re_matrix(mere_npu, mere_bm, small_value_thres)
rmse_matrix = compute_re_matrix(rmse_npu, rmse_bm, small_value_thres)
return mare_matrix, mere_matrix, rmse_matrix
def precision_compare_triple(npu_data, bm_data, golden_data, thres=(2, 1.2, 1.2)):
dtype = npu_data.dtype
if dtype in ["int8", "int32"]:
raise NotImplementedError("precision compare triplet only support float")
if dtype == torch.uint8:
npu_data = torch_npu.npu_dtype_cast(npu_data, torch.float32, input_dtype=torch_npu.hifloat8)
bm_data = torch_npu.npu_dtype_cast(bm_data, torch.float32, input_dtype=torch_npu.hifloat8)
golden_data = torch_npu.npu_dtype_cast(golden_data, torch.float32, input_dtype=torch_npu.hifloat8)
else:
npu_data = npu_data.to(torch.float32)
bm_data = bm_data.to(torch.float32)
golden_data = golden_data.to(torch.float32)
npu_data = npu_data.cpu()
bm_data = bm_data.cpu()
golden_data = golden_data.cpu()
large_value_idx, small_value_idx, small_value_thres = get_split_index(golden_data, dtype)
npu_error_count = compute_matrix_small_value(npu_data, golden_data, dtype, small_value_idx)
bm_error_count = compute_matrix_small_value(bm_data, golden_data, dtype, small_value_idx)
small_value_matrix = npu_error_count / max(bm_error_count, 1)
mare_npu, mere_npu, rmse_npu = compute_matrix_large_value(npu_data, golden_data, large_value_idx)
mare_bm, mere_bm, rmse_bm = compute_matrix_large_value(bm_data, golden_data, large_value_idx)
mare_matrix, mere_matrix, rmse_matrix = compute_re_triplet_matrix(
[mare_npu, mere_npu, rmse_npu], [mare_bm, mere_bm, rmse_bm], small_value_thres)
is_mare_acceptable = mare_matrix <= thres[0]
is_mere_acceptable = mere_matrix <= thres[1]
is_rmse_acceptable = rmse_matrix <= thres[2]
if small_value_matrix <= 2 and is_mare_acceptable and is_mere_acceptable and is_rmse_acceptable:
result = "PASS"
else:
result = "FAILED"
return result, mare_matrix, mere_matrix, rmse_matrix, small_value_matrix