import copy
import os
import shutil
import unittest
from collections import defaultdict
import numpy as np
import onnx
import onnxruntime as ort
import torch
from torch import nn
from amct_pytorch.classic.graph_based.amct_pytorch.custom_op.utils import (
copy_tensor,
)
from amct_pytorch.classic.graph_based.amct_pytorch.nn.module.quantization.conv1d import (
Conv1dQAT,
)
from amct_pytorch.classic.graph_based.amct_pytorch.nn.module.quantization.conv2d import (
Conv2dQAT,
)
from amct_pytorch.classic.graph_based.amct_pytorch.nn.module.quantization.conv3d import (
Conv3dQAT,
)
from amct_pytorch.classic.graph_based.amct_pytorch.nn.module.quantization.conv_transpose_1d import (
ConvTranspose1dQAT,
)
from amct_pytorch.classic.graph_based.amct_pytorch.nn.module.quantization.conv_transpose_2d import (
ConvTranspose2dQAT,
)
from amct_pytorch.classic.graph_based.amct_pytorch.nn.module.quantization.linear import (
LinearQAT,
)
from amct_pytorch.classic.graph_based.amct_pytorch.nn.module.quantization.matmul import (
MatMulQAT,
)
from amct_pytorch.classic.graph_based.amct_pytorch.utils.log import LOGGER
RETRAIN_DATA_CONFIG = 'retrain_data_config'
CUR_DIR = os.path.split(os.path.realpath(__file__))[0]
quant_configs = [
{
'retrain_weight_config': {},
RETRAIN_DATA_CONFIG: {}
},
{
'retrain_weight_config':
{'weights_retrain_algo': 'ulq_retrain'},
RETRAIN_DATA_CONFIG:
{'batch_num': 3}
},
{
'retrain_weight_config':
{'channel_wise': False},
RETRAIN_DATA_CONFIG:
{'batch_num': 3,
'clip_min': -1.0,
'clip_max': 1.0}
},
{
'retrain_weight_config':
{'weights_retrain_algo': 'ulq_retrain',
'channel_wise': False
},
RETRAIN_DATA_CONFIG:
{'batch_num': 3,
'fixed_min': True}
},
{
'retrain_weight_config':
{'weights_retrain_algo': 'ulq_retrain',
'channel_wise': False},
RETRAIN_DATA_CONFIG:
{'batch_num': 3,
'clip_min': -1.0,
'clip_max': 1.0,
'fixed_min': False}
},
{
'retrain_weight_config':
{'dst_type': 'INT8',
'weights_retrain_algo': 'ulq_retrain'},
RETRAIN_DATA_CONFIG:
{'dst_type': 'INT8',
'batch_num': 3,
'clip_min': -1.0,
'clip_max': 1.0,
'fixed_min': False}
}
]
MATMUL_QAT_ONNX = 'matmul_qat.onnx'
CPUEXECUTIONPROVIDER = 'CPUExecutionProvider'
def similarity(data0, data1):
data0_nan = np.isnan(data0)
data0[data0_nan] = 1
data1_nan = np.isnan(data1)
data1[data1_nan] = 1
cos_sim = np.sum(np.multiply(data0, data1).astype(np.float64))\
/ (np.sqrt(np.sum(data0.astype(np.float64)**2))\
* np.sqrt(np.sum(data1.astype(np.float64)**2))) * 100
if (data0 == data1).all():
cos_sim = 100
if np.isnan(cos_sim) or np.isinf(cos_sim):
data0 = np.divide(data0, np.power(10, 38))
data1 = np.divide(data1, np.power(10, 38))
cos_sim = np.sum(np.multiply(data0, data1).astype(np.float64))\
/ (np.sqrt(np.sum(data0.astype(np.float64)**2))\
* np.sqrt(np.sum(data1.astype(np.float64)**2))) * 100
if np.isnan(cos_sim) or np.isinf(cos_sim):
data0 = np.divide(data0, np.power(10, 38))
data1 = np.divide(data1, np.power(10, 38))
cos_sim = np.sum(np.multiply(data0, data1).astype(np.float64))\
/ (np.sqrt(np.sum(data0.astype(np.float64)**2))\
* np.sqrt(np.sum(data1.astype(np.float64)**2))) * 100
if np.isnan(cos_sim):
cos_sim = 0
return cos_sim
class TestQatOp(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
@classmethod
def tearDownClass(cls):
pass
def setUp(self):
if os.path.exists('./amct_log/amct_pytorch.classic.graph_based.amct_pytorch.log'):
os.remove('./amct_log/amct_pytorch.classic.graph_based.amct_pytorch.log')
LOGGER.logi('amct_pytorch.classic.graph_based.amct_pytorch.log is initialized successfully.')
def test_down(self):
pass
def test_qat_base_init_success(self):
quant_conf = {
RETRAIN_DATA_CONFIG: {
'clip_max': 1.0,
'clip_min': -1.0
}
}
mod = Conv2dQAT(1, 1, 1, config=quant_conf)
self.assertIsInstance(mod.retrain_data_config, dict)
self.assertIsInstance(mod.retrain_weight_config, dict)
self.assertEqual(mod.retrain_data_config.get('clip_max'), 1.0)
self.assertEqual(mod.retrain_data_config.get('clip_min'), -1.0)
self.assertEqual(mod.act_num_bits, 8)
self.assertEqual(mod.wts_num_bits, 8)
def test_qat_base_init_failed_wrong_config_data_type(self):
wrong_quant_conf = {
RETRAIN_DATA_CONFIG: {
'dst_type': 1,
'batch_num': '1',
'fixed_min': 1,
'clip_max': '1.0',
'clip_min': '-1.0'
},
'retrain_weight_config': {
'dst_type': 1,
'weights_retrain_algo': 1,
'channel_wise': 1
}
}
with self.assertRaises(ValueError) as cm:
Conv2dQAT(1, 1, 1, config=wrong_quant_conf)
with open('./amct_log/amct_pytorch.classic.graph_based.amct_pytorch.log', 'r') as f:
log_content = f.read()
for item in wrong_quant_conf.get(RETRAIN_DATA_CONFIG).keys():
self.assertIn(item, log_content)
with open('./amct_log/amct_pytorch.classic.graph_based.amct_pytorch.log', 'r') as f:
log_content = f.read()
for item in wrong_quant_conf.get('retrain_weight_config').keys():
self.assertIn(item, log_content)
def test_qat_base_init_failed_clip_max_min_not_both_set(self):
wrong_quant_conf = {
RETRAIN_DATA_CONFIG: {
'clip_max': 1.0,
}
}
Conv2dQAT(1, 1, 1, config=wrong_quant_conf)
def test_qat_base_init_failed_wrong_config_data_scope(self):
wrong_quant_conf = {
RETRAIN_DATA_CONFIG: {
'dst_type': 1,
},
'retrain_weight_config': {
'dst_type': 1,
'weights_retrain_algo': '1',
}
}
with self.assertRaises(ValueError) as cm:
Conv2dQAT(1, 1, 1, config=wrong_quant_conf)
with open('./amct_log/amct_pytorch.classic.graph_based.amct_pytorch.log', 'r') as f:
log_content = f.read()
for item in wrong_quant_conf.get(RETRAIN_DATA_CONFIG).keys():
self.assertIn(item, log_content)
with open('./amct_log/amct_pytorch.classic.graph_based.amct_pytorch.log', 'r') as f:
log_content = f.read()
for item in wrong_quant_conf.get('retrain_weight_config').keys():
self.assertIn(item, log_content)
wrong_quant_conf = {
RETRAIN_DATA_CONFIG: {
'batch_num': -1
}
}
with self.assertRaises(ValueError) as cm:
Conv2dQAT(1, 1, 1, config=wrong_quant_conf)
wrong_quant_conf = {
RETRAIN_DATA_CONFIG: {
'clip_min': 1.0,
'clip_max': 1.0
}
}
with self.assertRaises(ValueError) as cm:
Conv2dQAT(1, 1, 1, config=wrong_quant_conf)
wrong_quant_conf = {
RETRAIN_DATA_CONFIG: {
'clip_min': -1.0,
'clip_max': -1.0
}
}
with self.assertRaises(ValueError) as cm:
Conv2dQAT(1, 1, 1, config=wrong_quant_conf)
def test_qat_base_register_qat_params(self):
mod = Conv2dQAT(1, 1, 1)
self.assertEqual(mod.cur_batch, torch.tensor(0))
self.assertEqual(mod.acts_clip_max, torch.tensor(1.0))
self.assertEqual(mod.acts_clip_min, torch.tensor(-1.0))
self.assertTrue(np.isnan(mod.acts_clip_max_pre))
self.assertTrue(np.isnan(mod.acts_clip_min_pre))
self.assertTrue(np.isnan(mod.acts_scale))
self.assertEqual(mod.acts_offset_deploy, torch.tensor([0]))
self.assertTrue(np.isnan(mod.wts_scales))
self.assertTrue(np.isnan(mod.wts_offsets))
self.assertEqual(mod.wts_offsets_deploy, torch.tensor(0))
self.assertEqual(mod.s_rec_flag, torch.tensor(False))
def test_get_ori_op_params_redundant_params(self):
bak_required_params = copy.deepcopy(Conv2dQAT._required_params)
Conv2dQAT._required_params = ('err_param',)
with self.assertRaises(RuntimeError) as cm:
Conv2dQAT.from_float(torch.nn.Conv2d(1, 1, 1))
Conv2dQAT._required_params = bak_required_params
def test_check_qat_config_mismatch_dst_type(self):
wrong_quant_conf = {
RETRAIN_DATA_CONFIG: {
'dst_type': 'INT16',
},
'retrain_weight_config': {
'dst_type': 'INT16',
}
}
with self.assertRaises(ValueError) as cm:
Conv2dQAT(1, 1, 1, config=wrong_quant_conf)
def test_qat_init_d16w8(self):
quant_conf = {
RETRAIN_DATA_CONFIG: {
'dst_type': 'INT16',
},
'retrain_weight_config': {
'dst_type': 'INT8',
}
}
mod = Conv2dQAT(1, 1, 1, config=quant_conf)
self.assertEqual(mod.act_num_bits, 16)
self.assertEqual(mod.wts_num_bits, 8)
class TestConv2dQAT(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
@classmethod
def tearDownClass(cls):
pass
def setUp(self):
pass
def test_down(self):
pass
def test_conv2d_qat_from_float_success(self):
mod = torch.nn.Conv2d(1, 1, 1, padding_mode='zeros')
qat_mod = Conv2dQAT.from_float(mod)
self.assertIsInstance(qat_mod, Conv2dQAT)
def test_conv2d_qat_from_float_failed_padding_mode_not_zeros(self):
mod = torch.nn.Conv2d(1, 1, 1, padding_mode='reflect')
with self.assertRaises(ValueError) as cm:
Conv2dQAT.from_float(mod)
def test_conv2d_qat_from_float_failed_ori_op_not_conv2d(self):
mod = torch.nn.Conv3d(1, 1, 1)
with self.assertRaises(TypeError) as cm:
Conv2dQAT.from_float(mod)
def test_conv2d_qat_forward(self):
mod = Conv2dQAT(3, 16, 1)
inputs = torch.randn((3, 3, 224, 224))
output = mod.forward(inputs)
self.assertIsNotNone(output)
def test_conv2d_qat_forward_ulq_retrain(self):
quant_config = {'retrain_enable': True,
'retrain_weight_config':
{'weights_retrain_algo': 'ulq_retrain'}
}
mod = Conv2dQAT(3, 16, 1, config=quant_config)
inputs = torch.randn((3, 3, 224, 224))
output = mod.forward(inputs)
self.assertIsNotNone(output)
def test_conv2d_qat_forward_do_init_false(self):
mod = Conv2dQAT(3, 16, 1)
mod.do_init = False
inputs = torch.randn((3, 3, 224, 224))
output = mod.forward(inputs)
self.assertIsNotNone(output)
def test_conv2d_qat_forward_do_init_false_cur_batch_one(self):
mod = Conv2dQAT(3, 16, 1)
mod.do_init = False
copy_tensor(mod.cur_batch, torch.tensor(1))
inputs = torch.randn((3, 3, 224, 224))
output = mod.forward(inputs)
self.assertIsNotNone(output)
def test_conv2d_qat_forward_do_init_false_cur_batch_two(self):
mod = Conv2dQAT(3, 16, 1)
mod.do_init = False
copy_tensor(mod.cur_batch, torch.tensor(2))
inputs = torch.randn((3, 3, 224, 224))
output = mod.forward(inputs)
self.assertIsNotNone(output)
def test_conv2d_qat_unsupport_shape_inputs(self):
mod = Conv2dQAT(3, 16, 1)
with self.assertRaises(RuntimeError) as cm:
ret = mod.forward(torch.randn((3, 224, 224)))
class TestConvTranspose2dQAT(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
@classmethod
def tearDownClass(cls):
pass
def setUp(self):
pass
def test_down(self):
pass
def test_conv_transpose_2d_qat_from_float_success(self):
mod = torch.nn.ConvTranspose2d(1, 1, 1, padding_mode='zeros')
qat_mod = ConvTranspose2dQAT.from_float(mod)
self.assertIsInstance(qat_mod, ConvTranspose2dQAT)
def test_conv_transpose_2d_qat_from_float_failed_ori_op_not_conv_transpose_2d(self):
mod = torch.nn.Conv2d(1, 1, 1)
with self.assertRaises(TypeError) as cm:
ConvTranspose2dQAT.from_float(mod)
def test_conv_transpose_2d_qat_forward(self):
mod = ConvTranspose2dQAT(3, 16, 1)
inputs = torch.randn(3, 3, 224, 224)
output = mod.forward(inputs)
self.assertIsNotNone(output)
def test_conv_transpose_2d_qat_padding_mode_not_zero(self):
with self.assertRaises(ValueError) as cm:
ConvTranspose2dQAT(1, 1, 1, padding_mode='reflect')
def test_conv_transpose_2d_qat_forward_ulq_retrain(self):
quant_config = {'retrain_enable': True,
'retrain_weight_config':
{'weights_retrain_algo': 'ulq_retrain'}
}
mod = ConvTranspose2dQAT(3, 16, 1, config=quant_config)
inputs = torch.randn((3, 3, 224, 224))
output = mod.forward(inputs)
self.assertIsNotNone(output)
def test_conv_transpose_2d_qat_forward_do_init_false(self):
mod = ConvTranspose2dQAT(3, 16, 1)
mod.do_init = False
inputs = torch.randn((3, 3, 224, 224))
output = mod.forward(inputs)
self.assertIsNotNone(output)
def test_conv_transpose_2d_qat_forward_do_init_false_cur_batch_one(self):
mod = ConvTranspose2dQAT(3, 16, 1)
mod.do_init = False
copy_tensor(mod.cur_batch, torch.tensor(1))
inputs = torch.randn((3, 3, 224, 224))
output = mod.forward(inputs)
self.assertIsNotNone(output)
def test_conv_transpose_2d_qat_forward_do_init_false_cur_batch_two(self):
mod = ConvTranspose2dQAT(3, 16, 1)
mod.do_init = False
copy_tensor(mod.cur_batch, torch.tensor(2))
inputs = torch.randn((3, 3, 224, 224))
output = mod.forward(inputs)
self.assertIsNotNone(output)
def test_conv_transpose_2d_qat_unsupport_shape_inputs(self):
mod = ConvTranspose2dQAT(3, 16, 1)
with self.assertRaises(RuntimeError) as cm:
output = mod.forward(torch.randn(3, 224, 224))
class TestConv3dQAT(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
@classmethod
def tearDownClass(cls):
pass
def setUp(self):
pass
def test_down(self):
pass
def test_conv3d_qat_limit_check_01(self):
with self.assertRaises(RuntimeError):
Conv3dQAT(1, 1, 1, padding_mode='reflect')
def test_conv3d_qat_limit_check_02(self):
with self.assertRaises(RuntimeError):
Conv3dQAT(1, 1, 1, dilation=(1, 1))
def test_conv3d_qat_limit_check_03(self):
with self.assertRaises(RuntimeError):
Conv3dQAT(1, 1, 1, dilation=(2, 1, 1))
def test_conv3dqat_limit_check_04(self):
qat_op = Conv3dQAT(1, 1, 1, dilation=(1, 1, 1))
self.assertTrue(qat_op.do_init)
def test_conv3d_qat_forward(self):
qat_op = Conv3dQAT(3, 3, 4, dilation=(1, 2, 1))
output = qat_op.forward(torch.randn(30, 3, 12, 64, 64))
self.assertIsNotNone(output)
def test_conv3d_qat_from_float_01(self):
ori_op = torch.nn.Conv3d(3, 3, 4, dilation=(1, 2, 1))
qat_op = Conv3dQAT.from_float(ori_op)
output = qat_op.forward(torch.randn(30, 3, 12, 64, 64))
self.assertIsNotNone(output)
def test_conv3d_qat_from_float_02(self):
ori_op = torch.nn.Conv2d(2, 3, 4)
with self.assertRaises(RuntimeError) as cm:
Conv3dQAT.from_float(ori_op)
def test_conv3d_qat_unsupport_shape_inputs(self):
qat_op = Conv3dQAT(3, 3, 4, dilation=(1, 2, 1))
with self.assertRaises(RuntimeError) as cm:
qat_op.forward(torch.randn(3, 12, 64, 64))
class TestLinearQAT(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
@classmethod
def tearDownClass(cls):
pass
def setUp(self):
pass
def test_down(self):
pass
def test_lineard_qat_limit_check_01(self):
with self.assertRaises(RuntimeError) as cm:
LinearQAT(1, 1,
config={'retrain_weight_config': {'channel_wise': True}})
def test_lineard_qat_forward(self):
qat_op = LinearQAT(16, 10,
config={'retrain_weight_config': {'channel_wise': False}})
output = qat_op.forward(torch.rand((128, 16)))
self.assertIsNotNone(output)
def test_lineard_qat_from_float_01(self):
ori_op = torch.nn.Linear(8, 10)
qat_linear = LinearQAT.from_float(ori_op,
config={'retrain_weight_config': {'channel_wise': False}})
self.assertIsNotNone(qat_linear)
def test_lineard_qat_from_float_02(self):
ori_op = torch.nn.Conv3d(1, 1, 1)
with self.assertRaises(RuntimeError) as cm:
LinearQAT.from_float(ori_op,
config={'retrain_weight_config': {'channel_wise': False}})
def test_lineard_qat_channel_wise_default_false_00(self):
config = {'retrain_weight_config': {'dst_type': 'INT8'}}
LinearQAT(1, 1, config=config)
self.assertEqual(config.get('retrain_weight_config'), {'dst_type': 'INT8'})
def test_lineard_qat_channel_wise_default_false_01(self):
config = None
LinearQAT(1, 1, config=config)
self.assertIsNone(config)
def set_module(model, sub_module_name, module):
tokens = sub_module_name.split('.')
sub_tokens = tokens[:-1]
cur_mod = model
for s in sub_tokens:
cur_mod = getattr(cur_mod, s)
setattr(cur_mod, tokens[-1], module)
class NetConv1d(nn.Module):
def __init__(self):
super(NetConv1d, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv1d(2, 16, kernel_size=3, bias=False),
nn.BatchNorm1d(16))
self.layer2 = nn.Sequential(
nn.Conv1d(16, 16, kernel_size=3, bias=True),
nn.BatchNorm1d(16),
nn.ReLU(inplace=True))
self.layer3 = nn.Sequential(
nn.Conv1d(16, 16, kernel_size=3, groups=16),
nn.BatchNorm1d(16))
self.layer4 = nn.Sequential(
nn.Conv1d(16, 16, kernel_size=3, groups=16),
nn.BatchNorm1d(16),
nn.ReLU(inplace=True))
self.layer5 = nn.Sequential(
nn.Conv1d(16, 32, kernel_size=3, groups=4),
nn.BatchNorm1d(32))
self.layer6 = nn.Sequential(
nn.Conv1d(32, 8, kernel_size=3, groups=8),
nn.BatchNorm1d(8),
nn.ReLU(inplace=True))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
return x
class NetConv1dQAT(nn.Module):
def __init__(self):
super(NetConv1dQAT, self).__init__()
self.layer1 = nn.Sequential(
Conv1dQAT(2, 16, kernel_size=3, bias=False),
nn.BatchNorm1d(16))
self.layer2 = nn.Sequential(
Conv1dQAT(16, 16, kernel_size=3, bias=True),
nn.BatchNorm1d(16),
nn.ReLU(inplace=True))
self.layer3 = nn.Sequential(
Conv1dQAT(16, 16, kernel_size=3, groups=16),
nn.BatchNorm1d(16))
self.layer4 = nn.Sequential(
Conv1dQAT(16, 16, kernel_size=3, groups=16),
nn.BatchNorm1d(16),
nn.ReLU(inplace=True))
self.layer5 = nn.Sequential(
Conv1dQAT(16, 32, kernel_size=3, groups=4),
nn.BatchNorm1d(32))
self.layer6 = nn.Sequential(
Conv1dQAT(32, 8, kernel_size=3, groups=8),
nn.BatchNorm1d(8),
nn.ReLU(inplace=True))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
return x
class TestConv1dQAT(unittest.TestCase):
@classmethod
def setUpClass(cls):
LOGGER.logi(f'torch version: {torch.__version__}')
cls.temp_folder = os.path.join(CUR_DIR, 'test_quant_retrain_quant_fusion_model')
if not os.path.isdir(cls.temp_folder):
os.makedirs(cls.temp_folder)
@classmethod
def tearDownClass(cls):
os.popen('rm -r ' + cls.temp_folder)
def setUp(self):
pass
def test_down(self):
pass
def test_conv1d_qat_from_float_failed_ori_op_not_conv1d(self):
mod = torch.nn.Conv3d(1, 1, 1)
with self.assertRaises(TypeError) as cm:
Conv1dQAT.from_float(mod)
def test_conv1d_qat_scratch_from_zero(self):
net_conv1d_qat = NetConv1dQAT()
for _ in range(5):
ret = net_conv1d_qat.forward(torch.rand(1, 2, 28))
self.assertIsNotNone(ret)
def test_conv1_qat_convert_from_ori_op(self):
net_conv1d = NetConv1d()
idx = 0
for name, module in net_conv1d.named_modules():
if isinstance(module, nn.Conv1d):
qat_module = Conv1dQAT.from_float(
module)
set_module(net_conv1d, name, qat_module)
idx += 1
for _ in range(5):
ret = net_conv1d.forward(torch.rand(1, 2, 28))
self.assertIsNotNone(ret)
def test_conv1d_qat_unsupport_shape_inputs(self):
net_conv1d_qat = NetConv1dQAT()
with self.assertRaises(RuntimeError) as cm:
ret = net_conv1d_qat.forward(torch.rand(1, 2, 28, 28))
def test_conv1d_qat_op_not_support_padding_mode_not_zeros(self):
self.assertRaises(ValueError, Conv1dQAT, 1, 1, 1, padding_mode='reflect')
def test_conv1d_qat_not_support_dtype(self):
mod = Conv1dQAT(1, 1, 1)
mod = mod.to(torch.float64)
self.assertRaises(ValueError, mod, torch.randn(1, 1, 1).to(torch.float64))
class NetConvTranspose1dQAT(nn.Module):
def __init__(self):
super(NetConvTranspose1dQAT, self).__init__()
self.layer1 = nn.Sequential(
ConvTranspose1dQAT(2, 16, kernel_size=3, bias=False,
config=quant_configs[0]),
nn.BatchNorm1d(16))
self.layer2 = nn.Sequential(
ConvTranspose1dQAT(16, 16, kernel_size=3, bias=True,
config=quant_configs[1]),
nn.BatchNorm1d(16),
nn.ReLU(inplace=True))
self.layer3 = nn.Sequential(
ConvTranspose1dQAT(16, 16, kernel_size=3,
config=quant_configs[2]),
nn.BatchNorm1d(16))
self.layer4 = nn.Sequential(
ConvTranspose1dQAT(16, 16, kernel_size=3,
config=quant_configs[3]),
nn.BatchNorm1d(16),
nn.ReLU(inplace=True))
self.layer5 = nn.Sequential(
ConvTranspose1dQAT(16, 32, kernel_size=3,
config=quant_configs[4]),
nn.BatchNorm1d(32))
self.layer6 = nn.Sequential(
ConvTranspose1dQAT(32, 8, kernel_size=3,
config=quant_configs[5]),
nn.BatchNorm1d(8),
nn.ReLU(inplace=True))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
return x
class NetConvTranspose1d(nn.Module):
def __init__(self):
super(NetConvTranspose1d, self).__init__()
self.layer1 = nn.Sequential(
nn.ConvTranspose1d(2, 16, kernel_size=3, bias=False),
nn.BatchNorm1d(16))
self.layer2 = nn.Sequential(
nn.ConvTranspose1d(16, 16, kernel_size=3, bias=True),
nn.BatchNorm1d(16),
nn.ReLU(inplace=True))
self.layer3 = nn.Sequential(
nn.ConvTranspose1d(16, 16, kernel_size=3, groups=16),
nn.BatchNorm1d(16))
self.layer4 = nn.Sequential(
nn.ConvTranspose1d(16, 16, kernel_size=3, groups=16),
nn.BatchNorm1d(16),
nn.ReLU(inplace=True))
self.layer5 = nn.Sequential(
nn.ConvTranspose1d(16, 32, kernel_size=3),
nn.BatchNorm1d(32))
self.layer6 = nn.Sequential(
nn.ConvTranspose1d(32, 8, kernel_size=3),
nn.BatchNorm1d(8),
nn.ReLU(inplace=True))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
return x
class TestConvTranspose1dQAT(unittest.TestCase):
@classmethod
def setUpClass(cls):
LOGGER.logi(f'torch version: {torch.__version__}')
@classmethod
def tearDownClass(cls):
pass
def setUp(self):
pass
def test_down(self):
pass
def test_conv_transpose_1d_qat_scratch_from_zero(self):
net_conv_transpose_1d_qat = NetConvTranspose1dQAT()
for _ in range(5):
ret = net_conv_transpose_1d_qat.forward(torch.rand(1, 2, 28))
self.assertIsNotNone(ret)
def test_conv_transpose_1d_qat_convert_from_ori_op(self):
net_conv_transpose_1d = NetConvTranspose1d()
idx = 0
for name, module in net_conv_transpose_1d.named_modules():
if isinstance(module, nn.ConvTranspose1d):
qat_module = ConvTranspose1dQAT.from_float(
module, config=quant_configs[idx])
set_module(net_conv_transpose_1d, name, qat_module)
idx += 1
for _ in range(5):
ret = net_conv_transpose_1d.forward(torch.rand(1, 2, 28))
self.assertIsNotNone(ret)
def test_conv_transpose_1d_qat_unsupport_shape_inputs(self):
net_conv_transpose_1d_qat = NetConvTranspose1dQAT()
with self.assertRaises(RuntimeError) as cm:
ret = net_conv_transpose_1d_qat.forward(torch.rand(2, 28))
class TestMatMulQAT(unittest.TestCase):
@classmethod
def setUpClass(cls):
LOGGER.logi(f'torch version: {torch.__version__}')
cls.temp_dir = os.path.join(CUR_DIR, 'tmp')
os.makedirs(cls.temp_dir, exist_ok=True)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.temp_dir)
def setUp(self):
pass
def test_down(self):
pass
def test_check_input_error_wrong_dtype(self):
with self.assertRaises(ValueError):
MatMulQAT.check_input(torch.randn(32, 3).to(torch.float16))
def test_check_input_error_wrong_shape(self):
with self.assertRaises(RuntimeError):
MatMulQAT.check_input(torch.randn(3, 3, 3, 3, 3, 3, 3))
with self.assertRaises(RuntimeError):
MatMulQAT.check_input(torch.randn(3,))
def test_acts_quant_init_success(self):
matmul_op = MatMulQAT(config={RETRAIN_DATA_CONFIG: {'batch_num': 2}})
result = matmul_op.acts_quant_init(
torch.randn(30, 30), matmul_op.acts_quant_params.get('input'), matmul_op.input_init_module)
self.assertFalse(result)
result = matmul_op.acts_quant_init(
torch.randn(30, 30), matmul_op.acts_quant_params.get('input'), matmul_op.input_init_module)
self.assertTrue(result)
self.assertFalse(matmul_op.input_scale.isnan().all())
self.assertFalse(matmul_op.input_clip_max.isnan().all())
self.assertFalse(matmul_op.input_clip_min.isnan().all())
self.assertFalse(matmul_op.input_clip_max_pre.isnan().all())
self.assertFalse(matmul_op.input_clip_min_pre.isnan().all())
def test_acts_quant_success(self):
matmul_op = MatMulQAT(config={RETRAIN_DATA_CONFIG: {'batch_num': 2}})
result = matmul_op.acts_quant(
torch.randn(30, 30), matmul_op.acts_quant_params.get('input'))
self.assertEqual(result.shape, (30, 30))
self.assertFalse(matmul_op.input_scale.isnan().all())
self.assertFalse(matmul_op.input_clip_max.isnan().all())
self.assertFalse(matmul_op.input_clip_min.isnan().all())
self.assertFalse(matmul_op.input_clip_max_pre.isnan().all())
self.assertFalse(matmul_op.input_clip_min_pre.isnan().all())
def test_matmul_qat_init_failed_wrong_device_dtype(self):
with self.assertRaises(TypeError):
MatMulQAT(device=3)
def test_matmul_qat_init_failed_wrong_config_dtype(self):
with self.assertRaises(TypeError):
MatMulQAT(config=3)
def test_matmul_qat_init_failed_wrong_config_item_dtype(self):
with self.assertRaises(ValueError):
MatMulQAT(config={RETRAIN_DATA_CONFIG: {'batch_num': '1'}})
MatMulQAT(config={RETRAIN_DATA_CONFIG: {'dst_type': 1}})
MatMulQAT(config={RETRAIN_DATA_CONFIG: {'fixed_min': '1'}})
MatMulQAT(config={RETRAIN_DATA_CONFIG: {'clip_max': '1'}})
MatMulQAT(config={RETRAIN_DATA_CONFIG: {'clip_min': '1'}})
def test_matmul_qat_init_failed_wrong_config_item_scope(self):
with self.assertRaises(ValueError):
MatMulQAT(config={RETRAIN_DATA_CONFIG: {'dst_type': 'INT4'}})
MatMulQAT(config={RETRAIN_DATA_CONFIG: {'batch_num': 0}})
MatMulQAT(config={RETRAIN_DATA_CONFIG: {'clip_max': 0}})
MatMulQAT(config={RETRAIN_DATA_CONFIG: {'clip_min': 0}})
def test_matmul_qat_init_success(self):
MatMulQAT(config={RETRAIN_DATA_CONFIG: {'dst_type': 'INT8'}})
def test_matmul_op_infer_success(self):
mod = MatMulQAT()
for _ in range(10):
input_data = torch.randn(30, 30, 30)
other = torch.randn(30, 30, 30)
mod(input_data, other)
self.onnx_path = os.path.join(self.temp_dir, MATMUL_QAT_ONNX)
torch.onnx.export(mod, (input_data, other), self.onnx_path)
ori_out = torch.matmul(input_data, other)
ort_session = ort.InferenceSession(self.onnx_path, providers=[CPUEXECUTIONPROVIDER])
quantized_out = ort_session.run(
None, {ort_session.get_inputs()[0].name: input_data.cpu().numpy(),
ort_session.get_inputs()[1].name: other.cpu().numpy()})
sim = similarity(ori_out.detach().cpu().numpy(), quantized_out[0])
self.assertGreater(sim, 99)
mod = MatMulQAT()
for _ in range(10):
input_data = torch.zeros(30, 30, 30)
other = torch.zeros(30, 30, 30)
mod(input_data, other)
self.onnx_path = os.path.join(self.temp_dir, MATMUL_QAT_ONNX)
torch.onnx.export(mod, (input_data, other), self.onnx_path)
ori_out = torch.matmul(input_data, other)
ort_session = ort.InferenceSession(self.onnx_path, providers=[CPUEXECUTIONPROVIDER])
quantized_out = ort_session.run(
None, {ort_session.get_inputs()[0].name: input_data.cpu().numpy(),
ort_session.get_inputs()[1].name: other.cpu().numpy()})
sim = similarity(ori_out.detach().cpu().numpy(), quantized_out[0])
self.assertGreater(sim, 99)
mod = MatMulQAT()
for _ in range(10):
input_data = torch.ones(30, 30, 30)
other = torch.ones(30, 30, 30)
mod(input_data, other)
self.onnx_path = os.path.join(self.temp_dir, MATMUL_QAT_ONNX)
torch.onnx.export(mod, (input_data, other), self.onnx_path)
ori_out = torch.matmul(input_data, other)
ort_session = ort.InferenceSession(self.onnx_path, providers=[CPUEXECUTIONPROVIDER])
quantized_out = ort_session.run(
None, {ort_session.get_inputs()[0].name: input_data.cpu().numpy(),
ort_session.get_inputs()[1].name: other.cpu().numpy()})
sim = similarity(ori_out.detach().cpu().numpy(), quantized_out[0])
self.assertGreater(sim, 99)
mod = MatMulQAT()
for _ in range(10):
input_data = torch.rand(30, 30, 30) * 5 - 2
other = torch.rand(30, 30, 30) * 3 + 7
mod(input_data, other)
self.onnx_path = os.path.join(self.temp_dir, MATMUL_QAT_ONNX)
torch.onnx.export(mod, (input_data, other), self.onnx_path)
ori_out = torch.matmul(input_data, other)
ort_session = ort.InferenceSession(self.onnx_path, providers=[CPUEXECUTIONPROVIDER])
quantized_out = ort_session.run(
None, {ort_session.get_inputs()[0].name: input_data.cpu().numpy(),
ort_session.get_inputs()[1].name: other.cpu().numpy()})
sim = similarity(ori_out.detach().cpu().numpy(), quantized_out[0])
self.assertGreater(sim, 99)
op_type_dict = defaultdict(int)
model = onnx.load(self.onnx_path)
for node in model.graph.node:
op_type_dict[node.op_type] += 1
self.assertEqual(op_type_dict['QuantizeLinear'], 2)
self.assertEqual(op_type_dict['DequantizeLinear'], 2)
def test_matmul_qat_infer_failed_dtype_wrong(self):
mod = MatMulQAT()
input_data = torch.rand(30, 30, 30).to(torch.float16)
other = torch.rand(30, 30, 30).to(torch.float16)
with self.assertRaises(ValueError):
mod(input_data, other)
def test_matmul_qat_infer_failed_shape_wrong(self):
mod = MatMulQAT()
input_data = torch.rand(30)
other = torch.rand(30)
with self.assertRaises(RuntimeError):
mod(input_data, other)
input_data = torch.rand(3, 3, 3, 3, 3, 3, 3)
other = torch.rand(3, 3, 3, 3, 3, 3, 3)
with self.assertRaises(RuntimeError):
mod(input_data, other)