* Copyright (c) Huawei Technologies Co., Ltd. 2025-2026. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include <gtest/gtest.h>
#include "mockcpp/mockcpp.hpp"
#include <filesystem>
#include <iostream>
#include <string>
#include "dmi_msg_sender.h"
#include <grpcpp/channel.h>
#include <grpcpp/create_channel.h>
#include "health_checker.h"
#include "config/config_manager.h"
#include "env_util.h"
using namespace prefillAndDecodeCommunication;
namespace grpc {
bool operator==(const Status& lhs, const Status& rhs)
{
return lhs.error_code() == rhs.error_code() &&
lhs.error_message() == rhs.error_message();
}
bool operator!=(const Status& lhs, const Status& rhs)
{
return !(lhs == rhs);
}
}
namespace mindie_llm {
bool operator==(const ServerConfig& lhs, const ServerConfig& rhs)
{
return lhs.inferMode == rhs.inferMode &&
lhs.npuUsageThreshold == rhs.npuUsageThreshold;
}
const ServerConfig& GetServerConfig();
class DmiMsgSenderTest : public testing::Test {
protected:
void SetUp() override
{
EnvUtil::GetInstance().SetEnvVar(
"RANK_TABLE_FILE",
GetParentDirectory() + "/../../config_manager/conf/ranktable.json");
EnvUtil::GetInstance().SetEnvVar("MIES_CONTAINER_IP", "127.0.0.1");
EnvUtil::GetInstance().SetEnvVar("HOST_IP", "127.0.0.1");
EnvUtil::GetInstance().SetEnvVar("MINDIE_CHECK_INPUTFILES_PERMISSION", "1");
ASSERT_TRUE(ConfigManager::CreateInstance(
GetParentDirectory() + "/../../config_manager/conf/config_grpc.json"));
ServerConfig serverConfig;
serverConfig.inferMode = "standard";
serverConfig.npuUsageThreshold = 0;
serverConfig.distDPServerEnabled = false;
MOCKER_CPP(GetServerConfig, const ServerConfig& (*)())
.stubs()
.will(returnValue(serverConfig));
}
void TearDown() override
{
EnvUtil::GetInstance().ClearEnvVar("RANK_TABLE_FILE");
EnvUtil::GetInstance().ClearEnvVar("MINDIE_CHECK_INPUTFILES_PERMISSION");
EnvUtil::GetInstance().ClearEnvVar("MIES_CONTAINER_IP");
EnvUtil::GetInstance().ClearEnvVar("HOST_IP");
GlobalMockObject::verify();
}
static std::string GetParentDirectory()
{
try {
return std::filesystem::current_path().string();
} catch (const std::filesystem::filesystem_error& e) {
std::cerr << "Error getting current directory: " << e.what() << std::endl;
return "";
}
}
};
TEST_F(DmiMsgSenderTest, DecodeRequestSender_InitWithoutTlsSuccess)
{
DecodeRequestSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
EXPECT_TRUE(sender.Init());
}
TEST_F(DmiMsgSenderTest, DecodeRequestSender_InitWithTlsSuccess)
{
auto tlsOptions = std::make_unique<grpc::experimental::TlsChannelCredentialsOptions>();
DecodeRequestSender sender("127.0.0.1:50051", "127.0.0.1:50052", true, std::move(tlsOptions));
EXPECT_TRUE(sender.Init());
}
TEST_F(DmiMsgSenderTest, DecodeRequestSender_InitWithTlsButNoOptions)
{
DecodeRequestSender sender("127.0.0.1:50051", "127.0.0.1:50052", true, nullptr);
EXPECT_FALSE(sender.Init());
}
TEST_F(DmiMsgSenderTest, DecodeRequestSender_CreateStubSuccess)
{
auto channel = grpc::CreateChannel("127.0.0.1:50051", grpc::InsecureChannelCredentials());
DecodeRequestSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
sender.CreateStub(channel);
EXPECT_EQ(sender.receiverAddr_, "127.0.0.1:50051");
EXPECT_EQ(sender.localAddr_, "127.0.0.1:50052");
}
TEST_F(DmiMsgSenderTest, DecodeRequestSender_ResponseIsNotValiddecodeParameters)
{
auto channel = grpc::CreateChannel("127.0.0.1:50051", grpc::InsecureChannelCredentials());
DecodeRequestSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
sender.CreateStub(channel);
MOCKER_CPP(&DecodeService::Stub::DecodeRequestChannel, grpc::Status (*)(
grpc::ClientContext*, const DecodeParameters&, DecodeRequestResponse*))
.stubs().will(returnValue(grpc::Status::OK));
std::string errMsg;
DecodeParameters request;
std::string reqId = "req-";
EXPECT_FALSE(sender.SendDecodeRequestMsg(request, reqId, errMsg));
EXPECT_EQ(errMsg, "");
}
TEST_F(DmiMsgSenderTest, DecodeRequestSender_StubNull)
{
DecodeRequestSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
MOCKER_CPP(&DecodeService::Stub::DecodeRequestChannel, grpc::Status (*)(
grpc::ClientContext*, const DecodeParameters&, DecodeRequestResponse*))
.stubs().will(returnValue(grpc::Status::OK));
std::string errMsg;
DecodeParameters request;
std::string reqId = "req-";
EXPECT_FALSE(sender.SendDecodeRequestMsg(request, reqId, errMsg));
EXPECT_EQ(errMsg, "The stub_ is nullptr");
}
TEST_F(DmiMsgSenderTest, DecodeRequestSender_DecodeRequestChannelError)
{
auto channel = grpc::CreateChannel("127.0.0.1:50051", grpc::InsecureChannelCredentials());
DecodeRequestSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
sender.CreateStub(channel);
MOCKER_CPP(&DecodeService::Stub::DecodeRequestChannel, grpc::Status (*)(
grpc::ClientContext*, const DecodeParameters&, DecodeRequestResponse*))
.stubs().will(returnValue(grpc::Status::CANCELLED));
std::string errMsg;
DecodeParameters request;
std::string reqId = "req-";
EXPECT_FALSE(sender.SendDecodeRequestMsg(request, reqId, errMsg));
EXPECT_EQ(errMsg, "Failed to send decode request msg because[1] receiverAddr is 127.0.0.1:50051. RequestId is req-");
}
TEST_F(DmiMsgSenderTest, KvReleaseSender_CreateStubSuccess)
{
auto channel = grpc::CreateChannel("127.0.0.1:50051", grpc::InsecureChannelCredentials());
KvReleaseSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
sender.CreateStub(channel);
EXPECT_EQ(sender.receiverAddr_, "127.0.0.1:50051");
EXPECT_EQ(sender.localAddr_, "127.0.0.1:50052");
}
TEST_F(DmiMsgSenderTest, KvReleaseSender_Success)
{
auto channel = grpc::CreateChannel("127.0.0.1:50051", grpc::InsecureChannelCredentials());
KvReleaseSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
sender.CreateStub(channel);
MOCKER_CPP(&PrefillService::Stub::ReleaseKVCacheChannel, grpc::Status (*)(
grpc::ClientContext*, const RequestId&, google::protobuf::Empty*))
.stubs().will(returnValue(grpc::Status::OK));
prefillAndDecodeCommunication::RequestId request;
EXPECT_TRUE(sender.SendKvReleaseMsg(request));
}
TEST_F(DmiMsgSenderTest, KvReleaseSender_StubNull)
{
KvReleaseSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
MOCKER_CPP(&PrefillService::Stub::ReleaseKVCacheChannel, grpc::Status (*)(
grpc::ClientContext*, const RequestId&, google::protobuf::Empty*))
.stubs().will(returnValue(grpc::Status::OK));
prefillAndDecodeCommunication::RequestId request;
EXPECT_FALSE(sender.SendKvReleaseMsg(request));
}
TEST_F(DmiMsgSenderTest, KvReleaseSender_DecodeRequestChannelError)
{
auto channel = grpc::CreateChannel("127.0.0.1:50051", grpc::InsecureChannelCredentials());
KvReleaseSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
sender.CreateStub(channel);
MOCKER_CPP(&PrefillService::Stub::ReleaseKVCacheChannel, grpc::Status (*)(
grpc::ClientContext*, const RequestId&, google::protobuf::Empty*))
.stubs().will(returnValue(grpc::Status::CANCELLED));
prefillAndDecodeCommunication::RequestId request;
EXPECT_FALSE(sender.SendKvReleaseMsg(request));
}
}