#include "content/test/test_aggregation_service_impl.h"
#include <string>
#include <utility>
#include "base/check.h"
#include "base/files/file_path.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/logging.h"
#include "base/memory/scoped_refptr.h"
#include "base/task/thread_pool.h"
#include "base/time/clock.h"
#include "base/time/time.h"
#include "base/uuid.h"
#include "base/values.h"
#include "components/aggregation_service/aggregation_service.mojom.h"
#include "content/browser/aggregation_service/aggregatable_report.h"
#include "content/browser/aggregation_service/aggregatable_report_assembler.h"
#include "content/browser/aggregation_service/aggregatable_report_sender.h"
#include "content/browser/aggregation_service/aggregation_service_storage.h"
#include "content/browser/aggregation_service/aggregation_service_storage_sql.h"
#include "content/browser/aggregation_service/aggregation_service_test_utils.h"
#include "content/browser/aggregation_service/public_key.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
#include "third_party/blink/public/mojom/private_aggregation/aggregatable_report.mojom.h"
#include "url/gurl.h"
#include "url/origin.h"
namespace content {
namespace {
AggregationServicePayloadContents::Operation ConvertToOperation(
TestAggregationService::Operation operation) {
switch (operation) {
case TestAggregationService::Operation::kHistogram:
return AggregationServicePayloadContents::Operation::kHistogram;
}
}
blink::mojom::AggregationServiceMode ConvertToAggregationMode(
TestAggregationService::AggregationMode aggregation_mode) {
switch (aggregation_mode) {
case TestAggregationService::AggregationMode::kTeeBased:
return blink::mojom::AggregationServiceMode::kTeeBased;
case TestAggregationService::AggregationMode::kExperimentalPoplar:
return blink::mojom::AggregationServiceMode::kExperimentalPoplar;
}
}
void HandleAggregatableReportCallback(
base::OnceCallback<void(base::Value::Dict)> callback,
AggregatableReportRequest,
absl::optional<AggregatableReport> report,
AggregatableReportAssembler::AssemblyStatus status) {
if (!report.has_value()) {
LOG(ERROR) << "Failed to assemble the report, status: "
<< static_cast<int>(status);
std::move(callback).Run(base::Value::Dict());
return;
}
std::move(callback).Run(report->GetAsJson());
}
}
TestAggregationServiceImpl::TestAggregationServiceImpl(
const base::Clock* clock,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
: clock_(*clock),
storage_(base::SequenceBound<AggregationServiceStorageSql>(
base::ThreadPool::CreateSequencedTaskRunner({base::MayBlock()}),
true,
base::FilePath(),
clock)),
sender_(AggregatableReportSender::CreateForTesting(
url_loader_factory,
true)),
assembler_(AggregatableReportAssembler::CreateForTesting(
this,
url_loader_factory,
true)) {
DCHECK(clock);
}
TestAggregationServiceImpl::~TestAggregationServiceImpl() = default;
const base::SequenceBound<AggregationServiceStorage>&
TestAggregationServiceImpl::GetStorage() {
return storage_;
}
void TestAggregationServiceImpl::SetDisablePayloadEncryption(
bool should_disable) {
content::AggregatableReport::Provider::SetDisableEncryptionForTestingTool(
should_disable);
}
void TestAggregationServiceImpl::SetPublicKeys(
const GURL& url,
const base::FilePath& json_file,
base::OnceCallback<void(bool)> callback) {
std::string error_msg;
absl::optional<PublicKeyset> keyset =
aggregation_service::ReadAndParsePublicKeys(json_file, clock_->Now(),
&error_msg);
if (!keyset) {
LOG(ERROR) << error_msg;
std::move(callback).Run(false);
return;
}
storage_.AsyncCall(&AggregationServiceStorage::SetPublicKeys)
.WithArgs(url, std::move(*keyset))
.Then(base::BindOnce(std::move(callback), true));
}
void TestAggregationServiceImpl::AssembleReport(
AssembleRequest request,
base::OnceCallback<void(base::Value::Dict)> callback) {
AggregationServicePayloadContents payload_contents(
ConvertToOperation(request.operation),
{blink::mojom::AggregatableReportHistogramContribution(
request.bucket, request.value)},
ConvertToAggregationMode(request.aggregation_mode),
::aggregation_service::mojom::AggregationCoordinator::kDefault);
AggregatableReportSharedInfo shared_info(
base::Time::Now() + base::Seconds(30),
base::Uuid::GenerateRandomV4(),
std::move(request.reporting_origin),
request.is_debug_mode_enabled
? AggregatableReportSharedInfo::DebugMode::kEnabled
: AggregatableReportSharedInfo::DebugMode::kDisabled,
std::move(request.additional_fields), std::move(request.api_version),
std::move(request.api_identifier));
absl::optional<AggregatableReportRequest> report_request =
AggregatableReportRequest::CreateForTesting(
std::move(request.processing_urls), std::move(payload_contents),
std::move(shared_info));
if (!report_request.has_value()) {
std::move(callback).Run(base::Value::Dict());
return;
}
assembler_->AssembleReport(
std::move(report_request.value()),
base::BindOnce(HandleAggregatableReportCallback, std::move(callback)));
}
void TestAggregationServiceImpl::SendReport(
const GURL& url,
const base::Value& contents,
base::OnceCallback<void(bool)> callback) {
sender_->SendReport(
url, contents,
base::BindOnce(
[&](base::OnceCallback<void(bool)> callback,
AggregatableReportSender::RequestStatus status) {
std::move(callback).Run(
status == AggregatableReportSender::RequestStatus::kOk);
},
std::move(callback)));
}
void TestAggregationServiceImpl::GetPublicKeys(
const GURL& url,
base::OnceCallback<void(std::vector<PublicKey>)> callback) const {
storage_.AsyncCall(&AggregationServiceStorage::GetPublicKeys)
.WithArgs(url)
.Then(std::move(callback));
}
}