#include "net/base/network_change_notifier_win.h"
#include <iphlpapi.h>
#include <winsock2.h>
#include <utility>
#include "base/functional/bind.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/metrics/histogram_macros.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/single_thread_task_runner.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "base/threading/thread.h"
#include "base/time/time.h"
#include "net/base/winsock_init.h"
#include "net/base/winsock_util.h"
namespace net {
namespace {
const int kWatchForAddressChangeRetryIntervalMs = 500;
HRESULT GetConnectionPoints(IUnknown* manager,
REFIID IIDSyncInterface,
IConnectionPoint** connection_point_raw) {
*connection_point_raw = nullptr;
Microsoft::WRL::ComPtr<IConnectionPointContainer> connection_point_container;
HRESULT hr =
manager->QueryInterface(IID_PPV_ARGS(&connection_point_container));
if (FAILED(hr))
return hr;
Microsoft::WRL::ComPtr<IConnectionPoint> connection_point;
hr = connection_point_container->FindConnectionPoint(IIDSyncInterface,
&connection_point);
if (FAILED(hr))
return hr;
*connection_point_raw = connection_point.Get();
(*connection_point_raw)->AddRef();
return hr;
}
}
class NetworkCostManagerEventSink
: public Microsoft::WRL::RuntimeClass<
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
INetworkCostManagerEvents> {
public:
using CostChangedCallback = base::RepeatingCallback<void()>;
NetworkCostManagerEventSink(INetworkCostManager* cost_manager,
const CostChangedCallback& callback)
: network_cost_manager_(cost_manager), cost_changed_callback_(callback) {}
~NetworkCostManagerEventSink() override = default;
IFACEMETHODIMP CostChanged(_In_ DWORD cost,
_In_opt_ NLM_SOCKADDR* ) override {
cost_changed_callback_.Run();
return S_OK;
}
IFACEMETHODIMP DataPlanStatusChanged(
_In_opt_ NLM_SOCKADDR* ) override {
return S_OK;
}
HRESULT RegisterForNotifications() {
Microsoft::WRL::ComPtr<IUnknown> unknown;
HRESULT hr = QueryInterface(IID_PPV_ARGS(&unknown));
if (hr != S_OK)
return hr;
hr = GetConnectionPoints(network_cost_manager_.Get(),
IID_INetworkCostManagerEvents, &connection_point_);
if (hr != S_OK)
return hr;
hr = connection_point_->Advise(unknown.Get(), &cookie_);
return hr;
}
void UnRegisterForNotifications() {
if (connection_point_) {
connection_point_->Unadvise(cookie_);
connection_point_ = nullptr;
cookie_ = 0;
}
}
private:
Microsoft::WRL::ComPtr<INetworkCostManager> network_cost_manager_;
Microsoft::WRL::ComPtr<IConnectionPoint> connection_point_;
DWORD cookie_ = 0;
CostChangedCallback cost_changed_callback_;
};
NetworkChangeNotifierWin::NetworkChangeNotifierWin()
: NetworkChangeNotifier(NetworkChangeCalculatorParamsWin()),
blocking_task_runner_(
base::ThreadPool::CreateSequencedTaskRunner({base::MayBlock()})),
last_computed_connection_type_(RecomputeCurrentConnectionType()),
last_announced_offline_(last_computed_connection_type_ ==
CONNECTION_NONE),
sequence_runner_for_registration_(
base::SequencedTaskRunner::GetCurrentDefault()) {
memset(&addr_overlapped_, 0, sizeof addr_overlapped_);
addr_overlapped_.hEvent = WSACreateEvent();
}
NetworkChangeNotifierWin::~NetworkChangeNotifierWin() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
ClearGlobalPointer();
if (is_watching_) {
CancelIPChangeNotify(&addr_overlapped_);
addr_watcher_.StopWatching();
}
WSACloseEvent(addr_overlapped_.hEvent);
if (network_cost_manager_event_sink_) {
network_cost_manager_event_sink_->UnRegisterForNotifications();
network_cost_manager_event_sink_ = nullptr;
}
}
NetworkChangeNotifier::NetworkChangeCalculatorParams
NetworkChangeNotifierWin::NetworkChangeCalculatorParamsWin() {
NetworkChangeCalculatorParams params;
params.ip_address_offline_delay_ = base::Milliseconds(1500);
params.ip_address_online_delay_ = base::Milliseconds(1500);
params.connection_type_offline_delay_ = base::Milliseconds(1500);
params.connection_type_online_delay_ = base::Milliseconds(500);
return params;
}
NetworkChangeNotifier::ConnectionType
NetworkChangeNotifierWin::RecomputeCurrentConnectionType() {
EnsureWinsockInit();
HANDLE ws_handle;
WSAQUERYSET query_set = {0};
query_set.dwSize = sizeof(WSAQUERYSET);
query_set.dwNameSpace = NS_NLA;
if (0 != WSALookupServiceBegin(&query_set, LUP_RETURN_ALL, &ws_handle)) {
LOG(ERROR) << "WSALookupServiceBegin failed with: " << WSAGetLastError();
return NetworkChangeNotifier::CONNECTION_UNKNOWN;
}
bool found_connection = false;
char result_buffer[sizeof(WSAQUERYSET) + 256] = {0};
DWORD length = sizeof(result_buffer);
reinterpret_cast<WSAQUERYSET*>(&result_buffer[0])->dwSize =
sizeof(WSAQUERYSET);
int result =
WSALookupServiceNext(ws_handle, LUP_RETURN_NAME, &length,
reinterpret_cast<WSAQUERYSET*>(&result_buffer[0]));
if (result == 0) {
found_connection = true;
} else {
DCHECK_EQ(SOCKET_ERROR, result);
result = WSAGetLastError();
if (result == WSAEFAULT) {
found_connection = true;
} else if (result == WSA_E_NO_MORE || result == WSAENOMORE) {
} else {
LOG(WARNING) << "WSALookupServiceNext() failed with:" << result;
}
}
result = WSALookupServiceEnd(ws_handle);
LOG_IF(ERROR, result != 0) << "WSALookupServiceEnd() failed with: " << result;
return found_connection ? ConnectionTypeFromInterfaces()
: NetworkChangeNotifier::CONNECTION_NONE;
}
void NetworkChangeNotifierWin::RecomputeCurrentConnectionTypeOnBlockingSequence(
base::OnceCallback<void(ConnectionType)> reply_callback) const {
blocking_task_runner_->PostTaskAndReplyWithResult(
FROM_HERE,
base::BindOnce(&NetworkChangeNotifierWin::RecomputeCurrentConnectionType),
std::move(reply_callback));
}
NetworkChangeNotifier::ConnectionCost
NetworkChangeNotifierWin::GetCurrentConnectionCost() {
InitializeConnectionCost();
if (!network_cost_manager_event_sink_)
UpdateConnectionCostFromCostManager();
return last_computed_connection_cost_;
}
bool NetworkChangeNotifierWin::InitializeConnectionCostOnce() {
HRESULT hr =
::CoCreateInstance(CLSID_NetworkListManager, nullptr, CLSCTX_ALL,
IID_INetworkCostManager, &network_cost_manager_);
if (FAILED(hr)) {
SetCurrentConnectionCost(CONNECTION_COST_UNKNOWN);
return true;
}
UpdateConnectionCostFromCostManager();
return true;
}
void NetworkChangeNotifierWin::InitializeConnectionCost() {
static bool g_connection_cost_initialized = InitializeConnectionCostOnce();
DCHECK(g_connection_cost_initialized);
}
HRESULT NetworkChangeNotifierWin::UpdateConnectionCostFromCostManager() {
if (!network_cost_manager_)
return E_ABORT;
DWORD cost = NLM_CONNECTION_COST_UNKNOWN;
HRESULT hr = network_cost_manager_->GetCost(&cost, nullptr);
if (FAILED(hr)) {
SetCurrentConnectionCost(CONNECTION_COST_UNKNOWN);
} else {
SetCurrentConnectionCost(
ConnectionCostFromNlmCost((NLM_CONNECTION_COST)cost));
}
return hr;
}
NetworkChangeNotifier::ConnectionCost
NetworkChangeNotifierWin::ConnectionCostFromNlmCost(NLM_CONNECTION_COST cost) {
if (cost == NLM_CONNECTION_COST_UNKNOWN)
return CONNECTION_COST_UNKNOWN;
else if ((cost & NLM_CONNECTION_COST_UNRESTRICTED) != 0)
return CONNECTION_COST_UNMETERED;
else
return CONNECTION_COST_METERED;
}
void NetworkChangeNotifierWin::SetCurrentConnectionCost(
ConnectionCost connection_cost) {
last_computed_connection_cost_ = connection_cost;
}
void NetworkChangeNotifierWin::OnCostChanged() {
ConnectionCost old_cost = last_computed_connection_cost_;
UpdateConnectionCostFromCostManager();
if (old_cost != GetCurrentConnectionCost())
NotifyObserversOfConnectionCostChange();
}
void NetworkChangeNotifierWin::ConnectionCostObserverAdded() {
sequence_runner_for_registration_->PostTask(
FROM_HERE,
base::BindOnce(&NetworkChangeNotifierWin::OnConnectionCostObserverAdded,
weak_factory_.GetWeakPtr()));
}
void NetworkChangeNotifierWin::OnConnectionCostObserverAdded() {
DCHECK(sequence_runner_for_registration_->RunsTasksInCurrentSequence());
InitializeConnectionCost();
if (!network_cost_manager_ || network_cost_manager_event_sink_)
return;
network_cost_manager_event_sink_ =
Microsoft::WRL::Make<net::NetworkCostManagerEventSink>(
network_cost_manager_.Get(),
base::BindRepeating(&NetworkChangeNotifierWin::OnCostChanged,
weak_factory_.GetWeakPtr()));
HRESULT hr = network_cost_manager_event_sink_->RegisterForNotifications();
if (hr != S_OK) {
network_cost_manager_event_sink_ = nullptr;
}
}
NetworkChangeNotifier::ConnectionType
NetworkChangeNotifierWin::GetCurrentConnectionType() const {
base::AutoLock auto_lock(last_computed_connection_type_lock_);
return last_computed_connection_type_;
}
void NetworkChangeNotifierWin::SetCurrentConnectionType(
ConnectionType connection_type) {
base::AutoLock auto_lock(last_computed_connection_type_lock_);
last_computed_connection_type_ = connection_type;
}
void NetworkChangeNotifierWin::OnObjectSignaled(HANDLE object) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(is_watching_);
is_watching_ = false;
WatchForAddressChange();
RecomputeCurrentConnectionTypeOnBlockingSequence(base::BindOnce(
&NetworkChangeNotifierWin::NotifyObservers, weak_factory_.GetWeakPtr()));
}
void NetworkChangeNotifierWin::NotifyObservers(ConnectionType connection_type) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
SetCurrentConnectionType(connection_type);
NotifyObserversOfIPAddressChange();
offline_polls_ = 0;
timer_.Start(FROM_HERE, base::Seconds(1), this,
&NetworkChangeNotifierWin::NotifyParentOfConnectionTypeChange);
}
void NetworkChangeNotifierWin::WatchForAddressChange() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(!is_watching_);
if (!WatchForAddressChangeInternal()) {
++sequential_failures_;
base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&NetworkChangeNotifierWin::WatchForAddressChange,
weak_factory_.GetWeakPtr()),
base::Milliseconds(kWatchForAddressChangeRetryIntervalMs));
return;
}
if (sequential_failures_ > 0) {
RecomputeCurrentConnectionTypeOnBlockingSequence(
base::BindOnce(&NetworkChangeNotifierWin::NotifyObservers,
weak_factory_.GetWeakPtr()));
}
is_watching_ = true;
sequential_failures_ = 0;
}
bool NetworkChangeNotifierWin::WatchForAddressChangeInternal() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
ResetEventIfSignaled(addr_overlapped_.hEvent);
HANDLE handle = nullptr;
DWORD ret = NotifyAddrChange(&handle, &addr_overlapped_);
if (ret != ERROR_IO_PENDING)
return false;
addr_watcher_.StartWatchingOnce(addr_overlapped_.hEvent, this);
return true;
}
void NetworkChangeNotifierWin::NotifyParentOfConnectionTypeChange() {
RecomputeCurrentConnectionTypeOnBlockingSequence(base::BindOnce(
&NetworkChangeNotifierWin::NotifyParentOfConnectionTypeChangeImpl,
weak_factory_.GetWeakPtr()));
}
void NetworkChangeNotifierWin::NotifyParentOfConnectionTypeChangeImpl(
ConnectionType connection_type) {
SetCurrentConnectionType(connection_type);
bool current_offline = IsOffline();
offline_polls_++;
if (last_announced_offline_ && current_offline && offline_polls_ <= 20) {
timer_.Start(FROM_HERE, base::Seconds(1), this,
&NetworkChangeNotifierWin::NotifyParentOfConnectionTypeChange);
return;
}
if (last_announced_offline_)
UMA_HISTOGRAM_CUSTOM_COUNTS("NCN.OfflinePolls", offline_polls_, 1, 50, 50);
last_announced_offline_ = current_offline;
NotifyObserversOfConnectionTypeChange();
double max_bandwidth_mbps = 0.0;
ConnectionType max_connection_type = CONNECTION_NONE;
GetCurrentMaxBandwidthAndConnectionType(&max_bandwidth_mbps,
&max_connection_type);
NotifyObserversOfMaxBandwidthChange(max_bandwidth_mbps, max_connection_type);
}
}