#ifndef MOJO_PUBLIC_CPP_BINDINGS_RECEIVER_SET_H_
#define MOJO_PUBLIC_CPP_BINDINGS_RECEIVER_SET_H_
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include "base/compiler_specific.h"
#include "base/component_export.h"
#include "base/containers/contains.h"
#include "base/containers/variant_map.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/memory/ptr_util.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/raw_ptr_exclusion.h"
#include "base/task/sequenced_task_runner.h"
#include "base/types/pass_key.h"
#include "mojo/public/cpp/bindings/connection_error_callback.h"
#include "mojo/public/cpp/bindings/message.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "mojo/public/cpp/bindings/runtime_features.h"
#include "mojo/public/cpp/bindings/unique_ptr_impl_ref_traits.h"
namespace mojo {
namespace test {
class ReceiverSetStaticAssertTests;
}
using ReceiverId = uint64_t;
template <typename ReceiverType>
struct ReceiverSetTraits;
template <typename Interface, typename ImplRefTraits>
struct ReceiverSetTraits<Receiver<Interface, ImplRefTraits>> {
using InterfaceType = Interface;
using PendingType = PendingReceiver<Interface>;
using ImplPointerType = typename ImplRefTraits::PointerType;
};
template <typename ContextType>
struct ReceiverSetContextTraits {
using Type = ContextType;
static constexpr bool SupportsContext() { return true; }
};
template <>
struct ReceiverSetContextTraits<void> {
struct Empty {};
using Type = Empty;
static constexpr bool SupportsContext() { return false; }
};
class COMPONENT_EXPORT(MOJO_CPP_BINDINGS) ReceiverSetState {
public:
using PassKey = base::PassKey<ReceiverSetState>;
class ReceiverState {
public:
virtual ~ReceiverState() = default;
virtual const void* GetContext() const = 0;
virtual void* GetContext() = 0;
virtual void InstallDispatchHooks(
std::unique_ptr<MessageFilter> filter,
RepeatingConnectionErrorWithReasonCallback disconnect_handler) = 0;
virtual void FlushForTesting() = 0;
virtual void ResetWithReason(uint32_t custom_reason_code,
const std::string& description) = 0;
};
class COMPONENT_EXPORT(MOJO_CPP_BINDINGS) Entry {
public:
Entry(ReceiverSetState& state,
ReceiverId id,
std::unique_ptr<ReceiverState> receiver,
std::unique_ptr<MessageFilter> filter);
~Entry();
ReceiverState& receiver() { return *receiver_; }
private:
class DispatchFilter;
void WillDispatch();
void DidDispatchOrReject();
void OnDisconnect(uint32_t custom_reason_code,
const std::string& description);
RAW_PTR_EXCLUSION ReceiverSetState& state_;
const ReceiverId id_;
const std::unique_ptr<ReceiverState> receiver_;
};
using EntryMap = base::VariantMap<ReceiverId, std::unique_ptr<Entry>>;
ReceiverSetState();
ReceiverSetState(const ReceiverSetState&) = delete;
ReceiverSetState& operator=(const ReceiverSetState&) = delete;
~ReceiverSetState();
EntryMap& entries() { return entries_; }
const EntryMap& entries() const { return entries_; }
const void* current_context() const {
DCHECK(current_context_);
return current_context_;
}
void* current_context() {
DCHECK(current_context_);
return current_context_;
}
ReceiverId current_receiver() const {
DCHECK(current_context_);
return current_receiver_;
}
void set_disconnect_handler(base::RepeatingClosure handler);
void set_disconnect_with_reason_handler(
RepeatingConnectionErrorWithReasonCallback handler);
ReportBadMessageCallback GetBadMessageCallback();
ReceiverId Add(std::unique_ptr<ReceiverState> receiver,
std::unique_ptr<MessageFilter> filter);
bool Remove(ReceiverId id);
bool RemoveWithReason(ReceiverId id,
uint32_t custom_reason_code,
const std::string& description);
void FlushForTesting();
void SetDispatchContext(void* context, ReceiverId receiver_id);
void OnDisconnect(ReceiverId id,
uint32_t custom_reason_code,
const std::string& description);
private:
base::RepeatingClosure disconnect_handler_;
RepeatingConnectionErrorWithReasonCallback disconnect_with_reason_handler_;
ReceiverId next_receiver_id_ = 0;
EntryMap entries_;
raw_ptr<void, DanglingUntriaged> current_context_ = nullptr;
ReceiverId current_receiver_;
base::WeakPtrFactory<ReceiverSetState> weak_ptr_factory_{this};
};
template <typename ReceiverType, typename ContextType>
class ReceiverSetBase {
public:
using PassKey = ::base::PassKey<ReceiverSetBase<ReceiverType, ContextType>>;
using Traits = ReceiverSetTraits<ReceiverType>;
using Interface = typename Traits::InterfaceType;
using PendingType = typename Traits::PendingType;
using ImplPointerType = typename Traits::ImplPointerType;
using ContextTraits = ReceiverSetContextTraits<ContextType>;
using Context = typename ContextTraits::Type;
using PreDispatchCallback = base::RepeatingCallback<void(const Context&)>;
ReceiverSetBase() = default;
ReceiverSetBase(const ReceiverSetBase&) = delete;
ReceiverSetBase& operator=(const ReceiverSetBase&) = delete;
void set_disconnect_handler(base::RepeatingClosure handler) {
state_.set_disconnect_handler(std::move(handler));
}
void set_disconnect_with_reason_handler(
RepeatingConnectionErrorWithReasonCallback handler) {
state_.set_disconnect_with_reason_handler(std::move(handler));
}
ReceiverId Add(ImplPointerType impl,
PendingType receiver,
scoped_refptr<base::SequencedTaskRunner> task_runner = nullptr)
requires(!internal::kIsRuntimeFeatureGuarded<Interface>)
{
return AddImpl(std::move(impl), std::move(receiver), {},
std::move(task_runner), nullptr)
.value();
}
std::optional<ReceiverId> Add(
ImplPointerType impl,
PendingType receiver,
scoped_refptr<base::SequencedTaskRunner> task_runner = nullptr)
requires(internal::kIsRuntimeFeatureGuarded<Interface>)
{
return AddImpl(std::move(impl), std::move(receiver), {},
std::move(task_runner), nullptr);
}
ReceiverId Add(ImplPointerType impl,
PendingType receiver,
Context context,
scoped_refptr<base::SequencedTaskRunner> task_runner = nullptr)
requires(!internal::kIsRuntimeFeatureGuarded<Interface>)
{
static_assert(ContextTraits::SupportsContext(),
"Context value unsupported for void context type.");
return AddImpl(std::move(impl), std::move(receiver), std::move(context),
std::move(task_runner), nullptr)
.value();
}
std::optional<ReceiverId> Add(
ImplPointerType impl,
PendingType receiver,
Context context,
scoped_refptr<base::SequencedTaskRunner> task_runner = nullptr)
requires(internal::kIsRuntimeFeatureGuarded<Interface>)
{
static_assert(ContextTraits::SupportsContext(),
"Context value unsupported for void context type.");
return AddImpl(std::move(impl), std::move(receiver), std::move(context),
std::move(task_runner), nullptr);
}
ReceiverId Add(ImplPointerType impl,
PendingType receiver,
Context context,
std::unique_ptr<MessageFilter> filter,
scoped_refptr<base::SequencedTaskRunner> task_runner = nullptr)
requires(!internal::kIsRuntimeFeatureGuarded<Interface>)
{
static_assert(ContextTraits::SupportsContext(),
"Context value unsupported for void context type.");
return AddImpl(std::move(impl), std::move(receiver), std::move(context),
std::move(task_runner), std::move(filter))
.value();
}
std::optional<ReceiverId> Add(
ImplPointerType impl,
PendingType receiver,
Context context,
std::unique_ptr<MessageFilter> filter,
scoped_refptr<base::SequencedTaskRunner> task_runner = nullptr)
requires(internal::kIsRuntimeFeatureGuarded<Interface>)
{
static_assert(ContextTraits::SupportsContext(),
"Context value unsupported for void context type.");
return AddImpl(std::move(impl), std::move(receiver), std::move(context),
std::move(task_runner), std::move(filter));
}
bool Remove(ReceiverId id) { return state_.Remove(id); }
bool RemoveWithReason(ReceiverId id,
uint32_t custom_reason_code,
const std::string& description) {
return state_.RemoveWithReason(id, custom_reason_code, description);
}
std::vector<PendingType> TakeReceivers() {
ReceiverSetState::EntryMap entries(PassKey{});
std::swap(state_.entries(), entries);
std::vector<PendingType> pending_receivers;
for (auto& entry : entries) {
ReceiverEntry& receiver =
static_cast<ReceiverEntry&>(entry.second->receiver());
pending_receivers.push_back(receiver.Unbind());
}
return pending_receivers;
}
std::vector<std::pair<PendingType, Context>> TakeReceiversWithContext() {
static_assert(ContextTraits::SupportsContext(),
"TakeReceiversWithContext() requires non-void context type.");
ReceiverSetState::EntryMap entries(PassKey{});
std::swap(state_.entries(), entries);
std::vector<std::pair<PendingType, Context>> pending_receivers;
for (auto& entry : entries) {
ReceiverEntry& receiver =
static_cast<ReceiverEntry&>(entry.second->receiver());
pending_receivers.emplace_back(
receiver.Unbind(),
std::move(*static_cast<Context*>(receiver.GetContext())));
}
return pending_receivers;
}
void Clear() { state_.entries().clear(); }
void ClearWithReason(uint32_t custom_reason_code,
const std::string& description) {
for (auto& entry : state_.entries())
entry.second->receiver().ResetWithReason(custom_reason_code, description);
Clear();
}
bool HasReceiver(ReceiverId id) const {
return base::Contains(state_.entries(), id);
}
Context* GetContext(ReceiverId id) const {
static_assert(ContextTraits::SupportsContext(),
"GetContext() requires non-void context type.");
auto it = state_.entries().find(id);
if (it == state_.entries().end()) {
return nullptr;
}
return static_cast<Context*>(it->second->receiver().GetContext());
}
std::map<ReceiverId, Context*> GetAllContexts() const {
static_assert(ContextTraits::SupportsContext(),
"GetAllContexts() requires non-void context type.");
std::map<ReceiverId, Context*> contexts;
for (const auto& [receiver_id, entry] : state_.entries()) {
contexts[receiver_id] =
static_cast<Context*>(entry->receiver().GetContext());
}
return contexts;
}
bool empty() const { return state_.entries().empty(); }
size_t size() const { return state_.entries().size(); }
const Context& current_context() const {
static_assert(ContextTraits::SupportsContext(),
"current_context() requires non-void context type.");
return *static_cast<const Context*>(state_.current_context());
}
Context& current_context() {
static_assert(ContextTraits::SupportsContext(),
"current_context() requires non-void context type.");
return *static_cast<Context*>(state_.current_context());
}
ReceiverId current_receiver() const { return state_.current_receiver(); }
NOT_TAIL_CALLED void ReportBadMessage(const std::string& error) {
GetBadMessageCallback().Run(error);
}
ReportBadMessageCallback GetBadMessageCallback() {
return state_.GetBadMessageCallback();
}
void FlushForTesting() { state_.FlushForTesting(); }
[[nodiscard]] ImplPointerType SwapImplForTesting(ReceiverId id,
ImplPointerType new_impl) {
auto it = state_.entries().find(id);
if (it == state_.entries().end())
return nullptr;
ReceiverEntry& entry = static_cast<ReceiverEntry&>(it->second->receiver());
return entry.SwapImplForTesting(std::move(new_impl));
}
private:
friend test::ReceiverSetStaticAssertTests;
class ReceiverEntry : public ReceiverSetState::ReceiverState {
public:
ReceiverEntry(ImplPointerType impl,
PendingType receiver,
Context context,
scoped_refptr<base::SequencedTaskRunner> task_runner)
: receiver_(std::move(impl),
std::move(receiver),
std::move(task_runner)),
context_(std::move(context)) {}
ReceiverEntry(const ReceiverEntry&) = delete;
ReceiverEntry& operator=(const ReceiverEntry&) = delete;
~ReceiverEntry() override = default;
const void* GetContext() const override { return &context_; }
void* GetContext() override { return &context_; }
void InstallDispatchHooks(std::unique_ptr<MessageFilter> filter,
RepeatingConnectionErrorWithReasonCallback
disconnect_handler) override {
receiver_.SetFilter(std::move(filter));
receiver_.set_disconnect_with_reason_handler(
std::move(disconnect_handler));
}
void FlushForTesting() override { receiver_.FlushForTesting(); }
void ResetWithReason(uint32_t custom_reason_code,
const std::string& description) override {
receiver_.ResetWithReason(custom_reason_code, description);
}
ImplPointerType SwapImplForTesting(ImplPointerType new_impl) {
return receiver_.SwapImplForTesting(std::move(new_impl));
}
PendingType Unbind() { return receiver_.Unbind(); }
private:
ReceiverType receiver_;
NO_UNIQUE_ADDRESS Context context_;
};
std::optional<ReceiverId> AddImpl(
ImplPointerType impl,
PendingType receiver,
Context context,
scoped_refptr<base::SequencedTaskRunner> task_runner,
std::unique_ptr<MessageFilter> filter) {
DCHECK(receiver.is_valid());
if (!internal::GetRuntimeFeature_ExpectEnabled<Interface>()) {
return std::nullopt;
}
return state_.Add(std::make_unique<ReceiverEntry>(
std::move(impl), std::move(receiver),
std::move(context), std::move(task_runner)),
std::move(filter));
}
ReceiverSetState state_;
};
template <typename Interface, typename ContextType = void>
using ReceiverSet = ReceiverSetBase<Receiver<Interface>, ContextType>;
}
#endif