#ifdef UNSAFE_BUFFERS_BUILD
#pragma allow_unsafe_libc_calls
#endif
#include "base/sync_socket.h"
#include <stddef.h>
#include <stdio.h>
#include <array>
#include <memory>
#include <sstream>
#include <string>
#include "base/containers/span.h"
#include "base/functional/bind.h"
#include "base/location.h"
#include "base/memory/raw_ptr.h"
#include "base/run_loop.h"
#include "base/task/single_thread_task_runner.h"
#include "base/threading/thread.h"
#include "base/types/fixed_array.h"
#include "build/build_config.h"
#include "ipc/ipc_test_base.h"
#include "testing/gtest/include/gtest/gtest.h"
#if BUILDFLAG(IS_POSIX) || BUILDFLAG(IS_FUCHSIA)
#include "base/file_descriptor_posix.h"
#endif
#define IPC_MESSAGE_IMPL
#include "ipc/ipc_message_macros.h"
#include "ipc/ipc_message_start.h"
#define IPC_MESSAGE_START TestMsgStart
#if BUILDFLAG(IS_WIN)
IPC_MESSAGE_CONTROL1(MsgClassSetHandle, base::SyncSocket::Handle)
#elif BUILDFLAG(IS_POSIX) || BUILDFLAG(IS_FUCHSIA)
IPC_MESSAGE_CONTROL1(MsgClassSetHandle, base::FileDescriptor)
#endif
IPC_MESSAGE_CONTROL1(MsgClassResponse, std::string)
IPC_MESSAGE_CONTROL0(MsgClassShutdown)
namespace {
const char kHelloString[] = "Hello, SyncSocket Client";
const size_t kHelloStringLength = std::size(kHelloString);
class SyncSocketServerListener : public IPC::Listener {
public:
SyncSocketServerListener() : chan_(nullptr) {}
SyncSocketServerListener(const SyncSocketServerListener&) = delete;
SyncSocketServerListener& operator=(const SyncSocketServerListener&) = delete;
void Init(IPC::Channel* chan) {
chan_ = chan;
}
bool OnMessageReceived(const IPC::Message& msg) override {
if (msg.routing_id() == MSG_ROUTING_CONTROL) {
IPC_BEGIN_MESSAGE_MAP(SyncSocketServerListener, msg)
IPC_MESSAGE_HANDLER(MsgClassSetHandle, OnMsgClassSetHandle)
IPC_MESSAGE_HANDLER(MsgClassShutdown, OnMsgClassShutdown)
IPC_END_MESSAGE_MAP()
}
return true;
}
void set_quit_closure(base::OnceClosure quit_closure) {
quit_closure_ = std::move(quit_closure);
}
private:
#if BUILDFLAG(IS_WIN)
void OnMsgClassSetHandle(const base::SyncSocket::Handle handle) {
SetHandle(handle);
}
#elif BUILDFLAG(IS_POSIX) || BUILDFLAG(IS_FUCHSIA)
void OnMsgClassSetHandle(const base::FileDescriptor& fd_struct) {
SetHandle(fd_struct.fd);
}
#else
# error "What platform?"
#endif
void SetHandle(base::SyncSocket::Handle handle) {
base::SyncSocket sync_socket(handle);
auto bytes_to_send = base::as_byte_span(kHelloString);
EXPECT_EQ(sync_socket.Send(bytes_to_send), bytes_to_send.size());
IPC::Message* msg = new MsgClassResponse(kHelloString);
EXPECT_TRUE(chan_->Send(msg));
}
void OnMsgClassShutdown() { std::move(quit_closure_).Run(); }
raw_ptr<IPC::Channel> chan_;
base::OnceClosure quit_closure_;
};
DEFINE_IPC_CHANNEL_MOJO_TEST_CLIENT(SyncSocketServerClient) {
SyncSocketServerListener listener;
base::RunLoop loop;
listener.set_quit_closure(loop.QuitWhenIdleClosure());
Connect(&listener);
listener.Init(channel());
loop.Run();
Close();
}
class SyncSocketClientListener : public IPC::Listener {
public:
SyncSocketClientListener() = default;
SyncSocketClientListener(const SyncSocketClientListener&) = delete;
SyncSocketClientListener& operator=(const SyncSocketClientListener&) = delete;
void Init(base::SyncSocket* socket, IPC::Channel* chan) {
socket_ = socket;
chan_ = chan;
}
bool OnMessageReceived(const IPC::Message& msg) override {
if (msg.routing_id() == MSG_ROUTING_CONTROL) {
IPC_BEGIN_MESSAGE_MAP(SyncSocketClientListener, msg)
IPC_MESSAGE_HANDLER(MsgClassResponse, OnMsgClassResponse)
IPC_END_MESSAGE_MAP()
}
return true;
}
void set_quit_closure(base::OnceClosure quit_closure) {
quit_closure_ = std::move(quit_closure);
}
private:
void OnMsgClassResponse(const std::string& str) {
size_t expected_bytes_to_receive = str.length() + 1;
EXPECT_EQ(socket_->Peek(), expected_bytes_to_receive);
base::FixedArray<char> buf(expected_bytes_to_receive);
socket_->Receive(base::as_writable_byte_span(buf));
EXPECT_EQ(strcmp(str.c_str(), buf.data()), 0);
EXPECT_EQ(0U, socket_->Peek());
IPC::Message* msg = new MsgClassShutdown();
EXPECT_TRUE(chan_->Send(msg));
std::move(quit_closure_).Run();
}
raw_ptr<base::SyncSocket> socket_;
raw_ptr<IPC::Channel, DanglingUntriaged> chan_;
base::OnceClosure quit_closure_;
};
using SyncSocketTest = IPCChannelMojoTestBase;
TEST_F(SyncSocketTest, SanityTest) {
Init("SyncSocketServerClient");
base::RunLoop loop;
SyncSocketClientListener listener;
listener.set_quit_closure(loop.QuitWhenIdleClosure());
CreateChannel(&listener);
std::array<base::SyncSocket, 2> pair;
base::SyncSocket::CreatePair(&pair[0], &pair[1]);
EXPECT_EQ(0U, pair[0].Peek());
EXPECT_EQ(0U, pair[1].Peek());
base::SyncSocket::Handle target_handle;
ASSERT_TRUE(ConnectChannel());
listener.Init(&pair[0], channel());
#if BUILDFLAG(IS_WIN)
BOOL retval = DuplicateHandle(GetCurrentProcess(), pair[1].handle(),
client_process().Handle(), &target_handle,
0, FALSE, DUPLICATE_SAME_ACCESS);
EXPECT_TRUE(retval);
IPC::Message* msg = new MsgClassSetHandle(target_handle);
#else
target_handle = pair[1].handle();
base::FileDescriptor filedesc(target_handle, false);
IPC::Message* msg = new MsgClassSetHandle(filedesc);
#endif
EXPECT_TRUE(sender()->Send(msg));
loop.Run();
pair[0].Close();
pair[1].Close();
EXPECT_TRUE(WaitForClientShutdown());
DestroyChannel();
}
static void BlockingRead(base::SyncSocket* socket,
base::span<uint8_t> buffer,
size_t* received) {
socket->Send(base::as_byte_span(kHelloString));
*received = socket->Receive(buffer);
}
TEST_F(SyncSocketTest, DisconnectTest) {
std::array<base::CancelableSyncSocket, 2> pair;
ASSERT_TRUE(base::CancelableSyncSocket::CreatePair(&pair[0], &pair[1]));
base::Thread worker("BlockingThread");
worker.Start();
char buf[0xff];
size_t received = 1U;
worker.task_runner()->PostTask(
FROM_HERE, base::BindOnce(&BlockingRead, &pair[0],
base::as_writable_byte_span(buf), &received));
char hello[kHelloStringLength] = {};
pair[1].Receive(base::as_writable_byte_span(hello));
EXPECT_EQ(strcmp(hello, kHelloString), 0);
base::PlatformThread::YieldCurrentThread();
pair[0].Shutdown();
worker.Stop();
EXPECT_EQ(0U, received);
}
TEST_F(SyncSocketTest, BlockingReceiveTest) {
std::array<base::CancelableSyncSocket, 2> pair;
ASSERT_TRUE(base::CancelableSyncSocket::CreatePair(&pair[0], &pair[1]));
base::Thread worker("BlockingThread");
worker.Start();
char buf[kHelloStringLength] = {};
size_t received = 1U;
worker.task_runner()->PostTask(
FROM_HERE, base::BindOnce(&BlockingRead, &pair[0],
base::as_writable_byte_span(buf), &received));
char hello[kHelloStringLength] = {};
pair[1].Receive(base::as_writable_byte_span(hello));
EXPECT_EQ(0, strcmp(hello, kHelloString));
base::PlatformThread::YieldCurrentThread();
auto bytes_to_send = base::as_byte_span(kHelloString);
pair[1].Send(bytes_to_send);
worker.Stop();
EXPECT_TRUE(strcmp(buf, kHelloString) == 0);
EXPECT_EQ(received, bytes_to_send.size());
}
TEST_F(SyncSocketTest, NonBlockingWriteTest) {
std::array<base::CancelableSyncSocket, 2> pair;
ASSERT_TRUE(base::CancelableSyncSocket::CreatePair(&pair[0], &pair[1]));
auto bytes_to_send = base::as_byte_span(kHelloString);
while (pair[0].Send(bytes_to_send) != 0) {
}
size_t bytes_in_buffer = pair[1].Peek();
EXPECT_NE(bytes_in_buffer, 0U);
EXPECT_EQ(pair[0].Send(bytes_to_send), 0U);
EXPECT_EQ(bytes_in_buffer, pair[1].Peek());
char hello[kHelloStringLength] = {};
pair[1].Receive(base::as_writable_byte_span(hello));
EXPECT_EQ(pair[0].Send(bytes_to_send), bytes_to_send.size());
}
}