import logging
from common_func.constant import Constant
from common_func.db_manager import DBManager
from common_func.db_name_constant import DBNameConstant
from common_func.info_conf_reader import InfoConfReader
from common_func.ms_constant.number_constant import NumberConstant
from common_func.ms_constant.str_constant import StrConstant
from common_func.ms_multi_process import MsMultiProcess
from common_func.msprof_iteration import MsprofIteration
from common_func.path_manager import PathManager
from common_func.profiling_scene import ProfilingScene
from mscalculate.ts_task.ai_cpu.aicpu_from_ts_collector import AICpuFromTsCollector
from msconfig.config_manager import ConfigManager
from msmodel.step_trace.ts_track_model import TsTrackModel
from viewer.calculate_rts_data import calculate_task_schedule_data
from viewer.calculate_rts_data import multi_calculate_task_cost_time
class TaskSchedulerCalculator(MsMultiProcess):
"""
calculate origin data
"""
TABLE_PATH = ConfigManager.TABLES
COMPLETE_TIME_INDEX = 10
def __init__(self: any, file_list: dict, sample_config: dict) -> None:
super().__init__(sample_config)
self._file_list = file_list
self.sample_config = sample_config
self.project_path = sample_config.get(StrConstant.SAMPLE_CONFIG_PROJECT_PATH)
self.iter_range = sample_config.get(StrConstant.PARAM_ITER_ID)
@staticmethod
def _insert_report_task_data(runtime_conn: any, runtime_curs: any, device_id: str) -> None:
report_data = calculate_task_schedule_data(runtime_curs, device_id)
if not report_data:
logging.info('Unable to get report task data')
return
sql = 'insert into ReportTask values({value})'.format(
value='?,' * (len(report_data[0]) - 1) + '?')
DBManager.executemany_sql(runtime_conn, sql, report_data)
def ms_run(self: any) -> None:
"""
entry
:return:
"""
if not ProfilingScene().is_operator():
self.process()
def process(self: any) -> None:
"""
calculate for task scheduler
:return:
"""
try:
self.generate_report_data()
except (OSError, SystemError, ValueError, TypeError, RuntimeError) as err:
logging.error(str(err))
def create_task_time(self: any, runtime_conn: any, device: int, iter_time_range: list) -> None:
"""
create task time table
:param runtime_conn: connect for runtime
:param device: device id
:param iter_time_range: iteration range
:return: None
"""
runtime_curs = runtime_conn.cursor()
logging.info('start create task time table')
if not DBManager.judge_table_exist(runtime_curs, DBNameConstant.TABLE_RUNTIME_TIMELINE):
logging.warning("TimeLine data not found, and no need to create task time data, "
"please check the ts track data.")
return
self._create_task_time_table(runtime_conn, runtime_curs)
try:
cal_task_data = self._get_timeline_data(device, iter_time_range, runtime_curs)
except (OSError, SystemError, ValueError, TypeError, RuntimeError) as err:
logging.error(err, exc_info=Constant.TRACE_BACK_SWITCH)
return
task_time = self._add_info(sorted(cal_task_data, key=lambda x: x[-1]))
self._collect_aicpu(task_time)
self._insert_task_time_data(task_time, runtime_conn)
logging.info('create task time table end')
def insert_report_data(self: any, project_path: str, device_id: str) -> None:
"""
insert data into report
:param device_id:
:param project_path: project path
:return:
"""
logging.info('start insert data into report table')
db_path = PathManager.get_db_path(project_path, DBNameConstant.DB_RUNTIME)
runtime_conn, runtime_curs = DBManager.check_connect_db_path(db_path)
if not runtime_conn or not runtime_curs \
or not DBManager.check_tables_in_db(db_path, DBNameConstant.TABLE_RUNTIME_TASK_TIME):
DBManager.destroy_db_connect(runtime_conn, runtime_curs)
return
self._create_report_task_table(runtime_conn)
self._insert_report_task_data(runtime_conn, runtime_curs, device_id)
logging.info('Insert data into report table finished.')
DBManager.destroy_db_connect(runtime_conn, runtime_curs)
def generate_report_data(self: any) -> None:
"""
insert report task table
:return: None
"""
devices = InfoConfReader().get_device_list()
if not devices:
logging.error("No device list found in info.json")
return
iter_time_range = MsprofIteration(self.project_path).get_step_syscnt_range_by_iter_range(self.iter_range)
if not iter_time_range:
return
self.__pre_mini_task_data(self.project_path, devices[0], iter_time_range)
self.insert_report_data(self.project_path, devices[0])
def _insert_task_time_data(self: any, task_time: list, runtime_conn: any) -> None:
task_time = sorted(task_time, key=lambda data: data[self.COMPLETE_TIME_INDEX])
insert_sql = "insert into TaskTime " \
"values ({value})".format(value="?," * (len(task_time[0]) - 1) + "?")
DBManager.executemany_sql(runtime_conn, insert_sql, task_time)
def _create_task_time_table(self: any, runtime_conn: any, runtime_curs: any) -> None:
if DBManager.judge_table_exist(runtime_curs, DBNameConstant.TABLE_RUNTIME_TASK_TIME):
DBManager.drop_table(runtime_conn, DBNameConstant.TABLE_RUNTIME_TASK_TIME)
sql = DBManager.sql_create_general_table('TaskTimeMap', DBNameConstant.TABLE_RUNTIME_TASK_TIME,
self.TABLE_PATH)
DBManager.execute_sql(runtime_conn, sql)
def _get_timeline_data(self: any, device: int, iter_time_range: list, runtime_curs: any) -> list:
timeline_sql = "select replayId,device_id,'','',taskType," \
"task_id,stream_id,timeStamp,taskState " \
"from TimeLine WHERE device_id=? and timestamp>? and timestamp<?" \
"order by task_id, stream_id,timeStamp,taskState,device_id;"
timeline_data = DBManager.fetch_all_data(runtime_curs, timeline_sql, (device,
iter_time_range[0], iter_time_range[1]))
cal_task_data = multi_calculate_task_cost_time(timeline_data, self.project_path)
return cal_task_data
def __pre_mini_task_data(self: any, project_path: str, device_id: int, iter_time_range: list) -> None:
runtime_conn, runtime_curs = \
DBManager.check_connect_db_path(PathManager.get_db_path(project_path, DBNameConstant.DB_RUNTIME))
if not runtime_conn or not runtime_curs:
return
try:
self.create_task_time(runtime_conn, device_id, iter_time_range)
except Exception as err:
logging.error(err, exc_info=Constant.TRACE_BACK_SWITCH)
finally:
DBManager.destroy_db_connect(runtime_conn, runtime_curs)
def _create_report_task_table(self: any, runtime_conn: any) -> None:
if DBManager.check_tables_in_db(PathManager.get_db_path(self.project_path, DBNameConstant.DB_RUNTIME),
DBNameConstant.TABLE_RUNTIME_REPORT_TASK):
DBManager.drop_table(runtime_conn, DBNameConstant.TABLE_RUNTIME_REPORT_TASK)
sql = DBManager.sql_create_general_table('ReportTaskMap', 'ReportTask', self.TABLE_PATH)
DBManager.execute_sql(runtime_conn, sql)
def _add_info(self: any, cal_task_data: list) -> list:
with TsTrackModel(self.project_path, DBNameConstant.DB_STEP_TRACE,
[DBNameConstant.TABLE_STEP_TRACE_DATA]) as _trace:
step_trace_data = _trace.get_step_end_list_with_iter_range(self.iter_range)
task_time = [
task_data +
(MsprofIteration(self.project_path).get_iter_id_within_iter_range(step_trace_data, task_data[-1],
self.iter_range),
self.iter_range.model_id,
NumberConstant.DEFAULT_BATCH_ID)
for task_data in cal_task_data
]
return task_time
def _collect_aicpu(self: any, task_time: list) -> None:
aicpu_collector = AICpuFromTsCollector(self.project_path)
for data in task_time:
task_id = data[5]
stream_id = data[6]
start = data[9]
end = data[10]
task_type = data[4]
aicpu_data = (stream_id, task_id, start, end, task_type)
aicpu_collector.filter_aicpu(aicpu_data)
aicpu_collector.save_aicpu()