import os
import unittest
import torch
from amct_pytorch.classic.graph_based.amct_pytorch.custom_op import (
arq_cali_pytorch,
arq_real_pytorch,
)
CUR_DIR = os.path.split(os.path.realpath(__file__))[0]
DEVICE = torch.device('cpu')
class TestArqOp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.temp_folder = os.path.join(CUR_DIR, 'test_arq')
if not os.path.isdir(cls.temp_folder):
os.makedirs(cls.temp_folder)
@classmethod
def tearDownClass(cls):
os.popen('rm -r ' + cls.temp_folder)
def setUp(self):
pass
def tearDown(self):
pass
def test_arq_cali_pytorch_channelwise_f_withoffset_f(self):
'''channel_wise: F withoffset:F '''
data_list = [[-1.0] * 12, [0.0] * 12, [1.0] * 12, [-1.0, 0.0, 1.0] * 4]
input_data = torch.tensor(data_list, device=DEVICE)
scale, offset, output_data = arq_cali_pytorch(input_data, 8, False, False)
scale_except = torch.tensor([0.007844], device=DEVICE)
scale_err = torch.abs(scale_except - scale).to('cpu')
offset_except = torch.tensor([0], dtype=torch.int8, device=DEVICE)
err = torch.abs(output_data - input_data).to('cpu')
self.assertTrue(torch.ge(1e-4 * torch.ones([1]), scale_err))
self.assertEqual(offset_except, offset)
self.assertTrue(torch.ge(1e-2 * torch.ones([4, 12]), err).numpy().all())
def test_arq_cali_pytorch_channelwise_t_withoffset_f(self):
'''channel_wise: T withoffset:F '''
data_list = [[-1.0] * 12, [0.0] * 12, [1.0] * 12, [-1.0, 0.0, 1.0] * 4]
input_data = torch.tensor(data_list, device=DEVICE)
scale, offset, output_data = arq_cali_pytorch(input_data, 8, True, False)
scale_except = torch.tensor([0.007844, 1.000000, 0.007844, 0.007844], device=DEVICE)
scale_err = torch.abs(scale_except - scale).to('cpu')
offset_except = torch.tensor([0, 0, 0, 0], dtype=torch.int32, device=DEVICE)
err = torch.abs(output_data - input_data).to('cpu')
self.assertTrue(torch.ge(1e-4 * torch.ones([1]), scale_err).numpy().all())
self.assertTrue(torch.equal(offset_except, offset))
self.assertTrue(torch.ge(1e-2 * torch.ones([4, 12]), err).numpy().all())
def test_arq_cali_pytorch_channelwise_t_withoffset_t(self):
data_list = [[-1.0] * 12, [0.0] * 12, [1.0] * 12, [-1.0, 0.0, 1.0] * 4]
input_data = torch.tensor(data_list, device=DEVICE)
scale, offset, output_data = arq_cali_pytorch(input_data, 8, True, True)
scale_except = torch.tensor([0.003923, 1.000000, 0.003923, 0.007844], device=DEVICE)
scale_err = torch.abs(scale_except - scale).to('cpu')
offset_except = torch.tensor([127, -128, -128, -1], dtype=torch.int32, device=DEVICE)
err = torch.abs(output_data - input_data).to('cpu')
self.assertTrue(torch.ge(1e-4 * torch.ones([1]), scale_err).numpy().all())
self.assertTrue(torch.equal(offset_except, offset))
self.assertTrue(torch.ge(1e-2 * torch.ones([4, 12]), err).numpy().all())
def test_arq_real_pytorch(self):
data_list = [[-1.0] * 12, [0.0] * 12, [1.0] * 12, [-1.0, 0.0, 1.0] * 4]
scale_list = [0.003923, 1.000000, 0.003923, 0.007844]
offset_list = [127, -128, -128, -1]
input_data = torch.tensor(data_list, device=DEVICE)
scale = torch.tensor(scale_list, device=DEVICE)
offset = torch.tensor(offset_list, device=DEVICE, dtype=torch.int32)
output_data = arq_real_pytorch(input_data, scale, offset, 8)
out_list = [[-128] * 12, [-128] * 12, [127] * 12, [-128, -1, 126] * 4]
out_except = torch.tensor(out_list, dtype=torch.int8, device=DEVICE)
self.assertTrue(torch.equal(out_except, output_data))