* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef GRAPH_CACHE_POLICY_CACHE_STATE_H
#define GRAPH_CACHE_POLICY_CACHE_STATE_H
#include <vector>
#include <functional>
#include <unordered_map>
#include <chrono>
#include <queue>
#include <mutex>
#include "compile_cache_desc.h"
namespace ge {
class CacheInfo;
using CacheItemId = uint64_t;
constexpr CacheItemId KInvalidCacheItemId = std::numeric_limits<uint64_t>::max();
using DelCacheFunc = std::function<bool(CacheInfo &)>;
using CCStatType = std::unordered_map<uint64_t, std::vector<CacheInfo>>;
class CacheInfo {
friend class CacheState;
public:
CacheInfo(const uint64_t timer_count, const CacheItemId item_id, const CacheDescPtr &desc)
: item_id_(item_id), desc_(desc), timer_count_(timer_count) {}
CacheInfo(const CacheInfo &other)
: item_id_(other.item_id_), desc_(other.desc_), timer_count_(other.timer_count_) {}
CacheInfo &operator=(const CacheInfo &other) {
timer_count_ = other.timer_count_;
item_id_ = other.item_id_;
desc_ = other.desc_;
return *this;
}
CacheInfo() = delete;
~CacheInfo() = default;
void RefreshTimerCount(uint64_t time_count) {
timer_count_ = time_count;
}
uint64_t GetTimerCount() const noexcept {
return timer_count_;
}
CacheItemId GetItemId() const noexcept {
return item_id_;
}
const CacheDescPtr &GetCacheDesc() const noexcept {
return desc_;
}
private:
CacheItemId item_id_;
CacheDescPtr desc_;
uint64_t timer_count_;
};
struct CacheInfoQueue {
void Insert(const CacheHashKey main_hash_key, std::vector<CacheInfo> &cache_info);
void EmplaceBack(const CacheHashKey main_hash_key, CacheInfo &cache_info);
void Erase(std::vector<CacheItemId> &delete_ids, const DelCacheFunc &is_need_delete_func);
CCStatType cc_state_;
uint64_t cache_info_num_ = 0U;
};
class CacheState {
public:
CacheState() = default;
~CacheState() = default;
CacheItemId AddCache(const CacheHashKey main_hash_key, const CacheDescPtr &cache_desc);
std::vector<CacheItemId> DelCache(const DelCacheFunc &func);
std::vector<CacheItemId> DelCache(const std::vector<CacheItemId> &delete_item);
const CCStatType &GetState() const {
return cache_info_queue.cc_state_;
}
uint64_t GetCacheInfoNum() const {
return cache_info_queue.cache_info_num_;
}
uint64_t GetCurTimerCount() const {
return cache_timer_count_;
}
private:
CacheItemId GetNextCacheItemId();
void RecoveryCacheItemId(const std::vector<CacheItemId> &cache_items);
uint64_t GetNextTimerCount() {
const std::lock_guard<std::mutex> lock(cache_timer_count_mu_);
return cache_timer_count_++;
}
std::mutex cache_info_queue_mu_;
std::mutex cache_item_mu_;
int64_t cache_item_counter_ = 0L;
std::queue<int64_t> cache_item_queue_;
CacheInfoQueue cache_info_queue;
uint64_t cache_timer_count_ = 0U;
std::mutex cache_timer_count_mu_;
};
}
#endif