#include "net/first_party_sets/global_first_party_sets.h"
#include <algorithm>
#include <iterator>
#include <map>
#include <optional>
#include <set>
#include <tuple>
#include <utility>
#include "base/containers/contains.h"
#include "base/containers/flat_map.h"
#include "base/containers/flat_set.h"
#include "base/containers/map_util.h"
#include "base/functional/function_ref.h"
#include "base/types/optional_ref.h"
#include "base/types/optional_util.h"
#include "net/base/schemeful_site.h"
#include "net/first_party_sets/addition_overlaps_union_find.h"
#include "net/first_party_sets/first_party_set_entry.h"
#include "net/first_party_sets/first_party_set_entry_override.h"
#include "net/first_party_sets/first_party_set_metadata.h"
#include "net/first_party_sets/first_party_sets_context_config.h"
#include "net/first_party_sets/first_party_sets_validator.h"
#include "net/first_party_sets/local_set_declaration.h"
namespace net {
namespace {
using FlattenedSets = base::flat_map<SchemefulSite, FirstPartySetEntry>;
using SingleSet = base::flat_map<SchemefulSite, FirstPartySetEntry>;
FlattenedSets Flatten(const std::vector<SingleSet>& set_list) {
FlattenedSets sets;
for (const auto& set : set_list) {
for (const auto& site_and_entry : set) {
bool inserted = sets.emplace(site_and_entry).second;
CHECK(inserted);
}
}
return sets;
}
std::pair<SchemefulSite, FirstPartySetEntryOverride>
SiteAndEntryToSiteAndOverride(
const std::pair<SchemefulSite, FirstPartySetEntry>& pair) {
return std::make_pair(pair.first, FirstPartySetEntryOverride(pair.second));
}
}
GlobalFirstPartySets::GlobalFirstPartySets() = default;
GlobalFirstPartySets::GlobalFirstPartySets(
base::Version public_sets_version,
base::flat_map<SchemefulSite, FirstPartySetEntry> entries,
base::flat_map<SchemefulSite, SchemefulSite> aliases)
: GlobalFirstPartySets(
public_sets_version,
public_sets_version.IsValid()
? std::move(entries)
: base::flat_map<SchemefulSite, FirstPartySetEntry>(),
public_sets_version.IsValid()
? std::move(aliases)
: base::flat_map<SchemefulSite, SchemefulSite>(),
FirstPartySetsContextConfig()) {}
GlobalFirstPartySets::GlobalFirstPartySets(
base::Version public_sets_version,
base::flat_map<SchemefulSite, FirstPartySetEntry> entries,
base::flat_map<SchemefulSite, SchemefulSite> aliases,
FirstPartySetsContextConfig manual_config)
: public_sets_version_(std::move(public_sets_version)),
entries_(std::move(entries)),
aliases_(std::move(aliases)),
manual_config_(std::move(manual_config)) {
if (!public_sets_version_.IsValid()) {
CHECK(entries_.empty());
CHECK(aliases_.empty());
}
CHECK(std::ranges::all_of(aliases_, [&](const auto& pair) {
return entries_.contains(pair.second);
}));
CHECK(IsValid()) << "Sets must be valid";
}
GlobalFirstPartySets::GlobalFirstPartySets(GlobalFirstPartySets&&) = default;
GlobalFirstPartySets& GlobalFirstPartySets::operator=(GlobalFirstPartySets&&) =
default;
GlobalFirstPartySets::~GlobalFirstPartySets() = default;
bool GlobalFirstPartySets::operator==(const GlobalFirstPartySets& other) const =
default;
GlobalFirstPartySets GlobalFirstPartySets::Clone() const {
return GlobalFirstPartySets(public_sets_version_, entries_, aliases_,
manual_config_.Clone());
}
std::optional<FirstPartySetEntry> GlobalFirstPartySets::FindEntry(
const SchemefulSite& site,
const FirstPartySetsContextConfig& config) const {
return FindEntry(site, &config);
}
std::optional<FirstPartySetEntry> GlobalFirstPartySets::FindEntry(
const SchemefulSite& site,
const FirstPartySetsContextConfig* config) const {
if (config) {
if (const auto override = config->FindOverride(site);
override.has_value()) {
return override->IsDeletion() ? std::nullopt
: std::make_optional(override->GetEntry());
}
}
if (const auto manual_override = manual_config_.FindOverride(site);
manual_override.has_value()) {
return manual_override->IsDeletion()
? std::nullopt
: std::make_optional(manual_override->GetEntry());
}
return base::OptionalFromPtr(base::FindOrNull(entries_, ResolveAlias(site)));
}
base::flat_map<SchemefulSite, FirstPartySetEntry>
GlobalFirstPartySets::FindEntries(
const base::flat_set<SchemefulSite>& sites,
const FirstPartySetsContextConfig& config) const {
std::vector<std::pair<SchemefulSite, FirstPartySetEntry>> sites_to_entries;
for (const SchemefulSite& site : sites) {
const std::optional<FirstPartySetEntry> entry = FindEntry(site, config);
if (entry.has_value()) {
sites_to_entries.emplace_back(site, entry.value());
}
}
return sites_to_entries;
}
FirstPartySetMetadata GlobalFirstPartySets::ComputeMetadata(
const SchemefulSite& site,
base::optional_ref<const SchemefulSite> top_frame_site,
const FirstPartySetsContextConfig& fps_context_config) const {
return FirstPartySetMetadata(
FindEntry(site, fps_context_config),
top_frame_site ? FindEntry(*top_frame_site, fps_context_config)
: std::nullopt);
}
void GlobalFirstPartySets::ApplyManuallySpecifiedSet(
const LocalSetDeclaration& local_set_declaration) {
CHECK(manual_config_.empty());
if (local_set_declaration.empty()) {
return;
}
manual_config_ = ComputeConfig(local_set_declaration.ComputeMutation());
CHECK(IsValid()) << "Sets must be valid";
}
void GlobalFirstPartySets::UnsafeSetManualConfig(
FirstPartySetsContextConfig manual_config) {
CHECK(manual_config_.empty());
manual_config_ = std::move(manual_config);
}
base::flat_map<SchemefulSite, FirstPartySetEntry>
GlobalFirstPartySets::FindPrimariesAffectedByAdditions(
const FlattenedSets& additions) const {
std::vector<std::pair<SchemefulSite, FirstPartySetEntry>>
addition_intersected_primaries;
for (const auto& [new_member, new_entry] : additions) {
if (const auto entry = FindEntry(new_member, nullptr);
entry.has_value()) {
addition_intersected_primaries.emplace_back(entry->primary(), new_entry);
}
}
return addition_intersected_primaries;
}
std::pair<base::flat_map<SchemefulSite, base::flat_set<SchemefulSite>>,
base::flat_set<SchemefulSite>>
GlobalFirstPartySets::FindPrimariesAffectedByReplacements(
const FlattenedSets& replacements,
const FlattenedSets& additions,
const base::flat_map<SchemefulSite, FirstPartySetEntry>&
addition_intersected_primaries) const {
if (replacements.empty()) {
return {{}, {}};
}
std::map<SchemefulSite, std::set<SchemefulSite>> canonical_to_aliases;
ForEachAlias([&](const SchemefulSite& alias, const SchemefulSite& canonical) {
canonical_to_aliases[canonical].insert(alias);
});
const auto for_all_variants =
[this, canonical_to_aliases = std::move(canonical_to_aliases)](
const SchemefulSite& site,
const base::FunctionRef<void(const SchemefulSite&)> f) {
const SchemefulSite& canonical = ResolveAlias(site);
f(canonical);
if (const std::set<SchemefulSite>* aliases =
base::FindOrNull(canonical_to_aliases, canonical)) {
for (const auto& alias : *aliases) {
f(alias);
}
}
};
base::flat_map<SchemefulSite, base::flat_set<SchemefulSite>>
potential_singletons;
base::flat_set<SchemefulSite> replaced_existing_primaries;
for (const auto& [new_site, unused_entry] : replacements) {
const auto existing_entry = FindEntry(new_site, nullptr);
if (!existing_entry.has_value()) {
continue;
}
if (!addition_intersected_primaries.contains(existing_entry->primary()) &&
!additions.contains(existing_entry->primary()) &&
!replacements.contains(existing_entry->primary())) {
for_all_variants(new_site, [&](const SchemefulSite& variant) {
if (existing_entry->primary() != variant) {
potential_singletons[existing_entry->primary()].insert(variant);
}
});
}
if (existing_entry->primary() == new_site) {
bool inserted =
replaced_existing_primaries.emplace(existing_entry->primary()).second;
CHECK(inserted);
}
}
return std::make_pair(potential_singletons, replaced_existing_primaries);
}
FirstPartySetsContextConfig GlobalFirstPartySets::ComputeConfig(
SetsMutation mutation) const {
if (std::ranges::all_of(mutation.replacements(), &SingleSet::empty) &&
std::ranges::all_of(mutation.additions(), &SingleSet::empty)) {
return FirstPartySetsContextConfig();
}
const FlattenedSets replacements = Flatten(mutation.replacements());
const FlattenedSets additions =
Flatten(NormalizeAdditionSets(mutation.additions()));
std::vector<std::pair<SchemefulSite, FirstPartySetEntryOverride>>
site_to_override;
std::ranges::transform(replacements, std::back_inserter(site_to_override),
SiteAndEntryToSiteAndOverride);
std::ranges::transform(additions, std::back_inserter(site_to_override),
SiteAndEntryToSiteAndOverride);
const base::flat_map<SchemefulSite, FirstPartySetEntry>
addition_intersected_primaries =
FindPrimariesAffectedByAdditions(additions);
auto [potential_singletons, replaced_existing_primaries] =
FindPrimariesAffectedByReplacements(replacements, additions,
addition_intersected_primaries);
if (!addition_intersected_primaries.empty() ||
!potential_singletons.empty() || !replaced_existing_primaries.empty()) {
ForEachEffectiveSetEntry(
std::nullopt,
[&](const SchemefulSite& member, const FirstPartySetEntry& set_entry) {
if (const FirstPartySetEntry* entry = base::FindOrNull(
addition_intersected_primaries, set_entry.primary());
entry && !replacements.contains(member)) {
site_to_override.emplace_back(
member, FirstPartySetEntry(entry->primary(),
member == entry->primary()
? SiteType::kPrimary
: SiteType::kAssociated));
}
if (member == set_entry.primary())
return true;
if (const auto singletons_it =
potential_singletons.find(set_entry.primary());
singletons_it != potential_singletons.end() &&
!singletons_it->second.contains(member)) {
potential_singletons.erase(singletons_it);
}
if (replaced_existing_primaries.contains(set_entry.primary()) &&
!replacements.contains(member) &&
!addition_intersected_primaries.contains(set_entry.primary())) {
site_to_override.emplace_back(member, FirstPartySetEntryOverride());
}
return true;
});
for (const auto& [primary, members] : potential_singletons) {
site_to_override.emplace_back(primary, FirstPartySetEntryOverride());
}
}
ForEachAlias([&](const SchemefulSite& alias, const SchemefulSite& canonical) {
if (base::Contains(
site_to_override, canonical,
&std::pair<SchemefulSite, FirstPartySetEntryOverride>::first) &&
!base::Contains(
site_to_override, alias,
&std::pair<SchemefulSite, FirstPartySetEntryOverride>::first)) {
site_to_override.emplace_back(alias, FirstPartySetEntryOverride());
}
});
CHECK(std::ranges::none_of(
mutation.aliases(), [&](const auto& alias_pair) -> bool {
const auto alias_override_it = std::ranges::find_if(
site_to_override, [&](const auto& site_override_pair) -> bool {
return site_override_pair.first == alias_pair.first;
});
return alias_override_it == site_to_override.end() ||
alias_override_it->second.IsDeletion();
}));
std::optional<FirstPartySetsContextConfig> config =
FirstPartySetsContextConfig::Create(std::move(site_to_override),
mutation.aliases());
CHECK(config.has_value());
CHECK(IsValid(config)) << "Sets must not contain singleton or orphan";
return std::move(config).value();
}
std::vector<base::flat_map<SchemefulSite, FirstPartySetEntry>>
GlobalFirstPartySets::NormalizeAdditionSets(
const std::vector<base::flat_map<SchemefulSite, FirstPartySetEntry>>&
addition_sets) const {
if (std::ranges::all_of(addition_sets, &SingleSet::empty)) {
return {};
}
base::flat_map<SchemefulSite, base::flat_set<size_t>> addition_set_overlaps;
for (size_t set_idx = 0; set_idx < addition_sets.size(); set_idx++) {
for (const auto& site_and_entry : addition_sets[set_idx]) {
if (const auto entry =
FindEntry(site_and_entry.first, nullptr);
entry.has_value()) {
addition_set_overlaps[entry->primary()].insert(set_idx);
}
}
}
AdditionOverlapsUnionFind union_finder(addition_sets.size());
for (const auto& [public_site, addition_set_indices] :
addition_set_overlaps) {
for (size_t representative : addition_set_indices) {
union_finder.Union(*addition_set_indices.begin(), representative);
}
}
std::vector<SingleSet> normalized_additions;
for (const auto& [rep, children] : union_finder.SetsMapping()) {
SingleSet normalized = addition_sets[rep];
const SchemefulSite& rep_primary =
addition_sets[rep].begin()->second.primary();
for (size_t child_set_idx : children) {
for (const auto& child_site_and_entry : addition_sets[child_set_idx]) {
bool inserted =
normalized
.emplace(child_site_and_entry.first,
FirstPartySetEntry(rep_primary, SiteType::kAssociated))
.second;
CHECK(inserted);
}
}
normalized_additions.push_back(normalized);
}
return normalized_additions;
}
bool GlobalFirstPartySets::ForEachPublicSetEntry(
base::FunctionRef<bool(const SchemefulSite&, const FirstPartySetEntry&)> f)
const {
for (const auto& [site, entry] : entries_) {
if (!f(site, entry))
return false;
}
for (const auto& [alias, canonical] : aliases_) {
const FirstPartySetEntry* entry = base::FindOrNull(entries_, canonical);
CHECK(entry);
if (!f(alias, *entry)) {
return false;
}
}
return true;
}
bool GlobalFirstPartySets::ForEachManualConfigEntry(
base::FunctionRef<bool(const SchemefulSite&,
const FirstPartySetEntryOverride&)> f) const {
return manual_config_.ForEachCustomizationEntry(f);
}
bool GlobalFirstPartySets::ForEachEffectiveSetEntry(
const FirstPartySetsContextConfig& config,
base::FunctionRef<bool(const SchemefulSite&, const FirstPartySetEntry&)> f)
const {
return ForEachEffectiveSetEntry(&config, f);
}
bool GlobalFirstPartySets::ForEachEffectiveSetEntry(
base::optional_ref<const FirstPartySetsContextConfig> config,
base::FunctionRef<bool(const SchemefulSite&, const FirstPartySetEntry&)> f)
const {
if (config) {
if (!config->ForEachCustomizationEntry(
[&](const SchemefulSite& site,
const FirstPartySetEntryOverride& override) {
if (!override.IsDeletion())
return f(site, override.GetEntry());
return true;
})) {
return false;
}
}
if (!manual_config_.ForEachCustomizationEntry(
[&](const SchemefulSite& site,
const FirstPartySetEntryOverride& override) {
if (!override.IsDeletion() && (!config || !config->Contains(site)))
return f(site, override.GetEntry());
return true;
})) {
return false;
}
return ForEachPublicSetEntry([&](const SchemefulSite& site,
const FirstPartySetEntry& entry) {
if ((!config || !config->Contains(site)) && !manual_config_.Contains(site))
return f(site, entry);
return true;
});
}
void GlobalFirstPartySets::ForEachAlias(
base::FunctionRef<void(const SchemefulSite&, const SchemefulSite&)> f)
const {
manual_config_.ForEachAlias(f);
for (const auto& [alias, site] : aliases_) {
if (manual_config_.Contains(alias)) {
continue;
}
f(alias, site);
}
}
bool GlobalFirstPartySets::IsValid(
base::optional_ref<const FirstPartySetsContextConfig> config) const {
FirstPartySetsValidator validator;
ForEachEffectiveSetEntry(
config,
[&](const SchemefulSite& site, const FirstPartySetEntry& entry) -> bool {
validator.Update(site, entry.primary());
return true;
});
return validator.IsValid();
}
const SchemefulSite& GlobalFirstPartySets::ResolveAlias(
const SchemefulSite& site) const {
const SchemefulSite* canonical = base::FindOrNull(aliases_, site);
return canonical ? *canonical : site;
}
std::ostream& operator<<(std::ostream& os, const GlobalFirstPartySets& sets) {
os << "{entries = {";
for (const auto& [site, entry] : sets.entries_) {
os << "{" << site.Serialize() << ": " << entry << "}, ";
}
os << "}, aliases = {";
for (const auto& [alias, canonical] : sets.aliases_) {
os << "{" << alias.Serialize() << ": " << canonical.Serialize() << "}, ";
}
os << "}, manual_config = {";
sets.ForEachManualConfigEntry(
[&](const SchemefulSite& site,
const FirstPartySetEntryOverride& override) {
os << "{" << site.Serialize() << ": " << override << "},";
return true;
});
os << "}}";
return os;
}
}