from typing import List, Tuple
from collections import namedtuple
from common_func.constant import Constant
from common_func.db_manager import DBManager
from common_func.db_name_constant import DBNameConstant
from common_func.ms_constant.number_constant import NumberConstant
from common_func.msprof_object import CustomizedNamedtupleFactory
from mscalculate.ascend_task.ascend_task import TopDownTask
from msmodel.interface.base_model import BaseModel
from msmodel.interface.view_model import ViewModel
class AscendTaskModel(BaseModel):
"""
class used to operate top-down task
"""
def __init__(self: any, result_dir: str, table_list: list) -> None:
super(AscendTaskModel, self).__init__(result_dir, DBNameConstant.DB_ASCEND_TASK, table_list)
def get_ascend_task_data_without_unknown(self: any) -> List[TopDownTask]:
if not DBManager.judge_table_exist(self.cur, DBNameConstant.TABLE_ASCEND_TASK):
return []
param = (
Constant.TASK_TYPE_UNKNOWN,
)
data_sql = "select * from {} " \
"where device_task_type != ?".format(DBNameConstant.TABLE_ASCEND_TASK)
return DBManager.fetch_all_data(self.cur, data_sql, param=param, dto_class=TopDownTask)
class AscendTaskViewModel(ViewModel):
"""
class used to operate top-down task
"""
ASCEND_TASK_TYPE = CustomizedNamedtupleFactory.enhance_namedtuple(
namedtuple("AscendTask",
["model_id", "index_id", "stream_id", "task_id", "context_id", "batch_id", "timestamp",
"duration", "host_task_type", "device_task_type", "connection_id", "op_name",
"ts_virtual_batch_id"]),
{})
def __init__(self: any, result_dir: str, table_list: list) -> None:
super().__init__(result_dir, DBNameConstant.DB_ASCEND_TASK, table_list)
def get_ascend_task_data_with_op_name_pattern_and_stream_id(
self: any, device_id: str,
op_name_pattern: str,
stream_ids: tuple
) -> List[ASCEND_TASK_TYPE]:
if not DBManager.judge_table_exist(self.cur, DBNameConstant.TABLE_ASCEND_TASK) \
or not self.attach_to_db(DBNameConstant.DB_GE_INFO):
return []
stream_id_condition = "b.stream_id in ({})".format(",".join(map(str, stream_ids)))
sql = "SELECT b.model_id, b.index_id, b.stream_id, b.task_id, b.context_id, b.batch_id, b.start_time, " \
"b.duration, b.host_task_type, b.device_task_type, b.connection_id, a.op_name as op_name, " \
"-1 as ts_virtual_batch_id from {0} as a inner join {1} as b " \
"on a.stream_id = b.stream_id " \
"and a.task_id = b.task_id " \
"and a.batch_id = b.batch_id " \
"and a.context_id = b.context_id " \
"and a.device_id = {device_id} " \
"and b.start_time != {invalid_start} " \
"and (a.op_name like '%{pattern}' or {stream_id_condition}) " \
"order by start_time" \
.format(DBNameConstant.TABLE_GE_TASK, DBNameConstant.TABLE_ASCEND_TASK, device_id=device_id,
invalid_start=NumberConstant.INVALID_TASK_TIME, pattern=op_name_pattern,
stream_id_condition=stream_id_condition)
ascend_task_data = DBManager.fetch_all_data(self.cur, sql)
return [self.ASCEND_TASK_TYPE(*data) for data in ascend_task_data]
def get_ascend_task_data_with_stream_id(self: any, stream_ids: tuple) -> List[ASCEND_TASK_TYPE]:
if not DBManager.judge_table_exist(self.cur, DBNameConstant.TABLE_ASCEND_TASK):
return []
stream_id_condition = "stream_id in ({})".format(",".join(map(str, stream_ids)))
sql = "SELECT model_id, index_id, stream_id, task_id, context_id, batch_id, start_time, duration, " \
"host_task_type, device_task_type, connection_id, 'N/A' as op_name, batch_id as ts_virtual_batch_id " \
"from {0} where {stream_id_condition} " \
.format(DBNameConstant.TABLE_ASCEND_TASK, stream_id_condition=stream_id_condition)
ascend_task_data = DBManager.fetch_all_data(self.cur, sql)
return [self.ASCEND_TASK_TYPE(*data) for data in ascend_task_data]
def get_ascend_task_time_extremes(self: any, filter_type: list) -> Tuple[TopDownTask, TopDownTask]:
if not DBManager.judge_table_exist(self.cur, DBNameConstant.TABLE_ASCEND_TASK):
return (None, None)
type_condition = "host_task_type not in ({})".format(
",".join(map(lambda x: f"'{x}'", filter_type))
)
latest_sql = f"SELECT * FROM {DBNameConstant.TABLE_ASCEND_TASK} " \
f"WHERE {type_condition} " \
f"AND device_task_type != '{Constant.TASK_TYPE_UNKNOWN}'" \
f"ORDER BY (start_time + duration) DESC " \
f"LIMIT 1"
latest_result = DBManager.fetch_all_data(self.cur, latest_sql, dto_class=TopDownTask)
latest_task = latest_result[0] if latest_result else None
earliest_sql = f"SELECT * FROM {DBNameConstant.TABLE_ASCEND_TASK} " \
f"WHERE {type_condition} " \
f"AND device_task_type != '{Constant.TASK_TYPE_UNKNOWN}'" \
f"ORDER BY start_time ASC " \
f"LIMIT 1"
earliest_result = DBManager.fetch_all_data(self.cur, earliest_sql, dto_class=TopDownTask)
earliest_task = earliest_result[0] if earliest_result else None
return latest_task, earliest_task