import logging
import sqlite3
from common_func.constant import Constant
from common_func.db_name_constant import DBNameConstant
from common_func.ms_constant.number_constant import NumberConstant
from common_func.ms_constant.str_constant import StrConstant
from common_func.msvp_constant import MsvpConstant
from msmodel.interface.view_model import ViewModel
def get_ge_model_data(params: dict, table_name: str, configs: dict) -> tuple:
"""
get ge model data
"""
result_data = []
project_path = params.get(StrConstant.PARAM_RESULT_DIR)
search_data_sql = "select {0}.model_name, {1}.model_id, fusion_name, op_names, memory_input/{BYTES_TO_KB}, " \
"memory_output/{BYTES_TO_KB}, memory_weight/{BYTES_TO_KB}, memory_workspace/{BYTES_TO_KB}," \
" memory_total/{BYTES_TO_KB} from {1} inner join {0} where {0}.model_id={1}.model_id " \
.format(table_name, DBNameConstant.TABLE_GE_FUSION_OP_INFO, BYTES_TO_KB=NumberConstant.BYTES_TO_KB)
model_view = ViewModel(project_path, configs.get(StrConstant.CONFIG_DB), [table_name])
if not model_view.check_table():
return MsvpConstant.MSVP_EMPTY_DATA
data = model_view.get_sql_data(search_data_sql)
if not data:
return MsvpConstant.MSVP_EMPTY_DATA
hash_dict = get_ge_hash_dict(project_path)
_update_hash_data(data, hash_dict, result_data)
return configs.get(StrConstant.CONFIG_HEADERS), result_data, len(result_data)
def _update_hash_data(data: list, hash_dict: dict, result_data: list) -> None:
for _data in data:
_data = list(_data)
_data[0] = hash_dict.get(_data[0], _data[0])
_data[2] = hash_dict.get(_data[2], _data[2])
_data[3] = ";".join(map(str, [hash_dict.get(str(i), i) for i in list(_data[3].split(","))]))
result_data.append(_data)
def get_ge_hash_dict(project_path: str) -> dict:
"""
get ge hash dict
"""
hash_view = ViewModel(project_path, DBNameConstant.DB_GE_HASH, [DBNameConstant.TABLE_GE_HASH])
try:
if not hash_view.check_table():
return {}
return dict(hash_view.get_all_data(DBNameConstant.TABLE_GE_HASH))
except sqlite3.Error as err:
logging.error(str(err), exc_info=Constant.TRACE_BACK_SWITCH)
return {}
finally:
hash_view.finalize()
def get_ge_model_name_dict(project_path: str) -> dict:
"""
get ge model name dict
"""
model_view = ViewModel(project_path, DBNameConstant.DB_GE_MODEL_INFO, [DBNameConstant.TABLE_MODEL_NAME])
try:
if not model_view.check_table():
model_view.finalize()
return {}
except sqlite3.Error as err:
logging.error(str(err), exc_info=Constant.TRACE_BACK_SWITCH)
return {}
sql = "select model_id, model_name from {}".format(DBNameConstant.TABLE_MODEL_NAME)
data = model_view.get_sql_data(sql)
model_view.finalize()
hash_dict = get_ge_hash_dict(project_path)
model_name_list = []
for _data in data:
_data = list(_data)
_data[1] = hash_dict.get(_data[1], _data[1])
model_name_list.append(_data)
return dict(model_name_list)