#include "remoting/host/security_key/security_key_auth_handler.h"
#include <cstdint>
#include <map>
#include <memory>
#include <string>
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/notreached.h"
#include "base/strings/utf_string_conversions.h"
#include "base/task/single_thread_task_runner.h"
#include "base/threading/thread_checker.h"
#include "base/time/time.h"
#include "base/timer/timer.h"
#include "base/win/win_util.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver_set.h"
#include "remoting/base/logging.h"
#include "remoting/host/client_session_details.h"
#include "remoting/host/mojom/remote_security_key.mojom.h"
#include "remoting/host/security_key/security_key_ipc_constants.h"
namespace remoting {
namespace {
constexpr base::TimeDelta kInitialRequestTimeout = base::Seconds(5);
constexpr base::TimeDelta kSecurityKeyRequestTimeout = base::Seconds(60);
struct ActiveConnection {
mojo::ReceiverId receiver_id;
base::OneShotTimer disconnect_timer;
mojom::SecurityKeyForwarder::OnSecurityKeyRequestCallback
on_security_key_request_callback;
};
}
class SecurityKeyAuthHandlerWin : public SecurityKeyAuthHandler,
public mojom::SecurityKeyForwarder {
public:
explicit SecurityKeyAuthHandlerWin(
ClientSessionDetails* client_session_details);
SecurityKeyAuthHandlerWin(const SecurityKeyAuthHandlerWin&) = delete;
SecurityKeyAuthHandlerWin& operator=(const SecurityKeyAuthHandlerWin&) =
delete;
~SecurityKeyAuthHandlerWin() override;
private:
using ActiveConnections = std::map< int, ActiveConnection>;
void BindSecurityKeyForwarder(
mojo::PendingReceiver<mojom::SecurityKeyForwarder> receiver) override;
void CreateSecurityKeyConnection() override;
bool IsValidConnectionId(int security_key_connection_id) const override;
void SendClientResponse(int security_key_connection_id,
const std::string& response) override;
void SendErrorAndCloseConnection(int security_key_connection_id) override;
void SetSendMessageCallback(const SendMessageCallback& callback) override;
size_t GetActiveConnectionCountForTest() const override;
void SetRequestTimeoutForTest(base::TimeDelta timeout) override;
void OnSecurityKeyRequest(const std::string& request_data,
OnSecurityKeyRequestCallback callback) override;
void OnIpcPeerDisconnected();
void CloseSecurityKeyRequestConnection(int connection_id);
base::OnceClosure GetCloseConnectionClosure(int connection_id);
int last_connection_id_ = 0;
SendMessageCallback send_message_callback_;
raw_ptr<ClientSessionDetails> client_session_details_ = nullptr;
ActiveConnections active_connections_;
mojo::ReceiverSet<mojom::SecurityKeyForwarder, int>
receiver_set_;
base::ThreadChecker thread_checker_;
base::WeakPtrFactory<SecurityKeyAuthHandlerWin> weak_factory_{this};
};
std::unique_ptr<SecurityKeyAuthHandler> SecurityKeyAuthHandler::Create(
ClientSessionDetails* client_session_details,
const SendMessageCallback& send_message_callback,
scoped_refptr<base::SingleThreadTaskRunner> file_task_runner) {
std::unique_ptr<SecurityKeyAuthHandler> auth_handler(
new SecurityKeyAuthHandlerWin(client_session_details));
auth_handler->SetSendMessageCallback(send_message_callback);
return auth_handler;
}
SecurityKeyAuthHandlerWin::SecurityKeyAuthHandlerWin(
ClientSessionDetails* client_session_details)
: client_session_details_(client_session_details) {
DCHECK(client_session_details_);
receiver_set_.set_disconnect_handler(
base::BindRepeating(&SecurityKeyAuthHandlerWin::OnIpcPeerDisconnected,
weak_factory_.GetWeakPtr()));
}
SecurityKeyAuthHandlerWin::~SecurityKeyAuthHandlerWin() {
DCHECK(thread_checker_.CalledOnValidThread());
}
void SecurityKeyAuthHandlerWin::BindSecurityKeyForwarder(
mojo::PendingReceiver<mojom::SecurityKeyForwarder> receiver) {
DCHECK(thread_checker_.CalledOnValidThread());
int new_connection_id = ++last_connection_id_;
ActiveConnection& connection = active_connections_[new_connection_id];
connection.receiver_id =
receiver_set_.Add(this, std::move(receiver), new_connection_id);
connection.disconnect_timer.Start(
FROM_HERE, kInitialRequestTimeout,
GetCloseConnectionClosure(new_connection_id));
}
void SecurityKeyAuthHandlerWin::CreateSecurityKeyConnection() {
}
bool SecurityKeyAuthHandlerWin::IsValidConnectionId(int connection_id) const {
DCHECK(thread_checker_.CalledOnValidThread());
return active_connections_.find(connection_id) != active_connections_.end();
}
void SecurityKeyAuthHandlerWin::SendClientResponse(
int connection_id,
const std::string& response_data) {
DCHECK(thread_checker_.CalledOnValidThread());
auto iter = active_connections_.find(connection_id);
if (iter == active_connections_.end()) {
HOST_LOG << "Invalid security key connection ID received: "
<< connection_id;
return;
}
ActiveConnection& connection = iter->second;
std::move(connection.on_security_key_request_callback).Run(response_data);
connection.disconnect_timer.Start(FROM_HERE, kSecurityKeyRequestTimeout,
GetCloseConnectionClosure(connection_id));
}
void SecurityKeyAuthHandlerWin::SendErrorAndCloseConnection(int connection_id) {
DCHECK(thread_checker_.CalledOnValidThread());
SendClientResponse(connection_id, kSecurityKeyConnectionError);
CloseSecurityKeyRequestConnection(connection_id);
}
void SecurityKeyAuthHandlerWin::SetSendMessageCallback(
const SendMessageCallback& callback) {
DCHECK(thread_checker_.CalledOnValidThread());
send_message_callback_ = callback;
}
size_t SecurityKeyAuthHandlerWin::GetActiveConnectionCountForTest() const {
return active_connections_.size();
}
void SecurityKeyAuthHandlerWin::SetRequestTimeoutForTest(
base::TimeDelta timeout) {
NOTREACHED();
}
void SecurityKeyAuthHandlerWin::OnSecurityKeyRequest(
const std::string& request_data,
OnSecurityKeyRequestCallback callback) {
DCHECK(thread_checker_.CalledOnValidThread());
DCHECK(send_message_callback_);
int connection_id = receiver_set_.current_context();
auto iter = active_connections_.find(connection_id);
DCHECK(iter != active_connections_.end());
ActiveConnection& connection = iter->second;
if (connection.on_security_key_request_callback) {
LOG(ERROR) << "Received security key request while waiting for a response";
CloseSecurityKeyRequestConnection(connection_id);
return;
}
connection.disconnect_timer.Start(FROM_HERE, kSecurityKeyRequestTimeout,
GetCloseConnectionClosure(connection_id));
connection.on_security_key_request_callback = std::move(callback);
send_message_callback_.Run(connection_id, request_data);
}
void SecurityKeyAuthHandlerWin::OnIpcPeerDisconnected() {
DCHECK(thread_checker_.CalledOnValidThread());
active_connections_.erase(receiver_set_.current_context());
}
void SecurityKeyAuthHandlerWin::CloseSecurityKeyRequestConnection(
int connection_id) {
DCHECK(thread_checker_.CalledOnValidThread());
auto iter = active_connections_.find(connection_id);
if (iter == active_connections_.end()) {
LOG(ERROR) << "Connection ID " << connection_id << " doesn't exist.";
return;
}
receiver_set_.Remove(iter->second.receiver_id);
active_connections_.erase(iter);
}
base::OnceClosure SecurityKeyAuthHandlerWin::GetCloseConnectionClosure(
int connection_id) {
return base::BindOnce(
&SecurityKeyAuthHandlerWin::CloseSecurityKeyRequestConnection,
weak_factory_.GetWeakPtr(), connection_id);
}
}