import os
import logging
from collections import OrderedDict
from dataclasses import dataclass
from typing import List
from typing import Set
from common_func.db_name_constant import DBNameConstant
from common_func.info_conf_reader import InfoConfReader
from common_func.ms_constant.number_constant import NumberConstant
from common_func.ms_constant.str_constant import StrConstant
from common_func.trace_view_header_constant import TraceViewHeaderConstant
from common_func.trace_view_manager import TraceViewManager
from common_func.path_manager import PathManager
from common_func.hccl_info_common import DeviceHcclSource
from mscalculate.hccl.hccl_task import HcclOps
from mscalculate.hccl.hccl_task import HcclTask
from msmodel.hccl.hccl_model import HcclViewModel
from common_func.hash_dict_constant import HashDictData
class HCCLExport:
"""
hccl export
"""
HCCL_SORTED_OFFSET = 70000
INVALID_PLANE = -1
DEFAULT_PLANE = 0
INVALID_GROUP = 'N/A'
@dataclass
class HcclGroup:
group_name: str
planes: Set[int]
id: int
start_index: int
is_aicpu: bool
def __init__(self: any, param: dict) -> None:
self.project_path = param.get(StrConstant.PARAM_RESULT_DIR)
self.result = []
self.err_message = {}
self.iter_range = param.get(StrConstant.PARAM_ITER_ID)
self.pid_value = InfoConfReader().get_json_pid_data()
self.hccl_groups = dict()
self._hash_data = {}
@staticmethod
def get_hccl_arg(hccl_task):
return OrderedDict({
'notify_id': hccl_task.notify_id,
'duration estimated(us)': hccl_task.duration_estimated,
'stream id': hccl_task.stream_id,
'task id': hccl_task.task_id,
'context id': hccl_task.context_id,
'task type': hccl_task.hccl_name,
'src rank': hccl_task.local_rank,
'dst rank': hccl_task.remote_rank,
'transport type': hccl_task.transport_type,
'size(Byte)': hccl_task.size,
'data type': hccl_task.data_type,
'link type': hccl_task.link_type,
"bandwidth(GB/s)": hccl_task.bandwidth
})
def get_hccl_timeline_data(self: any) -> list:
"""
get data for hccl timeline
"""
hccl_data = []
mc2_hccl_task = []
if os.path.exists(PathManager.get_db_path(self.project_path, DBNameConstant.DB_HCCL_SINGLE_DEVICE)):
with HcclViewModel(self.project_path, DBNameConstant.DB_HCCL_SINGLE_DEVICE,
[DBNameConstant.TABLE_HCCL_TASK_SINGLE_DEVICE]) as hccl_model:
hccl_data = hccl_model.get_hccl_task()
mc2_hccl_task = hccl_model.get_kfc_task()
if not hccl_data and not mc2_hccl_task:
logging.error("get hccl data failed, may be the hccl file not completed or hccl parser "
"failed. please check data file.")
return []
self._hash_data = HashDictData(self.project_path).get_ge_hash_dict()
self._get_meta_data(hccl_data, mc2_hccl_task)
if hccl_data:
self._format_hccl_data(hccl_data, DeviceHcclSource.HCCL.value)
if mc2_hccl_task:
self._format_hccl_data(mc2_hccl_task, DeviceHcclSource.MC2.value)
return self.result
def _add_hccl_bar(self):
self.result = TraceViewManager.metadata_event(
[["process_name", self.pid_value, InfoConfReader().get_json_tid_data(), "Communication"]])
def _add_group_threads(self, group: HcclGroup, start_sort_index: int, valid_group: bool) -> int:
"""
add group thread meta data in json
return: end_sort_index
"""
if not group.planes:
return start_sort_index
index_now = start_sort_index
group.start_index = start_sort_index
group_name = self._hash_data.get(group.group_name, group.group_name)
if group.is_aicpu:
thread_name = f"Group {group_name} Aicpu Communication" if valid_group else "Aicpu Communication"
else:
thread_name = f"Group {group_name} Communication" if valid_group else "Communication"
self.result.extend(TraceViewManager.metadata_event(
[["thread_name", self.pid_value, index_now, thread_name]]))
self.result.extend(TraceViewManager.metadata_event(
[["thread_sort_index", self.pid_value, index_now, index_now]]))
plane_infos = []
plane_sort_indexes = []
for plane in group.planes:
index_now = start_sort_index + plane + 1
plane_infos.append(["thread_name", self.pid_value, index_now, f"Plane {plane}"])
plane_sort_indexes.append(["thread_sort_index", self.pid_value, index_now, index_now])
self.result.extend(TraceViewManager.metadata_event(plane_infos))
self.result.extend(TraceViewManager.metadata_event(plane_sort_indexes))
index_now += 1
return index_now
def _init_hccl_group(self, hccl_data: List[HcclTask], kfc_data: List[HcclTask]) -> dict:
name_planes_table: OrderedDict[str, dict] = OrderedDict()
hccl_groups = dict()
for data in hccl_data:
if data.plane_id == self.INVALID_PLANE:
data = data.replace(plane_id=self.DEFAULT_PLANE)
name_planes_table.setdefault(data.group_name, {}).setdefault("hccl", set()).add(data.plane_id)
for data in kfc_data:
if data.plane_id == self.INVALID_PLANE:
data = data.replace(plane_id=self.DEFAULT_PLANE)
name_planes_table.setdefault(data.group_name, {}).setdefault("aicpu", set()).add(data.plane_id)
for _id, (group_name, planes) in enumerate(name_planes_table.items()):
planes_hccl = planes.get("hccl", set())
planes_aicpu = planes.get("aicpu", set())
hccl_group = self.HcclGroup(group_name, planes_hccl, _id, -1, False)
hccl_group_aicpu = self.HcclGroup(group_name, planes_aicpu, _id, -1, True)
hccl_groups[group_name] = [hccl_group, hccl_group_aicpu]
return hccl_groups
def _get_meta_data(self: any, hccl_data: List[HcclTask], kfc_data: List[HcclTask]) -> None:
self.hccl_groups = self._init_hccl_group(hccl_data, kfc_data)
self._add_hccl_bar()
index_now = 0
for groups in self.hccl_groups.values():
for group in groups:
index_now = self._add_group_threads(group, index_now, group.group_name != self.INVALID_GROUP)
def _format_hccl_data(self: any, hccl_data: list, group_type: int = 0) -> None:
_hccl_format_data = self._format_hccl_communication_data(hccl_data, group_type)
_hccl_format_op_data = self._format_hccl_op_data(group_type)
self.result.extend(TraceViewManager.time_graph_trace(
TraceViewHeaderConstant.GRPC_TIME_GRAPH_HEAD, _hccl_format_data + _hccl_format_op_data))
def _format_hccl_op_data(self, group_type: int = 0):
if not os.path.exists(PathManager.get_db_path(self.project_path, DBNameConstant.DB_HCCL_SINGLE_DEVICE)):
return []
with HcclViewModel(
self.project_path, DBNameConstant.DB_HCCL_SINGLE_DEVICE, [DBNameConstant.TABLE_HCCL_TASK_SINGLE_DEVICE,
DBNameConstant.TABLE_HCCL_OP_SINGLE_DEVICE]) as hccl_model:
if group_type == DeviceHcclSource.HCCL.value:
hccl_op_data_from_group = hccl_model.get_hccl_op_data_by_group()
hccl_op_info_from_table = hccl_model.get_hccl_op_info_from_table()
else:
hccl_op_data_from_group = hccl_model.get_kfc_op_data_by_group()
hccl_op_info_from_table = hccl_model.get_hccl_op_info_from_table(DBNameConstant.TABLE_KFC_OP)
hccl_format_op_data = [None] * len(hccl_op_data_from_group)
for idx, hccl_op in enumerate(hccl_op_data_from_group):
if hccl_op.group_name not in self.hccl_groups:
logging.error("The group name %s not exists", hccl_op.group_name)
continue
hccl_op_info = hccl_op_info_from_table.get(hccl_op.connection_id, HcclOps())
args = {
"connection_id": hccl_op.connection_id,
"model id": hccl_op.model_id,
"data_type": hccl_op_info.data_type,
"alg_type": hccl_op_info.alg_type,
"count": hccl_op_info.count
}
hccl_format_op_data[idx] = [
hccl_op.op_name, self.pid_value,
self.hccl_groups.get(hccl_op.group_name)[group_type].start_index,
InfoConfReader().trans_into_local_time(raw_timestamp=hccl_op.timestamp),
hccl_op.duration / NumberConstant.NS_TO_US, args
]
return hccl_format_op_data
def _format_hccl_communication_data(self, hccl_data: List[HcclTask], group_type: int = 0):
if not hccl_data or InfoConfReader().is_level0():
return []
_hccl_format_data = [0] * len(hccl_data)
index = 0
for _, _hccl_data in enumerate(hccl_data):
hccl_args = HCCLExport.get_hccl_arg(_hccl_data)
hccl_args["model id"] = _hccl_data.model_id
if _hccl_data.group_name not in self.hccl_groups:
logging.error("The group name %s not exists: group idx: %d", _hccl_data.group_name, group_type)
continue
if _hccl_data.plane_id == self.INVALID_PLANE:
continue
thread_id = self.hccl_groups.get(_hccl_data.group_name)[group_type].start_index + _hccl_data.plane_id + 1
_hccl_data_piece = [
_hccl_data.hccl_name, self.pid_value, thread_id,
InfoConfReader().trans_into_local_time(raw_timestamp=_hccl_data.timestamp),
_hccl_data.duration / NumberConstant.NS_TO_US, hccl_args
]
_hccl_format_data[index] = _hccl_data_piece
index += 1
return _hccl_format_data[:index]