import json
import logging
import math
import os
import sys
import unittest
from unittest import mock
import numpy as np
import torch
from amct_pytorch.classic.graph_based.amct_pytorch.custom_op import bcp
DEVICE = 'cuda:0'
CUR_DIR = os.path.split(os.path.realpath(__file__))[0]
logger = logging.getLogger(__name__)
class TestBcp(unittest.TestCase):
"""
The UT for QuantizeTool
"""
@classmethod
def setUpClass(cls):
cls.temp_folder = os.path.join(CUR_DIR, 'test_bcp')
if not os.path.isdir(cls.temp_folder):
os.makedirs(cls.temp_folder)
@classmethod
def tearDownClass(cls):
os.popen('rm -r ' + cls.temp_folder)
pass
def setUp(self):
pass
def tearDown(self):
pass
def test_algo_001(self):
""" tensor:1, ascend_optimized:true, prune_group:1"""
tensor_list = [torch.randn([160, 4, 3, 3])]
prune_axises = [0]
algo_param = {
'prune_ratio': 0.3,
'ascend_optimized': True,
'prune_group': 1
}
logger.info('ori_data:', tensor_list[0].type())
logger.info(tensor_list[0][0])
prune_mask = bcp(
tensor_list, prune_axises,
algo_param['prune_ratio'],
algo_param['prune_group'],
algo_param['ascend_optimized'],
)
prune_mask_list = BcpFaker(tensor_list, prune_axises, algo_param).do()
is_equal = (prune_mask_list == prune_mask.cpu().numpy()).all()
self.assertTrue(is_equal)
def test_algo_002(self):
""" tensor:1, ascend_optimized:true, prune_group:2"""
tensor_list = [torch.randn([160, 4, 3, 3])]
prune_axises = [0]
algo_param = {
'prune_ratio': 0.3,
'ascend_optimized': True,
'prune_group': 2
}
prune_mask = bcp(
tensor_list, prune_axises,
algo_param['prune_ratio'],
algo_param['prune_group'],
algo_param['ascend_optimized'],
)
prune_mask_list = BcpFaker(tensor_list, prune_axises, algo_param).do()
is_equal = (prune_mask_list == prune_mask.cpu().numpy()).all()
self.assertTrue(is_equal)
def test_algo_003(self):
""" tensor:1, ascend_optimized:false, prune_group:1"""
tensor_list = [torch.randn([160, 4, 3, 3])]
prune_axises = [0]
algo_param = {
'prune_ratio': 0.3,
'ascend_optimized': False,
'prune_group': 1
}
prune_mask = bcp(
tensor_list, prune_axises,
algo_param['prune_ratio'],
algo_param['prune_group'],
algo_param['ascend_optimized'],
)
prune_mask_list = BcpFaker(tensor_list, prune_axises, algo_param).do()
is_equal = (prune_mask_list == prune_mask.cpu().numpy()).all()
self.assertTrue(is_equal)
def test_algo_004(self):
""" tensor:1, ascend_optimized:true, prune_group:1
num is too little
"""
tensor_list = [torch.randn([16, 4, 3, 3])]
prune_axises = [0]
algo_param = {
'prune_ratio': 0.3,
'ascend_optimized': True,
'prune_group': 1
}
prune_mask = bcp(
tensor_list, prune_axises,
algo_param['prune_ratio'],
algo_param['prune_group'],
algo_param['ascend_optimized'],
)
prune_mask_list = BcpFaker(tensor_list, prune_axises, algo_param).do()
is_equal = (prune_mask_list == prune_mask.cpu().numpy()).all()
self.assertTrue(is_equal)
def test_algo_005(self):
""" tensor:1, ascend_optimized:false, prune_group:1
num is too little
"""
tensor_list = [torch.randn([16, 4, 3, 3])]
prune_axises = [0]
algo_param = {
'prune_ratio': 0.3,
'ascend_optimized': False,
'prune_group': 1
}
prune_mask = bcp(
tensor_list, prune_axises,
algo_param['prune_ratio'],
algo_param['prune_group'],
algo_param['ascend_optimized'],
)
prune_mask_list = BcpFaker(tensor_list, prune_axises, algo_param).do()
is_equal = (prune_mask_list == prune_mask.cpu().numpy()).all()
self.assertTrue(is_equal)
@unittest.skip('*')
def test_algo_006(self):
""" tensor:2, ascend_optimized:true, prune_group:2"""
tensor_list = [torch.randn([160, 4, 3, 3]), torch.randn([4, 160, 3, 3])]
prune_axises = [0, 1]
algo_param = {
'prune_ratio': 0.3,
'ascend_optimized': True,
'prune_group': 2
}
prune_mask = bcp(
tensor_list, prune_axises,
algo_param['prune_ratio'],
algo_param['prune_group'],
algo_param['ascend_optimized'],
)
prune_mask_list = BcpFaker(tensor_list, prune_axises, algo_param).do()
is_equal = (prune_mask_list == prune_mask.cpu().numpy()).all()
logger.info('prune_mask_list', prune_mask_list)
logger.info('prune_mask', prune_mask)
logger.info('equal:', prune_mask_list == prune_mask.cpu().numpy())
logger.info('numpy_num', sum(prune_mask_list[0:80]), sum(prune_mask_list[80:160]))
logger.info('op_num', prune_mask.cpu().numpy()[0:80].sum(), prune_mask.cpu().numpy()[80:160].sum())
self.assertTrue(is_equal)
class BcpFaker:
def __init__(self, tensor_list, prune_axises, prune_param):
self.tensor_list = tensor_list
self.prune_axises = prune_axises
self.prune_param = prune_param
@staticmethod
def cal_channel_index(norm_list, prune_ratio, ascend_optimized, prune_group):
'''norm_list 要分组挑选
'''
def cal_prune_num(num, prune_ratio, ascend_optimized, prune_group):
remain_num = num - round(num * prune_ratio)
if ascend_optimized:
remain_num = min(math.ceil(remain_num / 16) * 16, num)
if remain_num == 0:
remain_num = min(num, 16)
if prune_group > 1:
remain_num = min(math.floor(remain_num / prune_group) * prune_group, num)
if remain_num == 0:
remain_num = prune_group
prune_num = num - remain_num
return prune_num
norm = np.mean(np.array(norm_list), axis=0)
prune_index = list()
group_len = norm.shape[0] // prune_group
prune_nums = cal_prune_num(norm.shape[0], prune_ratio, ascend_optimized, prune_group)
prune_nums_grp = prune_nums // prune_group
if prune_nums == 0:
prune_index = []
prune_cout = [1 for idx in range(norm.shape[0])]
return prune_cout, prune_index
for i in range(prune_group):
start_idx = i * group_len
norm_grp = norm[start_idx:(i + 1) * group_len]
sort_index = np.argsort(norm_grp) + start_idx
prune_index.extend(sort_index[0:prune_nums_grp].tolist())
prune_index.sort()
prune_mask = [0 if idx in prune_index else 1 for idx in range(norm.shape[0])]
return prune_mask, prune_index
@staticmethod
def cal_norm(wts_tensor, axis):
norm_algo = 'l2_norm'
ord_map = {'l1_norm': 1, 'l2_norm': 2}
wts_np = wts_tensor.cpu().numpy()
if axis != 0:
axis_list = [idx for idx in range(wts_np.ndim)]
axis_list[0] = axis_list[axis]
axis_list[axis] = 0
wts_np = np.transpose(wts_np, axis_list)
wts_np = wts_np.reshape([wts_np.shape[0], -1])
norm = np.linalg.norm(wts_np, ord=ord_map.get(norm_algo), axis=1, keepdims=False)
norm = norm / (wts_np[0].size)
norm = (norm - norm.min()) / (norm.max() - norm.min())
return norm
@staticmethod
def prune_cout_to_prune_index(prune_cout):
prune_index = [idx for idx, val in enumerate(prune_cout) if val == 0]
return prune_index
def do(self):
norm_list = list()
for data_tensor, cout_axis in zip(self.tensor_list, self.prune_axises):
norm_val = self.cal_norm(data_tensor, cout_axis)
norm_list.append(norm_val)
prune_mask, prune_index = self.cal_channel_index(norm_list,
self.prune_param['prune_ratio'], self.prune_param['ascend_optimized'], self.prune_param['prune_group'])
return prune_mask