import argparse
import json
import os
import shutil
import sys
import traceback
from abc import abstractmethod, ABC
import pandas as pd
from msprof_analyze.prof_common.db_manager import DBManager
from msprof_analyze.cluster_analyse.common_func.utils import convert_unit
from msprof_analyze.prof_common.constant import Constant
from msprof_analyze.prof_common.database_service import DatabaseService
from msprof_analyze.prof_common.logger import get_logger
from msprof_analyze.prof_common.path_manager import PathManager
from msprof_analyze.cluster_analyse.cluster_data_preprocess.msprof_data_preprocessor import MsprofDataPreprocessor
from msprof_analyze.prof_common.file_manager import FileManager
from msprof_analyze.prof_common.utils import convert_to_int
logger = get_logger()
class BaseRecipeAnalysis(ABC):
UNIT = "Us"
DB_UNIT = "Ns"
RANK_LIST = "rank_list"
TP_SIZE = "tensor_model_parallel_size"
PP_SIZE = "pipeline_model_parallel_size"
DP_SIZE = "data_parallel_size"
def __init__(self, params):
self._collection_dir = params.get(Constant.COLLECTION_PATH, "")
self._data_map = params.get(Constant.DATA_MAP, {})
self._recipe_name = params.get(Constant.RECIPE_NAME, "")
self._parallel_mode = params.get(Constant.PARALLEL_MODE, "")
self._export_type = params.get(Constant.EXPORT_TYPE, "")
self._prof_type = params.get(Constant.PROFILING_TYPE)
self._cluster_analysis_output_path = os.path.join(
params.get(Constant.CLUSTER_ANALYSIS_OUTPUT_PATH, self._collection_dir), Constant.CLUSTER_ANALYSIS_OUTPUT)
self._output_path = self._cluster_analysis_output_path if self._export_type == "db" else os.path.join(
self._cluster_analysis_output_path, self._recipe_name)
rank_list = params.get(Constant.RANK_LIST, 'all')
self._rank_list = rank_list if rank_list == "all" else [convert_to_int(rank) for rank in rank_list.split(",") if
rank.isdigit()]
self._step_id = params.get(Constant.STEP_ID, Constant.VOID_STEP)
self._extra_args = self.get_extra_argument(params.get(Constant.EXTRA_ARGS, []))
PathManager.make_dir_safety(self._output_path)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
logger.error(f"Failed to exit analysis: {exc_val}")
traceback.print_exc(file=sys.stdout)
@property
def output_path(self):
return self._output_path
@property
@abstractmethod
def base_dir(self):
"""
The directory name where stats.ipynb is located.
return: os.path.basename(os.path.dirname(__file__))
"""
raise NotImplementedError("Property base_dir need to be implemented.")
@staticmethod
def _filter_data(mapper_data):
return [(rank, data) for rank, data in mapper_data if data is not None and len(data) != 0]
@classmethod
def add_parser_argument(cls, parser):
pass
@classmethod
def get_extra_argument(cls, args_list) -> dict:
parser = argparse.ArgumentParser()
cls.add_parser_argument(parser)
args, unknown_args = parser.parse_known_args(args_list)
if unknown_args:
unknown_args = " ".join(unknown_args)
logger.warning(f"Invalid parameters: {unknown_args}. It will not have any effect.")
return vars(args)
@abstractmethod
def run(self, context):
raise NotImplementedError("Function run need to be implemented.")
def mapper_func(self, context):
return context.wait(
context.map(
self._mapper_func,
self._get_rank_db(),
analysis_class=self._recipe_name
)
)
def dump_data(self, data, file_name, table_name=None, index=True, custom_db_path=None):
if data is None:
logger.warning(f"No data to dump, skipping.")
return
if not isinstance(data, pd.DataFrame):
logger.error(f"Unknown dump data type: {type(data)}, expected pandas DataFrame")
return
if data.empty:
logger.warning(f"Empty DataFrame. Skip data dump!")
return
if table_name:
result_db = custom_db_path if custom_db_path else os.path.join(self.output_path, file_name)
logger.info(f"Exporting data to database: {result_db}, table: {table_name}")
conn, cursor = DBManager.create_connect_db(result_db)
data.to_sql(table_name, conn, if_exists='replace', index=index)
DBManager.destroy_db_connect(conn, cursor)
else:
result_csv = os.path.join(self.output_path, file_name)
logger.info(f"Exporting data to CSV file: {result_csv}")
data = convert_unit(data, self.DB_UNIT, self.UNIT)
FileManager.create_csv_from_dataframe(result_csv, data, index=index)
def create_notebook(self, filename, notebook_template_dir=None, replace_dict=None):
if notebook_template_dir is None:
template_path = os.path.dirname(__file__)
else:
template_path = notebook_template_dir
output_file_path = os.path.join(self.output_path, filename)
template_file = os.path.join(template_path, self.base_dir, filename)
if replace_dict is None:
shutil.copy(template_file, output_file_path)
os.chmod(output_file_path, Constant.FILE_AUTHORITY)
else:
template_content = FileManager.read_common_file(template_file)
for key, value in replace_dict.items():
template_content = template_content.replace(str(key), str(value))
FileManager.create_common_file(output_file_path, template_content)
logger.info(f"Notebook export path is: {output_file_path}")
def add_helper_file(self, helper_file):
helper_output_path = os.path.join(self.output_path, helper_file)
helper_file_path = os.path.join(os.path.dirname(__file__), helper_file)
if helper_file_path is not None:
shutil.copy(helper_file_path, helper_output_path)
os.chmod(helper_output_path, Constant.FILE_AUTHORITY)
def map_rank_pp_stage(self, distributed_args):
tp_size = distributed_args.get(self.TP_SIZE, 1)
pp_size = distributed_args.get(self.PP_SIZE, 1)
dp_size = distributed_args.get(self.DP_SIZE, 1)
rank_pp_stage_map = {}
rank = 0
for i in range(pp_size):
for _ in range(tp_size * dp_size):
rank_pp_stage_map[rank] = i
rank += 1
return rank_pp_stage_map
def load_distributed_args(self):
tp_size = self._extra_args.get("tp", None)
pp_size = self._extra_args.get("pp", None)
dp_size = self._extra_args.get("dp", None)
if tp_size and pp_size and dp_size:
if tp_size <= 0 or pp_size <= 0 or dp_size <= 0:
logger.error("Invalid distributed_args, tp pp dp < 0.")
return None
return {
self.TP_SIZE: tp_size,
self.DP_SIZE: dp_size,
self.PP_SIZE: pp_size,
}
else:
rank_id = list(self._data_map.keys())[0]
rank_path = self._data_map[rank_id]
db_path = self._get_profiler_db_path(rank_id, rank_path)
if os.path.exists(db_path):
try:
service = DatabaseService(db_path, {})
service.add_table_for_query("META_DATA", ["name", "value"])
df = service.query_data().get("META_DATA", None)
distributed_args = df.loc[df["name"] == "distributed_args", "value"]
if distributed_args.empty:
distributed_args = {}
logger.error("Distributed args not in profiling files, please input manually.")
else:
distributed_args = json.loads(distributed_args.values[0])
except Exception as err:
logger.error(err)
logger.error("Distributed args not in profiling files, please input manually.")
return None
tp_size = distributed_args.get(self.TP_SIZE, 1)
pp_size = distributed_args.get(self.PP_SIZE, 1)
dp_size = distributed_args.get(self.DP_SIZE, 1)
if not isinstance(tp_size, int) or not isinstance(pp_size, int) or not isinstance(dp_size, int):
logger.error("Invalid distributed_args in profiling files, please input manually.")
return None
if tp_size <= 0 or pp_size <= 0 or dp_size <= 0:
logger.error("Invalid distributed_args in profiling files, please input manually.")
return None
return {
self.TP_SIZE: tp_size,
self.PP_SIZE: pp_size,
self.DP_SIZE: dp_size,
}
logger.error(f"Db_file: {db_path} not exist.")
return None
def _get_rank_db(self):
invalid_rank_id = []
if self._rank_list == 'all':
rank_ids = list(self._data_map.keys())
else:
rank_ids = []
for rank_id in self._rank_list:
if rank_id in self._data_map.keys():
rank_ids.append(rank_id)
else:
invalid_rank_id.append(str(rank_id))
db_paths = []
for rank_id in rank_ids:
rank_path = self._data_map[rank_id]
db_path_dict = {Constant.RANK_ID: rank_id, Constant.PROFILER_DB_PATH: "", Constant.ANALYSIS_DB_PATH: "",
Constant.STEP_RANGE: {}, Constant.PROFILING_PATH: rank_path}
profiler_db_path = self._get_profiler_db_path(rank_id, rank_path)
analysis_db_path = self._get_analysis_db_path(rank_path)
if os.path.exists(profiler_db_path):
db_path_dict[Constant.PROFILER_DB_PATH] = profiler_db_path
db_path_dict[Constant.STEP_RANGE] = self._get_step_range(profiler_db_path)
else:
logger.warning(f"Profiler DB file not found, rank id: {rank_id}, db path: {profiler_db_path}.")
if self._prof_type != Constant.MSMONITOR:
if os.path.exists(analysis_db_path):
db_path_dict[Constant.ANALYSIS_DB_PATH] = analysis_db_path
else:
logger.warning(f"Analysis DB file not found, rank id: {rank_id}, db path: {analysis_db_path}.")
if db_path_dict.get(Constant.PROFILER_DB_PATH):
db_paths.append(db_path_dict)
if invalid_rank_id:
logger.warning(f"Invalid Rank id: [{','.join(invalid_rank_id)}].")
return db_paths
def _get_profiler_db_path(self, rank_id, data_path):
if self._prof_type == Constant.MSPROF:
db_path = MsprofDataPreprocessor.get_msprof_profiler_db_path(data_path)
return db_path if db_path else os.path.join(data_path, "msprof_xx.db")
if self._prof_type == Constant.MINDSPORE:
return os.path.join(data_path, Constant.SINGLE_OUTPUT, f"ascend_mindspore_profiler_{rank_id}.db")
if self._prof_type == Constant.PYTORCH:
return os.path.join(data_path, Constant.SINGLE_OUTPUT, f"ascend_pytorch_profiler_{rank_id}.db")
if self._prof_type == Constant.MSMONITOR:
return data_path
return ""
def _get_analysis_db_path(self, data_path):
if self._prof_type == Constant.MSPROF:
return os.path.join(data_path, Constant.ANALYZE_DIR, "communication_analyzer.db")
if self._prof_type == Constant.MINDSPORE:
return os.path.join(data_path, Constant.SINGLE_OUTPUT, "communication_analyzer.db")
if self._prof_type == Constant.PYTORCH:
return os.path.join(data_path, Constant.SINGLE_OUTPUT, "analysis.db")
return ""
def _get_step_range(self, db_path):
step_range = {}
if self._step_id == Constant.VOID_STEP:
return step_range
conn, cursor = DBManager.create_connect_db(db_path)
if not DBManager.judge_table_exists(cursor, "STEP_TIME"):
logger.error(f"The STEP_TIME table does not exist in the database: {db_path}, "
f"the parameter step_id will not take effect.")
DBManager.destroy_db_connect(conn, cursor)
return step_range
step_time = []
sql = f"select id, startNs, endNs from STEP_TIME"
try:
step_time = DBManager.fetch_all_data(cursor, sql)
except Exception as err:
logger.error(err)
return step_range
finally:
DBManager.destroy_db_connect(conn, cursor)
for step_data in step_time:
if step_data.get("id") == self._step_id:
step_range = step_data
break
if not step_range:
step_list = ", ".join([str(step.get("id", "")) for step in step_time])
logger.error(f"Invalid step_id {self._step_id} in the database: {db_path}, "
f"step_id must be an element of the set ({step_list}), "
f"the parameter step_id will not take effect.")
return step_range
def _mapper_func(self, data_map, analysis_class):
"""
Extract the profiling data required for cluster analysis from each device, and then aggregate the
results from each device to be processed by a reduce function.
Params:
data_map: eg1. {"RANK_ID": 1,
"profiler_db_path": "xxx/ASCEND_PROFILER_OUTPUT/ascend_pytorch_profiler_1.db",
"analysis_db_path": "xxx/ASCEND_PROFILER_OUTPUT/analysis.db",
"step_range": {"id": 2, "startNs": 12345, "endNs": 12443]}
eg2. {"RANK_ID": 1,
"profiler_db_path": "xxx/msprof_20250227145123.db",
"analysis_db_path": "xxx/analyze/communication_analyzer.db",
"step_range": {"id": 2, "startNs": 12345, "endNs": 12443]}
analysis_class: hccl_sum, compute_op_sum, cann_api_sum, mstx_sum……
"""
pass