#include "net/device_bound_sessions/session_store_impl.h"
#include <algorithm>
#include "base/metrics/histogram_functions.h"
#include "base/sequence_checker.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/thread_pool.h"
#include "base/time/time.h"
#include "components/unexportable_keys/background_task_priority.h"
#include "components/unexportable_keys/service_error.h"
#include "components/unexportable_keys/unexportable_key_id.h"
#include "components/unexportable_keys/unexportable_key_service.h"
#include "net/base/features.h"
#include "net/base/schemeful_site.h"
#include "net/device_bound_sessions/proto/storage.pb.h"
namespace net::device_bound_sessions {
namespace {
using unexportable_keys::BackgroundTaskPriority;
using unexportable_keys::ServiceError;
using unexportable_keys::ServiceErrorOr;
using unexportable_keys::UnexportableKeyId;
using unexportable_keys::UnexportableKeyService;
constexpr base::TaskTraits kDBTaskTraits = {
base::MayBlock(), base::TaskPriority::USER_VISIBLE,
base::TaskShutdownBehavior::BLOCK_SHUTDOWN};
const char kSessionTableName[] = "dbsc_session_tbl";
const base::TimeDelta kFlushDelay = base::Seconds(2);
SessionStoreImpl::DBStatus InitializeOnDbSequence(
sql::Database* db,
base::FilePath db_storage_path,
sqlite_proto::ProtoTableManager* table_manager,
sqlite_proto::KeyValueData<proto::SiteSessions>* session_data) {
if (db->Open(db_storage_path) == false) {
return SessionStoreImpl::DBStatus::kFailure;
}
table_manager->InitializeOnDbSequence(
db, std::vector<std::string>{kSessionTableName},
features::kDeviceBoundSessionsSchemaVersion.Get());
session_data->InitializeOnDBSequence();
return SessionStoreImpl::DBStatus::kSuccess;
}
}
SessionStoreImpl::SessionStoreImpl(base::FilePath db_storage_path,
UnexportableKeyService& key_service)
: key_service_(key_service),
db_task_runner_(
base::ThreadPool::CreateSequencedTaskRunner(kDBTaskTraits)),
db_storage_path_(std::move(db_storage_path)),
db_(std::make_unique<sql::Database>(
sql::DatabaseOptions().set_preload(true),
sql::Database::Tag("DBSCSessions"))),
table_manager_(base::MakeRefCounted<sqlite_proto::ProtoTableManager>(
db_task_runner_)),
session_table_(
std::make_unique<sqlite_proto::KeyValueTable<proto::SiteSessions>>(
kSessionTableName)),
session_data_(
std::make_unique<sqlite_proto::KeyValueData<proto::SiteSessions>>(
table_manager_,
session_table_.get(),
std::nullopt,
kFlushDelay)) {}
SessionStoreImpl::~SessionStoreImpl() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (db_status_ == DBStatus::kSuccess) {
session_data_->FlushDataToDisk();
}
db_task_runner_->PostTaskAndReply(
FROM_HERE,
base::BindOnce(
[](scoped_refptr<sqlite_proto::ProtoTableManager> table_manager,
std::unique_ptr<sql::Database> db,
auto session_table) { table_manager->WillShutdown(); },
std::move(table_manager_), std::move(db_), std::move(session_table_)),
base::BindOnce(
[](auto session_data, base::OnceClosure shutdown_callback) {
if (shutdown_callback) {
std::move(shutdown_callback).Run();
}
},
std::move(session_data_), std::move(shutdown_callback_)));
}
void SessionStoreImpl::LoadSessions(LoadSessionsCallback callback) {
CHECK_EQ(db_status_, DBStatus::kNotLoaded);
db_task_runner_->PostTaskAndReplyWithResult(
FROM_HERE,
base::BindOnce(&InitializeOnDbSequence, base::Unretained(db_.get()),
db_storage_path_, base::Unretained(table_manager_.get()),
base::Unretained(session_data_.get())),
base::BindOnce(&SessionStoreImpl::OnDatabaseLoaded,
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
base::ElapsedTimer()));
}
void SessionStoreImpl::OnDatabaseLoaded(LoadSessionsCallback callback,
base::ElapsedTimer timer,
DBStatus db_status) {
db_status_ = db_status;
SessionsMap sessions;
if (db_status == DBStatus::kSuccess) {
std::vector<std::string> keys_to_delete;
sessions = CreateSessionsFromLoadedData(session_data_->GetAllCached(),
keys_to_delete);
if (keys_to_delete.size() > 0) {
session_data_->DeleteData(keys_to_delete);
}
}
base::UmaHistogramBoolean("Net.DeviceBoundSessions.SessionStoreLoadSuccess",
db_status == DBStatus::kSuccess);
base::UmaHistogramTimes("Net.DeviceBoundSessions.SessionStoreLoadDuration",
timer.Elapsed());
std::move(callback).Run(std::move(sessions));
}
SessionStore::SessionsMap SessionStoreImpl::CreateSessionsFromLoadedData(
const std::map<std::string, proto::SiteSessions>& loaded_data,
std::vector<std::string>& keys_to_delete) {
SessionsMap all_sessions;
for (const auto& [site_str, site_proto] : loaded_data) {
SchemefulSite site = net::SchemefulSite::Deserialize(site_str);
if (site.opaque()) {
keys_to_delete.push_back(site_str);
continue;
}
bool invalid_session_found = false;
SessionsMap site_sessions;
for (const auto& [session_id, session_proto] : site_proto.sessions()) {
if (!session_proto.has_wrapped_key() ||
session_proto.wrapped_key().empty()) {
invalid_session_found = true;
break;
}
std::unique_ptr<Session> session =
Session::CreateFromProto(session_proto);
if (!session) {
invalid_session_found = true;
break;
}
site_sessions.emplace(SessionKey{site, session->id()},
std::move(session));
}
if (invalid_session_found) {
keys_to_delete.push_back(site_str);
} else {
all_sessions.merge(site_sessions);
}
}
return all_sessions;
}
void SessionStoreImpl::SetShutdownCallbackForTesting(
base::OnceClosure shutdown_callback) {
shutdown_callback_ = std::move(shutdown_callback);
}
void SessionStoreImpl::SaveSession(const SchemefulSite& site,
const Session& session) {
if (db_status_ != DBStatus::kSuccess) {
return;
}
CHECK(session.unexportable_key_id().has_value());
ServiceErrorOr<std::vector<uint8_t>> wrapped_key =
key_service_->GetWrappedKey(*session.unexportable_key_id());
if (!wrapped_key.has_value()) {
return;
}
proto::Session session_proto = session.ToProto();
session_proto.set_wrapped_key(
std::string(wrapped_key->begin(), wrapped_key->end()));
proto::SiteSessions site_proto;
std::string site_str = site.Serialize();
session_data_->TryGetData(site_str, &site_proto);
(*site_proto.mutable_sessions())[session_proto.id()] =
std::move(session_proto);
session_data_->UpdateData(site_str, site_proto);
}
void SessionStoreImpl::DeleteSession(const SessionKey& key) {
if (db_status_ != DBStatus::kSuccess) {
return;
}
proto::SiteSessions site_proto;
std::string site_str = key.site.Serialize();
if (!session_data_->TryGetData(site_str, &site_proto)) {
return;
}
if (site_proto.sessions().count(*key.id) == 0) {
return;
}
if (site_proto.mutable_sessions()->size() == 1) {
session_data_->DeleteData({site_str});
return;
}
site_proto.mutable_sessions()->erase(*key.id);
session_data_->UpdateData(key.site.Serialize(), site_proto);
}
SessionStore::SessionsMap SessionStoreImpl::GetAllSessions() const {
if (db_status_ != DBStatus::kSuccess) {
return SessionsMap();
}
std::vector<std::string> keys_to_delete;
SessionsMap all_sessions = CreateSessionsFromLoadedData(
session_data_->GetAllCached(), keys_to_delete);
CHECK(keys_to_delete.empty());
return all_sessions;
}
void SessionStoreImpl::RestoreSessionBindingKey(
const SessionKey& session_key,
RestoreSessionBindingKeyCallback callback) {
auto key_id_or_error = base::unexpected(ServiceError::kKeyNotFound);
if (db_status_ != DBStatus::kSuccess) {
std::move(callback).Run(key_id_or_error);
return;
}
proto::SiteSessions site_proto;
if (session_data_->TryGetData(session_key.site.Serialize(), &site_proto)) {
auto it = site_proto.sessions().find(*session_key.id);
if (it != site_proto.sessions().end()) {
auto [callbacks_it, inserted] =
restore_callbacks_.try_emplace(session_key);
callbacks_it->second.emplace_back(std::move(callback));
if (!inserted) {
return;
}
std::vector<uint8_t> wrapped_key(it->second.wrapped_key().begin(),
it->second.wrapped_key().end());
key_service_->FromWrappedSigningKeySlowlyAsync(
wrapped_key, BackgroundTaskPriority::kUserVisible,
base::BindOnce(&SessionStoreImpl::OnSessionBindingKeyRestored,
weak_ptr_factory_.GetWeakPtr(), session_key));
return;
}
}
std::move(callback).Run(key_id_or_error);
}
void SessionStoreImpl::OnSessionBindingKeyRestored(
const SessionKey& session_key,
unexportable_keys::ServiceErrorOr<unexportable_keys::UnexportableKeyId>
key_or_error) {
auto it = restore_callbacks_.find(session_key);
if (it == restore_callbacks_.end()) {
return;
}
auto callbacks = std::move(it->second);
for (auto& callback : callbacks) {
std::move(callback).Run(key_or_error);
}
restore_callbacks_.erase(it);
}
}