* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "block_table.h"
#include <algorithm>
#include <stdexcept>
#include "log.h"
#include "math_utils.h"
namespace mindie_llm {
std::vector<std::vector<TokenId>> BlockTable::ChunkTokensForAllocate(const std::vector<TokenId> &tokenIds,
size_t chunkSize) {
if (chunkSize == 0) {
throw std::runtime_error("Chunk size cannot be zero.");
}
std::vector<std::vector<TokenId>> tokenBlocks;
for (size_t i = 0; i < tokenIds.size(); i += chunkSize) {
size_t min = std::min(i + chunkSize, tokenIds.size());
tokenBlocks.emplace_back(tokenIds.begin() + i, tokenIds.begin() + min);
}
return tokenBlocks;
}
BlockTable::BlockTable(size_t blockSize, DeviceAwareBlockAllocatorSPtr &blockAllocator, size_t rankSize)
: blockObjs_(rankSize), blockSize_(blockSize), blockAllocator_(blockAllocator), rankSize_(rankSize) {
if (blockSize_ == 0) {
throw std::runtime_error("blockSize can't be zero.");
}
for (std::vector<BlockObjSPtr> &rankBlocks : blockObjs_) {
rankBlocks = std::vector<BlockObjSPtr>();
}
numFullSlotsPerRank_ = std::vector<size_t>(rankSize, 0);
}
BlockTable::BlockTable(size_t blockSize, DeviceAwareBlockAllocatorSPtr &blockAllocator,
std::vector<std::vector<BlockObjSPtr>> &blockObjs, size_t rankSize)
: blockObjs_(blockObjs), blockSize_(blockSize), blockAllocator_(blockAllocator), rankSize_(rankSize) {
if (blockSize_ == 0) {
throw std::runtime_error("blockSize can't be zero.");
}
for (const std::vector<BlockObjSPtr> &rankBlocks : blockObjs) {
for (const BlockObjSPtr &blockObj : rankBlocks) {
blockIds_.push_back(blockObj->GetBlockId());
}
}
numFullSlotsPerRank_ = std::vector<size_t>(rankSize, 0);
for (const std::vector<BlockObjSPtr> &rankBlocks : blockObjs_) {
for (const BlockObjSPtr &blockObj : rankBlocks) {
numFullSlotsPerRank_[blockObj->GetRankIdx()] += blockObj->GetTokenIds().size();
}
}
}
void BlockTable::Allocate(const std::vector<TokenId> &tokenIds, DeviceType device, HashValue extraHash) {
if (IsAllocated()) {
throw std::runtime_error("Blocks of this block table are already allocated.");
}
std::vector<BlockObjSPtr> newBlockObjs;
newBlockObjs = AllocateBlocksForTokenIds(tokenIds, device, extraHash);
Update(newBlockObjs);
numFullSlotsPerRank_[0] = tokenIds.size();
}
void BlockTable::AllocateSmallRankFirst(const std::vector<TokenId> &tokenIds, DeviceType device, HashValue extraHash) {
if (IsAllocated()) {
throw std::runtime_error("Blocks of this block table are already allocated.");
}
std::vector<BlockObjSPtr> newBlockObjs;
newBlockObjs = AllocateBlocksForTokenIdsSmallRankFirst(tokenIds, device, extraHash);
Update(newBlockObjs);
for (BlockObjSPtr blockObj : newBlockObjs) {
numFullSlotsPerRank_[blockObj->GetRankIdx()] += blockObj->GetTokenIds().size();
}
}
void BlockTable::Update(const std::vector<BlockObjSPtr> &blockObjs) {
for (std::vector<BlockObjSPtr> &rankBlocks : blockObjs_) {
rankBlocks.clear();
}
blockIds_.clear();
for (const BlockObjSPtr &blockObj : blockObjs) {
blockObjs_[blockObj->GetRankIdx()].push_back(blockObj);
blockIds_.push_back(blockObj->GetBlockId());
}
}
bool BlockTable::CanAppendNewTokens(const std::vector<TokenId> &tokenIds, size_t numLookaheadSlots) const {
if (rankSize_ > 1) {
size_t numTokenIds = tokenIds.size() + numLookaheadSlots;
size_t numEmptySlots = GetNumEmptySlots(currentSpRank_);
size_t nextSpRank = currentSpRank_;
if ((numEmptySlots == 0 && numLookaheadSlots == 0) ||
(numLookaheadSlots != 0 && tokenIds.size() > numEmptySlots)) {
nextSpRank = (nextSpRank + 1) % rankSize_;
numEmptySlots = GetNumEmptySlots(nextSpRank);
}
size_t needBlocks = 1;
if (numTokenIds > numEmptySlots) {
needBlocks = 1 + CeilDiv(numTokenIds - numEmptySlots, blockSize_);
}
return blockAllocator_->GetNumFreeBlock(DeviceType::NPU, nextSpRank) >= needBlocks;
} else {
throw std::runtime_error("rank size == 1 need call other function.");
}
}
void BlockTable::AppendNewTokens(const std::vector<TokenId> &newTokenIds, HashValue extraHash,
size_t numLookaheadSlots) {
if (!IsAllocated()) {
throw std::runtime_error("No blocks have been allocated.");
}
if (newTokenIds.size() == 0) {
return;
}
if (numLookaheadSlots == 0) {
size_t numEmptySlots = GetNumEmptySlots(currentSpRank_);
if (numEmptySlots == 0) {
currentSpRank_ = (currentSpRank_ + 1) % rankSize_;
numEmptySlots = GetNumEmptySlots(currentSpRank_);
}
if (numEmptySlots < 1) {
auto &rankBlockObjs = blockObjs_[currentSpRank_];
std::vector<TokenId> emptyTokenIds;
BlockObjSPtr newBlockObj = blockAllocator_->AllocateMutableBlock(
DeviceType::NPU, emptyTokenIds, rankBlockObjs.empty() ? nullptr : rankBlockObjs.back(), extraHash,
currentSpRank_);
rankBlockObjs.push_back(newBlockObj);
blockIds_.push_back(newBlockObj->GetBlockId());
}
std::vector<TokenId> currentSpToken{newTokenIds.back()};
AppendToSpRank(currentSpRank_, currentSpToken);
} else {
std::vector<TokenId> tokenIds;
for (size_t i = 0; i < newTokenIds.size(); i++) {
if (newTokenIds[i] != -1) {
tokenIds.push_back(newTokenIds[i]);
}
}
size_t numEmptySlots = GetNumEmptySlots(currentSpRank_);
size_t nextSpRank = currentSpRank_;
isAppendBlock_ = false;
if (newTokenIds.size() > numEmptySlots) {
nextSpRank = (currentSpRank_ + 1) % rankSize_;
size_t reservedSlots = newTokenIds.size() - numEmptySlots;
size_t nextRankEmptySlots = GetNumEmptySlots(nextSpRank);
if (nextRankEmptySlots < reservedSlots) {
auto &rankBlockObjs = blockObjs_[nextSpRank];
size_t numBlocksToAllocate = CeilDiv(reservedSlots, blockSize_);
for (size_t i = 0; i < numBlocksToAllocate; ++i) {
std::vector<TokenId> emptyTokenIds;
BlockObjSPtr newBlockObj = blockAllocator_->AllocateMutableBlock(
DeviceType::NPU, emptyTokenIds, rankBlockObjs.empty() ? nullptr : rankBlockObjs.back(),
extraHash, nextSpRank);
rankBlockObjs.push_back(newBlockObj);
blockIds_.push_back(newBlockObj->GetBlockId());
}
isAppendBlock_ = true;
appendBlockRankId_ = nextSpRank;
}
}
if (tokenIds.size() > 0) {
size_t segIndex = std::min(numEmptySlots, tokenIds.size() - 1);
std::vector<TokenId> lastRankToken(tokenIds.begin(), tokenIds.begin() + segIndex);
AppendToSpRank(currentSpRank_, lastRankToken);
if (numEmptySlots < tokenIds.size()) {
currentSpRank_ = nextSpRank;
}
std::vector<TokenId> currentRankToken(tokenIds.begin() + segIndex, tokenIds.end());
AppendToSpRank(currentSpRank_, currentRankToken);
}
}
}
void BlockTable::AppendToSpRank(size_t spRank, const std::vector<TokenId> &newTokenIds) {
size_t spTokenNum = numFullSlotsPerRank_[spRank];
auto &rankBlockObjs = blockObjs_[spRank];
size_t firstBlockIdx = spTokenNum / blockSize_;
size_t firstTokenIdx = spTokenNum % blockSize_;
for (size_t i = 0; i < newTokenIds.size(); ++i) {
size_t blockIdx = firstBlockIdx + (firstTokenIdx + i) / blockSize_;
if (blockIdx >= rankBlockObjs.size()) {
throw std::runtime_error("Block index out of range");
}
std::vector<TokenId> singleToken{newTokenIds[i]};
blockAllocator_->AppendTokenIds(rankBlockObjs[blockIdx], singleToken);
if (rankBlockObjs[blockIdx]->IsFull()) {
blockIds_[blockIdx] = rankBlockObjs[blockIdx]->GetBlockId();
}
}
numFullSlotsPerRank_[spRank] += newTokenIds.size();
}
void BlockTable::AppendTokenIds(const std::vector<TokenId> &tokenIds, HashValue extraHash, size_t numLookaheadSlots) {
if (!IsAllocated()) {
throw std::runtime_error("No blocks have been allocated.");
}
const size_t firstRank = 0;
std::vector<BlockObjSPtr> &firstRankBlockObjs = blockObjs_[firstRank];
EnsureEnoughSlots(tokenIds.size() + numLookaheadSlots, extraHash);
size_t firstBlockIdx = numFullSlotsPerRank_[firstRank] / blockSize_;
std::vector<std::vector<TokenId>> tokenBlocks = ChunkTokensForAppend(tokenIds);
for (size_t i = 0; i < tokenBlocks.size(); ++i) {
blockAllocator_->AppendTokenIds(firstRankBlockObjs.at(firstBlockIdx + i), tokenBlocks.at(i));
if (i + 1 < tokenBlocks.size()) {
if (!firstRankBlockObjs[firstBlockIdx + i]->IsFull()) {
throw std::runtime_error("a block table has not-full block in the middle, this should not happend.");
}
}
blockIds_.at(firstBlockIdx + i) = firstRankBlockObjs.at(firstBlockIdx + i)->GetBlockId();
}
numFullSlotsPerRank_[firstRank] += tokenIds.size();
}
size_t BlockTable::GetNumRelatedBlocks(size_t tokenIdsSize, size_t numLookaheadSlots) const {
size_t numTokenIds = tokenIdsSize + numLookaheadSlots;
size_t numLastBlockEmptySlots = blockSize_ - (numFullSlotsPerRank_[0] % blockSize_);
if (numLastBlockEmptySlots == blockSize_) {
return CeilDiv(numTokenIds, blockSize_);
}
if (numLastBlockEmptySlots >= numTokenIds) {
return 0;
}
return CeilDiv(numTokenIds - numLastBlockEmptySlots, blockSize_);
}
BlockTable BlockTable::Fork() {
if (!IsAllocated()) {
throw std::runtime_error("Empty blocks can't be forked.");
}
if (rankSize_ != 1) {
throw std::runtime_error("Fork only supports single rank scenario");
}
BlockObjSPtr lastBlockObj = blockObjs_[0].back();
std::vector<BlockObjSPtr> forkedBlockObjs = blockAllocator_->Fork(lastBlockObj);
std::vector<std::vector<BlockObjSPtr>> newBlockObjs;
newBlockObjs.push_back(forkedBlockObjs);
return BlockTable(blockSize_, blockAllocator_, newBlockObjs, rankSize_);
}
void BlockTable::Free() {
for (std::vector<BlockObjSPtr> &rankBlocks : blockObjs_) {
for (BlockObjSPtr &blockObj : rankBlocks) {
blockAllocator_->Free(blockObj);
}
}
blockObjs_.clear();
blockIds_.clear();
std::fill(numFullSlotsPerRank_.begin(), numFullSlotsPerRank_.end(), 0);
lastUsedRanks_ = std::queue<size_t>();
currentSpRank_ = 0;
}
const std::vector<BlockId> &BlockTable::GetBlockIds() const { return blockIds_; }
std::vector<BlockId> BlockTable::GetRankedBlockIds(size_t rankIdx) const {
std::vector<BlockId> rankedBlockIds;
for (auto &blockObj : blockObjs_[rankIdx]) {
rankedBlockIds.push_back(blockObj->GetBlockId());
}
return rankedBlockIds;
}
std::vector<BlockObjSPtr> BlockTable::GetBlockObjs() const {
std::vector<BlockObjSPtr> mergedBlocks;
for (const std::vector<BlockObjSPtr> &rankBlocks : blockObjs_) {
mergedBlocks.insert(mergedBlocks.end(), rankBlocks.begin(), rankBlocks.end());
}
return mergedBlocks;
}
size_t BlockTable::GetNumFullSlots() const { return numFullSlotsPerRank_[0]; }
void BlockTable::EnsureEnoughSlots(size_t numRequiredEmptySlots, HashValue extraHash, size_t rankId) {
if (!IsAllocated()) {
throw std::runtime_error("No blocks have been allocated, init blockTable before using it");
}
if (rankId >= rankSize_) {
throw std::runtime_error("Invalid rank id");
}
DeviceType device = DeviceType::NPU;
size_t numEmptySlots = GetNumEmptySlots(rankId);
if (numRequiredEmptySlots <= numEmptySlots) {
return;
}
size_t numBlocksToAllocate = CeilDiv(numRequiredEmptySlots - numEmptySlots, blockSize_);
std::vector<BlockObjSPtr> &rankBlocks = blockObjs_[rankId];
for (size_t i = 0; i < numBlocksToAllocate; ++i) {
std::vector<TokenId> emptyTokenIds;
BlockObjSPtr newBlockObj = blockAllocator_->AllocateMutableBlock(
device, emptyTokenIds, rankBlocks.empty() ? nullptr : rankBlocks.back(), extraHash);
rankBlocks.push_back(newBlockObj);
blockIds_.push_back(newBlockObj->GetBlockId());
}
}
std::vector<TokenId> BlockTable::GetNewGenTokenIds(const std::vector<TokenId> &tokenIds) const {
if (rankSize_ > 1) {
throw std::runtime_error("GetNewGenTokenIds only supports single rank");
}
return std::vector<TokenId>(tokenIds.begin() + numFullSlotsPerRank_[0], tokenIds.end());
}
std::vector<BlockObjSPtr> BlockTable::AllocateBlocksForTokenIds(const std::vector<TokenId> &tokenIds, DeviceType device,
HashValue extraHash) {
std::vector<BlockObjSPtr> newBlockObjs;
std::vector<std::vector<TokenId>> fullTokenBlocks;
std::vector<TokenId> nonFullTokenBlock;
std::vector<std::vector<TokenId>> blockedTokenIds = ChunkTokensForAllocate(tokenIds, blockSize_);
for (const std::vector<TokenId> &block : blockedTokenIds) {
if (block.size() == blockSize_) {
fullTokenBlocks.push_back(block);
} else {
nonFullTokenBlock = block;
}
}
BlockObjSPtr prevBlock = nullptr;
if (!fullTokenBlocks.empty()) {
std::vector<BlockObjSPtr> newBlocks =
blockAllocator_->AllocateImmutableBlocks(device, fullTokenBlocks, prevBlock, extraHash);
newBlockObjs.insert(newBlockObjs.end(), newBlocks.begin(), newBlocks.end());
prevBlock = newBlockObjs.back();
}
if (!nonFullTokenBlock.empty()) {
BlockObjSPtr block = blockAllocator_->AllocateMutableBlock(device, nonFullTokenBlock, prevBlock, extraHash);
newBlockObjs.push_back(block);
}
return newBlockObjs;
}
std::vector<BlockObjSPtr> BlockTable::AllocateBlocksForTokenIdsSmallRankFirst(const std::vector<TokenId> &tokenIds,
DeviceType device, HashValue extraHash) {
std::vector<BlockObjSPtr> newBlockObjs;
BlockObjSPtr prevBlock = nullptr;
size_t remainingTokens = tokenIds.size();
size_t currentPos = 0;
size_t rankIdx = 0;
while (remainingTokens > 0) {
size_t tokensForThisRank = std::min(remainingTokens, blockSize_);
std::vector<TokenId> rankTokens(tokenIds.begin() + currentPos,
tokenIds.begin() + currentPos + tokensForThisRank);
currentPos += tokensForThisRank;
remainingTokens -= tokensForThisRank;
BlockObjSPtr blockObj;
if (rankTokens.size() == blockSize_) {
blockObj = blockAllocator_->AllocateImmutableBlock(device, rankTokens, prevBlock, extraHash, rankIdx);
} else {
blockObj = blockAllocator_->AllocateMutableBlock(device, rankTokens, prevBlock, extraHash, rankIdx);
}
newBlockObjs.push_back(blockObj);
prevBlock = blockObj;
if (remainingTokens == 0) {
currentSpRank_ = rankIdx;
appendBlockRankId_ = currentSpRank_;
}
if (remainingTokens <= blockSize_ && newBlockObjs.size() >= rankSize_) {
rankIdx = rankSize_ - 1;
} else {
rankIdx = (rankIdx + 1) % rankSize_;
}
}
return newBlockObjs;
}
size_t BlockTable::GetNumTokenIds() const {
size_t res = 0;
for (const std::vector<BlockObjSPtr> &rankBlocks : blockObjs_) {
for (const BlockObjSPtr &blockObj : rankBlocks) {
res += blockObj->GetTokenIds().size();
}
}
return res;
}
bool BlockTable::IsAllocated() const {
for (const std::vector<BlockObjSPtr> &rankBlocks : blockObjs_) {
if (!rankBlocks.empty()) {
return true;
}
}
return false;
}
size_t BlockTable::GetNumEmptySlots(size_t rankId) const {
if (!IsAllocated()) {
throw std::runtime_error("No blocks have been allocated.");
}
if (rankId >= rankSize_) {
throw std::runtime_error("Invalid rank id");
}
size_t rankBlockCount = blockObjs_[rankId].size();
size_t rankFullSlots = numFullSlotsPerRank_[rankId];
return rankBlockCount * blockSize_ - rankFullSlots;
}
std::vector<std::vector<TokenId>> BlockTable::ChunkTokensForAppend(const std::vector<TokenId> &tokenIds) const {
std::vector<std::vector<TokenId>> tokenBlocks{};
if (tokenIds.empty()) {
return tokenBlocks;
}
size_t numLastBlockEmptySlots = blockSize_ - (numFullSlotsPerRank_[0] % blockSize_);
tokenBlocks.push_back(
std::vector<TokenId>(tokenIds.begin(), tokenIds.begin() + std::min(numLastBlockEmptySlots, tokenIds.size())));
for (size_t i = numLastBlockEmptySlots; i < tokenIds.size(); i += blockSize_) {
size_t end = std::min(i + blockSize_, tokenIds.size());
tokenBlocks.push_back(std::vector<TokenId>(tokenIds.begin() + i, tokenIds.begin() + end));
}
return tokenBlocks;
}
eg. trailingPlaceHolderNum = 3 replacedPlaceHolderNum = 2
替换 尾部trailingPlaceHolderNum个 中的 前replacedPlaceHolderNum个 为有效值
替换前:[x x x x -1] [-1 -1]
替换后:[x x x x r] [ r -1]
*/
void BlockTable::ReplaceTrailingPlaceHolder(const std::vector<TokenId> &tokenIds, size_t trailingPlaceHolderNum,
size_t replacedPlaceHolderNum, size_t rankId) {
size_t totalTokens = GetNumTokenIds();
if (trailingPlaceHolderNum >= totalTokens) {
throw std::runtime_error("replace start index cannot be greater than the total number of tokens.");
}
if (replacedPlaceHolderNum > trailingPlaceHolderNum) {
throw std::runtime_error("number of tokens to replace cannot be greater than the number of tokens start");
}
size_t startBlockIndex = (totalTokens - trailingPlaceHolderNum) / blockSize_;
size_t startLocalIndex = (totalTokens - trailingPlaceHolderNum) % blockSize_;
if (startBlockIndex == 0 && startLocalIndex == 0) {
throw std::runtime_error("replace start index cannot be the first token.");
}
size_t prevTokenOffset = startLocalIndex == 0 ? blockSize_ - 1 : startLocalIndex - 1;
size_t prevBlockOffset = startLocalIndex == 0 ? startBlockIndex - 1 : startBlockIndex;
std::vector<BlockObjSPtr> &rankBlocks = blockObjs_[rankId];
for (size_t i = 0; i < replacedPlaceHolderNum; ++i) {
size_t blockOffset = startBlockIndex + (startLocalIndex + i) / blockSize_;
size_t tokenOffset = (startLocalIndex + i) % blockSize_;
BlockObjSPtr currentBlock = rankBlocks[blockOffset];
TokenId newToken = tokenIds[i];
if (rankBlocks[prevBlockOffset]->GetTokenIds()[prevTokenOffset] == PLACEHOLDER_TOKEN) {
throw std::runtime_error("preceding token can't be PLACEHOLDER_TOKEN");
}
if (currentBlock->GetTokenIds()[tokenOffset] != PLACEHOLDER_TOKEN) {
throw std::runtime_error("only PLACEHOLDER_TOKEN can be replaced.");
}
blockAllocator_->ReplaceToken(currentBlock, tokenOffset, newToken);
if (rankBlocks[blockOffset]->IsFull()) {
blockIds_[blockOffset] = rankBlocks[blockOffset]->GetBlockId();
}
prevBlockOffset = blockOffset;
prevTokenOffset = tokenOffset;
}
}
size_t BlockTable::GetLatestAppendedRankId() const { return currentSpRank_; }
size_t BlockTable::GetAppendedBlockRankId() const { return appendBlockRankId_; }
bool BlockTable::IsAppendBlock() const { return isAppendBlock_; }
}