* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef GE_MODEL_GE_ROOT_MODEL_H_
#define GE_MODEL_GE_ROOT_MODEL_H_
#include <mutex>
#include <map>
#include <sstream>
#include "framework/common/ge_types.h"
#include "ge/ge_graph_compile_summary.h"
#include "ge/ge_allocator.h"
#include "graph/compute_graph.h"
#include "common/model/ge_model.h"
#include "common/op_so_store/op_so_store.h"
#include "common/memory/mem_type_utils.h"
#include "common/memory/feature_memory_impl.h"
#include "common/host_resource_center/host_resource_center.h"
#include "ge/ge_ir_build.h"
namespace ge {
class PortableOp;
struct FixedFeatureMemory {
std::string ToString() const {
std::stringstream ss;
ss << "rts memory type: " << MemTypeUtils::ToString(type) << ", addr: " << std::hex << addr << ", size: "
<< std::dec << size << ", user_alloc: " << user_alloc << ", ge_alloc: " << ge_alloc << ", block: "
<< std::hex << block << ", session_allocator: " << std::dec << is_session_allocator
<< ", session_id: " << session_id;
return ss.str();
}
rtMemType_t type;
void *addr;
size_t size;
bool user_alloc;
bool ge_alloc;
bool is_session_allocator;
uint64_t session_id;
MemBlock *block;
};
class GeRootModel : public std::enable_shared_from_this<GeRootModel> {
public:
GeRootModel() = default;
~GeRootModel() = default;
Status Initialize(const ComputeGraphPtr &root_graph);
Status ModifyOwnerGraphForSubModels();
void SetSubgraphInstanceNameToModel(const std::string &instance_name, const GeModelPtr &ge_model);
void RemoveInstanceSubgraphModel(const std::string &instance_name);
const std::map<std::string, GeModelPtr> &GetSubgraphInstanceNameToModel() const {
return subgraph_instance_name_to_model_;
};
void SetModelId(uint32_t model_id) {
model_id_ = model_id;
const std::lock_guard<std::mutex> lock(model_ids_mutex_);
(void)model_ids_.emplace_back(model_id);
}
uint32_t GetModelId() const { return model_id_; }
void SetIsSpecificStream(const bool is_specific_stream) { is_specific_stream_ = is_specific_stream; }
bool IsSpecificStream() const { return is_specific_stream_; }
std::vector<uint32_t> GetAllModelId() const { return model_ids_; }
void ClearAllModelId() { model_ids_.clear(); }
Status CheckIsUnknownShape(bool &is_dynamic_shape) const;
void SetWeightSize(const int64_t weight_size) { total_weight_size_ = weight_size; }
int64_t GetWeightSize() const { return total_weight_size_; }
void SetFlattenGraph(const ComputeGraphPtr &flatten_graph) {
const std::lock_guard<std::mutex> lock(model_ids_mutex_);
flatten_graph_ = flatten_graph;
}
ComputeGraphPtr GetFlattenGraph() const { return flatten_graph_; }
void SetNodesToTaskDef(const std::unordered_map<ge::NodePtr, std::vector<domi::TaskDef>> &nodes_2_task_def) {
const std::lock_guard<std::mutex> lock(model_ids_mutex_);
nodes_to_task_defs_ = nodes_2_task_def;
}
const std::unordered_map<ge::NodePtr, std::vector<domi::TaskDef>> &GetNodesToTaskDef() const {
return nodes_to_task_defs_;
}
void SetGraphToStaticModels(const std::unordered_map<std::string, ge::GeModelPtr> &graph_2_static_models) {
const std::lock_guard<std::mutex> lock(model_ids_mutex_);
graph_to_static_models_ = graph_2_static_models;
}
const std::unordered_map<std::string, ge::GeModelPtr> &GetGraphToStaticModels() const {
return graph_to_static_models_;
}
const uint8_t *GetOpSoStoreData() const;
size_t GetOpStoreDataSize() const;
bool LoadSoBinData(const uint8_t *const data, const size_t len);
std::vector<OpSoBinPtr> GetAllSoBin() const;
Status CheckAndSetNeedSoInOM();
uint16_t GetSoInOmFlag() const;
void SetSoInOmInfo(const SoInOmInfo &so_info);
SoInOmInfo GetSoInOmInfo() const;
void SetFileConstantWeightDir(const std::string &file_constant_weight_dir) {
file_constant_weight_dir_ = file_constant_weight_dir;
}
const std::string GetFileConstantWeightDir() const {
return file_constant_weight_dir_;
}
uint32_t GetCurModelId() const { return cur_model_id_; }
void SetCurModelId(uint32_t model_id) { cur_model_id_ = model_id; }
const std::map<rtMemType_t, FixedFeatureMemory> &GetFixedFeatureMemory() const {
return fixed_feature_mems_;
}
std::map<rtMemType_t, FixedFeatureMemory> &MutableFixedFeatureMemory() {
return fixed_feature_mems_;
}
Status GetSummaryFeatureMemory(std::vector<FeatureMemoryPtr> &all_feature_memory,
size_t &hbm_fixed_feature_mem);
bool IsNeedMallocFixedFeatureMem() const;
bool IsNeedMallocFixedFeatureMemByType(const rtMemType_t rt_mem_type) const;
HostResourceCenterPtr GetHostResourceCenterPtr() const;
const std::unordered_set<std::string> &GetOpMasterDeviceSoSet() const {
return op_master_device_so_set_;
}
const std::unordered_set<std::string> &GetAutofuseSoSet() const {
return autofuse_so_set_;
}
const std::unordered_set<std::string> &GetCustomOpSoSet() const {
return custom_op_so_set_;
}
std::shared_ptr<GeRootModel> Fork();
inline void SetRootGraph(const ComputeGraphPtr &graph) {
root_graph_ = graph;
}
inline const ComputeGraphPtr &GetRootGraph() const {
return root_graph_;
}
inline void SetModelName(const std::string &model_name) {
model_name_ = model_name;
}
inline const std::string &GetModelName() const {
return model_name_;
}
private:
Status CheckAndSetSpaceRegistry();
Status CheckAndSetOpMasterDevice();
Status CheckAndSetAutofuseSo();
Status CheckAndSetCustomOpSo();
Status GetTargetHostEnv(std::string &host_env_os, std::string &host_env_cpu) const;
bool IsCrossCompileTarget(const std::string &target_os, const std::string &target_cpu) const;
Status ResolvePortableOpSoPath(const std::string &op_type, PortableOp *portable_op, std::string &so_path) const;
Status CheckSoArchMatchesTarget(const std::string &so_path, const std::string &target_cpu) const;
Status CollectCustomOpSoFromCustomOppPath(const std::string &target_os, const std::string &target_cpu);
ComputeGraphPtr root_graph_ = nullptr;
std::map<std::string, GeModelPtr> subgraph_instance_name_to_model_;
std::vector<uint32_t> model_ids_;
std::mutex model_ids_mutex_;
bool is_specific_stream_ = false;
int64_t total_weight_size_ = 0;
std::unordered_map<ge::NodePtr, std::vector<domi::TaskDef>> nodes_to_task_defs_;
std::unordered_map<std::string, ge::GeModelPtr> graph_to_static_models_;
ComputeGraphPtr flatten_graph_ = nullptr;
OpSoStore op_so_store_;
uint16_t so_in_om_ = 0U;
SoInOmInfo so_info_ = {};
std::string file_constant_weight_dir_;
std::string model_name_;
uint32_t model_id_ = INVALID_MODEL_ID;
uint32_t cur_model_id_ = 0U;
bool all_feature_memory_init_flag_ = false;
std::vector<FeatureMemoryPtr> all_feature_memory_;
std::map<rtMemType_t, FixedFeatureMemory> fixed_feature_mems_;
HostResourceCenterPtr host_resource_center_ = ge::MakeShared<HostResourceCenter>();;
std::unordered_set<std::string> op_master_device_so_set_{};
std::unordered_set<std::string> autofuse_so_set_{};
std::unordered_set<std::string> custom_op_so_set_{};
};
using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>;
}
#endif