#ifndef MOJO_PUBLIC_CPP_BINDINGS_INTERFACE_ENDPOINT_CLIENT_H_
#define MOJO_PUBLIC_CPP_BINDINGS_INTERFACE_ENDPOINT_CLIENT_H_
#include <stdint.h>
#include <map>
#include <memory>
#include <optional>
#include <string_view>
#include <utility>
#include "base/component_export.h"
#include "base/containers/span.h"
#include "base/dcheck_is_on.h"
#include "base/functional/callback.h"
#include "base/location.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/raw_ptr_exclusion.h"
#include "base/memory/raw_span.h"
#include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "base/synchronization/lock.h"
#include "base/task/sequenced_task_runner.h"
#include "base/thread_annotations.h"
#include "base/time/time.h"
#include "base/timer/timer.h"
#include "mojo/public/cpp/bindings/connection_error_callback.h"
#include "mojo/public/cpp/bindings/connection_group.h"
#include "mojo/public/cpp/bindings/disconnect_reason.h"
#include "mojo/public/cpp/bindings/lib/control_message_handler.h"
#include "mojo/public/cpp/bindings/lib/control_message_proxy.h"
#include "mojo/public/cpp/bindings/message.h"
#include "mojo/public/cpp/bindings/message_dispatcher.h"
#include "mojo/public/cpp/bindings/message_metadata_helpers.h"
#include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h"
#include "mojo/public/cpp/bindings/thread_safe_proxy.h"
namespace mojo {
class AssociatedGroup;
class InterfaceEndpointController;
class COMPONENT_EXPORT(MOJO_CPP_BINDINGS) InterfaceEndpointClient
: public MessageReceiverWithResponder {
public:
InterfaceEndpointClient(ScopedInterfaceEndpointHandle handle,
MessageReceiverWithResponderStatus* receiver,
std::unique_ptr<MessageReceiver> payload_validator,
base::span<const uint32_t> sync_method_ordinals,
scoped_refptr<base::SequencedTaskRunner> task_runner,
uint32_t interface_version,
const char* interface_name,
MessageToMethodInfoCallback method_info_callback,
MessageToMethodNameCallback method_name_callback);
InterfaceEndpointClient(const InterfaceEndpointClient&) = delete;
InterfaceEndpointClient& operator=(const InterfaceEndpointClient&) = delete;
~InterfaceEndpointClient() override;
void set_connection_error_handler(base::OnceClosure error_handler) {
CHECK(sequence_checker_.CalledOnValidSequence());
error_handler_ = std::move(error_handler);
error_with_reason_handler_.Reset();
}
void set_connection_error_with_reason_handler(
ConnectionErrorWithReasonCallback error_handler) {
CHECK(sequence_checker_.CalledOnValidSequence());
error_with_reason_handler_ = std::move(error_handler);
error_handler_.Reset();
}
bool encountered_error() const {
CHECK(sequence_checker_.CalledOnValidSequence());
return encountered_error_;
}
bool has_pending_responders() const {
CHECK(sequence_checker_.CalledOnValidSequence());
base::AutoLock lock(async_responders_lock_);
return !async_responders_.empty() || !sync_responses_.empty();
}
AssociatedGroup* associated_group();
scoped_refptr<ThreadSafeProxy> CreateThreadSafeProxy(
scoped_refptr<ThreadSafeProxy::Target> target,
const base::Location& location);
void SetFilter(std::unique_ptr<MessageFilter> filter);
ScopedInterfaceEndpointHandle PassHandle();
void RaiseError();
void CloseWithReason(uint32_t custom_reason, std::string_view description);
void SendControlMessage(Message* message);
void SendControlMessageWithResponder(
Message* message,
std::unique_ptr<MessageReceiver> responder);
bool PrefersSerializedMessages() override;
bool Accept(Message* message) override;
bool AcceptWithResponder(Message* message,
std::unique_ptr<MessageReceiver> responder) override;
enum class SyncSendMode {
kAllowSyncWait,
kForceAsync,
};
bool SendMessage(Message* message, bool is_control_message);
bool SendMessageWithResponder(Message* message,
bool is_control_message,
SyncSendMode sync_send_mode,
std::unique_ptr<MessageReceiver> responder);
bool HandleIncomingMessage(Message* message);
void NotifyError(const std::optional<DisconnectReason>& reason);
void QueryVersion(base::OnceCallback<void(uint32_t)> callback);
void RequireVersion(uint32_t version);
void FlushForTesting();
void FlushAsyncForTesting(base::OnceClosure callback);
void SetIdleHandler(base::TimeDelta timeout, base::RepeatingClosure handler);
unsigned int GetNumUnackedMessagesForTesting() const {
return num_unacked_messages_;
}
using IdleTrackingEnabledCallback =
base::OnceCallback<void(ConnectionGroup::Ref connection_group)>;
void SetIdleTrackingEnabledCallback(IdleTrackingEnabledCallback callback);
bool AcceptEnableIdleTracking(base::TimeDelta timeout);
bool AcceptMessageAck();
bool AcceptNotifyIdle();
void MaybeStartIdleTimer();
void MaybeSendNotifyIdle();
const char* interface_name() const { return interface_name_; }
MessageToMethodInfoCallback method_info_callback() const {
return method_info_callback_;
}
MessageToMethodNameCallback method_name_callback() const {
return method_name_callback_;
}
#if DCHECK_IS_ON()
void SetNextCallLocation(const base::Location& location) {
next_call_location_ = location;
}
#endif
void ResetFromAnotherSequenceUnsafe();
void ForgetAsyncRequest(uint64_t request_id);
base::span<const uint32_t> sync_method_ordinals() const {
return sync_method_ordinals_;
}
static void SetThreadNameSuffixForMetrics(std::string thread_name);
private:
struct PendingAsyncResponse {
public:
PendingAsyncResponse(uint32_t request_message_name,
std::unique_ptr<MessageReceiver> responder);
PendingAsyncResponse(PendingAsyncResponse&&);
PendingAsyncResponse(const PendingAsyncResponse&) = delete;
PendingAsyncResponse& operator=(PendingAsyncResponse&&);
PendingAsyncResponse& operator=(const PendingAsyncResponse&) = delete;
~PendingAsyncResponse();
uint32_t request_message_name;
std::unique_ptr<MessageReceiver> responder;
};
using AsyncResponderMap = std::map<uint64_t, PendingAsyncResponse>;
struct SyncResponseInfo {
public:
SyncResponseInfo(uint32_t request_message_name, bool* in_response_received);
SyncResponseInfo(const SyncResponseInfo&) = delete;
SyncResponseInfo& operator=(const SyncResponseInfo&) = delete;
~SyncResponseInfo();
uint32_t request_message_name;
Message response;
raw_ptr<bool> response_received;
};
using SyncResponseMap = std::map<uint64_t, std::unique_ptr<SyncResponseInfo>>;
class HandleIncomingMessageThunk : public MessageReceiver {
public:
explicit HandleIncomingMessageThunk(InterfaceEndpointClient* owner);
HandleIncomingMessageThunk(const HandleIncomingMessageThunk&) = delete;
HandleIncomingMessageThunk& operator=(const HandleIncomingMessageThunk&) =
delete;
~HandleIncomingMessageThunk() override;
bool Accept(Message* message) override;
private:
RAW_PTR_EXCLUSION InterfaceEndpointClient* const owner_ = nullptr;
};
void InitControllerIfNecessary();
void OnAssociationEvent(
ScopedInterfaceEndpointHandle::AssociationEvent event);
bool HandleValidatedMessage(Message* message);
const base::raw_span<const uint32_t> sync_method_ordinals_;
base::RepeatingClosure idle_handler_;
IdleTrackingEnabledCallback idle_tracking_enabled_callback_;
std::optional<base::TimeDelta> idle_timeout_;
std::optional<base::OneShotTimer> notify_idle_timer_;
ConnectionGroup::Ref idle_tracking_connection_group_;
unsigned int num_unacked_messages_ = 0;
ScopedInterfaceEndpointHandle handle_;
std::unique_ptr<AssociatedGroup> associated_group_;
RAW_PTR_EXCLUSION InterfaceEndpointController* controller_ = nullptr;
RAW_PTR_EXCLUSION MessageReceiverWithResponderStatus* const
incoming_receiver_ = nullptr;
HandleIncomingMessageThunk thunk_{this};
MessageDispatcher dispatcher_;
mutable base::Lock async_responders_lock_;
AsyncResponderMap async_responders_ GUARDED_BY(async_responders_lock_);
SyncResponseMap sync_responses_;
uint64_t next_request_id_ = 1;
base::OnceClosure error_handler_;
ConnectionErrorWithReasonCallback error_with_reason_handler_;
bool encountered_error_ = false;
const scoped_refptr<base::SequencedTaskRunner> task_runner_;
internal::ControlMessageProxy control_message_proxy_{this};
internal::ControlMessageHandler control_message_handler_;
const char* const interface_name_;
const MessageToMethodInfoCallback method_info_callback_;
const MessageToMethodNameCallback method_name_callback_;
#if DCHECK_IS_ON()
base::Location next_call_location_;
#endif
base::SequenceCheckerImpl sequence_checker_;
base::WeakPtrFactory<InterfaceEndpointClient> weak_ptr_factory_{this};
};
}
#endif