* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "prefix_cache_block_allocator.h"
using namespace std;
namespace mindie_llm {
PrefixCacheBlockAllocator::PrefixCacheBlockAllocator(BlockId beginBlockId, size_t numBlocks, size_t blockSize,
BlockObjPoolSPtr blockObjPool)
: beginBlockId_(beginBlockId),
blockSize_(blockSize),
numBlocks_(numBlocks),
blockComputedAttr_(numBlocks_, beginBlockId) {
if (!blockObjPool) {
throw std::invalid_argument("blockObjPool cannot be null");
}
for (size_t blockIdx = 0; blockIdx < numBlocks; blockIdx++) {
allBlockIndices_.push_back(blockIdx + beginBlockId);
freeBlockIndices_.push_back(blockIdx + beginBlockId);
}
refCounter_ = MakeRefCounterProtocol(allBlockIndices_);
if (!refCounter_) {
throw std::runtime_error("Failed to create refCounter");
}
cowTracker_ = CopyOnWriteTracker(refCounter_);
evictor_ = MakeEvictor(EvictionPolicy::LRU);
if (!evictor_) {
throw std::runtime_error("Failed to create evictor");
}
hitRateCalculator_ = make_shared<HitRateCalculator>();
blockObjPool_ = blockObjPool;
}
BlockObjSPtr PrefixCacheBlockAllocator::AllocateMutableBlock(vector<TokenId> &tokenIds, BlockObjSPtr prevBlock,
HashValue extraHash) {
BlockId blockId = AllocateBlockId();
shared_ptr<BlockObj> block = blockObjPool_->AcquireObj();
BlockAllocatorSPtr allocator = shared_from_this();
const BlockSharedAttr blockSharedAttr = {blockSize_, allocator};
std::vector<TokenId> emptyTokenIds;
block->InitBlockObj(prevBlock, emptyTokenIds, blockSharedAttr, blockId, extraHash);
AppendTokenIds(block, tokenIds);
return block;
}
size_t PrefixCacheBlockAllocator::GetCachedBlockNum(std::vector<HashValue> &hashValues) {
size_t cachedBlockNum = 0;
for (HashValue key : hashValues) {
auto it = cachedBlocks_.find(key);
if (it != cachedBlocks_.end()) {
cachedBlockNum++;
} else {
break;
}
}
return cachedBlockNum;
}
BlockObjSPtr PrefixCacheBlockAllocator::AllocateImmutableBlock(vector<TokenId> &tokenIds, BlockObjSPtr prevBlock,
HashValue extraHash) {
BlockId cachedBlockId = INVALID_BLOCKID;
BlockObjSPtr block = blockObjPool_->AcquireObj();
BlockAllocatorSPtr allocator = shared_from_this();
const BlockSharedAttr blockSharedAttr = {blockSize_, allocator};
block->InitBlockObj(prevBlock, tokenIds, blockSharedAttr, cachedBlockId, extraHash);
HashValue key = block->PrefixHash();
auto it = cachedBlocks_.find(key);
if (it != cachedBlocks_.end()) {
cachedBlockId = it->second;
}
if (cachedBlockId != INVALID_BLOCKID) {
block->SetBlockId(cachedBlockId);
hitRateCalculator_->Record(true);
IncrRefCountCacheBlock(block);
return block;
}
blockObjPool_->FreeObj(block);
hitRateCalculator_->Record(false);
block = AllocateMutableBlock(tokenIds, prevBlock, extraHash);
return block;
}
vector<BlockObjSPtr> PrefixCacheBlockAllocator::AllocateImmutableBlocks(vector<vector<TokenId>> &tokenIds,
BlockObjSPtr prevBlock, HashValue extraHash) {
vector<BlockObjSPtr> blockObjs = {};
BlockObjSPtr frontBlock = prevBlock;
for (auto tokens : tokenIds) {
frontBlock = AllocateImmutableBlock(tokens, frontBlock, extraHash);
blockObjs.push_back(frontBlock);
}
return blockObjs;
}
BlockId PrefixCacheBlockAllocator::MayBeAllocateEvictedBlockId() {
EvictionResult evictionResult = evictor_->Evict();
BlockId blockId = evictionResult.blockId;
HashValue key = evictionResult.prefixHash;
if (blockId == INVALID_BLOCKID) {
throw runtime_error("BlockAllocator no blockId in evict!");
}
blockId = cachedBlocks_.at(key);
if (blockId != evictionResult.blockId) {
throw runtime_error("BlockAllocator blockId in evict unmatch blockId in cache!");
}
if (blockId >= beginBlockId_ + numBlocks_ || blockId < beginBlockId_) {
throw runtime_error("BlockId out of range!");
}
cachedBlocks_.erase(key);
refCounter_->Increase(blockId);
TrackBlockId(blockId, false);
return blockId;
}
BlockId PrefixCacheBlockAllocator::AllocateBlockId() {
if (!freeBlockIndices_.empty()) {
BlockId blockId = freeBlockIndices_.front();
freeBlockIndices_.pop_front();
refCounter_->Increase(blockId);
TrackBlockId(blockId, false);
return blockId;
}
return MayBeAllocateEvictedBlockId();
}
void PrefixCacheBlockAllocator::FreeBlockId(BlockObjSPtr &block) {
BlockId blockId = block->GetBlockId();
if (block->PrefixHash() != INVALID_HASH_VALUE) {
DecrRefCountCacheBlock(block);
} else {
RefCount refCount = refCounter_->Decrease(blockId);
if (refCount == 0) {
UntrackBlockId(blockId);
freeBlockIndices_.push_back(blockId);
}
block->ResetBlockId();
}
}
void PrefixCacheBlockAllocator::Free(BlockObjSPtr &block, bool keepBlockObj) {
FreeBlockId(block);
if (!keepBlockObj) {
blockObjPool_->FreeObj(block);
}
}
vector<BlockObjSPtr> PrefixCacheBlockAllocator::Fork(BlockObjSPtr &lastBlockObj) {
std::deque<BlockObjSPtr> sourceBlocks = {};
BlockObjSPtr visitBlockObj = lastBlockObj;
while (visitBlockObj) {
sourceBlocks.push_front(visitBlockObj);
visitBlockObj = visitBlockObj->GetPrevBlock();
}
std::vector<BlockObjSPtr> forkedBlocks = {};
BlockObjSPtr prevBlock = nullptr;
for (BlockObjSPtr sourceBlock : sourceBlocks) {
RefCount refCount = refCounter_->Increase(sourceBlock->GetBlockId());
constexpr long validRefCountThreshold = 2;
if (refCount < validRefCountThreshold) {
throw std::runtime_error("can't fork free'd block!");
}
std::vector<TokenId> tokenIds = sourceBlock->GetTokenIds();
BlockObjSPtr forkedBlock = blockObjPool_->AcquireObj();
BlockAllocatorSPtr allocator = shared_from_this();
const BlockSharedAttr blockSharedAttr = {blockSize_, allocator};
forkedBlock->InitBlockObj(prevBlock, tokenIds, blockSharedAttr, sourceBlock->GetBlockId(),
sourceBlock->ExtraHash());
forkedBlocks.push_back(forkedBlock);
prevBlock = forkedBlock;
}
return forkedBlocks;
}
void PrefixCacheBlockAllocator::IncrRefCountCacheBlock(BlockObjSPtr block) {
BlockId blockId = block->GetBlockId();
RefCount refCount = refCounter_->Increase(blockId);
if (refCount == 1) {
if (evictor_->ContainsBlock(blockId)) {
evictor_->Remove(blockId);
}
TrackBlockId(blockId, true);
}
}
void PrefixCacheBlockAllocator::DecrRefCountCacheBlock(BlockObjSPtr block) {
BlockId blockId = block->GetBlockId();
block->SetBlockId(INVALID_BLOCKID);
RefCount refCount = refCounter_->Decrease(blockId);
if (refCount < 0) {
throw runtime_error("Prefix blocks blockId counter has been decremented to a negative value.");
}
if (refCount == 0) {
HashValue prefixHash = block->PrefixHash();
BlockId reuseBlockId;
auto it = cachedBlocks_.find(prefixHash);
if (it != cachedBlocks_.end()) {
reuseBlockId = it->second;
} else {
throw runtime_error("Prefix blocks blockId not in cache.");
}
if (blockId != reuseBlockId) {
throw runtime_error("Prefix blocks blockId unmatch.");
}
evictor_->Add(blockId, block->PrefixHash(), block->GetNumTokensTotal(),
blockComputedAttr_.LastAccessed(blockId));
UntrackBlockId(blockId);
}
}
BlockId PrefixCacheBlockAllocator::PromoteToImmutableBlock(const BlockObjSPtr &block) {
HashValue prefixHash = block->PrefixHash();
auto it = cachedBlocks_.find(prefixHash);
if (it == cachedBlocks_.end()) {
cachedBlocks_[prefixHash] = block->GetBlockId();
touchedBlocks_.insert(block->GetBlockId());
return block->GetBlockId();
} else {
RefCount refCount = refCounter_->Decrease(block->GetBlockId());
if (refCount == 0) {
UntrackBlockId(block->GetBlockId());
freeBlockIndices_.push_back(block->GetBlockId());
}
block->ResetBlockId();
BlockId reuseBlockId = it->second;
block->SetBlockId(reuseBlockId);
IncrRefCountCacheBlock(block);
return reuseBlockId;
}
}
void PrefixCacheBlockAllocator::MarkBlocksAsAccessed(const vector<BlockId> &blockIds, TimeStamp now) {
for (BlockId blockId : blockIds) {
if (blockComputedAttr_.IsActive(blockId)) {
blockComputedAttr_.UpdateAccessTime(blockId, now);
} else if (evictor_->ContainsBlock(blockId)) {
evictor_->Update(blockId, now);
throw runtime_error("Block in evictor can not update time.");
} else {
throw runtime_error("Mark block as accessed which is not belonged to NPU.");
}
}
}
void PrefixCacheBlockAllocator::MarkBlocksAsComputed() {
for (BlockId blockId : touchedBlocks_) {
blockComputedAttr_.SetComputed(blockId, true);
}
touchedBlocks_.clear();
}
BlockId PrefixCacheBlockAllocator::CowBlockIfNotAppendable(BlockObjSPtr &block) {
BlockId srcBlockId = block->GetBlockId();
if (srcBlockId == INVALID_BLOCKID) {
throw runtime_error("Cow block id invalid.");
}
if (cowTracker_.IsAppendable(srcBlockId)) {
return srcBlockId;
}
FreeBlockId(block);
BlockId trgBlockId = AllocateBlockId();
cowTracker_.RecordCow(srcBlockId, trgBlockId);
return trgBlockId;
}
vector<BlockId> PrefixCacheBlockAllocator::GetCommonComputedBlockIds(
const std::vector<std::vector<BlockId>> &computedSeqBlockIds) {
if (computedSeqBlockIds.empty()) {
return {};
}
if (computedSeqBlockIds.size() == 1) {
return computedSeqBlockIds[0];
}
std::vector<BlockId> commonBlockIds = computedSeqBlockIds[0];
for (size_t i = 1; i < computedSeqBlockIds.size(); i++) {
const vector<BlockId> ¤t = computedSeqBlockIds[i];
size_t j = 0;
while (j < commonBlockIds.size() && j < current.size() && current[j] == commonBlockIds[j]) {
j++;
}
commonBlockIds.resize(j);
if (commonBlockIds.empty()) {
break;
}
}
return commonBlockIds;
}
size_t PrefixCacheBlockAllocator::GetNumTotalBlocks() const { return allBlockIndices_.size(); }
size_t PrefixCacheBlockAllocator::GetNumFreeBlock() const {
return freeBlockIndices_.size() + evictor_->GetNumblocks();
}
size_t PrefixCacheBlockAllocator::GetNumFullBlocksTouched(const std::vector<BlockObjSPtr> &blocks) {
int numTouchedBlocks = 0;
for (BlockObjSPtr block : blocks) {
if (!block->IsFull()) {
continue;
}
if (!IsBlockCached(block) ||
(block->PrefixHash() != INVALID_HASH_VALUE && evictor_->ContainsBlock(block->GetBlockId()))) {
numTouchedBlocks++;
}
}
return numTouchedBlocks;
}
void PrefixCacheBlockAllocator::SwapOut(std::vector<BlockObjSPtr> &blocks) {
for (BlockObjSPtr block : blocks) {
FreeBlockId(block);
}
}
void PrefixCacheBlockAllocator::SwapIn(std::vector<BlockObjSPtr> &blocks) {
for (BlockObjSPtr block : blocks) {
BlockObjSPtr tmpBlock;
if (block->IsFull()) {
std::vector<TokenId> token = block->GetTokenIds();
tmpBlock = AllocateImmutableBlock(token, block->GetPrevBlock(), block->ExtraHash());
} else {
std::vector<TokenId> token = block->GetTokenIds();
tmpBlock = AllocateMutableBlock(token, block->GetPrevBlock(), block->ExtraHash());
}
block->SetBlockId(tmpBlock->GetBlockId());
blockObjPool_->FreeObj(tmpBlock);
}
}
bool PrefixCacheBlockAllocator::FindCachedBlockPrefix(HashValue blockHash) const {
if (blockHash == INVALID_HASH_VALUE) {
return false;
}
auto cachedIt = cachedBlocks_.find(blockHash);
if (cachedIt != cachedBlocks_.end()) {
return IsBlockComputed(cachedIt->second);
}
return false;
}
std::vector<BlockId> PrefixCacheBlockAllocator::FindCachedBlocksPrefix(std::vector<HashValue> &blockHashes) const {
std::vector<BlockId> result;
for (auto blockHash : blockHashes) {
if (IsBlockCached(blockHash)) {
result.push_back(cachedBlocks_.at(blockHash));
} else {
break;
}
}
return result;
}
float PrefixCacheBlockAllocator::GetPrefixCacheHitRate() const {
double ret = hitRateCalculator_->GetHitRate();
return static_cast<float>(ret);
}
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking. */
bool PrefixCacheBlockAllocator::ResetPrefixCache() {
int useBlockNum = static_cast<int>(GetNumTotalBlocks()) - static_cast<int>(GetNumFreeBlock());
if (useBlockNum > 0) {
return false;
}
allBlockIndices_.clear();
freeBlockIndices_.clear();
for (size_t blockIdx = 0; blockIdx < numBlocks_; blockIdx++) {
allBlockIndices_.push_back(blockIdx + beginBlockId_);
freeBlockIndices_.push_back(blockIdx + beginBlockId_);
}
refCounter_ = MakeRefCounterProtocol(allBlockIndices_);
cowTracker_ = CopyOnWriteTracker(refCounter_);
evictor_ = MakeEvictor(EvictionPolicy::LRU);
return true;
}
bool PrefixCacheBlockAllocator::IsBlockCached(const BlockObjSPtr &block) const {
HashValue prefixHash = block->PrefixHash();
if (prefixHash == INVALID_HASH_VALUE) {
return false;
}
auto it = cachedBlocks_.find(prefixHash);
if (it != cachedBlocks_.end()) {
return true;
}
return false;
}
void PrefixCacheBlockAllocator::TrackBlockId(BlockId blockId, bool computed)
{
blockComputedAttr_.Enable(blockId);
blockComputedAttr_.SetComputed(blockId, computed);
}
void PrefixCacheBlockAllocator::UntrackBlockId(BlockId blockId) { blockComputedAttr_.Disable(blockId); }
bool PrefixCacheBlockAllocator::IsBlockComputed(BlockId blockId) const {
if (blockComputedAttr_.IsActive(blockId)) {
return blockComputedAttr_.IsComputed(blockId);
} else {
return evictor_->ContainsBlock(blockId);
}
}
bool PrefixCacheBlockAllocator::IsBlockCached(const HashValue prefixHash) const {
auto it = cachedBlocks_.find(prefixHash);
if (it == cachedBlocks_.end()) {
return false;
}
BlockId cachedBlockId = it->second;
return IsBlockComputed(cachedBlockId);
}
std::vector<std::pair<BlockId, BlockId>> PrefixCacheBlockAllocator::ClearCopyOnWrites() {
std::vector<std::pair<BlockId, BlockId>> ret = cowTracker_.ClearCows();
return ret;
}
void PrefixCacheBlockAllocator::AppendTokenIds(BlockObjSPtr blockObj, const std::vector<TokenId> &tokenIds) {
blockObj->AppendTokenIds(tokenIds);
BlockId blockId = CowBlockIfNotAppendable(blockObj);
blockObj->SetBlockId(blockId);
if (blockObj->PrefixHash() != INVALID_HASH_VALUE) {
PromoteToImmutableBlock(blockObj);
}
}
void PrefixCacheBlockAllocator::ReplaceToken(BlockObjSPtr blockObj, size_t startIndex, TokenId newToken) {
blockObj->ReplaceToken(startIndex, newToken);
if (blockObj->PrefixHash() != INVALID_HASH_VALUE) {
PromoteToImmutableBlock(blockObj);
}
}
}