* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef AIR_RUNTIME_DEPLOY_EXECUTOR_EXECUTOR_CONTEXT_H_
#define AIR_RUNTIME_DEPLOY_EXECUTOR_EXECUTOR_CONTEXT_H_
#include <vector>
#include <map>
#include <memory>
#include "ge/ge_api_error_codes.h"
#include "common/model/ge_model.h"
#include "graph/manager/graph_var_manager.h"
#include "graph/model.h"
#include "framework/common/framework_types_internal.h"
#include "framework/common/helper/om_file_helper.h"
#include "executor/dynamic_model_executor.h"
#include "executor/proxy_dynamic_model_executor.h"
#include "proto/deployer.pb.h"
#include "dflow/inc/data_flow/model/pne_model.h"
#include "external/ge/ge_ir_build.h"
#include "acl/acl.h"
namespace ge {
class ExecutorContext {
public:
ExecutorContext() = default;
virtual ~ExecutorContext() = default;
static ExecutorContext &LocalContext();
class ModelHandle {
public:
struct LoadParam {
uint32_t replica_num = 1U;
uint32_t replica_idx = 0U;
std::vector<QueueAttrs> input_queues;
std::vector<QueueAttrs> output_queues;
std::vector<int32_t> input_fusion_offsets;
QueueAttrs status_output_queue;
uint32_t model_uuid = 0U;
bool is_dynamic_sched = false;
bool need_report_status = false;
bool is_head = false;
InputAlignAttrs input_align_attrs{};
};
ModelHandle() = default;
virtual ~ModelHandle();
GE_DELETE_ASSIGN_AND_COPY(ModelHandle);
virtual Status ParseModel(const std::string &model_path);
virtual Status LoadModel(const LoadParam ¶m);
Status UnloadModel();
void SetExecuteTimes(int32_t execute_times);
void SetEschedPriority(int32_t esched_process_priority, int32_t esched_event_priority);
void SetIsDynamicProxyControlled(const bool is_dynamic_proxy_controlled);
void SetScope(const std::string &scope);
const std::string &GetScope() const;
bool IsInvokedNN() const;
void SetEnableExceptionCatch(bool enable_exception_catch);
bool IsEnableExceptionCatch() const;
Status GetModelData(ModelData &model_data);
void SetModelData(const ModelData &model_data);
Status GetRootGraph(ComputeGraphPtr &root_graph);
void SetRootGraph(const ComputeGraphPtr &root_graph);
virtual Status GetModelRuntimeIdOrHandle(std::vector<uint32_t> &davinci_model_runtime_ids,
std::vector<ExecutorContext::ModelHandle *> &dynamic_model_handles);
virtual Status ClearModel(const int32_t clear_type);
virtual Status ExceptionNotify(uint32_t type, uint64_t trans_id);
protected:
virtual Status DoLoadModel(const ModelData &model_data,
const ComputeGraphPtr &root_graph,
const LoadParam ¶ms);
virtual Status DoLoadModelWithQ(const ModelData &model_data,
const ComputeGraphPtr &root_graph,
const LoadParam ¶ms);
virtual Status DoUnloadModel(uint32_t model_id);
virtual std::unique_ptr<DynamicModelExecutor> CreateDynamicModelExecutor(bool is_host);
virtual std::unique_ptr<ProxyDynamicModelExecutor> CreateProxyDynamicModelExecutor();
private:
static Status CheckAicpuAlignTask(const InputAlignAttrs &input_align_attrs);
uint32_t inner_model_id_ = UINT32_MAX;
std::unique_ptr<DynamicModelExecutor> dynamic_model_executor_;
bool loaded_ = false;
int32_t esched_process_priority_ = -1;
int32_t esched_event_priority_ = -1;
ModelData model_data_;
bool model_data_from_cache_ = false;
ComputeGraphPtr root_graph_;
int32_t execute_times_ = -1;
bool is_dynamic_proxy_controlled_ = false;
bool is_invoked_nn_ = false;
bool enable_exception_catch_ = false;
std::string scope_;
aclmdlConfigHandle *handle_ = nullptr;
};
Status Initialize() const;
void Finalize() const;
void SetBaseDir(const std::string &base_dir);
virtual Status GetModel(uint32_t root_model_id, std::map<uint32_t, std::unique_ptr<ModelHandle>> *&submodel_map);
Status SyncSharedVarManager(const deployer::ExecutorRequest &request) const;
Status ParseModel(const deployer::ExecutorRequest_LoadModelRequest &request);
static Status AttachQueues(const deployer::ExecutorRequest_LoadModelRequest &request);
Status LoadModel(const deployer::ExecutorRequest_LoadModelRequest &request);
PneModelPtr GetLocalModel(uint32_t root_model_id, uint32_t model_id);
void AddLocalModel(uint32_t root_model_id, uint32_t model_id, const PneModelPtr &model);
void RemoveLocalModel(uint32_t root_model_id);
static Status SetOpTimeout();
static Status SetDeviceSatMode();
static void UpdateOptions(const deployer::Options &options);
static Status UpdateProfInfo(const deployer::ExecutorRequest &request);
protected:
virtual std::unique_ptr<std::istream> CreateInputStream(const std::string &path) const;
virtual ModelHandle *GetOrCreateModelHandle(uint32_t root_model_id, uint32_t model_id);
private:
Status ParseModel(uint32_t root_model_id, uint32_t model_id, const std::string &model_path);
Status ParseModelEschedPriority(const deployer::ExecutorRequest_LoadModelRequest &request, ModelHandle &handle) const;
static Status ParseInputAlignAttrs(const deployer::ExecutorRequest_LoadModelRequest &request,
InputAlignAttrs &input_align_attrs);
static void UpdateGraphOptions(const std::string &key, const std::string &value);
std::mutex mu_;
std::map<uint32_t, std::map<uint32_t, std::unique_ptr<ModelHandle>>> model_handles_;
std::map<uint32_t, std::map<uint32_t, PneModelPtr>> models_;
std::string base_dir_;
};
}
#endif