* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "llm_common/kv_cache_manager.h"
#include <set>
#include <memory>
#include <vector>
#include "ascend_hal.h"
namespace FlowFunc {
KvCacheManager &KvCacheManager::GetInstance() {
static KvCacheManager manager;
return manager;
}
KvCacheManager::~KvCacheManager() {
ClearDecoderKvCache();
ClearPromptKvCache();
decoder_req_to_kv_cache_.clear();
kv_blocks_tensors_.clear();
}
std::shared_ptr<FlowMsg> KvCacheManager::QueryDecoderKvCache(const ReqKey &req_key) {
std::unique_lock<std::mutex> lock(mutex_);
const auto iter = decoder_req_to_kv_cache_.find(req_key);
if (iter == decoder_req_to_kv_cache_.end()) {
uint64_t cached_req_id = decoder_req_to_kv_cache_.empty() ? UINT64_MAX : decoder_req_to_kv_cache_.begin()->first.req_id;
UDF_LOG_ERROR("Cannot find kv cache for req_id:%lu, current cached req_id:%lu.", req_key.req_id, cached_req_id);
return nullptr;
}
return iter->second;
}
FsmStatus KvCacheManager::SaveDecoderKvCache(const ReqKey &req_key, const std::shared_ptr<FlowMsg> &kv_cache) {
std::unique_lock<std::mutex> lock(mutex_);
const auto iter = decoder_req_to_kv_cache_.find(req_key);
if (iter != decoder_req_to_kv_cache_.end()) {
UDF_LOG_ERROR("Exist kv cache in decoder, req_id:%lu, model_id:%lu.", req_key.req_id, req_key.model_id);
}
decoder_req_to_kv_cache_[req_key] = kv_cache;
UDF_LOG_INFO("Success to save kv cache for decoder, req_id:%lu, model_id:%lu.", req_key.req_id, req_key.model_id);
return FsmStatus::kFsmSuccess;
}
FsmStatus KvCacheManager::ReleaseDecoderKvCache(const ReqKey &req_key) {
std::unique_lock<std::mutex> lock(mutex_);
const auto iter = decoder_req_to_kv_cache_.find(req_key);
if (iter == decoder_req_to_kv_cache_.end()) {
UDF_LOG_INFO("Not exist kv cache in decoder, req_id:%lu, model_id:%lu.", req_key.req_id, req_key.model_id);
return FsmStatus::kFsmSuccess;
}
(void) decoder_req_to_kv_cache_.erase(iter);
return FsmStatus::kFsmSuccess;
}
FsmStatus KvCacheManager::ReleaseDecoderKvCaches() {
std::unique_lock<std::mutex> lock(mutex_);
(void) decoder_req_to_kv_cache_.clear();
return FsmStatus::kFsmSuccess;
}
void KvCacheManager::SaveKvCacheForPrompt(const ReqKey &req_key,
uint64_t cache_id,
std::pair<int64_t, int64_t> offset_and_size) {
std::lock_guard<std::mutex> lock(mutex_);
UDF_LOG_INFO("KV Cache save for prompt, req_id = %lu, model_id = %lu, cache_id = %lu, offset = %ld, size = %ld",
req_key.req_id, req_key.model_id, cache_id, offset_and_size.first, offset_and_size.second);
auto &cache_entry = cache_id_to_cache_entry_[cache_id];
cache_entry.id_to_offset_and_size[req_key.req_id] = std::move(offset_and_size);
req_id_to_cache_id_[req_key] = cache_id;
}
void KvCacheManager::SaveKvCacheForPromptPrefix(const PrefixReqKey &prefix_req_key, uint64_t cache_id,
std::pair<int64_t, int64_t> offset_and_size) {
UDF_LOG_INFO("KV Cache save for prefix, prefix_id = %lu, cache_id = %lu, offset = %ld, size = %ld",
prefix_req_key.prefix_id, cache_id, offset_and_size.first, offset_and_size.second);
std::unique_lock<std::mutex> lock(mutex_);
const auto iter = prefix_id_to_cache_id_.find(prefix_req_key);
if (iter != prefix_id_to_cache_id_.end()) {
UDF_LOG_INFO("Exist kv cache in prompt, prefix_id:%lu.", prefix_req_key.prefix_id);
}
auto &cache_entry = cache_id_to_cache_entry_[cache_id];
cache_entry.id_to_offset_and_size[prefix_req_key.prefix_id] = std::move(offset_and_size);
prefix_id_to_cache_id_[prefix_req_key] = cache_id;
}
FsmStatus KvCacheManager::ReleaseKvCacheForPromptByKey(const ReqKey &req_key) {
UDF_LOG_INFO("Release KV Cache for prompt, req_id=%lu, model_id=%lu", req_key.req_id, req_key.model_id);
const auto iter = req_id_to_cache_id_.find(req_key);
if (iter == req_id_to_cache_id_.end()) {
UDF_LOG_INFO("Not exist kv cache in prompt, req_id:%lu.", req_key.req_id);
return FsmStatus::kFsmSuccess;
}
const auto cache_id = iter->second;
if (cache_id_to_cache_entry_.count(cache_id) > 0UL) {
auto &cache_entry = cache_id_to_cache_entry_[cache_id];
if (cache_entry.id_to_offset_and_size.count(req_key.req_id) > 0UL) {
(void) cache_entry.id_to_offset_and_size.erase(req_key.req_id);
}
if (cache_entry.id_to_offset_and_size.empty()) {
(void) cache_id_to_cache_entry_.erase(cache_id);
UDF_LOG_INFO("KV Cache removed, cache id = %lu", cache_id);
}
}
(void) req_id_to_cache_id_.erase(iter);
return FsmStatus::kFsmSuccess;
}
FsmStatus KvCacheManager::ReleaseKvCacheForPrompt(const ReqKey &req_key) {
std::unique_lock<std::mutex> lock(mutex_);
return ReleaseKvCacheForPromptByKey(req_key);
}
FsmStatus KvCacheManager::ReleaseKvCacheForPrompt(uint64_t req_id) {
std::unique_lock<std::mutex> lock(mutex_);
UDF_LOG_INFO("Release KV Cache for prompt, req_id=%lu", req_id);
std::vector<ReqKey> need_delete_req_keys;
for (const auto &it: req_id_to_cache_id_) {
if (it.first.req_id == req_id) {
need_delete_req_keys.emplace_back(it.first);
}
}
for (const auto &req_key: need_delete_req_keys) {
(void) ReleaseKvCacheForPromptByKey(req_key);
}
return FsmStatus::kFsmSuccess;
}
FsmStatus KvCacheManager::ReleaseKvCacheForPrefix(const PrefixReqKey &prefix_req_key) {
std::unique_lock<std::mutex> lock(mutex_);
UDF_LOG_INFO("Release KV Cache for prefix, prefix_id=%lu, model_id=%lu", prefix_req_key.prefix_id, prefix_req_key.model_id);
const auto iter = prefix_id_to_cache_id_.find(prefix_req_key);
if (iter == prefix_id_to_cache_id_.end()) {
UDF_LOG_INFO("Not exist kv cache in prompt, prefix_id:%lu.", prefix_req_key.prefix_id);
return FsmStatus::kFsmSuccess;
}
const auto cache_id = iter->second;
auto &cache_entry = cache_id_to_cache_entry_[cache_id];
(void) cache_entry.id_to_offset_and_size.erase(prefix_req_key.prefix_id);
if (cache_entry.id_to_offset_and_size.empty()) {
(void) cache_id_to_cache_entry_.erase(cache_id);
UDF_LOG_INFO("KV Cache removed, cache id = %lu", cache_id);
}
(void) prefix_id_to_cache_id_.erase(iter);
return FsmStatus::kFsmSuccess;
}
void KvCacheManager::SetDecoderKvCache(const std::vector<std::shared_ptr<FlowMsg>> &kv_caches) {
decoder_kv_caches_ = kv_caches;
}
std::shared_ptr<FlowMsg> KvCacheManager::GetDecoderKvCache(uint64_t model_id) const {
return decoder_kv_caches_[model_id];
}
void KvCacheManager::ClearDecoderKvCache()
{
decoder_kv_caches_.clear();
}
uint64_t KvCacheManager::GetPromptKvCacheCount() {
std::unique_lock<std::mutex> lock(mutex_);
uint64_t count = 0;
std::set<uint64_t> prompt_cache_ids;
for (const auto &req_id_and_cache_id: req_id_to_cache_id_) {
prompt_cache_ids.insert(req_id_and_cache_id.second);
}
for (const auto cache_id: prompt_cache_ids) {
const auto batch_num = cache_id_to_cache_entry_[cache_id].batch_num;
UDF_LOG_INFO("cache id = %lu, batch_num = %zu", cache_id, batch_num);
count += batch_num;
}
return count;
}
uint64_t KvCacheManager::GetDecoderKvCacheCount(const ReqKey &req_key) {
std::unique_lock<std::mutex> lock(mutex_);
return decoder_req_to_kv_cache_.count(req_key);
}
std::vector<KvTensor> KvCacheManager::QueryKvTensor(uint64_t cache_id, uint64_t tensor_id) {
auto &cache_entry = cache_id_to_cache_entry_[cache_id];
const auto &offset_and_size = cache_entry.id_to_offset_and_size[tensor_id];
std::vector<KvTensor> kv_tensors;
if (!cache_entry.kv_tensors.empty()) {
kv_tensors.reserve(cache_entry.kv_tensors.size() * cache_entry.kv_tensors[0]->GetTensorList().size());
}
for (const auto &kv_tensor_buffer: cache_entry.kv_tensors) {
const auto tensor_list = kv_tensor_buffer->GetTensorList();
for (const auto &tensor : tensor_list) {
KvTensor kv_tensor;
kv_tensor.tensor_buffer = kv_tensor_buffer;
kv_tensor.data_addr = static_cast<uint8_t *>(tensor->GetData()) + offset_and_size.first;
kv_tensor.data_size = offset_and_size.second;
kv_tensor.block_len = kv_tensor.data_size;
kv_tensors.emplace_back(std::move(kv_tensor));
}
}
return kv_tensors;
}
std::vector<KvTensor> KvCacheManager::QueryPromptKvCache(const ReqKey &req_key) {
static std::vector<KvTensor> empty_vector = {};
std::unique_lock<std::mutex> lock(mutex_);
const auto iter = req_id_to_cache_id_.find(req_key);
if (iter == req_id_to_cache_id_.end()) {
UDF_LOG_INFO("Not exist kv cache in prompt, req_id:%lu.", req_key.req_id);
return empty_vector;
}
const auto cache_id = iter->second;
return QueryKvTensor(cache_id, req_key.req_id);
}
std::vector<KvTensor> KvCacheManager::QueryPromptPrefixKvCache(const PrefixReqKey &prefix_req_key) {
static std::vector<KvTensor> empty_vector = {};
std::unique_lock<std::mutex> lock(mutex_);
const auto iter = prefix_id_to_cache_id_.find(prefix_req_key);
if (iter == prefix_id_to_cache_id_.end()) {
UDF_LOG_INFO("Not exist kv cache in prompt, req_id:%lu.", prefix_req_key.prefix_id);
return empty_vector;
}
const auto cache_id = iter->second;
return QueryKvTensor(cache_id, prefix_req_key.prefix_id);
}
void KvCacheManager::InitBlockTensors(uint64_t model_id) {
kv_blocks_tensors_.resize(model_id, std::vector<std::shared_ptr<FlowMsg>>());
}
void KvCacheManager::SaveKvBlocksTensor(uint64_t model_id, std::shared_ptr<FlowMsg> &kv_blocks_tensor) {
kv_blocks_tensors_[model_id].emplace_back(kv_blocks_tensor);
}
std::vector<std::shared_ptr<FlowMsg>> &KvCacheManager::QueryKvBlocksTensors(uint64_t model_id) {
return kv_blocks_tensors_[model_id];
}
void KvCacheManager::ClearKvBlocksTensors() {
kv_blocks_tensors_.clear();
}
uint64_t KvCacheManager::AddKvCache(size_t batch_num, std::vector<std::shared_ptr<FlowMsg>> kv_tensors) {
std::unique_lock<std::mutex> lock(mutex_);
const auto cache_id = cache_id_gen_++;
CacheEntry cache_entry;
cache_entry.batch_num = batch_num;
cache_entry.kv_tensors = std::move(kv_tensors);
cache_id_to_cache_entry_.emplace(cache_id, std::move(cache_entry));
UDF_LOG_INFO("KV Cache added, cache id = %lu, batch_num = %zu", cache_id, batch_num);
return cache_id;
}
FsmStatus KvCacheManager::ShrinkKvCacheForPrefix(const PrefixReqKey &prefix_req_key,
const std::shared_ptr<MetaRunContext> &run_context) {
UDF_LOG_INFO("Shrink start, prefix_id:%lu.", prefix_req_key.prefix_id);
std::vector<std::shared_ptr<FlowMsg>> shrinked_kv_tensors;
int64_t tensor_size = -1;
{
std::unique_lock<std::mutex> lock(mutex_);
const auto &cache_id = prefix_id_to_cache_id_[prefix_req_key];
auto &cache_entry = cache_id_to_cache_entry_[cache_id];
const auto &offset_and_size = cache_entry.id_to_offset_and_size[prefix_req_key.prefix_id];
tensor_size = offset_and_size.second;
shrinked_kv_tensors.reserve(cache_entry.kv_tensors.size());
for (auto &kv_tensor : cache_entry.kv_tensors) {
auto new_kv_tensor = run_context->AllocTensorMsg({tensor_size}, TensorDataType::DT_UINT8);
if (new_kv_tensor == nullptr) {
UDF_LOG_ERROR("Shink kv cache failed. prefix_id = %lu, Currently cache size:%lu",
prefix_req_key.prefix_id, KvCacheManager::GetInstance().GetPromptKvCacheCount());
return FsmStatus::kFsmOutOfMemory;
}
auto rt_ret =
halSdmaCopy(reinterpret_cast<uintptr_t>(new_kv_tensor->GetTensor()->GetData()),
tensor_size,
reinterpret_cast<uintptr_t>(static_cast<uint8_t *>(kv_tensor->GetTensor()->GetData()) +
offset_and_size.first),
tensor_size);
if (rt_ret != DRV_ERROR_NONE) {
UDF_LOG_ERROR("Fail to call halSdmaCopy, kv cache size:%ld, ret:%d.", tensor_size, rt_ret);
return FsmStatus::kFsmFailed;
}
shrinked_kv_tensors.emplace_back(std::move(new_kv_tensor));
kv_tensor.reset();
}
}
UDF_LOG_INFO("copy tensors ended, prefix_id:%lu.", prefix_req_key.prefix_id);
auto new_cache_id = AddKvCache(1U, shrinked_kv_tensors);
(void) ReleaseKvCacheForPrefix(prefix_req_key);
SaveKvCacheForPromptPrefix(prefix_req_key, new_cache_id, std::pair<int64_t, int64_t>(0L, tensor_size));
return FsmStatus::kFsmSuccess;
}
void KvCacheManager::ClearPromptKvCache() {
std::unique_lock<std::mutex> lock(mutex_);
cache_id_to_cache_entry_.clear();
req_id_to_cache_id_.clear();
prefix_id_to_cache_id_.clear();
}
}