#include "mojo/public/cpp/bindings/direct_receiver.h"
#include <optional>
#include <utility>
#include "base/check_op.h"
#include "base/debug/crash_logging.h"
#include "base/memory/ptr_util.h"
#include "base/memory/scoped_refptr.h"
#include "base/metrics/histogram_functions.h"
#include "base/no_destructor.h"
#include "base/synchronization/lock.h"
#include "base/task/single_thread_task_runner.h"
#include "build/build_config.h"
#include "mojo/core/embedder/embedder.h"
#include "mojo/core/ipcz_api.h"
#include "mojo/core/ipcz_driver/driver.h"
#include "mojo/core/ipcz_driver/transport.h"
#include "mojo/public/cpp/system/handle.h"
#include "third_party/ipcz/include/ipcz/ipcz.h"
namespace mojo::internal {
namespace {
using Transport = core::ipcz_driver::Transport;
using TransportPair =
std::pair<scoped_refptr<Transport>, scoped_refptr<Transport>>;
TransportPair CreateTransportPair() {
const Transport::EndpointType global_node_type =
core::GetIpczNodeOptions().is_broker
? Transport::EndpointType::kBroker
: Transport::EndpointType::kNonBroker;
const Transport::EndpointType local_node_type =
Transport::EndpointType::kNonBroker;
TransportPair transports =
Transport::CreatePair(global_node_type, local_node_type);
transports.first->set_remote_process(base::Process::Current());
transports.second->set_remote_process(base::Process::Current());
return transports;
}
#if BUILDFLAG(IS_WIN)
bool g_use_precreated_transport = false;
class TransportPairStorage {
public:
static TransportPairStorage& Get();
void CreateTransportPairBeforeSandbox();
TransportPair TakeTransportPair();
private:
base::Lock lock_;
std::optional<TransportPair> transport_pair_ GUARDED_BY(lock_);
};
TransportPairStorage& TransportPairStorage::Get() {
static base::NoDestructor<TransportPairStorage> instance;
return *instance;
}
void TransportPairStorage::CreateTransportPairBeforeSandbox() {
base::AutoLock lock(lock_);
CHECK(!transport_pair_.has_value());
transport_pair_ = CreateTransportPair();
}
TransportPair TransportPairStorage::TakeTransportPair() {
base::AutoLock lock(lock_);
return std::exchange(transport_pair_, std::nullopt).value();
}
#endif
enum class MojoAdoptPipeResult {
kSuccess = 0,
kTransportNotConnected = 1,
kPutFailed = 2,
kMaxValue = kPutFailed,
};
enum class MojoMergePortalsResult {
kSuccess = 0,
kNotAttempted = 1,
kGetFailed = 2,
kMaxValue = kGetFailed,
};
void LogAdoptPipeResult(MojoAdoptPipeResult result) {
base::UmaHistogramEnumeration("Mojo.DirectReceiver.AdoptPipeResult", result);
}
void LogMergePortalsResult(MojoMergePortalsResult result) {
base::UmaHistogramEnumeration("Mojo.DirectReceiver.MergePortalsResult",
result);
}
thread_local ThreadLocalNode* g_thread_local_node;
}
ThreadLocalNode::ThreadLocalNode(base::PassKey<ThreadLocalNode>) {
CHECK(IsDirectReceiverSupported());
CHECK(!g_thread_local_node);
g_thread_local_node = this;
scoped_refptr<Transport> global_transport;
scoped_refptr<Transport> local_transport;
#if BUILDFLAG(IS_WIN)
if (g_use_precreated_transport) {
std::tie(global_transport, local_transport) =
TransportPairStorage::Get().TakeTransportPair();
AddRef();
} else {
std::tie(global_transport, local_transport) = CreateTransportPair();
}
#else
std::tie(global_transport, local_transport) = CreateTransportPair();
#endif
const IpczAPI& ipcz = core::GetIpczAPI();
const IpczCreateNodeOptions create_options = {
.size = sizeof(create_options),
.memory_flags = IPCZ_MEMORY_FIXED_PARCEL_CAPACITY,
};
IpczHandle node;
const IpczResult create_result = ipcz.CreateNode(
&core::ipcz_driver::kDriver, IPCZ_NO_FLAGS, &create_options, &node);
CHECK_EQ(create_result, IPCZ_RESULT_OK);
node_.reset(Handle(node));
const core::IpczNodeOptions& global_node_options = core::GetIpczNodeOptions();
IpczConnectNodeFlags local_connect_flags;
IpczConnectNodeFlags global_connect_flags;
if (global_node_options.is_broker) {
global_connect_flags = IPCZ_NO_FLAGS;
local_connect_flags = IPCZ_CONNECT_NODE_TO_BROKER;
} else {
global_connect_flags = IPCZ_CONNECT_NODE_SHARE_BROKER;
local_connect_flags = IPCZ_CONNECT_NODE_INHERIT_BROKER;
if (!global_node_options.use_local_shared_memory_allocation) {
local_connect_flags |= IPCZ_CONNECT_NODE_TO_ALLOCATION_DELEGATE;
}
}
local_transport->OverrideIOTaskRunner(
base::SingleThreadTaskRunner::GetCurrentDefault());
IpczHandle global_portal;
const IpczResult global_connect_result = ipcz.ConnectNode(
core::GetIpczNode(),
Transport::ReleaseAsHandle(std::move(global_transport)),
1, global_connect_flags, nullptr, &global_portal);
CHECK_EQ(global_connect_result, IPCZ_RESULT_OK);
global_portal_.reset(Handle(global_portal));
IpczHandle local_portal;
const IpczResult local_connect_result = ipcz.ConnectNode(
node_->value(), Transport::ReleaseAsHandle(std::move(local_transport)),
1, local_connect_flags, nullptr, &local_portal);
CHECK_EQ(local_connect_result, IPCZ_RESULT_OK);
local_portal_.reset(Handle(local_portal));
WatchForIncomingTransfers();
}
ThreadLocalNode::~ThreadLocalNode() {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
g_thread_local_node = nullptr;
}
scoped_refptr<ThreadLocalNode> ThreadLocalNode::Get() {
if (g_thread_local_node) {
return base::WrapRefCounted(g_thread_local_node);
}
return base::MakeRefCounted<ThreadLocalNode>(
base::PassKey<ThreadLocalNode>{});
}
bool ThreadLocalNode::CurrentThreadHasInstance() {
return g_thread_local_node != nullptr;
}
ScopedMessagePipeHandle ThreadLocalNode::AdoptPipe(
ScopedMessagePipeHandle pipe) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
const IpczAPI& ipcz = core::GetIpczAPI();
IpczHandle portal_to_adopt = pipe.release().value();
IpczHandle portal_to_bind, portal_to_merge;
const IpczResult open_result =
ipcz.OpenPortals(node_->value(), IPCZ_NO_FLAGS, nullptr, &portal_to_bind,
&portal_to_merge);
CHECK_EQ(open_result, IPCZ_RESULT_OK);
const uint64_t merge_id = next_merge_id_++;
pending_merges_[merge_id] = ScopedHandle{Handle{portal_to_merge}};
const IpczResult put_result = ipcz.Put(
global_portal_->value(), &merge_id, sizeof(merge_id),
&portal_to_adopt, 1, IPCZ_NO_FLAGS, nullptr);
if (put_result != IPCZ_RESULT_OK) {
LogAdoptPipeResult(merge_id == 1 && put_result == IPCZ_RESULT_NOT_FOUND
? MojoAdoptPipeResult::kTransportNotConnected
: MojoAdoptPipeResult::kPutFailed);
LogMergePortalsResult(MojoMergePortalsResult::kNotAttempted);
return ScopedMessagePipeHandle{MessagePipeHandle{portal_to_adopt}};
}
LogAdoptPipeResult(MojoAdoptPipeResult::kSuccess);
return ScopedMessagePipeHandle{MessagePipeHandle{portal_to_bind}};
}
void ThreadLocalNode::ReplacePortalForTesting(ScopedHandle dummy_portal) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
global_portal_ = std::move(dummy_portal);
}
void ThreadLocalNode::WatchForIncomingTransfers() {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
const IpczAPI& ipcz = core::GetIpczAPI();
const IpczTrapConditions conditions = {
.size = sizeof(conditions),
.flags = IPCZ_TRAP_ABOVE_MIN_LOCAL_PARCELS,
.min_local_parcels = 0,
};
auto context = std::make_unique<base::WeakPtr<ThreadLocalNode>>(
weak_ptr_factory_.GetWeakPtr());
for (;;) {
const IpczResult trap_result =
ipcz.Trap(local_portal_->value(), &conditions, &OnTrapEvent,
reinterpret_cast<uintptr_t>(context.get()), IPCZ_NO_FLAGS,
nullptr, nullptr, nullptr);
if (trap_result == IPCZ_RESULT_OK) {
context.release();
return;
}
CHECK_EQ(trap_result, IPCZ_RESULT_FAILED_PRECONDITION);
OnTransferredPortalAvailable();
}
}
void ThreadLocalNode::OnTrapEvent(const IpczTrapEvent* event) {
auto weak_node_ptr = base::WrapUnique(
reinterpret_cast<base::WeakPtr<ThreadLocalNode>*>(event->context));
const base::WeakPtr<ThreadLocalNode>& weak_node = *weak_node_ptr;
if (!weak_node) {
return;
}
weak_node->OnTransferredPortalAvailable();
weak_node->WatchForIncomingTransfers();
}
void ThreadLocalNode::OnTransferredPortalAvailable() {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
IpczHandle portal;
uint64_t merge_id = 0;
size_t num_bytes = sizeof(merge_id);
size_t num_portals = 1;
const IpczAPI& ipcz = core::GetIpczAPI();
const IpczResult get_result = ipcz.Get(
local_portal_->value(), IPCZ_NO_FLAGS, nullptr, &merge_id, &num_bytes,
&portal, &num_portals, nullptr);
if (get_result != IPCZ_RESULT_OK) {
LogMergePortalsResult(MojoMergePortalsResult::kGetFailed);
return;
}
CHECK_EQ(num_bytes, sizeof(merge_id));
CHECK_EQ(num_portals, 1u);
CHECK_NE(portal, IPCZ_INVALID_HANDLE);
auto it = pending_merges_.find(merge_id);
CHECK(it != pending_merges_.end());
const IpczResult merge_result = ipcz.MergePortals(
portal, it->second.release().value(), IPCZ_NO_FLAGS, nullptr);
CHECK_EQ(merge_result, IPCZ_RESULT_OK);
pending_merges_.erase(it);
LogMergePortalsResult(MojoMergePortalsResult::kSuccess);
}
}
namespace mojo {
bool IsDirectReceiverSupported() {
return core::IsMojoIpczEnabled();
}
#if BUILDFLAG(IS_WIN)
void CreateDirectReceiverTransportBeforeSandbox() {
CHECK(!internal::g_use_precreated_transport);
internal::g_use_precreated_transport = true;
if (IsDirectReceiverSupported()) {
internal::TransportPairStorage::Get().CreateTransportPairBeforeSandbox();
}
}
#endif
}