"""Thr parser for parsing framework files."""
import csv
import enum
import json
import os
import re
import stat
from mindspore.profiler.common.exceptions.exceptions import \
ProfilerPathErrorException, ProfilerDirNotFoundException, \
ProfilerFileNotFoundException, ProfilerDeviceIdMismatchException, \
ProfilerRawFileException, ProfilerParamValueErrorException
from mindspore.profiler.common.validator.validate_path import \
validate_and_normalize_path
class VmDataType(enum.IntEnum):
"""Definition of vm data type."""
NUMBER_TYPE_BEGIN = 30
NUMBER_TYPE_BOOL = 31
NUMBER_TYPE_INT = 32
NUMBER_TYPE_INT8 = 33
NUMBER_TYPE_INT16 = 34
NUMBER_TYPE_INT32 = 35
NUMBER_TYPE_INT64 = 36
NUMBER_TYPE_UINT = 37
NUMBER_TYPE_UINT8 = 38
NUMBER_TYPE_UINT16 = 39
NUMBER_TYPE_UINT32 = 40
NUMBER_TYPE_UINT64 = 41
NUMBER_TYPE_FLOAT = 42
NUMBER_TYPE_FLOAT16 = 43
NUMBER_TYPE_FLOAT32 = 44
NUMBER_TYPE_FLOAT64 = 45
NUMBER_TYPE_COMPLEX = 46
NUMBER_TYPE_END = 47
@classmethod
def get_data_type_name(cls, num):
"""
Get the name of data type by enum number.
Args:
num (int): Enum number.
Returns:
str, the name of data type.
"""
data_type = cls._value2member_map_.get(num)
return 'UNKNOWN' if data_type is None else data_type.name
class GeDataType(enum.IntEnum):
"""Definition of ge data type."""
DT_FLOAT = 0
DT_FLOAT16 = 1
DT_INT8 = 2
DT_INT16 = 6
DT_UINT16 = 7
DT_UINT8 = 4
DT_INT32 = 3
DT_INT64 = 9
DT_UINT32 = 8
DT_UINT64 = 10
DT_BOOL = 12
DT_DOUBLE = 11
DT_STRING = 13
DT_DUAL_SUB_INT8 = 14
DT_DUAL_SUB_UINT8 = 15
DT_COMPLEX64 = 16
DT_COMPLEX128 = 17
DT_QINT8 = 18
DT_QINT16 = 19
DT_QINT32 = 20
DT_QUINT8 = 21
DT_QUINT16 = 22
DT_RESOURCE = 23
DT_STRING_REF = 24
DT_DUAL = 25
DT_UNDEFINED = 26
@classmethod
def get_data_type_name(cls, num):
"""
Get the name of data type by enum number.
Args:
num (int): Enum number.
Returns:
str, the name of data type.
"""
data_type = cls._value2member_map_.get(num)
return 'UNKNOWN' if data_type is None else data_type.name
class GeFormat(enum.IntEnum):
"""Definition of ge format type."""
FORMAT_NCHW = 0
FORMAT_NHWC = 1
FORMAT_ND = 2
FORMAT_NC1HWC0 = 3
FORMAT_FRACTAL_Z = 4
FORMAT_NC1C0HWPAD = 5
FORMAT_NHWC1C0 = 6
FORMAT_FSR_NCHW = 7
FORMAT_FRACTAL_DECONV = 8
FORMAT_C1HWNC0 = 9
FORMAT_FRACTAL_DECONV_TRANSPOSE = 10
FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11
FORMAT_NC1HWC0_C04 = 12
FORMAT_FRACTAL_Z_C04 = 13
FORMAT_CHWN = 14
FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15
FORMAT_HWCN = 16
FORMAT_NC1KHKWHWC0 = 17
FORMAT_BN_WEIGHT = 18
FORMAT_FILTER_HWCK = 19
FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20
FORMAT_HASHTABLE_LOOKUP_KEYS = 21
FORMAT_HASHTABLE_LOOKUP_VALUE = 22
FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23
FORMAT_HASHTABLE_LOOKUP_HITS = 24
FORMAT_C1HWNCOC0 = 25
FORMAT_MD = 26
FORMAT_NDHWC = 27
FORMAT_FRACTAL_ZZ = 28
FORMAT_FRACTAL_NZ = 29
FORMAT_NCDHW = 30
FORMAT_DHWCN = 31
FORMAT_NDC1HWC0 = 32
FORMAT_FRACTAL_Z_3D = 33
FORMAT_CN = 34
FORMAT_NC = 35
FORMAT_DHWNC = 36
FORMAT_FRACTAL_Z_3D_TRANSPOSE = 37
FORMAT_RESERVED = 38
FORMAT_ALL = 39
@classmethod
def get_format_name(cls, num):
"""
Get the name of format type by enum number.
Args:
num (int): Enum number.
Returns:
str, the name of format type.
"""
format_type = cls._value2member_map_.get(num)
return 'UNKNOWN' if format_type is None else format_type.name
class FrameworkParser:
"""
Thr parser for parsing framework files.
Args:
profiling_id (str): The profiling ID.
device_id (str): The device ID.
rank_id (str): The rank ID.
output_path (str): The directory of the parsed file. Default: `./`.
"""
_regex_framework = r'Framework\.(?P<data_type>.+)\.(?P<device_id>\d).+'
_regex_framework_in_data = r'Framework\.(?P<data_type>.+)\.' \
r'(?P<device_id>\d)\.(?P<profiling_id>[a-zA-Z0-9]+).+'
_col_names = [
'task_id', 'stream_id', 'block_dim', 'full_op_name', 'op_name',
'op_type', 'subgraph', 'op_info'
]
_graph_attr_name = [
'input_format', 'input_data_type', 'input_shape', 'output_format',
'output_data_type', 'output_shape'
]
_task_id_threshold = 25000
def __init__(self, profiling_id, device_id, rank_id, output_path='./'):
self._raw_data_dir = output_path
self._profiling_path = self._get_raw_profiling_path(profiling_id)
self._backend_type = None
self._framework_path = {'graph': [], 'task': [], 'point': []}
self._search_file(profiling_id, device_id)
self._device_id = device_id
self._save_path = self._get_save_path(rank_id, output_path)
self._task_id_full_op_name_dict = {}
self._task_cache = {}
self._point_info = {}
self._parse_task_files()
self._parse_point_files()
@property
def save_path(self):
"""
The property of save path.
Returns:
str, the save path.
"""
return self._save_path
@property
def point_info(self):
"""
The property of the framework point information.
Returns:
dict, the framework point information.
"""
return self._point_info
def to_task_id_full_op_name_dict(self):
"""
Get the task id and full operator name dict.
Returns:
dict, the task id and full operator name dict.
"""
return self._task_id_full_op_name_dict
def parse(self):
"""Parse the framework files."""
self._parse_graph_files_and_save(self._task_cache)
del self._task_cache
def check_op_name(self, op_name, is_prefix=True):
"""
Check whether the operator name exists.
Args:
op_name (str): The operator name or operator name prefix.
is_prefix (bool): `True` if the op_name is prefix, else `False`.
Default: True.
Returns:
bool, `True` if the operator name does exist in framework file, else
`False`.
"""
if not op_name:
raise ProfilerParamValueErrorException('The op_name should exist.')
for full_op_name in self._task_id_full_op_name_dict.values():
if full_op_name:
if is_prefix and full_op_name.startswith(op_name):
return True
if not is_prefix and op_name == full_op_name:
return True
return False
def _get_raw_profiling_path(self, profiling_id):
"""
Get raw profiling path.
Args:
profiling_id (str): The profiling ID.
Returns:
str, the raw profiling path.
Raises:
ProfilerPathErrorException: If the profiling path is invalid.
ProfilerDirNotFoundException: If the profiling dir is not found.
"""
profiling_path = os.path.join(self._raw_data_dir, profiling_id)
try:
profiling_path = validate_and_normalize_path(profiling_path)
except RuntimeError:
raise ProfilerPathErrorException('Profiling path is invalid.')
if not os.path.isdir(profiling_path):
raise ProfilerDirNotFoundException(profiling_path)
return profiling_path
def _search_file(self, profiling_id, device_id):
"""
Search all framework files in raw profiling path.
Args:
profiling_id (str): The profiling ID.
device_id (str): The device ID.
Raises:
ProfilerFileNotFoundException: If the framework files are not found.
"""
self._search_file_from_job_path(device_id, search_in_sub_path=False)
if self._backend_type is None:
self._search_file_from_job_path(device_id, search_in_sub_path=True)
self._search_file_from_data_path(profiling_id, device_id)
if self._backend_type is None:
raise ProfilerFileNotFoundException('Framework')
self._framework_path['graph'].sort()
self._framework_path['task'].sort()
def _search_file_from_job_path(self, device_id, search_in_sub_path=False):
"""
Search framework files from job path.
Args:
device_id (str): The device ID.
search_in_sub_path (bool): `True` if search file in profiling dir,
else search in profiling sub dir. Default: False.
Raises:
ProfilerRawFileException: If the framework file type is inconsistent.
ProfilerDeviceIdMismatchException: If the device id is mismatch
with framework in the raw dir.
"""
profiling_dir = os.path.join(self._profiling_path, 'data') \
if search_in_sub_path else self._profiling_path
if not os.path.isdir(profiling_dir):
return
files = os.listdir(profiling_dir)
for file in files:
pattern = re.search(self._regex_framework, file)
if not pattern or file.endswith('.done'):
continue
attrs = pattern.groupdict()
device_id_in_path = attrs.get('device_id')
if device_id_in_path != device_id:
raise ProfilerDeviceIdMismatchException()
data_type = attrs.get('data_type')
data_type = data_type.replace("host.", "")
if data_type.startswith('vm_'):
if self._backend_type and self._backend_type != 'vm':
raise ProfilerRawFileException('Backend type is inconsistent.')
self._backend_type = 'vm'
_, data_type = data_type.split('_', 1)
else:
if self._backend_type and self._backend_type != 'ge':
raise ProfilerRawFileException('Backend type is inconsistent.')
self._backend_type = 'ge'
if data_type.startswith('graph_desc_info'):
self._framework_path['graph'].append(
os.path.join(profiling_dir, file)
)
elif data_type.startswith('task_desc_info'):
self._framework_path['task'].append(
os.path.join(profiling_dir, file)
)
elif data_type.startswith('point'):
self._framework_path['point'].append(
os.path.join(profiling_dir, file)
)
def _search_file_from_data_path(self, profiling_id, device_id):
"""
Search framework files from data path.
Args:
profiling_id (str): The profiling ID.
device_id (str): The device ID.
Raises:
ProfilerRawFileException: If the framework file type is inconsistent.
ProfilerDeviceIdMismatchException: If the device id is mismatch
with framework in the raw dir.
"""
profiling_data_path = os.path.join(
self._raw_data_dir, 'container', device_id, 'data'
)
if not os.path.isdir(profiling_data_path):
return
files = os.listdir(profiling_data_path)
for file in files:
pattern = re.search(self._regex_framework_in_data, file)
if not pattern or file.endswith('.done') or file.endswith('.zip'):
continue
attrs = pattern.groupdict()
profiling_id_in_path = attrs.get('profiling_id')
if profiling_id_in_path != profiling_id:
continue
device_id_in_path = attrs.get('device_id')
if device_id_in_path != device_id:
raise ProfilerDeviceIdMismatchException()
data_type = attrs.get('data_type')
data_type = data_type.replace("host.", "")
if data_type.startswith('vm_'):
if self._backend_type and self._backend_type != 'vm':
raise ProfilerRawFileException('Backend type is inconsistent.')
self._backend_type = 'vm'
_, data_type = data_type.split('_', 1)
else:
if self._backend_type and self._backend_type != 'ge':
raise ProfilerRawFileException('Backend type is inconsistent.')
self._backend_type = 'ge'
if data_type.startswith('graph_desc_info'):
self._framework_path['graph'].append(
os.path.join(profiling_data_path, file)
)
elif data_type.startswith('task_desc_info'):
self._framework_path['task'].append(
os.path.join(profiling_data_path, file)
)
elif data_type.startswith('point'):
self._framework_path['point'].append(
os.path.join(profiling_data_path, file)
)
def _get_save_path(self, rank_id, output_path):
"""
Get the save path.
Args:
rank_id (str): The rank ID.
output_path (str): The output dir.
Returns:
str, the save path.
Raises:
ProfilerPathErrorException: If the output path is invalid.
ProfilerDirNotFoundException: If the output dir is not found.
"""
try:
output_dir = validate_and_normalize_path(output_path)
except RuntimeError:
raise ProfilerPathErrorException('Output path is invalid.')
if not os.path.isdir(output_dir):
raise ProfilerDirNotFoundException(output_dir)
return os.path.join(
output_dir, '_'.join(['framework', 'raw', rank_id]) + '.csv'
)
def _parse_task_files(self):
"""Parse the framework task files."""
for path in self._framework_path['task']:
path = validate_and_normalize_path(path)
with open(path, 'r') as file:
for task_info in file:
infos = task_info.strip('\n').split(' ')
infos = infos[1:] if len(infos) == 5 else infos
self._task_cache[infos[0]] = [infos[2], infos[3], infos[1]]
task_id = infos[2]
if int(task_id) < self._task_id_threshold:
task_id = '_'.join([infos[3], task_id])
self._task_id_full_op_name_dict[task_id] = infos[0]
def _parse_graph_files_and_save(self, task_cache):
"""
Parse the framework graph files and save the framework information.
Args:
task_cache (dict): The task information cache.
"""
with open(self._save_path, 'w') as save_file:
csv_writer = csv.writer(save_file)
csv_writer.writerow(self._col_names)
pre_graph_info = None
for path in self._framework_path['graph']:
first_row = True
with open(path, 'r') as graph_file:
for graph_info in graph_file:
if first_row is True:
first_row = False
if graph_info.startswith("op_name:") is False:
pre_graph_info = pre_graph_info + graph_info
continue
if pre_graph_info is not None:
self._parse_graph_row_and_save(task_cache, csv_writer, pre_graph_info)
pre_graph_info = graph_info
if pre_graph_info is not None:
self._parse_graph_row_and_save(task_cache, csv_writer, pre_graph_info)
none_list = [None, None, None, None]
for key, value in task_cache.items():
value.append(key)
value.extend(none_list)
csv_writer.writerow(value)
os.chmod(self._save_path, stat.S_IREAD | stat.S_IWRITE)
def _parse_graph_row_and_save(self, task_cache, csv_writer, graph_info):
"""
Parse the framework graph row and save the framework information.
Args:
task_cache (dict): The task information cache.
csv_writer (csv): Csv writer.
graph_info (str): Row info of graph.
"""
result = self._parse_one_row_graph_info(graph_info)
task_info = task_cache.get(result[0])
if task_info:
task_info.extend(result)
csv_writer.writerow(task_info)
del task_cache[result[0]]
else:
save_info = [None, None, None]
save_info.extend(result)
csv_writer.writerow(save_info)
def _parse_one_row_graph_info(self, row_info):
"""
Parse the graph information in one row.
Args:
row_info (str): One row graph information.
Returns:
list[str], the parsed graph information.
"""
full_op_name = None
op_name = None
subgraph_name = None
op_type = None
op_info = dict()
cur_op_info_key = None
infos = row_info.strip('\n').split(' ')
for info in infos:
attr_name, attr_value = info.split(':', 1)
if attr_name == 'op_name':
full_op_name = attr_value
subgraph_name = self._get_subgraph_name(full_op_name)
op_name = self._get_op_name(full_op_name, subgraph_name)
elif attr_name == 'op_type':
op_type = attr_value
elif attr_name in ['input_id', 'output_id']:
cur_op_info_key = '{}_{}'.format(
attr_name.split('_')[0], attr_value
)
op_info[cur_op_info_key] = dict()
elif attr_name in self._graph_attr_name:
op_attr = attr_name.split('_', 1)[1]
if op_attr == 'shape':
attr_value = attr_value.strip('"')
if self._backend_type == 'vm':
if op_attr == 'data_type':
attr_value = VmDataType.get_data_type_name(
int(attr_value)
)
else:
if op_attr == 'data_type':
attr_value = GeDataType.get_data_type_name(
int(attr_value)
)
elif op_attr == 'format':
attr_value = GeFormat.get_format_name(int(attr_value))
op_info[cur_op_info_key][op_attr] = attr_value
return [full_op_name, op_name, op_type, subgraph_name,
json.dumps(op_info)]
def _get_subgraph_name(self, full_op_name):
"""
Get subgraph name.
Args:
full_op_name (str): The full operator name.
Returns:
str, the subgraph name.
"""
subgraph_name = full_op_name.split('/', 1)[0]
if subgraph_name in ['Default', 'Gradients']:
return subgraph_name
return None
def _get_op_name(self, full_op_name, subgraph_name):
"""
Get operator name.
Args:
full_op_name (str): The full operator name.
subgraph_name (str): The subgraph name.
Returns:
str, the operator name.
"""
if subgraph_name is None:
return full_op_name
if self._backend_type == 'vm':
return full_op_name.split('/')[-1]
strs = full_op_name.split(subgraph_name + '/')
op_name = None
for name_str in strs:
if not name_str:
continue
if op_name is None:
op_name = name_str.split('/')[-1]
else:
op_name = '+'.join([op_name, name_str.split('/')[-1]])
return op_name
def _parse_point_files(self):
"""Parse the framework point files."""
for path in self._framework_path['point']:
path = validate_and_normalize_path(path)
with open(path, 'r') as file:
for point_info in file:
infos = point_info.strip('\n').split(' ')
self._point_info[int(infos[0])] = infos[1]