* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "executor/executor_context.h"
#include <fstream>
#include "nlohmann/json.hpp"
#include "common/compile_profiling/ge_call_wrapper.h"
#include "executor/cpu_sched_event_dispatcher.h"
#include "executor/proxy_dynamic_model_executor.h"
#include "executor/dynamic_model_executor.h"
#include "cpu_tasks.h"
#include "framework/executor/ge_executor.h"
#include "graph/ge_context.h"
#include "graph/manager/mem_manager.h"
#include "mmpa/mmpa_api.h"
#include "graph/debug/ge_attr_define.h"
#include "dflow/base/deploy/deploy_planner.h"
#include "common/checker.h"
#include "common/thread_pool/thread_pool.h"
#include "common/helper/model_parser_base.h"
#include "framework/common/helper/model_helper.h"
#include "toolchain/prof_api.h"
#include "dflow/inc/data_flow/model/graph_model.h"
#include "acl/acl_mdl.h"
#include "acl/acl_rt.h"
#include "graph/types.h"
#include "dflow/base/model/flow_model_om_loader.h"
#include "graph/utils/op_type_utils.h"
#include "common/df_chk.h"
namespace ge {
namespace {
constexpr int32_t kUserUnsetPriority = -1;
constexpr const char *OPTION_FLOAT_OVERFLOW_MODE = "ge.exec.float_overflow_mode";
constexpr const char *OPTION_FLOAT_OVERFLOW_MODE_SATURATION = "saturation";
constexpr const char *OPTION_FLOAT_OVERFLOW_MODE_INFNAN = "inf_nan";
const std::string kGatherDequeue = "gatherDequeue";
constexpr int32_t kDefaultAttachQueueTimeout = 3000;
constexpr const char_t *kAttrNameInvokedByBuiltIn = "_dflow_invoked_by_built_in";
constexpr int32_t kDefaultHostDeviceId = 64;
}
class SharedMemoryManager : public MemManager {
public:
static SharedMemoryManager &GetInstance() {
static SharedMemoryManager instance;
return instance;
}
uint8_t *MallocMemory(rtMemType_t memory_type,
const std::string &purpose,
const std::string &memory_key,
size_t memory_size,
uint32_t device_id) override {
if (memory_type != RT_MEMORY_HBM) {
return MemManager::MallocMemory(memory_type, purpose, memory_key, memory_size, device_id);
}
{
std::lock_guard<std::mutex> lk(mu_);
auto it = var_mem_bases_.find(memory_key);
if (it != var_mem_bases_.cend()) {
return it->second;
}
}
return MemManager::MallocMemory(memory_type, purpose, memory_key, memory_size, device_id);
}
Status FreeMemory(rtMemType_t memory_type, const std::string &memory_key, uint32_t device_id) override {
std::lock_guard<std::mutex> lk(mu_);
if (memory_type != RT_MEMORY_HBM || var_mem_bases_.count(memory_key) == 0) {
return MemManager::FreeMemory(memory_type, memory_key, device_id);
}
return SUCCESS;
}
uint8_t *GetMemoryBase(rtMemType_t memory_type, const std::string &memory_key, uint32_t device_id) override {
if (memory_type != RT_MEMORY_HBM) {
return MemManager::GetMemoryBase(memory_type, memory_key, device_id);
}
return GetMemoryAddr(memory_type, memory_key, device_id);
}
uint8_t *GetMemoryAddr(rtMemType_t memory_type, const std::string &memory_key, uint32_t device_id) override {
if (memory_type != RT_MEMORY_HBM) {
return MemManager::GetMemoryAddr(memory_type, memory_key, device_id);
}
std::lock_guard<std::mutex> lk(mu_);
auto it = var_mem_bases_.find(memory_key);
if (it == var_mem_bases_.cend()) {
GELOGW("MemoryAllocator::GetMemoryAddr failed, memory_key[%s] was does not exist", memory_key.c_str());
return nullptr;
}
return it->second;
}
void AddSharedMemory(const std::string &key, uint8_t *memory_base) {
std::lock_guard<std::mutex> lk(mu_);
var_mem_bases_[key] = memory_base;
}
private:
std::mutex mu_;
std::map<std::string, uint8_t *> var_mem_bases_;
};
ExecutorContext::ModelHandle *ExecutorContext::GetOrCreateModelHandle(uint32_t root_model_id, uint32_t model_id) {
std::lock_guard<std::mutex> lk(mu_);
auto &submodels = model_handles_[root_model_id];
const auto &it = submodels.find(model_id);
if (it != submodels.cend()) {
return it->second.get();
}
auto new_handle = MakeUnique<ModelHandle>();
if (new_handle == nullptr) {
return nullptr;
}
if (!model_handles_[root_model_id].emplace(model_id, std::move(new_handle)).second) {
GELOGE(FAILED, "[Add][Model] failed, model already added, sub_model_id = %u", model_id);
return nullptr;
}
ExecutorContext::ModelHandle *handle = model_handles_[root_model_id][model_id].get();
return handle;
}
ExecutorContext &ExecutorContext::LocalContext() {
static ExecutorContext context;
return context;
}
PneModelPtr ExecutorContext::GetLocalModel(uint32_t root_model_id, uint32_t model_id) {
std::lock_guard<std::mutex> lk(mu_);
const auto &root_it = models_.find(root_model_id);
if (root_it == models_.cend()) {
return nullptr;
}
const auto &sub_it = root_it->second.find(model_id);
if (sub_it == root_it->second.cend()) {
return nullptr;
}
return sub_it->second;
}
void ExecutorContext::AddLocalModel(uint32_t root_model_id, uint32_t model_id, const PneModelPtr &model) {
std::lock_guard<std::mutex> lk(mu_);
models_[root_model_id][model_id] = model;
}
void ExecutorContext::RemoveLocalModel(uint32_t root_model_id) {
std::lock_guard<std::mutex> lk(mu_);
models_.erase(root_model_id);
}
Status ExecutorContext::Initialize() const {
return SUCCESS;
}
void ExecutorContext::Finalize() const {}
void ExecutorContext::SetBaseDir(const std::string &base_dir) {
base_dir_ = base_dir;
}
Status ExecutorContext::ParseModel(uint32_t root_model_id, uint32_t model_id, const std::string &model_path) {
GELOGI("Parse model begin, root_model_id = %u, submodel_id = %u", root_model_id, model_id);
auto handle = GetOrCreateModelHandle(root_model_id, model_id);
GE_CHECK_NOTNULL(handle);
auto model = LocalContext().GetLocalModel(root_model_id, model_id);
if (model != nullptr) {
auto graph_model = std::dynamic_pointer_cast<GraphModel>(model);
GE_CHECK_NOTNULL(graph_model, "cast to graph model failed, model name=%s", model->GetModelName().c_str());
handle->SetModelData(graph_model->GetModelData());
handle->SetRootGraph(graph_model->GetRootGraph());
GELOGI("Parse model from cache success");
return SUCCESS;
}
GE_CHK_BOOL_RET_STATUS((model_path.find("..") == std::string::npos), FAILED,
"Model path[%s] is invalid, include relative path.", model_path.c_str());
std::string model_real_path = RealPath(model_path.c_str());
GE_CHK_BOOL_RET_STATUS((!model_real_path.empty()), FAILED, "Model path[%s] is invalid.", model_path.c_str());
GE_CHK_STATUS_RET_NOLOG(handle->ParseModel(model_path));
GELOGI("Parse model from om success, path = %s, root_model_id = %u, submodel_id = %u",
model_path.c_str(), root_model_id, model_id);
return SUCCESS;
}
std::unique_ptr<std::istream> ExecutorContext::CreateInputStream(const std::string &path) const {
return MakeUnique<std::ifstream>(path, std::ios::in | std::ios::binary);
}
Status ExecutorContext::GetModel(uint32_t root_model_id,
std::map<uint32_t, std::unique_ptr<ModelHandle>> *&submodel_map) {
std::lock_guard<std::mutex> lk(mu_);
auto it = model_handles_.find(root_model_id);
if (it == model_handles_.end()) {
GELOGE(FAILED, "[Get][Model] failed, root model id = %u.", root_model_id);
return FAILED;
}
submodel_map = &(it->second);
return SUCCESS;
}
ExecutorContext::ModelHandle::~ModelHandle() {
GELOGD("Begin to destruct model_handle.");
if (loaded_) {
(void) UnloadModel();
}
if (model_data_from_cache_ && model_data_.model_data != nullptr) {
delete[] static_cast<char_t *>(model_data_.model_data);
model_data_.model_data = nullptr;
model_data_from_cache_ = false;
}
}
Status ExecutorContext::SyncSharedVarManager(const deployer::ExecutorRequest &request) const {
auto &sync_var_manager_request = request.sync_var_manager_message();
const auto &var_manager_info = sync_var_manager_request.var_manager_info();
auto session_id = var_manager_info.session_id();
GELOGI("Init var manager, session id is %zu.", session_id);
const std::vector<rtMemType_t> mem_type{RT_MEMORY_HBM, RT_MEMORY_P2P_DDR};
auto &var_mem_manager = SharedMemoryManager::GetInstance();
GE_CHK_STATUS_RET(var_mem_manager.Initialize(mem_type),
"[Init][MemManager] MemoryAllocatorManager initialize failed.");
int32_t device_id = -1;
DF_CHK_ACL_RET(aclrtGetDevice(&device_id));
GELOGD("SyncSharedVarManager get device_id[%d] success", device_id);
for (const auto &shared_content_desc : sync_var_manager_request.shared_content_descs()) {
GeTensorDesc tensor_desc;
GeTensorSerializeUtils::AssembleGeTensorDescFromProto(&shared_content_desc.tensor_desc(), tensor_desc);
}
GELOGI("Init var manager successfully, session_id = %lu.", session_id);
return SUCCESS;
}
void ExecutorContext::UpdateGraphOptions(const std::string &key, const std::string &value) {
std::map<std::string, std::string> graph_options = GetThreadLocalContext().GetAllGraphOptions();
graph_options[key] = value;
GetThreadLocalContext().SetGraphOption(graph_options);
}
void ExecutorContext::UpdateOptions(const deployer::Options &options) {
std::map<std::string, std::string> global_options = GetThreadLocalContext().GetAllGlobalOptions();
for (const auto &item : options.global_options()) {
global_options[item.first] = item.second;
GELOGI("Insert global option, key = %s, value = %s.", item.first.c_str(), item.second.c_str());
}
GetThreadLocalContext().SetGlobalOption(global_options);
std::map<std::string, std::string> session_options = GetThreadLocalContext().GetAllSessionOptions();
for (const auto &item : options.session_options()) {
session_options[item.first] = item.second;
GELOGI("Insert session option, key = %s, value = %s.", item.first.c_str(), item.second.c_str());
}
GetThreadLocalContext().SetSessionOption(session_options);
std::map<std::string, std::string> graph_options = GetThreadLocalContext().GetAllGraphOptions();
for (const auto &item : options.graph_options()) {
graph_options[item.first] = item.second;
GELOGI("Insert graph option, key = %s, value = %s.", item.first.c_str(), item.second.c_str());
}
GetThreadLocalContext().SetGraphOption(graph_options);
}
Status ExecutorContext::SetOpTimeout() {
std::string op_wait_timeout_str;
(void)GetContext().GetOption(OP_WAIT_TIMEOUT, op_wait_timeout_str);
if (!op_wait_timeout_str.empty()) {
int64_t op_wait_timeout_sec = 0;
GE_CHK_STATUS_RET(ConvertToInt64(op_wait_timeout_str, op_wait_timeout_sec),
"Convert result[%s] to int32 failed.", op_wait_timeout_str.c_str());
GE_CHECK_LE(op_wait_timeout_sec, static_cast<int64_t>(UINT32_MAX));
auto ret = aclrtSetOpWaitTimeout(static_cast<uint32_t>(op_wait_timeout_sec));
GELOGI("Set aclrtSetOpWaitTimeout[%s], ret = %d.", op_wait_timeout_str.c_str(), ret);
}
std::string op_execute_timeout_str;
(void)GetContext().GetOption(OP_EXECUTE_TIMEOUT, op_execute_timeout_str);
if (!op_execute_timeout_str.empty()) {
int64_t op_execute_timeout_sec = 0;
GE_CHK_STATUS_RET(ConvertToInt64(op_execute_timeout_str, op_execute_timeout_sec),
"Convert result[%s] to int32 failed.", op_execute_timeout_str.c_str());
GE_CHECK_LE(op_execute_timeout_sec, static_cast<int64_t>(UINT32_MAX));
auto ret = aclrtSetOpExecuteTimeOut(static_cast<uint32_t>(op_execute_timeout_sec));
GELOGI("Set aclrtSetOpExecuteTimeOut[%s], ret = %d.", op_execute_timeout_str.c_str(), ret);
}
return SUCCESS;
}
Status ExecutorContext::SetDeviceSatMode() {
std::string sat_mode;
(void)GetContext().GetOption(OPTION_FLOAT_OVERFLOW_MODE, sat_mode);
if (!sat_mode.empty()) {
static std::map<std::string, aclrtFloatOverflowMode> mode_transfer = {
{OPTION_FLOAT_OVERFLOW_MODE_SATURATION, ACL_RT_OVERFLOW_MODE_SATURATION},
{OPTION_FLOAT_OVERFLOW_MODE_INFNAN, ACL_RT_OVERFLOW_MODE_INFNAN}};
const auto &it = mode_transfer.find(sat_mode);
if (it != mode_transfer.cend()) {
DF_CHK_ACL_RET(aclrtSetDeviceSatMode(it->second));
GEEVENT("aclrtSetDeviceSatMode success, mode = %d.", static_cast<int32_t>(it->second));
}
}
return SUCCESS;
}
Status ExecutorContext::ParseModelEschedPriority(const deployer::ExecutorRequest_LoadModelRequest &request,
ModelHandle &handle) const {
auto attrs = request.attrs();
int32_t process_priority = kUserUnsetPriority;
int32_t event_priority = kUserUnsetPriority;
if (attrs.contains(ATTR_NAME_ESCHED_PROCESS_PRIORITY)) {
GE_CHK_STATUS_RET(ConvertToInt32(attrs[ATTR_NAME_ESCHED_PROCESS_PRIORITY], process_priority),
"Failed to parse model esched process priority.");
}
if (attrs.contains(ATTR_NAME_ESCHED_EVENT_PRIORITY)) {
GE_CHK_STATUS_RET(ConvertToInt32(attrs[ATTR_NAME_ESCHED_EVENT_PRIORITY], event_priority),
"Failed to parse model esched event priority.");
}
handle.SetEschedPriority(process_priority, event_priority);
GELOGD("[ModelEschedPriority]: process priority[%d], event priority[%d].", process_priority, event_priority);
return SUCCESS;
}
Status ExecutorContext::ParseInputAlignAttrs(const deployer::ExecutorRequest_LoadModelRequest &request,
InputAlignAttrs &input_align_attrs) {
if (!request.has_input_align_attrs()) {
return SUCCESS;
}
const auto &proto_input_align_attrs = request.input_align_attrs();
input_align_attrs.align_max_cache_num = proto_input_align_attrs.align_max_cache_num();
input_align_attrs.align_timeout = proto_input_align_attrs.align_timeout();
input_align_attrs.drop_when_not_align = proto_input_align_attrs.drop_when_not_align();
return SUCCESS;
}
Status ExecutorContext::ParseModel(const deployer::ExecutorRequest_LoadModelRequest &request) {
GE_TIMESTAMP_START(ParseModel);
auto root_model_id = request.root_model_id();
auto sub_model_id = request.model_id();
auto model_path = base_dir_ + request.model_path();
const auto &saved_model_path = request.saved_model_file_path();
const auto &model_ins_name = request.model_instance_name();
std::string real_path = RealPath(saved_model_path.c_str());
if (!real_path.empty()) {
model_path = real_path;
}
std::string trace_log = "Parse model[" + model_ins_name + "] from path[" + model_path + "]";
GE_CHK_STATUS_RET(ParseModel(root_model_id, sub_model_id, model_path), "Failed to parse model");
GE_TIMESTAMP_EVENT_END(ParseModel, trace_log.c_str());
return SUCCESS;
}
Status ExecutorContext::AttachQueues(const deployer::ExecutorRequest_LoadModelRequest &request) {
std::map<int32_t, std::set<uint32_t>> attach_queue_list;
for (const auto &queue_attr : request.model_queues_attrs().input_queues_attrs()) {
if ((queue_attr.device_type() == static_cast<int32_t>(NPU)) &&
(queue_attr.queue_id() != UINT32_MAX)) {
attach_queue_list[queue_attr.device_id()].emplace(queue_attr.queue_id());
}
}
for (const auto &queue_attr : request.model_queues_attrs().output_queues_attrs()) {
if ((queue_attr.device_type() == static_cast<int32_t>(NPU)) &&
(queue_attr.queue_id() != UINT32_MAX)) {
attach_queue_list[queue_attr.device_id()].emplace(queue_attr.queue_id());
}
}
for (const auto &queue_attr : request.status_queues().output_queues_attrs()) {
if ((queue_attr.device_type() == static_cast<int32_t>(NPU)) &&
(queue_attr.queue_id() != UINT32_MAX)) {
attach_queue_list[queue_attr.device_id()].emplace(queue_attr.queue_id());
}
}
for (const auto &it : attach_queue_list) {
auto device_id = it.first;
const auto &queue_ids = it.second;
const auto ret = rtMemQueueInit(device_id);
if (ret != RT_ERROR_NONE && ret != ACL_ERROR_RT_REPEATED_INIT) {
return RT_ERROR_TO_GE_STATUS(ret);
}
for (auto queue_id : queue_ids) {
GE_CHK_RT_RET(rtMemQueueAttach(device_id, queue_id, kDefaultAttachQueueTimeout));
}
}
return SUCCESS;
}
Status ExecutorContext::LoadModel(const deployer::ExecutorRequest_LoadModelRequest &request) {
ModelHandle::LoadParam param;
for (const auto &input_queue : request.model_queues_attrs().input_queues_attrs()) {
QueueAttrs in_queue_attrs;
in_queue_attrs.queue_id = input_queue.queue_id();
in_queue_attrs.device_type = input_queue.device_type();
in_queue_attrs.device_id = input_queue.device_id();
in_queue_attrs.logic_id = input_queue.global_logic_id();
param.input_queues.emplace_back(in_queue_attrs);
}
for (const auto &output_queue : request.model_queues_attrs().output_queues_attrs()) {
QueueAttrs out_queue_attrs;
out_queue_attrs.queue_id = output_queue.queue_id();
out_queue_attrs.device_type = output_queue.device_type();
out_queue_attrs.device_id = output_queue.device_id();
out_queue_attrs.logic_id = output_queue.global_logic_id();
param.output_queues.emplace_back(out_queue_attrs);
}
param.input_fusion_offsets = std::move(std::vector<int32_t>(request.input_fusion_offsets().cbegin(),
request.input_fusion_offsets().cend()));
param.replica_num = request.replica_num();
param.replica_idx = request.replica_idx();
param.model_uuid = request.model_uuid();
std::vector<QueueAttrs> status_output_queues_attrs;
for (const auto &status_output_queue : request.status_queues().output_queues_attrs()) {
QueueAttrs status_output_queue_attrs;
status_output_queue_attrs.queue_id = status_output_queue.queue_id();
status_output_queue_attrs.device_type = status_output_queue.device_type();
status_output_queue_attrs.device_id = status_output_queue.device_id();
status_output_queue_attrs.logic_id = status_output_queue.global_logic_id();
status_output_queues_attrs.emplace_back(status_output_queue_attrs);
}
auto root_model_id = request.root_model_id();
auto sub_model_id = request.model_id();
const auto &model_path = base_dir_ + request.model_path();
GE_CHK_STATUS_RET(ParseInputAlignAttrs(request, param.input_align_attrs),
"Failed to parse input align attrs, model_id = %d, sub_model_id = %d", root_model_id, sub_model_id);
if (status_output_queues_attrs.size() == 1U) {
param.status_output_queue = status_output_queues_attrs[0];
param.need_report_status = true;
} else {
param.need_report_status = false;
}
param.is_dynamic_sched = request.is_dynamic_sched();
param.is_head = request.is_head();
GELOGI("DynamicSched model load info: model_uuid=%u, status output queue=%u, need report status=%d,"
" is dynamic sched=%d, is_head=%d", param.model_uuid, param.status_output_queue.queue_id,
param.need_report_status, param.is_dynamic_sched, param.is_head);
auto handle = GetOrCreateModelHandle(root_model_id, sub_model_id);
GE_CHECK_NOTNULL(handle);
handle->SetExecuteTimes(request.execute_times());
handle->SetIsDynamicProxyControlled(request.is_dynamic_proxy_controlled());
handle->SetScope(request.scope());
handle->SetEnableExceptionCatch(request.enable_exception_catch());
GE_CHK_STATUS_RET_NOLOG(ParseModelEschedPriority(request, *handle));
GELOGD("Begin to load model, root_model_id = %u, submodel_id = %u, path = %s, is_dynamic_proxy_controlled = %d.",
root_model_id, sub_model_id, model_path.c_str(), static_cast<int32_t>(request.is_dynamic_proxy_controlled()));
GE_CHK_STATUS_RET(handle->LoadModel(param), "Failed to load model");
GELOGD("Success to load model, root_model_id = %u, submodel_id = %u, path = %s, is_dynamic_proxy_controlled = %d.",
root_model_id, sub_model_id, model_path.c_str(), static_cast<int32_t>(request.is_dynamic_proxy_controlled()));
return SUCCESS;
}
Status ExecutorContext::UpdateProfInfo(const deployer::ExecutorRequest &request) {
const auto &prof_info_request = request.update_prof_message();
const std::string &prof_data = prof_info_request.prof_data();
const bool is_prof_start = prof_info_request.is_prof_start();
static bool is_start_status = false;
GELOGI("Is prof start[%d], current status %d.", is_prof_start, is_start_status);
if (is_start_status == is_prof_start) {
return SUCCESS;
}
MsprofConfig param = {};
param.devNums = 1;
param.devIdList[0] = kDefaultHostDeviceId;
auto ret = strcpy_s(param.dumpPath, sizeof(param.dumpPath), std::to_string(kDefaultHostDeviceId).c_str());
GE_CHK_BOOL_RET_STATUS(ret == EOK, FAILED, "strcpy_s ret:%d", ret);
ret = strcpy_s(param.sampleConfig, sizeof(param.sampleConfig), prof_data.c_str());
GE_CHK_BOOL_RET_STATUS(ret == EOK, FAILED, "strcpy_s prof_data ret:%d", ret);
ret = is_prof_start ? MsprofStart(MSPROF_CTRL_INIT_PURE_CPU, reinterpret_cast<const char *>(¶m), sizeof(param)) :
MsprofStop(MSPROF_CTRL_INIT_PURE_CPU, reinterpret_cast<const char *>(¶m), sizeof(param));
GE_CHK_BOOL_RET_STATUS(ret == MSPROF_ERROR_NONE, FAILED,
"MsprofStart/MsprofStop failed, is prof start : %d", is_prof_start);
is_start_status = is_prof_start;
return SUCCESS;
}
Status ExecutorContext::ModelHandle::GetModelData(ModelData &model_data) {
model_data = model_data_;
return SUCCESS;
}
void ExecutorContext::ModelHandle::SetModelData(const ModelData &model_data) {
model_data_ = model_data;
}
Status ExecutorContext::ModelHandle::GetRootGraph(ComputeGraphPtr &root_graph) {
root_graph = root_graph_;
return SUCCESS;
}
void ExecutorContext::ModelHandle::SetRootGraph(const ComputeGraphPtr &root_graph) {
root_graph_ = root_graph;
}
Status ExecutorContext::ModelHandle::GetModelRuntimeIdOrHandle(std::vector<uint32_t> &davinci_model_runtime_ids,
std::vector<ExecutorContext::ModelHandle *> &dynamic_model_handles) {
if (inner_model_id_ != UINT32_MAX) {
uint32_t model_runtime_id = UINT32_MAX;
GE_CHK_STATUS_RET(GeExecutor::GetRuntimeModelId(inner_model_id_, model_runtime_id),
"Failed to get runtime model id.");
davinci_model_runtime_ids.emplace_back(model_runtime_id);
return SUCCESS;
}
if (dynamic_model_executor_ != nullptr) {
if (is_dynamic_proxy_controlled_) {
const ProxyDynamicModelExecutor *const proxyDynamicModel =
dynamic_cast<const ProxyDynamicModelExecutor *>(dynamic_model_executor_.get());
GE_CHECK_NOTNULL(proxyDynamicModel);
davinci_model_runtime_ids.emplace_back(proxyDynamicModel->GetRuntimeModelId());
}
dynamic_model_handles.emplace_back(this);
return SUCCESS;
}
GELOGE(FAILED, "This model is not davinci model or dynamic model.");
return FAILED;
}
Status ExecutorContext::ModelHandle::ClearModel(const int32_t clear_type) {
return dynamic_model_executor_->ClearModel(clear_type);
}
Status ExecutorContext::ModelHandle::ExceptionNotify(uint32_t type, uint64_t trans_id) {
return dynamic_model_executor_->ExceptionNotify(type, trans_id);
}
Status ExecutorContext::ModelHandle::ParseModel(const std::string &model_path) {
GELOGD("Begin to parse model[%s].", model_path.c_str());
ge::ModelData model{};
GE_CHK_STATUS_RET(ModelParserBase::LoadFromFile(model_path.c_str(), 0, model),
"Load model from file[%s] failed.", model_path.c_str());
GE_CHK_STATUS_RET(FlowModelOmLoader::TransModelDataToComputeGraph(model, root_graph_), "Failed to trans Modeldata to ComputeGraph");
GE_CHECK_NOTNULL(root_graph_, "load root graph is null");
model_data_ = model;
model_data_from_cache_ = true;
return SUCCESS;
}
Status ExecutorContext::ModelHandle::LoadModel(const LoadParam ¶m) {
if (loaded_) {
GELOGW("Already loaded");
return SUCCESS;
}
ModelData model_data;
ComputeGraphPtr root_graph;
GE_CHK_STATUS_RET_NOLOG(GetRootGraph(root_graph));
GE_CHK_STATUS_RET_NOLOG(GetModelData(model_data));
GE_CHK_STATUS_RET(DoLoadModel(model_data, root_graph, param), "Do load model failed.");
loaded_ = true;
return SUCCESS;
}
Status ExecutorContext::ModelHandle::DoLoadModel(const ModelData &model_data,
const ComputeGraphPtr &root_graph,
const LoadParam ¶ms) {
return DoLoadModelWithQ(model_data, root_graph, params);
}
Status ExecutorContext::ModelHandle::UnloadModel() {
if (dynamic_model_executor_ != nullptr) {
dynamic_model_executor_->UnloadModel();
dynamic_model_executor_->Finalize();
dynamic_model_executor_.reset();
} else {
GE_CHK_STATUS_RET(DoUnloadModel(inner_model_id_),
"[Unload][Model] failed, model_id = %u", inner_model_id_);
GELOGD("[Unload][Model] success, model_id = %u", inner_model_id_);
inner_model_id_ = UINT32_MAX;
}
loaded_ = false;
return SUCCESS;
}
Status ExecutorContext::ModelHandle::CheckAicpuAlignTask(const InputAlignAttrs &input_align_attrs) {
if (input_align_attrs.align_max_cache_num == 0U) {
GELOGD("No need to check gather dequeue task result cache number is 0.");
return SUCCESS;
}
bool is_gather_supported = false;
GE_CHK_STATUS_RET(CpuTasks::ExecuteCheckSupported(kGatherDequeue, is_gather_supported));
GE_ASSERT_TRUE(is_gather_supported, "Gather dequeue is not supported in current version. "
"Please update software or unset input align attrs.");
return SUCCESS;
}
Status ExecutorContext::ModelHandle::DoLoadModelWithQ(const ModelData &model_data,
const ComputeGraphPtr &root_graph,
const LoadParam ¶ms) {
ModelQueueParam model_queue_param;
model_queue_param.group_total_count = params.replica_num;
model_queue_param.group_index = params.replica_idx;
for (const auto &queue : params.input_queues) {
model_queue_param.input_queues.emplace_back(queue.queue_id);
model_queue_param.input_queues_attrs.emplace_back(queue);
}
for (const auto &queue : params.output_queues) {
model_queue_param.output_queues.emplace_back(queue.queue_id);
model_queue_param.output_queues_attrs.emplace_back(queue);
}
model_queue_param.input_fusion_offsets = params.input_fusion_offsets;
model_queue_param.model_uuid = params.model_uuid;
model_queue_param.status_output_queue = params.status_output_queue;
model_queue_param.is_dynamic_sched = params.is_dynamic_sched;
model_queue_param.need_report_status = params.need_report_status;
model_queue_param.is_head = params.is_head;
(void)AttrUtils::GetBool(root_graph, ATTR_NAME_DATA_FLOW_UDF_INVOKED_NN, is_invoked_nn_);
if (!is_invoked_nn_) {
model_queue_param.input_align_attrs = params.input_align_attrs;
}
bool dflow_invoked_by_built_in = false;
(void)AttrUtils::GetBool(root_graph, kAttrNameInvokedByBuiltIn,
dflow_invoked_by_built_in);
model_queue_param.need_check_inputs = !dflow_invoked_by_built_in;
model_queue_param.need_model_config = true;
model_queue_param.mark_dump_step = true;
model_queue_param.io_with_tensor_desc = true;
model_queue_param.copy_inputs_for_non_zero_copy = true;
bool is_host = GetContext().GetHostExecFlag();
bool is_dynamic = false;
(void)AttrUtils::GetBool(root_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic);
is_dynamic = (is_dynamic || (root_graph->GetGraphUnknownFlag()));
GE_TIMESTAMP_START(LoadModel);
if (is_dynamic_proxy_controlled_) {
GELOGD("Load proxy dynamic model, is_dynamic = %d, is_dynamic_proxy_controlled = %d.",
static_cast<int32_t>(is_dynamic), static_cast<int32_t>(is_dynamic_proxy_controlled_));
if (IsEnableExceptionCatch()) {
GE_CHK_STATUS_RET(CpuTasks::CheckSupportExceptionNotify(), "aicpu does not support exception catch.");
}
GE_CHK_STATUS_RET(CheckAicpuAlignTask(model_queue_param.input_align_attrs),
"[Check][Align] attrs failed.");
dynamic_model_executor_ = CreateProxyDynamicModelExecutor();
GE_CHECK_NOTNULL(dynamic_model_executor_);
GE_CHK_STATUS_RET_NOLOG(dynamic_model_executor_->Initialize());
GE_CHK_STATUS_RET_NOLOG(dynamic_model_executor_->LoadModel(model_data, root_graph, model_queue_param));
GE_TIMESTAMP_EVENT_END(LoadModel, "LoadDynamicModel");
GELOGD("[LoadProxyDynamicModel] success");
return SUCCESS;
}
if (is_dynamic || is_host) {
GELOGD("Load dynamic model, is_dynamic = %d, is_dynamic_proxy_controlled = %d.",
static_cast<int32_t>(is_dynamic), static_cast<int32_t>(is_dynamic_proxy_controlled_));
dynamic_model_executor_ = CreateDynamicModelExecutor(is_host);
GE_CHECK_NOTNULL(dynamic_model_executor_);
GE_CHK_STATUS_RET_NOLOG(dynamic_model_executor_->Initialize());
if (IsEnableExceptionCatch()) {
GE_CHK_STATUS_RET(dynamic_model_executor_->CheckLocalAicpuSupportExceptionNotify(),
"local aicpu does not support exception catch.");
}
dynamic_model_executor_->SetModelEschedPriority(esched_process_priority_, esched_event_priority_);
GE_CHK_STATUS_RET_NOLOG(dynamic_model_executor_->LoadModel(model_data, root_graph, model_queue_param));
GE_TIMESTAMP_EVENT_END(LoadModel, "LoadDynamicModel");
GELOGD("[LoadDynamicModel] success");
return SUCCESS;
}
aclError ret;
GELOGD("Load static model, replica_num = %u, replica_idx = %u, is_dynamic = %d, is_dynamic_proxy_controlled = %d.",
params.replica_num, params.replica_idx, static_cast<int32_t>(is_dynamic),
static_cast<int32_t>(is_dynamic_proxy_controlled_));
std::vector<FileConstantMem> external_weight_mem_data{};
GE_CHK_STATUS_RET(DynamicModelExecutor::InitExternalWeightMem(root_graph, external_weight_mem_data), "Failed to init external weright mem.");
if (params.input_queues.empty() && params.output_queues.empty()) {
handle_ = aclmdlCreateConfigHandle();
GE_CHECK_NOTNULL(handle_, "Create acl load config handle failed.");
GE_CHK_STATUS_RET(DynamicModelExecutor::GenerateLoadConfig(model_data, external_weight_mem_data, handle_));
ret = aclmdlLoadWithConfig(handle_, &inner_model_id_);
if (ret != ACL_SUCCESS) {
GELOGE(FAILED, "Failed to load model");
(void) aclmdlDestroyConfigHandle(handle_);
handle_ = nullptr;
return FAILED;
}
GELOGI("Load static model[%u] on success.", inner_model_id_);
} else {
GeExecutor executor;
if (IsEnableExceptionCatch()) {
GE_CHK_STATUS_RET(CpuTasks::CheckSupportExceptionNotify(), "aicpu not support exception catch.");
}
GE_CHK_STATUS_RET(CpuTasks::ExecuteModelEschedPriorityTask(esched_process_priority_, esched_event_priority_));
GE_CHK_STATUS_RET(CheckAicpuAlignTask(model_queue_param.input_align_attrs),
"[Check][Algin] attrs failed.");
if (!external_weight_mem_data.empty()) {
model_queue_param.file_constant_mems = &external_weight_mem_data;
}
GE_CHK_STATUS_RET(executor.LoadModelWithQueueParam(inner_model_id_, model_data, model_queue_param), "[LoadModelWithQueueParam] failed");
GE_TIMESTAMP_EVENT_END(LoadModel, "LoadModelWithQ");
}
GELOGI("[LoadModelWithQ] success, model_id = %u, model_name = %s",
inner_model_id_, root_graph->GetName().c_str());
return SUCCESS;
}
Status ExecutorContext::ModelHandle::DoUnloadModel(uint32_t model_id) {
if (handle_ != nullptr) {
(void) aclmdlUnload(model_id);
(void) aclmdlDestroyConfigHandle(handle_);
handle_ = nullptr;
} else {
GE_CHK_STATUS_RET(GeExecutor().UnloadModel(model_id), "Unload model[%u] failed.", model_id);
}
return SUCCESS;
}
std::unique_ptr<DynamicModelExecutor> ExecutorContext::ModelHandle::CreateDynamicModelExecutor(bool is_host) {
return MakeUnique<DynamicModelExecutor>(is_host);
}
std::unique_ptr<ProxyDynamicModelExecutor> ExecutorContext::ModelHandle::CreateProxyDynamicModelExecutor() {
return MakeUnique<ProxyDynamicModelExecutor>();
}
void ExecutorContext::ModelHandle::SetExecuteTimes(int32_t execute_times) {
execute_times_ = execute_times;
}
void ExecutorContext::ModelHandle::SetIsDynamicProxyControlled(const bool is_dynamic_proxy_controlled) {
is_dynamic_proxy_controlled_ = is_dynamic_proxy_controlled;
}
void ExecutorContext::ModelHandle::SetEschedPriority(int32_t esched_process_priority, int32_t esched_event_priority) {
esched_process_priority_ = esched_process_priority;
esched_event_priority_ = esched_event_priority;
}
void ExecutorContext::ModelHandle::SetScope(const std::string &scope) {
scope_ = scope;
}
const std::string &ExecutorContext::ModelHandle::GetScope() const {
return scope_;
}
bool ExecutorContext::ModelHandle::IsInvokedNN() const {
return is_invoked_nn_;
}
void ExecutorContext::ModelHandle::SetEnableExceptionCatch(bool enable_exception_catch) {
enable_exception_catch_ = enable_exception_catch;
}
bool ExecutorContext::ModelHandle::IsEnableExceptionCatch() const {
return enable_exception_catch_;
}
}