from mindspeed.auto_settings.utils.logger import get_logger
class ProfileList(object):
def __init__(self):
self.fw = []
self.bw = []
self._logger = get_logger('ProfileList')
def reset_index_name(self, list1, list2):
M, N = len(list1), len(list2)
i, j = 0, 0
index = 0
last_mat = (0, 0)
first_mat = 0
while 1:
list1, i, j, last_mat, first_mat = self.reset_index_name_single(list1, list2, i, j, last_mat)
if j < N - 1 and index < 3:
index += 1
i = last_mat[0] + 1
j += 1
else:
break
if first_mat == 0:
first_mat = last_mat[0] + 1
return list1, first_mat
def reset_index_name_single(self, list1, list2, i, j, last_mat):
M, N = len(list1), len(list2)
dp_flag = False
mat_flag = False
disperses_list = []
first_mat = 0
continue_num = 0
while i < M:
if j < N and list1[i].index_name == '':
if list1[i].type == list2[j].type:
mat_flag = True
if dp_flag:
disperses_list.append(i)
continue_num += 1
if continue_num > 5 or i >= M - 1:
dp_flag = False
continue_num = 0
list1 = self.attract_list(disperses_list, list1, i)
disperses_list = []
list1[i].index_name = list2[j].index_name
last_mat = (i, j)
j += 1
else:
if mat_flag and first_mat == 0:
first_mat = i
disperses_list.append(i)
continue_num = 0
dp_flag = True
elif dp_flag and len(disperses_list) > 0:
while i < M and list1[i].index_name == '':
i += 1
i = i - 1
dp_flag = False
continue_num = 0
list1 = self.attract_list(disperses_list, list1, i)
disperses_list = []
i += 1
return list1, i, j, last_mat, first_mat
def attract_list(self, disperses_list, list1, i):
index = 0
len_dp = len(disperses_list)
while i - index >= 0 and len_dp - index - 1 >= 0 and list1[i - index].type == list1[
disperses_list[len_dp - index - 1]].type:
temp = list1[disperses_list[len_dp - index - 1]].index_name
list1[disperses_list[len_dp - index - 1]].index_name = ''
list1[i - index].index_name = temp
index += 1
return list1
def print_list(self):
self.print_list_fw()
self.print_list_bw()
def print_list_fw(self):
self._logger.debug("fw")
for item in self.fw:
self._logger.debug("name", item.name, "type", item.type, "index_name", item.index_name)
def print_list_bw(self):
self._logger.debug("bw")
for item in self.bw:
self._logger.debug("name", item.name, "type", item.type, "index_name", item.index_name)
class ChangeList:
def __init__(self):
super(ChangeList, self).__init__()
self.list_2 = ProfileList()
self.list_4 = ProfileList()
class ChangeOperatorList:
def __init__(self):
super(ChangeOperatorList, self).__init__()
self.list_2 = ProfileList()
self.list_4 = ProfileList()
class DictShape(object):
def __init__(self):
self.name = ""
self.type = ""
self.accelerator_core = ""
self.index_name = ""
def change_profile_into_dictshape(self, item, index):
self.name = item.name
self.type = item.type
self.accelerator_core = item.accelerator_core
if index == -1:
self.index_name = ""
else:
self.index_name = str(index) + str(item.type)
class OperatorLayerTime(object):
def __init__(self):
self.base_operator = self.Element()
self.cp_exist = self.Element()
self.cp_diff = self.Element()
self.ep_exist = self.Element()
self.ep_diff = self.Element()
class Element:
def __init__(self, fw=0.0, bw=0.0):
self.fw = fw
self.bw = bw
class DictModelShape(DictShape):
def __init__(self):
super(DictModelShape, self).__init__()
self.model_w = 0.0
self.model_b = 0.0
self.shape_model_w = 0.0
self.shape_model_b = 0.0
class DictCalShape(DictShape):
def __init__(self):
super(DictCalShape, self).__init__()
self.input_cal = 0.0
self.output_cal = 0.0
class OperatorList(ProfileList):
def __init__(self):
super(OperatorList, self).__init__()
self.fw = []
self.bw = []
self.re = []
self._logger = get_logger('operator_list')
def print_list(self):
self.print_list_fw()
self.print_list_bw()
self.print_list_re()
def print_list_fw(self):
self._logger.debug("fw")
for item in self.fw:
self._logger.debug("name", item.name, "type", item.type, "index_name", item.index_name)
def print_list_bw(self):
self._logger.debug("bw")
for item in self.bw:
self._logger.debug("name", item.name, "type", item.type, "index_name", item.index_name)
def print_list_re(self):
self._logger.debug("re")
for item in self.re:
self._logger.debug("name", item.name, "type", item.type, "index_name", item.index_name)
class OperatorDetailList(OperatorList):
def __init__(self):
super(OperatorDetailList, self).__init__()
self.fw = []
self.bw = []
self.re = []
self.all = []