from msprof_analyze.prof_common.logger import get_logger
import os
from abc import abstractmethod
from decimal import Decimal
from typing import List, Any
import pandas as pd
from msprof_analyze.cluster_analyse.common_func.table_constant import TableConstant
from msprof_analyze.prof_common.constant import Constant
from msprof_analyze.advisor.dataset.profiling.info_collection import OpInfo
from msprof_analyze.advisor.dataset.profiling.profiling_parser import ProfilingParser
from msprof_analyze.advisor.utils.utils import format_excel_title, lazy_property
logger = get_logger()
class OpSummaryBase(ProfilingParser):
FILE_PATTERN_MSG = "op_summary file pattern"
FILE_INFO = "op summary"
STATIC_OP_STATE = "static"
DYNAMIC_OP_STATE = "dynamic"
file_pattern_list = []
def __init__(self, path: str) -> None:
super().__init__(path)
self.op_list: List[OpInfo] = []
self._total_task_duration = 0.0
self._total_task_wait_time = 0.0
@abstractmethod
def parse_from_file(self, file: str):
return False
@abstractmethod
def get_static_shape_operators(self):
return False
@abstractmethod
def contains_op_state_info(self):
return False
def get_total_task_duration(self):
"""
get total task duration of all operators
:return:
"""
return self._total_task_duration
@lazy_property
def task_dict(self):
"""
task dict
"""
task_dict = {}
for op_info in self.op_list:
if op_info.op_name not in task_dict:
task_dict[op_info.op_name] = [op_info]
else:
task_dict[op_info.op_name].append(op_info)
return task_dict
class OpSummary(OpSummaryBase):
"""
op summary
"""
FILE_PATTERN_MSG = "op_summary_*.csv"
FILE_INFO = "op summary from text"
file_pattern_list = [r"^op_summary_[_\d]+\.csv$"]
def __init__(self, path: str) -> None:
super().__init__(path)
self._raw_data: List[List[str]] = []
def parse_from_file(self, file: str):
if not self._parse_csv(file):
return False
title_dict = dict(enumerate(self._raw_data[0]))
for op_data in self._raw_data[1:]:
op_info = OpInfo()
for idx, value in enumerate(op_data):
title = title_dict.get(idx, "")
formatted_title = format_excel_title(title)
if formatted_title == 'task_start_time' and 'us' in title and \
value.replace('.', '').replace("E+", "").isnumeric():
value = str(Decimal(value) * Decimal(1000))
op_info.add_attr(formatted_title, value)
self.op_list.append(op_info)
self._total_task_duration += self.get_float(op_info.get_attr("task_duration"))
self._total_task_wait_time += self.get_float(op_info.get_attr("task_wait_time"))
if not self.op_list:
logger.error("No valid op info in %s", file)
return False
return True
def get_static_shape_operators(self) -> List[Any]:
return [op_info.get_attr("op_name")
for op_info in self.op_list if op_info.get_attr("op_state") == self.STATIC_OP_STATE]
def contains_op_state_info(self):
return True
class OpSummaryDB(OpSummaryBase):
FILE_PATTERN_MSG = "ascend_*_profiler.db"
FILE_INFO = "op summary from db"
COLUMN_BLOCK_NUM = "blockNum"
file_pattern_list = [r'^ascend_pytorch_profiler(?:_\d+)?\.db$',
r'^ascend_mindspore_profiler(?:_\d+)?\.db$',
r'^msprof_\d{14}\.db$']
COMPUTE_INFO_SQL = """
WITH compute_info AS (
SELECT
(SELECT value FROM STRING_IDS WHERE id = t.name) AS op_name,
t.globalTaskId,
{block_dim_state}
(SELECT value FROM STRING_IDS WHERE id = t.opType) AS op_type,
(SELECT value FROM STRING_IDS WHERE id = t.taskType) AS task_type,
(SELECT value FROM STRING_IDS WHERE id = t.inputFormats) AS input_formats,
(SELECT value FROM STRING_IDS WHERE id = t.inputShapes) AS input_shapes,
(SELECT value FROM STRING_IDS WHERE id = t.inputDataTypes) AS input_data_types,
(SELECT value FROM STRING_IDS WHERE id = t.outputShapes) AS output_shapes,
(SELECT value FROM STRING_IDS WHERE id = t.outputFormats) AS output_formats,
(SELECT value FROM STRING_IDS WHERE id = t.outputDataTypes) AS output_data_types
{op_state}
FROM
COMPUTE_TASK_INFO t
)
SELECT
compute_info.*,
task.startNs as task_start_time,
task.endNs as task_end_time,
task.endNs - task.startNs as task_duration,
task.deviceId as device_id,
task.modelId as model_id,
task.streamId as stream_id,
task.contextId as context_id,
task.taskId as task_id
FROM
compute_info
JOIN
TASK as task ON compute_info.globalTaskId = task.globalTaskId;
"""
SELECT_OP_STATE = """,
(SELECT value FROM STRING_IDS WHERE id = t.opState) AS op_state
"""
PMU_SQL = """
SELECT
pmu.globalTaskId,
str.value as name,
pmu.value
FROM TASK_PMU_INFO AS pmu
JOIN STRING_IDS AS str ON str.id = pmu.name
"""
COMMUNICATION_INFO_SQL = """
WITH comm_info AS (
SELECT
(SELECT value FROM STRING_IDS WHERE id = c.opName) AS op_name,
(SELECT value FROM STRING_IDS WHERE id = c.opType) AS op_type,
startNs as task_start_time,
endNs as task_end_time,
endNs - startNs as task_duration,
connectionId
FROM
COMMUNICATION_OP c
)
SELECT
comm.*,
t.deviceId as device_id,
t.modelId as model_id,
'COMMUNICATION' as task_type
FROM
comm_info comm
JOIN (
SELECT
connectionId,
deviceId,
modelId
FROM TASK
GROUP BY connectionId
HAVING COUNT(DISTINCT deviceId) = 1 AND COUNT(DISTINCT modelId) = 1
) t ON comm.connectionId = t.connectionId
"""
COMMUNICATION_SCHEDULE_SQL = """
SELECT
(SELECT value FROM STRING_IDS WHERE id = CSTI.name) AS op_name,
(SELECT value FROM STRING_IDS WHERE id = CSTI.opType) AS op_type,
(SELECT value FROM STRING_IDS WHERE id = CSTI.taskType) AS task_type,
task.startNs as task_start_time,
task.endNs as task_end_time,
task.endNs - task.startNs as task_duration,
task.deviceId as device_id,
task.modelId as model_id,
task.streamId as stream_id,
task.contextId as context_id,
task.taskId as task_id
FROM COMMUNICATION_SCHEDULE_TASK_INFO as CSTI
JOIN TASK as task ON task.globalTaskId = CSTI.globalTaskId
"""
def __init__(self, path: str) -> None:
super().__init__(path)
self.has_op_state = False
def parse_from_file(self, file: str):
if not file or not os.path.exists(file):
logger.error("db path is None.")
return False
compute_df = self.export_compute_task(file)
communication_df = self._execute_sql(file, self.COMMUNICATION_INFO_SQL, [Constant.TABLE_COMMUNICATION_OP])
comm_schedule_df = self._execute_sql(file, self.COMMUNICATION_SCHEDULE_SQL,
[Constant.TABLE_COMMUNICATION_SCHEDULE_TASK_INFO])
if compute_df.empty and communication_df.empty and comm_schedule_df.empty:
logger.warning(f"No compute and communication operators in db: {file}")
return False
total_df = self.post_process([compute_df, communication_df, comm_schedule_df])
self._total_task_duration = float(total_df['task_duration'].sum())
self._total_task_wait_time = float(total_df['task_wait_time'].sum())
self.convert_to_op_info_list(total_df)
return True
def contains_op_state_info(self):
return self.has_op_state
def get_static_shape_operators(self) -> List[Any]:
if not self.contains_op_state_info():
return []
return [op_info.get_attr("op_name")
for op_info in self.op_list if op_info.get_attr("op_state") == self.STATIC_OP_STATE]
def export_compute_task(self, db_path):
if self._check_table_column_exists(db_path, Constant.TABLE_COMPUTE_TASK_INFO, TableConstant.OP_STATE):
op_state = self.SELECT_OP_STATE
self.has_op_state = True
else:
op_state = ""
self.has_op_state = False
if self._check_table_column_exists(db_path, Constant.TABLE_COMPUTE_TASK_INFO, self.COLUMN_BLOCK_NUM):
block_dim_state = """
t.blockNum AS block_dim,
t.mixBlockNum AS mix_block_dim,
"""
else:
block_dim_state = """
t.blockDim AS block_dim,
t.mixBlockDim AS mix_block_dim,
"""
comp_info_sql = self.COMPUTE_INFO_SQL.format(
op_state=op_state,
block_dim_state=block_dim_state
)
basic_df = self._execute_sql(db_path, comp_info_sql, [Constant.TABLE_COMPUTE_TASK_INFO])
pmu_df = self._execute_sql(db_path, self.PMU_SQL, [Constant.TABLE_TASK_PMU_INFO])
if basic_df.empty or pmu_df.empty:
return basic_df
pivoted_pmu_df = pmu_df.pivot_table(
index='globalTaskId',
columns='name',
values='value',
aggfunc='first'
).reset_index()
compute_df = basic_df.merge(pivoted_pmu_df, on='globalTaskId', how='left').fillna(0)
return compute_df
def export_communication_task(self, db_path):
return self._execute_sql(db_path, self.COMMUNICATION_INFO_SQL)
def post_process(self, df_list):
total_df = pd.concat(df_list, ignore_index=True).sort_values(by='task_start_time')
total_df = total_df.fillna('N/A')
total_df['task_wait_time'] = total_df['task_end_time'] - total_df['task_start_time'].shift(1)
total_df.loc[0, 'task_wait_time'] = 0
time_cols = [col for col in total_df.columns.tolist() if 'time' in col]
time_cols.append('task_duration')
for col in time_cols:
total_df[col] = total_df[col].apply(lambda x: Decimal(x) / 1000 if x != 'N/A' else x)
total_df = total_df.rename(columns={'aiv_total_time': 'aiv_time', 'aic_total_time': 'aicore_time'},
errors='ignore')
total_df = total_df.drop(columns=['task_end_time', 'globalTaskId', 'connectionId'], errors='ignore')
return total_df
def convert_to_op_info_list(self, df):
for row in df.itertuples(index=False):
op_info = OpInfo()
for col in df.columns:
setattr(op_info, col, getattr(row, col))
self.op_list.append(op_info)