* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include <gtest/gtest.h>
#include <vector>
#include <string>
#include "lora_manager.h"
#include "executor/executor_interface.h"
using namespace mindie_llm;
class MockModelExecutor : public IExecutor {
public:
void ExecuteRecoverCommand(RecoverCommandInfo &commandInfo) override {}
bool ExecutorInstanceInit(std::map<std::string, std::string> &config, bool isMultiNodesInfer,
size_t dpIdx = 0) override
{
return true;
}
bool AsyncExecuteModel(ExecuteModelRequestPtr &modelExecRequest,
std::function<void(ModelBatchResultSPtr)> callback = nullptr) override
{
return true;
}
bool AsyncTGCleanup(TGCleanupRequestPtr &TGCleanupRequest) override
{
return true;
}
bool ExecutorParseConfigAndInitGRPC(std::map<std::string, std::string> &configFromManager, bool isMultiNodesInfer,
size_t rankIdx) override
{
return true;
}
bool MasterAndSlaveModelInit(const std::map<std::string, std::string> &pdInfo) override { return true; }
bool SetupPDLink(model_execute_data::PDLinkRequest &pdLinkRequest) override { return true; }
bool QueryPDLinkStatus(model_execute_data::PDLinkStatusRequest &pdLinkStatusRequest) override
{
return true;
}
model_execute_data::PDLinkStatusResponse GetPDLinkStatusResponse() const override
{
return model_execute_data::PDLinkStatusResponse();
}
bool ExecuteKVTransfer(PullKVRequestPtr &pullKVRequest,
std::function<void(PullKVResponseSPtr)> callback = nullptr) override
{
return true;
}
bool ExecutorInstanceFinalize() override { return true; }
uint32_t GetCpuBlockNum() const override { return 1; }
uint32_t GetNpuBlockNum() const override { return 1; }
uint32_t GetLwdCloudNpuBlockNum() const override { return 1; }
uint32_t GetMaxPositionEmbeddings() const override { return 4096; }
ThinkingConfig GetThinkingConfig() const override
{
ThinkingConfig conf;
return conf;
}
bool ExecutLoraRequest(LoraOperationRequest &loraOperationRequest) override
{
return true;
}
bool AsyncEOSCleanup(TGCleanupRequestPtr &TGCleanupRequest) override
{
return true;
}
model_execute_data::LoraOperationResponse GetLoraOperationResponse() const override
{
return model_execute_data::LoraOperationResponse();
}
};
class LoraManagerTest : public ::testing::Test {
protected:
void SetUp() override
{
maxLoras_ = 2;
mockExecutor_ = std::make_shared<MockModelExecutor>();
loraManager_ = std::make_shared<LoraManager>(mockExecutor_, maxLoras_);
}
void TearDown() override
{
loraManager_.reset();
mockExecutor_.reset();
}
uint32_t maxLoras_;
std::shared_ptr<MockModelExecutor> mockExecutor_;
std::shared_ptr<LoraManager> loraManager_;
};
TEST_F(LoraManagerTest, TestGetInstanceSuccess)
{
std::vector<IExecutorSPtr> executors(1, mockExecutor_);
LoraManager::Initialize(executors, maxLoras_);
LlmLoraPtr instance = LoraManager::GetInstance(0);
EXPECT_NE(instance, nullptr);
}
TEST_F(LoraManagerTest, TestLoadSuccess)
{
LoraParamSPtr loraInfo = std::make_shared<LoraParam>("test_lora", "/path/to/test_lora", "master_model");
Status status = loraManager_->Load(loraInfo);
EXPECT_EQ(status.StatusCode(), Error::Code::OK);
EXPECT_NE(status.StatusMsg().find("Success"), std::string::npos);
}
TEST_F(LoraManagerTest, TestLoadDuplicated)
{
LoraParamSPtr loraInfo = std::make_shared<LoraParam>("test_lora", "/path/to/test_lora", "master_model");
loraManager_->Load(loraInfo);
Status status = loraManager_->Load(loraInfo);
EXPECT_EQ(status.StatusCode(), Error::Code::OK);
EXPECT_NE(status.StatusMsg().find("has already been added"), std::string::npos);
}
TEST_F(LoraManagerTest, TestInitLoadedLoras)
{
std::map<std::string, std::string> testLoraModules = {
{"test_lora1", "/path/to/test_lora"},
{"test_lora2", "/path/to/test_lora"}
};
ModelParam modelParam;
modelParam.loraModules = testLoraModules;
modelParam.modelName = "master_model";
std::vector<ModelParam> modelParamVec = {modelParam};
loraManager_->InitLoadedLoras(modelParamVec);
EXPECT_EQ(loraManager_->loaded_.Size(), 2);
}
TEST_F(LoraManagerTest, TestLoadInvalidPath)
{
LoraParamSPtr loraInfo = std::make_shared<LoraParam>("test_lora", "", "master_model");
Status status = loraManager_->Load(loraInfo);
EXPECT_EQ(status.StatusCode(), Error::Code::OK);
EXPECT_NE(status.StatusMsg().find("No adapter found"), std::string::npos);
}
TEST_F(LoraManagerTest, TestLoadSlotsFull)
{
LoraParamSPtr lora1 = std::make_shared<LoraParam>("test_lora1", "/path/to/test_lora1", "master_model");
LoraParamSPtr lora2 = std::make_shared<LoraParam>("test_lora2", "/path/to/test_lora2", "master_model");
loraManager_->Load(lora1);
loraManager_->Load(lora2);
LoraParamSPtr lora3 = std::make_shared<LoraParam>("test_lora3", "/path/to/test_lora3", "master_model");
Status status = loraManager_->Load(lora3);
EXPECT_EQ(status.StatusCode(), Error::Code::OK);
EXPECT_NE(status.StatusMsg().find("none are currently unloading"), std::string::npos);
}
TEST_F(LoraManagerTest, TestStartToUnloadSuccess)
{
LoraParamSPtr loraInfo = std::make_shared<LoraParam>("test_lora", "/path/to/test_lora", "master_model");
loraManager_->Load(loraInfo);
Status result = loraManager_->StartToUnload("test_lora");
EXPECT_EQ(result.StatusCode(), Error::Code::OK);
EXPECT_NE(result.StatusMsg().find("removed successfully"), std::string::npos);
}
TEST_F(LoraManagerTest, TestStartToUnloadNotFound)
{
Status status = loraManager_->StartToUnload("nonexistent_lora");
EXPECT_EQ(status.StatusCode(), Error::Code::OK);
EXPECT_NE(status.StatusMsg().find("cannot be found"), std::string::npos);
}
TEST_F(LoraManagerTest, TestStartToUnloadAlreadyUnloading)
{
LoraParamSPtr loraInfo = std::make_shared<LoraParam>("test_lora", "/path/to/test_lora", "master_model");
loraManager_->Load(loraInfo);
loraManager_->StartToUnload("test_lora");
Status status = loraManager_->StartToUnload("test_lora");
EXPECT_EQ(status.StatusCode(), Error::Code::OK);
EXPECT_NE(status.StatusMsg().find("cannot be found"), std::string::npos);
}
TEST_F(LoraManagerTest, TestGetLoadedLorasSuccess)
{
LoraParamSPtr lora1 = std::make_shared<LoraParam>("test_lora1", "/path/to/test_lora1", "master_model");
LoraParamSPtr lora2 = std::make_shared<LoraParam>("test_lora2", "/path/to/test_lora2", "master_model");
loraManager_->Load(lora1);
loraManager_->Load(lora2);
std::vector<LoraParamSPtr> loadedLoras;
Status status = loraManager_->GetLoadedLoras(loadedLoras);
EXPECT_EQ(status.StatusCode(), Error::Code::OK);
EXPECT_EQ(loraManager_->loaded_.Size(), 2);
EXPECT_EQ(loadedLoras.size(), 2);
}
TEST_F(LoraManagerTest, TestGetLoadedLorasEmpty)
{
std::vector<LoraParamSPtr> loadedLoras;
Status status = loraManager_->GetLoadedLoras(loadedLoras);
EXPECT_EQ(status.StatusCode(), Error::Code::OK);
EXPECT_EQ(loadedLoras.size(), 0);
}
TEST_F(LoraManagerTest, TestTryUnLoadWaitingSuccess)
{
LoraParamSPtr loraInfo = std::make_shared<LoraParam>("test_lora", "/path/to/test_lora", "master_model");
loraManager_->Load(loraInfo);
loraManager_->StartToUnload("test_lora");
loraManager_->TryUnLoadWaiting();
std::vector<LoraParamSPtr> loadedLoras;
loraManager_->GetLoadedLoras(loadedLoras);
EXPECT_EQ(loadedLoras.size(), 0);
EXPECT_EQ(loraManager_->loaded_.Size(), 0);
EXPECT_EQ(loraManager_->wait2Unloaded_.Size(), 0);
}
TEST_F(LoraManagerTest, TestTryUnLoadWaitingReferenceExists)
{
LoraParamSPtr loraInfo = std::make_shared<LoraParam>("test_lora", "/path/to/test_lora", "master_model");
loraManager_->Load(loraInfo);
loraManager_->IncLoraRef("test_lora");
loraManager_->StartToUnload("test_lora");
loraManager_->TryUnLoadWaiting();
EXPECT_EQ(loraManager_->loaded_.Size(), 1);
EXPECT_EQ(loraManager_->wait2Unloaded_.Size(), 1);
}
TEST_F(LoraManagerTest, TestValidateLoraIdSuccess)
{
LoraParamSPtr loraInfo = std::make_shared<LoraParam>("test_lora", "/path/to/test_lora", "master_model");
loraManager_->Load(loraInfo);
bool isValid = loraManager_->ValidateLoraId("test_lora");
EXPECT_TRUE(isValid);
}
TEST_F(LoraManagerTest, TestValidateLoraIdFailed)
{
bool isValid = loraManager_->ValidateLoraId("nonexistent_lora");
EXPECT_FALSE(isValid);
}
TEST_F(LoraManagerTest, TestIncDecLoraRefSuccess)
{
LoraParamSPtr loraInfo = std::make_shared<LoraParam>("test_lora", "/path/to/test_lora", "master_model");
loraManager_->Load(loraInfo);
loraManager_->IncLoraRef("test_lora");
EXPECT_EQ(loraManager_->loraIdRef_.Get("test_lora").value(), 1);
loraManager_->DecLoraRef("test_lora");
EXPECT_EQ(loraManager_->loraIdRef_.Get("test_lora").value(), 0);
}
TEST_F(LoraManagerTest, TestIncDecLoraRefInvalidId)
{
loraManager_->IncLoraRef("nonexistent_lora");
EXPECT_EQ(loraManager_->loraIdRef_.Count("nonexistent_lora"), 0);
loraManager_->DecLoraRef("nonexistent_lora");
EXPECT_EQ(loraManager_->loraIdRef_.Count("nonexistent_lora"), 0);
}