#ifndef NET_DNS_HOST_CACHE_H_
#define NET_DNS_HOST_CACHE_H_
#include <stddef.h>
#include <functional>
#include <map>
#include <memory>
#include <optional>
#include <ostream>
#include <set>
#include <string>
#include <string_view>
#include <tuple>
#include <utility>
#include <variant>
#include <vector>
#include "base/check.h"
#include "base/gtest_prod_util.h"
#include "base/memory/raw_ptr.h"
#include "base/numerics/clamped_math.h"
#include "base/threading/thread_checker.h"
#include "base/time/time.h"
#include "base/values.h"
#include "net/base/address_family.h"
#include "net/base/connection_endpoint_metadata.h"
#include "net/base/expiring_cache.h"
#include "net/base/host_port_pair.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/base/net_export.h"
#include "net/base/network_anonymization_key.h"
#include "net/dns/public/dns_query_type.h"
#include "net/dns/public/host_resolver_results.h"
#include "net/dns/public/host_resolver_source.h"
#include "net/log/net_log_capture_mode.h"
#include "url/scheme_host_port.h"
namespace base {
class TickClock;
}
namespace net {
class HostResolverInternalResult;
class NET_EXPORT HostCache {
public:
struct NET_EXPORT Key {
Key(std::variant<url::SchemeHostPort, std::string> host,
DnsQueryType dns_query_type,
HostResolverFlags host_resolver_flags,
HostResolverSource host_resolver_source,
const NetworkAnonymizationKey& network_anonymization_key
#if BUILDFLAG(ARKWEB_CUSTOM_DNS)
,
bool secure = false,
bool external_added = false
#endif
);
Key();
Key(const Key& key);
Key(Key&& key);
~Key();
static auto GetTuple(const Key* key) {
return std::tie(key->dns_query_type, key->host_resolver_flags, key->host,
key->host_resolver_source, key->network_anonymization_key,
key->secure);
}
#if BUILDFLAG(ARKWEB_CUSTOM_DNS)
static auto GetTuple2(const Key* key) {
return std::tie(key->dns_query_type, key->host_resolver_flags, key->host,
key->host_resolver_source);
}
#endif
bool operator==(const Key& other) const {
#if BUILDFLAG(ARKWEB_CUSTOM_DNS)
if (external_added || other.external_added) {
return GetTuple2(this) == GetTuple2(&other);
}
#endif
return GetTuple(this) == GetTuple(&other);
}
bool operator<(const Key& other) const {
#if BUILDFLAG(ARKWEB_CUSTOM_DNS)
if (external_added || other.external_added) {
return GetTuple2(this) < GetTuple2(&other);
}
#endif
return GetTuple(this) < GetTuple(&other);
}
std::variant<url::SchemeHostPort, std::string> host;
DnsQueryType dns_query_type = DnsQueryType::UNSPECIFIED;
HostResolverFlags host_resolver_flags = 0;
HostResolverSource host_resolver_source = HostResolverSource::ANY;
NetworkAnonymizationKey network_anonymization_key;
bool secure = false;
#if BUILDFLAG(ARKWEB_CUSTOM_DNS)
bool external_added = false;
#endif
};
struct NET_EXPORT EntryStaleness {
base::TimeDelta expired_by;
int network_changes;
int stale_hits;
bool is_stale() const {
return network_changes > 0 || expired_by >= base::TimeDelta();
}
};
class NET_EXPORT Entry {
public:
enum Source : int {
SOURCE_UNKNOWN,
SOURCE_DNS,
SOURCE_HOSTS,
SOURCE_CONFIG,
};
template <typename T>
Entry(int error,
T&& results,
Source source,
std::optional<base::TimeDelta> ttl)
: error_(error),
source_(source),
ttl_(ttl ? ttl.value() : kUnknownTtl) {
DCHECK(!ttl || ttl.value() >= base::TimeDelta());
SetResult(std::forward<T>(results));
}
template <typename T>
Entry(int error, T&& results, Source source)
: Entry(error, std::forward<T>(results), source, std::nullopt) {}
Entry(int error,
std::vector<IPEndPoint> ip_endpoints,
std::set<std::string> aliases,
Source source,
std::optional<base::TimeDelta> ttl = std::nullopt);
Entry(int error,
Source source,
std::optional<base::TimeDelta> ttl = std::nullopt);
Entry(const std::set<std::unique_ptr<HostResolverInternalResult>>& results,
base::Time now,
base::TimeTicks now_ticks,
Source empty_source = SOURCE_UNKNOWN);
Entry(const Entry& entry);
Entry(Entry&& entry);
~Entry();
Entry& operator=(const Entry& entry);
Entry& operator=(Entry&& entry);
bool operator==(const Entry& other) const {
return ContentsEqual(other) &&
std::tie(source_, pinning_, ttl_, expires_, network_changes_,
total_hits_, stale_hits_) ==
std::tie(other.source_, other.pinning_, other.ttl_,
other.expires_, other.network_changes_,
other.total_hits_, other.stale_hits_);
}
bool ContentsEqual(const Entry& other) const {
return std::tie(error_, ip_endpoints_, endpoint_metadatas_, aliases_,
text_records_, hostnames_, canonical_names_) ==
std::tie(other.error_, other.ip_endpoints_,
other.endpoint_metadatas_, other.aliases_,
other.text_records_, other.hostnames_,
other.canonical_names_);
}
int error() const { return error_; }
bool did_complete() const {
return error_ != ERR_NETWORK_CHANGED &&
error_ != ERR_HOST_RESOLVER_QUEUE_TOO_LARGE;
}
void set_error(int error) { error_ = error; }
std::vector<HostResolverEndpointResult> GetEndpoints() const;
const std::vector<IPEndPoint>& ip_endpoints() const {
return ip_endpoints_;
}
void set_ip_endpoints(std::vector<IPEndPoint> ip_endpoints) {
ip_endpoints_ = std::move(ip_endpoints);
}
std::vector<ConnectionEndpointMetadata> GetMetadatas() const;
void ClearMetadatas() { endpoint_metadatas_.clear(); }
const std::set<std::string>& aliases() const { return aliases_; }
void set_aliases(std::set<std::string> aliases) {
aliases_ = std::move(aliases);
}
const std::vector<std::string>& text_records() const {
return text_records_;
}
void set_text_records(std::vector<std::string> text_records) {
text_records_ = std::move(text_records);
}
const std::vector<HostPortPair>& hostnames() const { return hostnames_; }
void set_hostnames(std::vector<HostPortPair> hostnames) {
hostnames_ = std::move(hostnames);
}
std::optional<bool> pinning() const { return pinning_; }
void set_pinning(std::optional<bool> pinning) { pinning_ = pinning; }
const std::set<std::string>& canonical_names() const {
return canonical_names_;
}
void set_canonical_names(std::set<std::string> canonical_names) {
canonical_names_ = std::move(canonical_names);
}
Source source() const { return source_; }
bool has_ttl() const { return ttl_ >= base::TimeDelta(); }
base::TimeDelta ttl() const { return ttl_; }
std::optional<base::TimeDelta> GetOptionalTtl() const;
void set_ttl(base::TimeDelta ttl) { ttl_ = ttl; }
base::TimeTicks expires() const { return expires_; }
int network_changes() const { return network_changes_; }
static Entry MergeEntries(Entry front, Entry back);
base::Value NetLogParams() const;
HostCache::Entry CopyWithDefaultPort(uint16_t port) const;
std::vector<ServiceEndpoint> ConvertToServiceEndpoints(uint16_t port) const;
static std::optional<base::TimeDelta> TtlFromInternalResults(
const std::set<std::unique_ptr<HostResolverInternalResult>>& results,
base::Time now,
base::TimeTicks now_ticks);
private:
using HttpsRecordPriority = uint16_t;
friend class HostCache;
static constexpr base::TimeDelta kUnknownTtl = base::Seconds(-1);
Entry(const Entry& entry,
base::TimeTicks now,
base::TimeDelta ttl,
int network_changes);
Entry(int error,
std::vector<IPEndPoint> ip_endpoints,
std::multimap<HttpsRecordPriority, ConnectionEndpointMetadata>
endpoint_metadatas,
std::set<std::string> aliases,
std::vector<std::string>&& text_results,
std::vector<HostPortPair>&& hostnames,
Source source,
base::TimeTicks expires,
int network_changes);
void SetResult(
std::multimap<HttpsRecordPriority, ConnectionEndpointMetadata>
endpoint_metadatas) {
endpoint_metadatas_ = std::move(endpoint_metadatas);
}
void SetResult(std::vector<std::string> text_records) {
text_records_ = std::move(text_records);
}
void SetResult(std::vector<HostPortPair> hostnames) {
hostnames_ = std::move(hostnames);
}
int total_hits() const { return total_hits_; }
int stale_hits() const { return stale_hits_; }
bool IsStale(base::TimeTicks now, int network_changes) const;
void CountHit(bool hit_is_stale);
void GetStaleness(base::TimeTicks now,
int network_changes,
EntryStaleness* out) const;
base::Value::Dict GetAsValue(bool include_staleness) const;
int error_ = ERR_FAILED;
std::vector<IPEndPoint> ip_endpoints_;
std::multimap<HttpsRecordPriority, ConnectionEndpointMetadata>
endpoint_metadatas_;
std::set<std::string> aliases_;
std::vector<std::string> text_records_;
std::vector<HostPortPair> hostnames_;
Source source_ = SOURCE_UNKNOWN;
std::optional<bool> pinning_;
std::set<std::string> canonical_names_;
base::TimeDelta ttl_ = kUnknownTtl;
base::TimeTicks expires_;
int network_changes_ = -1;
base::ClampedNumeric<int> total_hits_ = 0;
base::ClampedNumeric<int> stale_hits_ = 0;
};
class PersistenceDelegate {
public:
virtual void ScheduleWrite() = 0;
};
using EntryMap = std::map<Key, Entry>;
enum class SerializationType {
kRestorable,
kDebug,
};
static const HostCache::EntryStaleness kNotStale;
explicit HostCache(size_t max_entries);
HostCache(const HostCache&) = delete;
HostCache& operator=(const HostCache&) = delete;
~HostCache();
const std::pair<const Key, Entry>* Lookup(const Key& key,
base::TimeTicks now,
bool ignore_secure = false);
const std::pair<const Key, Entry>* LookupStale(const Key& key,
base::TimeTicks now,
EntryStaleness* stale_out,
bool ignore_secure = false);
void Set(const Key& key,
const Entry& entry,
base::TimeTicks now,
base::TimeDelta ttl);
const HostCache::Key* GetMatchingKeyForTesting(
std::string_view hostname,
HostCache::Entry::Source* source_out = nullptr,
HostCache::EntryStaleness* stale_out = nullptr) const;
void Invalidate();
void set_persistence_delegate(PersistenceDelegate* delegate);
void set_tick_clock_for_testing(const base::TickClock* tick_clock) {
tick_clock_ = tick_clock;
}
void clear();
void ClearForHosts(
const base::RepeatingCallback<bool(const std::string&)>& host_filter);
void GetList(base::Value::List& entry_list,
bool include_staleness,
SerializationType serialization_type) const;
bool RestoreFromListValue(const base::Value::List& old_cache);
size_t last_restore_size() const { return restore_size_; }
size_t size() const;
size_t max_entries() const;
int network_changes() const { return network_changes_; }
const EntryMap& entries() const { return entries_; }
#if BUILDFLAG(ARKWEB_LOGGER_REPORT)
std::vector<IPEndPoint> LookupByHost(url::SchemeHostPort destination);
#endif
private:
FRIEND_TEST_ALL_PREFIXES(HostCacheTest, NoCache);
enum SetOutcome : int;
enum class LookupOutcome {
kLookupMissAbsent = 0,
kLookupMissStale = 1,
kLookupHitValid = 2,
kLookupHitStale = 3,
kMaxValue = kLookupHitStale
};
enum class EraseReason {
kEraseEvict = 0,
kEraseClear = 1,
kEraseDestruct = 2,
kMaxValue = kEraseDestruct
};
static std::pair<const HostCache::Key, HostCache::Entry>*
GetLessStaleMoreSecureResult(
base::TimeTicks now,
std::pair<const HostCache::Key, HostCache::Entry>* result1,
std::pair<const HostCache::Key, HostCache::Entry>* result2);
std::pair<const Key, Entry>* LookupInternalIgnoringFields(
const Key& initial_key,
base::TimeTicks now,
bool ignore_secure);
std::pair<const Key, Entry>* LookupInternal(const Key& key);
void RecordLookup(LookupOutcome outcome,
base::TimeTicks now,
const Key& key,
const Entry* entry);
void RecordErase(EraseReason reason,
base::TimeTicks now,
const Key& key,
const Entry& entry);
void RecordEraseAll(EraseReason reason, base::TimeTicks now);
bool caching_is_disabled() const { return max_entries_ == 0; }
bool EvictOneEntry(base::TimeTicks now);
bool HasActivePin(const Entry& entry);
void AddEntry(const Key& key, Entry&& entry);
EntryMap entries_;
size_t max_entries_;
int network_changes_ = 0;
size_t restore_size_ = 0;
raw_ptr<PersistenceDelegate> delegate_ = nullptr;
raw_ptr<const base::TickClock> tick_clock_;
THREAD_CHECKER(thread_checker_);
};
}
std::ostream& operator<<(std::ostream& out,
const net::HostCache::EntryStaleness& s);
#endif