from msprof_analyze.prof_common.logger import get_logger
import os
from functools import wraps
from typing import Dict, List, Union
from abc import abstractmethod, ABCMeta
from msprof_analyze.prof_common.constant import Constant
from msprof_analyze.advisor.common.enum_params_parser import EnumParamsParser
from msprof_analyze.advisor.common.version_control import VersionControl
from msprof_analyze.advisor.dataset.dataset import Dataset
from msprof_analyze.advisor.result.result import OptimizeResult
from msprof_analyze.advisor.display.html.render import HTMLRender
from msprof_analyze.advisor.display.html.priority_background_color import PriorityBackgroundColor
from msprof_analyze.advisor.utils.utils import safe_division
from msprof_analyze.prof_common.file_manager import FileManager
from msprof_analyze.prof_common.path_manager import PathManager
logger = get_logger()
ASCEND_PT = "ascend_pt"
ASCEND_MS = "ascend_ms"
PROFILER_INFO_HEAD = "profiler_info_"
PROFILER_INFO_EXTENSION = ".json"
MS_VERSION = "ms_version"
class BaseAnalyzer(VersionControl, metaclass=ABCMeta):
_SUPPORT_VERSIONS = EnumParamsParser().get_options(Constant.CANN_VERSION)
ANALYZER_HIGH_PRIORITY_TIME_RATIO = 0.05
ANALYZER_MEDIUM_PRIORITY_TIME_RATIO = 0.03
dataset_cls_list = []
def __init__(self, collection_path, n_processes: int = 1, **kwargs):
self.n_processes = n_processes
self.kwargs = kwargs
self.collection_path = collection_path
self.output_path = kwargs.get("output_path", None)
self.cann_version = kwargs.get(Constant.CANN_VERSION, EnumParamsParser().get_default(Constant.CANN_VERSION))
self.profiling_type = self.identify_profiling_type(
EnumParamsParser().get_options(Constant.PROFILING_TYPE_UNDER_LINE))
self.profiling_version = self.identify_profiling_version()
self.html_render = HTMLRender()
self.dataset_list: Dict[str, List[Dataset]] = {}
self.init_dataset_list()
self.result = OptimizeResult()
self.record_list: Dict[str, List] = {}
@staticmethod
def get_first_data_by_key(data, key) -> Union[Dataset, None]:
"""
get the first member from data with key
:param data: input data
:param key: data key
:return: the first dataset in dataset list
"""
if key in data and len(data[key]) > 0:
return data[key][0]
return None
@classmethod
def check_data(cls, data_list: tuple):
"""
check if all data in data list is contained
:param data_list: data list to check
:return: func ptr if check success
"""
def decorate(func):
@wraps(func)
def wrapper(self, **kwargs):
data = self.dataset_list
if data is None:
return None
for data_key in data_list:
if data_key not in data:
return None
logger.info("Start analysis %s with %s", self.__class__.__name__, ",".join(data_list))
return func(self, **kwargs)
return wrapper
return decorate
@abstractmethod
def optimize(self, **kwargs):
pass
@abstractmethod
def get_priority(self, max_mem_op_dur):
pass
def identify_profiling_type(self, profiling_type_list):
profiling_type = ""
if not profiling_type_list:
return profiling_type
if self.collection_path.endswith(ASCEND_MS):
profiling_type = [elem for elem in profiling_type_list if Constant.MINDSPORE in elem][0]
elif self.collection_path.endswith(ASCEND_PT):
profiling_type = [elem for elem in profiling_type_list if Constant.PYTORCH in elem][0]
else:
for _, dirs, __ in PathManager.limited_depth_walk(self.collection_path):
is_found_type = False
for direction in dirs:
if direction.endswith(ASCEND_MS):
profiling_type = [elem for elem in profiling_type_list if Constant.MINDSPORE in elem][0]
is_found_type = True
break
elif direction.endswith(ASCEND_PT):
profiling_type = [elem for elem in profiling_type_list if Constant.PYTORCH in elem][0]
is_found_type = True
break
if is_found_type:
break
if self.kwargs.get(Constant.PROFILING_TYPE_UNDER_LINE) and self.kwargs.get(
Constant.PROFILING_TYPE_UNDER_LINE) != profiling_type:
logger.warning("%s The input profiling type %s is inconsistent with the actual profiling type %s.",
self.__class__.__name__, self.kwargs.get(Constant.PROFILING_TYPE_UNDER_LINE), profiling_type)
if not profiling_type:
logger.warning("Unknown profiling type, the default value is set pytorch.")
profiling_type = profiling_type_list[0]
return profiling_type
def identify_profiling_version(self):
profiling_version = ""
if Constant.MINDSPORE in self.profiling_type:
ascend_dirs = []
if self.collection_path.endswith(ASCEND_MS):
ascend_dirs.append(self.collection_path)
else:
for root, dirs, _ in PathManager.limited_depth_walk(self.collection_path):
for direction in dirs:
if direction.endswith(ASCEND_MS):
ascend_dirs.append(os.path.join(root, direction))
if ascend_dirs:
ascend_dir = ascend_dirs[0]
for file_name in os.listdir(ascend_dir):
if file_name.startswith(PROFILER_INFO_HEAD) and file_name.endswith(PROFILER_INFO_EXTENSION):
file_path = os.path.join(ascend_dir, file_name)
config = FileManager.read_json_file(file_path)
profiling_version = config.get(MS_VERSION, "")
break
if profiling_version and self.kwargs.get(Constant.MINDSPORE_VERSION):
if profiling_version != self.kwargs.get(Constant.MINDSPORE_VERSION):
logger.warning("%s The input version %s is inconsistent with the actual version %s.",
self.__class__.__name__, self.kwargs.get(Constant.MINDSPORE_VERSION),
profiling_version)
elif Constant.PYTORCH in self.profiling_type:
profiling_version = self.kwargs.get(Constant.TORCH_VERSION,
EnumParamsParser().get_default(Constant.TORCH_VERSION))
if self.kwargs.get(Constant.TORCH_VERSION) and profiling_version != self.kwargs.get(Constant.TORCH_VERSION):
logger.warning("%s The input version %s is inconsistent with the actual version %s.",
self.__class__.__name__, self.kwargs.get(Constant.TORCH_VERSION), profiling_version)
return profiling_version
def init_dataset_list(self) -> None:
dataset_cls_list = self.dataset_cls_list
if len(dataset_cls_list) == 0:
logger.warning(f"Analyser: %s don't rely on any dataset!", self.__class__.__name__)
return
for dataset_cls in dataset_cls_list:
if dataset_cls and callable(dataset_cls):
try:
dataset = dataset_cls(collection_path=self.collection_path, data=self.dataset_list, **self.kwargs)
except Exception as e:
logger.error(e)
continue
key = dataset_cls.get_key()
if key not in self.dataset_list:
self.dataset_list[key] = []
self.dataset_list[key].append(dataset)
def get_priority_by_time_ratio(self, dur, step_dur):
time_ratio = safe_division(dur, step_dur)
if time_ratio >= self.ANALYZER_HIGH_PRIORITY_TIME_RATIO:
return PriorityBackgroundColor.high
elif time_ratio >= self.ANALYZER_MEDIUM_PRIORITY_TIME_RATIO:
return PriorityBackgroundColor.medium
else:
return PriorityBackgroundColor.low