* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "llm_common/cache_manager.h"
#include <numeric>
#include "flow_func/flow_func_log.h"
#include "flow_func/meta_flow_func.h"
#include "common/data_utils.h"
#include "ascend_hal.h"
namespace FlowFunc {
namespace {
constexpr uint32_t kMemAlign = 512U;
}
CacheManager &CacheManager::GetInstance() {
static CacheManager manager;
return manager;
}
FsmStatus CacheManager::AllocateCache(const std::shared_ptr<MetaRunContext> &run_context,
const AllocateCacheReqInfo &req_info,
uint64_t *data_addresses) {
std::lock_guard<std::mutex> lk(mu_);
if (req_info.num_tensors == 0U) {
return FsmStatus::kFsmParamInvalid;
}
CacheEntry cache_entry;
std::vector<CacheIndex> cache_indices;
const auto ret = CheckAndCreateCacheIndices(req_info, cache_indices);
if (ret != FsmStatus::kFsmSuccess) {
return ret;
}
std::vector<int64_t> shape(req_info.dims, req_info.dims + req_info.num_dims);
auto dtype = static_cast<TensorDataType>(req_info.dtype);
int64_t tensor_size = CalcDataSize(shape, dtype);
if (tensor_size <= 0) {
UDF_LOG_ERROR("[cache_id:%ld][Allocate] failed, tensor desc is invalid, shape=%s, dtype=%d",
req_info.cache_id, ToString(shape).c_str(), static_cast<int32_t>(dtype));
return FsmStatus::kFsmParamInvalid;
}
cache_entry.ext_ref_count = 1U;
cache_entry.tensor_size = static_cast<size_t>(tensor_size);
cache_entry.batch_size = shape.empty() ? 1U : static_cast<uint32_t>(shape.front());
cache_entry.batch_stride = tensor_size / cache_entry.batch_size;
cache_entry.seq_len_dim_index = req_info.seq_len_dim_index;
UDF_LOG_INFO("[cache_id:%ld][Allocate] allocate start, tensorNum = %u, tensor_size = %lu, "
"batch_size = %u, batch_stride = %lu, seq_len_dim_index=%d, cache_indices:%zu.",
req_info.cache_id, req_info.num_tensors, cache_entry.tensor_size,
cache_entry.batch_size, cache_entry.batch_stride, req_info.seq_len_dim_index, cache_indices.size());
const auto num_tensors = static_cast<size_t>(req_info.num_tensors);
cache_entry.tensors.resize(num_tensors);
for (size_t i = 0U; i < num_tensors; ++i) {
auto cache_tensor = run_context->AllocTensorMsgWithAlign(shape, dtype, kMemAlign);
if (cache_tensor == nullptr) {
UDF_LOG_ERROR("Allocate cache failed. cache_id = %ld, tensorNum = %u, tensor_size = %lu, "
"currentCacheNum = %zu",
req_info.cache_id, req_info.num_tensors, cache_entry.tensor_size, cache_id_to_entry_.size());
return FsmStatus::kFsmOutOfMemory;
}
data_addresses[i] = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(cache_tensor->GetTensor()->GetData()));
UDF_LOG_INFO("addr_to_size_ emplace data:%lu, size:%lu", data_addresses[i], cache_entry.tensor_size);
addr_to_size_[data_addresses[i]] = cache_entry.tensor_size;
cache_entry.tensors[i] = std::move(cache_tensor);
}
cache_entry.block_len = cache_entry.batch_stride;
if (req_info.is_allocate_blocks) {
AddCacheIndicesForBlocks(req_info, cache_indices);
cache_entry.is_blocks = true;
if (!cache_indices.empty()) {
cache_entry.blocks_cache_index = cache_indices[0];
}
} else {
AddCacheIndices(cache_entry, req_info, cache_indices);
}
cache_id_to_entry_[req_info.cache_id] = std::move(cache_entry);
UDF_LOG_INFO("[cache_id:%ld][Allocate] allocate ended successfully", req_info.cache_id);
return FsmStatus::kFsmSuccess;
}
void CacheManager::AddCacheIndices(CacheEntry &cache_entry, const AllocateCacheReqInfo &req_info,
const std::vector<CacheIndex> &cache_indices) {
int64_t cache_id = req_info.cache_id;
uint32_t batch_index = 0U;
std::vector<uint64_t> tensor_indices(cache_entry.tensors.size());
std::iota(tensor_indices.begin(), tensor_indices.end(), 0);
for (const auto &cache_index: cache_indices) {
if (cache_index.first != UINT64_MAX) {
if (req_info.is_prefix) {
(void) prefix_index_to_id_.emplace(cache_index, cache_id);
} else {
(void) cache_index_to_id_.emplace(cache_index, cache_id);
}
cache_entry.id_to_batch_index_and_size[cache_index.first] = std::make_pair(batch_index, cache_entry.batch_stride);
cache_entry.id_to_unpulled_tensor_indices[cache_index.first] =
std::set<uint64_t>(tensor_indices.cbegin(), tensor_indices.cend());
cache_id_and_batch_id_to_cache_index_[std::make_pair(cache_id, batch_index)] = cache_index;
UDF_LOG_INFO("[cache_id:%ld][AddCacheIndex] cacheKey(%lu, %lu) added, batch_index = %u, size = %lu",
cache_id, cache_index.first, cache_index.second, batch_index, cache_entry.batch_stride);
}
++batch_index;
}
}
void CacheManager::AddCacheIndicesForBlocks(const AllocateCacheReqInfo &req_info,
const std::vector<CacheIndex> &cache_indices) {
int64_t cache_id = req_info.cache_id;
for (const auto &cache_index: cache_indices) {
(void) cache_index_to_id_.emplace(cache_index, cache_id);
UDF_LOG_INFO("[cache_id:%ld][AddCacheIndicesForBlocks] cacheKey(%lu, %lu) added.",
cache_id,
cache_index.first,
cache_index.second);
}
}
void CacheManager::RemoveAddrToSize(const CacheEntry &cache_entry) {
for (const auto &tensor : cache_entry.tensors) {
auto addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(tensor->GetTensor()->GetData()));
if (addr_to_size_.find(addr) != addr_to_size_.end()) {
(void) addr_to_size_.erase(addr);
}
}
}
FsmStatus CacheManager::DeallocateCache(int64_t cache_id) {
std::lock_guard<std::mutex> lk(mu_);
const auto it = cache_id_to_entry_.find(cache_id);
if (it == cache_id_to_entry_.cend()) {
UDF_LOG_INFO("[cache_id:%ld][Deallocate] failed, cache_id not exist", cache_id);
return FsmStatus::kFsmKvNotExist;
}
auto &cache_entry = it->second;
cache_entry.ext_ref_count = 0U;
UDF_LOG_INFO("[cache_id:%ld][Deallocate] ext_ref_count was set to 0, ref_count remaining = %zu",
cache_id, cache_entry.id_to_batch_index_and_size.size());
if (cache_entry.is_blocks) {
(void) cache_index_to_id_.erase(cache_entry.blocks_cache_index);
UDF_LOG_INFO("[cache_id:%ld][Deallocate] deallocate ended successfully", cache_id);
}
if (cache_entry.id_to_batch_index_and_size.empty()) {
RemoveAddrToSize(cache_entry);
(void) cache_id_to_entry_.erase(it);
UDF_LOG_INFO("[cache_id:%ld][Deallocate] deallocate ended successfully", cache_id);
}
return FsmStatus::kFsmSuccess;
}
FsmStatus CacheManager::RemoveCacheIndex(const CacheIndex &cache_index,
bool is_prefix,
uint32_t num_tensor_indices,
const TransferInfo *tensor_indices) {
std::lock_guard<std::mutex> lk(mu_);
auto &index_to_id = is_prefix ? prefix_index_to_id_ : cache_index_to_id_;
const auto cache_id_it = index_to_id.find(cache_index);
if (cache_id_it == index_to_id.cend()) {
UDF_LOG_INFO("[RemoveIndex] cacheKey (%lu, %lu) not exist", cache_index.first, cache_index.second);
return FsmStatus::kFsmKvNotExist;
}
const auto cache_id = cache_id_it->second;
const auto it = cache_id_to_entry_.find(cache_id);
auto &cache_entry = it->second;
auto batch_index_and_size_it = cache_entry.id_to_batch_index_and_size.find(cache_index.first);
if (batch_index_and_size_it != cache_entry.id_to_batch_index_and_size.cend()) {
if (num_tensor_indices > 0 && num_tensor_indices < cache_entry.tensors.size()) {
auto &unpulled = cache_entry.id_to_unpulled_tensor_indices[cache_index.first];
for (uint32_t i = 0U; i < num_tensor_indices; ++i) {
unpulled.erase(tensor_indices[i].tensor_index);
}
UDF_LOG_INFO("[RemoveIndex] cacheKey (%lu, %lu) partial pull, unpulled = %s",
cache_index.first, cache_index.second,
ToString(std::vector<uint64_t>(unpulled.cbegin(), unpulled.cend())).c_str());
if (!unpulled.empty()) {
return FsmStatus::kFsmSuccess;
}
}
const auto batch_index = batch_index_and_size_it->second.first;
(void) cache_id_and_batch_id_to_cache_index_.erase(std::make_pair(cache_id, batch_index));
(void) cache_entry.id_to_batch_index_and_size.erase(batch_index_and_size_it);
}
UDF_LOG_INFO("[cache_id:%ld][RemoveIndex] cacheKey (%lu, %lu) was removed, ref_count = %zu, ext_ref_count = %zu",
cache_id, cache_index.first, cache_index.second,
cache_entry.id_to_batch_index_and_size.size(), cache_entry.ext_ref_count);
const auto ref_count = cache_entry.id_to_batch_index_and_size.size() + cache_entry.ext_ref_count;
(void) index_to_id.erase(cache_id_it);
if (ref_count == 0) {
RemoveAddrToSize(cache_entry);
(void) cache_id_to_entry_.erase(it);
UDF_LOG_INFO("[cache_id:%ld][Deallocate] ref_count = 0 and was deallocated", cache_id);
}
return FsmStatus::kFsmSuccess;
}
FsmStatus CacheManager::CopyBlockCache(const CopyCacheReqInfo &req_info) const {
CacheEntry cache_entry;
const auto cache_entry_exist = GetCacheEntry(req_info.dst_cache_id, cache_entry);
if (!cache_entry_exist) {
UDF_LOG_ERROR("[CopyCache](%ld->%ld) failed, dst_cache_exist = %d",
req_info.src_cache_id, req_info.dst_cache_id,
static_cast<int32_t>(cache_entry_exist));
return FsmStatus::kFsmKvNotExist;
}
for (uint32_t i = 0; i < req_info.copy_block_count; ++i) {
if ((req_info.block_copy_infos[i].src_block_index >= cache_entry.batch_size) ||
(req_info.block_copy_infos[i].dst_block_index >= cache_entry.batch_size)) {
UDF_LOG_ERROR("[cache_id:%ld][CopyBlockCache] copyBlockIndex (%lu, %lu) out of range [0, %zu).",
req_info.dst_cache_id, req_info.block_copy_infos[i].src_block_index,
req_info.block_copy_infos[i].dst_block_index, cache_entry.batch_size);
return FsmStatus::kFsmParamInvalid;
}
}
for (const auto &kv_tensor: cache_entry.tensors) {
auto data_addr = static_cast<uint8_t *>(kv_tensor->GetTensor()->GetData());
for (uint32_t i = 0; i < req_info.copy_block_count; ++i) {
auto src_mem = data_addr + req_info.block_copy_infos[i].src_block_index * cache_entry.block_len;
auto dst_mem = data_addr + req_info.block_copy_infos[i].dst_block_index * cache_entry.block_len;
const auto copy_ret =
halSdmaCopy(reinterpret_cast<uintptr_t>(dst_mem), cache_entry.block_len,
reinterpret_cast<uintptr_t>(src_mem), cache_entry.block_len);
if (copy_ret != DRV_ERROR_NONE) {
UDF_LOG_ERROR("[CopyCache][%ld] halSdmaCopy failed, tensor index = %zu, ret = %d",
req_info.dst_cache_id, i, copy_ret);
return FsmStatus::kFsmCopyKvFailed;
}
}
}
return FsmStatus::kFsmSuccess;
}
FsmStatus CacheManager::CopyCache(const CopyCacheReqInfo &req_info) const {
if (req_info.copy_block_count > 0U) {
return CopyBlockCache(req_info);
}
CacheEntry src_cache_entry;
CacheEntry dst_cache_entry;
const auto src_cache_entry_exist = GetCacheEntry(req_info.src_cache_id, src_cache_entry);
const auto dst_cache_entry_exist = GetCacheEntry(req_info.dst_cache_id, dst_cache_entry);
if ((!src_cache_entry_exist) || (!dst_cache_entry_exist)) {
UDF_LOG_ERROR("[CopyCache](%ld->%ld) failed, src_cache_exist = %d, dst_cache_exist = %d",
req_info.src_cache_id, req_info.dst_cache_id,
static_cast<int32_t>(src_cache_entry_exist), static_cast<int32_t>(dst_cache_entry_exist));
return FsmStatus::kFsmKvNotExist;
}
uint64_t copy_size;
if (CheckCopyCacheParam(src_cache_entry, dst_cache_entry, req_info, copy_size) != FsmStatus::kFsmSuccess) {
return FsmStatus::kFsmParamInvalid;
}
const auto dst_offset = dst_cache_entry.batch_stride * req_info.dst_batch_index + req_info.offset;
const auto src_offset = src_cache_entry.batch_stride * req_info.src_batch_index + req_info.offset;
const auto dst_max = dst_cache_entry.batch_stride - req_info.offset;
UDF_LOG_INFO("[CopyCache](%ld->%ld) start, num_tensors = %zu, copy_size = %lu",
req_info.src_cache_id, req_info.dst_cache_id, src_cache_entry.tensors.size(), copy_size);
UDF_LOG_INFO("Src: [batch_index = %u, offset = %lu], Dst: [BatchIndex = %u, offset = %lu]",
req_info.src_batch_index, src_offset, req_info.dst_batch_index, dst_offset);
for (size_t i = 0U; i < src_cache_entry.tensors.size(); ++i) {
auto src_mem_base = static_cast<uint8_t *>(src_cache_entry.tensors[i]->GetTensor()->GetData());
auto dst_mem_base = static_cast<uint8_t *>(dst_cache_entry.tensors[i]->GetTensor()->GetData());
auto dst_mem = dst_mem_base + dst_offset;
auto src_men = src_mem_base + src_offset;
const auto copy_ret = D2DCopy(dst_mem, dst_max, src_men, copy_size);
if (copy_ret != DRV_ERROR_NONE) {
UDF_LOG_ERROR("[CopyCache][%ld->%ld] halSdmaCopy failed, tensor index = %zu, ret = %d",
req_info.src_cache_id, req_info.dst_cache_id, i, copy_ret);
return FsmStatus::kFsmCopyKvFailed;
}
}
UDF_LOG_INFO("[CopyCache](%ld->%ld) ended successfully", req_info.src_cache_id, req_info.dst_cache_id);
return FsmStatus::kFsmSuccess;
}
FsmStatus CacheManager::CheckCopyCacheParam(const CacheEntry &src_cache_entry,
const CacheEntry &dst_cache_entry,
const CopyCacheReqInfo &req_info,
uint64_t ©_size) {
if (dst_cache_entry.tensors.size() != src_cache_entry.tensors.size()) {
UDF_LOG_ERROR("[CopyCache](%ld->%ld) failed, tensor number mismatches, src = %zu, dst = %zu",
req_info.src_cache_id,
req_info.dst_cache_id,
src_cache_entry.tensors.size(),
dst_cache_entry.tensors.size());
return FsmStatus::kFsmCacheIncompatible;
}
if (req_info.src_batch_index >= src_cache_entry.batch_size || req_info.dst_batch_index >= dst_cache_entry.batch_size) {
UDF_LOG_ERROR("[CopyCache](%ld->%ld) failed, batch_index out of range, "
"src: index = %u, range = [0, %u), dst: index = %u, range = [0, %u)",
req_info.src_cache_id, req_info.dst_cache_id,
req_info.src_batch_index, src_cache_entry.batch_size, req_info.dst_batch_index, dst_cache_entry.batch_size);
return FsmStatus::kFsmParamInvalid;
}
if ((req_info.offset >= dst_cache_entry.batch_stride) || (req_info.offset >= src_cache_entry.batch_stride)) {
UDF_LOG_ERROR("[CopyCache](%ld->%ld) failed, offset out of range, "
"offset=%lu, src_stride=%lu, dst_stride=%lu",
req_info.src_cache_id, req_info.dst_cache_id,
req_info.offset, src_cache_entry.batch_stride, dst_cache_entry.batch_stride);
return FsmStatus::kFsmParamInvalid;
}
copy_size = req_info.copy_size != -1 ?
static_cast<uint64_t>(req_info.copy_size) :
src_cache_entry.batch_stride - req_info.offset;
if ((copy_size > (src_cache_entry.batch_stride - req_info.offset)) ||
(copy_size > (dst_cache_entry.batch_stride - req_info.offset))) {
UDF_LOG_ERROR("[CopyCache](%ld->%ld) failed, copy_size out of range, "
"offset=%lu, src_stride=%lu, dst_stride=%lu, size=%ld",
req_info.src_cache_id, req_info.dst_cache_id,
req_info.offset, src_cache_entry.batch_stride, dst_cache_entry.batch_stride, req_info.copy_size);
return FsmStatus::kFsmParamInvalid;
}
return FsmStatus::kFsmSuccess;
}
FsmStatus CacheManager::CheckTensorIndices(const PullKvReqInfo &req_info,
const CacheEntry &cache_entry,
bool enable_page_attention) {
const uint64_t *tensor_indices = LocalTensorIndices(req_info, enable_page_attention);
for (uint32_t i = 0U; i < req_info.dst_tensor_index_count; ++i) {
if (tensor_indices[i] >= cache_entry.tensors.size()) {
UDF_LOG_ERROR("[cache_id:%ld][Pull] dstTensorIndex[%u] = %lu, out of range, tensorNum = %zu.",
req_info.cache_id, i, tensor_indices[i], cache_entry.tensors.size());
return FsmStatus::kFsmParamInvalid;
}
}
return FsmStatus::kFsmSuccess;
}
FsmStatus CacheManager::CheckReqInfo(const CacheEntry &cache_entry, PullKvReqInfo *req_info, bool &enable_page_attention) {
if (req_info->batch_index >= cache_entry.batch_size) {
UDF_LOG_ERROR("[cache_id:%ld][Pull] batch_index (%lu) out of range [0, %zu).",
req_info->cache_id, req_info->batch_index, cache_entry.batch_size);
return FsmStatus::kFsmParamInvalid;
}
enable_page_attention = (req_info->block_count > 0U);
if (enable_page_attention
&& ClusterManager::CheckBlockIndices(req_info, cache_entry.batch_size) != FsmStatus::kFsmSuccess) {
return FsmStatus::kFsmParamInvalid;
}
if (CheckTensorIndices(*req_info, cache_entry, enable_page_attention) != FsmStatus::kFsmSuccess) {
return FsmStatus::kFsmParamInvalid;
}
if (enable_page_attention && req_info->block_count != req_info->prompt_block_count) {
UDF_LOG_ERROR("[cache_id:%ld][Pull] block_count(%lu) != prompt_block_count(%lu).",
req_info->cache_id, req_info->block_count, req_info->prompt_block_count);
return FsmStatus::kFsmParamInvalid;
}
if (enable_page_attention && (cache_entry.block_len > INT32_MAX)) {
UDF_LOG_ERROR("[cache_id:%ld][Pull] block_len(%lu) is over limit.", req_info->cache_id, cache_entry.block_len);
return FsmStatus::kFsmParamInvalid;
}
req_info->block_count = (req_info->block_count == 0U) ? 1U : req_info->block_count;
if (req_info->block_len == 0U) {
req_info->block_len = (cache_entry.batch_stride - req_info->dst_cache_offset);
}
if ((req_info->dst_cache_offset + req_info->block_len) > cache_entry.batch_stride) {
UDF_LOG_ERROR("[cache_id:%ld][Pull] dst_cache_offset(%lu) or block_len(%lu) is invalid, cache batch stride(%lu).",
req_info->cache_id, req_info->dst_cache_offset, req_info->block_len, cache_entry.batch_stride);
return FsmStatus::kFsmParamInvalid;
}
return FsmStatus::kFsmSuccess;
}
FsmStatus CacheManager::PullCache(PullKvReqInfo *req_info, ClusterManager &cluster_manager, uint64_t start_tick) const {
CacheEntry cache_entry;
if (!GetCacheEntry(req_info->cache_id, cache_entry)) {
UDF_LOG_ERROR("[cache_id:%ld][Pull] cache_id not found.", req_info->cache_id);
return FsmStatus::kFsmKvNotExist;
}
bool enable_page_attention = false;
auto ret = CheckReqInfo(cache_entry, req_info, enable_page_attention);
if (ret != FsmStatus::kFsmSuccess) {
UDF_LOG_ERROR("[cache_id:%ld][Pull] check req info failed, ret = %d.",
req_info->cache_id, static_cast<int32_t>(ret));
return ret;
}
UDF_LOG_INFO("[cache_id:%ld][Pull] start, %s, num_tensors=%zu, block_len = %lu",
req_info->cache_id, CacheKeyDebugString(*req_info).c_str(), cache_entry.tensors.size(), req_info->block_len);
ret = cluster_manager.PullCacheFromPrompt(req_info, enable_page_attention, start_tick, &cache_entry);
if (ret != FsmStatus::kFsmSuccess) {
UDF_LOG_ERROR("[cache_id:%ld][Pull] failed, ret = %d.", req_info->cache_id, static_cast<int32_t>(ret));
return ret;
}
UDF_LOG_INFO("[cache_id:%ld][Pull] success, %s, num_tensors=%zu, block_len = %lu",
req_info->cache_id, CacheKeyDebugString(*req_info).c_str(), cache_entry.tensors.size(), req_info->block_len);
return FsmStatus::kFsmSuccess;
}
FsmStatus CacheManager::TransferCache(TransferKvReqInfo *req_info, ClusterManager &cluster_manager,
uint64_t start_tick) const {
CacheEntry cache_entry;
if (!GetCacheEntry(req_info->cache_id, cache_entry)) {
UDF_LOG_ERROR("[cache_id:%ld][Transfer] cache_id not found.", req_info->cache_id);
return FsmStatus::kFsmKvNotExist;
}
if (cache_entry.tensors.size() % req_info->tensor_num_per_layer != 0) {
UDF_LOG_ERROR("[cache_id:%ld][Transfer] tensors num:%zu is not a multiple of %lu.", req_info->cache_id, cache_entry.tensors.size(),
req_info->tensor_num_per_layer);
return FsmStatus::kFsmParamInvalid;
}
uint64_t max_layer_index = cache_entry.tensors.size() / req_info->tensor_num_per_layer;
if (req_info->src_layer_index >= max_layer_index) {
UDF_LOG_ERROR("[cache_id:%ld][Transfer] layer index:%lu out of range [0, %lu).", req_info->cache_id,
req_info->src_layer_index, max_layer_index);
return FsmStatus::kFsmParamInvalid;
}
if (req_info->src_batch_index >= cache_entry.batch_size) {
UDF_LOG_ERROR("[cache_id:%ld][Transfer] src_batch_index (%lu) out of range [0, %zu).",
req_info->cache_id, req_info->src_batch_index, cache_entry.batch_size);
return FsmStatus::kFsmParamInvalid;
}
if (req_info->block_len > cache_entry.batch_stride) {
UDF_LOG_ERROR("[cache_id:%ld][Transfer] block_len(%lu) > cached batch_stride(%lu).",
req_info->cache_id, req_info->block_len, cache_entry.batch_stride);
return FsmStatus::kFsmParamInvalid;
}
UDF_LOG_INFO("[cache_id:%ld][Transfer] start, %s, num_tensors=%zu, block_len = %lu, dst_cache_id:%ld.", req_info->cache_id,
TransferReqDebugString(*req_info).c_str(), cache_entry.tensors.size(), req_info->block_len, req_info->dst_cache_id);
bool enable_page_attention = (req_info->block_count > 0U);
if (enable_page_attention && req_info->block_count != req_info->prompt_block_count) {
UDF_LOG_ERROR("[cache_id:%ld][Transfer] block_count(%lu) != prompt_block_count(%lu).",
req_info->cache_id, req_info->block_count, req_info->prompt_block_count);
return FsmStatus::kFsmParamInvalid;
}
if (enable_page_attention
&& ClusterManager::CheckBlockIndices(req_info, cache_entry.batch_size, req_info->block_count) != FsmStatus::kFsmSuccess) {
return FsmStatus::kFsmParamInvalid;
}
req_info->block_count = (req_info->block_count == 0U) ? 1U : req_info->block_count;
if (req_info->block_len == 0U) {
req_info->block_len = cache_entry.batch_stride;
}
if (enable_page_attention && req_info->block_len != cache_entry.batch_stride) {
UDF_LOG_ERROR("[cache_id:%ld][Transfer] block_len(%lu) != cached batch_stride(%lu).",
req_info->cache_id, req_info->block_len, cache_entry.batch_stride);
return FsmStatus::kFsmParamInvalid;
}
auto ret = cluster_manager.TransferCache(req_info, enable_page_attention, start_tick, &cache_entry);
if (ret != FsmStatus::kFsmSuccess) {
UDF_LOG_ERROR("[cache_id:%ld][Transfer] failed, ret = %d.", req_info->cache_id, static_cast<int32_t>(ret));
return ret;
}
UDF_LOG_INFO("[cache_id:%ld][Transfer] success, %s, num_tensors=%zu, block_len = %lu, dst_cache_id:%ld.", req_info->cache_id,
TransferReqDebugString(*req_info).c_str(), cache_entry.tensors.size(), req_info->block_len, req_info->dst_cache_id);
return FsmStatus::kFsmSuccess;
}
bool CacheManager::GetCacheEntry(const CacheIndex &cache_index, CacheEntry &cache_entry, bool is_prefix) const {
std::lock_guard<std::mutex> lk(mu_);
auto &index_to_id = is_prefix ? prefix_index_to_id_ : cache_index_to_id_;
const auto it = index_to_id.find(cache_index);
if (it == index_to_id.cend()) {
return false;
}
cache_entry = *DoGetCacheEntry(it->second);
return true;
}
bool CacheManager::GetCacheEntry(int64_t cache_id, CacheEntry &cache_entry) const {
std::lock_guard<std::mutex> lk(mu_);
auto entry = DoGetCacheEntry(cache_id);
bool success = (entry != nullptr);
if (success) {
cache_entry = *entry;
}
return success;
}
bool CacheManager::GetCacheIndex(const std::pair<int64_t, uint32_t> &cache_id_and_batch_index, CacheIndex &cache_index) const {
std::lock_guard<std::mutex> lk(mu_);
const auto it = cache_id_and_batch_id_to_cache_index_.find(cache_id_and_batch_index);
if (it != cache_id_and_batch_id_to_cache_index_.cend()) {
cache_index = it->second;
return true;
}
return false;
}
const CacheEntry *CacheManager::DoGetCacheEntry(int64_t cache_id) const {
const auto it = cache_id_to_entry_.find(cache_id);
return (it == cache_id_to_entry_.cend()) ? nullptr : &it->second;
}
void CacheManager::Enable() {
enabled_ = true;
}
bool CacheManager::IsEnabled() const {
return enabled_;
}
FsmStatus CacheManager::CheckAndCreateCacheIndices(const AllocateCacheReqInfo &req_info,
std::vector<CacheIndex> &cache_indices) const {
if (cache_id_to_entry_.find(req_info.cache_id) != cache_id_to_entry_.cend()) {
UDF_LOG_ERROR("[cache_id:%ld][Allocate] check failed, cache_id already exist", req_info.cache_id);
return FsmStatus::kFsmCacheIdAlreadyExist;
}
cache_indices.reserve(req_info.num_requests);
auto &index_to_id = req_info.is_prefix ? prefix_index_to_id_ : cache_index_to_id_;
for (size_t i = 0; i < req_info.num_requests; ++i) {
cache_indices.emplace_back(req_info.req_ids[i], req_info.model_id);
const auto it = index_to_id.find(cache_indices.back());
if (it != index_to_id.cend()) {
UDF_LOG_ERROR("[cache_id:%ld][Allocate] check failed, cacheKey (%lu, %lu) already mapped to Cache [%ld]",
req_info.cache_id, req_info.req_ids[i], req_info.model_id, it->second);
return FsmStatus::kFsmCacheKeyAlreadyExist;
}
}
return FsmStatus::kFsmSuccess;
}
FsmStatus CacheManager::CheckAddr(uint64_t addr, uint64_t check_size) {
const auto it = addr_to_size_.find(addr);
if (it == addr_to_size_.end()) {
for (auto &addr_size_pair : addr_to_size_) {
auto exist_addr = addr_size_pair.first;
if ((addr > exist_addr) && ((addr + check_size) <= (exist_addr + addr_size_pair.second))) {
return FsmStatus::kFsmSuccess;
}
}
return FsmStatus::kFsmParamInvalid;
}
if (it->second < check_size) {
UDF_LOG_ERROR("Addr size is over limit, real size:%lu, req size:%lu.", it->second, check_size);
return FsmStatus::kFsmParamInvalid;
}
return FsmStatus::kFsmSuccess;
}
}