#ifndef CHROMEOS_ASH_EXPERIENCES_ARC_SESSION_CONNECTION_HOLDER_H_
#define CHROMEOS_ASH_EXPERIENCES_ARC_SESSION_CONNECTION_HOLDER_H_
#include <memory>
#include <type_traits>
#include <utility>
#include "base/functional/bind.h"
#include "base/logging.h"
#include "base/memory/raw_ptr.h"
#include "base/threading/thread_checker.h"
#include "chromeos/ash/experiences/arc/session/connection_notifier.h"
#include "chromeos/ash/experiences/arc/session/connection_observer.h"
#include "mojo/public/cpp/bindings/receiver.h"
#define ARC_GET_INSTANCE_FOR_METHOD(holder, method_name) \
(holder)->GetInstanceForVersionDoNotCallDirectly( \
std::remove_pointer< \
decltype(holder)>::type::Instance::k##method_name##MinVersion, \
#method_name)
namespace arc {
namespace internal {
struct HasInitImpl {
template <typename InstanceType>
static auto Check(InstanceType* v)
-> decltype(&InstanceType::Init, std::true_type());
static std::false_type Check(...);
};
template <typename InstanceType>
using HasInit =
decltype(HasInitImpl::Check(static_cast<InstanceType*>(nullptr)));
template <typename InstanceType, typename HostType>
class ConnectionHolderImpl {
public:
explicit ConnectionHolderImpl(ConnectionNotifier* connection_notifier)
: connection_notifier_(connection_notifier) {}
ConnectionHolderImpl(const ConnectionHolderImpl&) = delete;
ConnectionHolderImpl& operator=(const ConnectionHolderImpl&) = delete;
InstanceType* instance() { return IsConnected() ? instance_.get() : nullptr; }
uint32_t instance_version() const {
return IsConnected() ? instance_version_ : 0;
}
bool IsConnected() const { return receiver_.get(); }
void SetHost(HostType* host) {
DCHECK(host == nullptr || host_ == nullptr || instance_ == nullptr);
if (host_ == host) {
return;
}
host_ = host;
OnChanged();
}
void SetInstance(InstanceType* instance,
uint32_t version = InstanceType::version_) {
DCHECK(instance);
DCHECK(instance_ != instance);
instance_ = instance;
instance_version_ = version;
OnChanged();
}
void CloseInstance(InstanceType* instance) {
DCHECK(instance);
if (instance != instance_) {
DVLOG(1) << "Dropping request to close a stale instance";
return;
}
instance_ = nullptr;
instance_version_ = 0;
OnChanged();
}
private:
void OnChanged() {
weak_ptr_factory_.InvalidateWeakPtrs();
if (receiver_.get()) {
if (instance_ && host_) {
LOG(ERROR) << "Unbinding instance of a stale connection";
}
OnConnectionClosed();
}
if (!instance_ || !host_) {
return;
}
auto receiver = std::make_unique<mojo::Receiver<HostType>>(host_);
mojo::PendingRemote<HostType> host_proxy;
receiver->Bind(host_proxy.InitWithNewPipeAndPassReceiver());
instance_->Init(
std::move(host_proxy),
base::BindOnce(&ConnectionHolderImpl::OnConnectionReady,
weak_ptr_factory_.GetWeakPtr(), std::move(receiver)));
}
void OnConnectionClosed() {
DCHECK(receiver_);
receiver_.reset();
connection_notifier_->NotifyConnectionClosed();
}
void OnConnectionReady(std::unique_ptr<mojo::Receiver<HostType>> receiver) {
DCHECK(!receiver_);
receiver->set_disconnect_handler(base::BindOnce(
&ConnectionHolderImpl::OnConnectionClosed, base::Unretained(this)));
receiver_ = std::move(receiver);
connection_notifier_->NotifyConnectionReady();
}
const raw_ptr<ConnectionNotifier> connection_notifier_;
raw_ptr<InstanceType, DanglingUntriaged> instance_ = nullptr;
uint32_t instance_version_ = 0;
raw_ptr<HostType, DanglingUntriaged> host_ = nullptr;
std::unique_ptr<mojo::Receiver<HostType>> receiver_;
base::WeakPtrFactory<ConnectionHolderImpl> weak_ptr_factory_{this};
};
template <typename InstanceType>
class ConnectionHolderImpl<InstanceType, void> {
public:
static_assert(!HasInit<InstanceType>::value,
"Full duplex ConnectionHolderImpl should be used instead");
explicit ConnectionHolderImpl(ConnectionNotifier* connection_notifier)
: connection_notifier_(connection_notifier) {}
ConnectionHolderImpl(const ConnectionHolderImpl&) = delete;
ConnectionHolderImpl& operator=(const ConnectionHolderImpl&) = delete;
InstanceType* instance() { return instance_; }
uint32_t instance_version() const { return instance_version_; }
bool IsConnected() const { return instance_; }
void SetHost(void* unused) {
static_assert(!sizeof(*this),
"ConnectionHolder::SetHost for single direction connection "
"is called unexpectedly.");
NOTREACHED();
}
void SetInstance(InstanceType* instance,
uint32_t version = InstanceType::version_) {
DCHECK(instance);
DCHECK(instance_ != instance);
instance_ = instance;
instance_version_ = version;
connection_notifier_->NotifyConnectionReady();
}
void CloseInstance(InstanceType* instance) {
DCHECK(instance);
if (instance != instance_) {
DVLOG(1) << "Dropping request to close a stale instance";
return;
}
instance_ = nullptr;
instance_version_ = 0;
connection_notifier_->NotifyConnectionClosed();
}
private:
const raw_ptr<ConnectionNotifier> connection_notifier_;
raw_ptr<InstanceType, DanglingUntriaged> instance_ = nullptr;
uint32_t instance_version_ = 0;
};
}
template <typename InstanceType, typename HostType = void>
class ConnectionHolder {
public:
using Observer = ConnectionObserver<InstanceType>;
using Instance = InstanceType;
ConnectionHolder() = default;
ConnectionHolder(const ConnectionHolder&) = delete;
ConnectionHolder& operator=(const ConnectionHolder&) = delete;
uint32_t instance_version() const { return impl_.instance_version(); }
bool IsConnected() const { return impl_.IsConnected(); }
InstanceType* GetInstanceForVersionDoNotCallDirectly(
uint32_t min_version,
const char method_name_for_logging[]) {
if (!impl_.IsConnected()) {
VLOG(1) << "Instance " << InstanceType::Name_ << " not available.";
return nullptr;
}
if (impl_.instance_version() < min_version) {
LOG(ERROR) << "Instance for " << InstanceType::Name_
<< "::" << method_name_for_logging
<< " version mismatch. Expected " << min_version << " got "
<< impl_.instance_version();
return nullptr;
}
return impl_.instance();
}
void AddObserver(Observer* observer) {
connection_notifier_.AddObserver(observer);
if (impl_.IsConnected()) {
observer->OnConnectionReady();
}
}
void RemoveObserver(Observer* observer) {
connection_notifier_.RemoveObserver(observer);
}
void SetHost(HostType* host) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
impl_.SetHost(host);
}
void SetInstance(InstanceType* instance,
uint32_t version = InstanceType::Version_) {
DCHECK(instance);
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
impl_.SetInstance(instance, version);
}
void CloseInstance(InstanceType* instance) {
DCHECK(instance);
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
impl_.CloseInstance(instance);
}
private:
THREAD_CHECKER(thread_checker_);
internal::ConnectionNotifier connection_notifier_;
internal::ConnectionHolderImpl<InstanceType, HostType> impl_{
&connection_notifier_};
};
}
#endif