* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "node_compile_cache_module.h"
#include <securec.h>
#include "ge/ge_api_error_codes.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "graph/cache_policy/policy_register.h"
#include "graph/operator_factory.h"
#include "graph/utils/op_desc_utils.h"
#include "common/math/math_util.h"
#include "graph/compute_graph.h"
#include "graph/utils/node_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "common/plugin/ge_make_unique_util.h"
#include "base/err_msg.h"
namespace {
constexpr ge::char_t const *kAttrSupportDynamicShape = "support_dynamicshape";
template<typename T>
size_t AttrValueSizeByType(const ge::AnyValue &attr_value) {
(void)attr_value;
return sizeof(T);
}
template<>
size_t AttrValueSizeByType<std::string>(const ge::AnyValue &attr_value) {
std::string val;
(void)attr_value.GetValue<std::string>(val);
return val.length();
}
template<typename T>
size_t ListAttrValueSizeByType(const ge::AnyValue &attr_value)
{
std::vector<T> values;
(void)attr_value.GetValue<std::vector<T>>(values);
return (sizeof(T) * values.size());
}
template<>
size_t ListAttrValueSizeByType<std::string>(const ge::AnyValue &attr_value)
{
std::vector<std::string> values;
(void)attr_value.GetValue<std::vector<std::string>>(values);
size_t str_size = 0U;
for (const auto &val : values) {
str_size += val.length();
}
return str_size;
}
template<typename T>
size_t ListListAttrValueSizeByType(const ge::AnyValue &attr_value)
{
std::vector<std::vector<T>> values;
(void)attr_value.GetValue<std::vector<std::vector<T>>>(values);
size_t cnt = 0U;
for (const auto &vals : values) {
cnt += vals.size();
}
return (sizeof(T) * cnt);
}
template<typename T>
ge::Status CopyAttrValueSizeByType(const ge::AnyValue &attr_value,
uint8_t *base, const size_t max_size, size_t &offset) {
T val;
(void)attr_value.GetValue<T>(val);
const auto mem_ret = memcpy_s((base + offset), (max_size - offset), &val, sizeof(T));
if (mem_ret != EOK) {
GELOGE(ge::FAILED, "memcpy failed.");
return ge::FAILED;
}
offset += sizeof(T);
return ge::SUCCESS;
}
template<>
ge::Status CopyAttrValueSizeByType<std::string>(const ge::AnyValue &attr_value,
uint8_t *base, const size_t max_size, size_t &offset) {
std::string val;
(void)attr_value.GetValue<std::string>(val);
if (val.empty()) {
return ge::SUCCESS;
}
const auto mem_ret = memcpy_s((base + offset), (max_size - offset), val.data(), val.length());
if (mem_ret != EOK) {
GELOGE(ge::FAILED, "memcpy failed.");
return ge::FAILED;
}
offset += val.length();
return ge::SUCCESS;
}
template<typename T>
ge::Status CopyListAttrValueSizeByType(const ge::AnyValue &attr_value,
uint8_t *base, const size_t max_size, size_t &offset)
{
std::vector<T> values;
(void)attr_value.GetValue<std::vector<T>>(values);
for (const auto &val : values) {
T tmp_val = val;
const auto mem_ret = memcpy_s((base + offset), (max_size - offset), &tmp_val, sizeof(T));
if (mem_ret != EOK) {
GELOGE(ge::FAILED, "memcpy failed.");
return ge::FAILED;
}
offset += sizeof(T);
}
return ge::SUCCESS;
}
template<>
ge::Status CopyListAttrValueSizeByType<std::string>(const ge::AnyValue &attr_value,
uint8_t *base, const size_t max_size, size_t &offset)
{
std::vector<std::string> values;
(void)attr_value.GetValue<std::vector<std::string>>(values);
for (const auto &val : values) {
if (val.empty()) {
continue;
}
const auto mem_ret = memcpy_s((base + offset), (max_size - offset), val.data(), val.length());
if (mem_ret != EOK) {
GELOGE(ge::FAILED, "memcpy failed.");
return ge::FAILED;
}
offset += val.length();
}
return ge::SUCCESS;
}
template<typename T>
ge::Status CopyListListAttrValueSizeByType(const ge::AnyValue &attr_value,
uint8_t *base, const size_t max_size, size_t &offset)
{
std::vector<std::vector<T>> values;
(void)attr_value.GetValue<std::vector<std::vector<T>>>(values);
for (const auto &vals : values) {
for (const auto &val : vals) {
const auto mem_ret = memcpy_s((base + offset), (max_size - offset), &val, sizeof(T));
if (mem_ret != EOK) {
GELOGE(ge::FAILED, "memcpy failed.");
return ge::FAILED;
}
offset += sizeof(T);
}
}
return ge::SUCCESS;
}
}
namespace ge {
Status NodeCompileCacheItem::Build(const KernelLaunchBinType bin_type, const NodePtr &node, void *handle,
NodeCompileCacheItem &item) {
item.bin_type_ = bin_type;
item.handle_ = handle;
const auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if ((!ge::AttrUtils::GetStr(op_desc, COMPILE_INFO_KEY, item.op_compile_info_.key)) ||
item.op_compile_info_.key.empty()) {
GELOGW("Op[%s] does not have attr[%s].", op_desc->GetName().c_str(), COMPILE_INFO_KEY.c_str());
}
if ((!ge::AttrUtils::GetStr(op_desc, COMPILE_INFO_JSON, item.op_compile_info_.str)) ||
item.op_compile_info_.str.empty()) {
GELOGW("Op[%s] does not have attr[%s].", op_desc->GetName().c_str(), COMPILE_INFO_JSON.c_str());
}
if ((!ge::AttrUtils::GetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, item.atomic_op_compile_info_.key)) ||
item.atomic_op_compile_info_.key.empty()) {
GELOGW("Op[%s] does not have attr[%s].", op_desc->GetName().c_str(), ATOMIC_COMPILE_INFO_KEY.c_str());
}
if ((!ge::AttrUtils::GetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, item.atomic_op_compile_info_.str)) ||
item.atomic_op_compile_info_.str.empty()) {
GELOGW("Op[%s] does not have attr[%s].", op_desc->GetName().c_str(), ATOMIC_COMPILE_INFO_JSON.c_str());
}
if (!ge::AttrUtils::GetInt(op_desc, kAttrOpParamSize, item.max_tiling_size_)) {
GELOGW("Op[%s] does not have attr[%s].", op_desc->GetName().c_str(), kAttrOpParamSize);
}
if (!ge::AttrUtils::GetInt(op_desc, kAttrAtomicOpParamSize, item.atomic_max_tiling_size_)) {
GELOGW("Op[%s] does not have attr[%s].", op_desc->GetName().c_str(), kAttrAtomicOpParamSize);
}
if (!ge::AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, item.is_dynamic_)) {
GELOGW("Op[%s] does not have attr[%s].", op_desc->GetName().c_str(), kAttrSupportDynamicShape);
}
return SUCCESS;
}
uint64_t NodeCompileCacheItem::GetCacheItemId() const {
return cache_item_id_;
}
void NodeCompileCacheItem::SetCacheItemId(const uint64_t cache_item_id) {
cache_item_id_ = cache_item_id;
}
void *NodeCompileCacheItem::GetBinHandle() const {
return handle_;
}
KernelLaunchBinType NodeCompileCacheItem::GetBinType() const {
return bin_type_;
}
const optiling::OpCompileInfo *NodeCompileCacheItem::GetCompileInfo() const {
return &op_compile_info_;
}
const optiling::OpCompileInfo *NodeCompileCacheItem::GetAtomicCompileInfo() const {
return &atomic_op_compile_info_;
}
int64_t NodeCompileCacheItem::GetMaxTilingSize() const {
return max_tiling_size_;
}
int64_t NodeCompileCacheItem::GetAtomicMaxTilingSize() const {
return atomic_max_tiling_size_;
}
bool NodeCompileCacheItem::IsSupportDynamic() const {
return is_dynamic_;
}
static Status GetTensorInfos(const OpDesc &op_desc, CompileCacheDesc &cache_desc, const bool need_range) {
for (size_t i = 0U; i < op_desc.GetInputsSize(); ++i) {
auto input_desc = op_desc.MutableInputDesc(static_cast<uint32_t>(i));
if (input_desc == nullptr) {
continue;
}
TensorInfoArgs tensor_info_args(input_desc->GetFormat(), input_desc->GetOriginFormat(), input_desc->GetDataType());
const auto &dims = input_desc->MutableShape().GetMutableDims();
tensor_info_args.SetShape(dims);
const auto &origin_dims = input_desc->GetOriginShape().GetMutableDims();
tensor_info_args.SetOriginShape(origin_dims);
if (need_range) {
std::vector<std::pair<int64_t, int64_t>> ranges;
(void)input_desc->GetShapeRange(ranges);
tensor_info_args.SetShapeRange(ranges);
}
cache_desc.AddTensorInfo(tensor_info_args);
}
return SUCCESS;
}
static Status GetOriginAttr(const std::string &op_type, std::set<string> &ordered_origin_attr) {
auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_type.c_str());
if (node_op.IsEmpty()) {
GELOGE(FAILED, "get op from OperatorFactory fail. opType: %s", op_type.c_str());
return FAILED;
}
GELOGD("get op from OperatorFactory success. opType is %s", op_type.c_str());
auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
node_op.BreakConnect();
if (temp_op_desc == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "GetOpDescFromOperator failed, return nullptr.");
GELOGE(FAILED, "[Get][OpDesc] temp op desc is null");
return FAILED;
}
const auto &ir_origin_attr_names = temp_op_desc->GetIrAttrNames();
ordered_origin_attr.insert(ir_origin_attr_names.cbegin(), ir_origin_attr_names.cend());
GELOGD("get origin attr name success, size is %zu", ordered_origin_attr.size());
return SUCCESS;
}
size_t NodeCompileCacheModule::GetAttrSize(const AnyValue &attr_value) const {
switch (attr_value.GetValueType()) {
case ge::GeAttrValue::VT_STRING:
return AttrValueSizeByType<std::string>(attr_value);
case ge::GeAttrValue::VT_BOOL:
return AttrValueSizeByType<bool>(attr_value);
case ge::GeAttrValue::VT_INT:
return AttrValueSizeByType<int64_t>(attr_value);
case ge::GeAttrValue::VT_FLOAT:
return AttrValueSizeByType<float32_t>(attr_value);
case ge::GeAttrValue::VT_DATA_TYPE:
return AttrValueSizeByType<ge::GeAttrValue::DATA_TYPE>(attr_value);
case ge::GeAttrValue::VT_LIST_STRING:
return ListAttrValueSizeByType<std::string>(attr_value);
case ge::GeAttrValue::VT_LIST_BOOL:
return ListAttrValueSizeByType<bool>(attr_value);
case ge::GeAttrValue::VT_LIST_INT:
return ListAttrValueSizeByType<int64_t>(attr_value);
case ge::GeAttrValue::VT_LIST_FLOAT:
return ListAttrValueSizeByType<float32_t>(attr_value);
case ge::GeAttrValue::VT_LIST_DATA_TYPE:
return ListAttrValueSizeByType<ge::GeAttrValue::DATA_TYPE>(attr_value);
case ge::GeAttrValue::VT_LIST_LIST_INT:
return ListListAttrValueSizeByType<int64_t>(attr_value);
default:
GELOGD("unsupported type %d", attr_value.GetValueType());
return 0U;
}
}
Status NodeCompileCacheModule::CopyAttrValues(const AnyValue &attr_value,
uint8_t *base,
const size_t max_size,
size_t &offset) const {
switch (attr_value.GetValueType()) {
case ge::GeAttrValue::VT_STRING:
return CopyAttrValueSizeByType<std::string>(attr_value, base, max_size, offset);
case ge::GeAttrValue::VT_BOOL:
return CopyAttrValueSizeByType<bool>(attr_value, base, max_size, offset);
case ge::GeAttrValue::VT_INT:
return CopyAttrValueSizeByType<int64_t>(attr_value, base, max_size, offset);
case ge::GeAttrValue::VT_FLOAT:
return CopyAttrValueSizeByType<float32_t>(attr_value, base, max_size, offset);
case ge::GeAttrValue::VT_DATA_TYPE:
return CopyAttrValueSizeByType<ge::GeAttrValue::DATA_TYPE>(attr_value, base, max_size, offset);
case ge::GeAttrValue::VT_LIST_STRING:
return CopyListAttrValueSizeByType<std::string>(attr_value, base, max_size, offset);
case ge::GeAttrValue::VT_LIST_BOOL:
return CopyListAttrValueSizeByType<bool>(attr_value, base, max_size, offset);
case ge::GeAttrValue::VT_LIST_INT:
return CopyListAttrValueSizeByType<int64_t>(attr_value, base, max_size, offset);
case ge::GeAttrValue::VT_LIST_FLOAT:
return CopyListAttrValueSizeByType<float32_t>(attr_value, base, max_size, offset);
case ge::GeAttrValue::VT_LIST_DATA_TYPE:
return CopyListAttrValueSizeByType<ge::GeAttrValue::DATA_TYPE>(attr_value, base, max_size, offset);
case ge::GeAttrValue::VT_LIST_LIST_INT:
return CopyListListAttrValueSizeByType<int64_t>(attr_value, base, max_size, offset);
default:
GELOGD("unsupported type %d, no need to copy", attr_value.GetValueType());
return SUCCESS;
}
}
Status NodeCompileCacheModule::GetAttrTotalSize(const std::map<std::string, AnyValue> &all_attributes,
const std::set<string> &ordered_origin_attr_name, size_t &attr_size) const {
for (const auto &name : ordered_origin_attr_name) {
GELOGD("current origin attr name is %s", name.c_str());
auto it = all_attributes.find(name);
if (it != all_attributes.end()) {
const AnyValue &attr_value = it->second;
FMK_SIZET_ADDCHECK(attr_size, it->first.length());
attr_size += it->first.length();
const size_t current_attr_size = GetAttrSize(attr_value);
GELOGD("find attr name %s, size is %zu", it->first.c_str(), current_attr_size);
FMK_SIZET_ADDCHECK(attr_size, current_attr_size);
attr_size += current_attr_size;
} else {
GELOGD("cannot get attr name %s", name.c_str());
}
}
return SUCCESS;
}
Status NodeCompileCacheModule::CopyAttrToMem(const std::map<std::string, AnyValue> &all_attributes,
std::unique_ptr<uint8_t[]> &attr_mem,
const std::set<string> &ordered_origin_attr_name,
const size_t attr_size) const {
size_t offset = 0U;
for (const auto &name : ordered_origin_attr_name) {
auto it = all_attributes.find(name);
if (it != all_attributes.end()) {
const AnyValue &attr_value = it->second;
FMK_SIZET_SUBCHECK(attr_size, offset);
const auto mem_ret = memcpy_s((attr_mem.get() + offset), (attr_size - offset),
it->first.data(), it->first.length());
if (mem_ret != EOK) {
GELOGE(FAILED, "memcpy failed.");
return FAILED;
}
FMK_SIZET_ADDCHECK(offset, it->first.length());
offset += it->first.length();
if (CopyAttrValues(attr_value, attr_mem.get(), attr_size, offset) != SUCCESS) {
GELOGE(FAILED, "copy attr mem failed.");
return FAILED;
}
}
}
return SUCCESS;
}
Status NodeCompileCacheModule::GetOpAttrMem(OpDesc &op_desc, CompileCacheDesc &cache_desc) const {
if (op_desc.HasAttr("_origin_attr_value_bytes")) {
Buffer attr_mem;
if (AttrUtils::GetZeroCopyBytes(op_desc, "_origin_attr_value_bytes", attr_mem)) {
GELOGD("this op %s has attr mem", op_desc.GetName().c_str());
BinaryHolder binary_holder(attr_mem.GetData(), attr_mem.GetSize());
cache_desc.AddBinary(binary_holder);
return SUCCESS;
}
}
std::set<string> ordered_origin_attr_name;
GE_CHK_STATUS_RET_NOLOG(GetOriginAttr(op_desc.GetType(), ordered_origin_attr_name));
GELOGD("get current origin attr size is %zu", ordered_origin_attr_name.size());
if (ordered_origin_attr_name.empty()) {
return SUCCESS;
}
size_t attr_size = 0U;
const auto &all_attributes = op_desc.GetAllAttrs();
GELOGD("get current attr size is %zu", all_attributes.size());
GE_CHK_STATUS_RET_NOLOG(GetAttrTotalSize(all_attributes, ordered_origin_attr_name, attr_size));
if (attr_size == 0U) {
return SUCCESS;
}
GELOGD("get total size is %zu", attr_size);
auto attr_mem = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[attr_size]);
GE_CHECK_NOTNULL(attr_mem);
GE_CHK_STATUS_RET_NOLOG(CopyAttrToMem(all_attributes, attr_mem, ordered_origin_attr_name, attr_size));
Buffer attr_buf = Buffer::CopyFrom(attr_mem.get(), attr_size);
(void)AttrUtils::SetZeroCopyBytes(op_desc, "_origin_attr_value_bytes", std::move(attr_buf));
auto binary_holder = BinaryHolder::createFrom(std::move(attr_mem), attr_size);
cache_desc.AddBinary(std::move(*binary_holder));
return SUCCESS;
}
NodeCompileCacheModule::NodeCompileCacheModule()
: ccp_(CachePolicy::Create(MatchPolicyType::MATCH_POLICY_EXACT_ONLY, AgingPolicyType::AGING_POLICY_LRU)) {}
void NodeCompileCacheModule::Initialize() {
if (ccp_ == nullptr) {
ccp_ = CachePolicy::Create(MatchPolicyType::MATCH_POLICY_EXACT_ONLY, AgingPolicyType::AGING_POLICY_LRU);
GELOGD("Initialize ccm.");
}
}
void NodeCompileCacheModule::Finalize() {
ccp_.reset(nullptr);
ids_to_cci_.clear();
GELOGD("Finalize ccm.");
}
Status NodeCompileCacheModule::GetFusionOpCacheDesc(const NodePtr &node, CompileCacheDesc &cache_desc) const {
cache_desc.SetOpType(node->GetName());
auto compute_graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(compute_graph);
const uint32_t graph_id = compute_graph->GetGraphID();
const uint64_t session_id = compute_graph->GetSessionID();
cache_desc.SetScopeId({session_id, graph_id});
return SUCCESS;
}
Status NodeCompileCacheModule::GetInputConstTensor(const NodePtr &node, CompileCacheDesc &cache_desc) const {
for (size_t index = 0U; index < node->GetAllInDataAnchorsSize(); ++index) {
NodePtr input_node = nullptr;
(void)NodeUtils::GetInNodeCrossPartionedCallNode(node, static_cast<uint32_t>(index), input_node);
if ((input_node != nullptr) && (input_node->GetOpDesc() != nullptr)) {
GeTensorPtr const_tensor = nullptr;
(void)AttrUtils::MutableTensor(input_node->GetOpDesc(), ATTR_NAME_WEIGHTS, const_tensor);
if ((const_tensor != nullptr) && (const_tensor->GetData().data() != nullptr) &&
(const_tensor->GetData().size() > 0U)) {
GELOGD("find node %s input node %s which has weight", node->GetName().c_str(), input_node->GetName().c_str());
BinaryHolder holder_const(const_tensor->GetData().data(), const_tensor->GetData().size());
GE_CHECK_NOTNULL(holder_const.GetDataPtr());
cache_desc.AddBinary(std::move(holder_const));
}
}
}
return SUCCESS;
}
Status NodeCompileCacheModule::GetCompileCacheDescFromOp(const NodePtr &node,
std::shared_ptr<CompileCacheDesc> &cache_desc,
const bool need_range) const {
const OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
ge::ComputeGraphPtr graph_ptr = nullptr;
GE_CHK_STATUS_RET_NOLOG(GetTensorInfos(*op_desc, *cache_desc, need_range));
GE_CHK_STATUS_RET_NOLOG(GetInputConstTensor(node, *cache_desc));
if (ge::AttrUtils::GetGraph(op_desc, "_original_fusion_graph", graph_ptr)) {
GELOGD("This is fusion op %s", node->GetName().c_str());
GE_CHK_STATUS_RET_NOLOG(GetFusionOpCacheDesc(node, *cache_desc));
return SUCCESS;
}
cache_desc->SetOpType(op_desc->GetType());
GE_CHK_STATUS_RET_NOLOG(GetOpAttrMem(*op_desc, *cache_desc));
return SUCCESS;
}
void NodeCompileCacheModule::UpdateTensorInfos(const NodePtr &node, CompileCacheDesc &cache_desc) const {
const auto &op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
GELOGW("op_desc of node:[%s] is nullptr.", node->GetName().c_str());
return;
}
size_t index = 0U;
for (size_t i = 0U; i < op_desc->GetInputsSize(); ++i) {
auto input_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(i));
if (input_desc == nullptr) {
continue;
}
const auto &dims = input_desc->MutableShape().GetMutableDims();
const auto &origin_dims = input_desc->GetOriginShape().GetMutableDims();
auto tensor_info = cache_desc.MutableTensorInfo(index);
++index;
if (tensor_info == nullptr) {
continue;
}
tensor_info->SetShape(dims);
tensor_info->SetOriginShape(origin_dims);
}
return;
}
NodeCompileCacheItem *NodeCompileCacheModule::FindCompileCache(const NodePtr &node) {
if ((ccp_ == nullptr) || (node == nullptr)) {
return nullptr;
}
auto cache_desc = ge::MakeShared<CompileCacheDesc>();
GE_ASSERT_NOTNULL(cache_desc);
const auto added_cache_desc = GetCompileCacheDesc(node);
if (added_cache_desc != nullptr) {
GELOGD("get compile cache desc from map successful");
*cache_desc = *added_cache_desc;
UpdateTensorInfos(node, *cache_desc);
} else {
GELOGW("get compile cache desc from op in find process");
if (GetCompileCacheDescFromOp(node, cache_desc, false) != SUCCESS) {
GELOGW("get cache desc failed in find process");
return nullptr;
}
}
const CacheItemId id = ccp_->FindCache(cache_desc);
GELOGD("find cache item id is %lu, node name is %s", id, node->GetName().c_str());
const std::lock_guard<std::mutex> lk(ids_to_cci_mu_);
const auto it = ids_to_cci_.find(id);
if (it == ids_to_cci_.end()) {
GELOGD("cannot find id %lu", id);
return nullptr;
} else {
return &it->second;
}
}
std::shared_ptr<CompileCacheDesc> NodeCompileCacheModule::GetCompileCacheDesc(const NodePtr &node) {
const uintptr_t node_id = PtrToValue(node.get());
const std::lock_guard<std::mutex> lk(node_to_cache_desc_map_mu_);
const auto it = node_to_cache_desc_map_.find(node_id);
if (it == node_to_cache_desc_map_.end()) {
GELOGW("cannot get cache desc from map, node_id is %lu", node_id);
return nullptr;
}
return it->second;
}
void NodeCompileCacheModule::InsertCompileCacheDesc(const NodePtr &node,
std::shared_ptr<CompileCacheDesc> &cache_desc) {
const uintptr_t node_id = PtrToValue(node.get());
const std::lock_guard<std::mutex> lk(node_to_cache_desc_map_mu_);
node_to_cache_desc_map_[node_id] = cache_desc;
}
NodeCompileCacheItem *NodeCompileCacheModule::AddCompileCache(const NodePtr &node, NodeCompileCacheItem &item) {
if ((ccp_ == nullptr) || (node == nullptr)) {
return nullptr;
}
auto cache_desc = ge::MakeShared<CompileCacheDesc>();
GE_ASSERT_NOTNULL(cache_desc);
if (GetCompileCacheDescFromOp(node, cache_desc, true) != SUCCESS) {
GELOGW("get cache desc failed in add process");
return nullptr;
}
const CacheItemId id = ccp_->AddCache(cache_desc);
GELOGD("add cache item id is %lu, node name is %s", id, node->GetName().c_str());
InsertCompileCacheDesc(node, cache_desc);
const std::lock_guard<std::mutex> lk(ids_to_cci_mu_);
const auto it = ids_to_cci_.find(id);
if (it == ids_to_cci_.end()) {
item.SetCacheItemId(id);
ids_to_cci_[id] = item;
return &ids_to_cci_[id];
} else {
return &it->second;
}
}
}