#ifndef NET_DNS_MOCK_HOST_RESOLVER_H_
#define NET_DNS_MOCK_HOST_RESOLVER_H_
#include <stddef.h>
#include <stdint.h>
#include <list>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/strings/string_piece.h"
#include "base/strings/string_piece_forward.h"
#include "base/synchronization/lock.h"
#include "base/synchronization/waitable_event.h"
#include "base/thread_annotations.h"
#include "base/threading/thread_checker.h"
#include "net/base/address_family.h"
#include "net/base/address_list.h"
#include "net/base/completion_once_callback.h"
#include "net/base/host_port_pair.h"
#include "net/base/net_errors.h"
#include "net/base/network_anonymization_key.h"
#include "net/dns/host_resolver.h"
#include "net/dns/host_resolver_proc.h"
#include "net/dns/public/dns_query_type.h"
#include "net/dns/public/host_resolver_results.h"
#include "net/dns/public/host_resolver_source.h"
#include "net/dns/public/mdns_listener_update_type.h"
#include "net/dns/public/secure_dns_policy.h"
#include "net/log/net_log_with_source.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
#include "third_party/abseil-cpp/absl/types/variant.h"
#include "url/scheme_host_port.h"
namespace base {
class TickClock;
}
namespace net {
class HostCache;
class IPEndPoint;
class URLRequestContext;
int ParseAddressList(base::StringPiece host_list,
std::vector<net::IPEndPoint>* ip_endpoints);
class MockHostResolverBase
: public HostResolver,
public base::SupportsWeakPtr<MockHostResolverBase> {
private:
class RequestImpl;
class ProbeRequestImpl;
class MdnsListenerImpl;
public:
class RuleResolver {
public:
struct RuleKey {
struct WildcardScheme : absl::monostate {};
struct NoScheme : absl::monostate {};
using Scheme = std::string;
RuleKey();
~RuleKey();
RuleKey(const RuleKey&);
RuleKey& operator=(const RuleKey&);
RuleKey(RuleKey&&);
RuleKey& operator=(RuleKey&&);
auto GetTuple() const {
return std::tie(scheme, hostname_pattern, port, query_type,
query_source);
}
bool operator<(const RuleKey& other) const {
return GetTuple() < other.GetTuple();
}
absl::variant<WildcardScheme, NoScheme, Scheme> scheme = WildcardScheme();
std::string hostname_pattern = "*";
absl::optional<uint16_t> port;
absl::optional<DnsQueryType> query_type;
absl::optional<HostResolverSource> query_source;
};
struct RuleResult {
RuleResult();
explicit RuleResult(
std::vector<HostResolverEndpointResult> endpoints,
std::set<std::string> aliases = std::set<std::string>());
~RuleResult();
RuleResult(const RuleResult&);
RuleResult& operator=(const RuleResult&);
RuleResult(RuleResult&&);
RuleResult& operator=(RuleResult&&);
std::vector<HostResolverEndpointResult> endpoints;
std::set<std::string> aliases;
};
using ErrorResult = Error;
using RuleResultOrError = absl::variant<RuleResult, ErrorResult>;
explicit RuleResolver(
absl::optional<RuleResultOrError> default_result = absl::nullopt);
~RuleResolver();
RuleResolver(const RuleResolver&);
RuleResolver& operator=(const RuleResolver&);
RuleResolver(RuleResolver&&);
RuleResolver& operator=(RuleResolver&&);
const RuleResultOrError& Resolve(const Host& request_endpoint,
DnsQueryTypeSet request_types,
HostResolverSource request_source) const;
void ClearRules();
static RuleResultOrError GetLocalhostResult();
void AddRule(RuleKey key, RuleResultOrError result);
void AddRule(RuleKey key, base::StringPiece ip_literal);
void AddRule(base::StringPiece hostname_pattern, RuleResultOrError result);
void AddRule(base::StringPiece hostname_pattern,
base::StringPiece ip_literal);
void AddRule(base::StringPiece hostname_pattern, Error error);
void AddIPLiteralRule(base::StringPiece hostname_pattern,
base::StringPiece ip_literal,
base::StringPiece canonical_name);
void AddIPLiteralRuleWithDnsAliases(base::StringPiece hostname_pattern,
base::StringPiece ip_literal,
std::vector<std::string> dns_aliases);
void AddIPLiteralRuleWithDnsAliases(base::StringPiece hostname_pattern,
base::StringPiece ip_literal,
std::set<std::string> dns_aliases);
void AddSimulatedFailure(base::StringPiece hostname_pattern);
void AddSimulatedTimeoutFailure(base::StringPiece hostname_pattern);
void AddRuleWithFlags(base::StringPiece host_pattern,
base::StringPiece ip_literal,
HostResolverFlags flags,
std::vector<std::string> dns_aliases = {});
private:
std::map<RuleKey, RuleResultOrError> rules_;
absl::optional<RuleResultOrError> default_result_;
};
using RequestMap = std::map<size_t, RequestImpl*>;
class State : public base::RefCounted<State> {
public:
State();
bool has_pending_requests() const { return !requests_.empty(); }
bool IsDohProbeRunning() const { return !!doh_probe_request_; }
size_t num_resolve() const { return num_resolve_; }
size_t num_resolve_from_cache() const { return num_resolve_from_cache_; }
size_t num_non_local_resolves() const { return num_non_local_resolves_; }
RequestMap& mutable_requests() { return requests_; }
void IncrementNumResolve() { ++num_resolve_; }
void IncrementNumResolveFromCache() { ++num_resolve_from_cache_; }
void IncrementNumNonLocalResolves() { ++num_non_local_resolves_; }
void ClearDohProbeRequest() { doh_probe_request_ = nullptr; }
void ClearDohProbeRequestIfMatching(ProbeRequestImpl* request) {
if (request == doh_probe_request_) {
doh_probe_request_ = nullptr;
}
}
void set_doh_probe_request(ProbeRequestImpl* request) {
DCHECK(request);
DCHECK(!doh_probe_request_);
doh_probe_request_ = request;
}
private:
friend class RefCounted<State>;
~State();
RequestMap requests_;
raw_ptr<ProbeRequestImpl> doh_probe_request_ = nullptr;
size_t num_resolve_ = 0;
size_t num_resolve_from_cache_ = 0;
size_t num_non_local_resolves_ = 0;
};
MockHostResolverBase(const MockHostResolverBase&) = delete;
MockHostResolverBase& operator=(const MockHostResolverBase&) = delete;
~MockHostResolverBase() override;
RuleResolver* rules() { return &rule_resolver_; }
scoped_refptr<const State> state() const { return state_; }
void set_synchronous_mode(bool is_synchronous) {
synchronous_mode_ = is_synchronous;
}
void set_ondemand_mode(bool is_ondemand) {
ondemand_mode_ = is_ondemand;
}
void OnShutdown() override;
std::unique_ptr<ResolveHostRequest> CreateRequest(
url::SchemeHostPort host,
NetworkAnonymizationKey network_anonymization_key,
NetLogWithSource net_log,
absl::optional<ResolveHostParameters> optional_parameters) override;
std::unique_ptr<ResolveHostRequest> CreateRequest(
const HostPortPair& host,
const NetworkAnonymizationKey& network_anonymization_key,
const NetLogWithSource& net_log,
const absl::optional<ResolveHostParameters>& optional_parameters)
override;
std::unique_ptr<ProbeRequest> CreateDohProbeRequest() override;
std::unique_ptr<MdnsListener> CreateMdnsListener(
const HostPortPair& host,
DnsQueryType query_type) override;
HostCache* GetHostCache() override;
void SetRequestContext(URLRequestContext* request_context) override {}
int LoadIntoCache(
absl::variant<url::SchemeHostPort, HostPortPair> endpoint,
const NetworkAnonymizationKey& network_anonymization_key,
const absl::optional<ResolveHostParameters>& optional_parameters);
int LoadIntoCache(
const Host& endpoint,
const NetworkAnonymizationKey& network_anonymization_key,
const absl::optional<ResolveHostParameters>& optional_parameters);
bool has_pending_requests() const { return state_->has_pending_requests(); }
void ResolveAllPending();
size_t last_id();
void ResolveNow(size_t id);
void DetachRequest(size_t id);
base::StringPiece request_host(size_t id);
RequestPriority request_priority(size_t id);
const NetworkAnonymizationKey& request_network_anonymization_key(size_t id);
void ResolveOnlyRequestNow();
size_t num_resolve() const { return state_->num_resolve(); }
size_t num_resolve_from_cache() const {
return state_->num_resolve_from_cache();
}
size_t num_non_local_resolves() const {
return state_->num_non_local_resolves();
}
RequestPriority last_request_priority() const {
return last_request_priority_;
}
const absl::optional<NetworkAnonymizationKey>&
last_request_network_anonymization_key() {
return last_request_network_anonymization_key_;
}
SecureDnsPolicy last_secure_dns_policy() const {
return last_secure_dns_policy_;
}
bool IsDohProbeRunning() const { return state_->IsDohProbeRunning(); }
void TriggerMdnsListeners(const HostPortPair& host,
DnsQueryType query_type,
MdnsListenerUpdateType update_type,
const IPEndPoint& address_result);
void TriggerMdnsListeners(const HostPortPair& host,
DnsQueryType query_type,
MdnsListenerUpdateType update_type,
const std::vector<std::string>& text_result);
void TriggerMdnsListeners(const HostPortPair& host,
DnsQueryType query_type,
MdnsListenerUpdateType update_type,
const HostPortPair& host_result);
void TriggerMdnsListeners(const HostPortPair& host,
DnsQueryType query_type,
MdnsListenerUpdateType update_type);
void set_tick_clock(const base::TickClock* tick_clock) {
tick_clock_ = tick_clock;
}
private:
friend class MockHostResolver;
friend class MockCachingHostResolver;
friend class MockHostResolverFactory;
RequestImpl* request(size_t id);
MockHostResolverBase(bool use_caching,
int cache_invalidation_num,
RuleResolver rule_resolver);
int Resolve(RequestImpl* request);
int ResolveFromIPLiteralOrCache(
const Host& endpoint,
const NetworkAnonymizationKey& network_anonymization_key,
DnsQueryType dns_query_type,
HostResolverFlags flags,
HostResolverSource source,
HostResolver::ResolveHostParameters::CacheUsage cache_usage,
std::vector<HostResolverEndpointResult>* out_endpoints,
std::set<std::string>* out_aliases,
absl::optional<HostCache::EntryStaleness>* out_stale_info);
int DoSynchronousResolution(RequestImpl& request);
void AddListener(MdnsListenerImpl* listener);
void RemoveCancelledListener(MdnsListenerImpl* listener);
RequestPriority last_request_priority_ = DEFAULT_PRIORITY;
absl::optional<NetworkAnonymizationKey>
last_request_network_anonymization_key_;
SecureDnsPolicy last_secure_dns_policy_ = SecureDnsPolicy::kAllow;
bool synchronous_mode_ = false;
bool ondemand_mode_ = false;
RuleResolver rule_resolver_;
std::unique_ptr<HostCache> cache_;
const int initial_cache_invalidation_num_;
std::map<HostCache::Key, int> cache_invalidation_nums_;
std::set<MdnsListenerImpl*> listeners_;
size_t next_request_id_ = 1;
raw_ptr<const base::TickClock> tick_clock_;
scoped_refptr<State> state_;
THREAD_CHECKER(thread_checker_);
};
class MockHostResolver : public MockHostResolverBase {
public:
explicit MockHostResolver(absl::optional<RuleResolver::RuleResultOrError>
default_result = absl::nullopt)
: MockHostResolverBase(false,
0,
RuleResolver(std::move(default_result))) {}
~MockHostResolver() override = default;
};
class MockCachingHostResolver : public MockHostResolverBase {
public:
explicit MockCachingHostResolver(
int cache_invalidation_num = 0,
absl::optional<RuleResolver::RuleResultOrError> default_result =
absl::nullopt)
: MockHostResolverBase(true,
cache_invalidation_num,
RuleResolver(std::move(default_result))) {}
~MockCachingHostResolver() override = default;
};
class MockHostResolverFactory : public HostResolver::Factory {
public:
explicit MockHostResolverFactory(MockHostResolverBase::RuleResolver rules =
MockHostResolverBase::RuleResolver(),
bool use_caching = false,
int cache_invalidation_num = 0);
MockHostResolverFactory(const MockHostResolverFactory&) = delete;
MockHostResolverFactory& operator=(const MockHostResolverFactory&) = delete;
~MockHostResolverFactory() override;
std::unique_ptr<HostResolver> CreateResolver(
HostResolverManager* manager,
base::StringPiece host_mapping_rules,
bool enable_caching) override;
std::unique_ptr<HostResolver> CreateStandaloneResolver(
NetLog* net_log,
const HostResolver::ManagerOptions& options,
base::StringPiece host_mapping_rules,
bool enable_caching) override;
private:
const MockHostResolverBase::RuleResolver rules_;
const bool use_caching_;
const int cache_invalidation_num_;
};
class RuleBasedHostResolverProc : public HostResolverProc {
public:
explicit RuleBasedHostResolverProc(scoped_refptr<HostResolverProc> previous,
bool allow_fallback = true);
void AddRule(base::StringPiece host_pattern, base::StringPiece ip_literal);
void AddRuleForAddressFamily(base::StringPiece host_pattern,
AddressFamily address_family,
base::StringPiece ip_literal);
void AddRuleWithFlags(base::StringPiece host_pattern,
base::StringPiece ip_literal,
HostResolverFlags flags,
std::vector<std::string> dns_aliases = {});
void AddIPLiteralRule(base::StringPiece host_pattern,
base::StringPiece ip_literal,
base::StringPiece canonical_name);
void AddIPLiteralRuleWithDnsAliases(base::StringPiece host_pattern,
base::StringPiece ip_literal,
std::vector<std::string> dns_aliases);
void AddRuleWithLatency(base::StringPiece host_pattern,
base::StringPiece replacement,
int latency_ms);
void AllowDirectLookup(base::StringPiece host);
void AddSimulatedFailure(
base::StringPiece host,
HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY);
void AddSimulatedTimeoutFailure(
base::StringPiece host,
HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY);
void ClearRules();
void DisableModifications();
int Resolve(const std::string& host,
AddressFamily address_family,
HostResolverFlags host_resolver_flags,
AddressList* addrlist,
int* os_error) override;
struct Rule {
enum ResolverType {
kResolverTypeFail,
kResolverTypeFailTimeout,
kResolverTypeSystem,
kResolverTypeIPLiteral,
};
Rule(ResolverType resolver_type,
base::StringPiece host_pattern,
AddressFamily address_family,
HostResolverFlags host_resolver_flags,
base::StringPiece replacement,
std::vector<std::string> dns_aliases,
int latency_ms);
Rule(const Rule& other);
~Rule();
ResolverType resolver_type;
std::string host_pattern;
AddressFamily address_family;
HostResolverFlags host_resolver_flags;
std::string replacement;
std::vector<std::string> dns_aliases;
int latency_ms;
};
typedef std::list<Rule> RuleList;
RuleList GetRules();
size_t NumResolvesForHostPattern(base::StringPiece host_pattern);
private:
~RuleBasedHostResolverProc() override;
void AddRuleInternal(const Rule& rule);
RuleList rules_ GUARDED_BY(rule_lock_);
std::map<base::StringPiece, size_t> num_resolves_per_host_pattern_
GUARDED_BY(rule_lock_);
base::Lock rule_lock_;
bool modifications_allowed_ = true;
};
scoped_refptr<RuleBasedHostResolverProc> CreateCatchAllHostResolverProc();
class HangingHostResolver : public HostResolver {
public:
class State : public base::RefCounted<State> {
public:
State();
int num_cancellations() const { return num_cancellations_; }
void IncrementNumCancellations() { ++num_cancellations_; }
private:
friend class RefCounted<State>;
~State();
int num_cancellations_ = 0;
};
HangingHostResolver();
~HangingHostResolver() override;
void OnShutdown() override;
std::unique_ptr<ResolveHostRequest> CreateRequest(
url::SchemeHostPort host,
NetworkAnonymizationKey network_anonymization_key,
NetLogWithSource net_log,
absl::optional<ResolveHostParameters> optional_parameters) override;
std::unique_ptr<ResolveHostRequest> CreateRequest(
const HostPortPair& host,
const NetworkAnonymizationKey& network_anonymization_key,
const NetLogWithSource& net_log,
const absl::optional<ResolveHostParameters>& optional_parameters)
override;
std::unique_ptr<ProbeRequest> CreateDohProbeRequest() override;
void SetRequestContext(URLRequestContext* url_request_context) override;
int num_cancellations() const { return state_->num_cancellations(); }
const HostPortPair& last_host() const { return last_host_; }
const NetworkAnonymizationKey& last_network_anonymization_key() const {
return last_network_anonymization_key_;
}
const scoped_refptr<const State> state() const { return state_; }
private:
class RequestImpl;
class ProbeRequestImpl;
HostPortPair last_host_;
NetworkAnonymizationKey last_network_anonymization_key_;
scoped_refptr<State> state_;
bool shutting_down_ = false;
base::WeakPtrFactory<HangingHostResolver> weak_ptr_factory_{this};
};
class ScopedDefaultHostResolverProc {
public:
ScopedDefaultHostResolverProc();
explicit ScopedDefaultHostResolverProc(HostResolverProc* proc);
~ScopedDefaultHostResolverProc();
void Init(HostResolverProc* proc);
private:
scoped_refptr<HostResolverProc> current_proc_;
scoped_refptr<HostResolverProc> previous_proc_;
};
}
#endif