import json
import os
import unittest
import numpy as np
import torch
from amct_pytorch.classic.graph_based.amct_pytorch.configuration.retrain_config import (
CONFIGURER,
RetrainConfig,
)
from amct_pytorch.classic.graph_based.amct_pytorch.parser.parser import Parser
from tests.amct_pytorch.testcase_python.configuration.utils import models
REGULAR_PRUNE_ENABLE = 'regular_prune_enable'
REGULAR_PRUNE_CONFIG = 'regular_prune_config'
CUR_DIR = os.path.split(os.path.realpath(__file__))[0]
ALGO = 'algo'
PRUNE_TYPE = 'prune_type'
DST_TYPE = 'dst_type'
FILTER_PRUNE = 'filter_prune'
INT8 = 'INT8'
BALANCED_L2_NORM_FILTER_PRUNE = 'balanced_l2_norm_filter_prune'
PRUNE_RATIO = 'prune_ratio'
ASCEND_OPTIMIZED = 'ascend_optimized'
class TestRetrainConfigForPrune(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.temp_folder = os.path.join(CUR_DIR, 'test_retrain_config')
if not os.path.isdir(cls.temp_folder):
os.makedirs(cls.temp_folder)
cls.model_001 = models.Net001().to(torch.device("cpu"))
cls.args_shape = [(1, 2, 28, 28)]
cls.args = list()
for input_shape in cls.args_shape:
cls.args.append(torch.randn(input_shape))
cls.args = tuple(cls.args)
cls.onnx_file = os.path.join(cls.temp_folder, 'net_001.onnx')
Parser.export_onnx(cls.model_001, cls.args, cls.onnx_file)
cls.graph = Parser.parse_net_to_graph(cls.onnx_file)
cls.graph.model = cls.model_001
@classmethod
def tearDownClass(cls):
os.popen('rm -r ' + cls.temp_folder)
pass
def setUp(self):
pass
def tearDown(self):
pass
def test_complete_cfg(self):
config_defination = os.path.join(CUR_DIR, "./utils/net_001_prune_complete.cfg")
RetrainConfig.init(self.graph, config_defination)
fc_config = {
REGULAR_PRUNE_ENABLE: False,
REGULAR_PRUNE_CONFIG: {
PRUNE_TYPE: FILTER_PRUNE,
ALGO: BALANCED_L2_NORM_FILTER_PRUNE,
PRUNE_RATIO: 0.3,
ASCEND_OPTIMIZED: True
}
}
override_layer_config = {
REGULAR_PRUNE_ENABLE: True,
REGULAR_PRUNE_CONFIG: {
PRUNE_TYPE: FILTER_PRUNE,
ALGO: BALANCED_L2_NORM_FILTER_PRUNE,
PRUNE_RATIO: 0.4,
ASCEND_OPTIMIZED: False
}
}
override_layer_types = {
REGULAR_PRUNE_ENABLE: True,
REGULAR_PRUNE_CONFIG: {
PRUNE_TYPE: FILTER_PRUNE,
ALGO: BALANCED_L2_NORM_FILTER_PRUNE,
PRUNE_RATIO: 0.5,
ASCEND_OPTIMIZED: True
}
}
skip_override_layer_types = {
REGULAR_PRUNE_ENABLE: False,
REGULAR_PRUNE_CONFIG: {
PRUNE_TYPE: FILTER_PRUNE,
ALGO: BALANCED_L2_NORM_FILTER_PRUNE,
PRUNE_RATIO: 0.5,
ASCEND_OPTIMIZED: True
}
}
self.assertEqual(RetrainConfig.retrain_config['fc.0'], fc_config)
self.assertEqual(RetrainConfig.retrain_config['fc.2'], fc_config)
self.assertEqual(RetrainConfig.retrain_config['fc.5'], fc_config)
self.assertEqual(RetrainConfig.retrain_config['layer1.0'], skip_override_layer_types)
self.assertEqual(RetrainConfig.retrain_config['layer2.0'], override_layer_config)
self.assertEqual(RetrainConfig.retrain_config.get('layer3.0'), None)
self.assertEqual(RetrainConfig.retrain_config.get('layer4.0'), None)
self.assertEqual(RetrainConfig.retrain_config['layer5.0'], override_layer_types)
self.assertEqual(RetrainConfig.retrain_config['layer6.0'], override_layer_types)
def test_no_bcp_prune_ratio(self):
config_defination = os.path.join(CUR_DIR, "./utils/net_001_prune_miss_ratio.cfg")
self.assertRaises(ValueError, RetrainConfig.init, self.graph, config_defination)
def test_no_bcp_algo(self):
config_defination = os.path.join(CUR_DIR, "./utils/net_001_prune_empty_001.cfg")
self.assertRaises(ValueError, RetrainConfig.init, self.graph, config_defination)
def test_prune_has_quant(self):
config_defination = os.path.join(CUR_DIR, "./utils/net_001_prune_has_quant.cfg")
self.assertRaises(ValueError, RetrainConfig.init, self.graph, config_defination)
def test_quant_has_prune(self):
config_defination = os.path.join(CUR_DIR, "./utils/net_001_quant_has_prune.cfg")
RetrainConfig.set_ability(enable_retrain=True, enable_prune=False)
retrain_config = {}
self.assertRaises(
ValueError, CONFIGURER.create_config_from_proto,
retrain_config, self.graph, config_defination)
def test_prune_skip_override_repeated(self):
config_defination = os.path.join(CUR_DIR, "./utils/net_001_prune_skip_override_repeated.cfg")
RetrainConfig.init(self.graph, config_defination)
def test_prune_unsupport_type(self):
config_defination = os.path.join(CUR_DIR, "./utils/net_001_prune_unsupport_type.cfg")
self.assertRaises(ValueError, RetrainConfig.init, self.graph, config_defination)
def test_prune_no_layer(self):
config_defination = os.path.join(CUR_DIR, "./utils/net_001_prune_no_layer.cfg")
self.assertRaises(ValueError, RetrainConfig.init, self.graph, config_defination)
def test_compressed_complete_cfg(self):
config_defination = os.path.join(CUR_DIR, "./utils/compressed_cfg/net_001_compressed_complete.cfg")
RetrainConfig.init(self.graph, config_defination, True, True)
def test_compressed_empty_cfg(self):
config_defination = os.path.join(CUR_DIR, "./utils/compressed_cfg/net_001_compressed_empty.cfg")
self.assertRaises(ValueError, RetrainConfig.init, self.graph, config_defination, True, True)
def test_compressed_only_prune_complete_prune_cfg(self):
config_defination = os.path.join(CUR_DIR, "./utils/net_001_prune_complete.cfg")
RetrainConfig.init(self.graph, config_defination, True, True)
fc_config = {
REGULAR_PRUNE_ENABLE: False,
REGULAR_PRUNE_CONFIG: {
PRUNE_TYPE: FILTER_PRUNE,
ALGO: BALANCED_L2_NORM_FILTER_PRUNE,
PRUNE_RATIO: 0.3,
ASCEND_OPTIMIZED: True
}
}
override_layer_config = {
REGULAR_PRUNE_ENABLE: True,
REGULAR_PRUNE_CONFIG: {
PRUNE_TYPE: FILTER_PRUNE,
ALGO: BALANCED_L2_NORM_FILTER_PRUNE,
PRUNE_RATIO: 0.4,
ASCEND_OPTIMIZED: False
}
}
override_layer_types = {
REGULAR_PRUNE_ENABLE: True,
REGULAR_PRUNE_CONFIG: {
PRUNE_TYPE: FILTER_PRUNE,
ALGO: BALANCED_L2_NORM_FILTER_PRUNE,
PRUNE_RATIO: 0.5,
ASCEND_OPTIMIZED: True
}
}
skip_override_layer_types = {
REGULAR_PRUNE_ENABLE: False,
REGULAR_PRUNE_CONFIG: {
PRUNE_TYPE: FILTER_PRUNE,
ALGO: BALANCED_L2_NORM_FILTER_PRUNE,
PRUNE_RATIO: 0.5,
ASCEND_OPTIMIZED: True
}
}
self.assertEqual(RetrainConfig.retrain_config['fc.0'], fc_config)
self.assertEqual(RetrainConfig.retrain_config['fc.2'], fc_config)
self.assertEqual(RetrainConfig.retrain_config['fc.5'], fc_config)
self.assertEqual(RetrainConfig.retrain_config['layer1.0'], skip_override_layer_types)
self.assertEqual(RetrainConfig.retrain_config['layer2.0'], override_layer_config)
self.assertEqual(RetrainConfig.retrain_config.get('layer3.0'), None)
self.assertEqual(RetrainConfig.retrain_config.get('layer4.0'), None)
self.assertEqual(RetrainConfig.retrain_config['layer5.0'], override_layer_types)
self.assertEqual(RetrainConfig.retrain_config['layer6.0'], override_layer_types)
def test_compressed_only_prune_no_bcp_ration(self):
config_defination = os.path.join(CUR_DIR, "./utils/net_001_prune_miss_ratio.cfg")
self.assertRaises(ValueError, RetrainConfig.init, self.graph, config_defination, True, True)
def test_compressed_only_prune_no_bcp_algo(self):
config_defination = os.path.join(CUR_DIR, "./utils/net_001_prune_empty_001.cfg")
self.assertRaises(ValueError, RetrainConfig.init, self.graph, config_defination, True, True)
def test_compressed_only_prune_no_filter_prune_algo(self):
config_defination = os.path.join(CUR_DIR, "./utils/net_001_prune_empty_002.cfg")
self.assertRaises(ValueError, RetrainConfig.init, self.graph, config_defination, True, True)
def test_compressed_only_prune_skip_override_repeated(self):
config_defination = os.path.join(CUR_DIR, "./utils/net_001_prune_skip_override_repeated.cfg")
RetrainConfig.init(self.graph, config_defination, True, True)
def test_compressed_only_prune_prune_unsupported_type(self):
config_defination = os.path.join(CUR_DIR, "./utils/net_001_prune_unsupport_type.cfg")
self.assertRaises(ValueError, RetrainConfig.init, self.graph, config_defination, True, True)
def test_compressed_only_prune_prune_no_layer(self):
config_defination = os.path.join(CUR_DIR, "./utils/net_001_prune_no_layer.cfg")
self.assertRaises(ValueError, RetrainConfig.init, self.graph, config_defination, True, True)
config_defination = os.path.join(CUR_DIR, "./utils/net_001_quant_has_prune.cfg")
layer1_0_skip_override_types = {
REGULAR_PRUNE_ENABLE: False,
REGULAR_PRUNE_CONFIG: {
PRUNE_TYPE: FILTER_PRUNE,
ALGO: BALANCED_L2_NORM_FILTER_PRUNE,
PRUNE_RATIO: 0.5,
ASCEND_OPTIMIZED: True
}
}
layer2_0_override_layer_configs = {
REGULAR_PRUNE_ENABLE: True,
REGULAR_PRUNE_CONFIG: {
PRUNE_TYPE: FILTER_PRUNE,
ALGO: BALANCED_L2_NORM_FILTER_PRUNE,
PRUNE_RATIO: 0.4,
ASCEND_OPTIMIZED: False
}
}
self.assertEqual(RetrainConfig.retrain_config['layer1.0'], layer1_0_skip_override_types)
self.assertEqual(RetrainConfig.retrain_config['layer2.0'], layer2_0_override_layer_configs)
self.assertEqual(RetrainConfig.retrain_config.get('layer3.0'), None)
self.assertEqual(RetrainConfig.retrain_config.get('layer4.0'), None)
override_layer_types = {
REGULAR_PRUNE_ENABLE: True,
REGULAR_PRUNE_CONFIG: {
PRUNE_TYPE: FILTER_PRUNE,
ALGO: BALANCED_L2_NORM_FILTER_PRUNE,
PRUNE_RATIO: 0.5,
ASCEND_OPTIMIZED: True
}
}
self.assertEqual(RetrainConfig.retrain_config['layer5.0'], override_layer_types)
self.assertEqual(RetrainConfig.retrain_config['layer6.0'], override_layer_types)
fc_config = {
REGULAR_PRUNE_ENABLE: False,
REGULAR_PRUNE_CONFIG: {
PRUNE_TYPE: FILTER_PRUNE,
ALGO: BALANCED_L2_NORM_FILTER_PRUNE,
PRUNE_RATIO: 0.3,
ASCEND_OPTIMIZED: True
}
}
self.assertEqual(RetrainConfig.retrain_config['fc.0'], fc_config)
self.assertEqual(RetrainConfig.retrain_config['fc.2'], fc_config)
self.assertEqual(RetrainConfig.retrain_config['fc.5'], fc_config)
def test_compressed_only_quant_complete_quant_cfg(self):
config_defination = os.path.join(CUR_DIR, "./utils/compressed_cfg/net_001_compressed_quant_only_complete.cfg")
RetrainConfig.init(self.graph, config_defination, True, True)
def test_compressed_only_quant_only_data_weight_cfg(self):
config_defination = os.path.join(
CUR_DIR, "./utils/compressed_cfg/net_001_compressed_quant_only_data_weight.cfg")
self.assertRaises(ValueError, RetrainConfig.init, self.graph, config_defination, True, True)
def test_compressed_only_quant_only_data_weight_no_channelwise_cfg(self):
config_defination = os.path.join(
CUR_DIR,
"./utils/compressed_cfg/net_001_compressed_quant_only_data_weight_nochannelwise.cfg")
RetrainConfig.init(self.graph, config_defination, True, True)
layer_quant = {
'retrain_enable': True,
'retrain_data_config': {
ALGO: 'ulq_quantize',
DST_TYPE: INT8,
'clip_max': 6.0,
'clip_min': -6.0
},
'retrain_weight_config': {
ALGO: 'arq_retrain',
'channel_wise': True,
DST_TYPE: INT8
}
}
fc_quant = {
'retrain_enable': True,
'retrain_data_config': {
ALGO: 'ulq_quantize',
DST_TYPE: INT8,
'clip_max': 6.0,
'clip_min': -6.0
},
'retrain_weight_config': {
ALGO: 'arq_retrain',
'channel_wise': False,
DST_TYPE: INT8
}
}
self.assertEqual(RetrainConfig.retrain_config['layer1.0'], layer_quant)
self.assertEqual(RetrainConfig.retrain_config['layer2.0'], layer_quant)
self.assertEqual(RetrainConfig.retrain_config['layer3.0'], layer_quant)
self.assertEqual(RetrainConfig.retrain_config['layer4.0'], layer_quant)
self.assertEqual(RetrainConfig.retrain_config['layer5.0'], layer_quant)
self.assertEqual(RetrainConfig.retrain_config['layer6.0'], layer_quant)
self.assertEqual(RetrainConfig.retrain_config['fc.0'], fc_quant)
self.assertEqual(RetrainConfig.retrain_config['fc.2'], fc_quant)
self.assertEqual(RetrainConfig.retrain_config['fc.5'], fc_quant)
def test_compressed_only_quant_only_skip_layers(self):
config_defination = os.path.join(
CUR_DIR, "./utils/compressed_cfg/net_001_compressed_quant_only_skip_layers.cfg")
self.assertRaises(ValueError, RetrainConfig.init, self.graph, config_defination, True, True)
def test_compressed_only_quant_only_skip_layer_types(self):
config_defination = os.path.join(
CUR_DIR, "./utils/compressed_cfg/net_001_compressed_quant_only_skip_layer_types.cfg")
self.assertRaises(ValueError, RetrainConfig.init, self.graph, config_defination, True, True)
def test_compressed_only_override_layer_configs(self):
config_defination = os.path.join(
CUR_DIR,
"./utils/compressed_cfg/net_001_compressed_quant_only_override_layer_configs.cfg")
RetrainConfig.init(self.graph, config_defination, True, True)
self.assertTrue(RetrainConfig.retrain_config['layer1.0'].get(REGULAR_PRUNE_ENABLE))
self.assertEqual(RetrainConfig.retrain_config['layer1.0'].get(REGULAR_PRUNE_CONFIG),
dict([(PRUNE_TYPE, FILTER_PRUNE), (ALGO, BALANCED_L2_NORM_FILTER_PRUNE),
(PRUNE_RATIO, 0.5), (ASCEND_OPTIMIZED, True)]))
if __name__ == "__main__":
unittest.main()