* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include <chrono>
#include <thread>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "load_balancer.h"
#include "llm_engine.h"
#include "scheduler.h"
#include "../crash_handler.h"
using namespace mindie_llm;
using namespace std;
namespace {
std::mutex mtx;
std::condition_variable cv;
int seqCount = 0;
}
class EmptyScheduler : public IScheduler {
explicit EmptyScheduler() {}
void AddSeqGroup(SequenceGroupSPtr &seqGroup) override { (void)seqGroup; }
SchedulerMetric CollectSchedulerMetric() override
{
SchedulerMetric metric;
return metric;
}
std::pair<SequenceGroupMetaDatas, SchedulerOutputs> Schedule(bool needSync = false) override
{
(void)needSync;
return {};
};
std::pair<SequenceGroupMetaDatas, SchedulerKVTransferOutput> ScheduleTransfer() { return {}; }
size_t GetUnFinishedSeqGroups() { return 0; }
std::unordered_set<SequenceId> &FetchFinishedSeqIds(ConcurrentDeque<SequenceId> &finishedSeqIds)
{
(void)finishedSeqIds;
return seqIds_;
}
std::unordered_set<SequenceId> &FetchExceptionSeqIds(ConcurrentDeque<SequenceId> &exceptionSeqIds)
{
(void)exceptionSeqIds;
return seqIds_;
}
std::unordered_set<RequestId> &FetchAbortedReqIds(ConcurrentDeque<RequestId> &abortedReqIds)
{
(void)abortedReqIds;
return reqIds_;
}
void KVPulledReqEnterRunningQueue(ConcurrentDeque<RequestId> &pulledReqIds) { (void)pulledReqIds; }
void NotifyMeKvPulledSeqIds(SequenceId seqId) { (void)seqId; }
std::unordered_set<SequenceId> ClearAndReturnTerminatedSeqIds() { return {}; }
void FetchSeqGeneratedTokens(ConcurrentDeque<std::pair<SequenceId, TokenId>> &seqIdToOutputTokenQueue)
{
(void)seqIdToOutputTokenQueue;
}
void MarkLastScheduleEmpty() {}
void ClearLastScheduleEmpty() {}
void PrepareNextSchedule(std::vector<ScheduledSequenceGroupSPtr> &scheduledSeqGroups) { (void)scheduledSeqGroups; }
void ClearSeqGrp(SequenceGroupSPtr seqGroup, SequenceStatus finalStatus)
{
(void)seqGroup;
(void)finalStatus;
}
void CollectAndClearAbortedParallelSeqGroups() override{};
std::vector<SequenceGroupSPtr> &GetAbortedParallelSeqGroups() override { return abortedParallelSeqGroups_; }
void SetPrefillPercentage(uint32_t prefillPercentage)
{
(void)prefillPercentage;
}
Role SwitchRole() {};
std::shared_ptr<StagePolicy> GetStagePolicy() {};
private:
std::unordered_set<SequenceId> seqIds_;
std::unordered_set<RequestId> reqIds_;
std::vector<SequenceGroupSPtr> abortedParallelSeqGroups_;
};
class MockScheduler : public EmptyScheduler {
MockScheduler() : EmptyScheduler() {}
void StopRunningRequest() override {}
void AddSeqGroup(SequenceGroupSPtr &seqGroup) override
{
std::unique_lock<std::mutex> lock(mtx);
(void)seqGroup;
seqCount++;
cv.notify_one();
}
SchedulerMetric CollectSchedulerMetric() override
{
SchedulerMetric metric;
metric.blockInfo.freeNpuBlockNum_ = 160;
return metric;
}
};
class LoadBalancerTest : public ::testing::Test {
protected:
void SetUp() override { seqCount = 0; }
static void SetUpTestSuite()
{
mindie_llm::test::InitCrashHandler();
}
};
TEST_F(LoadBalancerTest, SendMassiveRequestsAndVerifySchedulingCompletionInShortTime)
{
size_t dpSize = 4;
uint32_t maxPrefillBatchSize = 10;
SchedulerConfig schedulerConfig;
SchedulerConfigSPtr schedulerConfigPtr = std::make_shared<SchedulerConfig>(schedulerConfig);
std::vector<std::shared_ptr<EnginePerDP>> enginePerDPs;
for (size_t i = 0; i < dpSize; ++i) {
EnginePerDPSPtr enginePerDP = std::make_shared<EnginePerDP>();
enginePerDPs.emplace_back(enginePerDP);
enginePerDP->scheduler = std::make_unique<MockScheduler>();
}
LoadBalancerPtr loadBalancer_ = MakeLoadBalancer(enginePerDPs, maxPrefillBatchSize);
constexpr int reqNum = 10;
size_t seqLen[reqNum] = {18, 3, 22, 15, 7, 11, 24, 9, 1, 17};
std::vector<SequenceGroupSPtr> candidates(reqNum);
for (int i = 0; i < reqNum; i++) {
RequestId id(std::to_string(i));
std::vector<SequenceSPtr> seq(1);
std::vector<TokenId> inputs(seqLen[i], 5);
seq[0] = std::make_shared<Sequence>(i, 1, inputs);
candidates[i] = std::make_shared<SequenceGroup>(id, seq);
}
for (auto candidate : candidates) {
loadBalancer_->AddSeqGroup(candidate);
}
{
std::unique_lock<std::mutex> lock(mtx);
cv.wait_for(lock, std::chrono::milliseconds(10), [] { return seqCount >= reqNum; });
}
EXPECT_EQ(seqCount, reqNum);
}
TEST_F(LoadBalancerTest, SendFewRequestsAndVerifyPeriodicTriggerScheduling)
{
size_t dpSize = 4;
uint32_t maxPrefillBatchSize = 10;
SchedulerConfig schedulerConfig;
SchedulerConfigSPtr schedulerConfigPtr = std::make_shared<SchedulerConfig>(schedulerConfig);
std::vector<std::shared_ptr<EnginePerDP>> enginePerDPs;
for (size_t i = 0; i < dpSize; ++i) {
EnginePerDPSPtr enginePerDP = std::make_shared<EnginePerDP>();
enginePerDPs.emplace_back(enginePerDP);
enginePerDP->scheduler = std::make_unique<MockScheduler>();
}
LoadBalancerPtr loadBalancer_ = MakeLoadBalancer(enginePerDPs, maxPrefillBatchSize);
RequestId id("hi");
std::vector<SequenceSPtr> seq(1);
std::vector<TokenId> inputs(1, 5);
seq[0] = std::make_shared<Sequence>(1, 1, inputs);
SequenceGroupSPtr candidate = std::make_shared<SequenceGroup>(id, seq);
seqCount = 0;
loadBalancer_->AddSeqGroup(candidate);
std::this_thread::sleep_for(std::chrono::microseconds(150));
std::this_thread::sleep_for(std::chrono::milliseconds(1));
EXPECT_EQ(seqCount, 1);
}
TEST_F(LoadBalancerTest, ValidateSequentialAlternationBetweenTriggeredAndPeriodicTasks)
{
size_t dpSize = 4;
uint32_t maxPrefillBatchSize = 10;
SchedulerConfig schedulerConfig;
SchedulerConfigSPtr schedulerConfigPtr = std::make_shared<SchedulerConfig>(schedulerConfig);
std::vector<std::shared_ptr<EnginePerDP>> enginePerDPs;
for (size_t i = 0; i < dpSize; ++i) {
EnginePerDPSPtr enginePerDP = std::make_shared<EnginePerDP>();
enginePerDPs.emplace_back(enginePerDP);
enginePerDP->scheduler = std::make_unique<MockScheduler>();
}
LoadBalancerPtr loadBalancer_ = MakeLoadBalancer(enginePerDPs, maxPrefillBatchSize);
constexpr int reqNum = 10;
size_t seqLen[reqNum] = {18, 3, 22, 15, 7, 11, 24, 9, 1, 17};
std::vector<SequenceGroupSPtr> candidates(reqNum);
for (int i = 0; i < reqNum; i++) {
RequestId id(std::to_string(i));
std::vector<SequenceSPtr> seq(1);
std::vector<TokenId> inputs(seqLen[i], 5);
seq[0] = std::make_shared<Sequence>(i, 1, inputs);
candidates[i] = std::make_shared<SequenceGroup>(id, seq);
}
for (auto candidate : candidates) {
loadBalancer_->AddSeqGroup(candidate);
}
{
std::unique_lock<std::mutex> lock(mtx);
cv.wait_for(lock, std::chrono::milliseconds(10), [] { return seqCount >= reqNum; });
}
EXPECT_EQ(seqCount, reqNum);
loadBalancer_->AddSeqGroup(candidates[0]);
loadBalancer_->AddSeqGroup(candidates[1]);
std::this_thread::sleep_for(std::chrono::microseconds(150));
std::this_thread::sleep_for(std::chrono::milliseconds(1));
int expectSeqCount = reqNum + 2;
EXPECT_EQ(seqCount, expectSeqCount);
std::this_thread::sleep_for(std::chrono::milliseconds(5));
for (auto candidate : candidates) {
loadBalancer_->AddSeqGroup(candidate);
}
expectSeqCount = reqNum * 2 + 2;
{
std::unique_lock<std::mutex> lock(mtx);
cv.wait_for(lock, std::chrono::milliseconds(50), [expectSeqCount] { return seqCount >= expectSeqCount; });
}
EXPECT_EQ(seqCount, expectSeqCount);
}