"""
Precision check tool.
"""
import numpy as np
class PrecisionChecker:
"""
Check precision between golden data and input data.
Args:
cos_sim_thd: cos similarity threshold
l1_norm_thd: l1 normalization threshold
kl_dvg_thd: kl divergence threshold
"""
def __init__(self, cos_sim_thd=0.999, l1_norm_thd=0.01, kl_dvg_thd=0.005):
self.cos_sim_thd = cos_sim_thd
self.l1_norm_thd = l1_norm_thd
self.kl_dvg_thd = kl_dvg_thd
def check_precision(self, golden_np, input_np):
"""
Check precision between golden data and input data.
Args:
golden_np: golden data, shape (batch_size, seq_length, input_size)
input_np: input data, shape (batch_size, seq_length, input_size)
Returns:
True or False. True means precision check passed, otherwise failed.
"""
if not golden_np.dtype == input_np.dtype:
raise ValueError("The dtype of golden data is not the same as input data")
if not golden_np.shape == input_np.shape:
raise ValueError("The shape of golden data is not the same as input data")
golden_flatten = golden_np.flatten()
input_flatten = input_np.flatten()
kl = self._kl_divergence(golden_flatten, input_flatten)
cos = self._cosine_similarity_numpy(golden_flatten, input_flatten)
l1_norm = (np.abs(input_flatten).sum() / np.abs(golden_flatten).sum()) - 1
if cos > self.cos_sim_thd and l1_norm < self.l1_norm_thd and kl < self.kl_dvg_thd:
return True
raise AssertionError(f"Precision check failed: "
f"cos similarity={cos} (required>{self.cos_sim_thd}), "
f"l1_norm={l1_norm} (required<{self.l1_norm_thd}), "
f"kl={kl} (required<{self.l1_norm_thd}).")
def _kl_divergence(self, golden_flatten, input_flatten):
"""
Compute the KL divergence between flattened golden data and flattened input data.
Args:
golden_flatten: flattened golden data, shape (batch_size * seq_length * input_size, )
input_flatten: flattened input data, shape (batch_size * seq_length * input_size, )
Returns:
the sum of KL divergence across all elements
"""
def log_softmax(x, axis=-1):
x_max = np.max(x, axis=axis, keepdims=True)
log_sum_exp = np.log(np.sum(np.exp(x - x_max), axis=axis, keepdims=True))
return x - x_max - log_sum_exp
log_golden = log_softmax(golden_flatten, axis=-1)
log_input = log_softmax(input_flatten, axis=-1)
p = np.exp(log_golden)
kl = p * (log_golden - log_input)
return np.sum(kl)
def _cosine_similarity_numpy(self, golden_flatten, input_flatten, axis=-1):
"""
Compute cosine similarity between flattened golden data and flattened input data.
Args:
golden_flatten: flattened golden data, shape (batch_size * seq_length * input_size, )
input_flatten: flattened input data, shape (batch_size * seq_length * input_size, )
axis: along this axis to compute similarity, default is -1
Returns:
cosine similarity between flattened golden data and flattened input data
"""
norm1 = np.linalg.norm(golden_flatten, axis=axis, keepdims=True)
norm2 = np.linalg.norm(input_flatten, axis=axis, keepdims=True)
norm_product = norm1 * norm2
zero_mask = norm_product == 0
norm_product[zero_mask] = 1
dot_product = np.sum(golden_flatten * input_flatten, axis=axis, keepdims=True)
cosine_sim = dot_product / norm_product
cosine_sim[zero_mask] = 0
return np.squeeze(cosine_sim)