#include "remoting/host/linux/certificate_watcher.h"
#include "base/files/file_util.h"
#include "base/functional/bind.h"
#include "base/functional/callback_helpers.h"
#include "base/hash/hash.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/path_service.h"
#include "base/task/single_thread_task_runner.h"
#include "base/threading/thread_checker.h"
namespace remoting {
namespace {
const int kReadDelayInSeconds = 2;
const char kCertDirectoryPath[] = ".pki/nssdb";
const char* const kCertFiles[] = {"cert9.db", "key4.db", "pkcs11.txt"};
}
class CertDbContentWatcher {
public:
CertDbContentWatcher(
base::WeakPtr<CertificateWatcher> watcher,
scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner,
base::FilePath cert_watch_path,
base::TimeDelta read_delay);
CertDbContentWatcher(const CertDbContentWatcher&) = delete;
CertDbContentWatcher& operator=(const CertDbContentWatcher&) = delete;
~CertDbContentWatcher();
void StartWatching();
private:
base::ThreadChecker thread_checker_;
typedef size_t HashValue;
void OnCertDirectoryChanged(const base::FilePath& path, bool error);
void OnTimer();
HashValue ComputeHash();
base::WeakPtr<CertificateWatcher> watcher_;
scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner_;
std::unique_ptr<base::FilePathWatcher> file_watcher_;
base::FilePath cert_watch_path_;
std::unique_ptr<base::DelayTimer> read_timer_;
base::TimeDelta delay_;
HashValue current_hash_;
};
CertDbContentWatcher::CertDbContentWatcher(
base::WeakPtr<CertificateWatcher> watcher,
scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner,
base::FilePath cert_watch_path,
base::TimeDelta read_delay)
: watcher_(watcher),
caller_task_runner_(caller_task_runner),
cert_watch_path_(cert_watch_path),
delay_(read_delay) {
thread_checker_.DetachFromThread();
}
CertDbContentWatcher::~CertDbContentWatcher() {
DCHECK(thread_checker_.CalledOnValidThread());
}
void CertDbContentWatcher::StartWatching() {
DCHECK(!cert_watch_path_.empty());
DCHECK(thread_checker_.CalledOnValidThread());
file_watcher_.reset(new base::FilePathWatcher());
current_hash_ = ComputeHash();
file_watcher_->Watch(
cert_watch_path_, base::FilePathWatcher::Type::kRecursive,
base::BindRepeating(&CertDbContentWatcher::OnCertDirectoryChanged,
base::Unretained(this)));
read_timer_.reset(new base::DelayTimer(FROM_HERE, delay_, this,
&CertDbContentWatcher::OnTimer));
}
void CertDbContentWatcher::OnCertDirectoryChanged(const base::FilePath& path,
bool error) {
DCHECK(path == cert_watch_path_);
DCHECK(thread_checker_.CalledOnValidThread());
if (error) {
LOG(FATAL) << "Error occurred while watching for changes of file: "
<< cert_watch_path_.MaybeAsASCII();
}
read_timer_->Reset();
}
void CertDbContentWatcher::OnTimer() {
DCHECK(thread_checker_.CalledOnValidThread());
HashValue new_hash = ComputeHash();
if (new_hash != current_hash_) {
current_hash_ = new_hash;
caller_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&CertificateWatcher::DatabaseChanged, watcher_));
} else {
VLOG(1) << "Directory changed but contents are the same.";
}
}
CertDbContentWatcher::HashValue CertDbContentWatcher::ComputeHash() {
DCHECK(thread_checker_.CalledOnValidThread());
HashValue result = 0;
for (const char* file : kCertFiles) {
base::FilePath path = cert_watch_path_.AppendASCII(file);
std::string content;
HashValue file_hash = 0;
if (base::ReadFileToString(path, &content)) {
file_hash = base::Hash(content);
}
result = base::HashInts(result, file_hash);
}
return result;
}
CertificateWatcher::CertificateWatcher(
const base::RepeatingClosure& restart_action,
scoped_refptr<base::SingleThreadTaskRunner> io_task_runner)
: restart_action_(restart_action),
caller_task_runner_(base::SingleThreadTaskRunner::GetCurrentDefault()),
io_task_runner_(io_task_runner),
delay_(base::Seconds(kReadDelayInSeconds)) {
if (!base::PathService::Get(base::DIR_HOME, &cert_watch_path_)) {
LOG(FATAL) << "Failed to get path of the home directory.";
}
cert_watch_path_ = cert_watch_path_.AppendASCII(kCertDirectoryPath);
}
CertificateWatcher::~CertificateWatcher() {
DCHECK(caller_task_runner_->BelongsToCurrentThread());
if (!is_started()) {
return;
}
if (monitor_) {
monitor_->RemoveStatusObserver(this);
}
io_task_runner_->DeleteSoon(FROM_HERE, content_watcher_.release());
VLOG(1) << "Stopped watching certificate changes.";
}
void CertificateWatcher::Start() {
DCHECK(caller_task_runner_->BelongsToCurrentThread());
DCHECK(!cert_watch_path_.empty());
content_watcher_.reset(new CertDbContentWatcher(weak_factory_.GetWeakPtr(),
caller_task_runner_,
cert_watch_path_, delay_));
io_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&CertDbContentWatcher::StartWatching,
base::Unretained(content_watcher_.get())));
VLOG(1) << "Started watching certificate changes.";
}
void CertificateWatcher::SetMonitor(scoped_refptr<HostStatusMonitor> monitor) {
DCHECK(is_started());
if (monitor_) {
monitor_->RemoveStatusObserver(this);
}
monitor->AddStatusObserver(this);
monitor_ = monitor;
}
void CertificateWatcher::OnClientConnected(const std::string& jid) {
DCHECK(is_started());
DCHECK(caller_task_runner_->BelongsToCurrentThread());
inhibit_mode_ = true;
}
void CertificateWatcher::OnClientDisconnected(const std::string& jid) {
DCHECK(is_started());
DCHECK(caller_task_runner_->BelongsToCurrentThread());
inhibit_mode_ = false;
if (restart_pending_) {
restart_pending_ = false;
restart_action_.Run();
}
}
void CertificateWatcher::SetDelayForTests(const base::TimeDelta& delay) {
DCHECK(!is_started());
delay_ = delay;
}
void CertificateWatcher::SetWatchPathForTests(
const base::FilePath& watch_path) {
DCHECK(!is_started());
cert_watch_path_ = watch_path;
}
bool CertificateWatcher::is_started() const {
return content_watcher_ != nullptr;
}
void CertificateWatcher::DatabaseChanged() {
DCHECK(caller_task_runner_->BelongsToCurrentThread());
if (inhibit_mode_) {
restart_pending_ = true;
return;
}
VLOG(1) << "Certificate was updated. Calling restart...";
restart_action_.Run();
}
}