from abc import ABC
from functools import partial
from common_func.constant import Constant
from common_func.db_manager import DBManager
from common_func.db_name_constant import DBNameConstant
from common_func.empty_class import EmptyClass
from common_func.info_conf_reader import InfoConfReader
from common_func.path_manager import PathManager
from common_func.profiling_scene import ProfilingScene
from msmodel.interface.base_model import BaseModel
from msmodel.interface.view_model import ViewModel
from profiling_bean.db_dto.step_trace_dto import IterationRange
from profiling_bean.db_dto.step_trace_dto import StepTraceDto
from profiling_bean.db_dto.step_trace_dto import StepTraceOriginDto
from profiling_bean.db_dto.step_trace_ge_dto import StepTraceGeDto
from profiling_bean.db_dto.tiling_block_dim_dto import TilingBlockDimDto
from profiling_bean.db_dto.time_section_dto import TimeSectionDto
from msparser.step_trace.ts_binary_data_reader.task_flip_bean import TaskFlip
class TsTrackModel(BaseModel, ABC):
"""
acsq task model class
"""
TS_AI_CPU_TYPE = 1
@staticmethod
def __aicpu_in_time_range(data, min_timestamp, max_timestamp):
return min_timestamp <= data[2] <= max_timestamp
def flush(self: any, table_name: str, data_list: list) -> None:
"""
flush acsq task data to db
:param data_list:acsq task data list
:return: None
"""
self.insert_data_to_db(table_name, data_list)
def create_table(self: any, table_name: str) -> None:
"""
create table
"""
if DBManager.judge_table_exist(self.cur, table_name):
DBManager.drop_table(self.conn, table_name)
table_map = "{0}Map".format(table_name)
sql = DBManager.sql_create_general_table(table_map, table_name, self.TABLES_PATH)
DBManager.execute_sql(self.conn, sql)
def get_ai_cpu_data(self: any, iter_time_range) -> list:
"""
get ai cpu data
:param iter_time_range: iteration time range
:return: ai cpu with state
"""
if not DBManager.check_tables_in_db(PathManager.get_db_path(self.result_dir, DBNameConstant.DB_STEP_TRACE),
DBNameConstant.TABLE_TASK_TYPE):
return []
sql = "select stream_id, task_id, timestamp, task_state from {0} " \
"where task_type={1} order by timestamp ".format(DBNameConstant.TABLE_TASK_TYPE, self.TS_AI_CPU_TYPE)
ai_cpu_with_state = DBManager.fetch_all_data(self.cur, sql)
for index, datum in enumerate(ai_cpu_with_state):
ai_cpu_with_state[index] = list(datum)
ai_cpu_with_state[index][2] = int(datum[2])
if not ProfilingScene().is_all_export() and iter_time_range:
min_timestamp = min(iter_time_range)
max_timestamp = max(iter_time_range)
ai_cpu_with_state = list(filter(partial(self.__aicpu_in_time_range, min_timestamp=min_timestamp,
max_timestamp=max_timestamp), ai_cpu_with_state))
return ai_cpu_with_state
def get_step_trace_data(self: any, table_name: str) -> list:
"""
get step trace data
"""
if not DBManager.judge_table_exist(self.cur, table_name) or \
not DBManager.judge_row_exist(self.cur, table_name):
return []
sql = "select model_id, index_id, iter_id, step_start, step_end from {0}".format(table_name)
step_trace_data = DBManager.fetch_all_data(self.cur, sql, dto_class=StepTraceDto)
return step_trace_data
def get_index_range_with_model(self, model_id):
"""
get the max iteration id of the model id.
"""
table_name = ProfilingScene().get_step_table_name()
if not DBManager.judge_table_exist(self.cur, table_name) or \
not DBManager.judge_row_exist(self.cur, table_name):
return []
sql = f'select min(index_id), max(index_id) from {table_name} where model_id=?'
return DBManager.fetchone(self.cur, sql, (model_id,))
def get_step_syscnt_range_by_iter_range(self, iteration: IterationRange):
"""
get step time range by the iteration range.
"""
table_name = ProfilingScene().get_step_table_name()
if not DBManager.judge_table_exist(self.cur, table_name) or \
not DBManager.judge_row_exist(self.cur, table_name):
return EmptyClass()
iteration_range = iteration.get_iteration_range()
if iteration_range[0] == 1:
sql = f"select 0 as step_start, max(step_end) as step_end " \
f"from {table_name} where model_id=? and index_id>=? and index_id<=?"
else:
sql = f"select min(step_start) as step_start, max(step_end) as step_end " \
f"from {table_name} where model_id=? and index_id>=? and index_id<=?"
return DBManager.fetchone(self.cur, sql, (iteration.model_id, *iteration_range),
dto_class=StepTraceDto)
def get_step_syscnt_range(self, iteration: IterationRange):
"""
get step start sys_cnt and end sys_cnt
"""
table_name = ProfilingScene().get_step_table_name()
if not DBManager.judge_table_exist(self.cur, table_name) or \
not DBManager.judge_row_exist(self.cur, table_name):
return EmptyClass()
iteration_range = iteration.get_iteration_range()
sql = f"select min(step_start) as step_start, max(step_end) as step_end " \
f"from {table_name} where model_id = ? and index_id >= ? and index_id <= ?"
return DBManager.fetchone(self.cur, sql, (iteration.model_id, *iteration_range),
dto_class=StepTraceDto)
def get_step_end_list_with_iter_range(self, iteration: IterationRange):
"""
get step trace within the range of iteration.
"""
table_name = ProfilingScene().get_step_table_name()
if not DBManager.judge_table_exist(self.cur, table_name) or \
not DBManager.judge_row_exist(self.cur, table_name):
return []
sql = f"select index_id, step_end from {table_name} " \
f"where model_id=? and index_id>=? and index_id<=? order by step_end"
return DBManager.fetch_all_data(self.cur, sql, (iteration.model_id, *iteration.get_iteration_range()),
dto_class=StepTraceDto)
def get_task_flip_data(self: any) -> list:
if not DBManager.judge_table_exist(self.cur, DBNameConstant.TABLE_DEVICE_TASK_FLIP):
return []
sql = "select stream_id, timestamp, task_id, flip_num from {0}".format(DBNameConstant.TABLE_DEVICE_TASK_FLIP)
return DBManager.fetch_all_data(self.cur, sql, dto_class=TaskFlip)
def get_step_trace_with_tag(self: any, tags: list) -> list:
if not tags or not DBManager.judge_table_exist(self.cur, DBNameConstant.TABLE_STEP_TRACE) or \
not DBManager.judge_row_exist(self.cur, DBNameConstant.TABLE_STEP_TRACE):
return []
tags_condition = ",".join([str(tag) for tag in tags])
select_sql = f"select DISTINCT index_id, model_id, timestamp, tag_id, stream_id " \
f"from {DBNameConstant.TABLE_STEP_TRACE} " \
f"where tag_id in ({tags_condition}) order by timestamp"
return DBManager.fetch_all_data(self.cur, select_sql, dto_class=StepTraceOriginDto)
class TsTrackViewModel(ViewModel):
def __init__(self: any, path: str, table_list: list = None) -> None:
super().__init__(path, DBNameConstant.DB_STEP_TRACE, table_list if table_list else [])
def get_hccl_operator_exe_data(self) -> list:
if not self.attach_to_db(DBNameConstant.DB_GE_INFO):
return []
device_id = InfoConfReader().get_device_id()
sql = "SELECT t1.model_id model_id, t1.index_id index_id, t1.stream_id stream_id, t1.task_id task_id, " \
"t1.tag_id tag_id, t1.timestamp timestamp, t2.op_name op_name, t2.op_type op_type " \
"FROM ( SELECT model_id, index_id, tag_id, stream_id, task_id-1 AS task_id, timestamp " \
"FROM {0} WHERE tag_id>=10000 ) t1 LEFT JOIN ( " \
"SELECT model_id, index_id, stream_id, task_id, op_name, op_type FROM {1} WHERE task_type='{2}' ) t2 " \
"ON t1.model_id=t2.model_id AND (t1.index_id=t2.index_id OR t2.index_id=0 ) " \
"AND t1.stream_id = t2.stream_id AND t1.task_id = t2.task_id AND " \
"t2.device_id = {3} ORDER BY t1.timestamp".format(
DBNameConstant.TABLE_STEP_TRACE, DBNameConstant.TABLE_GE_TASK, Constant.TASK_TYPE_HCCL, device_id)
return DBManager.fetch_all_data(self.cur, sql, dto_class=StepTraceGeDto)
def get_ai_cpu_data(self) -> list:
sql = "SELECT stream_id, task_id, timestamp, task_state FROM {} where " \
"task_type=1 and (task_state=1 or task_state=2) order by timestamp".format(
DBNameConstant.TABLE_TASK_TYPE)
return DBManager.fetch_all_data(self.cur, sql)
def get_iter_time_data(self) -> list:
sql = "select model_id, index_id ,step_start as start_time, step_end as end_time " \
"from {} order by end_time".format(DBNameConstant.TABLE_STEP_TRACE_DATA)
return DBManager.fetch_all_data(self.cur, sql, dto_class=TimeSectionDto)
def get_tiling_block_dim_data(self):
sql = "select stream_id, task_id, timestamp, block_dim from {}".format(DBNameConstant.TABLE_BLOCK_DIM)
return DBManager.fetch_all_data(self.cur, sql, dto_class=TilingBlockDimDto)