* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "error_tracking.h"
#include "runtime/rt.h"
#include "framework/common/debug/log.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/attr_utils.h"
#include "rt_error_codes.h"
#include "base/err_msg.h"
namespace {
void GetOpOriginName(const ge::OpDescPtr &op_desc, std::string &origin_op_name) {
std::vector<std::string> original_names;
origin_op_name = op_desc->GetName();
if (ge::AttrUtils::GetListStr(op_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names) &&
!original_names.empty()) {
origin_op_name.clear();
for (const auto &name : original_names) {
origin_op_name += name;
origin_op_name += ";";
}
origin_op_name = origin_op_name.substr(0U, origin_op_name.length() - 1U);
}
}
}
namespace ge {
constexpr uint32_t kMaxGraphOpDescInfoNum = 2048U * 2048U;
ErrorTracking &ErrorTracking::GetInstance() {
static ErrorTracking instance;
return instance;
}
ErrorTracking::ErrorTracking() {
}
void ErrorTracking::AddTaskOpdescInfo(const OpDescPtr &op, const TaskKey &key,
std::map<TaskKey, ErrorTrackingOpInfo> &map, uint32_t max_count) const {
if (op != nullptr) {
GELOGD("Add task opdesc info, opname %s, task_id=%u, stream_id=%u, thread_id %u, context_id %u.",
op->GetName().c_str(), key.GetTaskId(), key.GetStreamId(), key.GetThreadId(), key.GetContextId());
std::string origin_name;
GetOpOriginName(op, origin_name);
if (map.size() >= max_count) {
(void)map.erase(map.begin());
}
ErrorTrackingOpInfo info = {op->GetName(), origin_name};
map[key] = info;
}
}
bool ErrorTracking::GetGraphTaskOpdescInfo(const uint32_t task_id, const uint32_t stream_id, ErrorTrackingOpInfo &op_info) {
TaskKey key(task_id, stream_id);
return GetTaskOpdescInfo(key, graph_task_to_op_info_, op_info);
}
bool ErrorTracking::GetGraphTaskOpdescInfo(const TaskKey &key, ErrorTrackingOpInfo &op_info) {
return GetTaskOpdescInfo(key, graph_task_to_op_info_, op_info);
}
bool ErrorTracking::GetSingleOpTaskOpdescInfo(const uint32_t task_id, const uint32_t stream_id, ErrorTrackingOpInfo &op_info) {
TaskKey key(task_id, stream_id);
return GetTaskOpdescInfo(key, single_op_task_to_op_info_, op_info);
}
bool ErrorTracking::GetTaskOpdescInfo(const TaskKey &key, const std::map<TaskKey, ErrorTrackingOpInfo> &map,
ErrorTrackingOpInfo &op_info) {
const std::lock_guard<std::mutex> lk(mutex_);
auto iter = map.find(key);
if (iter != map.end()) {
op_info = iter->second;
return true;
}
return false;
}
bool ErrorTracking::GetTaskOpdescInfo(const TaskKey &key,
const std::map<uint32_t, std::map<TaskKey, ErrorTrackingOpInfo>> &map,
ErrorTrackingOpInfo &op_info)
{
const std::lock_guard<std::mutex> lk(mutex_);
for (const auto& pair : map) {
auto iter = pair.second.find(key);
if (iter != pair.second.end()) {
op_info = iter->second;
return true;
}
}
return false;
}
void ErrorTracking::SaveGraphTaskOpdescInfo(const OpDescPtr &op, const uint32_t task_id, const uint32_t stream_id,
const uint32_t model) {
TaskKey key(task_id, stream_id);
const std::lock_guard<std::mutex> lk(mutex_);
AddTaskOpdescInfo(op, key, graph_task_to_op_info_[model], kMaxGraphOpDescInfoNum);
}
void ErrorTracking::UpdateTaskId(const uint32_t old_task_id, const uint32_t new_task_id, const uint32_t stream_id, const uint32_t model) {
TaskKey old_key(old_task_id, stream_id);
TaskKey new_key(new_task_id, stream_id);
const std::lock_guard<std::mutex> lk(mutex_);
auto model_it = graph_task_to_op_info_.find(model);
if (model_it == graph_task_to_op_info_.end()) {
GELOGW("[Update][TaskId] failed, model %u not found", model);
return;
}
auto &task_map = model_it->second;
auto it = task_map.find(old_key);
if (it != task_map.end()) {
const std::string opname = it->second.op_name;
GELOGD("Update task id, old: %u -> new: %u, stream_id: %u, model: %u, opname: %s",
old_task_id, new_task_id, stream_id, model, opname.c_str());
ErrorTrackingOpInfo op_info = it->second;
(void)task_map.erase(it);
task_map[new_key] = op_info;
} else {
GELOGW("Failed to update task id, old task id %u not found in model %u", old_task_id, model);
}
}
void ErrorTracking::SaveGraphTaskOpdescInfo(const OpDescPtr &op, const TaskKey &key, const uint32_t model) {
const std::lock_guard<std::mutex> lk(mutex_);
AddTaskOpdescInfo(op, key, graph_task_to_op_info_[model], kMaxGraphOpDescInfoNum);
}
void ErrorTracking::SaveSingleOpTaskOpdescInfo(const OpDescPtr &op, const uint32_t task_id, const uint32_t stream_id) {
TaskKey key(task_id, stream_id);
const std::lock_guard<std::mutex> lk(mutex_);
AddTaskOpdescInfo(op, key, single_op_task_to_op_info_, single_op_max_count_);
}
void ErrorTracking::ClearUnloadedModelOpdescInfo(const uint32_t model) {
const std::lock_guard<std::mutex> lk(mutex_);
auto it = graph_task_to_op_info_.find(model);
if (it != graph_task_to_op_info_.end()) {
(void)graph_task_to_op_info_.erase(it);
}
}
void ErrorTrackingCallback(rtExceptionInfo *const exception_data) {
if (exception_data == nullptr) {
return;
}
if ((exception_data->retcode == ACL_ERROR_RT_AICORE_OVER_FLOW) ||
(exception_data->retcode == ACL_ERROR_RT_AIVEC_OVER_FLOW) || (exception_data->retcode == ACL_ERROR_RT_OVER_FLOW)) {
return;
}
GELOGI("ErrorTracking callbak in, task_id %u, stream_id %u.", exception_data->taskid, exception_data->streamid);
ErrorTrackingOpInfo op_info;
bool founded = false;
if (exception_data->expandInfo.type == RT_EXCEPTION_FFTS_PLUS) {
const uint32_t context_id = static_cast<uint32_t>(exception_data->expandInfo.u.fftsPlusInfo.contextId);
const uint32_t thread_id = static_cast<uint32_t>(exception_data->expandInfo.u.fftsPlusInfo.threadId);
TaskKey key(exception_data->taskid, exception_data->streamid, context_id, thread_id);
founded = ErrorTracking::GetInstance().GetGraphTaskOpdescInfo(key, op_info);
} else {
founded = ErrorTracking::GetInstance().GetGraphTaskOpdescInfo(exception_data->taskid, exception_data->streamid, op_info);
if (!founded) {
founded = ErrorTracking::GetInstance().GetSingleOpTaskOpdescInfo(exception_data->taskid, exception_data->streamid, op_info);
}
}
if (founded) {
GELOGE(FAILED, "Error happened, origin_op_name [%s], op_name [%s], task_id %u, stream_id %u.",
op_info.op_origin_name.c_str(), op_info.op_name.c_str(), exception_data->taskid, exception_data->streamid);
REPORT_INNER_ERR_MSG("E18888", "Op execute failed. origin_op_name [%s], op_name [%s], "
"error_info: task_id %u, stream_id %u, tid %u, device_id %u, retcode 0x%x",
op_info.op_origin_name.c_str(), op_info.op_name.c_str(), exception_data->taskid, exception_data->streamid,
exception_data->tid, exception_data->deviceid, exception_data->retcode);
}
}
uint32_t RegErrorTrackingCallBack() {
GE_CHK_RT_RET(rtRegTaskFailCallbackByModule("GeErrorTracking", &ErrorTrackingCallback));
return 0;
}
}