* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "self_attn_block_manager.h"
#include <numeric>
#include "cpu_npu_block_allocator.h"
#include "log.h"
#include "lwd_self_attn_block_manager.h"
#include "math_utils.h"
#include "msServiceProfiler/msServiceProfiler.h"
namespace mindie_llm {
SelfAttnBlockManager::SelfAttnBlockManager(const BlockManagerConfig &config, size_t localDPRank)
: blockSize_(config.cacheBlockSize),
cpuBlockNum_(config.cpuBlockNum),
npuBlockNum_(config.npuBlockNum),
reservedBlockNum_(config.reservedBlockNum),
speculativeSlots_(config.speculativeSlots),
enableCaching_(config.enableCaching),
localDPRank_(localDPRank),
rankSize_(config.rankSize),
hostSize_(config.hostSize),
allocationMode_(config.allocationMode) {
if (reservedBlockNum_ > npuBlockNum_) {
throw std::invalid_argument("The num of reserved block is larger than npu block num");
}
if (rankSize_ == 0 || hostSize_ == 0) {
throw std::invalid_argument("The rank and host size must be greater than 0");
}
size_t allocatableNpuBlockNum = npuBlockNum_ - reservedBlockNum_;
PROF(INFO, AddMetaInfo("allocatableBlockNum", allocatableNpuBlockNum));
PROF(INFO, AddMetaInfo("npuBlockNum", npuBlockNum_));
PROF(INFO, AddMetaInfo("reservedBlockNum", reservedBlockNum_));
AllocatorConfig allocatorConfig;
allocatorConfig.allocatorType = enableCaching_ ? BlockAllocatorType::PREFIXCACHING : BlockAllocatorType::HASHLESS;
allocatorConfig.numCpuBlocks = cpuBlockNum_;
allocatorConfig.numNpuBlocks = allocatableNpuBlockNum;
allocatorConfig.blockSize = blockSize_;
allocatorConfig.rankSize = rankSize_;
blockAllocator_ = std::make_shared<CpuNpuBlockAllocator>(allocatorConfig);
seqsBlockComputedTracker_ = SeqsBlocksComputedTracker(blockAllocator_, blockSize_, enableCaching_, rankSize_);
seqsLastAccessBlocksTracker_ = SeqsLastAccessBlocksTracker(blockAllocator_);
if (config.enableKvPool) {
PyGILState_STATE gstate = PyGILState_Ensure();
py::object memPoolCls_ = py::module_::import("mindie_llm.text_generator.mempool").attr("MemPool");
py::object memPool_ = memPoolCls_.attr("create_pool")(config.cachePoolBackend, config.cachePoolConfigPath);
memPoolInstance_ = std::make_shared<MemPool>(std::make_shared<py::object>(memPool_));
PyGILState_Release(gstate);
}
MINDIE_LLM_LOG_INFO("SelfAttnBlockManager init success!");
}
BlockSpaceManagerSPtr BlockManagerFactory::CreateBlockSpaceManager(BlockManagerType type,
const BlockManagerConfig &config,
size_t localDPRank) {
switch (type) {
case BlockManagerType::SELFATTNBLOCKMANAGER:
return std::make_shared<SelfAttnBlockManager>(config, localDPRank);
case BlockManagerType::LWDSELFATTNBLOCKMANAGER:
return std::make_shared<LwdSelfAttnBlockManager>(config, localDPRank);
default:
throw std::invalid_argument("Invalid block manager type");
}
}
size_t SelfAttnBlockManager::GetNumRequiredBlocks(size_t seqLen, size_t blockSize) const {
if (blockSize == 0) {
throw std::runtime_error("the blockSize should not be zero");
}
size_t num = (seqLen + blockSize - 1) / blockSize;
if (rankSize_ > 1) {
num += 1;
}
return num;
}
AllocStatus SelfAttnBlockManager::CanAllocate(const SequenceGroupSPtr &seqGroup) const {
std::vector<SequenceSPtr> waitingSeqs = seqGroup->GetFirstSequence(SequenceStatus::WAITING);
if (waitingSeqs.empty()) {
return AllocStatus::NEVER;
}
const std::vector<TokenId> &tokenIds = waitingSeqs.at(0)->GetTokenIds();
size_t numRequiredBlocks = GetNumRequiredBlocks(tokenIds.size(), blockSize_);
size_t numFreeNpuBlocks = blockAllocator_->GetNumFreeBlock(DeviceType::NPU);
if (reservedBlockNum_ > npuBlockNum_) {
throw std::runtime_error("The num of reserved block is larger than npu block num");
}
if ((npuBlockNum_ * rankSize_ - reservedBlockNum_) < numRequiredBlocks) {
return AllocStatus::NEVER;
} else if (numFreeNpuBlocks >= numRequiredBlocks) {
return AllocStatus::OK;
}
return AllocStatus::LATER;
}
bool SelfAttnBlockManager::Allocate(const SequenceGroupSPtr &seqGroup) {
std::vector<SequenceSPtr> waitingSeqs = seqGroup->GetFirstSequence(SequenceStatus::WAITING);
if (waitingSeqs.empty()) {
return false;
}
for (const auto &sequence : waitingSeqs) {
SequenceId sequenceId = sequence->seqId_;
auto it = seqId2BlockTable_.find(sequenceId);
if (it != seqId2BlockTable_.end()) {
return false;
}
}
SequenceSPtr sequence = waitingSeqs.at(0);
BlockTable blockTable = BlockTable(blockSize_, blockAllocator_, rankSize_);
const std::vector<TokenId> &tokenIds = sequence->GetTokenIds();
if (tokenIds.empty()) {
return false;
}
if (rankSize_ == 1) {
blockTable.Allocate(tokenIds, DeviceType::NPU, sequence->GetExtraHash());
} else {
blockTable.AllocateSmallRankFirst(tokenIds, DeviceType::NPU, sequence->GetExtraHash());
}
seqId2BlockTable_[sequence->seqId_] = blockTable;
seqsLastAccessBlocksTracker_.AddSeq(sequence->seqId_);
for (auto it = waitingSeqs.begin() + 1; it != waitingSeqs.end(); ++it) {
SequenceId seqId = (*it)->seqId_;
seqId2BlockTable_[seqId] = blockTable.Fork();
seqsLastAccessBlocksTracker_.AddSeq(seqId);
}
return true;
}
bool SelfAttnBlockManager::CanAppendSlot(const SequenceGroupSPtr &seqGroup) const {
size_t numRelatedBlocks = 0;
std::vector<SequenceSPtr> runningSeqs = seqGroup->GetSequences(SequenceStatus::RUNNING);
for (auto &sequence : runningSeqs) {
SequenceId seqId = sequence->seqId_;
auto it = seqId2BlockTable_.find(seqId);
if (it == seqId2BlockTable_.end()) {
return false;
}
const BlockTable &blockTable = it->second;
const std::vector<TokenId> &tokenIds = sequence->GetTokenIds();
size_t tokenIdSize = blockTable.GetNewGenTokenIds(tokenIds).size();
numRelatedBlocks += blockTable.GetNumRelatedBlocks(tokenIdSize, speculativeSlots_);
}
size_t numFreeNpuBlocks = blockAllocator_->GetNumFreeBlock(DeviceType::NPU);
return numRelatedBlocks <= numFreeNpuBlocks;
}
std::vector<std::pair<BlockId, BlockId>> SelfAttnBlockManager::AppendSlot(const SequenceSPtr &sequence) {
if (rankSize_ > 1) {
throw std::runtime_error("throw not supported exception.");
}
SequenceId seqId = sequence->seqId_;
BlockTable &blockTable = seqId2BlockTable_.at(seqId);
const std::vector<TokenId> &tokenIds = sequence->GetTokenIds();
blockTable.AppendTokenIds(blockTable.GetNewGenTokenIds(tokenIds), sequence->GetExtraHash(), speculativeSlots_);
std::vector<std::pair<BlockId, BlockId>> block2Copy = blockAllocator_->ClearCopyOnWrites();
return block2Copy;
}
bool SelfAttnBlockManager::CanAppendSlotNew(const SequenceGroupSPtr &seqGroup) const {
std::vector<SequenceSPtr> runningSeqs = seqGroup->GetFirstSequence(SequenceStatus::RUNNING);
if (runningSeqs.empty()) {
return false;
}
const SequenceSPtr sequence = runningSeqs.at(0);
const BlockTable &blockTable = seqId2BlockTable_.at(sequence->seqId_);
const std::vector<TokenId> &tokenIds = sequence->GetTokenIds();
std::vector<size_t> numFullSlotsPerRank = GetTokenCountPerRank(sequence->seqId_);
size_t fullSlotsNum = std::accumulate(numFullSlotsPerRank.begin(), numFullSlotsPerRank.end(), size_t(0));
std::vector<TokenId> appendTokenIds(tokenIds.begin() + fullSlotsNum, tokenIds.end());
return blockTable.CanAppendNewTokens(appendTokenIds, speculativeSlots_);
}
void SelfAttnBlockManager::AppendSlotNew(const SequenceGroupSPtr &seqGroup) {
std::vector<SequenceSPtr> runningSeqs = seqGroup->GetFirstSequence(SequenceStatus::RUNNING);
if (runningSeqs.empty()) {
return;
}
const SequenceSPtr sequence = runningSeqs.at(0);
BlockTable &blockTable = seqId2BlockTable_.at(sequence->seqId_);
const std::vector<TokenId> &tokenIds = sequence->GetTokenIds();
std::vector<size_t> numFullSlotsPerRank = GetTokenCountPerRank(sequence->seqId_);
size_t fullSlotsNum = std::accumulate(numFullSlotsPerRank.begin(), numFullSlotsPerRank.end(), size_t(0));
std::vector<TokenId> appendTokenIds(tokenIds.begin() + fullSlotsNum, tokenIds.end());
blockTable.AppendNewTokens(appendTokenIds, sequence->GetExtraHash(), speculativeSlots_);
}
void SelfAttnBlockManager::AppendTokenToLatestRank(SequenceId seqId, const std::vector<TokenId> &tokens) {
BlockTable &blockTable = seqId2BlockTable_[seqId];
blockTable.AppendToSpRank(blockTable.GetLatestAppendedRankId(), tokens);
}
bool SelfAttnBlockManager::IsAppendBlock(SequenceId seqId) {
BlockTable &blockTable = seqId2BlockTable_[seqId];
return blockTable.IsAppendBlock();
}
void SelfAttnBlockManager::Free(SequenceId seqId) {
auto it = seqId2BlockTable_.find(seqId);
if (it == seqId2BlockTable_.end()) {
return;
}
BlockTable &blockTable = it->second;
std::vector<std::vector<BlockId>> rankedBlockIds;
GetRankedBlockIds(seqId, rankedBlockIds);
seqsLastAccessBlocksTracker_.UpdateSeqBlocksLastAccess(seqId, rankedBlockIds);
seqsLastAccessBlocksTracker_.RemoveSeq(seqId);
seqsBlockComputedTracker_.RemoveSeq(seqId);
blockTable.Free();
seqId2BlockTable_.erase(it);
}
std::vector<BlockIds> SelfAttnBlockManager::GetBlockIds(SequenceId seqId) const {
return {seqId2BlockTable_.at(seqId).GetBlockIds()};
}
void SelfAttnBlockManager::GetRankedBlockIds(SequenceId seqId, std::vector<RankedBlockId> &rankedBlockIds) const {
rankedBlockIds.clear();
std::vector<BlockObjSPtr> blocks = seqId2BlockTable_.at(seqId).GetBlockObjs();
for (const auto &block : blocks) {
RankedBlockId rankedBlockId = {block->GetBlockId(), block->GetRankIdx()};
rankedBlockIds.push_back(rankedBlockId);
}
}
void SelfAttnBlockManager::GetRankedBlockIds(SequenceId seqId,
std::vector<std::vector<BlockId>> &rankedBlockIds) const {
rankedBlockIds.clear();
rankedBlockIds.resize(rankSize_);
std::vector<BlockObjSPtr> blocks = seqId2BlockTable_.at(seqId).GetBlockObjs();
for (const auto &block : blocks) {
size_t rankIdx = block->GetRankIdx();
BlockId blockId = block->GetBlockId();
rankedBlockIds[rankIdx].push_back(blockId);
}
}
std::vector<std::vector<HashValue>> SelfAttnBlockManager::GetRankedHashValues(SequenceId seqId) const {
std::vector<std::vector<HashValue>> rankedHashValues;
rankedHashValues.resize(rankSize_);
std::vector<BlockObjSPtr> blocks = seqId2BlockTable_.at(seqId).GetBlockObjs();
for (const auto &block : blocks) {
size_t rankIdx = block->GetRankIdx();
HashValue hashValue = block->GetHashValue();
rankedHashValues[rankIdx].push_back(hashValue);
}
return rankedHashValues;
}
std::vector<HashValue> SelfAttnBlockManager::GetSeqHashValues(SequenceId seqId) const {
std::vector<HashValue> seqHashValues;
std::vector<BlockObjSPtr> blocks = seqId2BlockTable_.at(seqId).GetBlockObjs();
for (const auto &block : blocks) {
HashValue hashValue = block->GetHashValue();
seqHashValues.push_back(hashValue);
}
return seqHashValues;
}
void SelfAttnBlockManager::AccessAllblocksInSeq(const SequenceSPtr &seq, float accessTime) {
if (enableCaching_) {
seqsLastAccessBlocksTracker_.UpdateSeqLastAccess(seq->seqId_, accessTime);
}
}
void SelfAttnBlockManager::MarkBlocksAsComputed() { blockAllocator_->MarkBlocksAsComputed(); }
std::vector<size_t> SelfAttnBlockManager::GetAllrankComputedBlockNum(const std::vector<SequenceSPtr> &seqs) {
size_t maxCachedBlocksPreRank = 0;
std::vector<std::vector<std::vector<BlockId>>> rankedComputedSeqBlockIds;
for (size_t rankIdx = 0; rankIdx < rankSize_; rankIdx++) {
std::vector<std::vector<BlockId>> computedSeqBlockIds;
for (const auto &seq : seqs) {
std::vector<std::vector<HashValue>> rankedHashValues = GetRankedHashValues(seq->seqId_);
std::vector<BlockId> allBlocks = seqId2BlockTable_.at(seq->seqId_).GetRankedBlockIds(rankIdx);
size_t numCachedBlocks = seqsBlockComputedTracker_.GetCachedTokensNum(
seq, rankIdx, rankedHashValues[rankIdx], seq->IsPrefill()) /
blockSize_;
std::vector<BlockId> computedBlockIds(allBlocks.begin(), allBlocks.begin() + numCachedBlocks);
computedSeqBlockIds.push_back(computedBlockIds);
maxCachedBlocksPreRank = std::max(maxCachedBlocksPreRank, allBlocks.size());
}
rankedComputedSeqBlockIds.push_back(computedSeqBlockIds);
}
std::vector<size_t> allRankComputedBlockNum =
blockAllocator_->GetAllRankCommonComputedBlockNum(rankedComputedSeqBlockIds);
for (auto &computedBlockNum : allRankComputedBlockNum) {
maxCachedBlocksPreRank = std::min(maxCachedBlocksPreRank, computedBlockNum);
}
bool flag = true;
for (auto &computedBlockNum : allRankComputedBlockNum) {
if (flag && computedBlockNum > maxCachedBlocksPreRank) {
computedBlockNum = maxCachedBlocksPreRank + 1;
} else {
flag = false;
computedBlockNum = maxCachedBlocksPreRank;
}
}
return allRankComputedBlockNum;
}
std::vector<BlockId> SelfAttnBlockManager::GetCommonComputedBlockIds(const std::vector<SequenceSPtr> &seqs) {
std::vector<std::vector<BlockId>> computedSeqBlockIds;
for (const auto &seq : seqs) {
const std::vector<BlockId> &allBlocks = seqId2BlockTable_.at(seq->seqId_).GetBlockIds();
size_t numCachedTokens = GetNumCachedTokens(seq);
if (numCachedTokens % blockSize_ != 0) {
throw std::runtime_error("The number of cached tokens is not a multiple of block size.");
}
size_t numCachedBlocks = numCachedTokens / blockSize_;
std::vector<BlockId> computedBlockIds(allBlocks.begin(), allBlocks.begin() + numCachedBlocks);
computedSeqBlockIds.push_back(computedBlockIds);
}
return blockAllocator_->GetCommonComputedBlockIds(computedSeqBlockIds);
}
std::vector<BlockId> SelfAttnBlockManager::GetRemoteComputedBlockIds(const std::vector<SequenceSPtr> &seqs,
size_t computedLens, uint32_t tpSize,
std::string modelName) {
std::vector<std::vector<BlockId>> remoteComputedSeqBlockIds;
for (const auto &seq : seqs) {
const std::vector<BlockId> &allBlocks = seqId2BlockTable_.at(seq->seqId_).GetBlockIds();
size_t numCachedBlocks = computedLens;
std::vector<HashValue> hashValues = GetSeqHashValues(seq->seqId_);
std::vector<std::string> allKeys;
for (size_t i = computedLens; i < hashValues.size(); i++) {
HashValue hashValue = hashValues[i];
for (uint32_t tpIds = 0; tpIds < tpSize; tpIds++) {
std::string key = std::to_string(hashValue) + "_" + std::to_string(tpIds) + "_" +
std::to_string(tpSize) + "_" + modelName;
allKeys.push_back(key);
}
}
if (!allKeys.empty() && tpSize > 0) {
std::vector<bool> lookRes = memPoolInstance_->LookUp(allKeys);
size_t numElem = 0;
while (numElem < lookRes.size() && lookRes[numElem]) {
++numElem;
}
numCachedBlocks += numElem / tpSize;
}
std::vector<BlockId> remoteComputedBlockIds(allBlocks.begin(), allBlocks.begin() + numCachedBlocks);
remoteComputedSeqBlockIds.push_back(remoteComputedBlockIds);
}
return blockAllocator_->GetCommonComputedBlockIds(remoteComputedSeqBlockIds);
}
std::vector<size_t> SelfAttnBlockManager::GetAllRankRemoteComputedBlockIds(const std::vector<SequenceSPtr> &seqs,
std::vector<size_t> &computedBlocksNum,
std::string modelName) {
if (seqs.size() > 1) {
throw std::runtime_error("`Kv pool` do not support `splitfuse`!");
}
const auto &seq = seqs[0];
size_t rankIdx = 0;
while (rankIdx < rankSize_ && computedBlocksNum.back() < computedBlocksNum[rankIdx]) {
rankIdx = (rankIdx + 1) % rankSize_;
}
std::vector<size_t> remoteComputedBlocksNum(computedBlocksNum);
std::vector<std::vector<HashValue>> rankedHashValues = GetRankedHashValues(seq->seqId_);
std::vector<std::string> allKeys;
for (size_t i = remoteComputedBlocksNum[rankIdx]; i < rankedHashValues[0].size(); i++) {
size_t curRankIdx = i == remoteComputedBlocksNum[rankIdx] ? rankIdx : 0;
while (curRankIdx < rankSize_ && i < rankedHashValues[curRankIdx].size()) {
HashValue hashValue = rankedHashValues[curRankIdx][i];
std::string key = std::to_string(hashValue) + "_" + std::to_string(curRankIdx) + "_" +
std::to_string(rankSize_) + "_" + modelName;
allKeys.push_back(key);
curRankIdx++;
}
}
if (!allKeys.empty()) {
std::vector<bool> batchLookupResult = memPoolInstance_->LookUp(allKeys);
size_t startIdx = 0;
for (auto num : computedBlocksNum) {
startIdx += num;
}
for (size_t idx = 0; idx < batchLookupResult.size() && batchLookupResult[idx]; idx++) {
remoteComputedBlocksNum[(idx + startIdx) % rankSize_]++;
}
}
return remoteComputedBlocksNum;
}
void SelfAttnBlockManager::Fork(SequenceSPtr &parentSeq, SequenceSPtr &childSeq) {
auto it = seqId2BlockTable_.find(parentSeq->seqId_);
if (it == seqId2BlockTable_.end()) {
throw std::runtime_error("SequenceId not found in seqId2BlockTable_ When Fork sequence");
}
BlockTable &blockTable = it->second;
seqId2BlockTable_[childSeq->seqId_] = blockTable;
seqsLastAccessBlocksTracker_.AddSeq(childSeq->seqId_);
}
AllocStatus SelfAttnBlockManager::CanSwap(const SequenceGroupSPtr &seqGroup, DeviceType dstDeviceType,
SequenceStatus status, size_t numLookaheads) {
size_t numRelatedBlocks = 0;
for (const auto &seq : seqGroup->GetFirstSequence(status)) {
BlockTable &blockTable = seqId2BlockTable_.at(seq->seqId_);
std::vector<BlockObjSPtr> blocksObj = blockTable.GetBlockObjs();
numRelatedBlocks += CeilDiv(seq->GetLen() + numLookaheads, blockSize_);
}
size_t numTotalBlocks = blockAllocator_->GetNumTotalBlocks(dstDeviceType);
size_t numFreeBlocks = blockAllocator_->GetNumFreeBlock(dstDeviceType);
if (numRelatedBlocks > numTotalBlocks) {
return AllocStatus::NEVER;
} else if (numFreeBlocks >= numRelatedBlocks) {
return AllocStatus::OK;
}
return AllocStatus::LATER;
}
AllocStatus SelfAttnBlockManager::CanSwapIn(const SequenceGroupSPtr &seqGroup, size_t numLookheadSlots) {
return CanSwap(seqGroup, DeviceType::NPU, SequenceStatus::SWAPPED, numLookheadSlots);
}
std::vector<std::pair<PhysicalBlockId, PhysicalBlockId>> SelfAttnBlockManager::SwapIn(
const SequenceGroupSPtr &seqGroup) {
std::vector<std::pair<PhysicalBlockId, PhysicalBlockId>> physicalBlockIdMapping;
for (const auto &seq : seqGroup->GetFirstSequence(SequenceStatus::SWAPPED)) {
std::vector<BlockObjSPtr> blocks = seqId2BlockTable_.at(seq->seqId_).GetBlockObjs();
std::vector<std::pair<BlockId, BlockId>> seqSwapMapping =
blockAllocator_->Swap(blocks, DeviceType::CPU, DeviceType::NPU);
seqId2BlockTable_[seq->seqId_].Update(blocks);
for (const auto &seqSwap : seqSwapMapping) {
physicalBlockIdMapping.push_back({blockAllocator_->GetPhysicalBlockId(seqSwap.first),
blockAllocator_->GetPhysicalBlockId(seqSwap.second)});
}
}
return physicalBlockIdMapping;
}
bool SelfAttnBlockManager::CanSwapOut(const SequenceGroupSPtr &seqGroup) {
AllocStatus status = CanSwap(seqGroup, DeviceType::CPU, SequenceStatus::RUNNING);
return status == AllocStatus::OK;
}
std::vector<std::pair<PhysicalBlockId, PhysicalBlockId>> SelfAttnBlockManager::SwapOut(
const SequenceGroupSPtr &seqGroup) {
std::vector<std::pair<PhysicalBlockId, PhysicalBlockId>> physicalBlockIdMapping;
for (const auto &seq : seqGroup->GetFirstSequence(SequenceStatus::RUNNING)) {
std::vector<BlockObjSPtr> blocks = seqId2BlockTable_.at(seq->seqId_).GetBlockObjs();
std::vector<std::pair<BlockId, BlockId>> seqSwapMapping =
blockAllocator_->Swap(blocks, DeviceType::NPU, DeviceType::CPU);
seqId2BlockTable_[seq->seqId_].Update(blocks);
for (const auto &seqSwap : seqSwapMapping) {
physicalBlockIdMapping.push_back({blockAllocator_->GetPhysicalBlockId(seqSwap.first),
blockAllocator_->GetPhysicalBlockId(seqSwap.second)});
}
}
return physicalBlockIdMapping;
}
size_t SelfAttnBlockManager::GetNumFreeNpuBlocks() const { return blockAllocator_->GetNumFreeBlock(DeviceType::NPU); }
size_t SelfAttnBlockManager::GetNumFreeCpuBlocks() const { return blockAllocator_->GetNumFreeBlock(DeviceType::CPU); }
float SelfAttnBlockManager::GetPrefixCacheHitRate() const { return blockAllocator_->GetPrefixCacheHitRate(); }
bool SelfAttnBlockManager::ResetPrefixCache() const { return blockAllocator_->ResetPrefixCache(); }
size_t SelfAttnBlockManager::GetNumCachedTokens(const SequenceSPtr &seq) {
size_t numCachedTokens = 0;
std::vector<std::vector<HashValue>> rankedHashValues = GetRankedHashValues(seq->seqId_);
for (size_t rankIdx = 0; rankIdx < rankSize_; rankIdx++) {
numCachedTokens +=
seqsBlockComputedTracker_.GetCachedTokensNum(seq, rankIdx, rankedHashValues[rankIdx], seq->IsPrefill());
}
return numCachedTokens;
}
size_t SelfAttnBlockManager::GetSeqNumCachedTokens(const SequenceSPtr &seq) {
size_t numCachedTokens = 0;
numCachedTokens += seqsBlockComputedTracker_.GetCachedTokensNum(seq);
return numCachedTokens;
}
eg. trailingPlaceHolderNum = 3 replacedPlaceHolderNum = 2
替换 尾部trailingPlaceHolderNum个 中的 前replacedPlaceHolderNum个 为有效值
替换前:[x x x x -1] [-1 -1]
替换后:[x x x x r] [ r -1]
*/
void SelfAttnBlockManager::ReplaceTrailingPlaceHolder(const SequenceSPtr &seq, size_t trailingPlaceHolderNum,
size_t replacedPlaceHolderNum) {
if (rankSize_ != 1) {
return;
}
const std::vector<TokenId> &allTokenIds = seq->GetTokenIds();
if (trailingPlaceHolderNum < replacedPlaceHolderNum || allTokenIds.size() < trailingPlaceHolderNum) {
throw std::runtime_error("tokenIds size is less than replacement size");
}
size_t begin = allTokenIds.size() - trailingPlaceHolderNum;
size_t end = allTokenIds.size() - trailingPlaceHolderNum + replacedPlaceHolderNum;
std::vector<TokenId> newTokenIds(allTokenIds.begin() + begin, allTokenIds.begin() + end);
BlockTable &blockTable = seqId2BlockTable_.at(seq->seqId_);
blockTable.ReplaceTrailingPlaceHolder(newTokenIds, trailingPlaceHolderNum, replacedPlaceHolderNum);
return;
}
size_t SelfAttnBlockManager::GetLocalDPRank() const { return localDPRank_; }
std::vector<size_t> SelfAttnBlockManager::GetTokenCountPerRank(SequenceId seqId) const {
const BlockTable &blockTable = seqId2BlockTable_.at(seqId);
std::vector<size_t> tokenCounts(rankSize_, 0);
for (const auto &blockObj : blockTable.GetBlockObjs()) {
size_t rank = blockObj->GetRankIdx();
tokenCounts[rank] += blockObj->GetTokenIds().size();
}
return tokenCounts;
}
size_t SelfAttnBlockManager::GetLatestAppendedRankId(SequenceId seqId) const {
const BlockTable &blockTable = seqId2BlockTable_.at(seqId);
return blockTable.GetLatestAppendedRankId();
}
size_t SelfAttnBlockManager::GetAppendedBlockRankId(SequenceId seqId) const {
const BlockTable &blockTable = seqId2BlockTable_.at(seqId);
return blockTable.GetAppendedBlockRankId();
}
}