* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "pdds_policy.h"
#include <stdexcept>
namespace mindie_llm {
PDDSPolicy::PDDSPolicy(std::shared_ptr<SchedulerConfig> &schedulerConfig, BlockSpaceManagerSPtr &blockManager,
Role role)
: FcfsPolicy(schedulerConfig, blockManager), role_(role) {
if (schedulerConfig == nullptr) {
throw std::invalid_argument("schedulerConfig is null");
}
if (blockManager == nullptr) {
throw std::invalid_argument("blockManager is null");
}
}
PolicyOutput PDDSPolicy::Apply(SchedulingBudget &budget, std::shared_ptr<SeqGroupCollection> &collection) {
if (collection == nullptr) {
throw std::invalid_argument("collection is null");
}
PreprocessCollection(collection);
switch (role_) {
case Role::P:
if (schedulerConfig_->enableChunkedPrefill) {
return ScheduleChunkedPrefill(budget);
}
return SchedulePrefill(budget);
case Role::D:
return ScheduleDecode(budget);
default:
throw std::runtime_error("the role is not support. role = " + std::to_string(static_cast<uint8_t>(role_)));
}
}
PolicyOutput PDDSPolicy::ScheduleChunkedPrefill(SchedulingBudget &budget) {
PrefillOutputs prefillOut;
RunningOutputs runningOut;
SwappedInOutputs swappedInOut;
budget.statistics4PartialPrefill_ = Statistics4PartialPrefill::FromQueues(
queuesCollection_->running_, queuesCollection_->waiting_, schedulerConfig_);
runningOut = ApplyToRunningQueue(budget, true);
prefillOut = ApplyToWaitingQueue(budget, true);
return BuildPolicyOutput(prefillOut, runningOut, swappedInOut, budget);
}
KVTransferSchedulePolicy::KVTransferSchedulePolicy(std::shared_ptr<SchedulerConfig> &schedulerConfig,
BlockSpaceManagerSPtr &blockManager, Role role)
: policyHelper_(schedulerConfig, blockManager),
schedulerConfig_(schedulerConfig),
blockManager_(blockManager),
role_(role) {
if (schedulerConfig == nullptr) {
throw std::invalid_argument("schedulerConfig is null");
}
if (blockManager == nullptr) {
throw std::invalid_argument("blockManager is null");
}
}
KVTransferPolicyOutput KVTransferSchedulePolicy::Apply(SchedulingBudget &budget,
std::shared_ptr<SeqGroupCollection> &collection) {
if (collection == nullptr) {
throw std::invalid_argument("collection is null");
}
this->queuesCollection_ = collection;
return ScheduleKVTransferInD(budget);
}
KVTransferPolicyOutput KVTransferSchedulePolicy::ScheduleKVTransferInD(SchedulingBudget &budget, bool enableChunking) {
std::vector<ScheduledSequenceGroupSPtr> pullSeqGroups = PickPullSeqGroup(budget, enableChunking);
return BuildTransferPolicyOutput(pullSeqGroups);
}
std::vector<ScheduledSequenceGroupSPtr> KVTransferSchedulePolicy::PickPullSeqGroup(SchedulingBudget &budget,
bool enableChunking) {
std::vector<ScheduledSequenceGroupSPtr> pullSeqGroups;
while (!queuesCollection_->waiting_.empty()) {
SequenceGroupSPtr seqGroup = queuesCollection_->waiting_.front();
const auto [numNewTokensUncached, numNewTokensCached] = policyHelper_.GetNumComputeNewUnCachedAndCachedTokens(
seqGroup, SequenceStatus::WAITING, enableChunking, budget);
const size_t numNewTokens = numNewTokensUncached + numNewTokensCached;
size_t freeNpuBlocksNum = blockManager_->GetNumFreeNpuBlocks();
size_t totalNpuBlocksNum = blockManager_->GetTotalNpuBlocks();
size_t preservedNpuBlocks = static_cast<size_t>(PRESERVED_FACTOR_FOR_DECODE * totalNpuBlocksNum);
if (freeNpuBlocksNum < preservedNpuBlocks) {
break;
}
const auto canAllocate = blockManager_->CanAllocate(seqGroup);
if (canAllocate == AllocStatus::LATER) {
break;
} else if (canAllocate == AllocStatus::NEVER) {
throw std::runtime_error("Prompt sequence too long.");
}
const int numNewSeqs = seqGroup->GetMaxNumRunningSeqs();
if (!budget.CanSchedule(numNewTokensUncached, numNewSeqs)) {
break;
}
queuesCollection_->waiting_.pop_front();
policyHelper_.AllocateAndSetRunning(seqGroup);
pullSeqGroups.emplace_back(std::make_shared<ScheduledSequenceGroup>(seqGroup, numNewTokens));
budget.AddNumBatchedTokens(seqGroup->requestId, numNewTokensUncached, numNewTokensCached);
budget.AddNumSeqs(seqGroup->requestId, numNewSeqs);
}
return pullSeqGroups;
}
KVTransferPolicyOutput KVTransferSchedulePolicy::BuildTransferPolicyOutput(
std::vector<ScheduledSequenceGroupSPtr> &pullSeqGroups) const {
KVTransferPolicyOutput output;
output.pullSeqGroups = std::move(pullSeqGroups);
output.withdrewSeqGroups.insert(output.withdrewSeqGroups.end(), queuesCollection_->waiting_.begin(),
queuesCollection_->waiting_.end());
return output;
}
}