#include "mojo/public/cpp/bindings/interface_endpoint_client.h"
#include <stdint.h>
#include <optional>
#include <string_view>
#include <tuple>
#include <vector>
#include "base/check.h"
#include "base/containers/contains.h"
#include "base/debug/alias.h"
#include "base/functional/bind.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/memory/ptr_util.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/metrics/histogram_functions.h"
#include "base/synchronization/waitable_event.h"
#include "base/task/bind_post_task.h"
#include "base/task/common/task_annotator.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/thread_pool/thread_pool_instance.h"
#include "base/threading/thread_local.h"
#include "base/trace_event/interned_args_helper.h"
#include "base/trace_event/typed_macros.h"
#include "build/build_config.h"
#include "mojo/public/cpp/bindings/associated_group.h"
#include "mojo/public/cpp/bindings/associated_group_controller.h"
#include "mojo/public/cpp/bindings/interface_endpoint_controller.h"
#include "mojo/public/cpp/bindings/lib/task_runner_helper.h"
#include "mojo/public/cpp/bindings/lib/validation_util.h"
#include "mojo/public/cpp/bindings/sync_call_restrictions.h"
#include "mojo/public/cpp/bindings/sync_event_watcher.h"
#include "mojo/public/cpp/bindings/thread_safe_proxy.h"
#include "third_party/perfetto/protos/perfetto/trace/track_event/chrome_mojo_event_info.pbzero.h"
namespace mojo {
namespace {
constinit thread_local base::HistogramBase* g_end_to_end_metric = nullptr;
class ThreadSafeInterfaceEndpointClientProxy : public ThreadSafeProxy {
public:
ThreadSafeInterfaceEndpointClientProxy(
base::WeakPtr<InterfaceEndpointClient> endpoint,
scoped_refptr<ThreadSafeProxy::Target> target,
const AssociatedGroup& associated_group,
scoped_refptr<base::SequencedTaskRunner> task_runner,
const base::Location& location)
: endpoint_(std::move(endpoint)),
target_(std::move(target)),
associated_group_(associated_group),
task_runner_(std::move(task_runner)),
location_(location) {}
ThreadSafeInterfaceEndpointClientProxy(
const ThreadSafeInterfaceEndpointClientProxy&) = delete;
ThreadSafeInterfaceEndpointClientProxy& operator=(
const ThreadSafeInterfaceEndpointClientProxy&) = delete;
void SendMessage(Message& message) override {
message.SerializeHandles(associated_group_.GetController());
task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&ThreadSafeInterfaceEndpointClientProxy::ForwardMessage,
this, std::move(message)));
}
void SendMessageWithResponder(
Message& message,
std::unique_ptr<MessageReceiver> responder) override;
private:
~ThreadSafeInterfaceEndpointClientProxy() override {
base::AutoLock l(sync_calls_->lock);
for (ThreadSafeInterfaceEndpointClientProxy::SyncResponseInfo*
pending_response : sync_calls_->pending_responses) {
pending_response->event.Signal();
}
}
struct SyncResponseInfo
: public base::RefCountedThreadSafe<SyncResponseInfo> {
SyncResponseInfo() = default;
Message message;
bool received = false;
base::WaitableEvent event{base::WaitableEvent::ResetPolicy::MANUAL,
base::WaitableEvent::InitialState::NOT_SIGNALED};
private:
friend class base::RefCountedThreadSafe<SyncResponseInfo>;
~SyncResponseInfo() = default;
};
class SyncResponseSignaler : public MessageReceiver {
public:
explicit SyncResponseSignaler(scoped_refptr<SyncResponseInfo> response)
: response_(std::move(response)) {}
~SyncResponseSignaler() override {
if (response_)
response_->event.Signal();
}
bool Accept(Message* message) override {
response_->message = std::move(*message);
response_->received = true;
response_->event.Signal();
response_ = nullptr;
return true;
}
private:
scoped_refptr<SyncResponseInfo> response_;
};
struct InProgressSyncCalls
: public base::RefCountedThreadSafe<InProgressSyncCalls> {
InProgressSyncCalls() = default;
base::Lock lock;
std::vector<raw_ptr<SyncResponseInfo, VectorExperimental>> pending_responses
GUARDED_BY(lock);
private:
friend class base::RefCountedThreadSafe<InProgressSyncCalls>;
~InProgressSyncCalls() = default;
};
class ForwardToCallingThread : public MessageReceiver {
public:
explicit ForwardToCallingThread(std::unique_ptr<MessageReceiver> responder,
const base::Location& location)
: responder_(std::move(responder)),
caller_task_runner_(base::SequencedTaskRunner::GetCurrentDefault()),
location_(location) {}
~ForwardToCallingThread() override {
caller_task_runner_->DeleteSoon(location_, std::move(responder_));
}
private:
bool Accept(Message* message) override {
caller_task_runner_->PostTask(
location_,
base::BindOnce(&ForwardToCallingThread::CallAcceptAndDeleteResponder,
std::move(responder_), std::move(*message)));
return true;
}
static void CallAcceptAndDeleteResponder(
std::unique_ptr<MessageReceiver> responder,
Message message) {
std::ignore = responder->Accept(&message);
}
std::unique_ptr<MessageReceiver> responder_;
scoped_refptr<base::SequencedTaskRunner> caller_task_runner_;
const base::Location location_;
};
class ForwardSameThreadResponder : public MessageReceiver {
public:
explicit ForwardSameThreadResponder(
scoped_refptr<ThreadSafeProxy> proxy,
std::unique_ptr<MessageReceiver> responder)
: proxy_(std::move(proxy)), responder_(std::move(responder)) {}
~ForwardSameThreadResponder() override = default;
private:
bool Accept(Message* message) override {
if (proxy_->HasOneRef())
return false;
return responder_->Accept(message);
}
const scoped_refptr<ThreadSafeProxy> proxy_;
const std::unique_ptr<MessageReceiver> responder_;
};
void ForwardMessage(Message message) {
DCHECK(task_runner_->RunsTasksInCurrentSequence());
if (!endpoint_)
return;
endpoint_->SendMessage(&message, false);
}
void ForwardMessageWithResponder(
Message message,
InterfaceEndpointClient::SyncSendMode sync_send_mode,
std::unique_ptr<MessageReceiver> responder) {
DCHECK(task_runner_->RunsTasksInCurrentSequence());
if (!endpoint_)
return;
endpoint_->SendMessageWithResponder(&message, false,
sync_send_mode, std::move(responder));
}
const base::WeakPtr<InterfaceEndpointClient> endpoint_;
const scoped_refptr<ThreadSafeProxy::Target> target_;
AssociatedGroup associated_group_;
const scoped_refptr<base::SequencedTaskRunner> task_runner_;
const scoped_refptr<InProgressSyncCalls> sync_calls_{
base::MakeRefCounted<InProgressSyncCalls>()};
const base::Location location_;
};
void DetermineIfEndpointIsConnected(
const base::WeakPtr<InterfaceEndpointClient>& client,
base::OnceCallback<void(bool)> callback) {
std::move(callback).Run(client && !client->encountered_error());
}
class ResponderThunk : public MessageReceiverWithStatus {
public:
explicit ResponderThunk(
const base::WeakPtr<InterfaceEndpointClient>& endpoint_client,
scoped_refptr<base::SequencedTaskRunner> runner)
: endpoint_client_(endpoint_client),
accept_was_invoked_(false),
task_runner_(std::move(runner)) {}
ResponderThunk(const ResponderThunk&) = delete;
ResponderThunk& operator=(const ResponderThunk&) = delete;
~ResponderThunk() override {
if (!accept_was_invoked_) {
if (task_runner_->RunsTasksInCurrentSequence()) {
if (endpoint_client_) {
endpoint_client_->RaiseError();
}
} else {
base::ThreadPoolInstance::ScopedFizzleBlockShutdownTasks fizzler;
task_runner_->PostTask(
FROM_HERE, base::BindOnce(&InterfaceEndpointClient::RaiseError,
endpoint_client_));
}
}
}
void set_connection_group(ConnectionGroup::Ref connection_group) {
connection_group_ = std::move(connection_group);
}
bool PrefersSerializedMessages() override {
return endpoint_client_ && endpoint_client_->PrefersSerializedMessages();
}
bool Accept(Message* message) override {
DCHECK(task_runner_->RunsTasksInCurrentSequence());
accept_was_invoked_ = true;
DCHECK(message->has_flag(Message::kFlagIsResponse));
bool result = false;
if (endpoint_client_)
result = endpoint_client_->Accept(message);
return result;
}
bool IsConnected() override {
DCHECK(task_runner_->RunsTasksInCurrentSequence());
return endpoint_client_ && !endpoint_client_->encountered_error();
}
void IsConnectedAsync(base::OnceCallback<void(bool)> callback) override {
if (task_runner_->RunsTasksInCurrentSequence()) {
DetermineIfEndpointIsConnected(endpoint_client_, std::move(callback));
} else {
task_runner_->PostTask(
FROM_HERE, base::BindOnce(&DetermineIfEndpointIsConnected,
endpoint_client_, std::move(callback)));
}
}
private:
base::WeakPtr<InterfaceEndpointClient> endpoint_client_;
bool accept_was_invoked_;
scoped_refptr<base::SequencedTaskRunner> task_runner_;
ConnectionGroup::Ref connection_group_;
};
}
InterfaceEndpointClient::PendingAsyncResponse::PendingAsyncResponse(
uint32_t request_message_name,
std::unique_ptr<MessageReceiver> responder)
: request_message_name(request_message_name),
responder(std::move(responder)) {}
InterfaceEndpointClient::PendingAsyncResponse::PendingAsyncResponse(
PendingAsyncResponse&&) = default;
InterfaceEndpointClient::PendingAsyncResponse&
InterfaceEndpointClient::PendingAsyncResponse::operator=(
PendingAsyncResponse&&) = default;
InterfaceEndpointClient::PendingAsyncResponse::~PendingAsyncResponse() =
default;
InterfaceEndpointClient::SyncResponseInfo::SyncResponseInfo(
uint32_t request_message_name,
bool* in_response_received)
: request_message_name(request_message_name),
response_received(in_response_received) {}
InterfaceEndpointClient::SyncResponseInfo::~SyncResponseInfo() {}
InterfaceEndpointClient::HandleIncomingMessageThunk::HandleIncomingMessageThunk(
InterfaceEndpointClient* owner)
: owner_(owner) {}
InterfaceEndpointClient::HandleIncomingMessageThunk::
~HandleIncomingMessageThunk() {}
bool InterfaceEndpointClient::HandleIncomingMessageThunk::Accept(
Message* message) {
return owner_->HandleValidatedMessage(message);
}
void ThreadSafeInterfaceEndpointClientProxy::SendMessageWithResponder(
Message& message,
std::unique_ptr<MessageReceiver> responder) {
message.SerializeHandles(associated_group_.GetController());
if (!message.has_flag(Message::kFlagIsSync)) {
auto reply_forwarder = std::make_unique<ForwardToCallingThread>(
std::move(responder), location_);
task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&ThreadSafeInterfaceEndpointClientProxy ::
ForwardMessageWithResponder,
this, std::move(message),
InterfaceEndpointClient::SyncSendMode::kForceAsync,
std::move(reply_forwarder)));
return;
}
if (task_runner_->RunsTasksInCurrentSequence()) {
ForwardMessageWithResponder(
std::move(message),
InterfaceEndpointClient::SyncSendMode::kAllowSyncWait,
std::make_unique<ForwardSameThreadResponder>(this,
std::move(responder)));
return;
}
const bool allow_interrupt =
SyncCallRestrictions::AreSyncCallInterruptsEnabled() &&
!message.has_flag(Message::kFlagNoInterrupt);
auto response = base::MakeRefCounted<SyncResponseInfo>();
auto response_signaler = std::make_unique<SyncResponseSignaler>(response);
task_runner_->PostTask(
FROM_HERE,
base::BindOnce(
&ThreadSafeInterfaceEndpointClientProxy::ForwardMessageWithResponder,
this, std::move(message),
InterfaceEndpointClient::SyncSendMode::kForceAsync,
std::move(response_signaler)));
auto sync_calls = sync_calls_;
{
base::AutoLock l(sync_calls->lock);
sync_calls->pending_responses.push_back(response.get());
}
SyncCallRestrictions::AssertSyncCallAllowed();
if (allow_interrupt) {
bool signaled = false;
auto set_flag = [](bool* flag) { *flag = true; };
SyncEventWatcher watcher(&response->event,
base::BindRepeating(set_flag, &signaled));
const bool* stop_flags[] = {&signaled};
watcher.SyncWatch(stop_flags);
} else {
response->event.Wait();
}
{
base::AutoLock l(sync_calls->lock);
std::erase(sync_calls->pending_responses, response.get());
}
if (response->received)
std::ignore = responder->Accept(&response->message);
}
InterfaceEndpointClient::InterfaceEndpointClient(
ScopedInterfaceEndpointHandle handle,
MessageReceiverWithResponderStatus* receiver,
std::unique_ptr<MessageReceiver> payload_validator,
base::span<const uint32_t> sync_method_ordinals,
scoped_refptr<base::SequencedTaskRunner> task_runner,
uint32_t interface_version,
const char* interface_name,
MessageToMethodInfoCallback method_info_callback,
MessageToMethodNameCallback method_name_callback)
: sync_method_ordinals_(sync_method_ordinals),
handle_(std::move(handle)),
incoming_receiver_(receiver),
dispatcher_(&thunk_),
task_runner_(std::move(task_runner)),
control_message_handler_(this, interface_version),
interface_name_(interface_name),
method_info_callback_(method_info_callback),
method_name_callback_(method_name_callback) {
DCHECK(interface_name_);
DCHECK(handle_.is_valid());
sequence_checker_.DetachFromSequence();
if (payload_validator)
dispatcher_.SetValidator(std::move(payload_validator));
if (handle_.pending_association()) {
if (task_runner_->RunsTasksInCurrentSequence()) {
handle_.SetAssociationEventHandler(
base::BindOnce(&InterfaceEndpointClient::OnAssociationEvent,
base::Unretained(this)));
} else {
handle_.SetAssociationEventHandler(base::BindPostTask(
task_runner_,
base::BindOnce(&InterfaceEndpointClient::OnAssociationEvent,
weak_ptr_factory_.GetWeakPtr())));
}
} else {
InitControllerIfNecessary();
}
}
InterfaceEndpointClient::~InterfaceEndpointClient() {
CHECK(sequence_checker_.CalledOnValidSequence());
if (controller_)
handle_.group_controller()->DetachEndpointClient(handle_);
}
AssociatedGroup* InterfaceEndpointClient::associated_group() {
if (!associated_group_)
associated_group_ = std::make_unique<AssociatedGroup>(handle_);
return associated_group_.get();
}
scoped_refptr<ThreadSafeProxy> InterfaceEndpointClient::CreateThreadSafeProxy(
scoped_refptr<ThreadSafeProxy::Target> target,
const base::Location& location) {
return base::MakeRefCounted<ThreadSafeInterfaceEndpointClientProxy>(
weak_ptr_factory_.GetWeakPtr(), std::move(target), *associated_group_,
task_runner_, location);
}
ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() {
CHECK(sequence_checker_.CalledOnValidSequence());
DCHECK(!has_pending_responders());
if (!handle_.is_valid())
return ScopedInterfaceEndpointHandle();
handle_.SetAssociationEventHandler(
ScopedInterfaceEndpointHandle::AssociationEventCallback());
if (controller_) {
controller_ = nullptr;
handle_.group_controller()->DetachEndpointClient(handle_);
}
return std::move(handle_);
}
void InterfaceEndpointClient::SetFilter(std::unique_ptr<MessageFilter> filter) {
dispatcher_.SetFilter(std::move(filter));
}
void InterfaceEndpointClient::RaiseError() {
CHECK(sequence_checker_.CalledOnValidSequence());
if (!handle_.pending_association())
handle_.group_controller()->RaiseError();
}
void InterfaceEndpointClient::CloseWithReason(uint32_t custom_reason,
std::string_view description) {
CHECK(sequence_checker_.CalledOnValidSequence());
auto handle = PassHandle();
handle.ResetWithReason(custom_reason, description);
}
bool InterfaceEndpointClient::PrefersSerializedMessages() {
auto* controller = handle_.group_controller();
return controller && controller->PrefersSerializedMessages();
}
void InterfaceEndpointClient::SendControlMessage(Message* message) {
SendMessage(message, true );
}
void InterfaceEndpointClient::SendControlMessageWithResponder(
Message* message,
std::unique_ptr<MessageReceiver> responder) {
SendMessageWithResponder(message, true ,
SyncSendMode::kAllowSyncWait, std::move(responder));
}
bool InterfaceEndpointClient::Accept(Message* message) {
return SendMessage(message, false );
}
bool InterfaceEndpointClient::AcceptWithResponder(
Message* message,
std::unique_ptr<MessageReceiver> responder) {
return SendMessageWithResponder(message, false ,
SyncSendMode::kAllowSyncWait,
std::move(responder));
}
bool InterfaceEndpointClient::SendMessage(Message* message,
bool is_control_message) {
CHECK(sequence_checker_.CalledOnValidSequence());
DCHECK(!message->has_flag(Message::kFlagExpectsResponse));
CHECK(!handle_.pending_association())
<< "Cannot send a message when the endpoint hasn't been associated with "
"a message pipe. This failure typically happens when attempting to "
"make a call with an AssociatedRemote before one of the endpoints "
"(either the AssociatedRemote itself or its entangled "
"AssociatedReceiver) is sent over a Remote/Receiver pair or an "
"already-established AssociatedRemote/AssociatedReceiver pair.";
message->SerializeHandles(handle_.group_controller());
if (encountered_error_) {
message->NotifyPeerClosureForSerializedHandles(handle_.group_controller());
return false;
}
InitControllerIfNecessary();
#if DCHECK_IS_ON()
#endif
message->set_heap_profiler_tag(interface_name_);
if (!controller_->SendMessage(message)) {
message->NotifyPeerClosureForSerializedHandles(handle_.group_controller());
return false;
}
if (!is_control_message && idle_handler_)
++num_unacked_messages_;
return true;
}
bool InterfaceEndpointClient::SendMessageWithResponder(
Message* message,
bool is_control_message,
SyncSendMode sync_send_mode,
std::unique_ptr<MessageReceiver> responder) {
#if BUILDFLAG(ARKWEB_BUGFIX_CRASH)
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
#else
CHECK(sequence_checker_.CalledOnValidSequence());
#endif
DCHECK(message->has_flag(Message::kFlagExpectsResponse));
DCHECK(!handle_.pending_association());
message->SerializeHandles(handle_.group_controller());
if (encountered_error_) {
message->NotifyPeerClosureForSerializedHandles(handle_.group_controller());
return false;
}
InitControllerIfNecessary();
uint64_t request_id = next_request_id_++;
if (request_id == 0)
request_id = next_request_id_++;
message->set_request_id(request_id);
message->set_heap_profiler_tag(interface_name_);
#if DCHECK_IS_ON()
#endif
const uint32_t message_name = message->name();
const bool is_sync = message->has_flag(Message::kFlagIsSync);
const bool exclusive_wait =
message->has_flag(Message::kFlagNoInterrupt) ||
!SyncCallRestrictions::AreSyncCallInterruptsEnabled();
if (!controller_->SendMessage(message)) {
message->NotifyPeerClosureForSerializedHandles(handle_.group_controller());
return false;
}
if (!is_control_message && idle_handler_)
++num_unacked_messages_;
if (!is_sync || sync_send_mode == SyncSendMode::kForceAsync) {
if (is_sync) {
sync_responses_.emplace(request_id, nullptr);
controller_->RegisterExternalSyncWaiter(request_id);
}
base::AutoLock lock(async_responders_lock_);
async_responders_.emplace(
request_id, PendingAsyncResponse{message_name, std::move(responder)});
return true;
}
SyncCallRestrictions::AssertSyncCallAllowed();
bool response_received = false;
sync_responses_.insert(std::make_pair(
request_id,
std::make_unique<SyncResponseInfo>(message_name, &response_received)));
base::WeakPtr<InterfaceEndpointClient> weak_self =
weak_ptr_factory_.GetWeakPtr();
if (exclusive_wait)
controller_->SyncWatchExclusive(request_id);
else
controller_->SyncWatch(response_received);
if (weak_self) {
DCHECK(base::Contains(sync_responses_, request_id));
auto iter = sync_responses_.find(request_id);
DCHECK_EQ(&response_received, iter->second->response_received);
if (response_received) {
std::ignore = responder->Accept(&iter->second->response);
} else {
DVLOG(1) << "Mojo sync call returns without receiving a response. "
<< "Typcially it is because the interface has been "
<< "disconnected.";
}
sync_responses_.erase(iter);
}
return true;
}
bool InterfaceEndpointClient::HandleIncomingMessage(Message* message) {
CHECK(sequence_checker_.CalledOnValidSequence());
const char* interface_name = interface_name_;
uint32_t name = message->name();
if (!dispatcher_.Accept(message)) {
LOG(ERROR) << "Message " << name << " rejected by interface "
<< interface_name;
return false;
}
return true;
}
void InterfaceEndpointClient::NotifyError(
const std::optional<DisconnectReason>& reason) {
TRACE_EVENT("toplevel", "Closed mojo endpoint",
[&](perfetto::EventContext& ctx) {
auto* info = ctx.event()->set_chrome_mojo_event_info();
info->set_mojo_interface_tag(interface_name_);
});
CHECK(sequence_checker_.CalledOnValidSequence());
if (encountered_error_)
return;
encountered_error_ = true;
DEBUG_ALIAS_FOR_CSTR(interface_name, interface_name_, 256);
AsyncResponderMap responders;
{
base::AutoLock lock(async_responders_lock_);
std::swap(responders, async_responders_);
}
control_message_proxy_.OnConnectionError();
if (error_handler_) {
std::move(error_handler_).Run();
} else if (error_with_reason_handler_) {
if (reason) {
std::move(error_with_reason_handler_)
.Run(reason->custom_reason, reason->description);
} else {
std::move(error_with_reason_handler_).Run(0, std::string());
}
}
}
void InterfaceEndpointClient::QueryVersion(
base::OnceCallback<void(uint32_t)> callback) {
control_message_proxy_.QueryVersion(std::move(callback));
}
void InterfaceEndpointClient::RequireVersion(uint32_t version) {
control_message_proxy_.RequireVersion(version);
}
void InterfaceEndpointClient::FlushForTesting() {
control_message_proxy_.FlushForTesting();
}
void InterfaceEndpointClient::FlushAsyncForTesting(base::OnceClosure callback) {
control_message_proxy_.FlushAsyncForTesting(std::move(callback));
}
void InterfaceEndpointClient::SetIdleHandler(base::TimeDelta timeout,
base::RepeatingClosure handler) {
control_message_proxy_.EnableIdleTracking(timeout);
idle_handler_ = std::move(handler);
}
void InterfaceEndpointClient::SetIdleTrackingEnabledCallback(
IdleTrackingEnabledCallback callback) {
idle_tracking_enabled_callback_ = std::move(callback);
}
bool InterfaceEndpointClient::AcceptEnableIdleTracking(
base::TimeDelta timeout) {
if (idle_tracking_enabled_callback_) {
idle_tracking_connection_group_ = ConnectionGroup::Create(
base::BindRepeating(&InterfaceEndpointClient::MaybeStartIdleTimer,
weak_ptr_factory_.GetWeakPtr()),
task_runner_);
std::move(idle_tracking_enabled_callback_)
.Run(idle_tracking_connection_group_.WeakCopy());
}
idle_timeout_ = timeout;
return true;
}
bool InterfaceEndpointClient::AcceptMessageAck() {
if (!idle_handler_ || num_unacked_messages_ == 0)
return false;
--num_unacked_messages_;
return true;
}
bool InterfaceEndpointClient::AcceptNotifyIdle() {
if (!idle_handler_)
return false;
if (num_unacked_messages_ > 0)
return true;
idle_handler_.Run();
return true;
}
void InterfaceEndpointClient::MaybeStartIdleTimer() {
if (idle_tracking_connection_group_ &&
idle_tracking_connection_group_.HasZeroRefs()) {
DCHECK(idle_timeout_);
notify_idle_timer_.emplace();
notify_idle_timer_->Start(
FROM_HERE, *idle_timeout_,
base::BindOnce(&InterfaceEndpointClient::MaybeSendNotifyIdle,
base::Unretained(this)));
} else {
notify_idle_timer_.reset();
}
}
void InterfaceEndpointClient::MaybeSendNotifyIdle() {
if (idle_tracking_connection_group_ &&
idle_tracking_connection_group_.HasZeroRefs()) {
control_message_proxy_.NotifyIdle();
}
}
void InterfaceEndpointClient::ResetFromAnotherSequenceUnsafe() {
sequence_checker_.DetachFromSequence();
if (controller_) {
controller_ = nullptr;
handle_.group_controller()->DetachEndpointClient(handle_);
}
handle_.reset();
}
void InterfaceEndpointClient::ForgetAsyncRequest(uint64_t request_id) {
std::optional<PendingAsyncResponse> response;
{
base::AutoLock lock(async_responders_lock_);
auto it = async_responders_.find(request_id);
if (it == async_responders_.end())
return;
response = std::move(it->second);
async_responders_.erase(it);
}
}
void InterfaceEndpointClient::InitControllerIfNecessary() {
if (controller_ || handle_.pending_association())
return;
controller_ = handle_.group_controller()->AttachEndpointClient(handle_, this,
task_runner_);
if (!sync_method_ordinals_.empty() &&
task_runner_->RunsTasksInCurrentSequence())
controller_->AllowWokenUpBySyncWatchOnSameThread();
}
void InterfaceEndpointClient::OnAssociationEvent(
ScopedInterfaceEndpointHandle::AssociationEvent event) {
if (event == ScopedInterfaceEndpointHandle::ASSOCIATED) {
InitControllerIfNecessary();
} else if (event ==
ScopedInterfaceEndpointHandle::PEER_CLOSED_BEFORE_ASSOCIATION) {
task_runner_->PostTask(FROM_HERE,
base::BindOnce(&InterfaceEndpointClient::NotifyError,
weak_ptr_factory_.GetWeakPtr(),
handle_.disconnect_reason()));
}
}
bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) {
TRACE_EVENT("toplevel,mojom",
perfetto::StaticString{method_name_callback_(*message)},
[&](perfetto::EventContext& ctx) {
auto* info = ctx.event()->set_chrome_mojo_event_info();
#if BUILDFLAG(IS_ANDROID) && defined(ARCH_CPU_ARM64)
info->set_mojo_interface_tag(interface_name_);
#else
if (!ctx.ShouldFilterDebugAnnotations()) {
info->set_mojo_interface_tag(interface_name_);
}
#endif
const auto method_info = method_info_callback_(*message);
if (method_info) {
info->set_ipc_hash((*method_info)());
const auto method_address =
reinterpret_cast<uintptr_t>(method_info);
const std::optional<size_t> location_iid =
base::trace_event::InternedUnsymbolizedSourceLocation::
Get(&ctx, method_address);
if (location_iid) {
info->set_mojo_interface_method_iid(*location_iid);
}
}
info->set_payload_size(message->payload_num_bytes());
info->set_data_num_bytes(message->data_num_bytes());
static const uint8_t* flow_enabled =
TRACE_EVENT_API_GET_CATEGORY_GROUP_ENABLED(
"toplevel.flow,mojom.flow");
if (!*flow_enabled)
return;
perfetto::Flow::Global(message->GetTraceId())(ctx);
});
DCHECK_EQ(handle_.id(), message->interface_id());
int64_t creation_timeticks_us = message->creation_timeticks_us();
if (creation_timeticks_us > 0) {
if (!g_end_to_end_metric) {
SetThreadNameSuffixForMetrics("Default");
}
base::TimeTicks creation_timeticks =
base::TimeTicks() + base::Microseconds(creation_timeticks_us);
base::TimeDelta end_to_end_duration =
base::TimeTicks::Now() - creation_timeticks;
g_end_to_end_metric->AddTimeMicrosecondsGranularity(end_to_end_duration);
}
if (!message->has_flag(Message::kFlagIsSync)) {
const auto method_info = method_info_callback_(*message);
base::TaskAnnotator::OnIPCReceived(
interface_name_, method_info,
message->has_flag(Message::kFlagIsResponse));
}
if (encountered_error_) {
DVLOG(1) << "A message is received for an interface after it has been "
<< "disconnected. Closing the pipe.";
return false;
}
auto weak_self = weak_ptr_factory_.GetWeakPtr();
bool accepted_interface_message = false;
bool has_response = false;
if (message->has_flag(Message::kFlagExpectsResponse)) {
has_response = true;
auto responder = std::make_unique<ResponderThunk>(
weak_ptr_factory_.GetWeakPtr(), task_runner_);
if (mojo::internal::ControlMessageHandler::IsControlMessage(message)) {
return control_message_handler_.AcceptWithResponder(message,
std::move(responder));
} else {
if (idle_tracking_connection_group_) {
responder->set_connection_group(idle_tracking_connection_group_);
}
accepted_interface_message = incoming_receiver_->AcceptWithResponder(
message, std::move(responder));
}
} else if (message->has_flag(Message::kFlagIsResponse)) {
uint64_t request_id = message->request_id();
if (message->has_flag(Message::kFlagIsSync)) {
auto it = sync_responses_.find(request_id);
if (it == sync_responses_.end())
return false;
if (it->second) {
if (message->name() != it->second->request_message_name) {
return false;
}
it->second->response = std::move(*message);
*it->second->response_received = true;
return true;
}
sync_responses_.erase(it);
}
std::optional<PendingAsyncResponse> pending_response;
{
base::AutoLock lock(async_responders_lock_);
auto it = async_responders_.find(request_id);
if (it == async_responders_.end())
return false;
pending_response = std::move(it->second);
async_responders_.erase(it);
}
if (message->name() != pending_response->request_message_name) {
return false;
}
internal::MessageDispatchContext dispatch_context(message);
return pending_response->responder->Accept(message);
} else {
if (mojo::internal::ControlMessageHandler::IsControlMessage(message))
return control_message_handler_.Accept(message);
accepted_interface_message = incoming_receiver_->Accept(message);
}
if (weak_self && accepted_interface_message &&
idle_tracking_connection_group_) {
control_message_proxy_.SendMessageAck();
if (!has_response)
MaybeStartIdleTimer();
}
return accepted_interface_message;
}
void InterfaceEndpointClient::SetThreadNameSuffixForMetrics(
std::string thread_name) {
g_end_to_end_metric = base::Histogram::FactoryMicrosecondsTimeGet(
"Mojo.EndToEndLatencyUs." + thread_name, base::Microseconds(1),
base::Seconds(1), 100, base::HistogramBase::kUmaTargetedHistogramFlag);
}
}