* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "executor/event_handler.h"
#include <string>
#include <future>
#include "executor/executor_context.h"
#include "common/utils/rts_api_utils.h"
#include "mmpa/mmpa_api.h"
#include "common/thread_pool/thread_pool.h"
#include "cpu_tasks.h"
namespace ge {
namespace {
constexpr uint32_t kClearDavinciModelThreadNum = 1U;
constexpr uint32_t kSingleThreadNum = 1U;
constexpr int32_t kClearTypeStop = 1;
constexpr int32_t kClearTypeClear = 2;
constexpr int32_t kMaxParseModelThreadPoolSize = 8;
}
Status EventHandler::Initialize() {
context_ = MakeUnique<ExecutorContext>();
GE_CHECK_NOTNULL(context_);
context_->SetBaseDir(base_dir_);
GE_CHK_STATUS_RET_NOLOG(context_->Initialize());
return SUCCESS;
}
void EventHandler::Finalize() {
context_->Finalize();
context_.reset();
}
void EventHandler::SetBaseDir(const std::string &base_dir) {
base_dir_ = base_dir;
}
void EventHandler::HandleEvent(deployer::ExecutorRequest &request,
deployer::ExecutorResponse &response) {
if (request.has_batch_load_model_message()) {
HandleBatchLoadRequest(request, response);
} else if (request.has_unload_model_message()) {
HandleUnloadRequest(request, response);
} else if (request.has_sync_var_manager_message()) {
HandleSyncVarManagerRequest(request, response);
} else if (request.has_clear_model_message()) {
HandleClearModelRequest(request, response);
} else if (request.has_exception_notify_request()) {
HandleDataFlowExceptionNotifyRequest(request, response);
} else if (request.has_update_prof_message()) {
HandleProfInfo(request, response);
} else {
GELOGE(PARAM_INVALID, "[Handle][Event] failed, request has no content");
response.set_error_code(UNSUPPORTED);
response.set_error_message("request has no content");
}
}
void EventHandler::HandleUnloadRequest(deployer::ExecutorRequest &request, deployer::ExecutorResponse &response) const {
auto &unload_req = request.unload_model_message();
auto root_model_id = unload_req.model_id();
std::map<uint32_t, std::unique_ptr<ExecutorContext::ModelHandle>> *submodel_map = nullptr;
if (context_->GetModel(root_model_id, submodel_map) != SUCCESS) {
response.set_error_code(FAILED);
response.set_error_message("Failed to get model");
return;
}
GEEVENT("[Batch][UnloadModel] begin, model size = %zu.", submodel_map->size());
std::mutex failed_mu;
std::vector<uint32_t> failed;
std::vector<std::thread> threads;
aclrtContext ctx = nullptr;
(void) aclrtGetCurrentContext(&ctx);
const auto &thread_local_ctx = GetThreadLocalContext();
for (auto it = submodel_map->begin(); it != submodel_map->end(); ++it) {
auto submodel_id = it->first;
ExecutorContext::ModelHandle *handle = it->second.get();
try {
threads.emplace_back([handle, submodel_id, &failed, &failed_mu, ctx, &thread_local_ctx]() {
(void) aclrtSetCurrentContext(ctx);
GetThreadLocalContext() = thread_local_ctx;
if (handle->UnloadModel() != SUCCESS) {
std::lock_guard<std::mutex> lk(failed_mu);
failed.emplace_back(submodel_id);
}
});
} catch (std::exception &e) {
GELOGW("Failed to unload model, submodel_id = %u, exception = %s", submodel_id, e.what());
std::lock_guard<std::mutex> lk(failed_mu);
failed.emplace_back(submodel_id);
}
}
for (auto &th : threads) {
if (th.joinable()) {
th.join();
}
}
for (auto it = submodel_map->begin(); it != submodel_map->end();) {
it = submodel_map->erase(it);
}
ExecutorContext().LocalContext().RemoveLocalModel(root_model_id);
if (!failed.empty()) {
response.set_error_code(FAILED);
const std::string error_msg = "Unload model failed, failed model id = " + ToString(failed);
response.set_error_message(error_msg);
GELOGW("%s", error_msg.c_str());
return;
}
response.set_error_code(SUCCESS);
GEEVENT("[Batch][UnloadModel] success.");
}
void EventHandler::HandleSyncVarManagerRequest(deployer::ExecutorRequest &request,
deployer::ExecutorResponse &response) {
GELOGD("[Handle][Init VarManager] begin.");
if (context_->SyncSharedVarManager(request) != SUCCESS) {
response.set_error_code(FAILED);
response.set_error_message("Failed to init VarManger");
return;
}
response.set_error_code(SUCCESS);
GELOGD("[Handle][Init VarManager] success.");
}
void EventHandler::HandleBatchLoadRequest(deployer::ExecutorRequest &request,
deployer::ExecutorResponse &response) {
if (BatchLoadModels(request) != SUCCESS) {
response.set_error_code(FAILED);
response.set_error_message("Failed to batch load models");
return;
}
response.set_error_code(SUCCESS);
GELOGD("[Handle][BatchLoadModel] success.");
}
Status EventHandler::BatchParseAndLoadModels(const deployer::ExecutorRequest_BatchLoadModelMessage &model_messages) {
int32_t pool_size = model_messages.models_size() > kMaxParseModelThreadPoolSize ?
kMaxParseModelThreadPoolSize :
model_messages.models_size();
ThreadPool pool("ge_dpl_prsm", static_cast<uint32_t>(pool_size), false);
std::vector<std::future<Status>> fut_rets;
for (const auto &load_model_req : model_messages.models()) {
auto fut = pool.commit([this, &load_model_req]() -> Status {
GE_CHK_STATUS_RET(context_->ParseModel(load_model_req), "Failed to parse model message.");
GE_CHK_STATUS_RET(context_->AttachQueues(load_model_req), "Failed to attach model queues.");
return SUCCESS;
});
fut_rets.emplace_back(std::move(fut));
}
for (int32_t i = 0; i < model_messages.models_size(); ++i) {
GE_CHK_STATUS_RET(fut_rets[i].get(), "Failed to parse model");
const auto &load_model_req = model_messages.models(i);
GE_CHK_STATUS_RET(context_->LoadModel(load_model_req), "Failed to process load model message.");
}
return SUCCESS;
}
Status EventHandler::BatchLoadModels(deployer::ExecutorRequest &request) {
auto &message = request.batch_load_model_message();
GEEVENT("[Batch][LoadModel] begin, model size = %d.", message.models_size());
context_->UpdateOptions(message.options());
if (!once_inited_) {
GE_CHK_STATUS(context_->SetOpTimeout(), "Failed to set op timeout");
GE_CHK_STATUS(context_->SetDeviceSatMode(), "Failed to set device sat mode");
once_inited_ = true;
}
GE_CHK_STATUS_RET(BatchParseAndLoadModels(message), "Failed to parse and load models");
GEEVENT("[Batch][LoadModel] success, model size = %d.", message.models_size());
return SUCCESS;
}
void EventHandler::HandleClearModelRequest(deployer::ExecutorRequest &request,
deployer::ExecutorResponse &response) {
GELOGD("[Handle][Clear Model] begin.");
const auto &clear_model_req = request.clear_model_message();
const auto clear_msg_type = clear_model_req.clear_msg_type();
if ((clear_msg_type != kClearTypeStop) && (clear_msg_type != kClearTypeClear)) {
GELOGE(FAILED, "Failed to clear model, invalid type: %d", clear_msg_type);
const std::string err_msg = "Failed to clear model, invalid type: " +
std::to_string(clear_msg_type);
response.set_error_code(FAILED);
response.set_error_message(err_msg.c_str());
return;
}
const auto root_model_id = clear_model_req.model_id();
std::map<uint32_t, std::unique_ptr<ExecutorContext::ModelHandle>> *submodel_map = nullptr;
if (context_->GetModel(root_model_id, submodel_map) != SUCCESS) {
GELOGE(FAILED, "Failed to get model: %u", root_model_id);
const std::string err_msg = "Failed to get model, model id: " +
std::to_string(root_model_id);
response.set_error_code(FAILED);
response.set_error_message(err_msg.c_str());
return;
}
std::vector<uint32_t> davinci_model_runtime_ids;
std::vector<ExecutorContext::ModelHandle *> dynamic_model_handles;
for (auto it = submodel_map->begin(); it != submodel_map->end(); it++) {
ExecutorContext::ModelHandle *handle = it->second.get();
if (handle->GetModelRuntimeIdOrHandle(davinci_model_runtime_ids, dynamic_model_handles) != SUCCESS) {
response.set_error_code(FAILED);
response.set_error_message("Failed to get clear model handle");
return;
}
}
DoClearModel(davinci_model_runtime_ids, dynamic_model_handles, clear_msg_type, response);
GELOGD("[Handle][Clear Model] end.");
}
void EventHandler::DoClearModel(const std::vector<uint32_t> &davinci_model_runtime_ids,
const std::vector<ExecutorContext::ModelHandle *> &dynamic_model_handles,
const int32_t clear_msg_type,
deployer::ExecutorResponse &response) const {
uint32_t parallel_num = static_cast<uint32_t>(dynamic_model_handles.size());
if (!davinci_model_runtime_ids.empty()) {
parallel_num += kClearDavinciModelThreadNum;
}
if (parallel_num > kSingleThreadNum) {
return DoClearModelPara(davinci_model_runtime_ids, dynamic_model_handles, clear_msg_type, response);
}
for (const auto dynamic_model_handle : dynamic_model_handles) {
const auto ret = dynamic_model_handle->ClearModel(clear_msg_type);
if (ret != SUCCESS) {
response.set_error_code(FAILED);
response.set_error_message("Failed to clear model");
return;
}
}
const auto ret = CpuTasks::ExecuteModelClearTask(clear_msg_type, davinci_model_runtime_ids);
if (ret != SUCCESS) {
response.set_error_code(FAILED);
response.set_error_message("Failed to clear model");
return;
}
response.set_error_code(SUCCESS);
}
void EventHandler::DoClearModelPara(const std::vector<uint32_t> &davinci_model_runtime_ids,
const std::vector<ExecutorContext::ModelHandle *> &dynamic_model_handles,
const int32_t clear_msg_type,
deployer::ExecutorResponse &response) const {
const uint32_t parallel_num = static_cast<uint32_t>(dynamic_model_handles.size());
ThreadPool pool("ge_dpl_clrm", static_cast<uint32_t>(parallel_num), false);
std::vector<std::future<Status>> fut_rets;
for (const auto dynamic_model_handle : dynamic_model_handles) {
auto fut = pool.commit([dynamic_model_handle, clear_msg_type]() -> Status {
GE_CHK_STATUS_RET_NOLOG(dynamic_model_handle->ClearModel(clear_msg_type));
return SUCCESS;
});
fut_rets.emplace_back(std::move(fut));
}
const auto ret = CpuTasks::ExecuteModelClearTask(clear_msg_type, davinci_model_runtime_ids);
if (ret != SUCCESS) {
response.set_error_code(FAILED);
response.set_error_message("Failed to clear model");
return;
}
for (auto &fut : fut_rets) {
if (fut.get() != SUCCESS) {
response.set_error_code(FAILED);
response.set_error_message("Failed to clear model");
return;
}
}
response.set_error_code(SUCCESS);
}
void EventHandler::HandleDataFlowExceptionNotifyRequest(deployer::ExecutorRequest &request,
deployer::ExecutorResponse &response) {
GELOGD("[Handle][Exception notify] begin.");
const auto &exception_notify_request = request.exception_notify_request();
const auto &exception_notify = exception_notify_request.exception_notify();
const auto root_model_id = exception_notify_request.root_model_id();
std::map<uint32_t, std::unique_ptr<ExecutorContext::ModelHandle>> *submodel_map = nullptr;
if (context_->GetModel(root_model_id, submodel_map) != SUCCESS) {
GELOGE(FAILED, "No model found, model: %u", root_model_id);
response.set_error_code(FAILED);
response.set_error_message("No model found");
return;
}
std::vector<uint32_t> davinci_model_runtime_ids;
std::vector<ExecutorContext::ModelHandle *> dynamic_model_handles;
for (auto it = submodel_map->begin(); it != submodel_map->end(); it++) {
ExecutorContext::ModelHandle *handle = it->second.get();
if ((handle->GetScope() != exception_notify.scope()) || handle->IsInvokedNN()) {
GELOGI("scope[%s] is mismatch exception scope[%s] or is invoked nn[%d], ignore it, model: %u",
handle->GetScope().c_str(), exception_notify.scope().c_str(), static_cast<int32_t>(handle->IsInvokedNN()),
root_model_id);
continue;
}
auto get_ret = handle->GetModelRuntimeIdOrHandle(davinci_model_runtime_ids, dynamic_model_handles);
if (get_ret != SUCCESS) {
GELOGE(get_ret, "Failed to get model: %u", root_model_id);
response.set_error_code(get_ret);
response.set_error_message("get model runtime id or handle failed");
return;
}
}
Status notify_ret = DoDataFlowExceptionNotify(davinci_model_runtime_ids, dynamic_model_handles, exception_notify);
if (notify_ret != SUCCESS) {
response.set_error_code(notify_ret);
response.set_error_message("data flow exception notify failed");
} else {
response.set_error_code(SUCCESS);
}
GELOGD("[Handle][Exception notify] end.");
}
Status EventHandler::DoDataFlowExceptionNotify(const std::vector<uint32_t> &davinci_model_runtime_ids,
const std::vector<ExecutorContext::ModelHandle *> &dynamic_model_handles,
const deployer::DataFlowExceptionNotify &exception_notify) const {
uint32_t type = exception_notify.type();
uint64_t trans_id = exception_notify.trans_id();
const uint32_t parallel_num = static_cast<uint32_t>(dynamic_model_handles.size());
std::vector<std::future<Status>> fut_rets;
std::unique_ptr<ThreadPool> thread_pool;
if (parallel_num > 0) {
fut_rets.reserve(parallel_num);
thread_pool = MakeUnique<ThreadPool>("ge_dpl_hexn", parallel_num, false);
if (thread_pool == nullptr) {
GELOGE(FAILED, "make unique for ThreadPool failed, parallel_num=%u", parallel_num);
return FAILED;
}
for (const auto dynamic_model_handle : dynamic_model_handles) {
auto fut = thread_pool->commit([dynamic_model_handle, type, trans_id]() -> Status {
GE_CHK_STATUS_RET_NOLOG(dynamic_model_handle->ExceptionNotify(type, trans_id));
return SUCCESS;
});
fut_rets.emplace_back(std::move(fut));
}
}
const auto ret = CpuTasks::ExceptionNotify(davinci_model_runtime_ids, type, trans_id);
if (ret != SUCCESS) {
GELOGE(FAILED, "ExecuteDataFlowExceptionNotify failed, model runtime ids=%s",
ToString(davinci_model_runtime_ids).c_str());
return ret;
}
for (auto &fut : fut_rets) {
auto fut_ret = fut.get();
if (fut_ret != SUCCESS) {
GELOGE(fut_ret, "DataFlowExceptionNotify to dynamic model failed, trans_id=%lu", trans_id);
return fut_ret;
}
}
return SUCCESS;
}
void EventHandler::HandleProfInfo(deployer::ExecutorRequest &request,
deployer::ExecutorResponse &response) {
GELOGI("[Handle][Set Proiling Info] begin.");
if (context_->UpdateProfInfo(request) != SUCCESS) {
response.set_error_code(FAILED);
response.set_error_message("Failed to set prof");
return;
}
response.set_error_code(SUCCESS);
GELOGI("[Handle][Set Proiling Info] end.");
}
}