* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/partition/engine_place.h"
#include <mutex>
#include "framework/common/op/ge_op_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "common/thread_pool/thread_pool.h"
#include "base/err_msg.h"
#include "base/err_mgr.h"
#include "api/gelib/gelib.h"
#include "graph/partition/optimizer/dynamic_data_flow_engine_reassign_pass.h"
#include "graph/partition/optimizer/hostcpu_engine_update_pass.h"
namespace ge {
namespace {
std::mutex check_support_cost_mutex;
constexpr uint32_t kDefaultThreadNum = 16;
}
Status EnginePlacer::Check() const {
if (compute_graph_ == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "compute_graph_ is nullptr, check invalid.");
GELOGE(GE_GRAPH_NULL_INPUT, "[Check][Param] compute_graph_ is nullptr.");
return FAILED;
}
std::shared_ptr<GELib> instance_ptr = GELib::GetInstance();
if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
REPORT_INNER_ERR_MSG("E19999", "GELib instance is nullptr or it is not InitFlag, check invalid.");
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Get][GELib] Run enginePlacer failed, because GELib is invalid.");
return FAILED;
}
return SUCCESS;
}
Status EnginePlacer::SelectEngine(const NodePtr &node, const std::set<std::string> &exclude_engines,
bool &is_check_support_success, OpInfo &matched_op_info) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
std::string engine_name;
std::string kernel_name;
bool has_engine_attr =
AttrUtils::GetStr(op_desc, ATTR_NAME_ENGINE_NAME_FOR_LX, engine_name) && !engine_name.empty();
bool has_kernel_attr =
AttrUtils::GetStr(op_desc, ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX, kernel_name) && !kernel_name.empty();
if (op_desc->GetOpEngineName().empty()) {
if (has_kernel_attr && has_engine_attr) {
UpdateOpdescWithAttr(op_desc, engine_name, kernel_name, matched_op_info);
} else {
std::string selected_engine_name =
DNNEngineManager::GetInstance().GetDNNEngineName(node, exclude_engines, matched_op_info);
if (selected_engine_name.empty()) {
is_check_support_success = false;
REPORT_PREDEFINED_ERR_MSG(
"EZ3003", std::vector<const char *>({"opname", "optype"}),
std::vector<const char *>({op_desc->GetName().c_str(), op_desc->GetType().c_str()}));
GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "[Check][Param] Cannot find engine of op name %s type %s",
op_desc->GetName().c_str(), op_desc->GetType().c_str());
return FAILED;
}
}
} else {
GetOpInfoByOpDesc(op_desc, engine_name, kernel_name, matched_op_info);
matched_op_info.flagAsync = false;
}
GELOGD("Assigning DNNEngine %s to node %s, op type %s", op_desc->GetOpEngineName().c_str(), node->GetName().c_str(),
node->GetType().c_str());
std::lock_guard<std::mutex> lock(mutex_);
node_atomic_engine_map_.emplace(node, op_desc->GetOpEngineName());
return SUCCESS;
}
void EnginePlacer::GetOpInfoByOpDesc(const OpDescPtr &op_desc, const std::string &attr_engine_name,
const std::string &attr_kernel_name, OpInfo &matched_op_info) {
if (attr_engine_name.empty()) {
matched_op_info.engine = op_desc->GetOpEngineName();
} else {
if (attr_engine_name != op_desc->GetOpEngineName()) {
GELOGW("[%s] of attr[%s] in op[%s] is not equal to engine[%s]", attr_engine_name.c_str(),
ATTR_NAME_ENGINE_NAME_FOR_LX.c_str(), op_desc->GetName().c_str(), op_desc->GetOpEngineName().c_str());
}
}
if (attr_kernel_name.empty()) {
matched_op_info.opKernelLib = op_desc->GetOpKernelLibName();
} else {
if (attr_kernel_name != op_desc->GetOpKernelLibName()) {
GELOGW("[%s] of attr[%s] in op[%s] is not equal to [%s]", attr_kernel_name.c_str(),
ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX.c_str(), op_desc->GetName().c_str(),
op_desc->GetOpKernelLibName().c_str());
}
}
}
void EnginePlacer::UpdateOpdescWithAttr(const OpDescPtr &op_desc, const std::string &attr_engine_name,
const std::string &attr_kernel_name, OpInfo &matched_op_info) {
op_desc->SetOpEngineName(attr_engine_name);
if (op_desc->GetOpKernelLibName().empty()) {
matched_op_info.opKernelLib = attr_kernel_name;
} else {
if (op_desc->GetOpKernelLibName() != attr_kernel_name) {
GELOGW("kernel[%s] in op[%s] is not equal to [%s] of attr[%s]", op_desc->GetOpKernelLibName().c_str(),
op_desc->GetName().c_str(), attr_kernel_name.c_str(), ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX.c_str());
}
}
}
Status EnginePlacer::RunAllSubgraphs() {
DNNEngineManager::GetInstance().InitPerformanceStatistic();
GE_ASSERT_SUCCESS(Run(false), "RunAllSubgraphs failed, graph=%s.", compute_graph_->GetName().c_str());
DNNEngineManager::GetInstance().LogCheckSupportCost();
return SUCCESS;
}
Status EnginePlacer::Run(bool direct_node_flag) {
std::lock_guard<std::mutex> lock(check_support_cost_mutex);
GELOGD("Engine placer starts.");
if (Check() != SUCCESS) {
return FAILED;
}
ThreadPool thread_pool("ge_fdengine", kDefaultThreadNum, false);
std::vector<std::future<Status>> fut_rets;
std::set<std::string> exclude_engines;
DNNEngineManager::GetExcludeEngines(exclude_engines);
const auto &context = GetThreadLocalContext();
const auto &err_msg_ctx = error_message::GetErrMgrContext();
std::map<NodePtr, OpInfo> nodes_op_infos;
for (const auto &node_ptr : compute_graph_->GetNodes(direct_node_flag)) {
GE_CHECK_NOTNULL(node_ptr);
auto fut = thread_pool.commit(
[this, node_ptr, &exclude_engines, &context, &err_msg_ctx, &nodes_op_infos]() -> Status {
error_message::SetErrMgrContext(err_msg_ctx);
bool is_check_support_success = true;
GetThreadLocalContext() = context;
OpInfo op_info;
GE_ASSERT_SUCCESS(SelectEngine(node_ptr, exclude_engines, is_check_support_success, op_info),
"Failed to select engine for [%s][%s].", node_ptr->GetNamePtr(), node_ptr->GetTypePtr());
if ((!op_info.engine.empty()) || (!op_info.opKernelLib.empty())) {
std::lock_guard<std::mutex> inner_lock(mutex_);
nodes_op_infos.emplace(node_ptr, op_info);
}
GE_ASSERT_TRUE(is_check_support_success);
return SUCCESS;
});
fut_rets.emplace_back(std::move(fut));
}
Status result = SUCCESS;
for (auto &fut : fut_rets) {
Status ret = SUCCESS;
try {
ret = fut.get();
} catch(std::runtime_error &e) {
GELOGE(INTERNAL_ERROR, "Got an exception, reason[%s].", e.what());
ret = INTERNAL_ERROR;
}
if (ret != SUCCESS) {
GELOGE(ret, "Failed to SelectEngine.");
result = ret;
}
}
DNNEngineManager::UpdateOpDescsWithOpInfos(nodes_op_infos);
GELOGD("Engine placer ends.");
return result;
}
Status EnginePlacer::AssignCompositeEngine() {
if (OpsKernelManager::GetInstance().GetCompositeEngines().empty()) {
GELOGI("No composite engine registers, ignore assign composite engine");
return SUCCESS;
}
std::vector<ComputeGraphPtr> subgraphs;
if (GraphUtils::GetSubgraphsRecursively(compute_graph_, subgraphs) != GRAPH_SUCCESS) {
GE_CHECK_NOTNULL(compute_graph_);
GELOGE(FAILED, "[Get][Subgraphs] Get subgraphs contained in graph %s failed", compute_graph_->GetName().c_str());
return FAILED;
}
for (const auto &subgraph : subgraphs) {
(void)subgraph->DelAttr(ATTR_NAME_COMPOSITE_ENGINE_NAME);
}
std::reverse(subgraphs.begin(), subgraphs.end());
subgraphs.emplace_back(compute_graph_);
for (const auto &subgraph : subgraphs) {
for (const auto &node : subgraph->GetDirectNode()) {
std::string composite_engine_name = DNNEngineManager::GetInstance().GetCompositeEngineName(node, 1);
GELOGD("Assign composite engine %s to node %s, op type %s", composite_engine_name.c_str(),
node->GetName().c_str(), node->GetType().c_str());
node_composite_engine_map_.insert(std::make_pair(node, composite_engine_name));
}
}
return SUCCESS;
}
const NodeEngineMap &EnginePlacer::GetNodeEngineMap(bool is_composite_engine_mode) const {
std::lock_guard<std::mutex> lock(mutex_);
return is_composite_engine_mode ? node_composite_engine_map_ : node_atomic_engine_map_;
}
Status EnginePlacer::RunSinglePass(const std::string &pass_name,
const std::shared_ptr<EngineReAssignPass> &pass) {
GE_CHECK_NOTNULL(pass);
Status ret;
{
std::lock_guard<std::mutex> lock(mutex_);
ret = pass->Run(compute_graph_, node_atomic_engine_map_, node_composite_engine_map_);
}
if (ret == SUCCESS) {
GELOGD("Engine reassign pass %s return SUCCESS.", pass_name.c_str());
} else if (ret == NOT_CHANGED) {
GELOGD("Engine reassign pass %s return NOT_CHANGED.", pass_name.c_str());
} else {
REPORT_INNER_ERR_MSG("E19999", "Engine reassign pass %s run failed.", pass_name.c_str());
GELOGE(ret, "[Call][Run] Engine reassign pass %s failed.", pass_name.c_str());
return ret;
}
return ret;
}
Status EnginePlacer::RunHostcpuEngineUpdatePass() {
auto hostcpu_pass = MakeShared<HostcpuEngineUpdatePass>();
return RunSinglePass("HostcpuEngineUpdatePass", hostcpu_pass);
}
Status EnginePlacer::ReAssignEngine() {
std::vector<std::pair<std::string, std::shared_ptr<EngineReAssignPass>>> passes;
passes.emplace_back(
std::make_pair("DynamicDataFlowEngineReassignPass", MakeShared<DynamicDataFlowEngineReassignPass>()));
passes.emplace_back(std::make_pair("HostcpuEngineUpdatePass", MakeShared<HostcpuEngineUpdatePass>()));
for (const auto &pass_item : passes) {
Status ret = RunSinglePass(pass_item.first, pass_item.second);
if (ret != SUCCESS && ret != NOT_CHANGED) {
return ret;
}
}
return SUCCESS;
}
}