import logging
import os
from collections import deque
from dataclasses import dataclass
from typing import Union
from typing import List
from common_func.constant import Constant
from common_func.db_manager import DBManager
from common_func.db_name_constant import DBNameConstant
from common_func.info_conf_reader import InfoConfReader
from common_func.ms_constant.number_constant import NumberConstant
from common_func.msprof_iteration import MsprofIteration
from common_func.msvp_common import format_high_precision_for_csv
from common_func.path_manager import PathManager
from common_func.profiling_scene import ProfilingScene
from msmodel.ai_cpu.ai_cpu_model import AiCpuModel
from profiling_bean.db_dto.ge_task_dto import GeTaskDto
from profiling_bean.db_dto.step_trace_dto import IterationRange
from profiling_bean.db_dto.task_time_dto import TaskTimeDto
@dataclass
class AiCpuData:
stream_id: int = None
task_id: int = None
sys_start: int = None
sys_end: int = None
node_name: str = None
compute_time: float = None
memcpy_time: float = None
task_time: float = None
dispatch_time: float = None
total_time: float = None
batch_id: int = None
class ParseAiCpuData:
"""
class for parse aicpu data
"""
@staticmethod
def analysis_aicpu(project_path: str, iter_range: IterationRange) -> list:
"""
parse and analysis AI CPU related dp data
:return: ai cpu data , headers
"""
ai_cpu_results = ParseAiCpuData.get_ai_cpu_data(project_path, iter_range)
ascend_task_results = ParseAiCpuData.get_ascend_task_ai_cpu_data(project_path)
res = ParseAiCpuData.get_aicpu_batch_id(ai_cpu_results, ascend_task_results)
ge_results = ParseAiCpuData.get_ge_summary_aicpu_data(project_path)
res = ParseAiCpuData.match_aicpu_with_ge_summary(res, ge_results)
res.sort(key=lambda x: x[0])
return res
@staticmethod
def get_aicpu_batch_id(ai_cpu_data: List[AiCpuData], ascend_task_data: List[TaskTimeDto]) -> List[AiCpuData]:
"""
get ai cpu batch_id from ascend_task data
"""
sep_ai_cpu_dict = ParseAiCpuData._sep_task_by_stream_task(ai_cpu_data)
seq_ascend_task_dict = ParseAiCpuData._sep_task_by_stream_task(ascend_task_data)
res = []
stream_task_set = sep_ai_cpu_dict.keys()
for key in stream_task_set:
ai_cpu_t = sep_ai_cpu_dict.get(key, [])
ascend_task_t = seq_ascend_task_dict.get(key, [])
match_res = ParseAiCpuData._match_aicpu_data_by_task_time(ai_cpu_t, ascend_task_t)
res.extend([*match_res])
return res
@staticmethod
def match_aicpu_with_ge_summary(ai_cpu_data: List[AiCpuData], ge_summary_data: List[GeTaskDto]) -> list:
"""
get ai cpu op_name from ge_summary data
"""
start_ts_float = float(InfoConfReader().get_collect_time()[0])
ge_summary_dic = {}
for dto in ge_summary_data:
ge_summary_dic.setdefault((dto.stream_id, dto.task_id, dto.batch_id), dto.op_name)
index = 0
res = [[] for _ in range(len(ai_cpu_data))]
for i, dto in enumerate(ai_cpu_data):
sys_start = InfoConfReader().trans_into_local_time(dto.sys_start, use_us=True)
if start_ts_float > float(sys_start):
continue
sys_start = format_high_precision_for_csv(sys_start)
compute_time = round(dto.compute_time, NumberConstant.ROUND_THREE_DECIMAL)
memcpy_time = round(dto.memcpy_time, NumberConstant.ROUND_THREE_DECIMAL)
task_time = round(dto.task_time, NumberConstant.ROUND_THREE_DECIMAL)
dispatch_time = round(dto.dispatch_time, NumberConstant.ROUND_THREE_DECIMAL)
total_time = round(dto.total_time, NumberConstant.ROUND_THREE_DECIMAL)
node_name = ge_summary_dic.get((dto.stream_id, dto.task_id, dto.batch_id), dto.node_name)
res[index] = [
sys_start, node_name, compute_time, memcpy_time,
task_time, dispatch_time, total_time, dto.stream_id, dto.task_id
]
index = index + 1
return res[:index]
@staticmethod
def get_ge_summary_aicpu_data(project_path: str) -> list:
"""
get ge_summary data
"""
db_path = PathManager.get_db_path(project_path, DBNameConstant.DB_AICORE_OP_SUMMARY)
ge_summary_conn, ge_summary_curs = DBManager.check_connect_db_path(db_path)
if not ge_summary_conn or not ge_summary_curs:
logging.warning("Can't connect ai_core_op_summary.db!")
return []
ge_results = ParseAiCpuData._get_ge_summary_aicpu_data(ge_summary_conn)
DBManager.destroy_db_connect(ge_summary_conn, ge_summary_curs)
return ge_results
@staticmethod
def get_ascend_task_ai_cpu_data(project_path: str) -> list:
"""
get ascend_task_data
"""
db_path = PathManager.get_db_path(project_path, DBNameConstant.DB_ASCEND_TASK)
ascend_task_conn, ascend_task_curs = DBManager.check_connect_db_path(db_path)
if not ascend_task_conn or not ascend_task_curs:
logging.warning("Can't connect ascend_task.db!")
return []
ascend_task_results = ParseAiCpuData._get_ascend_task_aicpu_data(ascend_task_conn)
DBManager.destroy_db_connect(ascend_task_conn, ascend_task_curs)
return ascend_task_results
@staticmethod
def get_ai_cpu_data(project_path: str, iter_range: IterationRange) -> list:
"""
get ai cpu data
"""
db_path = PathManager.get_db_path(project_path, DBNameConstant.DB_AI_CPU)
ai_cpu_conn, ai_cpu_curs = DBManager.check_connect_db_path(db_path)
if not ai_cpu_conn or not ai_cpu_curs:
logging.warning("Can't connect ai_cpu.db!")
return []
ai_cpu_results = ParseAiCpuData._get_aicpu_data(ai_cpu_conn, iter_range, project_path)
DBManager.destroy_db_connect(ai_cpu_conn, ai_cpu_curs)
return ai_cpu_results
@staticmethod
def get_aicpu_mi_data(project_path: str) -> list:
"""
get ai cpu mi data
"""
db_path = PathManager.get_db_path(project_path, DBNameConstant.DB_CLUSTER_DATA_PREPROCESS)
conn, curs = DBManager.check_connect_db_path(db_path)
if not conn or not curs:
logging.warning("Can't connect %s", DBNameConstant.DB_CLUSTER_DATA_PREPROCESS)
return []
sql = "select node_name, start_time, end_time, queue_size from {0}".format(DBNameConstant.TABLE_DATA_QUEUE)
aicpu_mi_data = DBManager.fetch_all_data(conn.cursor(), sql)
DBManager.destroy_db_connect(conn, curs)
return aicpu_mi_data
@staticmethod
def get_ai_cpu_from_ts(project_path: str) -> list:
"""
get ai cpu query sql
"""
aicpu_model = AiCpuModel(project_path)
if not os.path.exists(PathManager.get_db_path(project_path, DBNameConstant.DB_AI_CPU)):
logging.info("no aicpu db found")
return []
with aicpu_model:
aicpu_data = aicpu_model.get_ai_cpu_data_from_ts()
return aicpu_data
@staticmethod
def _sep_task_by_stream_task(tasks: Union[List[TaskTimeDto], List[AiCpuData]]) -> dict:
ret = {}
for task in tasks:
ret.setdefault((task.stream_id, task.task_id), []).append(task)
return ret
@staticmethod
def _match_aicpu_data_by_task_time(ai_cpu_data: List[AiCpuData], ascend_task_data: List[TaskTimeDto]) -> List:
aicpu_index = 0
task_index = 0
while aicpu_index < len(ai_cpu_data) and task_index < len(ascend_task_data):
aicpu = ai_cpu_data[aicpu_index]
task = ascend_task_data[task_index]
if task.start_time <= aicpu.sys_end * NumberConstant.NS_TO_US <= task.end_time:
ai_cpu_data[aicpu_index] = aicpu.replace(batch_id=task.batch_id)
aicpu_index += 1
task_index += 1
elif aicpu.sys_end * NumberConstant.NS_TO_US < task.start_time:
aicpu_index += 1
elif aicpu.sys_end * NumberConstant.NS_TO_US > task.end_time:
task_index += 1
return ai_cpu_data
@staticmethod
def _get_ge_summary_aicpu_data(ge_summary_conn):
sql = "select op_name, stream_id, task_id, batch_id from {0} where task_type = '{1}' " \
"order by timestamp".format(DBNameConstant.TABLE_SUMMARY_GE,
Constant.TASK_TYPE_AI_CPU)
return DBManager.fetch_all_data(ge_summary_conn.cursor(), sql, dto_class=GeTaskDto)
@staticmethod
def _get_ascend_task_aicpu_data(ascend_task_conn):
sql = "select batch_id, stream_id, task_id, start_time, start_time + duration AS end_time " \
"from {0} where host_task_type = 'KERNEL_AICPU' " \
"order by start_time".format(DBNameConstant.TABLE_ASCEND_TASK)
return DBManager.fetch_all_data(ascend_task_conn.cursor(), sql, dto_class=TaskTimeDto)
@staticmethod
def _get_aicpu_data(ai_cpu_conn, iter_range, project_path):
where_condition = ""
if not ProfilingScene().is_all_export():
iter_time = MsprofIteration(project_path).get_iter_interval(iter_range)
if iter_time:
where_condition = "where sys_start>={0} " \
"and sys_end<={1}".format(iter_time[0] / NumberConstant.MS_TO_NS,
iter_time[1] / NumberConstant.MS_TO_NS)
sql = "select sys_start*{MS_TO_US} as sys_start, sys_end * {MS_TO_US} as sys_end, " \
"'{node_name}' as node_name," \
"compute_time*{MS_TO_US} as compute_time, memcpy_time*{MS_TO_US} as memcpy_time," \
"task_time*{MS_TO_US} as task_time,dispatch_time*{MS_TO_US} as dispatch_time," \
"total_time*{MS_TO_US} as total_time, stream_id, task_id from {0} {where_condition} " \
"order by sys_start".format(DBNameConstant.TABLE_AI_CPU,
MS_TO_US=NumberConstant.MS_TO_US,
local_time_offset=InfoConfReader().get_local_time_offset(),
node_name=Constant.NA,
where_condition=where_condition)
return DBManager.fetch_all_data(ai_cpu_conn.cursor(), sql, dto_class=AiCpuData)