* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "fuzz_compile_bin_selector.h"
#include "common/profiling_definitions.h"
#include "framework/common/debug/log.h"
#include "graph/def_types.h"
#include "api/gelib/gelib.h"
#include "engines/manager/opskernel_manager/ops_kernel_builder_manager.h"
#include "common/opskernel/ops_kernel_info_types.h"
namespace ge {
namespace hybrid {
namespace {
const uintptr_t kWeightBase = 0x10000000;
const uintptr_t kMemBase = 0x20000000;
const uint64_t kFakeSize = 0x10000000UL;
const std::string kKernelName = "AIcoreEngine";
void BackUpNodeShapes(const ge::NodePtr &node_ptr, NodeInOutShape &in_out_shapes) {
for (size_t in_index = 0UL; in_index < node_ptr->GetOpDesc()->GetAllInputsSize(); ++in_index) {
const auto in_tensor_desc = node_ptr->GetOpDesc()->MutableInputDesc(in_index);
if (in_tensor_desc == nullptr) {
continue;
}
in_out_shapes.in_ori_shapes.emplace(std::make_pair(in_index, in_tensor_desc->GetOriginShape()));
in_out_shapes.in_shapes.emplace(std::make_pair(in_index, in_tensor_desc->GetShape()));
}
for (size_t out_index = 0UL; out_index < node_ptr->GetOpDesc()->GetAllOutputsDescSize(); ++out_index) {
const auto out_tensor_desc = node_ptr->GetOpDesc()->MutableOutputDesc(out_index);
if (out_tensor_desc == nullptr) {
continue;
}
in_out_shapes.out_ori_shapes.emplace(std::make_pair(out_index, out_tensor_desc->GetOriginShape()));
in_out_shapes.out_shapes.emplace(std::make_pair(out_index, out_tensor_desc->GetShape()));
}
}
void SetNodeShapes(const ge::NodePtr &node_ptr, NodeInOutShape &in_out_shapes) {
for (size_t in_index = 0UL; in_index < node_ptr->GetOpDesc()->GetAllInputsSize(); ++in_index) {
const auto in_tensor_desc = node_ptr->GetOpDesc()->MutableInputDesc(in_index);
if (in_tensor_desc == nullptr) {
continue;
}
in_tensor_desc->SetOriginShape(in_out_shapes.in_ori_shapes[in_index]);
in_tensor_desc->SetShape(in_out_shapes.in_shapes[in_index]);
}
for (size_t out_index = 0UL; out_index < node_ptr->GetOpDesc()->GetAllOutputsDescSize(); ++out_index) {
const auto out_tensor_desc = node_ptr->GetOpDesc()->MutableOutputDesc(out_index);
if (out_tensor_desc == nullptr) {
continue;
}
out_tensor_desc->SetOriginShape(in_out_shapes.out_ori_shapes[out_index]);
out_tensor_desc->SetShape(in_out_shapes.out_shapes[out_index]);
}
}
}
Status FuzzCompileBinSelector::Initialize() {
const auto ge_lib = GELib::GetInstance();
GE_CHECK_NOTNULL(ge_lib);
if (!ge_lib->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge_lib is uninitialized, failed.");
return GE_CLI_GE_NOT_INITIALIZED;
}
const auto &kernel_manager = ge_lib->OpsKernelManagerObj();
aicore_kernel_store_ = kernel_manager.GetOpsKernelInfoStore(kKernelName);
GE_CHECK_NOTNULL(aicore_kernel_store_);
return SUCCESS;
}
NodeCompileCacheItem *FuzzCompileBinSelector::SelectBin(const NodePtr &node, const GEThreadLocalContext *ge_context,
std::vector<domi::TaskDef> &task_defs) {
if (node == nullptr) {
GELOGE(FAILED, "Node is nullptr.");
return nullptr;
}
if (ge_context != nullptr) {
GetThreadLocalContext() = *ge_context;
}
PROFILING_START(-1, profiling::kFindCompileCache);
NodeCompileCacheItem *cci = nccm_->FindCompileCache(node);
PROFILING_END(-1, profiling::kFindCompileCache);
if (cci != nullptr) {
GELOGD("Found cci %lu of node %s.", cci->GetCacheItemId(), node->GetName().c_str());
return cci;
}
NodeInOutShape execution_shapes{};
BackUpNodeShapes(node, execution_shapes);
auto ret = DoCompileOp(node);
NodeInOutShape generalize_shapes{};
BackUpNodeShapes(node, generalize_shapes);
SetNodeShapes(node, execution_shapes);
if (ret != SUCCESS) {
return nullptr;
}
ret = OpsKernelBuilderManager::Instance().CalcOpRunningParam(*node);
if (ret != SUCCESS) {
return nullptr;
}
ret = DoGenerateTask(node, task_defs);
if (ret != SUCCESS) {
GELOGE(FAILED, "Failed to execute generate task, node = %s", node->GetName().c_str());
return nullptr;
}
GELOGD("Successfully generate task, node is %s", node->GetName().c_str());
const auto op_desc = *node->GetOpDesc();
const auto task_def = task_defs.back();
void *handle = nullptr;
KernelLaunchBinType bin_type = KernelLaunchBinType::kBinTypeEnd;
ret = DoRegisterBin(op_desc, task_def, bin_type, handle);
if (ret != SUCCESS) {
GELOGE(FAILED, "Register bin failed.");
return nullptr;
}
NodeCompileCacheItem node_cci;
if (NodeCompileCacheItem::Build(bin_type, node, handle, node_cci) != SUCCESS) {
GELOGI("Fail to build compile cache item of node %s.", node->GetName().c_str());
return nullptr;
}
SetNodeShapes(node, generalize_shapes);
PROFILING_START(-1, profiling::kAddCompileCache);
const auto ret_cci = nccm_->AddCompileCache(node, node_cci);
PROFILING_END(-1, profiling::kAddCompileCache);
SetNodeShapes(node, execution_shapes);
return ret_cci;
}
Status FuzzCompileBinSelector::DoCompileOp(const NodePtr &node_ptr) const {
std::vector<NodePtr> node_vec{node_ptr};
PROFILING_START(-1, profiling::kFuzzCompileOp);
const auto ret = aicore_kernel_store_->FuzzCompileOp(node_vec);
PROFILING_END(-1, profiling::kFuzzCompileOp);
if (ret != SUCCESS) {
GELOGI("Fuzz compile node %s, ret is not SUCCESS. Execute the origin graph", node_ptr->GetName().c_str());
}
return ret;
}
Status FuzzCompileBinSelector::DoRegisterBin(const OpDesc &op_desc, const domi::TaskDef &task_def,
KernelLaunchBinType &bin_type, void *&handle) const {
PROFILING_SCOPE(-1, profiling::kRegisterBin);
Status ret;
AttrNameOfBinOnOp attr_names_of_bin = {OP_EXTATTR_NAME_TBE_KERNEL, TVM_ATTR_NAME_METADATA, "_kernelname",
TVM_ATTR_NAME_MAGIC};
if (static_cast<ModelTaskType>(task_def.type()) != ModelTaskType::MODEL_TASK_ALL_KERNEL) {
bin_type = KernelLaunchBinType::kStubFunc;
ret = BinRegisterUtils::RegisterBin(op_desc, task_def.kernel().stub_func(), attr_names_of_bin, handle);
} else {
bin_type = KernelLaunchBinType::kWithHandle;
ret = BinRegisterUtils::RegisterBinWithHandle(op_desc, attr_names_of_bin, handle);
}
return ret;
}
Status FuzzCompileBinSelector::DoGenerateTask(const NodePtr &node_ptr, std::vector<domi::TaskDef> &task_defs) {
PROFILING_SCOPE(-1, profiling::kGenTask);
auto op_desc_ptr = node_ptr->GetOpDesc();
std::vector<int64_t> input_offsets(op_desc_ptr->GetInputsSize(), kMemBase);
std::vector<int64_t> output_offsets(op_desc_ptr->GetOutputsSize(), kMemBase);
op_desc_ptr->SetInputOffset(input_offsets);
op_desc_ptr->SetOutputOffset(output_offsets);
std::vector<int64_t> workspaces(op_desc_ptr->GetWorkspaceBytes().size(), kMemBase);
op_desc_ptr->SetWorkspace(std::move(workspaces));
RunContext run_ctx;
run_ctx.weightMemBase = PtrToPtr<void, uint8_t>(ValueToPtr(kWeightBase));
run_ctx.dataMemBase = PtrToPtr<void, uint8_t>(ValueToPtr(kWeightBase));
run_ctx.weightMemSize = kFakeSize;
run_ctx.dataMemSize = kFakeSize;
Status ret = SUCCESS;
{
const std::lock_guard<std::mutex> lk(gen_task_mutex_);
ret = OpsKernelBuilderManager::Instance().GenerateTask(*node_ptr, run_ctx, task_defs);
}
return ret;
}
}
}