"""
dataset module
"""
from msprof_analyze.prof_common.logger import get_logger
import os
import re
from msprof_analyze.prof_common.constant import Constant
from msprof_analyze.prof_common.file_manager import FileManager
from msprof_analyze.prof_common.path_manager import PathManager
from msprof_analyze.advisor.config.config import Config
logger = get_logger()
class Dataset:
"""
:param collection_path: dataSet absolute path
dataset base class
"""
PYTORCH_DB_PATTERN = re.compile(r'ascend_pytorch_profiler(?:_\d+)?\.db$')
MINDSPORE_DB_PATTERN = re.compile(r'ascend_mindspore_profiler(?:_\d+)?\.db$')
def __init__(self, collection_path, data=None, **kwargs) -> None:
if data is None:
data = {}
self.collection_path = os.path.abspath(os.path.join(Config().work_path, collection_path))
self.output_path = kwargs.get("output_path", None)
self.data_type = self.get_data_type()
if not self.output_path:
self.output_path = self.collection_path
logger.debug("init %s with %s", self.__class__.__name__, self.collection_path)
if self._parse():
key = self.get_key()
if key not in data:
data[key] = []
data[key].append(self)
@staticmethod
def _parse():
return None
@classmethod
def get_key(cls):
"""
get key of dataset
:return: key
"""
return cls.__name__.rsplit('.', maxsplit=1)[-1]
def get_data_type(self):
for root, dirs, _ in PathManager.limited_depth_walk(self.collection_path):
if Constant.ASCEND_PROFILER_OUTPUT in dirs:
profiler_dir = os.path.join(root, Constant.ASCEND_PROFILER_OUTPUT)
for file in os.listdir(profiler_dir):
if self.PYTORCH_DB_PATTERN.match(file) or self.MINDSPORE_DB_PATTERN.match(file):
return Constant.DB
return Constant.TEXT