* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef GE_GRAPH_MANAGER_GRAPH_MANAGER_UTILS_H_
#define GE_GRAPH_MANAGER_GRAPH_MANAGER_UTILS_H_
#include <condition_variable>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "common/blocking_queue.h"
#include "framework/common/ge_types.h"
#include "framework/common/framework_types_internal.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/ge_inner_error_codes.h"
#include "graph/compute_graph.h"
#include "common/context/local_context.h"
#include "common/context/ome_context.h"
#include "graph/graph.h"
#include "graph/model.h"
#include "common/model/ge_model.h"
#include "common/model/ge_root_model.h"
#include "register/register_fmk_types.h"
#include "ge/ge_api_types.h"
#include "ge/ge_graph_compile_summary.h"
#include "graph/types.h"
#include "framework/common/util.h"
#include "ge/ge_allocator.h"
#include "graph/ge_tensor.h"
namespace ge {
using GraphId = uint32_t;
using ConstGraphPtr = std::shared_ptr<const ge::Graph>;
constexpr uint64_t INVALID_SESSION_ID = 0xFFFFFFFFFFFFFFFFULL;
constexpr uint32_t INVALID_GRAPH_ID = 0xFFFFFFFFU;
constexpr uint32_t kMaxLoadNum = 8U;
struct ModelIdInfo {
uint32_t model_id{INVALID_MODEL_ID};
};
class SubGraphInfo {
public:
SubGraphInfo();
~SubGraphInfo();
void SetSubGraph(const ComputeGraphPtr &sub_graph_ptr) { subgraph_ptr_ = sub_graph_ptr; }
ComputeGraphPtr GetSubGraph() const { return subgraph_ptr_; }
void SetEngineName(const std::string &engine_name) { engine_name_ = engine_name; }
const std::string &GetEngineName() const { return engine_name_; }
void SetInputFlag(const vector_bit_t &input_flag) { input_flag_ = input_flag; }
void SetOutputFlag(const vector_bit_t &output_flag) { output_flag_ = output_flag; }
void SetOutputContext(const std::string &output) { output_names_ = output; }
void SetStreamLabel(const std::string &stream_label) { stream_label_ = stream_label; }
const std::string &GetStreamLabel() const { return stream_label_; }
void SetUserStreamLabel(const std::string &user_stream_label) { user_stream_label_ = user_stream_label; }
const std::string &GetUserStreamLabel() const { return user_stream_label_; }
void SetEnd2PldMap(const std::unordered_map<ge::NodePtr, ge::NodePtr> &end_map) { end_to_pld_ = end_map; }
const std::unordered_map<ge::NodePtr, ge::NodePtr> &GetEnd2PldMap() const { return end_to_pld_; }
void SetPld2EndMap(const std::unordered_map<ge::NodePtr, ge::NodePtr> &pld_map) { pld_to_end_ = pld_map; }
const std::unordered_map<ge::NodePtr, ge::NodePtr> &GetPld2EndMap() const { return pld_to_end_; }
private:
ComputeGraphPtr subgraph_ptr_;
std::string engine_name_;
vector_bit_t input_flag_;
vector_bit_t output_flag_;
ModelIdInfo model_id_info_;
GeModelPtr ge_model_ptr_;
std::string output_names_;
std::string stream_label_;
std::string user_stream_label_;
std::unordered_map<ge::NodePtr, ge::NodePtr> end_to_pld_;
std::unordered_map<ge::NodePtr, ge::NodePtr> pld_to_end_;
};
using SubGraphInfoPtr = std::shared_ptr<ge::SubGraphInfo>;
using Graph2SubGraphInfoList = std::unordered_map<ComputeGraphPtr, std::vector<SubGraphInfoPtr>>;
using Graph2InputNodesSubGraphInfo = std::unordered_map<ComputeGraphPtr, SubGraphInfoPtr>;
class RunAsyncListener : public ModelListener {
public:
RunAsyncListener() : ModelListener(), sem_(32U), sem_v2_(32U) {}
~RunAsyncListener() override = default;
void SetCallback(const RunAsyncCallbackV2 &callback) override;
Status OnComputeDone(const uint32_t model_id, const uint32_t data_index, const uint32_t result_code,
std::vector<gert::Tensor> &outputs) override;
private:
BlockingQueue<RunAsyncCallback> sem_;
BlockingQueue<RunAsyncCallbackV2> sem_v2_;
};
enum class RunGraphMode : uint32_t {
kRunGraph,
kRunGraphAsync,
kRunGraphWithStreamAsync,
kRunGraphModeEnd
};
const char_t *GetRunGraphModeStr(RunGraphMode mode);
using InputMemoryBaseInfo = std::pair<const void *, size_t>;
class GraphNode {
public:
explicit GraphNode(const GraphId graph_id);
~GraphNode();
GraphId GetGraphId() const { return graph_id_; }
ConstGraphPtr GetGraph() const { return graph_; }
void SetGraph(const GraphPtr &graph) { graph_ = graph; }
ComputeGraphPtr GetComputeGraph() const { return compute_graph_; }
void SetComputeGraph(const ComputeGraphPtr &compute_graph) { compute_graph_ = compute_graph; }
bool GetRunFlag() const { return run_flag_; }
void SetRunFlag(const bool flag) { run_flag_ = flag; }
void SetOmeContext(const OmeContext &context) { context_ = context; }
const OmeContext &GetOmeContext() const { return context_; }
bool IsAsync() const { return async_; }
void SetAsync(const bool flag) { async_ = flag; }
bool GetCompiledFlag() const { return compiled_flag_; }
void SetCompiledFlag(const bool flag) { compiled_flag_ = flag; }
bool GetBuildFlag() const { return build_flag_; }
void SetBuildFlag(const bool buildFlag) { build_flag_ = buildFlag; }
bool GetLoadFlag() const { return load_flag_; }
RunGraphMode GetRunGraphMode() const { return run_graph_mode_; }
void SetRunGraphMode(RunGraphMode mode) { run_graph_mode_ = mode; }
void UpdateLoadFlag() { load_flag_ = ((load_count_ == 0U) || (load_record_ >= max_load_record_)); }
void SetLoadFlag(const bool load_flag) { load_flag_ = load_flag; }
void SetIsSpecificStream(const bool specific_stream) { is_specific_stream_ = specific_stream; }
bool IsSpecificStream() const { return is_specific_stream_; }
void SetGeRootModel(const GeRootModelPtr &ge_root_model) { ge_root_model_ = ge_root_model; }
GeRootModelPtr GetGeRootModel() const { return ge_root_model_; }
const std::map<std::string, std::string>& GetOptions() const { return options_; }
std::map<std::string, std::string>& GetMutableOptions() { return options_; }
void SetOptions(const std::map<std::string, std::string> &options) { options_ = options; }
void Lock();
void Unlock();
void SetSemSize(const uint32_t size) { sem_.SetMaxSize(size); }
void SetLoadCount(const uint32_t count) { load_count_ = count; }
uint32_t GetLoadRecord() const { return load_record_; }
void SetLoadRecord(const uint32_t record) { load_record_ = record; }
void IncreaseLoadCount();
void SetLoaded();
void SetFeatureBaseRefreshable(const bool refreshable) { is_feature_base_refreshable_ = refreshable; }
bool IsFeatureBaseRefreshable() const { return is_feature_base_refreshable_; }
void SetConstMemoryBase(const void * const memory, const size_t size) {
const_mem_ = std::make_pair(memory, size);
}
const InputMemoryBaseInfo &GetConstMemoryBase() const { return const_mem_; }
void SetFeatureMemoryBase(const void * const memory, const size_t size) {
feature_mem_ = std::make_pair(memory, size);
}
const InputMemoryBaseInfo &GetFeatureMemoryBase() const { return feature_mem_; }
void SetRefreshableFeatureMemoryBase(const void * const memory, const size_t size) {
refreshable_feature_mem_ = std::make_pair(memory, size);
}
const InputMemoryBaseInfo &GetRefreshableFeatureMemoryBase() const { return refreshable_feature_mem_; }
CompiledGraphSummaryPtr GetCompiledGraphSummary() const { return compiled_summary_; }
void SaveCompiledGraphSummary(const CompiledGraphSummaryPtr &summary) { compiled_summary_ = summary; }
void SetAppRefreshConstMemoryFlag() {
app_refresh_const_memory_flag_ = true;
}
bool IsAppRefreshConstMemory() const {
return app_refresh_const_memory_flag_;
}
void SetAppRefreshFeatureMemoryFlag() {
app_refresh_feature_memory_flag_ = true;
}
bool IsAppRefreshFeatureMemory() const {
return app_refresh_feature_memory_flag_;
}
void SetConstMemBlock(MemBlock *const const_mem_block) {
const_mem_block_ = const_mem_block;
}
void SetFeatureMemBlock(MemBlock *const feature_mem_block) {
feature_mem_block_ = feature_mem_block;
}
MemBlock *GetConstMemBlock() {
return const_mem_block_;
}
MemBlock *GetFeatureMemBlock() {
return feature_mem_block_;
}
void SetNetOutputNode(const NodePtr &net_output_node) { net_output_node_ = net_output_node; }
NodePtr GetNetOutputNode() const { return net_output_node_; }
void SetTensorSize(size_t tensor_size) {
(void)tensor_sizes_.emplace_back(tensor_size);
}
const std::vector<size_t> &GetTensorSize() const {
return tensor_sizes_;
}
void SetGeTensorDescPtr(GeTensorDescPtr &ge_tensor_desc) {
(void)ge_tensor_descs_.emplace_back(ge_tensor_desc);
}
const std::vector<GeTensorDescPtr> &GetGeTensorDescPtr() const {
return ge_tensor_descs_;
}
bool IsSavedNetOutputTensorInfoFlag() const {
return is_saved_net_output_tensor_info_flag_;
}
void SetSavedNetOutputTensorInfoFlag(const bool is_saved_net_output_tensor_info_flag) {
is_saved_net_output_tensor_info_flag_ = is_saved_net_output_tensor_info_flag;
}
std::unordered_set<uint32_t> &GetFrozenInputIndex() {
return frozen_input_indexes_;
}
const std::map<uint32_t, std::pair<uint64_t, uint64_t>> &GetFrozenInputInfo() const {
return frozen_index_to_node_info_;
}
Status ParseFrozenInputIndex();
void SetOriginGraphId(uint32_t graph_id) {
origin_graph_id_ = graph_id;
}
uint32_t GetOriginGraphId() const {
return origin_graph_id_;
}
bool IsForkedGraph() const {
return origin_graph_id_ != INVALID_GRAPH_ID;
}
void SetCommunicationNodes(const std::map<std::string, std::vector<std::string>> &group_2_comm_nodes) {
group_2_communication_nodes_ = group_2_comm_nodes;
}
std::map<std::string, std::vector<std::string>> GetCommunicationNodes() const {
return group_2_communication_nodes_;
}
std::mutex &GetRunMutex(){
return run_mutex_;
}
std::shared_ptr<GraphNode> Fork(uint32_t forked_graph_id);
private:
GraphId graph_id_;
GraphId origin_graph_id_{INVALID_GRAPH_ID};
std::map<std::string, std::string> options_;
bool run_flag_{false};
std::vector<SubGraphInfoPtr> subgraph_ptr_list_;
OmeContext context_;
GraphPtr graph_;
ComputeGraphPtr compute_graph_;
bool compiled_flag_{false};
bool build_flag_{false};
bool load_flag_{false};
* GeSession restricts usage to only one of the graph execution methods:
* RunGraph, RunGraphAsync, or RunGraphWithStreamAsync.
* These methods are mutually exclusive and must not be mixed.
*/
RunGraphMode run_graph_mode_{RunGraphMode::kRunGraphModeEnd};
bool async_{false};
bool is_specific_stream_{false};
GeModelPtr ge_model_;
GeRootModelPtr ge_root_model_;
BlockingQueue<uint8_t> sem_;
uint32_t load_count_{0U};
uint32_t load_record_{0U};
uint32_t max_load_record_{kMaxLoadNum};
std::mutex load_count_mu_;
bool is_feature_base_refreshable_ = false;
InputMemoryBaseInfo const_mem_;
InputMemoryBaseInfo feature_mem_;
InputMemoryBaseInfo refreshable_feature_mem_;
CompiledGraphSummaryPtr compiled_summary_{nullptr};
bool app_refresh_const_memory_flag_{false};
bool app_refresh_feature_memory_flag_{false};
MemBlock *const_mem_block_{nullptr};
MemBlock *feature_mem_block_{nullptr};
NodePtr net_output_node_;
std::vector<size_t> tensor_sizes_;
bool is_saved_net_output_tensor_info_flag_{false};
std::vector<GeTensorDescPtr> ge_tensor_descs_;
std::unordered_set<uint32_t> frozen_input_indexes_;
std::map<uint32_t, std::pair<uint64_t, uint64_t>> frozen_index_to_node_info_;
std::map<std::string, std::vector<std::string>> group_2_communication_nodes_;
std::mutex run_mutex_;
};
using GraphNodePtr = std::shared_ptr<GraphNode>;
class GraphModelListener : public ge::ModelListener {
public:
GraphModelListener();
~GraphModelListener() override = default;
Status OnComputeDone(const uint32_t model_id, const uint32_t data_index, const uint32_t result_code,
std::vector<gert::Tensor> &outputs) override;
uint32_t GetResultCode() override;
Status ResetResult() override;
private:
uint32_t result_code_ = 0U;
bool is_finished_ = false;
std::mutex mutex_;
std::condition_variable condition_;
};
struct GraphManagerOptions {
int32_t stream_num{1};
int32_t perf_level{domi::GEN_TASK_WITHOUT_FUSION};
int32_t encrypt_mode{-1};
int32_t framework_type{domi::TENSORFLOW};
std::string ek_file;
std::string cert_file;
std::string hw_key_file;
std::string private_key_file;
std::string calibration_conf_file;
std::string insert_op_file;
std::string input_format;
std::string output_node_name;
std::string func_bin_path;
std::string input_nodes_set_fp16;
std::string core_type;
bool compress_flag{false};
bool run_graph_flag{false};
bool train_graph_flag{false};
bool local_fmk_op_flag{false};
bool hcom_parallel{false};
bool enable_print_op_pass{false};
bool is_single_op{false};
std::string dynamic_image_size;
std::map<std::string, int32_t> stream_max_parallel_num;
std::string output_datatype;
std::string original_model_file;
std::string save_original_model{"false"};
std::string build_mode;
std::string build_step;
std::string tuning_path;
std::string input_shape;
std::string dynamic_dims;
int32_t dynamic_node_type = -1;
std::set<std::string> exclude_engines;
std::string build_inner_model{"true"};
std::string event{"event"};
GraphManagerOptions() {}
};
}
#endif