from collections import namedtuple
from common_func.db_name_constant import DBNameConstant
from common_func.db_manager import DBManager
from common_func.msprof_object import CustomizedNamedtupleFactory
from msmodel.interface.parser_model import ParserModel
from msmodel.interface.view_model import ViewModel
class GeModel(ParserModel):
"""
ge info model class
"""
def __init__(self: any, result_dir: str, table_list: list) -> None:
super().__init__(result_dir, DBNameConstant.DB_GE_INFO, table_list)
self._current_table_name = None
def flush_all(self: any, data_dict: dict) -> None:
"""
insert all ge data to table
:param data_dict: ge data
:return:
"""
for table_name in data_dict.keys():
self._current_table_name = table_name
self.flush(data_dict.get(table_name, []))
def flush(self: any, data_list: list) -> None:
"""
insert ge data into database
"""
self.insert_data_to_db(self._current_table_name, data_list)
def delete_table(self: any, table_name: str) -> None:
"""
delete ge data
"""
self.cur.execute('delete from {}'.format(table_name))
def get_ge_model_name(self: any) -> any:
"""
get ge model name
"""
return self.__class__.__name__
class GeInfoViewModel(ViewModel):
TASK_INFO_TYPE = CustomizedNamedtupleFactory.enhance_namedtuple(
namedtuple(
"TaskInfo",
[
"model_id",
"op_name",
"stream_id",
"task_id",
"block_num",
"mix_block_num",
"op_state",
"task_type",
"op_type",
"index_id",
"thread_id",
"timestamp",
"batch_id",
"tensor_num",
"input_formats",
"input_data_types",
"input_shapes",
"output_formats",
"output_data_types",
"output_shapes",
"device_id",
"context_id",
"op_flag",
"hashid",
],
),
{},
)
def __init__(self, result_dir: str, table_list: list):
super().__init__(result_dir, DBNameConstant.DB_GE_INFO, table_list)
def get_ge_info_by_device_id(self: any, table_name: str, device_id: str, task_type_filter: tuple = tuple()) -> any:
fields = ", ".join(self.TASK_INFO_TYPE._fields)
ge_sql = "select {0} from {1} where device_id={2} ".format(fields, table_name, device_id)
condition = ""
for t in task_type_filter:
condition += " AND task_type != '{0}' ".format(t)
ge_sql = ge_sql + condition
task_info_data = DBManager.fetch_all_data(self.cur, ge_sql)
return [self.TASK_INFO_TYPE(*data) for data in task_info_data]