* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "scheduling_budget.h"
#include "utils/log.h"
namespace mindie_llm {
SchedulingBudget::SchedulingBudget(const size_t maxNumBatchedTokens, const size_t maxNumSeqs,
const SchedulerConfigSPtr &schedulerConfig)
: maxNumSeqs_(maxNumSeqs), maxNumBatchedTokens_(maxNumBatchedTokens), schedulerConfig_(schedulerConfig) {
if (schedulerConfig_ != nullptr && schedulerConfig_->enableChunkedPrefill) {
maxNumBatchedTokens_ = schedulerConfig_->maxPrefillTokens;
maxNumSeqs_ = schedulerConfig_->maxBatchSize;
const size_t size = schedulerConfig_->maxNumPartialPrefills + 1;
partialPrefillOccupiedBudgetPerSlot_.resize(size, 0);
partialPrefillOccupiedBudgetPerSlot_[0] = schedulerConfig_->maxPrefillTokens;
for (size_t i = 1; i < size; i++) {
partialPrefillOccupiedBudgetPerSlot_[i] = schedulerConfig_->maxPrefillTokens / i;
}
}
}
bool SchedulingBudget::CanSchedule(const size_t numNewTokens, const size_t numNewSeqs) const {
return numBatchedTokens_ + numNewTokens <= maxNumBatchedTokens_ && numCurSeqs_ + numNewSeqs <= maxNumSeqs_;
}
size_t SchedulingBudget::RemainingTokenBudget() const {
if (numBatchedTokens_ > maxNumBatchedTokens_) {
return 0;
}
return maxNumBatchedTokens_ - numBatchedTokens_;
}
size_t SchedulingBudget::GetPrefillSlots() {
size_t remainingTokenBudget = RemainingTokenBudget();
if (schedulerConfig_ == nullptr || !schedulerConfig_->enableChunkedPrefill) {
return remainingTokenBudget;
}
if (schedulerConfig_->prefillChunkSize != 0) {
return schedulerConfig_->prefillChunkSize;
} else {
return schedulerConfig_->maxPrefillTokens;
}
if (statistics4PartialPrefill_ == nullptr) {
return remainingTokenBudget;
}
size_t numSchedulablePrefills = statistics4PartialPrefill_->SchedulablePrefills();
if (partialPrefillOccupiedBudgetPerSlot_.size() <= numSchedulablePrefills) {
MINDIE_LLM_LOG_ERROR(
"index is out of range of partialPrefillOccupiedBudgetPerSlot_. index=" << numSchedulablePrefills);
throw std::runtime_error("index is out of range of partialPrefillOccupiedBudgetPerSlot_. index=" +
std::to_string(numSchedulablePrefills));
}
return partialPrefillOccupiedBudgetPerSlot_[numSchedulablePrefills];
}
void SchedulingBudget::AddNumBatchedTokens(RequestId &reqId, const size_t numBatchedTokens,
const size_t numCachedTokens) {
if (requestIdsNumBatchedTokens_.count(reqId) != 0) {
return;
}
requestIdsNumBatchedTokens_.emplace(reqId);
numBatchedTokens_ += numBatchedTokens;
numCachedTokens_ += numCachedTokens;
}
void SchedulingBudget::AddNumSeqs(RequestId &reqId, const size_t numCurSeqs) {
if (requestIdsNumCurSeqs_.find(reqId) != requestIdsNumCurSeqs_.end()) {
return;
}
requestIdsNumCurSeqs_.emplace(reqId);
numCurSeqs_ += numCurSeqs;
}
void SchedulingBudget::SubtractNumBatchedTokens(const RequestId &reqId, const size_t numBatchedTokens) {
if (requestIdsNumBatchedTokens_.find(reqId) != requestIdsNumBatchedTokens_.end()) {
requestIdsNumBatchedTokens_.erase(reqId);
numBatchedTokens_ -= numBatchedTokens;
}
}
void SchedulingBudget::SubtractNumSeqs(const RequestId &reqId, const size_t numCurSeqs) {
if (requestIdsNumCurSeqs_.find(reqId) != requestIdsNumCurSeqs_.end()) {
requestIdsNumCurSeqs_.erase(reqId);
numCurSeqs_ -= numCurSeqs;
}
}
Statistics4PartialPrefill::Statistics4PartialPrefill(int numSchedulablePrefills, int numLongPrefills,
std::shared_ptr<SchedulerConfig> schedulerConfig)
: numSchedulablePrefills_(numSchedulablePrefills),
numLongPrefills_(numLongPrefills),
schedulerConfig_(schedulerConfig) {}
bool Statistics4PartialPrefill::CanSchedule(const std::shared_ptr<SequenceGroup> &seqGroup) const {
if (seqGroup->firstSeq->GetNumUncomputedTokens() <= schedulerConfig_->longPrefillTokenThreshold) {
return true;
}
return numLongPrefills_ < schedulerConfig_->maxLongPartialPrefills;
}
void Statistics4PartialPrefill::MaybeIncrementPartialPrefills(const std::shared_ptr<SequenceGroup> &seqGroup) {
if (seqGroup->firstSeq->GetNumUncomputedTokens() > schedulerConfig_->longPrefillTokenThreshold) {
numLongPrefills_++;
}
}
Statistics4PartialPrefillPtr Statistics4PartialPrefill::FromQueues(
const std::deque<std::shared_ptr<SequenceGroup>> &running,
const std::deque<std::shared_ptr<SequenceGroup>> &waiting, std::shared_ptr<SchedulerConfig> schedulerConfig) {
size_t numPrefills = 0;
size_t numLongPrefills = 0;
size_t numWaitingLongPrefills = 0;
for (const SequenceGroupSPtr &seqGroup : running) {
if (seqGroup->firstSeq->data_.stage_ == SequenceStage::PREFILL) {
numPrefills++;
if (seqGroup->firstSeq->GetNumUncomputedTokens() > schedulerConfig->longPrefillTokenThreshold) {
numLongPrefills++;
}
}
}
for (const SequenceGroupSPtr &seqGroup : waiting) {
if (numPrefills >= schedulerConfig->maxNumPartialPrefills) {
break;
}
if (seqGroup->firstSeq->GetNumUncomputedTokens() > schedulerConfig->longPrefillTokenThreshold) {
if (numLongPrefills + numWaitingLongPrefills >= schedulerConfig->maxLongPartialPrefills) {
continue;
}
numWaitingLongPrefills++;
}
numPrefills++;
}
return std::make_unique<Statistics4PartialPrefill>(
std::min(numPrefills, static_cast<size_t>(schedulerConfig->maxNumPartialPrefills)), numLongPrefills,
schedulerConfig);
}
size_t Statistics4PartialPrefill::SchedulablePrefills() const { return numSchedulablePrefills_; }
size_t Statistics4PartialPrefill::LongPrefills() const { return numLongPrefills_; }
}