import numpy as np
import torch
from torch_npu.testing.testcase import TestCase
class TestCase(TestCase):
def assert_acceptable_deviation(self, a, b, deviation):
if a.shape != b.shape:
self.fail("shape error")
if a.dtype != b.dtype:
self.fail("dtype error")
if a.dtype == torch.uint8:
result = np.abs(a.to(torch.int16) - b.to(torch.int16))
elif a.dtype == np.uint8:
result = np.abs(a.astype(np.int16) - b.astype(np.int16))
else:
result = np.abs(a - b)
if result.max() > deviation:
self.fail(f"result error, got deviation {result.max()}")