#include "mojo/public/cpp/system/wait_set.h"
#include <algorithm>
#include <limits>
#include <map>
#include <set>
#include <vector>
#include "base/check_op.h"
#include "base/compiler_specific.h"
#include "base/containers/span.h"
#include "base/memory/ptr_util.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/ref_counted.h"
#include "base/synchronization/lock.h"
#include "base/synchronization/waitable_event.h"
#include "mojo/public/cpp/system/trap.h"
#include "third_party/abseil-cpp/absl/container/inlined_vector.h"
namespace mojo {
class WaitSet::State : public base::RefCountedThreadSafe<State> {
public:
State()
: handle_event_(base::WaitableEvent::ResetPolicy::MANUAL,
base::WaitableEvent::InitialState::NOT_SIGNALED) {
MojoResult rv = CreateTrap(&Context::OnNotification, &trap_handle_);
DCHECK_EQ(MOJO_RESULT_OK, rv);
}
State(const State&) = delete;
State& operator=(const State&) = delete;
void ShutDown() {
trap_handle_.reset();
cancelled_contexts_.clear();
}
MojoResult AddEvent(base::WaitableEvent* event) {
auto result = user_events_.insert(event);
if (result.second)
return MOJO_RESULT_OK;
return MOJO_RESULT_ALREADY_EXISTS;
}
MojoResult RemoveEvent(base::WaitableEvent* event) {
auto it = user_events_.find(event);
if (it == user_events_.end())
return MOJO_RESULT_NOT_FOUND;
user_events_.erase(it);
return MOJO_RESULT_OK;
}
MojoResult AddHandle(Handle handle, MojoHandleSignals signals) {
DCHECK(trap_handle_.is_valid());
scoped_refptr<Context> context = new Context(this, handle);
{
base::AutoLock lock(lock_);
if (handle_to_context_.count(handle))
return MOJO_RESULT_ALREADY_EXISTS;
DCHECK(!contexts_.count(context->context_value()));
handle_to_context_[handle] = context;
contexts_[context->context_value()] = context;
}
context->AddRef();
MojoResult rv =
MojoAddTrigger(trap_handle_.get().value(), handle.value(), signals,
MOJO_TRIGGER_CONDITION_SIGNALS_SATISFIED,
context->context_value(), nullptr);
if (rv == MOJO_RESULT_INVALID_ARGUMENT) {
base::AutoLock lock(lock_);
handle_to_context_.erase(handle);
contexts_.erase(context->context_value());
context->Release();
return rv;
}
DCHECK_EQ(MOJO_RESULT_OK, rv);
return rv;
}
MojoResult RemoveHandle(Handle handle) {
DCHECK(trap_handle_.is_valid());
scoped_refptr<Context> context;
{
base::AutoLock lock(lock_);
cancelled_contexts_.clear();
auto it = handle_to_context_.find(handle);
if (it == handle_to_context_.end())
return MOJO_RESULT_NOT_FOUND;
context = std::move(it->second);
handle_to_context_.erase(it);
ready_handles_.erase(handle);
}
MojoResult rv = MojoRemoveTrigger(trap_handle_.get().value(),
context->context_value(), nullptr);
DCHECK(rv == MOJO_RESULT_OK || rv == MOJO_RESULT_NOT_FOUND);
return rv;
}
void Wait(base::WaitableEvent** ready_event,
size_t* num_ready_handles,
base::span<Handle> ready_handles,
base::span<MojoResult> ready_results,
base::span<HandleSignalsState> signals_states) {
DCHECK(trap_handle_.is_valid());
DCHECK(num_ready_handles);
DCHECK(!ready_handles.empty());
DCHECK(!ready_results.empty());
{
base::AutoLock lock(lock_);
if (ready_handles_.empty()) {
handle_event_.Reset();
DCHECK_LE(*num_ready_handles, std::numeric_limits<uint32_t>::max());
uint32_t num_blocking_events =
static_cast<uint32_t>(*num_ready_handles);
absl::InlinedVector<MojoTrapEvent, 4> blocking_events;
blocking_events.resize(num_blocking_events);
for (size_t i = 0; i < num_blocking_events; ++i) {
blocking_events[i].struct_size = sizeof(blocking_events[i]);
}
MojoResult rv =
MojoArmTrap(trap_handle_.get().value(), nullptr,
&num_blocking_events, blocking_events.data());
if (rv == MOJO_RESULT_FAILED_PRECONDITION) {
handle_event_.Signal();
for (size_t i = 0; i < num_blocking_events; ++i) {
const auto& event = blocking_events[i];
auto it = contexts_.find(event.trigger_context);
CHECK(it != contexts_.end());
ready_handles_[it->second->handle()] = {event.result,
event.signals_state};
}
} else if (rv == MOJO_RESULT_NOT_FOUND) {
if (user_events_.empty())
handle_event_.Signal();
} else {
DCHECK_EQ(MOJO_RESULT_OK, rv);
}
}
}
absl::InlinedVector<base::WaitableEvent*, 4> events;
events.resize(user_events_.size() + 1);
if (waitable_index_shift_ > user_events_.size())
waitable_index_shift_ = 0;
size_t dest_index = waitable_index_shift_++;
events[dest_index] = &handle_event_;
for (base::WaitableEvent* e : user_events_) {
dest_index = (dest_index + 1) % events.size();
events[dest_index] = e;
}
size_t index = base::WaitableEvent::WaitMany(events);
base::AutoLock lock(lock_);
*num_ready_handles = std::min(*num_ready_handles, ready_handles_.size());
for (size_t i = 0; i < *num_ready_handles; ++i) {
auto it = ready_handles_.begin();
ready_handles[i] = it->first;
ready_results[i] = it->second.result;
if (!signals_states.empty()) {
signals_states[i] = it->second.signals_state;
}
ready_handles_.erase(it);
}
if (ready_event) {
if (events[index] == &handle_event_) {
*ready_event = nullptr;
} else {
*ready_event = events[index];
}
}
}
private:
friend class base::RefCountedThreadSafe<State>;
class Context : public base::RefCountedThreadSafe<Context> {
public:
Context(scoped_refptr<State> state, Handle handle)
: state_(state), handle_(handle) {}
Context(const Context&) = delete;
Context& operator=(const Context&) = delete;
Handle handle() const { return handle_; }
uintptr_t context_value() const {
return reinterpret_cast<uintptr_t>(this);
}
static void OnNotification(const MojoTrapEvent* event) {
reinterpret_cast<Context*>(event->trigger_context)
->Notify(event->result, event->signals_state);
}
private:
friend class base::RefCountedThreadSafe<Context>;
~Context() {}
void Notify(MojoResult result, MojoHandleSignalsState signals_state) {
state_->Notify(handle_, result, signals_state, this);
}
const scoped_refptr<State> state_;
const Handle handle_;
};
~State() {}
void Notify(Handle handle,
MojoResult result,
MojoHandleSignalsState signals_state,
Context* context) {
base::AutoLock lock(lock_);
if (handle_to_context_.count(handle)) {
ready_handles_[handle] = {result, signals_state};
handle_event_.Signal();
}
if (result == MOJO_RESULT_CANCELLED) {
contexts_.erase(context->context_value());
handle_to_context_.erase(handle);
cancelled_contexts_.emplace_back(base::WrapRefCounted(context));
context->Release();
}
}
struct ReadyState {
ReadyState() = default;
ReadyState(MojoResult result, MojoHandleSignalsState signals_state)
: result(result), signals_state(signals_state) {}
~ReadyState() = default;
MojoResult result = MOJO_RESULT_UNKNOWN;
MojoHandleSignalsState signals_state = {0, 0};
};
ScopedTrapHandle trap_handle_;
base::Lock lock_;
std::map<uintptr_t, scoped_refptr<Context>> contexts_;
std::map<Handle, scoped_refptr<Context>> handle_to_context_;
std::map<Handle, ReadyState> ready_handles_;
std::vector<scoped_refptr<Context>> cancelled_contexts_;
std::set<raw_ptr<base::WaitableEvent, SetExperimental>> user_events_;
base::WaitableEvent handle_event_;
size_t waitable_index_shift_ = 0;
};
WaitSet::WaitSet() : state_(new State) {}
WaitSet::~WaitSet() {
state_->ShutDown();
}
MojoResult WaitSet::AddEvent(base::WaitableEvent* event) {
return state_->AddEvent(event);
}
MojoResult WaitSet::RemoveEvent(base::WaitableEvent* event) {
return state_->RemoveEvent(event);
}
MojoResult WaitSet::AddHandle(Handle handle, MojoHandleSignals signals) {
return state_->AddHandle(handle, signals);
}
MojoResult WaitSet::RemoveHandle(Handle handle) {
return state_->RemoveHandle(handle);
}
void WaitSet::Wait(base::WaitableEvent** ready_event,
size_t* num_ready_handles,
base::span<Handle> ready_handles,
base::span<MojoResult> ready_results,
base::span<HandleSignalsState> signals_states) {
state_->Wait(ready_event, num_ready_handles, ready_handles, ready_results,
signals_states);
}
}