#include "net/dns/dns_transaction.h"
#include <stdint.h>
#include <algorithm>
#include <memory>
#include <optional>
#include <set>
#include <string>
#include <string_view>
#include <unordered_map>
#include <utility>
#include <vector>
#include "base/base64url.h"
#include "base/byte_count.h"
#include "base/containers/circular_deque.h"
#include "base/containers/span.h"
#include "base/functional/bind.h"
#include "base/functional/callback_helpers.h"
#include "base/location.h"
#include "base/memory/ptr_util.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/ref_counted.h"
#include "base/memory/safe_ref.h"
#include "base/memory/weak_ptr.h"
#include "base/metrics/histogram_functions.h"
#include "base/metrics/histogram_macros.h"
#include "base/numerics/byte_conversions.h"
#include "base/numerics/safe_conversions.h"
#include "base/rand_util.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/single_thread_task_runner.h"
#include "base/threading/thread_checker.h"
#include "base/timer/elapsed_timer.h"
#include "base/timer/timer.h"
#include "base/values.h"
#include "build/build_config.h"
#include "net/base/backoff_entry.h"
#include "net/base/completion_once_callback.h"
#include "net/base/elements_upload_data_stream.h"
#include "net/base/idempotency.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h"
#include "net/base/load_flags.h"
#include "net/base/net_errors.h"
#include "net/base/upload_bytes_element_reader.h"
#include "net/dns/dns_config.h"
#include "net/dns/dns_names_util.h"
#include "net/dns/dns_query.h"
#include "net/dns/dns_response.h"
#include "net/dns/dns_response_result_extractor.h"
#include "net/dns/dns_server_iterator.h"
#include "net/dns/dns_session.h"
#include "net/dns/dns_udp_tracker.h"
#include "net/dns/dns_util.h"
#include "net/dns/host_cache.h"
#include "net/dns/host_resolver_internal_result.h"
#include "net/dns/opt_record_rdata.h"
#include "net/dns/public/dns_over_https_config.h"
#include "net/dns/public/dns_over_https_server_config.h"
#include "net/dns/public/dns_protocol.h"
#include "net/dns/public/dns_query_type.h"
#include "net/dns/public/secure_dns_policy.h"
#include "net/dns/resolve_context.h"
#include "net/http/http_request_headers.h"
#include "net/log/net_log.h"
#include "net/log/net_log_capture_mode.h"
#include "net/log/net_log_event_type.h"
#include "net/log/net_log_source.h"
#include "net/log/net_log_values.h"
#include "net/log/net_log_with_source.h"
#include "net/socket/client_socket_factory.h"
#include "net/socket/datagram_client_socket.h"
#include "net/socket/stream_socket.h"
#include "net/third_party/uri_template/uri_template.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
#include "net/url_request/url_request.h"
#include "net/url_request/url_request_context.h"
#include "net/url_request/url_request_context_builder.h"
#include "url/url_constants.h"
namespace net {
namespace {
constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation =
net::DefineNetworkTrafficAnnotation("dns_transaction", R"(
semantics {
sender: "DNS Transaction"
description:
"DNS Transaction implements a stub DNS resolver as defined in RFC "
"1034."
trigger:
"Any network request that may require DNS resolution, including "
"navigations, connecting to a proxy server, detecting proxy "
"settings, getting proxy config, certificate checking, and more."
data:
"Domain name that needs resolution."
destination: OTHER
destination_other:
"The connection is made to a DNS server based on user's network "
"settings."
}
policy {
cookies_allowed: NO
setting:
"This feature cannot be disabled. Without DNS Transactions Chrome "
"cannot resolve host names."
policy_exception_justification:
"Essential for Chrome's navigation."
})");
const char kDnsOverHttpResponseContentType[] = "application/dns-message";
constexpr base::ByteCount kDnsOverHttpResponseMaximumSize =
base::ByteCount(65535);
int CountLabels(base::span<const uint8_t> name) {
size_t count = 0;
for (size_t i = 0; i < name.size() && name[i]; i += name[i] + 1)
++count;
return count;
}
bool IsIPLiteral(const std::string& hostname) {
IPAddress ip;
return ip.AssignFromIPLiteral(hostname);
}
base::Value::Dict NetLogStartParams(const std::string& hostname,
uint16_t qtype) {
base::Value::Dict dict;
dict.Set("hostname", hostname);
dict.Set("query_type", qtype);
return dict;
}
class DnsAttempt {
public:
explicit DnsAttempt(size_t server_index) : server_index_(server_index) {}
DnsAttempt(const DnsAttempt&) = delete;
DnsAttempt& operator=(const DnsAttempt&) = delete;
virtual ~DnsAttempt() = default;
virtual int Start(CompletionOnceCallback callback) = 0;
virtual const DnsQuery* GetQuery() const = 0;
virtual const DnsResponse* GetResponse() const = 0;
virtual base::Value GetRawResponseBufferForLog() const = 0;
virtual const NetLogWithSource& GetSocketNetLog() const = 0;
size_t server_index() const { return server_index_; }
base::Value::Dict NetLogResponseParams(NetLogCaptureMode capture_mode) const {
base::Value::Dict dict;
if (GetResponse()) {
DCHECK(GetResponse()->IsValid());
dict.Set("rcode", GetResponse()->rcode());
dict.Set("answer_count", static_cast<int>(GetResponse()->answer_count()));
dict.Set("additional_answer_count",
static_cast<int>(GetResponse()->additional_answer_count()));
}
GetSocketNetLog().source().AddToEventParameters(dict);
if (capture_mode == NetLogCaptureMode::kEverything) {
dict.Set("response_buffer", GetRawResponseBufferForLog());
}
return dict;
}
virtual bool IsPending() const = 0;
private:
const size_t server_index_;
};
class DnsUDPAttempt : public DnsAttempt {
public:
DnsUDPAttempt(size_t server_index,
std::unique_ptr<DatagramClientSocket> socket,
const IPEndPoint& server,
std::unique_ptr<DnsQuery> query,
DnsUdpTracker* udp_tracker)
: DnsAttempt(server_index),
socket_(std::move(socket)),
server_(server),
query_(std::move(query)),
udp_tracker_(udp_tracker) {}
DnsUDPAttempt(const DnsUDPAttempt&) = delete;
DnsUDPAttempt& operator=(const DnsUDPAttempt&) = delete;
int Start(CompletionOnceCallback callback) override {
DCHECK_EQ(STATE_NONE, next_state_);
callback_ = std::move(callback);
start_time_ = base::TimeTicks::Now();
next_state_ = STATE_CONNECT_COMPLETE;
int rv = socket_->ConnectAsync(
server_,
base::BindOnce(&DnsUDPAttempt::OnIOComplete, base::Unretained(this)));
if (rv == ERR_IO_PENDING) {
return rv;
}
return DoLoop(rv);
}
const DnsQuery* GetQuery() const override { return query_.get(); }
const DnsResponse* GetResponse() const override {
const DnsResponse* resp = response_.get();
return (resp != nullptr && resp->IsValid()) ? resp : nullptr;
}
base::Value GetRawResponseBufferForLog() const override {
if (!response_)
return base::Value();
return NetLogBinaryValue(response_->io_buffer()->data(), read_size_);
}
const NetLogWithSource& GetSocketNetLog() const override {
return socket_->NetLog();
}
bool IsPending() const override { return next_state_ != STATE_NONE; }
private:
enum State {
STATE_CONNECT_COMPLETE,
STATE_SEND_QUERY,
STATE_SEND_QUERY_COMPLETE,
STATE_READ_RESPONSE,
STATE_READ_RESPONSE_COMPLETE,
STATE_NONE,
};
int DoLoop(int result) {
CHECK_NE(STATE_NONE, next_state_);
int rv = result;
do {
State state = next_state_;
next_state_ = STATE_NONE;
switch (state) {
case STATE_CONNECT_COMPLETE:
rv = DoConnectComplete(rv);
break;
case STATE_SEND_QUERY:
rv = DoSendQuery(rv);
break;
case STATE_SEND_QUERY_COMPLETE:
rv = DoSendQueryComplete(rv);
break;
case STATE_READ_RESPONSE:
rv = DoReadResponse();
break;
case STATE_READ_RESPONSE_COMPLETE:
rv = DoReadResponseComplete(rv);
break;
default:
NOTREACHED();
}
} while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
if (rv != ERR_IO_PENDING)
DCHECK_EQ(STATE_NONE, next_state_);
return rv;
}
int DoConnectComplete(int rv) {
if (rv != OK) {
DVLOG(1) << "Failed to connect socket: " << rv;
udp_tracker_->RecordConnectionError(rv);
return ERR_CONNECTION_REFUSED;
}
next_state_ = STATE_SEND_QUERY;
IPEndPoint local_address;
if (socket_->GetLocalAddress(&local_address) == OK)
udp_tracker_->RecordQuery(local_address.port(), query_->id());
return OK;
}
int DoSendQuery(int rv) {
DCHECK_NE(ERR_IO_PENDING, rv);
if (rv < 0)
return rv;
next_state_ = STATE_SEND_QUERY_COMPLETE;
return socket_->Write(
query_->io_buffer(), query_->io_buffer()->size(),
base::BindOnce(&DnsUDPAttempt::OnIOComplete, base::Unretained(this)),
kTrafficAnnotation);
}
int DoSendQueryComplete(int rv) {
DCHECK_NE(ERR_IO_PENDING, rv);
if (rv < 0)
return rv;
if (rv != query_->io_buffer()->size())
return ERR_MSG_TOO_BIG;
next_state_ = STATE_READ_RESPONSE;
return OK;
}
int DoReadResponse() {
next_state_ = STATE_READ_RESPONSE_COMPLETE;
response_ = std::make_unique<DnsResponse>();
return socket_->Read(
response_->io_buffer(), response_->io_buffer_size(),
base::BindOnce(&DnsUDPAttempt::OnIOComplete, base::Unretained(this)));
}
int DoReadResponseComplete(int rv) {
DCHECK_NE(ERR_IO_PENDING, rv);
if (rv < 0)
return rv;
read_size_ = rv;
bool parse_result = response_->InitParse(rv, *query_);
if (response_->id())
udp_tracker_->RecordResponseId(query_->id(), response_->id().value());
if (!parse_result)
return ERR_DNS_MALFORMED_RESPONSE;
if (response_->flags() & dns_protocol::kFlagTC)
return ERR_DNS_SERVER_REQUIRES_TCP;
if (response_->rcode() == dns_protocol::kRcodeNXDOMAIN)
return ERR_NAME_NOT_RESOLVED;
if (response_->rcode() != dns_protocol::kRcodeNOERROR)
return ERR_DNS_SERVER_FAILED;
return OK;
}
void OnIOComplete(int rv) {
rv = DoLoop(rv);
if (rv != ERR_IO_PENDING)
std::move(callback_).Run(rv);
}
State next_state_ = STATE_NONE;
base::TimeTicks start_time_;
std::unique_ptr<DatagramClientSocket> socket_;
IPEndPoint server_;
std::unique_ptr<DnsQuery> query_;
const raw_ptr<DnsUdpTracker> udp_tracker_;
std::unique_ptr<DnsResponse> response_;
int read_size_ = 0;
CompletionOnceCallback callback_;
};
class DnsHTTPAttempt : public DnsAttempt, public URLRequest::Delegate {
public:
DnsHTTPAttempt(size_t doh_server_index,
std::unique_ptr<DnsQuery> query,
const string& server_template,
const GURL& gurl_without_parameters,
bool use_post,
URLRequestContext* url_request_context,
const IsolationInfo& isolation_info,
RequestPriority request_priority_,
bool is_probe)
: DnsAttempt(doh_server_index),
query_(std::move(query)),
net_log_(NetLogWithSource::Make(NetLog::Get(),
NetLogSourceType::DNS_OVER_HTTPS)) {
GURL url;
if (use_post) {
url = gurl_without_parameters;
} else {
std::string url_string;
std::unordered_map<string, string> parameters;
std::string encoded_query;
base::Base64UrlEncode(std::string_view(query_->io_buffer()->data(),
query_->io_buffer()->size()),
base::Base64UrlEncodePolicy::OMIT_PADDING,
&encoded_query);
parameters.emplace("dns", encoded_query);
uri_template::Expand(server_template, parameters, &url_string);
url = GURL(url_string);
}
net_log_.BeginEvent(NetLogEventType::DOH_URL_REQUEST, [&] {
if (is_probe) {
return NetLogStartParams("(probe)", query_->qtype());
}
std::optional<std::string> hostname =
dns_names_util::NetworkToDottedName(query_->qname());
DCHECK(hostname.has_value());
return NetLogStartParams(*hostname, query_->qtype());
});
HttpRequestHeaders extra_request_headers;
extra_request_headers.SetHeader(HttpRequestHeaders::kAccept,
kDnsOverHttpResponseContentType);
extra_request_headers.SetHeader(HttpRequestHeaders::kAcceptLanguage, "*");
extra_request_headers.SetHeader(HttpRequestHeaders::kUserAgent, "Chrome");
extra_request_headers.SetHeader(HttpRequestHeaders::kAcceptEncoding,
"identity");
DCHECK(url_request_context);
request_ = url_request_context->CreateRequest(
url, request_priority_, this,
net::DefineNetworkTrafficAnnotation("dns_over_https", R"(
semantics {
sender: "DNS over HTTPS"
description: "Domain name resolution over HTTPS"
trigger: "User enters a navigates to a domain or Chrome otherwise "
"makes a connection to a domain whose IP address isn't cached"
data: "The domain name that is being requested"
destination: OTHER
destination_other: "The user configured DNS over HTTPS server, which"
"may be dns.google.com"
}
policy {
cookies_allowed: NO
setting:
"You can configure this feature via that 'dns_over_https_servers' and"
"'dns_over_https.method' prefs. Empty lists imply this feature is"
"disabled"
policy_exception_justification: "Experimental feature that"
"is disabled by default"
}
)"),
false, net_log_.source());
if (use_post) {
request_->set_method("POST");
request_->SetIdempotency(IDEMPOTENT);
std::unique_ptr<UploadElementReader> reader =
std::make_unique<UploadBytesElementReader>(
query_->io_buffer()->span());
request_->set_upload(
ElementsUploadDataStream::CreateWithReader(std::move(reader)));
extra_request_headers.SetHeader(HttpRequestHeaders::kContentType,
kDnsOverHttpResponseContentType);
}
request_->SetExtraRequestHeaders(extra_request_headers);
request_->SetSecureDnsPolicy(SecureDnsPolicy::kBootstrap);
request_->SetLoadFlags(request_->load_flags() | LOAD_DISABLE_CACHE |
LOAD_BYPASS_PROXY);
request_->set_disallow_credentials();
request_->set_isolation_info(isolation_info);
}
DnsHTTPAttempt(const DnsHTTPAttempt&) = delete;
DnsHTTPAttempt& operator=(const DnsHTTPAttempt&) = delete;
int Start(CompletionOnceCallback callback) override {
callback_ = std::move(callback);
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&DnsHTTPAttempt::StartAsync,
weak_factory_.GetWeakPtr()));
return ERR_IO_PENDING;
}
const DnsQuery* GetQuery() const override { return query_.get(); }
const DnsResponse* GetResponse() const override {
const DnsResponse* resp = response_.get();
return (resp != nullptr && resp->IsValid()) ? resp : nullptr;
}
base::Value GetRawResponseBufferForLog() const override {
if (!response_)
return base::Value();
return NetLogBinaryValue(response_->io_buffer()->data(),
response_->io_buffer_size());
}
const NetLogWithSource& GetSocketNetLog() const override { return net_log_; }
void OnResponseStarted(net::URLRequest* request, int net_error) override {
DCHECK_NE(net::ERR_IO_PENDING, net_error);
std::string content_type;
if (net_error != OK) {
if (IsHostnameResolutionError(net_error))
net_error = ERR_DNS_SECURE_RESOLVER_HOSTNAME_RESOLUTION_FAILED;
ResponseCompleted(net_error);
return;
}
if (request_->GetResponseCode() != 200 ||
!request->response_headers()->GetMimeType(&content_type) ||
0 != content_type.compare(kDnsOverHttpResponseContentType)) {
ResponseCompleted(ERR_DNS_MALFORMED_RESPONSE);
return;
}
buffer_ = base::MakeRefCounted<GrowableIOBuffer>();
std::optional<base::ByteCount> content_length =
request_->response_headers()->GetContentLength();
if (content_length.has_value()) {
if (content_length.value() > kDnsOverHttpResponseMaximumSize) {
ResponseCompleted(ERR_DNS_MALFORMED_RESPONSE);
return;
}
buffer_->SetCapacity(
base::checked_cast<int>(content_length->InBytes() + 1));
} else {
buffer_->SetCapacity(base::checked_cast<int>(
kDnsOverHttpResponseMaximumSize.InBytes() + 1));
}
DCHECK(buffer_->data());
DCHECK_GT(buffer_->capacity(), 0);
int bytes_read =
request_->Read(buffer_.get(), buffer_->RemainingCapacity());
if (bytes_read == net::ERR_IO_PENDING)
return;
OnReadCompleted(request_.get(), bytes_read);
}
void OnReceivedRedirect(URLRequest* request,
const RedirectInfo& redirect_info,
bool* defer_redirect) override {
if (!redirect_info.new_url.SchemeIs(url::kHttpsScheme)) {
request->Cancel();
}
}
void OnReadCompleted(net::URLRequest* request, int bytes_read) override {
if (bytes_read < 0) {
ResponseCompleted(bytes_read);
return;
}
DCHECK_GE(bytes_read, 0);
if (bytes_read > 0) {
if (buffer_->offset() + bytes_read >
kDnsOverHttpResponseMaximumSize.InBytes()) {
ResponseCompleted(ERR_DNS_MALFORMED_RESPONSE);
return;
}
buffer_->set_offset(buffer_->offset() + bytes_read);
if (buffer_->RemainingCapacity() == 0) {
buffer_->SetCapacity(buffer_->capacity() + 16384);
}
DCHECK(buffer_->data());
DCHECK_GT(buffer_->capacity(), 0);
int read_result =
request_->Read(buffer_.get(), buffer_->RemainingCapacity());
if (read_result == net::ERR_IO_PENDING)
return;
if (read_result <= 0) {
OnReadCompleted(request_.get(), read_result);
} else {
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&DnsHTTPAttempt::OnReadCompleted,
weak_factory_.GetWeakPtr(),
request_.get(), read_result));
}
} else {
DCHECK_EQ(0, bytes_read);
ResponseCompleted(net::OK);
}
}
bool IsPending() const override { return !callback_.is_null(); }
private:
void StartAsync() {
DCHECK(request_);
request_->Start();
}
void ResponseCompleted(int net_error) {
request_.reset();
std::move(callback_).Run(CompleteResponse(net_error));
}
int CompleteResponse(int net_error) {
net_log_.EndEventWithNetErrorCode(NetLogEventType::DOH_URL_REQUEST,
net_error);
DCHECK_NE(net::ERR_IO_PENDING, net_error);
if (net_error != OK) {
return net_error;
}
if (!buffer_.get() || 0 == buffer_->capacity())
return ERR_DNS_MALFORMED_RESPONSE;
size_t size = buffer_->offset();
buffer_->set_offset(0);
if (size == 0u)
return ERR_DNS_MALFORMED_RESPONSE;
response_ = std::make_unique<DnsResponse>(buffer_, size);
if (!response_->InitParse(size, *query_))
return ERR_DNS_MALFORMED_RESPONSE;
if (response_->rcode() == dns_protocol::kRcodeNXDOMAIN)
return ERR_NAME_NOT_RESOLVED;
if (response_->rcode() != dns_protocol::kRcodeNOERROR)
return ERR_DNS_SERVER_FAILED;
return OK;
}
scoped_refptr<GrowableIOBuffer> buffer_;
std::unique_ptr<DnsQuery> query_;
CompletionOnceCallback callback_;
std::unique_ptr<DnsResponse> response_;
std::unique_ptr<URLRequest> request_;
NetLogWithSource net_log_;
base::WeakPtrFactory<DnsHTTPAttempt> weak_factory_{this};
};
void ConstructDnsHTTPAttempt(DnsSession* session,
size_t doh_server_index,
base::span<const uint8_t> qname,
uint16_t qtype,
const OptRecordRdata* opt_rdata,
std::vector<std::unique_ptr<DnsAttempt>>* attempts,
URLRequestContext* url_request_context,
const IsolationInfo& isolation_info,
RequestPriority request_priority,
bool is_probe) {
DCHECK(url_request_context);
std::unique_ptr<DnsQuery> query;
if (attempts->empty()) {
query =
std::make_unique<DnsQuery>(0, qname, qtype, opt_rdata,
DnsQuery::PaddingStrategy::BLOCK_LENGTH_128);
} else {
query = std::make_unique<DnsQuery>(*attempts->at(0)->GetQuery());
}
DCHECK_LT(doh_server_index, session->config().doh_config.servers().size());
const DnsOverHttpsServerConfig& doh_server =
session->config().doh_config.servers()[doh_server_index];
GURL gurl_without_parameters(
GetURLFromTemplateWithoutParameters(doh_server.server_template()));
attempts->push_back(std::make_unique<DnsHTTPAttempt>(
doh_server_index, std::move(query), doh_server.server_template(),
gurl_without_parameters, doh_server.use_post(), url_request_context,
isolation_info, request_priority, is_probe));
}
class DnsTCPAttempt : public DnsAttempt {
public:
DnsTCPAttempt(size_t server_index,
std::unique_ptr<StreamSocket> socket,
std::unique_ptr<DnsQuery> query)
: DnsAttempt(server_index),
socket_(std::move(socket)),
query_(std::move(query)),
length_buffer_(
base::MakeRefCounted<IOBufferWithSize>(sizeof(uint16_t))) {}
DnsTCPAttempt(const DnsTCPAttempt&) = delete;
DnsTCPAttempt& operator=(const DnsTCPAttempt&) = delete;
int Start(CompletionOnceCallback callback) override {
DCHECK_EQ(STATE_NONE, next_state_);
callback_ = std::move(callback);
start_time_ = base::TimeTicks::Now();
next_state_ = STATE_CONNECT_COMPLETE;
int rv = socket_->Connect(
base::BindOnce(&DnsTCPAttempt::OnIOComplete, base::Unretained(this)));
if (rv == ERR_IO_PENDING) {
return rv;
}
return DoLoop(rv);
}
const DnsQuery* GetQuery() const override { return query_.get(); }
const DnsResponse* GetResponse() const override {
const DnsResponse* resp = response_.get();
return (resp != nullptr && resp->IsValid()) ? resp : nullptr;
}
base::Value GetRawResponseBufferForLog() const override {
if (!response_)
return base::Value();
return NetLogBinaryValue(response_->io_buffer()->data(),
response_->io_buffer_size());
}
const NetLogWithSource& GetSocketNetLog() const override {
return socket_->NetLog();
}
bool IsPending() const override { return next_state_ != STATE_NONE; }
private:
enum State {
STATE_CONNECT_COMPLETE,
STATE_SEND_LENGTH,
STATE_SEND_QUERY,
STATE_READ_LENGTH,
STATE_READ_LENGTH_COMPLETE,
STATE_READ_RESPONSE,
STATE_READ_RESPONSE_COMPLETE,
STATE_NONE,
};
int DoLoop(int result) {
CHECK_NE(STATE_NONE, next_state_);
int rv = result;
do {
State state = next_state_;
next_state_ = STATE_NONE;
switch (state) {
case STATE_CONNECT_COMPLETE:
rv = DoConnectComplete(rv);
break;
case STATE_SEND_LENGTH:
rv = DoSendLength(rv);
break;
case STATE_SEND_QUERY:
rv = DoSendQuery(rv);
break;
case STATE_READ_LENGTH:
rv = DoReadLength(rv);
break;
case STATE_READ_LENGTH_COMPLETE:
rv = DoReadLengthComplete(rv);
break;
case STATE_READ_RESPONSE:
rv = DoReadResponse(rv);
break;
case STATE_READ_RESPONSE_COMPLETE:
rv = DoReadResponseComplete(rv);
break;
default:
NOTREACHED();
}
} while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
if (rv != ERR_IO_PENDING)
DCHECK_EQ(STATE_NONE, next_state_);
return rv;
}
int DoConnectComplete(int rv) {
DCHECK_NE(ERR_IO_PENDING, rv);
if (rv < 0)
return rv;
uint16_t query_size = static_cast<uint16_t>(query_->io_buffer()->size());
if (static_cast<int>(query_size) != query_->io_buffer()->size())
return ERR_FAILED;
length_buffer_->span().copy_from(base::U16ToBigEndian(query_size));
buffer_ = base::MakeRefCounted<DrainableIOBuffer>(length_buffer_,
length_buffer_->size());
next_state_ = STATE_SEND_LENGTH;
return OK;
}
int DoSendLength(int rv) {
DCHECK_NE(ERR_IO_PENDING, rv);
if (rv < 0)
return rv;
buffer_->DidConsume(rv);
if (buffer_->BytesRemaining() > 0) {
next_state_ = STATE_SEND_LENGTH;
return socket_->Write(
buffer_.get(), buffer_->BytesRemaining(),
base::BindOnce(&DnsTCPAttempt::OnIOComplete, base::Unretained(this)),
kTrafficAnnotation);
}
buffer_ = base::MakeRefCounted<DrainableIOBuffer>(
query_->io_buffer(), query_->io_buffer()->size());
next_state_ = STATE_SEND_QUERY;
return OK;
}
int DoSendQuery(int rv) {
DCHECK_NE(ERR_IO_PENDING, rv);
if (rv < 0)
return rv;
buffer_->DidConsume(rv);
if (buffer_->BytesRemaining() > 0) {
next_state_ = STATE_SEND_QUERY;
return socket_->Write(
buffer_.get(), buffer_->BytesRemaining(),
base::BindOnce(&DnsTCPAttempt::OnIOComplete, base::Unretained(this)),
kTrafficAnnotation);
}
buffer_ = base::MakeRefCounted<DrainableIOBuffer>(length_buffer_,
length_buffer_->size());
next_state_ = STATE_READ_LENGTH;
return OK;
}
int DoReadLength(int rv) {
DCHECK_EQ(OK, rv);
next_state_ = STATE_READ_LENGTH_COMPLETE;
return ReadIntoBuffer();
}
int DoReadLengthComplete(int rv) {
DCHECK_NE(ERR_IO_PENDING, rv);
if (rv < 0)
return rv;
if (rv == 0)
return ERR_CONNECTION_CLOSED;
buffer_->DidConsume(rv);
if (buffer_->BytesRemaining() > 0) {
next_state_ = STATE_READ_LENGTH;
return OK;
}
response_length_ =
base::U16FromBigEndian(length_buffer_->span().first<2u>());
if (response_length_ < query_->io_buffer()->size())
return ERR_DNS_MALFORMED_RESPONSE;
response_ = std::make_unique<DnsResponse>(response_length_);
buffer_ = base::MakeRefCounted<DrainableIOBuffer>(response_->io_buffer(),
response_length_);
next_state_ = STATE_READ_RESPONSE;
return OK;
}
int DoReadResponse(int rv) {
DCHECK_EQ(OK, rv);
next_state_ = STATE_READ_RESPONSE_COMPLETE;
return ReadIntoBuffer();
}
int DoReadResponseComplete(int rv) {
DCHECK_NE(ERR_IO_PENDING, rv);
if (rv < 0)
return rv;
if (rv == 0)
return ERR_CONNECTION_CLOSED;
buffer_->DidConsume(rv);
if (buffer_->BytesRemaining() > 0) {
next_state_ = STATE_READ_RESPONSE;
return OK;
}
DCHECK_GT(buffer_->BytesConsumed(), 0);
if (!response_->InitParse(buffer_->BytesConsumed(), *query_))
return ERR_DNS_MALFORMED_RESPONSE;
if (response_->flags() & dns_protocol::kFlagTC)
return ERR_UNEXPECTED;
if (response_->rcode() == dns_protocol::kRcodeNXDOMAIN)
return ERR_NAME_NOT_RESOLVED;
if (response_->rcode() != dns_protocol::kRcodeNOERROR)
return ERR_DNS_SERVER_FAILED;
return OK;
}
void OnIOComplete(int rv) {
rv = DoLoop(rv);
if (rv != ERR_IO_PENDING)
std::move(callback_).Run(rv);
}
int ReadIntoBuffer() {
return socket_->Read(
buffer_.get(), buffer_->BytesRemaining(),
base::BindOnce(&DnsTCPAttempt::OnIOComplete, base::Unretained(this)));
}
State next_state_ = STATE_NONE;
base::TimeTicks start_time_;
std::unique_ptr<StreamSocket> socket_;
std::unique_ptr<DnsQuery> query_;
scoped_refptr<IOBufferWithSize> length_buffer_;
scoped_refptr<DrainableIOBuffer> buffer_;
uint16_t response_length_ = 0;
std::unique_ptr<DnsResponse> response_;
CompletionOnceCallback callback_;
};
const net::BackoffEntry::Policy kProbeBackoffPolicy = {
0,
1000,
1.5,
0.2,
1000 * 60 * 60,
-1,
false,
};
class DnsOverHttpsProbeRunner : public DnsProbeRunner {
public:
DnsOverHttpsProbeRunner(base::WeakPtr<DnsSession> session,
base::WeakPtr<ResolveContext> context)
: session_(session), context_(context) {
DCHECK(session_);
DCHECK(!session_->config().doh_config.servers().empty());
DCHECK(context_);
std::optional<std::vector<uint8_t>> qname =
dns_names_util::DottedNameToNetwork(kDohProbeHostname);
DCHECK(qname.has_value());
formatted_probe_qname_ = std::move(qname).value();
for (size_t i = 0; i < session_->config().doh_config.servers().size();
i++) {
probe_stats_list_.push_back(nullptr);
}
}
~DnsOverHttpsProbeRunner() override = default;
void Start(bool network_change) override {
DCHECK(session_);
DCHECK(context_);
const auto& config = session_->config().doh_config;
#if BUILDFLAG(ARKWEB_LOGGER_REPORT)
LOG_FEEDBACK(INFO) << "probe runner running server size: "
<< config.servers().size();
#endif
for (size_t i = 0; i < config.servers().size(); i++) {
if (!probe_stats_list_[i]) {
probe_stats_list_[i] = std::make_unique<ProbeStats>();
ContinueProbe(i, probe_stats_list_[i]->weak_factory.GetWeakPtr(),
network_change,
base::TimeTicks::Now() );
}
}
}
base::TimeDelta GetDelayUntilNextProbeForTest(
size_t doh_server_index) const override {
if (doh_server_index >= probe_stats_list_.size() ||
!probe_stats_list_[doh_server_index])
return base::TimeDelta();
return probe_stats_list_[doh_server_index]
->backoff_entry->GetTimeUntilRelease();
}
private:
struct ProbeStats {
ProbeStats()
: backoff_entry(
std::make_unique<net::BackoffEntry>(&kProbeBackoffPolicy)) {}
std::unique_ptr<net::BackoffEntry> backoff_entry;
std::vector<std::unique_ptr<DnsAttempt>> probe_attempts;
base::WeakPtrFactory<ProbeStats> weak_factory{this};
};
void ContinueProbe(size_t doh_server_index,
base::WeakPtr<ProbeStats> probe_stats,
bool network_change,
base::TimeTicks sequence_start_time) {
if (!session_ || !context_) {
probe_stats_list_.clear();
return;
}
if (!probe_stats)
return;
if (context_->GetDohServerAvailability(doh_server_index, session_.get())) {
probe_stats_list_[doh_server_index] = nullptr;
return;
}
DCHECK(probe_stats);
DCHECK(probe_stats->backoff_entry);
probe_stats->backoff_entry->InformOfRequest(false );
base::SequencedTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&DnsOverHttpsProbeRunner::ContinueProbe,
weak_ptr_factory_.GetWeakPtr(), doh_server_index,
probe_stats, network_change, sequence_start_time),
probe_stats->backoff_entry->GetTimeUntilRelease());
uint32_t attempt_number = probe_stats->probe_attempts.size();
ConstructDnsHTTPAttempt(
session_.get(), doh_server_index, formatted_probe_qname_,
dns_protocol::kTypeA, nullptr,
&probe_stats->probe_attempts, context_->url_request_context(),
context_->isolation_info(), RequestPriority::DEFAULT_PRIORITY,
true);
DnsAttempt* probe_attempt = probe_stats->probe_attempts.back().get();
probe_attempt->Start(base::BindOnce(
&DnsOverHttpsProbeRunner::ProbeComplete, weak_ptr_factory_.GetWeakPtr(),
attempt_number, doh_server_index, std::move(probe_stats),
network_change, sequence_start_time,
base::TimeTicks::Now() ));
}
void ProbeComplete(uint32_t attempt_number,
size_t doh_server_index,
base::WeakPtr<ProbeStats> probe_stats,
bool network_change,
base::TimeTicks sequence_start_time,
base::TimeTicks query_start_time,
int rv) {
bool success = false;
#if BUILDFLAG(ARKWEB_LOGGER_REPORT)
LOG_FEEDBACK(INFO) << "probe complete, attempt number " << attempt_number
<< ", server index " << doh_server_index << ", rv" << rv;
#endif
while (probe_stats && session_ && context_) {
if (rv != OK) {
context_->RecordServerFailure(doh_server_index, true,
rv, session_.get());
break;
}
DCHECK_LT(attempt_number, probe_stats->probe_attempts.size());
const DnsAttempt* attempt =
probe_stats->probe_attempts[attempt_number].get();
const DnsResponse* response = attempt->GetResponse();
if (response) {
DnsResponseResultExtractor extractor(*response);
DnsResponseResultExtractor::ResultsOrError results =
extractor.ExtractDnsResults(
DnsQueryType::A,
kDohProbeHostname,
0);
if (results.has_value()) {
for (const auto& result : results.value()) {
if (result->type() == HostResolverInternalResult::Type::kData &&
!result->AsData().endpoints().empty()) {
context_->RecordServerSuccess(
doh_server_index, true, session_.get());
context_->RecordRtt(doh_server_index, true,
base::TimeTicks::Now() - query_start_time, rv,
session_.get());
success = true;
break;
}
}
}
}
if (!success) {
context_->RecordServerFailure(
doh_server_index, true,
ERR_DNS_SECURE_PROBE_RECORD_INVALID, session_.get());
}
break;
}
base::UmaHistogramLongTimes(
base::JoinString({"Net.DNS.ProbeSequence",
network_change ? "NetworkChange" : "ConfigChange",
success ? "Success" : "Failure", "AttemptTime"},
"."),
base::TimeTicks::Now() - sequence_start_time);
}
base::WeakPtr<DnsSession> session_;
base::WeakPtr<ResolveContext> context_;
std::vector<uint8_t> formatted_probe_qname_;
std::vector<std::unique_ptr<ProbeStats>> probe_stats_list_;
base::WeakPtrFactory<DnsOverHttpsProbeRunner> weak_ptr_factory_{this};
};
class DnsTransactionImpl final : public DnsTransaction {
public:
DnsTransactionImpl(DnsSession* session,
std::string hostname,
uint16_t qtype,
const NetLogWithSource& parent_net_log,
const OptRecordRdata* opt_rdata,
bool secure,
SecureDnsMode secure_dns_mode,
ResolveContext* resolve_context,
bool fast_timeout)
: session_(session),
hostname_(std::move(hostname)),
qtype_(qtype),
opt_rdata_(opt_rdata),
secure_(secure),
secure_dns_mode_(secure_dns_mode),
fast_timeout_(fast_timeout),
net_log_(NetLogWithSource::Make(NetLog::Get(),
NetLogSourceType::DNS_TRANSACTION)),
resolve_context_(resolve_context->AsSafeRef()) {
DCHECK(session_.get());
DCHECK(!hostname_.empty());
DCHECK(!IsIPLiteral(hostname_));
parent_net_log.AddEventReferencingSource(NetLogEventType::DNS_TRANSACTION,
net_log_.source());
}
DnsTransactionImpl(const DnsTransactionImpl&) = delete;
DnsTransactionImpl& operator=(const DnsTransactionImpl&) = delete;
~DnsTransactionImpl() override {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
if (!callback_.is_null()) {
net_log_.EndEventWithNetErrorCode(NetLogEventType::DNS_TRANSACTION,
ERR_ABORTED);
}
}
const std::string& GetHostname() const override {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
return hostname_;
}
uint16_t GetType() const override {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
return qtype_;
}
void Start(ResponseCallback callback) override {
DCHECK(!callback.is_null());
DCHECK(callback_.is_null());
DCHECK(attempts_.empty());
callback_ = std::move(callback);
net_log_.BeginEvent(NetLogEventType::DNS_TRANSACTION,
[&] { return NetLogStartParams(hostname_, qtype_); });
time_from_start_ = std::make_unique<base::ElapsedTimer>();
AttemptResult result(PrepareSearch(), nullptr);
if (result.rv == OK) {
qnames_initial_size_ = qnames_.size();
result = ProcessAttemptResult(StartQuery());
}
if (result.rv != ERR_IO_PENDING) {
ClearAttempts(result.attempt);
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&DnsTransactionImpl::DoCallback,
weak_ptr_factory_.GetWeakPtr(), result));
}
}
void SetRequestPriority(RequestPriority priority) override {
request_priority_ = priority;
}
#if BUILDFLAG(ARKWEB_EXT_HTTP_DNS_FALLBACK)
void SetNotNeedMoreAttemptIPQueryType(
uint16_t not_need_more_attempt_query_type) override {
not_need_more_attempt_query_type_ = not_need_more_attempt_query_type;
}
#endif
private:
struct AttemptResult {
AttemptResult() = default;
AttemptResult(int rv, const DnsAttempt* attempt)
: rv(rv), attempt(attempt) {}
int rv;
raw_ptr<const DnsAttempt, AcrossTasksDanglingUntriaged> attempt;
};
enum class DnsAttemptType {
kUdp = 0,
kTcpLowEntropy = 1,
kTcpTruncationRetry = 2,
kHttp = 3,
kMaxValue = kHttp,
};
int PrepareSearch() {
const DnsConfig& config = session_->config();
std::optional<std::vector<uint8_t>> labeled_qname =
dns_names_util::DottedNameToNetwork(
hostname_,
true);
if (!labeled_qname.has_value())
return ERR_INVALID_ARGUMENT;
if (hostname_.back() == '.') {
qnames_.push_back(std::move(labeled_qname).value());
return OK;
}
int ndots = CountLabels(labeled_qname.value()) - 1;
if (ndots > 0 && !config.append_to_multi_label_name) {
qnames_.push_back(std::move(labeled_qname).value());
return OK;
}
bool had_qname = false;
if (ndots >= config.ndots) {
qnames_.push_back(labeled_qname.value());
had_qname = true;
}
for (const auto& suffix : config.search) {
std::optional<std::vector<uint8_t>> qname =
dns_names_util::DottedNameToNetwork(
hostname_ + "." + suffix,
true);
if (!qname.has_value())
continue;
if (qname.value().size() == labeled_qname.value().size()) {
if (had_qname)
continue;
had_qname = true;
}
qnames_.push_back(std::move(qname).value());
}
if (ndots > 0 && !had_qname)
qnames_.push_back(std::move(labeled_qname).value());
return qnames_.empty() ? ERR_DNS_SEARCH_EMPTY : OK;
}
void DoCallback(AttemptResult result) {
DCHECK_NE(ERR_IO_PENDING, result.rv);
if (callback_.is_null())
return;
const DnsResponse* response =
result.attempt ? result.attempt->GetResponse() : nullptr;
CHECK(result.rv != OK || response != nullptr);
timer_.Stop();
net_log_.EndEventWithNetErrorCode(NetLogEventType::DNS_TRANSACTION,
result.rv);
std::move(callback_).Run(result.rv, response);
}
void RecordAttemptUma(DnsAttemptType attempt_type) {
UMA_HISTOGRAM_ENUMERATION("Net.DNS.DnsTransaction.AttemptType",
attempt_type);
}
AttemptResult MakeAttempt() {
DCHECK(MoreAttemptsAllowed());
DnsConfig config = session_->config();
if (secure_) {
DCHECK(!config.doh_config.servers().empty());
RecordAttemptUma(DnsAttemptType::kHttp);
return MakeHTTPAttempt();
}
DCHECK_GT(config.nameservers.size(), 0u);
return MakeClassicDnsAttempt();
}
AttemptResult MakeClassicDnsAttempt() {
uint16_t id = session_->NextQueryId();
std::unique_ptr<DnsQuery> query;
if (attempts_.empty()) {
query =
std::make_unique<DnsQuery>(id, qnames_.front(), qtype_, opt_rdata_);
} else {
query = attempts_[0]->GetQuery()->CloneWithNewId(id);
}
DCHECK(dns_server_iterator_->AttemptAvailable());
size_t server_index = dns_server_iterator_->GetNextAttemptIndex();
size_t attempt_number = attempts_.size();
AttemptResult result;
if (session_->udp_tracker()->low_entropy()) {
result = MakeTcpAttempt(server_index, std::move(query));
RecordAttemptUma(DnsAttemptType::kTcpLowEntropy);
} else {
result = MakeUdpAttempt(server_index, std::move(query));
RecordAttemptUma(DnsAttemptType::kUdp);
}
if (result.rv == ERR_IO_PENDING) {
base::TimeDelta fallback_period =
resolve_context_->NextClassicFallbackPeriod(
server_index, attempt_number, session_.get());
timer_.Start(FROM_HERE, fallback_period, this,
&DnsTransactionImpl::OnFallbackPeriodExpired);
}
return result;
}
AttemptResult MakeUdpAttempt(size_t server_index,
std::unique_ptr<DnsQuery> query) {
DCHECK(!secure_);
DCHECK(!session_->udp_tracker()->low_entropy());
const DnsConfig& config = session_->config();
DCHECK_LT(server_index, config.nameservers.size());
size_t attempt_number = attempts_.size();
std::unique_ptr<DatagramClientSocket> socket =
resolve_context_->url_request_context()
->GetNetworkSessionContext()
->client_socket_factory->CreateDatagramClientSocket(
DatagramSocket::RANDOM_BIND, net_log_.net_log(),
net_log_.source());
attempts_.push_back(std::make_unique<DnsUDPAttempt>(
server_index, std::move(socket), config.nameservers[server_index],
std::move(query), session_->udp_tracker()));
++attempts_count_;
DnsAttempt* attempt = attempts_.back().get();
net_log_.AddEventReferencingSource(NetLogEventType::DNS_TRANSACTION_ATTEMPT,
attempt->GetSocketNetLog().source());
int rv = attempt->Start(base::BindOnce(
&DnsTransactionImpl::OnAttemptComplete, base::Unretained(this),
attempt_number, true , base::TimeTicks::Now()));
return AttemptResult(rv, attempt);
}
AttemptResult MakeHTTPAttempt() {
DCHECK(secure_);
size_t doh_server_index = dns_server_iterator_->GetNextAttemptIndex();
uint32_t attempt_number = attempts_.size();
#if BUILDFLAG(ARKWEB_EXT_HTTP_DNS_FALLBACK)
if (resolve_context_->IsHttpsDnsFallbackEnabled()) {
LOG(INFO) << "DOH-Fallback make http fallback attempt for " << hostname_;
}
#endif
ConstructDnsHTTPAttempt(session_.get(), doh_server_index, qnames_.front(),
qtype_, opt_rdata_, &attempts_,
resolve_context_->url_request_context(),
resolve_context_->isolation_info(),
request_priority_, false);
++attempts_count_;
DnsAttempt* attempt = attempts_.back().get();
net_log_.AddEventReferencingSource(
NetLogEventType::DNS_TRANSACTION_HTTPS_ATTEMPT,
attempt->GetSocketNetLog().source());
attempt->GetSocketNetLog().AddEventReferencingSource(
NetLogEventType::DNS_TRANSACTION_HTTPS_ATTEMPT, net_log_.source());
int rv = attempt->Start(base::BindOnce(
&DnsTransactionImpl::OnAttemptComplete, base::Unretained(this),
attempt_number, true , base::TimeTicks::Now()));
if (rv == ERR_IO_PENDING) {
base::TimeDelta fallback_period = resolve_context_->NextDohFallbackPeriod(
doh_server_index, session_.get());
timer_.Start(FROM_HERE, fallback_period, this,
&DnsTransactionImpl::OnFallbackPeriodExpired);
}
return AttemptResult(rv, attempts_.back().get());
}
AttemptResult RetryUdpAttemptAsTcp(const DnsAttempt* previous_attempt) {
DCHECK(previous_attempt);
DCHECK(!had_tcp_retry_);
had_tcp_retry_ = true;
size_t server_index = previous_attempt->server_index();
std::unique_ptr<DnsQuery> query =
previous_attempt->GetQuery()->CloneWithNewId(session_->NextQueryId());
ClearAttempts(nullptr);
AttemptResult result = MakeTcpAttempt(server_index, std::move(query));
RecordAttemptUma(DnsAttemptType::kTcpTruncationRetry);
if (result.rv == ERR_IO_PENDING) {
base::TimeDelta fallback_period = timer_.GetCurrentDelay() * 2;
timer_.Start(FROM_HERE, fallback_period, this,
&DnsTransactionImpl::OnFallbackPeriodExpired);
}
return result;
}
AttemptResult MakeTcpAttempt(size_t server_index,
std::unique_ptr<DnsQuery> query) {
DCHECK(!secure_);
const DnsConfig& config = session_->config();
DCHECK_LT(server_index, config.nameservers.size());
NetworkQualityEstimator* network_quality_estimator = nullptr;
std::unique_ptr<StreamSocket> socket =
resolve_context_->url_request_context()
->GetNetworkSessionContext()
->client_socket_factory->CreateTransportClientSocket(
AddressList(config.nameservers[server_index]), nullptr,
network_quality_estimator, net_log_.net_log(),
net_log_.source());
uint32_t attempt_number = attempts_.size();
attempts_.push_back(std::make_unique<DnsTCPAttempt>(
server_index, std::move(socket), std::move(query)));
++attempts_count_;
DnsAttempt* attempt = attempts_.back().get();
net_log_.AddEventReferencingSource(
NetLogEventType::DNS_TRANSACTION_TCP_ATTEMPT,
attempt->GetSocketNetLog().source());
int rv = attempt->Start(base::BindOnce(
&DnsTransactionImpl::OnAttemptComplete, base::Unretained(this),
attempt_number, false , base::TimeTicks::Now()));
return AttemptResult(rv, attempt);
}
AttemptResult StartQuery() {
std::optional<std::string> dotted_qname =
dns_names_util::NetworkToDottedName(qnames_.front());
net_log_.BeginEventWithStringParams(
NetLogEventType::DNS_TRANSACTION_QUERY, "qname",
dotted_qname.value_or("???MALFORMED_NAME???"));
attempts_.clear();
had_tcp_retry_ = false;
if (secure_) {
dns_server_iterator_ = resolve_context_->GetDohIterator(
session_->config(), secure_dns_mode_, session_.get());
} else {
dns_server_iterator_ = resolve_context_->GetClassicDnsIterator(
session_->config(), session_.get());
}
DCHECK(dns_server_iterator_);
if (!dns_server_iterator_->AttemptAvailable())
return AttemptResult(ERR_BLOCKED_BY_CLIENT, nullptr);
return MakeAttempt();
}
void OnAttemptComplete(uint32_t attempt_number,
bool record_rtt,
base::TimeTicks start,
int rv) {
DCHECK_LT(attempt_number, attempts_.size());
const DnsAttempt* attempt = attempts_[attempt_number].get();
if (record_rtt && attempt->GetResponse()) {
resolve_context_->RecordRtt(
attempt->server_index(), secure_ ,
base::TimeTicks::Now() - start, rv, session_.get());
}
if (callback_.is_null())
return;
AttemptResult result = ProcessAttemptResult(AttemptResult(rv, attempt));
if (result.rv != ERR_IO_PENDING)
DoCallback(result);
}
void LogResponse(const DnsAttempt* attempt) {
if (attempt) {
net_log_.AddEvent(NetLogEventType::DNS_TRANSACTION_RESPONSE,
[&](NetLogCaptureMode capture_mode) {
return attempt->NetLogResponseParams(capture_mode);
});
}
}
bool MoreAttemptsAllowed() const {
if (had_tcp_retry_)
return false;
#if BUILDFLAG(ARKWEB_EXT_HTTP_DNS_FALLBACK)
if (not_need_more_attempt_query_type_ == qtype_) {
return false;
}
#endif
return dns_server_iterator_->AttemptAvailable();
}
AttemptResult ProcessAttemptResult(AttemptResult result) {
while (result.rv != ERR_IO_PENDING) {
LogResponse(result.attempt);
switch (result.rv) {
case OK:
resolve_context_->RecordServerSuccess(result.attempt->server_index(),
secure_ ,
session_.get());
net_log_.EndEventWithNetErrorCode(
NetLogEventType::DNS_TRANSACTION_QUERY, result.rv);
DCHECK(result.attempt);
DCHECK(result.attempt->GetResponse());
return result;
case ERR_NAME_NOT_RESOLVED:
resolve_context_->RecordServerSuccess(result.attempt->server_index(),
secure_ ,
session_.get());
net_log_.EndEventWithNetErrorCode(
NetLogEventType::DNS_TRANSACTION_QUERY, result.rv);
if (!qnames_.empty())
qnames_.pop_front();
if (qnames_.empty()) {
return result;
} else {
result = StartQuery();
}
break;
case ERR_DNS_TIMED_OUT:
timer_.Stop();
if (result.attempt) {
DCHECK(result.attempt == attempts_.back().get());
resolve_context_->RecordServerFailure(
result.attempt->server_index(), secure_ ,
result.rv, session_.get());
}
if (MoreAttemptsAllowed()) {
result = MakeAttempt();
break;
}
if (!fast_timeout_ && AnyAttemptPending()) {
StartTimeoutTimer();
return AttemptResult(ERR_IO_PENDING, nullptr);
}
return result;
case ERR_DNS_SERVER_REQUIRES_TCP:
result = RetryUdpAttemptAsTcp(result.attempt);
break;
case ERR_BLOCKED_BY_CLIENT:
net_log_.EndEventWithNetErrorCode(
NetLogEventType::DNS_TRANSACTION_QUERY, result.rv);
return result;
default:
DCHECK(result.attempt);
if (result.attempt == attempts_.back().get()) {
timer_.Stop();
resolve_context_->RecordServerFailure(
result.attempt->server_index(), secure_ ,
result.rv, session_.get());
if (MoreAttemptsAllowed()) {
result = MakeAttempt();
break;
}
if (fast_timeout_) {
return result;
}
StartTimeoutTimer();
}
if (AnyAttemptPending()) {
DCHECK(timer_.IsRunning());
return AttemptResult(ERR_IO_PENDING, nullptr);
}
return result;
}
}
return result;
}
void ClearAttempts(const DnsAttempt* leave_attempt) {
for (auto it = attempts_.begin(); it != attempts_.end();) {
if ((*it)->IsPending() && it->get() != leave_attempt) {
it = attempts_.erase(it);
} else {
++it;
}
}
}
bool AnyAttemptPending() {
return std::ranges::any_of(attempts_,
[](std::unique_ptr<DnsAttempt>& attempt) {
return attempt->IsPending();
});
}
void OnFallbackPeriodExpired() {
if (callback_.is_null())
return;
DCHECK(!attempts_.empty());
AttemptResult result = ProcessAttemptResult(
AttemptResult(ERR_DNS_TIMED_OUT, attempts_.back().get()));
if (result.rv != ERR_IO_PENDING)
DoCallback(result);
}
void StartTimeoutTimer() {
DCHECK(!fast_timeout_);
DCHECK(!timer_.IsRunning());
DCHECK(!callback_.is_null());
base::TimeDelta timeout;
if (secure_) {
timeout = resolve_context_->SecureTransactionTimeout(secure_dns_mode_,
session_.get());
} else {
timeout = resolve_context_->ClassicTransactionTimeout(session_.get());
}
timeout -= time_from_start_->Elapsed();
timer_.Start(FROM_HERE, timeout, this, &DnsTransactionImpl::OnTimeout);
}
void OnTimeout() {
if (callback_.is_null())
return;
DoCallback(AttemptResult(ERR_DNS_TIMED_OUT, nullptr));
}
scoped_refptr<DnsSession> session_;
std::string hostname_;
uint16_t qtype_;
raw_ptr<const OptRecordRdata, DanglingUntriaged> opt_rdata_;
const bool secure_;
const SecureDnsMode secure_dns_mode_;
ResponseCallback callback_;
const bool fast_timeout_;
NetLogWithSource net_log_;
base::circular_deque<std::vector<uint8_t>> qnames_;
size_t qnames_initial_size_ = 0;
std::vector<std::unique_ptr<DnsAttempt>> attempts_;
int attempts_count_ = 0;
bool had_tcp_retry_ = false;
std::unique_ptr<DnsServerIterator> dns_server_iterator_;
base::OneShotTimer timer_;
std::unique_ptr<base::ElapsedTimer> time_from_start_;
#if BUILDFLAG(ARKWEB_EXT_HTTP_DNS_FALLBACK)
uint16_t not_need_more_attempt_query_type_ = dns_protocol::kTypeANY;
#endif
base::SafeRef<ResolveContext> resolve_context_;
RequestPriority request_priority_ = DEFAULT_PRIORITY;
THREAD_CHECKER(thread_checker_);
base::WeakPtrFactory<DnsTransactionImpl> weak_ptr_factory_{this};
};
class DnsTransactionFactoryImpl : public DnsTransactionFactory {
public:
explicit DnsTransactionFactoryImpl(DnsSession* session) {
session_ = session;
}
std::unique_ptr<DnsTransaction> CreateTransaction(
std::string hostname,
uint16_t qtype,
const NetLogWithSource& net_log,
bool secure,
SecureDnsMode secure_dns_mode,
ResolveContext* resolve_context,
bool fast_timeout) override {
return std::make_unique<DnsTransactionImpl>(
session_.get(), std::move(hostname), qtype, net_log, opt_rdata_.get(),
secure, secure_dns_mode, resolve_context, fast_timeout);
}
std::unique_ptr<DnsProbeRunner> CreateDohProbeRunner(
ResolveContext* resolve_context) override {
resolve_context->StartDohAutoupgradeSuccessTimer(session_.get());
return std::make_unique<DnsOverHttpsProbeRunner>(
session_->GetWeakPtr(), resolve_context->GetWeakPtr());
}
void AddEDNSOption(std::unique_ptr<OptRecordRdata::Opt> opt) override {
DCHECK(opt);
if (opt_rdata_ == nullptr)
opt_rdata_ = std::make_unique<OptRecordRdata>();
opt_rdata_->AddOpt(std::move(opt));
}
SecureDnsMode GetSecureDnsModeForTest() override {
return session_->config().secure_dns_mode;
}
OptRecordRdata* GetOptRdataForTest() override { return opt_rdata_.get(); }
private:
scoped_refptr<DnsSession> session_;
std::unique_ptr<OptRecordRdata> opt_rdata_;
};
}
DnsTransactionFactory::DnsTransactionFactory() = default;
DnsTransactionFactory::~DnsTransactionFactory() = default;
std::unique_ptr<DnsTransactionFactory> DnsTransactionFactory::CreateFactory(
DnsSession* session) {
return std::make_unique<DnsTransactionFactoryImpl>(session);
}
}