from abc import ABC
import multiprocessing
import logging
from msprof_analyze.prof_common.constant import Constant
from msprof_analyze.advisor.common.timeline.event import TimelineEvent
from msprof_analyze.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset
logger = logging.getLogger()
class TimelineBaseChecker(ABC):
def __init__(self, n_processes: int = 1):
self.n_processes = n_processes
self._matched_op_index = {} if self.n_processes <= 1 else multiprocessing.Manager().dict()
self.matched_op_stacks = {}
self.empty_stacks = True
self.framework_black_list = False
def query_stack(self, event_dataset: ScheduleAnalysisDataset = None, profiling_with_stack: str = None):
if all([len(matched_index) == 0 for matched_index in self._matched_op_index.values()]):
return
event_dataset = event_dataset if not profiling_with_stack else ScheduleAnalysisDataset(
collection_path=profiling_with_stack, data={}, _datasets={}, analysis_mode="fusion_ops",
build_dataset=False)
op_stack_list = event_dataset.parse_data_with_generator(self._query_stack_by_matched_index)
for op_stack in op_stack_list:
for op, stack in op_stack.items():
if op not in self.matched_op_stacks:
self.matched_op_stacks[op] = {}
if stack == Constant.TIMELINE_FUSION_OPS_NO_STACK_FLAG:
continue
if stack not in self.matched_op_stacks[op]:
self.matched_op_stacks[op][stack] = 0
self.matched_op_stacks[op][stack] += 1
def _query_stack_by_matched_index(self, index, event):
stack_record = {}
event = TimelineEvent(event)
matched_ops = []
for op, matched_index in self._matched_op_index.items():
if index not in matched_index:
continue
matched_ops.append(op)
stack = event.args.get(Constant.CALL_STACKS)
if not stack:
logger.debug("Got empty '%s' for event %s", Constant.CALL_STACKS, event)
continue
if not self._is_keep_stack(stack):
self.framework_black_list = True
logger.debug("Drop stack from framework %s", Constant.FRAMEWORK_STACK_BLACK_LIST)
continue
if self.empty_stacks and stack:
self.empty_stacks = False
stack_record[op] = stack
if matched_ops and not stack_record:
for op in matched_ops:
stack_record[op] = Constant.TIMELINE_FUSION_OPS_NO_STACK_FLAG
return stack_record
def _is_keep_stack(self, stack):
stack_list = stack.replace("\\r\\n", ";").split(";")
if not stack_list:
return False
final_called_stack = stack_list[0]
for framework in Constant.FRAMEWORK_STACK_BLACK_LIST:
if framework in final_called_stack.split("/"):
return False
return True