import glob
import importlib
import os
import pandas as pd
import sys
import torch
import torch_npu
from datetime import datetime, timezone
from xlsxwriter import Workbook
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from msprof_analyze.prof_common.constant import Constant
from msprof_analyze.prof_common.file_manager import FileManager
from msprof_analyze.prof_common.logger import get_logger
from msprof_analyze.prof_common.path_manager import PathManager
from msprof_analyze.prof_exports.inductor_triton_export import InductorTritonExport
logger = get_logger()
class ComparisonGenerator:
DB_PATTERN = "*_ascend_pt/ASCEND_PROFILER_OUTPUT/ascend_pytorch_profiler.db"
MSTX_MESSAGE = "inductor_triton"
TRITON_OP_PREFIX = "triton_"
COL_MESSAGE = "message"
COL_NAME = "Name"
COL_DURATION = "Duration(us)"
COL_TRITON_OP = "Triton Op"
COL_TRITON_DURATION = "Triton Op Duration(us)"
COL_ORIGINAL_DURATION = "Original Op Duration(us)"
COL_ORIGINAL_TOTAL_DURATION = "Original Op Total Duration(us)"
COL_DURATION_DIFF_RATIO = "Duration Diff Ratio"
EXCLUDE_OPS = ["aclnnIsClose_IsCloseAiCore_IsClose", "aclnnAll_ReduceAll_ReduceAll"]
DEFAULT = {"font_name": "Arial", 'font_size': 11, 'align': 'left', 'valign': 'vcenter', 'border': True,
'num_format': '#,##0'}
DEFAULT_FLOAT = {'text_wrap': True, "font_name": "Arial", 'font_size': 11, 'align': 'left', 'valign': 'vcenter', 'border': True,
'num_format': '#,##0.00'}
DEFAULT_RATIO = {"font_name": "Arial", 'font_size': 11, 'align': 'left', 'valign': 'vcenter',
'border': True, 'num_format': '0.00%'}
RED_RATIO = {"font_name": "Arial", 'font_size': 11, 'align': 'left', 'valign': 'vcenter',
'border': True, 'num_format': '0.00%', "fg_color": Constant.RED_COLOR}
BOLD_STR = {"font_name": "Arial", 'font_size': 11, 'align': 'left', 'valign': 'vcenter', 'border': True,
'bold': True}
BLUE_BOLD = {"font_name": "Arial", 'font_size': 11, 'fg_color': Constant.BLUE_COLOR, 'align': 'left',
'valign': 'vcenter', 'bold': True, 'border': True}
GREEN_BOLD = {"font_name": "Arial", 'font_size': 11, 'fg_color': Constant.GREEN_COLOR, 'align': 'left',
'valign': 'vcenter', 'bold': True, 'border': True}
YELLOW_BOLD = {"font_name": "Arial", 'font_size': 11, 'fg_color': Constant.YELLOW_COLOR, 'align': 'left',
'valign': 'vcenter', 'bold': True, 'border': True}
def __init__(self, params):
self.fx_graph_path = params.fx_graph_path
self.output_path = params.output_path
self.profiling_data_path = os.path.join(params.output_path, self.MSTX_MESSAGE)
self._result_data = None
def process_group(self, group: pd.DataFrame):
triton_op_mask = group[self.COL_NAME].str.startswith(self.TRITON_OP_PREFIX, na=False)
triton_op_rows = group[triton_op_mask].copy()
original_op_rows = group[~triton_op_mask].copy().sort_values(by=self.COL_DURATION, ascending=False)
if triton_op_rows.empty:
logger.warning(f"{group.name}: No triton operator found")
return None
if original_op_rows.empty:
logger.warning(f"{group.name}: No original operators found")
return None
triton_op_name = triton_op_rows[self.COL_NAME].iloc[0]
triton_op_duration = triton_op_rows[self.COL_DURATION].sum()
original_op_total_duration = original_op_rows[self.COL_DURATION].sum()
original_op_duration = '\n'.join([
f"{row[self.COL_NAME]}:{row[self.COL_DURATION]:.2f}"
for _, row in original_op_rows.iterrows()
])
return pd.Series(
[triton_op_name, triton_op_duration, original_op_duration, original_op_total_duration],
index=[self.COL_TRITON_OP, self.COL_TRITON_DURATION,self.COL_ORIGINAL_DURATION,
self.COL_ORIGINAL_TOTAL_DURATION]
)
def get_all_fx_graph(self):
sys.path.append(self.fx_graph_path)
fx_graph_modules = []
for dir_name in os.listdir(self.fx_graph_path):
sub_dir_fullpath = os.path.join(self.fx_graph_path, dir_name)
if not os.path.isdir(sub_dir_fullpath):
continue
target_py_path = os.path.join(sub_dir_fullpath, f"{dir_name}.py")
if not os.path.isfile(target_py_path):
continue
module_name = f"{dir_name}.{dir_name}"
try:
fx_module = importlib.import_module(module_name)
fx_graph_modules.append(fx_module)
except ImportError as err:
logger.error(f"Import error for '{module_name}': {err}")
except Exception as err:
logger.error(f"Unexpected error for '{module_name}': {err}", exc_info=True)
return fx_graph_modules
def get_prof_config(self):
experimental_config = torch_npu.profiler._ExperimentalConfig(
export_type=[
torch_npu.profiler.ExportType.Db
],
profiler_level=torch_npu.profiler.ProfilerLevel.Level0,
mstx=True,
)
prof = torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.NPU
],
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(self.profiling_data_path),
experimental_config=experimental_config)
return prof
def generate_compare_result(self):
res = glob.glob(os.path.join(self.profiling_data_path, self.DB_PATTERN))
if not res:
logger.error(f"Invalid profiling data: {self.output_path}, "
f"please check if the ascend_pytorch_profiler.db file exists.")
return
db_path = res[0]
FileManager.check_file_size(db_path)
PathManager.check_input_file_path(db_path)
df = InductorTritonExport(db_path).read_export_db()
if df is None or df.empty:
logger.error(f"Invalid profiling data, the db path is {db_path}")
return
filtered_df = df[~df[self.COL_NAME].isin(self.EXCLUDE_OPS)]
grouped_df = filtered_df.groupby(self.COL_MESSAGE).apply(self.process_group).reset_index(drop=True)
grouped_df.dropna(inplace=True)
grouped_df[self.COL_DURATION_DIFF_RATIO] = (grouped_df[self.COL_TRITON_DURATION] /
grouped_df[self.COL_ORIGINAL_TOTAL_DURATION])
self._result_data = grouped_df
def generate_view(self):
if self._result_data is None or self._result_data.empty:
logger.error("Invalid comparison result, please check if the comparison result exists.")
return
file_path = os.path.join(self.output_path,
f"inductor_triton_performance_comparison_result_{datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}.xlsx")
num_metrics = 2
data_cols = [
self.COL_TRITON_OP,
self.COL_TRITON_DURATION,
self.COL_ORIGINAL_DURATION,
self.COL_ORIGINAL_TOTAL_DURATION,
self.COL_DURATION_DIFF_RATIO
]
last_col_idx = len(data_cols) - 1
str_col_indices = [0, 2]
with Workbook(file_path) as workbook:
worksheet = workbook.add_worksheet()
str_format = workbook.add_format(self.BOLD_STR)
green_format = workbook.add_format(self.GREEN_BOLD)
yellow_format = workbook.add_format(self.YELLOW_BOLD)
default_ratio_format = workbook.add_format(self.DEFAULT_RATIO)
red_ratio_format = workbook.add_format(self.RED_RATIO)
float_format = workbook.add_format(self.DEFAULT_FLOAT)
r_idx = 0
worksheet.set_column(0, last_col_idx, 27)
for c_idx in str_col_indices:
worksheet.set_column(c_idx, c_idx, 60)
for c_idx, header in enumerate(data_cols):
if c_idx < num_metrics:
worksheet.write(r_idx, c_idx, header, green_format)
elif c_idx < num_metrics * 2:
worksheet.write(r_idx, c_idx, header, yellow_format)
else:
worksheet.write(r_idx, c_idx, header, str_format)
r_idx += 1
for _, row in self._result_data.iterrows():
for c_idx, cell_data in enumerate(row):
cell_format = float_format
if c_idx == last_col_idx and cell_data:
cell_format = red_ratio_format if cell_data > 1 else default_ratio_format
cell_data = "INF" if cell_data == float('inf') else cell_data
worksheet.write(r_idx, c_idx, cell_data, cell_format)
r_idx += 1
os.chmod(file_path, Constant.FILE_AUTHORITY)
logger.info(f"Generate comparison result successfully: {file_path}")
def run(self):
PathManager.remove_path_safety(self.profiling_data_path)
fx_graph_modules = self.get_all_fx_graph()
prof = self.get_prof_config()
stream = torch_npu.npu.current_stream()
prof.start()
for index, fx_module in enumerate(fx_graph_modules):
range_id = torch_npu.npu.mstx.range_start(f"{self.MSTX_MESSAGE}_{index}", stream)
try:
fx_module.run()
except Exception as err:
logger.error(f"Unexpected error for '{fx_module}': {err}", exc_info=True)
torch_npu.npu.mstx.range_end(range_id)
prof.stop()
try:
self.generate_compare_result()
self.generate_view()
except Exception as err:
logger.error(f"Failed to generate comparison result: {err}", exc_info=True)