from collections import namedtuple
import os
import pandas as pd
from msprof_analyze.cluster_analyse.common_func.utils import describe_duration
from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis
from msprof_analyze.prof_common.constant import Constant
from msprof_analyze.prof_common.logger import get_logger
from msprof_analyze.prof_exports.mstx_event_export import MstxMarkExport, MstxRangeExport
from msprof_analyze.prof_exports.mstx_step_export import MstxStepExport
logger = get_logger()
MarkInfo = namedtuple("MarkInfo", ["name", "framework_duration", "cann_duration", "device_duration",
"tid", "start_ns"])
def format_mark_info(df: pd.DataFrame, start_idx, stop_idx, name) -> MarkInfo:
start_series = df.iloc[start_idx]
stop_series = df.iloc[stop_idx]
return MarkInfo(
name=name,
framework_duration=float(stop_series["framework_ts"] - start_series["framework_ts"]),
cann_duration=float(stop_series["cann_ts"] - start_series["cann_ts"]),
device_duration=float(stop_series["device_ts"] - start_series["device_ts"]),
tid=start_series["tid"],
start_ns=start_series["cann_ts"]
)
def format_range_info(df: pd.DataFrame, idx, name) -> MarkInfo:
range_series = df.iloc[idx]
return MarkInfo(
name=name,
framework_duration=float(0),
cann_duration=float(range_series["cann_end_ts"] - range_series["cann_start_ts"]),
device_duration=float(range_series["device_end_ts"] - range_series["device_start_ts"]),
tid=range_series["tid"],
start_ns=range_series["cann_start_ts"]
)
def rename_mark_msg_name(mstx_stats_df: pd.DataFrame):
msg_idx_counter = {}
for idx, mark_info in enumerate(mstx_stats_df.itertuples(index=False)):
msg_idx_counter.setdefault(mark_info.step_id, {}).setdefault(mark_info.name, []).append(idx)
for msg_dict in msg_idx_counter.values():
for msg, idx_list in msg_dict.items():
if len(idx_list) <= 1:
continue
for i, idx in enumerate(idx_list):
mstx_stats_df.loc[idx, 'name'] = f"{msg}_{i}"
def compute_step_id(mark_stat, step_stats_df: pd.DataFrame):
for step_info in step_stats_df.itertuples(index=False):
if step_info.start_ns <= mark_stat.start_ns <= step_info.end_ns:
return step_info.step_id
logger.warning(f"{mark_stat.name} is not in any step.")
return 0
def format_columns(df: pd.DataFrame):
formatted_df = df.rename(
{
"framework_duration": "FrameworkDurationNs",
"cann_duration": "CannDurationNs",
"device_duration": "DeviceDurationNs",
"duration": "DurationNs",
"step_id": "StepId",
"tid": "Tid",
"name": "Name"
},
axis="columns"
)
cols = [col for col in formatted_df.columns if not col.endswith("_ns") and col not in {"Tid"}]
return formatted_df[cols]
def handle_mark_data(mark_df: pd.DataFrame, rank_id: int) -> list:
res = []
mark_df["framework_ts"] = mark_df["framework_ts"].astype("int64")
mark_info = {}
mismatch_msg = []
for idx, row in enumerate(mark_df.itertuples(index=False)):
if row.msg.endswith(MstxSum.START_SUFFIX):
msg = row.msg[:-len(MstxSum.START_SUFFIX)]
mark_info.setdefault(row.tid, {}).setdefault(msg, []).append(idx)
elif row.msg.endswith(MstxSum.STOP_SUFFIX):
msg = row.msg[:-len(MstxSum.STOP_SUFFIX)]
idx_list = mark_info.get(row.tid, {}).get(msg, [])
if not idx_list:
mismatch_msg.append((row.msg, idx))
continue
start_idx = idx_list.pop()
res.append(format_mark_info(mark_df, start_idx, idx, msg))
for msg_info in mark_info.values():
for msg, idx_list in msg_info.items():
if not idx_list:
continue
mismatch_msg.extend((msg + MstxSum.START_SUFFIX, idx) for idx in idx_list)
if mismatch_msg:
mismatch_msg.sort(key=lambda msg: msg[1])
logger.warning(f"The following mark messages do not match anyone in "
f"rank {rank_id}: {','.join(msg[0] for msg in mismatch_msg)}.")
return res
def handle_range_data(range_df: pd.DataFrame) -> list:
res = []
for idx, row in enumerate(range_df.itertuples(index=False)):
res.append(format_range_info(range_df, idx, row.msg))
return res
class MstxSum(BaseRecipeAnalysis):
TABLE_FRAMEWORK_STATS = "MSTXAllFrameworkStats"
TABLE_CANN_STATS = "MSTXAllCannStats"
TABLE_DEVICE_STATS = "MSTXAllDeviceStats"
TABLE_MARK_STATS = "MSTXMarkStats"
START_SUFFIX = "_start"
STOP_SUFFIX = "_stop"
def __init__(self, params):
super().__init__(params)
logger.info("MstxSum init.")
self.mark_stats = None
self.all_fwk_stats = None
self.all_cann_stats = None
self.all_device_stats = None
@property
def base_dir(self):
return os.path.basename(os.path.dirname(__file__))
def reducer_func(self, mapper_res):
mapper_res = list(filter(lambda df: df is not None, mapper_res))
if not mapper_res:
logger.error("Mapper data is None.")
return
self.mark_stats = pd.concat(mapper_res)
all_fwk_stats = []
all_cann_stats = []
all_device_stats = []
mark_step_df = self.mark_stats.groupby("StepId")
for step_id, df in mark_step_df:
name_gdf = df.groupby("Name")
fwk_stats = describe_duration(name_gdf["FrameworkDurationNs"]).assign(StepId=step_id)
fwk_stats.sort_values(by=["SumNs"], inplace=True, ascending=False)
all_fwk_stats.append(fwk_stats)
cann_stats = describe_duration(name_gdf["CannDurationNs"]).assign(StepId=step_id)
cann_stats.sort_values(by=["SumNs"], inplace=True, ascending=False)
all_cann_stats.append(cann_stats)
device_stats = describe_duration(name_gdf["DeviceDurationNs"]).assign(StepId=step_id)
device_stats.sort_values(by=["SumNs"], inplace=True, ascending=False)
all_device_stats.append(device_stats)
self.all_fwk_stats = pd.concat(all_fwk_stats)
self.all_cann_stats = pd.concat(all_cann_stats)
self.all_device_stats = pd.concat(all_device_stats)
def run(self, context):
mapper_res = self.mapper_func(context)
self.reducer_func(mapper_res)
if self._export_type == Constant.DB:
self.save_db()
elif self._export_type == Constant.NOTEBOOK:
self.save_notebook()
else:
logger.error("Unknown export type.")
def save_notebook(self):
self.dump_data(self.mark_stats, "mark_stats.csv")
self.dump_data(self.all_fwk_stats, "all_fwk_stats.csv")
self.dump_data(self.all_cann_stats, "all_cann_stats.csv")
self.dump_data(self.all_device_stats, "all_device_stats.csv")
self.create_notebook("stats.ipynb")
self.add_helper_file("cluster_display.py")
def save_db(self):
self.dump_data(self.mark_stats, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, self.TABLE_MARK_STATS)
self.dump_data(self.all_fwk_stats, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, self.TABLE_FRAMEWORK_STATS)
self.dump_data(self.all_cann_stats, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, self.TABLE_CANN_STATS)
self.dump_data(self.all_device_stats, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, self.TABLE_DEVICE_STATS)
def _mapper_func(self, data_map, analysis_class):
profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH)
rank_id = data_map.get(Constant.RANK_ID)
step_range = data_map.get(Constant.STEP_RANGE)
step_df = MstxStepExport(profiler_db_path, analysis_class, step_range).read_export_db()
if step_df is None or step_df.empty:
step_df = pd.DataFrame({"start_ns": [0], "end_ns": [float("inf")], "step_id": [0]})
mark_df = MstxMarkExport(profiler_db_path, analysis_class, step_range).read_export_db()
range_df = MstxRangeExport(profiler_db_path, analysis_class, step_range).read_export_db()
mstx_res = []
if not mark_df.empty:
mstx_res += handle_mark_data(mark_df, rank_id)
if not range_df.empty:
mstx_res += handle_range_data(range_df)
if not mstx_res:
logger.warning(f"There is no mstx data in {profiler_db_path}.")
return None
mstx_stats_df = pd.DataFrame(mstx_res).assign(Rank=rank_id)
mstx_stats_df["step_id"] = mstx_stats_df.apply(compute_step_id, axis=1, step_stats_df=step_df)
rename_mark_msg_name(mstx_stats_df)
mstx_stats_df = format_columns(mstx_stats_df).set_index("Name", drop=True)
return mstx_stats_df