"""The parser for parsing hccl files."""
import csv
import json
import os
import stat
from enum import Enum
import numpy as np
from mindspore.profiler.common.exceptions.exceptions import \
ProfilerPathErrorException, ProfilerFileNotFoundException, \
ProfilerDirNotFoundException, ProfilerRawFileException
from mindspore import log as logger
from mindspore.profiler.common.validator.validate_path import \
validate_and_normalize_path
class CommunicationInfo(Enum):
"""
Communication related enumeration types.
Enum:
RDMA: Communication link between servers in cluster training.
SDMA: Communication link inside server in cluster training.
LOCAL: The operation of this card has no transmission process.
RDMASEND:Communication operator of RDMA link.
REDUCE_INLINE:Communication operator of SDMA link.
MEMCPY:Communication operator of SDMA link.
NOTIFY_RECORD: Communication operator of SDMA link.
NOTIFY_WAIT: operator of LOCAL.
"""
RDMA = 'RDMA'
SDMA = 'SDMA'
LOCAL = 'LOCAL'
RDMASEND = 'RDMASend'
REDUCE_INLINE = 'Reduce Inline'
MEMCPY = 'Memcpy'
NOTIFY_RECORD = 'Notify Record'
NOTIFY_WAIT = 'Notify Wait'
class HcclParser:
"""
The parser for parsing hccl file.
Args:
source_dir (str): The hccl source dir.
device_id (str): The device ID.
rank_id (str): The rank ID.
output_path (str): The directory of the parsed file. Default: `./`.
Raises:
ProfilerPathErrorException: If the hccl file path or the output path is invalid.
ProfilerFileNotFoundException: If the hccl file or the output dir does not exist.
"""
_parsed_hccl_file_name = 'hccl_raw_{}.csv'
_col_names = ['step_num', 'communication_cost', 'wait_cost', 'link_info', 'communication_operator_cost']
def __init__(self, source_dir, device_id, rank_id, output_path):
self._dev_id = device_id
self._rank_id = rank_id
self._source_dir = source_dir
self._save_path = self._get_save_path(output_path)
self._step_trace_info = self._get_step_trace_info(output_path)
self._communication_operator_name_mapping_info = self._get_communication_operator_name_mapping_info()
def parse(self):
"""Parse communication info."""
self._parse_and_save(self._source_dir)
def _parse_communication_cost(self, operators_cost_info, info, operators_dict):
"""Parse communication cost."""
for key, value in operators_cost_info.items():
for item in value:
if info[0] == item[0]:
operators_dict[key] = item
def _parse_and_save(self, dir_path):
"""Parse and save communication info."""
communication_info_cache = list()
operators_cost_info = self._get_communication_operators_cost_info(dir_path)
for key, value in operators_cost_info.items():
for item in value:
communication_info_cache.append(item)
communication_info_cache = self._merge_communication_info_by_step_num(communication_info_cache)
for info in communication_info_cache:
operators_dict = dict()
self._parse_communication_cost(operators_cost_info, info, operators_dict)
info.append(operators_dict)
device_communication_average_value = self._calculate_communication_average_value(communication_info_cache)
operators_average_value = dict()
for key, value in operators_cost_info.items():
average_value = self._calculate_communication_average_value(value)
average_value.insert(0, '-')
operators_average_value[key] = average_value
device_communication_average_value.append(operators_average_value)
device_communication_average_value.insert(0, '-')
with open(self._save_path, 'w', newline='') as save_file:
csv_writer = csv.writer(save_file)
csv_writer.writerow(self._col_names)
for item in communication_info_cache:
item[3] = json.dumps(item[3])
item[4] = json.dumps(item[4])
csv_writer.writerow(item)
device_communication_average_value[3] = json.dumps(device_communication_average_value[3])
device_communication_average_value[4] = json.dumps(device_communication_average_value[4])
csv_writer.writerow(device_communication_average_value)
os.chmod(self._save_path, stat.S_IREAD | stat.S_IWRITE)
def _get_save_path(self, output_path):
"""
Get the save path.
Args:
output_path (str): The output dir.
Returns:
str, the save path.
"""
output_path = self._validate_dir_path(output_path)
return os.path.join(
output_path, self._parsed_hccl_file_name.format(self._rank_id)
)
def _get_step_trace_info(self, source_dir):
"""Get the start and end timestamps in a step and communication operators names."""
file_path = os.path.join(
source_dir,
f'step_trace_raw_{self._rank_id}_detail_time.csv'
)
try:
file_path = validate_and_normalize_path(file_path)
except RuntimeError:
logger.warning('file path is invalid.')
raise ProfilerPathErrorException('file path is invalid.')
if not os.path.isfile(file_path):
logger.warning('The step trace file <%s> not found.', file_path)
raise ProfilerFileNotFoundException(file_path)
with open(file_path, 'r') as src_file:
csv_reader = csv.reader(src_file)
communication_operators_names = next(csv_reader)[9:]
step_timestamps_info = [[info[0], float(info[1]) / 100, float(info[2]) / 100]
for info in csv_reader if info[0].isdigit()]
return [communication_operators_names, step_timestamps_info]
def _get_communication_operator_name_mapping_info(self):
"""Get the name of communication operators mapping between hccl and step trace."""
dir_path = self._validate_dir_path(self._source_dir)
operators_names_in_hccl = [entry.name for entry in os.scandir(dir_path) if entry.is_dir()
and entry.name.endswith(self._dev_id)]
operators_names_in_hccl_set = set({i.split('_')[0] for i in operators_names_in_hccl})
op_names_in_hccl_dic = dict()
for item in operators_names_in_hccl_set:
op_names_in_hccl_dic[item] = sorted([i for i in operators_names_in_hccl if i.split('_')[0] == item],
key=lambda x: int(x.split('_')[1]))
operators_names_in_step_trace = [self._step_trace_info[0][i]
for i in range(0, len(self._step_trace_info[0]), 3)]
op_names_in_step_trace_set = set({i.split('_')[3].split('-')[0] for i in operators_names_in_step_trace})
op_names_in_step_trace_dic = dict()
for item in op_names_in_step_trace_set:
op_names_in_step_trace_dic[item] = [i for i in operators_names_in_step_trace
if i.split('_')[3].split('-')[0] == item]
communication_operator_mapping_info = dict()
for hccl_key, hccl_value in op_names_in_hccl_dic.items():
for step_trace_key, step_trace_value in op_names_in_step_trace_dic.items():
if hccl_key.lower() == step_trace_key.lower():
communication_operator_mapping_info[hccl_key] = list(zip(hccl_value, step_trace_value))
logger.info("Communication operator name mapping info is %s", communication_operator_mapping_info)
return communication_operator_mapping_info
def _calculate_the_step_by_timestamp(self, timestamp):
"""Calculate the step according to the timestamp."""
step_timestamps_info = self._step_trace_info[1]
step_timestamps_len = len(step_timestamps_info)
if timestamp < step_timestamps_info[0][1]:
step_num = "1"
elif step_timestamps_info[step_timestamps_len - 1][2] < timestamp:
step_num = step_timestamps_info[step_timestamps_len - 1][0]
else:
for item in step_timestamps_info:
if item[1] <= timestamp < item[2]:
step_num = item[0]
return step_num
def _get_communication_operators_cost_info(self, dir_path):
"""Obtain time-consuming information of all communication operators."""
operators_cost_info = dict()
dir_path = self._validate_dir_path(dir_path)
operators_dir = [entry.name for entry in os.scandir(dir_path) if entry.is_dir()
and entry.name.endswith(self._dev_id)]
operator_dir_path = [os.path.join(dir_path, operator_dir) for operator_dir in operators_dir]
for operator_dir in operator_dir_path:
operator_cost = self._calculate_communication_operator_cost(operator_dir)
operator_name = os.path.basename(operator_dir)
op_mapping_info = self._communication_operator_name_mapping_info.get(operator_name.split('_')[0], [])
op_mapping_name = [item[1] for item in op_mapping_info if item[0] == operator_name]
if not op_mapping_name:
logger.warning("The mapping relationship between op name in hccl and op name in step trace "
"cannot be found. Use op name in hccl to show the name of the communication operator.")
else:
operator_name = op_mapping_name[0]
operators_cost_info[operator_name] = operator_cost
return operators_cost_info
def _calculate_communication_operator_cost(self, dir_path):
"""Calculate communication operator cost. Such as allReduce_1,allReduce_2."""
dir_path = self._validate_dir_path(dir_path)
files = [entry.name for entry in os.scandir(dir_path) if entry.is_file()]
files_path = [os.path.join(dir_path, file) for file in files]
operator_cost = list(map(self._calculate_communication_operator_iter_cost, files_path))
steps_operator_cost = self._merge_communication_info_by_step_num(operator_cost)
return steps_operator_cost
def _merge_communication_info_by_step_num(self, communication_info: list):
"""According to step num to merge communication info."""
steps_communication_info = list()
info_set = set()
for item in communication_info:
if item[0].isdigit():
info_set.add(int(item[0]))
info_set = sorted(info_set)
for item in info_set:
item = str(item)
step_communication_info = [info for info in communication_info if info[0] == item]
step_communication_cost = sum([i[1] for i in step_communication_info])
step_communication_wait_cost = sum([i[2] for i in step_communication_info])
step_communication_link = self._calculate_link_value([i[3] for i in step_communication_info], "total")
steps_communication_info.append([item, step_communication_cost,
step_communication_wait_cost, step_communication_link])
return steps_communication_info
def _calculate_communication_operator_iter_cost(self, file_path):
"""Calculate the time-consuming of communication operator in one execution round."""
def _inner_calculate_communication_operator_iter_cost(events):
total_notify_wait = HcclParser._calculate_notify_wait_time(events)
src_dst_dict = self._divide_communication_info_by_src_dst_rank(events)
src_dst_link_info = self._calculate_src_dst_link_info(src_dst_dict)
communication_cost, communication_wait = self._calculate_device_communication_cost(src_dst_link_info)
total_notify_wait -= communication_wait
return [communication_cost, total_notify_wait, src_dst_link_info]
file_path = self._validate_file_path(file_path)
with open(file_path, 'r') as src_file:
try:
operator_info = json.load(src_file)
except (json.JSONDecodeError, TypeError) as err:
logger.warning(err)
raise ProfilerRawFileException('Fail to parse operator file.')
trace_events = operator_info.get("traceEvents")
operator_timestamp = trace_events[0].get("ts", 0)
step_id = self._calculate_the_step_by_timestamp(operator_timestamp)
total_communication_operator_iter_cost = \
_inner_calculate_communication_operator_iter_cost(trace_events)
threads_dict = self._divide_communication_info_by_thread(trace_events)
major_thread = sorted(threads_dict, reverse=True)[0]
major_thread_trace_events = threads_dict.get(major_thread)
mainstream_communication_operator_iter_cost = \
_inner_calculate_communication_operator_iter_cost(major_thread_trace_events)
return [step_id, mainstream_communication_operator_iter_cost[0],
mainstream_communication_operator_iter_cost[1],
total_communication_operator_iter_cost[2]]
@staticmethod
def _divide_communication_info_by_thread(trace_events: list):
"""Divide information by thread."""
threads_dict = dict()
for item in trace_events:
thread_id = item.get("tid")
if thread_id not in threads_dict.keys():
threads_dict[thread_id] = [item]
else:
threads_dict[thread_id].append(item)
return threads_dict
def _divide_communication_info_by_src_dst_rank(self, trace_event: list):
"""Divide information by src rank id and dst rank id"""
src_dst_dict = dict()
for item in trace_event:
src_rank = item.get("args").get("src rank")
dst_rank = item.get("args").get("dst rank")
if src_rank is None or dst_rank is None:
continue
if int(src_rank) == int('0xffffffff', 16):
src_rank = dst_rank
if int(dst_rank) == int('0xffffffff', 16):
dst_rank = src_rank
if item.get("args").get("transport type") == CommunicationInfo.LOCAL.value:
item["args"]["src rank"] = dst_rank
item["args"]["dst rank"] = src_rank
src_dst_key = str(dst_rank) + '-' + str(src_rank)
else:
src_dst_key = str(src_rank) + '-' + str(dst_rank)
if src_dst_key not in src_dst_dict.keys():
src_dst_dict[src_dst_key] = [item]
else:
src_dst_dict[src_dst_key].append(item)
return src_dst_dict
def _divide_communication_info_by_link_type(self, trace_event: list):
"""Divide information by link type."""
link_type_dict = dict()
for item in trace_event:
link_type_key = item.get("args").get("transport type")
if link_type_key is None:
continue
if link_type_key in (CommunicationInfo.RDMA.value, CommunicationInfo.SDMA.value):
task_type = item.get("args").get("task type")
if task_type == CommunicationInfo.NOTIFY_RECORD.value:
continue
if link_type_dict.get(link_type_key):
link_type_dict[link_type_key].append(item)
else:
link_type_dict[link_type_key] = [item]
if link_type_key == CommunicationInfo.LOCAL.value:
if link_type_dict.get(CommunicationInfo.RDMA.value):
link_type_dict[CommunicationInfo.RDMA.value].append(item)
return link_type_dict
def _calculate_device_communication_cost(self, src_dst_link_info: dict):
"""Calculate notify wait time."""
total_communication_time = 0
total_wait_time = 0
for src_dst_value in src_dst_link_info.values():
for link_type_value in src_dst_value.values():
total_communication_time += link_type_value[0]
if len(link_type_value) > 3:
total_wait_time += link_type_value[3]
return total_communication_time, total_wait_time
def _parse_link_cost(self, result_dict, key, link_type_dict):
"""Parse link cost."""
for link_type_key, link_type_value in link_type_dict.items():
if link_type_key == CommunicationInfo.RDMA.value:
rdma_infos = []
threads_dict = self._divide_communication_info_by_thread(link_type_value)
for thread_value in threads_dict.values():
rdma_info = self._calculate_adma_link_info(thread_value)
rdma_infos.append(rdma_info)
rdma_total_cost = np.sum(rdma_infos, axis=0).tolist()
result_dict[key][link_type_key] = rdma_total_cost
if link_type_key == CommunicationInfo.SDMA.value:
sdma_total_cost = self._calculate_sdma_link_info(link_type_value)
result_dict[key][link_type_key] = sdma_total_cost
def _calculate_src_dst_link_info(self, src_dst_dict: dict):
"""Calculate src dst link info."""
result_dict = dict()
for key, value in src_dst_dict.items():
link_type_dict = self._divide_communication_info_by_link_type(value)
if not link_type_dict:
continue
result_dict[key] = dict()
self._parse_link_cost(result_dict, key, link_type_dict)
return result_dict
@staticmethod
def _calculate_adma_link_info(trace_event: list):
"""
Calculate RDMA link info.
When the link is RDMA,it is necessary to match three consecutive operators RDMASend, RDMASend \
and Notify Wait,and take the sum of the time of the three operators as one communication time.
"""
rdma_communication_time = 0
rdma_communication_size = 0
rdma_communication_wait_time = 0
start_index = 0
end_index = len(trace_event) - 1
while start_index < end_index:
first_task_type = trace_event[start_index].get("args").get("task type")
if first_task_type == CommunicationInfo.RDMASEND.value and start_index < end_index - 1:
second_task_type = trace_event[start_index + 1].get("args").get("task type")
third_task_type = trace_event[start_index + 2].get("args").get("task type")
if second_task_type == CommunicationInfo.RDMASEND.value and \
third_task_type == CommunicationInfo.NOTIFY_WAIT.value:
rdma_send_cost = trace_event[start_index].get("dur", 0)
notify_record_cost = trace_event[start_index + 1].get("dur", 0)
notify_wait_cost = trace_event[start_index + 2].get("dur", 0)
rdma_communication_time += rdma_send_cost + notify_record_cost + notify_wait_cost
rdma_communication_wait_time += notify_wait_cost
rdma_size = trace_event[start_index].get("args").get("size")
rdma_size = int(rdma_size, 16) if rdma_size else 0
notify_record_size = trace_event[start_index + 1].get("args").get("size")
notify_record_size = int(notify_record_size, 16) if notify_record_size else 0
rdma_communication_size += rdma_size + notify_record_size
start_index += 2
start_index += 1
rdma_communication_wait_time = rdma_communication_wait_time / 1e3
rdma_communication_size = rdma_communication_size / 1e3
rdma_communication_time = rdma_communication_time / 1e3
rdma_bandwidth = rdma_communication_size / (rdma_communication_time / 1e3) \
if rdma_communication_size else 0
return [rdma_communication_time, rdma_communication_size, rdma_bandwidth, rdma_communication_wait_time]
def _calculate_sdma_link_info(self, trace_event: list):
"""
Calculate SDMA link info.
When the link is SDMA, the communication time of the primary link is the sum of the execution time\
of Reduce inline and Memcpy operators.
"""
sdma_communication_time = 0
sdma_communication_size = 0
for item in trace_event:
task_type = item.get("args").get("task type")
if task_type in (CommunicationInfo.REDUCE_INLINE.value, CommunicationInfo.MEMCPY.value):
sdma_communication_time += item.get("dur", 0)
sdma_size = int(item.get("args").get("size"), 16) if item.get("args").get("size") else 0
sdma_communication_size += sdma_size
sdma_communication_time = sdma_communication_time / 1e3
sdma_communication_size = sdma_communication_size / 1e3
sdma_bandwidth = sdma_communication_size / (sdma_communication_time / 1e3) \
if sdma_communication_size else 0
return [sdma_communication_time, sdma_communication_size, sdma_bandwidth]
@staticmethod
def _calculate_notify_wait_time(trace_event: list):
"""Calculate notify wait time."""
total_notify_wait_time = 0
for item in trace_event:
task_type = item.get("args").get("task type")
if task_type == CommunicationInfo.NOTIFY_WAIT.value:
total_notify_wait_time += item.get("dur", 0)
total_notify_wait_time = total_notify_wait_time / 1e3
return total_notify_wait_time
def _calculate_communication_average_value(self, communication_info: list):
"""Calculate communication average value."""
communication_info_size = len(communication_info)
if communication_info_size == 0:
return []
communication_cost_average = sum([i[1] for i in communication_info]) / communication_info_size
wait_cost_average = sum([i[2] for i in communication_info]) / communication_info_size
link_info = [i[3] for i in communication_info]
calculate_type = 'average'
link_average_info = HcclParser._calculate_link_value(link_info, calculate_type)
return [communication_cost_average, wait_cost_average, link_average_info]
@staticmethod
def _parser_link_dict(result_dict, src_dst_key, src_dst_value):
"""Parser link info to dict."""
if src_dst_key not in result_dict.keys():
result_dict[src_dst_key] = dict()
for link_key, link_value in src_dst_value.items():
if link_key not in result_dict[src_dst_key].keys():
result_dict[src_dst_key][link_key] = list()
result_dict[src_dst_key][link_key].append(link_value)
@staticmethod
def _calculate_link_value(link_info: list, calculate_type):
"""Calculate link average or total value."""
result_dict = dict()
for item in link_info:
for src_dst_key, src_dst_value in item.items():
HcclParser._parser_link_dict(result_dict, src_dst_key, src_dst_value)
for src_dst_key, src_dst_value in result_dict.items():
for link_key, _ in src_dst_value.items():
if calculate_type == 'average':
result_dict[src_dst_key][link_key] = np.mean(result_dict[src_dst_key][link_key], axis=0).tolist()
if calculate_type == 'total':
result_dict[src_dst_key][link_key] = np.sum(result_dict[src_dst_key][link_key], axis=0).tolist()
return result_dict
def _validate_file_path(self, file_path):
"""Validate file path."""
try:
file_path = validate_and_normalize_path(file_path)
except RuntimeError:
logger.warning('file path is invalid.')
raise ProfilerPathErrorException('file path is invalid.')
if not os.path.isfile(file_path):
logger.warning('The file <%s> not found.', file_path)
raise ProfilerFileNotFoundException(file_path)
return file_path
def _validate_dir_path(self, dir_path):
"""Validate dir path."""
try:
dir_path = validate_and_normalize_path(dir_path)
except RuntimeError:
logger.warning('dir path is invalid.')
raise ProfilerPathErrorException('dir path is invalid.')
if not os.path.isdir(dir_path):
logger.warning('The dir <%s> not found.', dir_path)
raise ProfilerDirNotFoundException(dir_path)
return dir_path