* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "dynamic_batch_size.h"
#include "dynamic_batch_recorder.h"
#include "log.h"
#include "policy/seq_group_collection.h"
#include "stage_policy/stage_policy.h"
namespace mindie_llm {
DynamicBatchSize::DynamicBatchSize(const SchedulerConfigSPtr schedulerConfig,
std::shared_ptr<LatencyPredictor> predictor,
std::shared_ptr<BlockSpaceManager> blockManager, size_t localDPRank)
: schedulerConfig_(schedulerConfig), predictor_(predictor), blockManager_(blockManager), localDPRank_(localDPRank) {
batchSizeUpperBound_ = schedulerConfig_->maxBatchSize;
batchSizeLowerBound_ = 3U;
batchSizeUpper_ = batchSizeUpperBound_;
batchSizeLower_ = batchSizeLowerBound_;
batchTrackerWindowSize_ = 20000;
decodeBatchSizeQueue_ = std::make_shared<DecodeBatchSizeTracker>(batchTrackerWindowSize_);
auto &recorder = DynamicBatchRecorder::GetInstance(localDPRank_);
recorder.SetDecodeBatchSizeTracker(decodeBatchSizeQueue_);
recorder.SetLatencyPredictor(predictor_);
}
void DynamicBatchSize::ApplyDynamicBatchSize(Role role, SchedulerOutputs &schedulerOut, size_t waitingSize,
size_t runningSize, size_t swappedSize) {
if (schedulerOut.IsEmpty() && schedulerConfig_->maxPrefillBatchSize != 0) {
return;
}
size_t previousDecodeBatchSize = 0;
if (schedulerOut.forwardMode_ == ForwardMode::DECODE) {
previousDecodeBatchSize = schedulerOut.scheduledSeqGroups_.size();
previousStage_ = 1;
}
if (role == Role::PnD && schedulerConfig_->dynamicBatchSizeEnable) {
AdjustBatchSize(previousStage_, previousDecodeBatchSize, waitingSize, runningSize, swappedSize);
}
}
void DynamicBatchSize::AdjustBatchSize(size_t previousStage, size_t previousDecodeBatchSize, size_t waitingSize,
size_t runningSize, size_t swappedSize) {
const uint64_t currentPrefillRequestNum = static_cast<uint64_t>(waitingSize);
const uint64_t localDecodeRequestNum = static_cast<uint64_t>(runningSize + swappedSize);
if (previousDecodeBatchSize > 0) {
decodeBatchSizeQueue_->AddDataPoint(previousDecodeBatchSize);
}
if (currentPrefillRequestNum == 0 && localDecodeRequestNum == 0) {
return;
}
auto &recorder = DynamicBatchRecorder::GetInstance(localDPRank_);
recorder.SetRunningSize(runningSize + swappedSize);
double maxDecodeLatency = 0.0;
uint64_t maxBatchSize = 0;
size_t maxDecodeRequestNum = 0;
size_t validDPCount = DynamicBatchRecorder::AggregateAllFromAllDPs(batchTrackerWindowSize_, maxDecodeLatency,
maxBatchSize, maxDecodeRequestNum);
if (maxBatchSize == 0 || std::fabs(maxDecodeLatency) < 1e-6f || previousStage != 1) {
schedulerConfig_->maxBatchSize = batchSizeLowerBound_;
batchSizeLower_ = maxBatchSize;
batchSizeUpper_ = maxBatchSize;
}
MINDIE_LLM_LOG_INFO_REQUEST("[DynamicBatchSize|DP"
<< localDPRank_ << "] Inputs: " << "localDecodeRequestNum=" << localDecodeRequestNum
<< ", " << "aggregated(max)DecodeRequestNum=" << maxDecodeRequestNum << ", "
<< "aggregated(max)DecodeLatency=" << maxDecodeLatency << "ms, "
<< "aggregated(max)BatchSize=" << maxBatchSize << ", "
<< "validDPCount=" << validDPCount);
const uint64_t aggregatedDecodeRequestNum = static_cast<uint64_t>(maxDecodeRequestNum);
const uint32_t stageModulo = 2;
stage_ = (stage_ + 1) % stageModulo;
if (stage_ % stageModulo == 0) {
BinarySearchBatchSize(aggregatedDecodeRequestNum, maxDecodeLatency, maxBatchSize);
} else {
SetMinimalBatchSize(aggregatedDecodeRequestNum, maxDecodeLatency);
}
}
void DynamicBatchSize::BinarySearchBatchSize(uint64_t currentDecodeRequestNum, double avgDecodeLatency,
uint64_t avgBatchSize) {
const uint32_t deltaAdjustUpper = 3;
const uint32_t deltaAdjustLower = 2;
const double deltaMs = 5.0;
const uint64_t deltaBatchSize = 2;
if (avgDecodeLatency > schedulerConfig_->decodeExpectedTime + deltaMs) {
batchSizeUpper_ = std::max(avgBatchSize, batchSizeLower_ + deltaAdjustUpper);
batchSizeLower_ = std::max(batchSizeLower_ - deltaBatchSize, batchSizeLowerBound_);
} else if (avgDecodeLatency < schedulerConfig_->decodeExpectedTime - deltaMs) {
batchSizeLower_ = std::min(avgBatchSize, batchSizeUpper_ - deltaAdjustUpper);
batchSizeUpper_ = std::min(batchSizeUpper_ + deltaBatchSize, batchSizeUpperBound_);
} else {
batchSizeUpper_ = std::min(avgBatchSize + deltaAdjustUpper, batchSizeUpperBound_);
batchSizeLower_ = std::max(avgBatchSize - deltaAdjustLower, batchSizeLowerBound_);
}
batchSizeUpper_ = std::clamp(batchSizeUpper_, batchSizeLowerBound_, batchSizeUpperBound_);
batchSizeLower_ = std::clamp(batchSizeLower_, batchSizeLowerBound_, batchSizeUpperBound_);
uint64_t newDecodeMaxBatchSize =
std::clamp((batchSizeUpper_ + batchSizeLower_) / 2, schedulerConfig_->maxBatchSize - deltaBatchSize,
schedulerConfig_->maxBatchSize + deltaBatchSize);
newDecodeMaxBatchSize = std::clamp(newDecodeMaxBatchSize, currentDecodeRequestNum, batchSizeUpperBound_);
uint64_t newPrefillMaxBatchSize =
newDecodeMaxBatchSize > currentDecodeRequestNum ? (newDecodeMaxBatchSize - currentDecodeRequestNum) : 0UL;
MINDIE_LLM_LOG_DEBUG("[DynamicBatchSize|DP" << localDPRank_ << "] BinarySearch result: " << "batchSizeUpper="
<< batchSizeUpper_ << ", batchSizeLower=" << batchSizeLower_ << ", "
<< "newDecodeMaxBatchSize=" << newDecodeMaxBatchSize << ", "
<< "newPrefillMaxBatchSize=" << newPrefillMaxBatchSize);
ApplyUpdatedBatchSize(newDecodeMaxBatchSize, newPrefillMaxBatchSize);
}
void DynamicBatchSize::SetMinimalBatchSize(uint64_t currentDecodeRequestNum, double avgDecodeLatency) {
const double deltaMs = 5.0;
if (currentDecodeRequestNum != 0 && (avgDecodeLatency > schedulerConfig_->decodeExpectedTime + deltaMs)) {
uint64_t newDecodeMaxBatchSize = currentDecodeRequestNum;
uint64_t newPrefillMaxBatchSize = 0UL;
MINDIE_LLM_LOG_DEBUG("[DynamicBatchSize|DP" << localDPRank_
<< "] SetMinimal: newDecodeMaxBatchSize=" << newDecodeMaxBatchSize
<< ", newPrefillMaxBatchSize=" << newPrefillMaxBatchSize);
ApplyUpdatedBatchSize(newDecodeMaxBatchSize, newPrefillMaxBatchSize);
} else {
uint64_t currentDecodeMaxBatchSize = schedulerConfig_->maxBatchSize;
uint64_t currentPrefillMaxBatchSize = schedulerConfig_->maxPrefillBatchSize;
MINDIE_LLM_LOG_DEBUG("[DynamicBatchSize|DP" << localDPRank_ << "] SetMinimal: condition not met, "
<< "keeping current maxBatchSize=" << currentDecodeMaxBatchSize
<< ", " << "maxPrefillBatchSize=" << currentPrefillMaxBatchSize);
}
}
void DecodeBatchSizeTracker::AddDataPoint(uint64_t batchSize) {
if (queue_.size() == windowSize_) {
queue_.pop_front();
}
queue_.push_back(batchSize);
}
uint32_t DynamicBatchSize::GetScheduledOutBlockNum(SchedulerOutputs schedulerOut) {
uint32_t blockNum = 0;
for (size_t i = 0; i < schedulerOut.scheduledSeqGroups_.size(); ++i) {
const auto scheSeqGroup = schedulerOut.scheduledSeqGroups_[i];
const auto seqGroup = scheSeqGroup->seqGroup_;
SequenceId seqId = seqGroup->seqs_[0]->seqId_;
const auto allIds = blockManager_->GetBlockIds(seqId);
if (allIds.empty()) {
continue;
}
blockNum += static_cast<uint32_t>(allIds[0].size());
}
return blockNum;
}
void DynamicBatchSize::RecordPredictorMetrics(const SchedulerOutputs &schedulerOut, const SchedulingBudget &budget) {
if (!schedulerConfig_->dynamicBatchSizeEnable &&
schedulerConfig_->stageSelectPolicy != static_cast<uint32_t>(StagePolicyType::LATENCY_FIRST)) {
return;
}
auto batchStatsPtr = std::make_shared<BatchStats>();
batchStatsPtr->forwardMode = schedulerOut.forwardMode_;
batchStatsPtr->numBatchedTokens = budget.numBatchedTokens_;
batchStatsPtr->kvCacheBlockNum = GetScheduledOutBlockNum(schedulerOut);
predictor_->SaveBatchStats(batchStatsPtr);
}
uint64_t DecodeBatchSizeTracker::GetRecentAvgBatchSize(uint64_t forwardNum) {
if (queue_.empty()) {
return 0;
}
std::vector<uint64_t> dataPoint;
std::vector<uint64_t> dataCnt;
uint64_t sumVal = 0;
for (auto it = queue_.rbegin(); it != queue_.rend(); ++it) {
sumVal += *it;
dataPoint.push_back(*it);
if (sumVal < forwardNum) {
dataCnt.push_back(*it);
} else if (sumVal == forwardNum) {
dataCnt.push_back(*it);
break;
} else {
dataCnt.push_back(sumVal - forwardNum);
break;
}
}
uint64_t weightedSum = 0;
uint64_t cnt = 0;
for (uint64_t i = 0; i < dataPoint.size(); ++i) {
weightedSum += dataPoint[i] * dataCnt[i];
cnt += dataCnt[i];
}
if (cnt == 0) {
return 0;
}
return weightedSum / cnt;
}
void DynamicBatchSize::ApplyUpdatedBatchSize(uint64_t newDecodeMaxBatchSize, uint64_t newPrefillMaxBatchSize) {
schedulerConfig_->maxBatchSize = newDecodeMaxBatchSize;
schedulerConfig_->maxPrefillBatchSize = newPrefillMaxBatchSize;
if (previousDecodeMaxBatchSize_ == newDecodeMaxBatchSize) {
return;
}
MINDIE_LLM_LOG_INFO_REQUEST(
"[DynamicBatchSize|DP" << localDPRank_ << "] Updated maxPrefillBatchSize: " << newPrefillMaxBatchSize
<< ", maxBatchSize: " << previousDecodeMaxBatchSize_ << " -> " << newDecodeMaxBatchSize);
previousDecodeMaxBatchSize_ = newDecodeMaxBatchSize;
}
}