* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "data_cache_engine.h"
#include <set>
#include "llm_datadist/llm_error_codes.h"
#include "statistic_manager.h"
#include "common/common.h"
#include "common/llm_utils.h"
#include "common/mem_utils.h"
#include "swap_impl.h"
#include "cache_manager.h"
#include "data_transfer/d2h_data_transfer_job.h"
#include "data_transfer/layer_wise_transfer_job.h"
#include "data_transfer/data_transfer_client.h"
#include "base/err_msg.h"
#include "common/llm_checker.h"
#include "common/llm_scope_guard.h"
#include "llm_datadist/llm_engine_types.h"
namespace llm {
namespace {
constexpr size_t kMaxDimNum = 32U;
constexpr size_t kAlignment = 4096U;
ge::Status ParseMemoryPoolConfig(const std::string &mem_pool_config, size_t &pool_size, size_t &page_shift) {
const std::string &json_str = mem_pool_config;
nlohmann::json json_obj;
try {
json_obj = nlohmann::json::parse(json_str);
LLM_CHK_BOOL_RET_STATUS(json_obj.at("memory_size").is_number_unsigned(), ge::LLM_PARAM_INVALID,
"memory_size is not an unsigned integer: config = %s", json_str.c_str());
pool_size = json_obj.at("memory_size").get<size_t>();
if (json_obj.contains("page_shift")) {
LLM_CHK_BOOL_RET_STATUS(json_obj.at("page_shift").is_number_unsigned(), ge::LLM_PARAM_INVALID,
"page_shift is not an unsigned integer: config = %s", json_str.c_str());
page_shift = json_obj.at("page_shift").get<size_t>();
}
} catch (nlohmann::json::exception &e) {
REPORT_INNER_ERR_MSG("E19999", "Failed to parse memory pool config: %s", json_str.c_str());
LLMLOGE(ge::LLM_PARAM_INVALID, "Failed to parse memory pool config: \"%s\", exception = %s", json_str.c_str(),
e.what());
return ge::LLM_PARAM_INVALID;
}
return ge::SUCCESS;
}
ge::Status CheckTensorIndicesContinuous(const std::vector<uint64_t> &tensor_indices) {
if (tensor_indices.empty()) {
return ge::SUCCESS;
}
std::set<uint64_t> unique_elements(tensor_indices.begin(), tensor_indices.end());
LLM_CHK_BOOL_RET_STATUS(unique_elements.size() == tensor_indices.size(), ge::LLM_PARAM_INVALID,
"tensor_indices is not continuous");
const uint64_t min_element = *std::min_element(unique_elements.begin(), unique_elements.end());
const uint64_t max_element = *std::max_element(unique_elements.begin(), unique_elements.end());
LLM_CHK_BOOL_RET_STATUS((max_element - min_element + 1) == unique_elements.size(), ge::LLM_PARAM_INVALID,
"tensor_indices is not continuous");
return ge::SUCCESS;
}
}
ge::Status DataCacheEngine::Register(const llm::CacheDesc &cache_desc, const std::vector<CacheKey> &cache_keys,
llm::Cache &cache) {
LLM_CHK_BOOL_RET_STATUS(!cache_desc.shape.empty() && (cache_desc.shape.size() < kMaxDimNum), ge::LLM_PARAM_INVALID,
"Invalid shape: %s, dim_num (%zu) must be in range [1, 33)",
ToString(cache_desc.shape).c_str(), cache_desc.shape.size());
LLM_CHECK_GE(cache.per_device_tensor_addrs.size(), 1);
LLM_CHK_BOOL_RET_STATUS(cache_desc.num_tensors == static_cast<uint32_t>(cache.per_device_tensor_addrs[0].size()),
ge::LLM_PARAM_INVALID, "cache addrs size[%zu] is not equal to num_tensors[%u] of cache_desc",
cache.per_device_tensor_addrs[0].size(), cache_desc.num_tensors);
std::lock_guard<std::mutex> lock(mu_);
const auto cache_id = cache_id_gen_.fetch_add(1, std::memory_order::memory_order_relaxed);
LLM_CHECK_GE(cache_id, 1);
int64_t tensor_size;
LLM_CHK_STATUS_RET(LLMUtils::CalcTensorMemSize(cache_desc.shape,
cache_desc.data_type, tensor_size),
"Failed to calc tensor size, shape = %s, dtype = %d",
ToString(cache_desc.shape).c_str(),
static_cast<int32_t>(cache_desc.data_type));
LLMLOGI("[Register] start, placement:%u", static_cast<uint32_t>(cache_desc.placement));
LLM_CHK_STATUS_RET(comm_mem_manager_->RegisterCacheMem(cache_id, cache_desc,
cache.per_device_tensor_addrs[0U], tensor_size),
"Register cache addr failed, cache_id = %ld.", cache_id);
LLM_CHK_STATUS_RET(cache_manager_->RegisterCacheEntry(cache_id, cache_keys, cache_desc,
cache.per_device_tensor_addrs[0U], tensor_size),
"Register cache entry failed.");
cache.cache_id = cache_id;
LLMLOGI("[cache_id:%ld][Register] success, num_tensors = %u, shape = %s", cache_id, cache_desc.num_tensors,
ToString(cache_desc.shape).c_str());
return ge::SUCCESS;
}
ge::Status DataCacheEngine::Unregister(int64_t cache_id) {
const auto start = std::chrono::steady_clock::now();
LLM_CHK_STATUS_RET(comm_mem_manager_->UnregisterCacheMem(cache_id),
"Unregister cache addr failed, cache_id = %ld.", cache_id);
LLM_CHK_STATUS_RET(cache_manager_->UnregisterCacheEntry(cache_id),
"Unregister cache entry failed, cache_id = %ld.", cache_id);
LLMLOGI("[cache_id:%ld][Unregister] success, cost = %ld ms",
cache_id,
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - start).count());
return ge::SUCCESS;
}
ge::Status DataCacheEngine::PullCache(int64_t cache_id, const CacheKey &cache_key,
const PullCacheParam &pull_cache_param) {
const auto start = std::chrono::steady_clock::now();
CacheEntry cache_entry;
LLM_CHK_BOOL_RET_STATUS(cache_manager_->GetCacheEntry(cache_id, cache_entry), ge::LLM_KV_CACHE_NOT_EXIST,
"cache id:%ld not found", cache_id);
LLM_CHK_STATUS_RET(CheckParam(cache_entry, pull_cache_param), "[cache_id:%ld] check param failed", cache_id);
LLMLOGI("pull cache with tensor num per layer:%lu.", pull_cache_param.tensor_num_per_layer);
const auto entity = comm_entity_manager_->GetEntityByRemoteClusterId(cache_key.prompt_cluster_id);
LLM_CHK_BOOL_RET_STATUS(entity != nullptr, ge::LLM_NOT_YET_LINK,
"current cluster is not linked with remote cluster:%lu", cache_key.prompt_cluster_id);
std::lock_guard<std::mutex> pull_lock(entity->GetPullMutex());
LLM_CHK_BOOL_RET_STATUS((entity->GetCurState() != FsmState::FSM_DESTROYED_STATE), ge::LLM_NOT_YET_LINK,
"current cluster is not linked with remote cluster:%lu", cache_key.prompt_cluster_id);
LLMLOGI("Get lock cost:%ld us.",
std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - start).count());
TemporaryRtContext with_context(rt_context_);
LLM_CHK_BOOL_RET_STATUS(entity->CheckEntityInfo(), ge::LLM_NOT_YET_LINK,
"pull cache must wait until the query_register_mem_status return ok");
LLM_DISMISSABLE_GUARD(abort_stream, [this]() -> void {
LLM_CHK_ACL(rtStreamAbort(req_stream_));
});
if (access_remote_cache_) {
DataTransferClient client(*entity, req_stream_);
LLM_CHK_STATUS_RET(client.PullCacheByGet(cache_entry, cache_key, pull_cache_param, sync_cache_timeout_));
LLMLOGI("[PullCache] success, cache_id = %ld, num_tensors = %zu, stride = %lu, "
"pull_size = %ld, local_block_cnt = %zu, remote_block_cnt = %zu",
cache_id, cache_entry.cache_addrs.size(), cache_entry.stride,
pull_cache_param.size, pull_cache_param.decoder_blocks.size(), pull_cache_param.prompt_blocks.size());
LLM_DISMISS_GUARD(abort_stream);
return ge::SUCCESS;
}
entity->ClearResponseFlags();
if (cache_entry.placement == CachePlacement::HOST) {
LLM_CHK_BOOL_RET_STATUS(npu_pool_memory_ != nullptr, ge::LLM_PARAM_INVALID, "Device memory pool is not enabled.");
D2HDataTransferClient client(*entity, req_stream_);
LLM_CHK_STATUS_RET(client.PullCache(cache_entry, cache_key, pull_cache_param, sync_cache_timeout_),
"Failed to pull kv from remote cluster:%lu", cache_key.prompt_cluster_id);
LLM_DISMISS_GUARD(abort_stream);
return ge::SUCCESS;
}
DataTransferClient client(*entity, req_stream_);
LLM_CHK_STATUS_RET(client.PullCache(cache_entry, cache_key, pull_cache_param, sync_cache_timeout_),
"Failed to pull kv from remote cluster:%lu", cache_key.prompt_cluster_id);
LLM_DISMISS_GUARD(abort_stream);
return ge::SUCCESS;
}
ge::Status DataCacheEngine::SwapBlocks(const Cache &src, const Cache &dst, const uint64_t block_size,
const uint32_t type,
const std::vector<std::pair<int64_t, int64_t>> &block_mapping) const {
SwapImpl swap_impl(device_id_);
LLM_CHK_STATUS_RET(swap_impl.SwapBlocksV2(src, dst, block_size, type, block_mapping));
return ge::SUCCESS;
}
ge::Status DataCacheEngine::Initialize(const std::map<ge::AscendString, ge::AscendString> &options) {
int32_t device_id;
LLM_CHK_STATUS_RET(LLMUtils::ParseDeviceId(options, device_id), "Failed to get device id");
device_id_ = device_id;
LLM_ASSERT_RT_OK(rtCtxGetCurrent(&rt_context_));
LLM_CHK_STATUS_RET(LLMUtils::ParseFlag(kLlmOptionEnableRemoteCacheAccessible, options, access_remote_cache_),
"Failed to parse option %s", kLlmOptionEnableRemoteCacheAccessible);
LLM_CHK_STATUS_RET(cache_manager_->Initialize(access_remote_cache_));
DecoderWaitTimeInfo wait_time_info{};
LLM_CHK_STATUS_RET(LLMUtils::ParserWaitTimeInfo(options, wait_time_info), "parser wait time info failed");
sync_cache_timeout_ = wait_time_info.sync_kv_wait_time;
LLM_CHK_STATUS_RET(InitializeMemoryPool(options), "Failed to initialize memory pool");
LLM_ASSERT_RT_OK(
rtStreamCreateWithFlags(&req_stream_, RT_STREAM_PRIORITY_DEFAULT, RT_STREAM_FAST_LAUNCH | RT_STREAM_FAST_SYNC));
LLM_CHECK_NOTNULL(comm_entity_manager_);
LLM_CHECK_NOTNULL(comm_mem_manager_);
LLM_CHECK_NOTNULL(cache_manager_);
return ge::SUCCESS;
}
void DataCacheEngine::Finalize() const{
{
TemporaryRtContext with_context(rt_context_);
cache_manager_->Finalize();
if (npu_pool_memory_ != nullptr) {
LLM_CHK_ACL(rtFree(npu_pool_memory_));
}
if (host_pool_memory_ != nullptr) {
LLM_CHK_ACL(rtFreeHost(host_pool_memory_));
}
if (req_stream_ != nullptr) {
LLM_CHK_ACL(rtStreamDestroy(req_stream_));
}
if (transfer_stream_ != nullptr) {
LLM_CHK_ACL(rtStreamDestroy(transfer_stream_));
}
}
}
void DataCacheEngine::SetCommEntityManager(CommEntityManager *comm_entity_manager) {
comm_entity_manager_ = comm_entity_manager;
}
void DataCacheEngine::SetCommMemManager(CommMemManager *comm_mem_manager) {
comm_mem_manager_ = comm_mem_manager;
}
void DataCacheEngine::SetCacheManager(CacheManager *cache_manager) {
cache_manager_ = cache_manager;
}
ge::Status DataCacheEngine::InitializeDeviceMemoryPool(const std::map<ge::AscendString, ge::AscendString> &options) {
const auto it = options.find(LLM_OPTION_MEM_POOL_CONFIG);
if (it == options.cend()) {
LLMLOGI("memory pool is not enabled");
return ge::SUCCESS;
}
const std::string &json_str = it->second.GetString();
size_t page_shift = 16U;
LLM_CHK_STATUS_RET(ParseMemoryPoolConfig(json_str, npu_pool_size_, page_shift), "parse %s failed",
LLM_OPTION_MEM_POOL_CONFIG);
ScalableConfig config{};
config.page_idem_num = page_shift;
config.page_mem_size_total_threshold = npu_pool_size_;
npu_mem_pool_ = MakeUnique<LlmMemPool>(config);
LLM_CHECK_NOTNULL(npu_mem_pool_, "Failed to create memory pool");
LLM_CHK_BOOL_RET_STATUS(
rtMalloc(&npu_pool_memory_, npu_pool_size_, RT_MEMORY_HBM, LLM_MODULE_NAME_U16) == RT_ERROR_NONE,
ge::LLM_OUT_OF_MEMORY, "Failed to allocate memory for memory_pool, config = %s", json_str.c_str());
LLM_CHK_STATUS_RET(npu_mem_pool_->Initialize(npu_pool_memory_, npu_pool_size_),
"Failed to initialize memory pool, config = %s", json_str.c_str());
LLM_CHK_STATUS(
comm_mem_manager_->RegisterCommMemAddr(npu_pool_memory_, npu_pool_size_, HcclMemType::HCCL_MEM_TYPE_DEVICE));
cache_manager_->SetNpuMemPool(npu_mem_pool_.get());
LLMLOGI("npu memory_size = %lu B, page_shift = %zu, page_size = %lu B", npu_pool_size_, page_shift,
(1UL << page_shift));
return ge::SUCCESS;
}
ge::Status DataCacheEngine::InitializeHostMemoryPool(const std::map<ge::AscendString, ge::AscendString> &options) {
const auto it = options.find(LLM_OPTION_HOST_MEM_POOL_CONFIG);
if (it == options.cend()) {
LLMLOGI("host memory pool is not enabled");
return ge::SUCCESS;
}
const std::string &json_str = it->second.GetString();
size_t page_shift = 16U;
size_t host_pool_size = 0UL;
LLM_CHK_STATUS_RET(ParseMemoryPoolConfig(json_str, host_pool_size, page_shift), "parse %s failed",
LLM_OPTION_HOST_MEM_POOL_CONFIG);
ScalableConfig config{};
config.page_idem_num = page_shift;
config.page_mem_size_total_threshold = host_pool_size;
host_mem_pool_ = MakeUnique<LlmMemPool>(config);
LLM_CHECK_NOTNULL(host_mem_pool_);
LLM_CHK_ACL_RET(rtMallocHost(&host_pool_memory_, host_pool_size, LLM_MODULE_NAME_U16));
LLM_CHK_STATUS_RET(host_mem_pool_->Initialize(host_pool_memory_, host_pool_size),
"Failed to initialize host memory pool, config = %s", json_str.c_str());
LLM_CHK_STATUS_RET(
comm_mem_manager_->RegisterCommMemAddr(host_pool_memory_, host_pool_size, HCCL_MEM_TYPE_HOST));
cache_manager_->SetHostMemPool(host_mem_pool_.get());
LLMLOGI("host memory_size = %lu B, page_shift = %zu, page_size = %lu B", host_pool_size, page_shift,
(1UL << page_shift));
return ge::SUCCESS;
}
ge::Status DataCacheEngine::InitializeMemoryPool(const std::map<ge::AscendString, ge::AscendString> &options) {
LLM_CHK_STATUS_RET(InitializeDeviceMemoryPool(options), "initialize device memory pool failed");
LLM_CHK_STATUS_RET(InitializeHostMemoryPool(options), "initialize host memory pool failed");
return ge::SUCCESS;
}
ge::Status DataCacheEngine::Allocate(const CacheDesc &cache_desc, const std::vector<CacheKey> &cache_keys, Cache &cache) {
LLM_CHK_BOOL_RET_STATUS(((npu_pool_memory_ != nullptr) || (host_pool_memory_ != nullptr)), ge::LLM_FEATURE_NOT_ENABLED,
"memory pool is not enabled");
LLM_CHK_BOOL_RET_STATUS(
((npu_pool_memory_ != nullptr) && (cache_desc.placement == static_cast<uint32_t>(CachePlacement::DEVICE))) ||
((host_pool_memory_ != nullptr) && (cache_desc.placement == static_cast<uint32_t>(CachePlacement::HOST))),
ge::LLM_PARAM_INVALID, "placement must be set that matches memory pool config");
LLM_CHK_BOOL_RET_STATUS(!cache_desc.shape.empty() && (cache_desc.shape.size() < kMaxDimNum), ge::LLM_PARAM_INVALID,
"Invalid shape: %s, dim_num (%zu) must be in range [1, 33)",
ToString(cache_desc.shape).c_str(), cache_desc.shape.size());
const auto cache_id = cache_id_gen_.fetch_add(1, std::memory_order::memory_order_relaxed);
LLM_CHECK_GE(cache_id, 1);
LLM_CHK_STATUS_RET(cache_manager_->Allocate(cache_id, cache_desc, cache_keys, cache));
LLMLOGI("[cache_id:%ld][Allocate] success, num_tensors = %u, shape = %s", cache_id, cache_desc.num_tensors,
ToString(cache_desc.shape).c_str());
return ge::SUCCESS;
}
ge::Status DataCacheEngine::Deallocate(int64_t cache_id) const{
LLM_CHK_BOOL_RET_STATUS((npu_pool_memory_ != nullptr) || (host_pool_memory_ != nullptr), ge::LLM_FEATURE_NOT_ENABLED,
"memory pool is not enabled");
return cache_manager_->Deallocate(cache_id);
}
ge::Status DataCacheEngine::RemoveCacheKey(const CacheKey &cache_key) const {
LLM_CHK_BOOL_RET_STATUS((npu_pool_memory_ != nullptr) || (host_pool_memory_ != nullptr), ge::LLM_FEATURE_NOT_ENABLED,
"memory pool is not enabled");
return cache_manager_->RemoveCacheKey(cache_key);
}
ge::Status DataCacheEngine::CopyCache(const CopyCacheParam ©_cache_param) const {
LLM_CHK_ACL_RET(rtCtxSetCurrent(rt_context_));
return cache_manager_->CopyCache(copy_cache_param);
}
ge::Status DataCacheEngine::CheckCapacity(size_t size) {
LLM_CHK_BOOL_RET_STATUS(npu_mem_pool_ != nullptr, ge::LLM_FEATURE_NOT_ENABLED, "memory pool is not enabled");
auto ret = (npu_mem_pool_->AllocShared(size) != nullptr) ? ge::SUCCESS : ge::LLM_OUT_OF_MEMORY;
LLMLOGI("check size = %zu, check result = %u", size, ret);
return ret;
}
ge::Status DataCacheEngine::CheckParam(const CacheEntry &cache_entry, const PullCacheParam &pull_cache_param) {
LLM_CHK_BOOL_RET_STATUS(pull_cache_param.batch_index < cache_entry.batch_size, ge::LLM_PARAM_INVALID,
"dst_batch_index:%u out of range [0, %u)", pull_cache_param.batch_index,
cache_entry.batch_size);
if (pull_cache_param.decoder_blocks.empty()) {
LLM_CHK_BOOL_RET_STATUS(cache_entry.num_blocks == 0 || cache_entry.cache_mem_type == CacheMemType::MIX,
ge::LLM_PARAM_INVALID,
"check failed, request expect local cache is non-blocks");
LLM_CHK_BOOL_RET_STATUS((pull_cache_param.size < 0) ||
(static_cast<uint64_t>(pull_cache_param.size) <= cache_entry.stride),
ge::LLM_PARAM_INVALID,
"pull_size(%ld) > cache stride(%lu)",
pull_cache_param.size, cache_entry.stride);
} else {
LLM_CHK_BOOL_RET_STATUS(cache_entry.num_blocks > 0,
ge::LLM_PARAM_INVALID,
"check failed, request expect local cache is blocks");
}
if ((cache_entry.placement == CachePlacement::HOST) && (cache_entry.cache_mem_type == CacheMemType::BLOCKS)) {
for (const auto block_index : pull_cache_param.decoder_blocks) {
LLM_CHK_BOOL_RET_STATUS(block_index < cache_entry.num_blocks,
ge::LLM_PARAM_INVALID,
"local block index out of bound, index = %lu, num_blocks = %lu", block_index,
cache_entry.num_blocks);
}
LLM_CHK_BOOL_RET_STATUS(pull_cache_param.prompt_blocks.empty() ||
(pull_cache_param.decoder_blocks.size() == pull_cache_param.prompt_blocks.size()),
ge::LLM_PARAM_INVALID,
"check failed, src_block_index.size() = %zu, dst_block_index.size() = %zu",
pull_cache_param.prompt_blocks.size(),
pull_cache_param.decoder_blocks.size());
}
LLM_CHK_STATUS_RET(CheckTensorIndices(cache_entry, pull_cache_param), "tensor_indices is invalid");
LLM_CHK_BOOL_RET_STATUS(pull_cache_param.tensor_num_per_layer >= 1,
ge::LLM_PARAM_INVALID,
"check failed, tensor_num_per_layer expect [1, %lu]", cache_entry.cache_addrs.size());
return ge::SUCCESS;
}
ge::Status DataCacheEngine::CheckTensorIndices(const CacheEntry &cache_entry, const PullCacheParam &pull_cache_param) {
const size_t remainder = cache_entry.cache_addrs.size() % pull_cache_param.tensor_num_per_layer;
if ((!pull_cache_param.src_tensor_indices.empty()) || (!pull_cache_param.dst_tensor_indices.empty())) {
LLM_CHK_BOOL_RET_STATUS(remainder == 0U, ge::LLM_PARAM_INVALID,
"When using layer wise transfer, the tensor_num [%zu] of caches must be a multiple of tensor_num_per_layer[%lu].",
cache_entry.cache_addrs.size(), pull_cache_param.tensor_num_per_layer);
}
LLM_CHK_STATUS_RET(CheckTensorIndicesContinuous(pull_cache_param.src_tensor_indices),
"src_tensor_indices is not continuous");
LLM_CHK_STATUS_RET(CheckTensorIndicesContinuous(pull_cache_param.dst_tensor_indices),
"dst_tensor_indices is not continuous");
if (!pull_cache_param.dst_tensor_indices.empty()) {
LLM_CHK_BOOL_RET_STATUS(pull_cache_param.dst_tensor_indices.size() <= cache_entry.cache_addrs.size(),
ge::LLM_PARAM_INVALID, "dst_tensor_indices size[%zu] is out of range[0, %zu]",
pull_cache_param.dst_tensor_indices.size(), cache_entry.cache_addrs.size());
LLM_CHK_BOOL_RET_STATUS((pull_cache_param.dst_tensor_indices.front() < cache_entry.cache_addrs.size()) &&
(pull_cache_param.dst_tensor_indices.back() < cache_entry.cache_addrs.size()),
ge::LLM_PARAM_INVALID,
"dst_tensor_indices start index[%lu] or end index[%lu] is out of range[0, %zu)",
pull_cache_param.dst_tensor_indices.front(), pull_cache_param.dst_tensor_indices.back(),
cache_entry.cache_addrs.size());
}
if ((!pull_cache_param.src_tensor_indices.empty()) && (!pull_cache_param.dst_tensor_indices.empty())) {
LLM_CHK_BOOL_RET_STATUS(pull_cache_param.src_tensor_indices.size() == pull_cache_param.dst_tensor_indices.size(),
ge::LLM_PARAM_INVALID,
"src_tensor_indices size[%zu] is not match dst_tensor_indices size[%zu]",
pull_cache_param.src_tensor_indices.size(), pull_cache_param.dst_tensor_indices.size());
} else if (!pull_cache_param.src_tensor_indices.empty()) {
LLM_CHK_BOOL_RET_STATUS(pull_cache_param.src_tensor_indices.size() == cache_entry.cache_addrs.size(),
ge::LLM_PARAM_INVALID, "src_tensor_indices size[%zu] is not match dst_num_tensors size[%zu]",
pull_cache_param.src_tensor_indices.size(), cache_entry.cache_addrs.size());
} else {
}
return ge::SUCCESS;
}
ge::Status DataCacheEngine::TransferCache(const uint64_t task_id, const TransferCacheConfig &transfer_cache_config,
const TransferBlockConfig &transfer_block_config) {
CacheEntry cache_entry;
LLM_CHK_BOOL_RET_STATUS(cache_manager_->GetCacheEntry(transfer_cache_config.src_cache_id, cache_entry),
ge::LLM_KV_CACHE_NOT_EXIST, "cache id:%ld not found", transfer_cache_config.src_cache_id);
const auto entity = comm_entity_manager_->GetEntityByRemoteClusterId(transfer_cache_config.cluster_id);
LLM_CHK_BOOL_RET_STATUS(entity != nullptr, ge::LLM_NOT_YET_LINK,
"current cluster is not linked with remote cluster:%lu", transfer_cache_config.cluster_id);
std::lock_guard<std::mutex> pull_lock(entity->GetPullMutex());
LLM_CHK_BOOL_RET_STATUS((entity->GetCurState() != FsmState::FSM_DESTROYED_STATE), ge::LLM_NOT_YET_LINK,
"current cluster is not linked with remote cluster:%lu", transfer_cache_config.cluster_id);
LLM_CHK_BOOL_RET_STATUS(entity->CheckEntityInfo(), ge::LLM_NOT_YET_LINK,
"transfer cache must wait until the query_register_mem_status return ok");
LLM_CHK_BOOL_RET_STATUS(transfer_cache_config.tensor_num_per_layer >= 1,
ge::LLM_PARAM_INVALID,
"check failed, tensor_num_per_layer expect [1, %lu]", cache_entry.cache_addrs.size());
LLMLOGI("Transfer cache with tensor num per layer:%lu.", transfer_cache_config.tensor_num_per_layer);
TemporaryRtContext with_context(rt_context_);
rtError_t ret = RT_ERROR_NONE;
std::call_once(create_stream_once_flag_, [&ret, this]() {
ret = rtStreamCreateWithFlags(&transfer_stream_, RT_STREAM_PRIORITY_DEFAULT,
RT_STREAM_FAST_LAUNCH | RT_STREAM_FAST_SYNC);
});
LLM_ASSERT_RT_OK(ret, "create transfer stream failed");
LLM_ASSERT_NOTNULL(transfer_stream_, "transfer stream is nullptr");
LLM_DISMISSABLE_GUARD(abort_stream, [this]() -> void {
LLM_CHK_ACL(rtStreamAbort(transfer_stream_));
});
LayerWiseTransferJob layer_wise_transfer_job(*entity, transfer_stream_);
LLM_CHK_STATUS_RET(layer_wise_transfer_job.TransferCache(cache_entry, transfer_cache_config, transfer_block_config,
sync_cache_timeout_, access_remote_cache_),
"task:%lu of cluster:%lu transfer cache of layer[%lu] failed", task_id,
transfer_cache_config.cluster_id, transfer_cache_config.layer_index);
LLM_DISMISS_GUARD(abort_stream);
return ge::SUCCESS;
}
}