* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <algorithm>
#include <cinttypes>
#include <memory>
#include "framework/common/debug/ge_log.h"
#include "guard_cache.h"
#include "execution_point.h"
#include "graph_metadef/common/ge_common/util.h"
#include "common/checker.h"
#include "base/err_msg.h"
using namespace ge;
GuardedExecutionPoint *GuardCheckCache::FindGuardedExecutionPoint(const std::vector<gert::Tensor> &input_tensor) {
int64_t ep_id = -1;
if (owner_point_) {
ep_id = owner_point_->GetId();
}
for (auto &item : cache_models_) {
auto result = item->Match(input_tensor);
if (result) {
GELOGI("Get EP[%ld] hint GEP(%u) and priority is (%u)", ep_id, item->GetCompiledGraphId(),
item->GetPriority());
item->SetPriority(item->GetPriority() + 1);
return item.get();
}
}
GELOGI("There is no hint GEP in EP[%ld] cache size(%" PRIu64 ")", ep_id, cache_models_.size());
return nullptr;
}
uint32_t GuardCheckCache::GetSavedCacheNum() const {
return cache_models_.size();
}
GuardedExecutionPoint *GuardCheckCache::FindOrCreateGuarded(const std::vector<gert::Tensor> &inputs) {
GuardedExecutionPoint *gep = FindGuardedExecutionPoint(inputs);
if (gep) {
return gep;
}
if (owner_point_) {
gep = FindGuardedExecutionPoint(inputs);
if (gep) {
return gep;
}
}
REPORT_INNER_ERR_MSG("W18888", "There is no hint GEP, cache size(%" PRIu64 "). Guard miss reason in info log", cache_models_.size());
gep = new GuardedExecutionPoint(owner_point_);
GE_ASSERT_SUCCESS(AddCompiledCompiledGraph(gep));
return gep;
}
Status GuardCheckCache::AddCompiledCompiledGraph(GuardedExecutionPoint* gep) {
if (!gep) {
return FAILED;
}
GE_ASSERT_SUCCESS(gep->CopySlicedGraph());
if (max_cache_count_ <= cache_models_.size()) {
auto &lastItem = cache_models_.back();
lastItem->RemoveItem();
cache_models_.pop_back();
}
cache_models_.emplace_back(gep);
std::sort(cache_models_.begin(), cache_models_.end(),
[](std::unique_ptr<ge::GuardedExecutionPoint> &a, std::unique_ptr<ge::GuardedExecutionPoint> &b) {
return a->GetPriority() > b->GetPriority();
});
return SUCCESS;
}
Status GuardCheckCache::RemoveCompiledGraph() {
for (auto& item : cache_models_) {
auto status = item->RemoveItem();
if (status != ge::SUCCESS) {
return status;
}
}
cache_models_.clear();
return ge::SUCCESS;
}