from msprof_analyze.prof_common.logger import get_logger
import os
from typing import Dict, List
import json
import pandas as pd
from msprof_analyze.prof_common.constant import Constant
from msprof_analyze.advisor.dataset.profiling.info_collection import TaskInfo
from msprof_analyze.advisor.dataset.profiling.profiling_parser import ProfilingParser
logger = get_logger()
class TaskChecker:
"""
check task info
"""
def __init__(self):
self.sqe_keys = set()
def is_sqe(self, task: TaskInfo) -> bool:
"""check sqe"""
key = (task.pid, task.tid)
if task.args.get('name', '').endswith('_SQE'):
self.sqe_keys.add(key)
return False
return key in self.sqe_keys
class Msprof(ProfilingParser):
"""
msprof
"""
FILE_PATTERN_MSG = "msprof_*.json"
FILE_INFO = "msprof"
file_pattern_list = [r"^msprof[_\d]+.json$"]
def __init__(self, path: str) -> None:
super().__init__(path)
self._tasks: List[TaskInfo] = []
self._hccl_tasks: List[TaskInfo] = []
self._iteration_time = 0.0
self._model_id = None
self._iteration_id = None
self._process_pid: Dict[str, str] = {}
self._min_time = 0.0
self._max_time = 0.0
self._data_process_time = 0.0
self._start_point = 0.0
def __len__(self):
return len(self._tasks)
@property
def step_time(self):
return self._iteration_time + self._data_process_time
@property
def iteration_time(self):
return self._iteration_time
@property
def iter_max_time(self):
return self._max_time
@property
def iter_min_time(self):
return self._min_time
@property
def data_process_time(self):
return self._data_process_time
@property
def tasks(self):
return self._tasks
@property
def hccl_tasks(self):
return self._hccl_tasks
@property
def model_id(self):
return self._model_id
@property
def iteration_id(self):
return self._iteration_id
@property
def process_pid(self):
return self._process_pid
@property
def start_point(self):
return self._start_point
def parse_from_file(self, file: str):
if not self._parse_json(file):
return False
min_time = float('inf')
max_time = 0.0
task_checker = TaskChecker()
is_iter = False
self._tasks = [None] * len(self._raw_data)
task_index = 0
for item in self._raw_data:
task = TaskInfo(item)
if task.cat == "Iteration Time":
self._min_time = task.start_time
self._max_time = task.end_time
self._iteration_time = task.dur
is_iter = True
if task.cat == "Data_aug Bound" and "Data_aug Bound(us)" in task.args:
self._data_process_time = task.args["Data_aug Bound(us)"]
if self._start_point == 0 and task.start_time > 0:
self._start_point = task.start_time
if task_checker.is_sqe(task):
continue
self._tasks[task_index] = task
task_index += 1
self._parse_task(task)
start_time = task.start_time
dur = task.dur
if start_time == -1 or dur == -1 or dur == 0:
continue
if start_time < min_time:
min_time = start_time
end_time = start_time + dur
if end_time > max_time:
max_time = end_time
if not is_iter:
self._iteration_time = dur
self._max_time = max_time
self._min_time = min_time
self._tasks = self._tasks[:task_index]
if self._tasks:
self._tasks.sort(key=lambda x: x.start_time)
return True
return False
def _parse_task(self, task):
if "Iteration Refresh" in task.name:
self._iteration_id = task.args.get("Iteration ID")
elif "Model ID" in task.name:
self._model_id = int(task.name.split(":")[1])
elif "process_name" == task.name:
self._process_pid[task.args.get("name")] = task.pid
class MsprofDB(Msprof):
FILE_PATTERN_MSG = "ascend_*_profiler.db"
FILE_INFO = "timeline info from db"
file_pattern_list = [r'^ascend_pytorch_profiler(?:_\d+)?\.db$',
r'^ascend_mindspore_profiler(?:_\d+)?\.db$',
r'^msprof_\d{14}\.db$']
HCCL_TASK_SQL = """
SELECT
str.value as name,
task.globalPid as pid,
task.startNs / 1000.0 as ts,
(task.endNs - task.startNs) / 1000.0 as dur,
JSON_OBJECT(
'task type', str.value,
'stream id', task.streamId,
'task id', task.taskId,
'transport type', trans.name,
'link type', link.name,
'size(Byte)', comm.size
) as args
FROM COMMUNICATION_TASK_INFO as comm
JOIN TASK as task ON comm.globalTaskId = task.globalTaskId
JOIN STRING_IDS as str ON str.id = comm.taskType
JOIN ENUM_HCCL_LINK_TYPE as link ON link.id = comm.linkType
JOIN ENUM_HCCL_TRANSPORT_TYPE as trans ON trans.id = comm.transportType
"""
HCCL_OP_SQL = """
SELECT
str.value as name,
comm.startNs / 1000.0 as ts,
(comm.endNs - comm.startNs) / 1000.0 as dur
FROM COMMUNICATION_OP as comm
JOIN STRING_IDS as str ON str.id = comm.opName
"""
NODE_INFO_SQL = """
WITH ranked_apis AS (
SELECT
str.value AS name,
api.startNs / 1000.0 AS ts,
(api.endNs - api.startNs) / 1000.0 AS dur,
type.name as type,
LAG(str.value) OVER (ORDER BY api.startNs) AS prev_name
FROM CANN_API as api
JOIN STRING_IDS as str ON api.name = str.id
JOIN ENUM_API_TYPE as type ON api.type = type.id
)
SELECT
name,
ts,
dur,
json_object('item_id', prev_name) AS args
FROM ranked_apis
WHERE type = 'node'
ORDER BY ts;
"""
def __init__(self, path: str) -> None:
super().__init__(path)
def parse_from_file(self, file: str):
if not file or not os.path.exists(file):
logger.error("db path is None.")
return False
self.process_communication_tasks(file)
self.process_node_tasks(file)
return True
def process_communication_tasks(self, db_path):
self._process_task_data(db_path, self.HCCL_TASK_SQL, self._hccl_tasks, [Constant.TABLE_COMMUNICATION_TASK_INFO])
self._process_task_data(db_path, self.HCCL_OP_SQL, self._hccl_tasks, [Constant.TABLE_COMMUNICATION_OP])
self._hccl_tasks.sort(key=lambda x: x.start_time)
def process_node_tasks(self, db_path):
self._process_task_data(db_path, self.NODE_INFO_SQL, self._tasks)
def _process_task_data(self, db_path, sql: str, result: list, tables_to_check=None):
df = self._execute_sql(db_path, sql, tables_to_check)
if df.empty:
return
if 'args' in df.columns:
df['args'] = df['args'].apply(lambda x: json.loads(x) if pd.notna(x) else {})
result.extend([TaskInfo(record) for record in df.to_dict('records')])