#include "content/browser/aggregation_service/aggregation_service_key_fetcher.h"
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#include "base/containers/circular_deque.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/rand_util.h"
#include "content/browser/aggregation_service/aggregation_service_storage.h"
#include "content/browser/aggregation_service/aggregation_service_storage_context.h"
#include "services/network/public/cpp/is_potentially_trustworthy.h"
#include "url/gurl.h"
namespace content {
AggregationServiceKeyFetcher::AggregationServiceKeyFetcher(
AggregationServiceStorageContext* storage_context,
std::unique_ptr<NetworkFetcher> network_fetcher)
: storage_context_(storage_context),
network_fetcher_(std::move(network_fetcher)) {}
AggregationServiceKeyFetcher::~AggregationServiceKeyFetcher() = default;
void AggregationServiceKeyFetcher::GetPublicKey(const GURL& url,
FetchCallback callback) {
CHECK(network::IsUrlPotentiallyTrustworthy(url));
base::circular_deque<FetchCallback>& pending_callbacks = url_callbacks_[url];
bool in_progress = !pending_callbacks.empty();
pending_callbacks.push_back(std::move(callback));
if (in_progress) {
return;
}
storage_context_->GetStorage()
.AsyncCall(&AggregationServiceStorage::GetPublicKeys)
.WithArgs(url)
.Then(base::BindOnce(
&AggregationServiceKeyFetcher::OnPublicKeysReceivedFromStorage,
weak_factory_.GetWeakPtr(), url));
}
void AggregationServiceKeyFetcher::OnPublicKeysReceivedFromStorage(
const GURL& url,
std::vector<PublicKey> keys) {
if (keys.empty()) {
FetchPublicKeysFromNetwork(url);
return;
}
RunCallbacksForUrl(url, keys);
}
void AggregationServiceKeyFetcher::FetchPublicKeysFromNetwork(const GURL& url) {
if (!network_fetcher_) {
RunCallbacksForUrl(url, {});
return;
}
network_fetcher_->FetchPublicKeys(
url, base::BindOnce(
&AggregationServiceKeyFetcher::OnPublicKeysReceivedFromNetwork,
base::Unretained(this), url));
}
void AggregationServiceKeyFetcher::OnPublicKeysReceivedFromNetwork(
const GURL& url,
std::optional<PublicKeyset> keyset) {
if (!keyset.has_value() || keyset->expiry_time.is_null()) {
storage_context_->GetStorage()
.AsyncCall(&AggregationServiceStorage::ClearPublicKeys)
.WithArgs(url);
} else {
storage_context_->GetStorage()
.AsyncCall(&AggregationServiceStorage::SetPublicKeys)
.WithArgs(url, keyset.value());
}
RunCallbacksForUrl(
url, keyset.has_value() ? keyset->keys : std::vector<PublicKey>());
}
void AggregationServiceKeyFetcher::RunCallbacksForUrl(
const GURL& url,
const std::vector<PublicKey>& keys) {
auto iter = url_callbacks_.find(url);
CHECK(iter != url_callbacks_.end());
base::circular_deque<FetchCallback> pending_callbacks =
std::move(iter->second);
CHECK(!pending_callbacks.empty());
url_callbacks_.erase(iter);
if (keys.empty()) {
for (auto& callback : pending_callbacks) {
std::move(callback).Run(std::nullopt,
PublicKeyFetchStatus::kPublicKeyFetchFailed);
}
} else {
for (auto& callback : pending_callbacks) {
uint64_t key_index = base::RandGenerator(keys.size());
std::move(callback).Run(keys[key_index], PublicKeyFetchStatus::kOk);
}
}
}
}