#include "mojo/public/cpp/bindings/direct_receiver.h"
#include <memory>
#include <utility>
#include "base/barrier_closure.h"
#include "base/functional/bind.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/raw_ref.h"
#include "base/message_loop/message_pump_type.h"
#include "base/synchronization/waitable_event.h"
#include "base/test/bind.h"
#include "base/test/task_environment.h"
#include "base/threading/sequence_bound.h"
#include "base/threading/thread.h"
#include "mojo/core/embedder/embedder.h"
#include "mojo/core/ipcz_api.h"
#include "mojo/core/test/mojo_test_base.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "mojo/public/cpp/bindings/tests/direct_receiver_unittest.test-mojom.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/ipcz/src/test/test_base.h"
namespace mojo::test::direct_receiver_unittest {
namespace {
template <typename Fn>
void RunOn(base::Thread& thread, Fn closure) {
base::WaitableEvent wait;
thread.task_runner()->PostTask(FROM_HERE, base::BindLambdaForTesting([&] {
closure();
wait.Signal();
}));
wait.Wait();
}
IpczResult DummyBoxSerializer(uintptr_t object,
uint32_t flags,
const void* options,
volatile void* data,
size_t* num_bytes,
IpczHandle* handles,
size_t* num_handles) {
*num_bytes = 0;
*num_handles = 0;
return IPCZ_RESULT_OK;
}
void DummyBoxDestructor(uintptr_t object, uint32_t flags, const void* options) {
}
ScopedHandle CreateDummyHandle(IpczHandle node) {
const IpczBoxContents dummy_contents{
.size = sizeof(IpczBoxContents),
.type = IPCZ_BOX_TYPE_APPLICATION_OBJECT,
.object = {.application_object = 0},
.serializer = &DummyBoxSerializer,
.destructor = &DummyBoxDestructor,
};
IpczHandle dummy_handle;
const IpczResult box_result = core::GetIpczAPI().Box(
node, &dummy_contents, IPCZ_NO_FLAGS, nullptr, &dummy_handle);
EXPECT_EQ(box_result, IPCZ_RESULT_OK);
EXPECT_NE(dummy_handle, IPCZ_INVALID_HANDLE);
return ScopedHandle{Handle{dummy_handle}};
}
}
class DirectReceiverTest : public ipcz::test::internal::TestBase,
public core::test::MojoTestBase {
void SetUp() override {
if (!core::IsMojoIpczEnabled()) {
GTEST_SKIP() << "This test is only valid when MojoIpcz is enabled.";
}
}
private:
base::test::TaskEnvironment task_environment_;
};
class ScopedPauseIOThread {
public:
ScopedPauseIOThread() {
base::RunLoop loop;
core::GetIOTaskRunner()->PostTask(
FROM_HERE,
base::BindLambdaForTesting([this, quit = loop.QuitClosure()] {
quit.Run();
unblock_event_.Wait();
}));
loop.Run();
}
~ScopedPauseIOThread() {
base::RunLoop loop;
unblock_event_.Signal();
core::GetIOTaskRunner()->PostTask(FROM_HERE, loop.QuitClosure());
loop.Run();
}
private:
base::WaitableEvent unblock_event_;
};
class ServiceImpl : public mojom::Service {
public:
explicit ServiceImpl(scoped_refptr<base::SingleThreadTaskRunner> task_runner)
: task_runner_(std::move(task_runner)) {}
~ServiceImpl() override = default;
DirectReceiver<mojom::Service>& receiver() { return receiver_; }
void SimulateFailure() {
receiver_.node_for_testing().ReplacePortalForTesting(
CreateDummyHandle(receiver_.node_for_testing().node()));
}
void BindAndPauseReceiver(PendingReceiver<mojom::Service> receiver,
base::WaitableEvent* bound_event,
base::WaitableEvent* ping_event) {
ping_event_ = ping_event;
receiver_.Bind(std::move(receiver));
receiver_.receiver_for_testing().Pause();
bound_event->Signal();
}
void UnpauseReceiver() { receiver_.receiver_for_testing().Resume(); }
IpczHandle GetReceiverPortal() {
return receiver_.receiver_for_testing().internal_state()->handle().value();
}
void Ping(PingCallback callback) override {
EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
std::move(callback).Run();
if (auto event = std::exchange(ping_event_, nullptr)) {
event->Signal();
}
}
private:
DirectReceiver<mojom::Service> receiver_{DirectReceiverKey{}, this};
const scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
raw_ptr<base::WaitableEvent> ping_event_ = nullptr;
};
class ServiceRunner {
public:
explicit ServiceRunner(DirectReceiverTest& test,
base::Thread* shared_thread = nullptr)
: test_(test), impl_thread_(shared_thread) {
if (!impl_thread_) {
owned_thread_ = std::make_unique<base::Thread>("Impl Thread");
impl_thread_ = owned_thread_.get();
impl_thread_->StartWithOptions(
base::Thread::Options{base::MessagePumpType::IO, 0});
}
}
~ServiceRunner() { service_.SynchronouslyResetForTest(); }
base::Thread* impl_thread() const { return impl_thread_; }
void Start(PendingReceiver<mojom::Service> receiver,
base::WaitableEvent& ping_event,
bool simulate_failure = false) {
service_.emplace(impl_thread_->task_runner(), impl_thread_->task_runner());
if (simulate_failure) {
service_.AsyncCall(&ServiceImpl::SimulateFailure);
}
base::WaitableEvent bound_event;
service_.AsyncCall(&ServiceImpl::BindAndPauseReceiver)
.WithArgs(std::move(receiver), &bound_event, &ping_event);
bound_event.Wait();
base::RunLoop wait_loop;
service_.AsyncCall(&ServiceImpl::GetReceiverPortal)
.Then(base::BindLambdaForTesting([&](IpczHandle portal) {
test_->WaitForDirectRemoteLink(portal);
wait_loop.Quit();
}));
wait_loop.Run();
base::RunLoop unpause_loop;
service_.AsyncCall(&ServiceImpl::UnpauseReceiver)
.Then(unpause_loop.QuitClosure());
unpause_loop.Run();
}
private:
const raw_ref<DirectReceiverTest> test_;
std::unique_ptr<base::Thread> owned_thread_;
raw_ptr<base::Thread> impl_thread_;
base::SequenceBound<ServiceImpl> service_;
};
TEST_F(DirectReceiverTest, NoIOThreadHopInBroker) {
ServiceRunner runner{*this};
RunTestClient("NoIOThreadHopInBroker_Child", [&](MojoHandle child) {
PendingRemote<mojom::Service> remote;
PendingReceiver<mojom::Service> receiver =
remote.InitWithNewPipeAndPassReceiver();
MojoHandle remote_pipe = remote.PassPipe().release().value();
WriteMessageWithHandles(child, "", &remote_pipe, 1);
base::WaitableEvent ping_event;
runner.Start(std::move(receiver), ping_event);
{
ScopedPauseIOThread pause_io;
WriteMessage(child, "ok go");
ping_event.Wait();
}
EXPECT_EQ("done", ReadMessage(child));
});
}
DEFINE_TEST_CLIENT_TEST_WITH_PIPE(NoIOThreadHopInBroker_Child,
DirectReceiverTest,
test_pipe_handle) {
const ScopedMessagePipeHandle test_pipe{MessagePipeHandle{test_pipe_handle}};
MojoHandle handle;
ReadMessageWithHandles(test_pipe->value(), &handle, 1);
WaitForDirectRemoteLink(handle);
Remote<mojom::Service> service{PendingRemote<mojom::Service>{
MakeScopedHandle(MessagePipeHandle{handle}), 0}};
EXPECT_EQ("ok go", ReadMessage(test_pipe->value()));
base::RunLoop loop;
service->Ping(loop.QuitClosure());
loop.Run();
WriteMessage(test_pipe->value(), "done");
}
TEST_F(DirectReceiverTest, NoIOThreadHopInNonBrokerProcess) {
PendingRemote<mojom::Service> remote;
PendingReceiver<mojom::Service> receiver =
remote.InitWithNewPipeAndPassReceiver();
RunTestClient("NoIOThreadHopInNonBroker_Child", [&](MojoHandle child) {
MojoHandle service_pipe = receiver.PassPipe().release().value();
WriteMessageWithHandles(child, "", &service_pipe, 1);
WaitForDirectRemoteLink(remote.internal_state()->pipe->value());
Remote<mojom::Service> service{std::move(remote)};
EXPECT_EQ("ok go", ReadMessage(child));
base::RunLoop loop;
service->Ping(loop.QuitClosure());
loop.Run();
EXPECT_EQ("done", ReadMessage(child));
});
}
DEFINE_TEST_CLIENT_TEST_WITH_PIPE(NoIOThreadHopInNonBroker_Child,
DirectReceiverTest,
test_pipe_handle) {
const ScopedMessagePipeHandle test_pipe{MessagePipeHandle{test_pipe_handle}};
MojoHandle service_pipe;
ReadMessageWithHandles(test_pipe->value(), &service_pipe, 1);
PendingReceiver<mojom::Service> receiver{
ScopedMessagePipeHandle{MessagePipeHandle{service_pipe}}};
ServiceRunner runner{*this};
base::WaitableEvent ping_event;
runner.Start(std::move(receiver), ping_event);
{
ScopedPauseIOThread pause_io;
WriteMessage(test_pipe->value(), "ok go");
ping_event.Wait();
}
WriteMessage(test_pipe->value(), "done");
}
TEST_F(DirectReceiverTest, FallbackToIOThreadHopOnFailure) {
ServiceRunner direct_runner{*this};
ServiceRunner fallback_runner{*this, direct_runner.impl_thread()};
RunTestClient("FallbackToIOThreadHopOnFailure_Child", [&](MojoHandle child) {
PendingRemote<mojom::Service> direct_remote;
PendingReceiver<mojom::Service> direct_receiver =
direct_remote.InitWithNewPipeAndPassReceiver();
PendingRemote<mojom::Service> fallback_remote;
PendingReceiver<mojom::Service> fallback_receiver =
fallback_remote.InitWithNewPipeAndPassReceiver();
MojoHandle remote_pipes[] = {direct_remote.PassPipe().release().value(),
fallback_remote.PassPipe().release().value()};
WriteMessageWithHandles(child, "", remote_pipes, 2);
base::WaitableEvent direct_ping_event;
direct_runner.Start(std::move(direct_receiver), direct_ping_event);
base::WaitableEvent fallback_ping_event;
fallback_runner.Start(std::move(fallback_receiver), fallback_ping_event,
true);
{
ScopedPauseIOThread pause_io;
WriteMessage(child, "ok go");
direct_ping_event.Wait();
EXPECT_FALSE(fallback_ping_event.IsSignaled());
}
fallback_ping_event.Wait();
EXPECT_EQ("done", ReadMessage(child));
});
}
DEFINE_TEST_CLIENT_TEST_WITH_PIPE(FallbackToIOThreadHopOnFailure_Child,
DirectReceiverTest,
test_pipe_handle) {
const ScopedMessagePipeHandle test_pipe{MessagePipeHandle{test_pipe_handle}};
MojoHandle handles[2];
ReadMessageWithHandles(test_pipe->value(), handles, 2);
WaitForDirectRemoteLink(handles[0]);
WaitForDirectRemoteLink(handles[1]);
Remote<mojom::Service> direct_service{PendingRemote<mojom::Service>{
MakeScopedHandle(MessagePipeHandle{handles[0]}), 0}};
Remote<mojom::Service> fallback_service{PendingRemote<mojom::Service>{
MakeScopedHandle(MessagePipeHandle{handles[1]}), 0}};
EXPECT_EQ("ok go", ReadMessage(test_pipe->value()));
base::RunLoop loop;
auto quit_closure = base::BarrierClosure(2, loop.QuitClosure());
fallback_service->Ping(quit_closure);
direct_service->Ping(quit_closure);
loop.Run();
WriteMessage(test_pipe->value(), "done");
}
TEST_F(DirectReceiverTest, ThreadLocalInstanceShared) {
base::Thread io_thread{"Test IO thread"};
io_thread.StartWithOptions(
base::Thread::Options{base::MessagePumpType::IO, 0});
std::unique_ptr<ServiceImpl> impl1;
std::unique_ptr<ServiceImpl> impl2;
Remote<mojom::Service> remote1;
Remote<mojom::Service> remote2;
auto receiver1 = remote1.BindNewPipeAndPassReceiver();
auto receiver2 = remote2.BindNewPipeAndPassReceiver();
RunOn(io_thread, [&] {
impl1 = std::make_unique<ServiceImpl>(io_thread.task_runner());
impl1->receiver().Bind(std::move(receiver1));
impl2 = std::make_unique<ServiceImpl>(io_thread.task_runner());
impl2->receiver().Bind(std::move(receiver2));
EXPECT_EQ(impl1->receiver().node_for_testing().node(),
impl2->receiver().node_for_testing().node());
});
base::RunLoop loop1;
remote1->Ping(loop1.QuitClosure());
loop1.Run();
base::RunLoop loop2;
remote2->Ping(loop2.QuitClosure());
loop2.Run();
RunOn(io_thread, [&] {
EXPECT_TRUE(internal::ThreadLocalNode::CurrentThreadHasInstance());
impl1.reset();
EXPECT_TRUE(internal::ThreadLocalNode::CurrentThreadHasInstance());
impl2.reset();
EXPECT_FALSE(internal::ThreadLocalNode::CurrentThreadHasInstance());
});
}
TEST_F(DirectReceiverTest, UniqueNodePerThread) {
base::Thread io_thread1{"Test IO thread 1"};
base::Thread io_thread2{"Test IO thread 2"};
io_thread1.StartWithOptions(
base::Thread::Options{base::MessagePumpType::IO, 0});
io_thread2.StartWithOptions(
base::Thread::Options{base::MessagePumpType::IO, 0});
std::unique_ptr<ServiceImpl> impl1;
std::unique_ptr<ServiceImpl> impl2;
Remote<mojom::Service> remote1;
Remote<mojom::Service> remote2;
auto receiver1 = remote1.BindNewPipeAndPassReceiver();
auto receiver2 = remote2.BindNewPipeAndPassReceiver();
IpczHandle node1, node2;
RunOn(io_thread1, [&] {
impl1 = std::make_unique<ServiceImpl>(io_thread1.task_runner());
impl1->receiver().Bind(std::move(receiver1));
node1 = impl1->receiver().node_for_testing().node();
});
RunOn(io_thread2, [&] {
impl2 = std::make_unique<ServiceImpl>(io_thread2.task_runner());
impl2->receiver().Bind(std::move(receiver2));
node2 = impl2->receiver().node_for_testing().node();
});
EXPECT_NE(node1, node2);
base::RunLoop loop1;
remote1->Ping(loop1.QuitClosure());
loop1.Run();
base::RunLoop loop2;
remote2->Ping(loop2.QuitClosure());
loop2.Run();
RunOn(io_thread1, [&] {
EXPECT_TRUE(internal::ThreadLocalNode::CurrentThreadHasInstance());
impl1.reset();
EXPECT_FALSE(internal::ThreadLocalNode::CurrentThreadHasInstance());
});
RunOn(io_thread2, [&] {
EXPECT_TRUE(internal::ThreadLocalNode::CurrentThreadHasInstance());
impl2.reset();
EXPECT_FALSE(internal::ThreadLocalNode::CurrentThreadHasInstance());
});
}
TEST_F(DirectReceiverTest, BindInvalidPendingReceiver) {
base::Thread io_thread("Test IO thread 1");
io_thread.StartWithOptions(
base::Thread::Options{base::MessagePumpType::IO, 0});
RunOn(io_thread, [&] {
EXPECT_FALSE(internal::ThreadLocalNode::CurrentThreadHasInstance());
auto impl = std::make_unique<ServiceImpl>(io_thread.task_runner());
impl->receiver().Bind(NullReceiver());
EXPECT_FALSE(impl->receiver().receiver_for_testing().is_bound());
});
}
}