* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "cann_tracing_profiler.h"
#include <unordered_map>
#include "runtime/model_v2_executor.h"
#include "core/builder/node_types.h"
#include "graph/def_types.h"
#include "graph/debug/ge_attr_define.h"
#include "lowering/placement/placed_lowering_result.h"
#include "lowering/pass_changed_kernels_info.h"
#include "graph/load/model_manager/davinci_model.h"
#include "subscriber/subscriber_utils.h"
#include "engine/aicore/launch_kernel/ai_core_launch_kernel.h"
#include "runtime/subscriber/global_profiler.h"
namespace gert {
namespace {
constexpr uint16_t kFpBeginLogId = 2U;
constexpr uint16_t kBpEndLogId = 3U;
size_t CalcArgIndex(size_t total_num, ExecuteArgIndex arg_index) {
size_t tensor_num = total_num - static_cast<size_t>(ExecuteArgIndex::kNum);
return tensor_num + static_cast<size_t>(static_cast<int64_t>(arg_index) * -1 - 1);
}
}
void CannTracingProfiler::Init() {
const auto execution_data =
static_cast<const ExecutionData *>(extend_info_.executor->GetExeGraphExecutor(kMainExeGraph)->GetExecutionData());
if (execution_data == nullptr) {
GELOGW("Execution data is empty, do not init tracing profiler.");
return;
}
GELOGD("Training trace init, execute node num = %zu", execution_data->base_ed.node_num);
for (size_t i = 0UL; i < execution_data->base_ed.node_num; ++i) {
const auto node = execution_data->base_ed.nodes[i];
auto kernel_context = reinterpret_cast<KernelContext *>(&node->context);
const auto kernel_extend_info = static_cast<const KernelExtendInfo *>(kernel_context->GetKernelExtend());
if (kernel_extend_info == nullptr) {
GELOGW("Kernel extend info is nullptr.");
return;
}
if (ProfilerRegistry::GetInstance().IsProfLaunchType(kernel_extend_info->GetKernelType())) {
const auto compute_node_info = static_cast<const ComputeNodeInfo *>(kernel_context->GetComputeNodeExtend());
if (compute_node_info == nullptr) {
continue;
}
const auto node_name = compute_node_info->GetNodeName();
const auto iter = extend_info_.node_names_to_attrs.find(node_name);
if (iter != extend_info_.node_names_to_attrs.cend()) {
node_ids_to_attrs_[node->node_id] = iter->second;
}
}
}
}
CannTracingProfiler::CannTracingProfiler(SubscriberExtendInfo extend_info)
: BaseExecutorProfiler(), extend_info_(std::move(extend_info)) {
if ((extend_info_.executor != nullptr)) {
Init();
}
}
ge::Status CannTracingProfiler::ReportTraceInfo(uint16_t tag_id, const Node *node) {
if (rt_streams_ == nullptr) {
const auto execution_data = static_cast<const ExecutionData *>(
extend_info_.executor->GetExeGraphExecutor(kMainExeGraph)->GetExecutionData());
GE_ASSERT_NOTNULL(execution_data);
const auto stream_idx = CalcArgIndex(execution_data->base_ed.input_num, ExecuteArgIndex::kStream);
rt_streams_ =
reinterpret_cast<Chain *>(execution_data->base_ed.input_values[stream_idx])->GetValue<ContinuousVector *>();
}
aclrtStream cur_stream = extend_info_.stream;
int64_t logic_stream_id = 0;
if (node_ids_to_attrs_.find(node->node_id) != node_ids_to_attrs_.end()) {
auto trace_info = node_ids_to_attrs_[node->node_id];
logic_stream_id = trace_info.logic_stream_id;
GE_ASSERT_NOTNULL(rt_streams_);
cur_stream =
*(reinterpret_cast<aclrtStream *>(rt_streams_->MutableData()) + static_cast<size_t>(logic_stream_id));
}
rtProfTraceUserData userData = {
.id = iteration_num_,
.model_id = static_cast<uint64_t>(extend_info_.model_id),
.tag_id = tag_id
};
GE_ASSERT_RT_OK(aclrtProfTrace(&userData, sizeof(rtProfTraceUserData), cur_stream));
GELOGD(
"Profiling Step Info TraceTask execute async success, index_id = %llu, model_id = %u, tag_id = %u, "
"logic_stream_id= %lld",
iteration_num_, extend_info_.model_id, tag_id, logic_stream_id);
return ge::SUCCESS;
}
ge::Status CannTracingProfiler::ReportStartTraceInfo(const Node *node) {
GE_ASSERT_NOTNULL(node);
if (node_ids_to_attrs_.find(node->node_id) != node_ids_to_attrs_.end()) {
auto trace_info = node_ids_to_attrs_[node->node_id];
if (trace_info.is_fp) {
GE_ASSERT_SUCCESS(ReportTraceInfo(kFpBeginLogId, node));
}
if (trace_info.start_log_id > 0LL) {
GE_ASSERT_SUCCESS(ReportTraceInfo(static_cast<uint16_t>(trace_info.start_log_id), node));
}
}
return ge::SUCCESS;
}
ge::Status CannTracingProfiler::ReportEndTraceInfo(const Node *node) {
GE_ASSERT_NOTNULL(node);
if (node_ids_to_attrs_.find(node->node_id) != node_ids_to_attrs_.end()) {
auto trace_info = node_ids_to_attrs_[node->node_id];
if (trace_info.is_bp) {
if (trace_info.start_log_id > 0LL) {
GE_ASSERT_SUCCESS(ReportTraceInfo(static_cast<uint16_t>(trace_info.start_log_id + 1LL), node));
} else {
GE_ASSERT_SUCCESS(ReportTraceInfo(kBpEndLogId, node));
}
} else if (trace_info.start_log_id > 0LL) {
GE_ASSERT_SUCCESS(ReportTraceInfo(static_cast<uint16_t>(trace_info.start_log_id + 1LL), node));
}
}
return ge::SUCCESS;
}
void CannTracingProfiler::OnExecuteEvent(int32_t sub_exe_graph_type, CannTracingProfiler *profiler, ExecutorEvent event,
const void *node, KernelStatus result) {
(void)result;
(void)sub_exe_graph_type;
if (profiler == nullptr) {
GELOGW("Cann tracing profiler is nullptr.");
return;
}
if (event == kExecuteStart) {
(void)profiler->ReportStartTraceInfo(static_cast<const Node *>(node));
}
if (event == kExecuteEnd) {
(void)profiler->ReportEndTraceInfo(static_cast<const Node *>(node));
}
if (event == kModelEnd) {
profiler->IncreaseIterationNum();
}
}
}