* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* ------------------------------------------------------------------------- */
#define private public
#define protected public
#include "core/DomainSocket.h"
#undef protected
#undef private
#include <chrono>
#include <gtest/gtest.h>
#include <string>
#include <sys/socket.h>
#include <thread>
#include <memory>
#include "mockcpp/mockcpp.hpp"
void AssertReadMessage(DomainSocketServer &server, std::size_t clientId, std::string const &message)
{
constexpr std::size_t maxSize = 1024ULL;
size_t readBytes = 0;
std::string recieved;
std::string buffer;
while (recieved.size() < message.size()) {
if (server.Read(clientId, buffer, maxSize, readBytes)) {
recieved += buffer;
}
}
ASSERT_EQ(recieved, message);
}
void AssertWriteMessage(DomainSocketServer &server, std::size_t clientId, std::string const &message)
{
size_t sentBytes = 0;
std::size_t totalSent = 0;
while (totalSent < message.size()) {
if (server.Write(clientId, message, sentBytes) && sentBytes > 0) {
totalSent += static_cast<std::size_t>(sentBytes);
}
}
ASSERT_EQ(totalSent, message.size());
}
void AssertReadMessage(DomainSocketClient &client, std::string const &message)
{
constexpr std::size_t maxSize = 1024ULL;
size_t readBytes = 0;
std::string recieved;
std::string buffer;
while (recieved.size() < message.size()) {
if (client.Read(buffer, maxSize, readBytes)) {
recieved += buffer;
}
}
ASSERT_EQ(recieved, message);
}
void AssertWriteMessage(DomainSocketClient &client, std::string const &message)
{
size_t writeBytes = 0;
ASSERT_TRUE(client.Write(message, writeBytes));
ASSERT_EQ(static_cast<std::size_t>(writeBytes), message.size());
}
void TryConnect(DomainSocketClient &client)
{
while (!client.Connect()) {
constexpr uint64_t connectRetryDuration = 100;
std::this_thread::sleep_for(std::chrono::milliseconds(connectRetryDuration));
}
}
TEST(DomainSocket, socket_call_failed_expect_create_socket_return_false)
{
MOCKER(&socket).stubs().will(returnValue(-1));
std::string socketPath = "/tmp/msop_connect.1000.sock";
DomainSocket socket(socketPath);
ASSERT_FALSE(socket.CreateSocket());
GlobalMockObject::verify();
}
TEST(DomainSocket, set_socket_options_failed_expect_create_socket_return_false)
{
MOCKER(&setsockopt).stubs().will(returnValue(-1));
std::string socketPath = "/tmp/msop_connect.1000.sock";
DomainSocket socket(socketPath);
ASSERT_FALSE(socket.CreateSocket());
GlobalMockObject::verify();
}
TEST(DomainSocket, create_socket_failed_expect_listen_and_bind_return_false)
{
MOCKER(&socket).stubs().will(returnValue(-1));
std::string socketPath = "/tmp/msop_connect.1000.sock";
DomainSocketServer socket(socketPath, 1);
ASSERT_FALSE(socket.ListenAndBind());
GlobalMockObject::verify();
}
TEST(DomainSocket, bind_failed_expect_listen_and_bind_return_false)
{
MOCKER(&bind).stubs().will(returnValue(-1));
std::string socketPath = "/tmp/msop_connect.1000.sock";
DomainSocketServer socket(socketPath, 1);
ASSERT_FALSE(socket.ListenAndBind());
GlobalMockObject::verify();
}
TEST(DomainSocket, listen_failed_expect_listen_and_bind_return_false)
{
MOCKER(&listen).stubs().will(returnValue(-1));
std::string socketPath = "/tmp/msop_connect.1000.sock";
DomainSocketServer socket(socketPath, 1);
ASSERT_FALSE(socket.ListenAndBind());
GlobalMockObject::verify();
}
TEST(DomainSocket, accept_failed_expect_accept_return_false)
{
MOCKER(&accept).stubs().will(returnValue(-1));
std::string socketPath = "/tmp/msop_connect.1000.sock";
DomainSocketServer socket(socketPath, 1);
DomainSocketServer::ClientId clientId;
ASSERT_FALSE(socket.Accept(clientId));
GlobalMockObject::verify();
}
TEST(DomainSocket, set_sock_opt_failed_expect_accept_return_false)
{
MOCKER(&accept).stubs().will(returnValue(1000));
MOCKER(&setsockopt).stubs().will(returnValue(-1));
std::string socketPath = "/tmp/msop_connect.1000.sock";
DomainSocketServer socket(socketPath, 1);
DomainSocketServer::ClientId clientId;
ASSERT_FALSE(socket.Accept(clientId));
GlobalMockObject::verify();
}
TEST(DomainSocket, accept_exceed_max_client_num_expect_accept_return_false)
{
uid_t uid{0};
uid_t gid{0};
MOCKER(&accept).stubs().will(returnValue(1000));
MOCKER(&setsockopt).stubs().will(returnValue(0));
MOCKER(&getsockopt).stubs().will(returnValue(0));
MOCKER(&getuid).stubs().will(returnValue(uid));
MOCKER(&getgid).stubs().will(returnValue(gid));
std::string socketPath = "/tmp/msop_connect.1000.sock";
DomainSocketServer socket(socketPath, 1);
DomainSocketServer::ClientId clientId;
ASSERT_TRUE(socket.Accept(clientId));
ASSERT_FALSE(socket.Accept(clientId));
GlobalMockObject::verify();
}
TEST(DomainSocket, create_socket_failed_expect_connect_return_false)
{
MOCKER(&socket).stubs().will(returnValue(-1));
std::string socketPath = "/tmp/msop_connect.1000.sock";
DomainSocketClient socket(socketPath);
ASSERT_FALSE(socket.Connect());
GlobalMockObject::verify();
}
TEST(DomainSocket, connect_failed_expect_connect_return_false)
{
MOCKER(&connect).stubs().will(returnValue(-1));
std::string socketPath = "/tmp/msop_connect.1000.sock";
DomainSocketClient socket(socketPath);
ASSERT_FALSE(socket.Connect());
GlobalMockObject::verify();
}
TEST(DomainSocket, connect_success_expect_connect_return_true)
{
MOCKER(&connect).stubs().will(returnValue(0));
std::string socketPath = "/tmp/msop_connect.1000.sock";
DomainSocketClient socket(socketPath);
ASSERT_TRUE(socket.Connect());
GlobalMockObject::verify();
}
TEST(DomainSocket, send_message_from_client_to_server_expect_success)
{
std::string message = "hello world";
pid_t pid = fork();
ASSERT_NE(pid, -1);
std::string socketPath = "/tmp/msop_connect.1000.sock";
if (pid == 0) {
std::shared_ptr<void> defer(nullptr, [](std::nullptr_t&) {_exit(0);});
DomainSocketClient socket(socketPath);
TryConnect(socket);
AssertWriteMessage(socket, message);
} else {
DomainSocketServer socket(socketPath, 1);
typename DomainSocketServer::ClientId id;
ASSERT_TRUE(socket.ListenAndBind());
ASSERT_TRUE(socket.Accept(id));
ASSERT_EQ(id, 0ULL);
AssertReadMessage(socket, 0, message);
}
}
TEST(DomainSocket, send_message_from_server_to_client_expect_success)
{
std::string message = "hello world";
pid_t pid = fork();
ASSERT_NE(pid, -1);
std::string socketPath = "/tmp/msop_connect.1000.sock";
if (pid == 0) {
std::shared_ptr<void> defer(nullptr, [](std::nullptr_t&) {_exit(0);});
DomainSocketClient socket(socketPath);
TryConnect(socket);
AssertReadMessage(socket, message);
} else {
DomainSocketServer socket(socketPath, 1);
typename DomainSocketServer::ClientId id;
ASSERT_TRUE(socket.ListenAndBind());
ASSERT_TRUE(socket.Accept(id));
ASSERT_EQ(id, 0ULL);
AssertWriteMessage(socket, 0, message);
}
}
TEST(DomainSocket, send_interleaving_message_expect_success)
{
std::vector<std::string> messages = {"first", "second", "third", "forth"};
pid_t pid = fork();
ASSERT_NE(pid, -1);
std::string socketPath = "/tmp/msop_connect.1000.sock";
if (pid == 0) {
DomainSocketClient socket(socketPath);
TryConnect(socket);
AssertReadMessage(socket, messages[0] + messages[1]);
AssertWriteMessage(socket, messages[2]);
AssertReadMessage(socket, messages[3]);
_exit(0);
} else {
DomainSocketServer socket(socketPath, 1);
typename DomainSocketServer::ClientId id;
ASSERT_TRUE(socket.ListenAndBind());
ASSERT_TRUE(socket.Accept(id));
ASSERT_EQ(id, 0ULL);
AssertWriteMessage(socket, 0, messages[0]);
AssertWriteMessage(socket, 0, messages[1]);
AssertReadMessage(socket, 0, messages[2]);
AssertWriteMessage(socket, 0, messages[3]);
}
}
TEST(DomainSocket, send_messages_with_client_in_parent_and_server_in_child_expect_success)
{
std::vector<std::string> messages = {"first", "second", "third", "forth"};
pid_t pid = fork();
ASSERT_NE(pid, -1);
std::string socketPath = "/tmp/msop_connect.1000.sock";
if (pid == 0) {
std::shared_ptr<void> defer(nullptr, [](std::nullptr_t&) {_exit(0);});
DomainSocketServer socket(socketPath, 1);
typename DomainSocketServer::ClientId id;
ASSERT_TRUE(socket.ListenAndBind());
ASSERT_TRUE(socket.Accept(id));
ASSERT_EQ(id, 0ULL);
AssertWriteMessage(socket, 0, messages[0]);
AssertWriteMessage(socket, 0, messages[1]);
AssertReadMessage(socket, 0, messages[2]);
AssertWriteMessage(socket, 0, messages[3]);
} else {
DomainSocketClient socket(socketPath);
TryConnect(socket);
AssertReadMessage(socket, messages[0] + messages[1]);
AssertWriteMessage(socket, messages[2]);
AssertReadMessage(socket, messages[3]);
}
}
TEST(DomainSocket, multiple_clients_connect_to_server_and_send_messages_expect_success)
{
std::vector<std::string> messages = {"first", "second", "third", "forth"};
pid_t pid = fork();
ASSERT_NE(pid, -1);
if (pid == 0) {
std::shared_ptr<void> defer(nullptr, [](std::nullptr_t&) {_exit(0);});
std::string socketPath = "/tmp/msop_connect.1000.sock";
DomainSocketClient socket0(socketPath);
DomainSocketClient socket1(socketPath);
DomainSocketClient socket2(socketPath);
DomainSocketClient socket3(socketPath);
TryConnect(socket0);
TryConnect(socket1);
TryConnect(socket2);
TryConnect(socket3);
AssertReadMessage(socket0, messages[0]);
AssertReadMessage(socket1, messages[1]);
AssertWriteMessage(socket2, messages[2]);
AssertReadMessage(socket3, messages[3]);
} else {
std::string socketPath = "/tmp/msop_connect.1000.sock";
DomainSocketServer socket(socketPath, 4);
typename DomainSocketServer::ClientId id;
ASSERT_TRUE(socket.ListenAndBind());
for (std::size_t clientNum = 0; clientNum < 4; ++clientNum) {
ASSERT_TRUE(socket.Accept(id));
ASSERT_EQ(id, clientNum);
}
ASSERT_EQ(socket.GetClientNum(), 4ULL);
AssertWriteMessage(socket, 0, messages[0]);
AssertWriteMessage(socket, 1, messages[1]);
AssertReadMessage(socket, 2, messages[2]);
AssertWriteMessage(socket, 3, messages[3]);
}
}