* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include <gtest/gtest.h>
#include "lora/loraops_mixin.h"
#include "lora_manager.h"
using namespace mindie_llm;
LoraParamSPtr CreateTestLoraParam(const std::string& name = "test_lora",
const std::string& path = "/path/to/test_lora",
const std::string& master = "master_model")
{
auto param = std::make_shared<LoraParam>(name, path, master);
return param;
}
class MockModelExecutor : public IExecutor {
public:
void ExecuteRecoverCommand(RecoverCommandInfo &commandInfo) override {}
bool ExecutorInstanceInit(std::map<std::string, std::string> &config, bool isMultiNodesInfer,
size_t dpIdx = 0) override
{
return true;
}
bool AsyncExecuteModel(ExecuteModelRequestPtr &modelExecRequest,
std::function<void(ModelBatchResultSPtr)> callback = nullptr) override
{
return true;
}
bool AsyncTGCleanup(TGCleanupRequestPtr &TGCleanupRequest) override
{
return true;
}
bool ExecutorParseConfigAndInitGRPC(std::map<std::string, std::string> &configFromManager, bool isMultiNodesInfer,
size_t rankIdx) override
{
return true;
}
bool MasterAndSlaveModelInit(const std::map<std::string, std::string> &pdInfo) override { return true; }
bool SetupPDLink(model_execute_data::PDLinkRequest &pdLinkRequest) override { return true; }
bool QueryPDLinkStatus(model_execute_data::PDLinkStatusRequest &pdLinkStatusRequest) override
{
return true;
}
model_execute_data::PDLinkStatusResponse GetPDLinkStatusResponse() const override
{
return model_execute_data::PDLinkStatusResponse();
}
bool ExecuteKVTransfer(PullKVRequestPtr &pullKVRequest,
std::function<void(PullKVResponseSPtr)> callback = nullptr) override
{
return true;
}
bool ExecutorInstanceFinalize() override { return true; }
uint32_t GetCpuBlockNum() const override { return 1; }
uint32_t GetNpuBlockNum() const override { return 1; }
uint32_t GetLwdCloudNpuBlockNum() const override { return 1; }
uint32_t GetMaxPositionEmbeddings() const override { return 4096; }
ThinkingConfig GetThinkingConfig() const override
{
ThinkingConfig conf;
return conf;
}
bool ExecutLoraRequest(LoraOperationRequest &loraOperationRequest) override
{
return true;
}
bool AsyncEOSCleanup(TGCleanupRequestPtr &TGCleanupRequest) override
{
return true;
}
model_execute_data::LoraOperationResponse GetLoraOperationResponse() const override
{
return model_execute_data::LoraOperationResponse();
}
};
class LoraOpsMixinTest : public testing::Test {
protected:
LoraOpsMixin mixin;
void SetUp() override
{
maxLoras_ = 2;
mockExecutor_ = std::make_shared<MockModelExecutor>();
std::vector<IExecutorSPtr> executors(1, mockExecutor_);
LoraManager::Initialize(executors, maxLoras_);
}
void TearDown() override
{
mockExecutor_.reset();
}
uint32_t maxLoras_;
std::shared_ptr<MockModelExecutor> mockExecutor_;
};
TEST_F(LoraOpsMixinTest, InitStaticLoras)
{
std::map<std::string, std::string> testLoraModules = {
{"test_lora", "/path/to/test_lora"}
};
ModelParam modelParam;
modelParam.loraModules = testLoraModules;
modelParam.modelName = "master_model";
std::vector<ModelParam> modelParamVec = {modelParam};
size_t dpSize = 1;
mixin.InitStaticLoras(modelParamVec, dpSize);
EXPECT_EQ(LoraManager::GetInstance(0)->loaded_.Size(), 1);
}
TEST_F(LoraOpsMixinTest, TestLoraLoadInvalidLoraInfoCount)
{
std::vector<LoraParamSPtr> loraInfo;
Status status = mixin.LoraLoad(loraInfo, 1);
EXPECT_EQ(status.StatusCode(), Error::Code::ERROR);
EXPECT_NE(status.StatusMsg().find("invalid"), std::string::npos);
}
TEST_F(LoraOpsMixinTest, TestLoraLoadSuccess)
{
std::vector<mindie_llm::LoraParamSPtr> loraInfo;
loraInfo.push_back(CreateTestLoraParam());
size_t dpSize = 1;
Status status = mixin.LoraLoad(loraInfo, dpSize);
EXPECT_TRUE(status.IsOk());
}
TEST_F(LoraOpsMixinTest, TestLoraUnLoadInvalidLoraInfoCount)
{
std::vector<LoraParamSPtr> loraInfo;
Status status = mixin.LoraUnLoad(loraInfo, 1);
EXPECT_EQ(status.StatusCode(), Error::Code::ERROR);
EXPECT_NE(status.StatusMsg().find("invalid"), std::string::npos);
}
TEST_F(LoraOpsMixinTest, TestLoraUnLoadSuccess)
{
std::vector<LoraParamSPtr> loraInfo;
loraInfo.push_back(CreateTestLoraParam("unload_test"));
size_t dpSize = 1;
Status status = mixin.LoraUnLoad(loraInfo, dpSize);
EXPECT_TRUE(status.IsOk());
}
TEST_F(LoraOpsMixinTest, TestLoraGetLoaded)
{
std::vector<LoraParamSPtr> loraInfo;
size_t dpSize = 1;
Status status = mixin.LoraGetLoaded(loraInfo, dpSize);
EXPECT_TRUE(status.IsOk());
}