import json
import logging
import os
from decimal import Decimal
from common_func.common import error
from common_func.constant import Constant
from common_func.empty_class import EmptyClass
from common_func.info_conf_reader import InfoConfReader
from common_func.ms_constant.number_constant import NumberConstant
from common_func.ms_constant.str_constant import StrConstant
from common_func.msprof_iteration import MsprofIteration
from common_func.singleton import singleton
from common_func.trace_view_header_constant import TraceViewHeaderConstant
from common_func.trace_view_manager import TraceViewManager
from common_func.profiling_scene import ProfilingScene
from profiling_bean.db_dto.step_trace_dto import IterationRange
from viewer.association.host_connect_device import HostToDevice
from viewer.training.step_trace_viewer import StepTraceViewer
@singleton
class MsprofTimeline:
"""
This class is used to export a summary timeline json file.
"""
FILE_NAME = os.path.basename(__file__)
def __init__(self: any) -> None:
self._iter_range = None
self._model_id = NumberConstant.DEFAULT_MODEL_ID
self._result_dir = None
self._export_data_list = []
self._iteration_time = (float('-inf'), float('inf'))
self._default_sort_index = TraceViewHeaderConstant.DEFAULT_LAYER_SORT_START
self.connect_list = {}
@classmethod
def get_timeline_header(cls: any, pid: str, pid_sort_index: int) -> list:
"""
get timeline header
"""
header = [["process_sort_index", pid, InfoConfReader().get_json_tid_data(), pid_sort_index]]
process_index = TraceViewManager.metadata_event(header)
return process_index
@classmethod
def filter_msprof_timeline(cls: any, json_list: list) -> list:
"""
filter msprof timeline data, and partition data by pid.
"""
pid_process_name = dict()
pid_json_data = dict()
filtered_data_list = []
for data_dict in json_list:
pid = data_dict.get(StrConstant.TRACE_HEADER_PID, TraceViewHeaderConstant.DEFAULT_PID_VALUE)
if data_dict.get(StrConstant.TRACE_HEADER_NAME) == "process_name":
pid_process_name[pid] = data_dict.get(StrConstant.TRACE_HEADER_ARGS, {}) \
.get(StrConstant.TRACE_HEADER_NAME, "")
if pid_process_name.get(pid) not in TraceViewHeaderConstant.MSPROF_TIMELINE_FILTER_LIST:
pid_json_data.setdefault(pid, []).append(data_dict)
for key in pid_process_name.keys():
if pid_process_name.get(key) not in TraceViewHeaderConstant.MSPROF_TIMELINE_FILTER_LIST:
filtered_data_list.append([key, pid_process_name.get(key), pid_json_data.get(key)])
return filtered_data_list
@classmethod
def modify_timeline_info(cls: any, process_name: str, layer_info: TraceViewHeaderConstant.LayerInfo,
format_pid: int, value: dict) -> None:
"""
modify timeline info based on layer_info
"""
value[StrConstant.TRACE_HEADER_PID] = format_pid
if value.get(StrConstant.TRACE_HEADER_NAME) == "process_name":
value.setdefault(StrConstant.TRACE_HEADER_ARGS, {})[StrConstant.TRACE_HEADER_NAME] = \
layer_info.component_layer
if value.get(StrConstant.TRACE_HEADER_NAME) == "thread_name" and \
process_name == TraceViewHeaderConstant.PROCESS_STEP_TRACE:
value.setdefault(StrConstant.TRACE_HEADER_ARGS, {})[StrConstant.TRACE_HEADER_NAME] = \
f'{process_name}({value.get(StrConstant.TRACE_HEADER_ARGS, {}).get(StrConstant.TRACE_HEADER_NAME, "")})'
if cls.is_cann_ai_stack_data(layer_info, value):
value[StrConstant.TRACE_HEADER_PID] = format_pid
level = value.get(StrConstant.TRACE_HEADER_ARGS, {}).get(StrConstant.API_EVENT_HEADER_LEVEL)
if not level:
prefix = process_name
else:
prefix = StrConstant.LEVEL_MAP.get(level, level.capitalize())
value[StrConstant.TRACE_HEADER_NAME] = \
f'{prefix}@{value.get(StrConstant.TRACE_HEADER_NAME, "")}'
@classmethod
def get_layer_label_and_sort(cls: any, pid: int, layer_info: TraceViewHeaderConstant.LayerInfo) -> list:
"""
get layer_label layer_sort headers
"""
general_layer = cls.get_label_name_by_pid(pid, layer_info.general_layer)
label_header = [["process_labels", pid, InfoConfReader().get_json_tid_data(), general_layer]]
sort_header = [["process_sort_index", pid, InfoConfReader().get_json_tid_data(), layer_info.sort_index]]
process_label = TraceViewManager.metadata_event(label_header)
process_sort = TraceViewManager.metadata_event(sort_header)
return process_label + process_sort
@classmethod
def get_label_name_by_pid(cls, pid: int, label_name: str):
device_id = TraceViewManager.get_device_id_from_format_pid(pid)
return label_name if device_id == TraceViewManager.HOST_ID_FOR_PID else label_name + " " + str(device_id)
@classmethod
def is_cann_ai_stack_data(cls: any, layer_info: TraceViewHeaderConstant.LayerInfo, value: dict) -> bool:
"""
return whether the data is cann ai stack data
"""
return value.get("ph") == "X" and \
layer_info.component_layer == TraceViewHeaderConstant.COMPONENT_LAYER_CANN
def init_export_data(self: any) -> None:
self._export_data_list = []
def add_export_data(self: any, export_data: json, data_type: str) -> None:
"""
index events in bulk json
:param data_type: data type
:param export_data: data to export
:param data_type: data_type
:return: None
"""
if not export_data:
return
try:
if isinstance(export_data, list):
self.add_sort_index(export_data)
self.add_connect_json_line(export_data, data_type)
if data_type not in ["api"]:
start_time, end_time = self.get_start_end_time()
export_data = filter(
lambda value: value["ph"] == "M" or self.is_in_iteration(value, start_time, end_time),
export_data)
self._export_data_list.extend(export_data)
except (TypeError, ValueError) as err:
logging.error(err, exc_info=Constant.TRACE_BACK_SWITCH)
error(self.FILE_NAME, err)
def add_connect_json_line(self: any, json_list: list, data_type: str) -> None:
"""
add connect line with task time
:param json_list: json list
:param data_type: data_type
"""
for connect_obj in self.connect_list.values():
connect_obj.add_connect_line(json_list, data_type)
def add_sort_index(self: any, json_list: list) -> None:
"""
add sort index and header
:param json_list: json list
"""
if isinstance(json_list, list):
filtered_data_list = self.filter_msprof_timeline(json_list)
json_list.clear()
for filtered_data in filtered_data_list:
process_name = filtered_data[1]
json_data = filtered_data[2]
pid = filtered_data[0]
layer_info = self.get_layer_info(process_name)
format_pid = TraceViewManager.get_format_pid(pid, layer_info)
for value in json_data:
self.modify_timeline_info(process_name, layer_info, format_pid, value)
json_list.extend(json_data)
json_list.extend(self.get_layer_label_and_sort(format_pid, layer_info))
def export_all_data(self: any) -> list:
"""
get bulk data
:return: json for timeline
"""
data = EmptyClass()
if not ProfilingScene().is_all_export():
data = StepTraceViewer.get_one_iter_timeline_data(self._result_dir, self._iter_range)
if not isinstance(data, EmptyClass):
data_list = data
if isinstance(data_list, list) and data_list:
self.add_sort_index(data_list)
data_list.extend(self._export_data_list)
return data_list
if not self._export_data_list:
return []
return self._export_data_list
def is_in_iteration(self: any, json_value: dict, start_time: Decimal, end_time: Decimal) -> bool:
"""
check if in iteration
"""
if not self._iteration_time:
return True
time_start = Decimal(str(json_value.get(TraceViewHeaderConstant.TRACE_HEADER_TS,
NumberConstant.DEFAULT_START_TIME)))
time_dur = Decimal(str(json_value.get(TraceViewHeaderConstant.TRACE_HEADER_DURATION,
NumberConstant.DEFAULT_START_TIME)))
time_end = time_start + time_dur
return start_time <= time_start < end_time or time_start < start_time < time_end
def set_connection_list(self, result_dir: str):
host_to_device_obj = HostToDevice(result_dir)
self.connect_list["host_to_device"] = host_to_device_obj
def set_iteration_info(self: any, result_dir: str, iter_range: IterationRange) -> None:
"""
get iteration time
"""
self._result_dir = result_dir
self._iter_range = iter_range
self._model_id = iter_range.model_id
if ProfilingScene().is_all_export():
start_time, _ = InfoConfReader().get_collect_time()
self._iteration_time = (float(start_time), float('inf'))
else:
start_time, end_time = MsprofIteration(result_dir).get_iter_interval(iter_range,
NumberConstant.MICRO_SECOND)
start_time = Decimal(InfoConfReader().trans_into_local_time(start_time, use_us=True))
end_time = Decimal(InfoConfReader().trans_into_local_time(end_time, use_us=True))
self._iteration_time = (float(start_time), float(end_time))
def get_layer_info(self: any, process_name: str) -> TraceViewHeaderConstant.LayerInfo:
"""
get msprof timeline layer info based on map
"""
layer_info = TraceViewHeaderConstant.LAYER_INFO_MAP.get(process_name, "")
if layer_info:
return layer_info
else:
self._default_sort_index += 1
return TraceViewHeaderConstant.LayerInfo(process_name, TraceViewHeaderConstant.GENERAL_LAYER_NPU,
self._default_sort_index)
def get_start_end_time(self: any):
return self._iteration_time