#ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_
#define NET_SOCKET_SOCKET_TEST_UTIL_H_
#include <stddef.h>
#include <stdint.h>
#include <algorithm>
#include <cstring>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "base/check_op.h"
#include "base/containers/span.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/memory/ptr_util.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/raw_span.h"
#include "base/memory/ref_counted.h"
#include "base/memory/weak_ptr.h"
#include "base/run_loop.h"
#include "base/strings/string_view_util.h"
#include "build/build_config.h"
#include "net/base/address_list.h"
#include "net/base/completion_once_callback.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/base/test_completion_callback.h"
#include "net/http/http_auth_controller.h"
#include "net/log/net_log_with_source.h"
#include "net/socket/client_socket_factory.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/client_socket_pool.h"
#include "net/socket/datagram_client_socket.h"
#include "net/socket/socket_performance_watcher.h"
#include "net/socket/socket_tag.h"
#include "net/socket/ssl_client_socket.h"
#include "net/socket/transport_client_socket.h"
#include "net/socket/transport_client_socket_pool.h"
#include "net/ssl/ssl_config_service.h"
#include "net/ssl/ssl_info.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace base {
class RunLoop;
}
namespace net {
struct CommonConnectJobParams;
class NetLog;
struct NetworkTrafficAnnotationTag;
class X509Certificate;
const handles::NetworkHandle kDefaultNetworkForTests = 1;
const handles::NetworkHandle kNewNetworkForTests = 2;
enum {
ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000,
};
class AsyncSocket;
class MockClientSocket;
class MockTCPClientSocket;
class MockSSLClientSocket;
class SSLClientSocket;
class StreamSocket;
enum IoMode { ASYNC, SYNCHRONOUS };
class MockConnectCompleter {
public:
MockConnectCompleter();
MockConnectCompleter(const MockConnectCompleter&) = delete;
MockConnectCompleter& operator=(const MockConnectCompleter&) = delete;
~MockConnectCompleter();
void WaitForConnect();
void Complete(int result);
void WaitForConnectAndComplete(int result);
private:
friend class MockTCPClientSocket;
friend class MockSSLClientSocket;
friend class MockUDPClientSocket;
void SetCallback(CompletionOnceCallback callback);
CompletionOnceCallback callback_;
base::RunLoop run_loop_;
};
struct MockConnect {
MockConnect();
MockConnect(IoMode io_mode, int r);
MockConnect(IoMode io_mode, int r, IPEndPoint addr);
MockConnect(IoMode io_mode, int r, IPEndPoint addr, bool first_attempt_fails);
explicit MockConnect(MockConnectCompleter* completer);
~MockConnect();
IoMode mode;
int result;
IPEndPoint peer_addr;
bool first_attempt_fails = false;
raw_ptr<MockConnectCompleter> completer;
};
struct MockConfirm {
MockConfirm();
MockConfirm(IoMode io_mode, int r);
~MockConfirm();
IoMode mode;
int result;
};
enum MockReadWriteType { MOCK_READ, MOCK_WRITE };
template <MockReadWriteType type>
struct MockReadWrite {
enum { STOPLOOP = 1 << 31 };
class ToStringView {
public:
ToStringView(std::string_view data) : data_(data) {}
ToStringView(std::string& data) : data_(data) {}
ToStringView(const std::string& data) : data_(data) {}
ToStringView(std::string&&) = delete;
template <size_t N>
ToStringView(const char (&data)[N]) : data_(data, N - 1) {
CHECK_EQ(data[N - 1], '\0');
CHECK(std::ranges::none_of(data_, [](char c) { return c == '\0'; }));
}
template <size_t Extent>
ToStringView(base::span<const char, Extent> data)
: data_(base::as_string_view(data)) {}
template <size_t Extent>
ToStringView(base::span<const uint8_t, Extent> data)
: data_(base::as_string_view(data)) {}
template <size_t Extent>
ToStringView(base::span<char, Extent> data)
: data_(base::as_string_view(data)) {}
template <size_t Extent>
ToStringView(base::span<uint8_t, Extent> data)
: data_(base::as_string_view(data)) {}
~ToStringView() = default;
explicit operator std::string_view() const { return data_; }
private:
const std::string_view data_;
};
MockReadWrite()
: mode(SYNCHRONOUS),
result(0),
sequence_number(0),
tos(0) {}
MockReadWrite(IoMode io_mode, int result)
: mode(io_mode),
result(result),
sequence_number(0),
tos(0) {}
MockReadWrite(IoMode io_mode, int result, int seq)
: mode(io_mode),
result(result),
sequence_number(seq),
tos(0) {}
explicit MockReadWrite(ToStringView data)
: mode(ASYNC), result(0), data(data), sequence_number(0), tos(0) {}
MockReadWrite(IoMode io_mode, const char* data, int data_len) = delete;
MockReadWrite(IoMode io_mode, int seq, ToStringView data)
: mode(io_mode), result(0), data(data), sequence_number(seq), tos(0) {}
MockReadWrite(IoMode io_mode,
ToStringView data,
int result = 0,
int seq = 0,
uint8_t tos_byte = 0)
: mode(io_mode),
result(result),
data(data),
sequence_number(seq),
tos(tos_byte) {}
IoMode mode;
int result;
std::string_view data;
int sequence_number;
uint8_t tos;
};
typedef MockReadWrite<MOCK_READ> MockRead;
typedef MockReadWrite<MOCK_WRITE> MockWrite;
struct MockWriteResult {
MockWriteResult(IoMode io_mode, int result) : mode(io_mode), result(result) {}
IoMode mode;
int result;
};
class SocketDataPrinter {
public:
~SocketDataPrinter() = default;
virtual std::string PrintWrite(std::string_view data) = 0;
};
class SocketDataProvider {
public:
SocketDataProvider();
SocketDataProvider(const SocketDataProvider&) = delete;
SocketDataProvider& operator=(const SocketDataProvider&) = delete;
virtual ~SocketDataProvider();
virtual MockRead OnRead() = 0;
virtual MockWriteResult OnWrite(const std::string& data) = 0;
virtual bool AllReadDataConsumed() const = 0;
virtual bool AllWriteDataConsumed() const = 0;
virtual void CancelPendingRead() {}
int receive_buffer_size() const { return receive_buffer_size_; }
void set_receive_buffer_size(int receive_buffer_size) {
receive_buffer_size_ = receive_buffer_size;
}
int send_buffer_size() const { return send_buffer_size_; }
void set_send_buffer_size(int send_buffer_size) {
send_buffer_size_ = send_buffer_size;
}
bool no_delay() const { return no_delay_; }
void set_no_delay(bool no_delay) { no_delay_ = no_delay; }
enum class KeepAliveState { kEnabled, kDisabled, kDefault };
KeepAliveState keep_alive_state() const { return keep_alive_state_; }
int keep_alive_delay() const { return keep_alive_delay_; }
void set_keep_alive(bool enable, int delay) {
keep_alive_state_ =
enable ? KeepAliveState::kEnabled : KeepAliveState::kDisabled;
keep_alive_delay_ = delay;
}
void set_set_receive_buffer_size_result(int receive_buffer_size_result) {
set_receive_buffer_size_result_ = receive_buffer_size_result;
}
int set_receive_buffer_size_result() const {
return set_receive_buffer_size_result_;
}
void set_set_send_buffer_size_result(int set_send_buffer_size_result) {
set_send_buffer_size_result_ = set_send_buffer_size_result;
}
int set_send_buffer_size_result() const {
return set_send_buffer_size_result_;
}
void set_set_no_delay_result(bool set_no_delay_result) {
set_no_delay_result_ = set_no_delay_result;
}
bool set_no_delay_result() const { return set_no_delay_result_; }
void set_set_keep_alive_result(bool set_keep_alive_result) {
set_keep_alive_result_ = set_keep_alive_result;
}
bool set_keep_alive_result() const { return set_keep_alive_result_; }
const std::optional<AddressList>& expected_addresses() const {
return expected_addresses_;
}
void set_expected_addresses(net::AddressList addresses) {
expected_addresses_ = std::move(addresses);
}
virtual bool IsIdle() const;
void Initialize(AsyncSocket* socket);
void DetachSocket();
AsyncSocket* socket() { return socket_; }
MockConnect connect_data() const { return connect_; }
void set_connect_data(const MockConnect& connect) { connect_ = connect; }
private:
virtual void Reset() = 0;
MockConnect connect_;
raw_ptr<AsyncSocket> socket_ = nullptr;
int receive_buffer_size_ = -1;
int send_buffer_size_ = -1;
bool no_delay_ = true;
KeepAliveState keep_alive_state_ = KeepAliveState::kDefault;
int keep_alive_delay_ = 0;
int set_receive_buffer_size_result_ = net::OK;
int set_send_buffer_size_result_ = net::OK;
bool set_no_delay_result_ = true;
bool set_keep_alive_result_ = true;
std::optional<AddressList> expected_addresses_;
};
class AsyncSocket {
public:
virtual void OnReadComplete(const MockRead& data) = 0;
virtual void OnWriteComplete(int rv) = 0;
virtual void OnConnectComplete(const MockConnect& data) = 0;
virtual void OnDataProviderDestroyed() = 0;
};
class StaticSocketDataHelper {
public:
StaticSocketDataHelper(base::span<const MockRead> reads,
base::span<const MockWrite> writes);
StaticSocketDataHelper(const StaticSocketDataHelper&) = delete;
StaticSocketDataHelper& operator=(const StaticSocketDataHelper&) = delete;
~StaticSocketDataHelper();
const MockRead& PeekRead() const;
const MockWrite& PeekWrite() const;
const MockRead& AdvanceRead();
const MockWrite& AdvanceWrite();
void Reset();
bool VerifyWriteData(const std::string& data, SocketDataPrinter* printer);
size_t read_index() const { return read_index_; }
size_t write_index() const { return write_index_; }
size_t read_count() const { return reads_.size(); }
size_t write_count() const { return writes_.size(); }
bool AllReadDataConsumed() const { return read_index() >= read_count(); }
bool AllWriteDataConsumed() const { return write_index() >= write_count(); }
void ExpectAllReadDataConsumed(SocketDataPrinter* printer) const;
void ExpectAllWriteDataConsumed(SocketDataPrinter* printer) const;
private:
const MockWrite& PeekRealWrite() const;
const base::raw_span<const MockRead, DanglingUntriaged> reads_;
size_t read_index_ = 0;
const base::raw_span<const MockWrite, DanglingUntriaged> writes_;
size_t write_index_ = 0;
};
class StaticSocketDataProvider : public SocketDataProvider {
public:
StaticSocketDataProvider();
StaticSocketDataProvider(base::span<const MockRead> reads,
base::span<const MockWrite> writes);
StaticSocketDataProvider(const StaticSocketDataProvider&) = delete;
StaticSocketDataProvider& operator=(const StaticSocketDataProvider&) = delete;
~StaticSocketDataProvider() override;
void Pause();
void Resume();
void ExpectAllReadDataConsumed() const;
void ExpectAllWriteDataConsumed() const;
MockRead OnRead() override;
MockWriteResult OnWrite(const std::string& data) override;
bool AllReadDataConsumed() const override;
bool AllWriteDataConsumed() const override;
size_t read_index() const { return helper_.read_index(); }
size_t write_index() const { return helper_.write_index(); }
size_t read_count() const { return helper_.read_count(); }
size_t write_count() const { return helper_.write_count(); }
void set_printer(SocketDataPrinter* printer) { printer_ = printer; }
private:
void Reset() override;
StaticSocketDataHelper helper_;
raw_ptr<SocketDataPrinter> printer_ = nullptr;
bool paused_ = false;
};
struct SSLSocketDataProvider {
SSLSocketDataProvider(IoMode mode, int result);
explicit SSLSocketDataProvider(MockConnectCompleter* completer);
SSLSocketDataProvider(const SSLSocketDataProvider& other);
~SSLSocketDataProvider();
bool ConnectDataConsumed() const { return is_connect_data_consumed; }
bool ConfirmDataConsumed() const { return is_confirm_data_consumed; }
bool WriteBeforeConfirm() const { return write_called_before_confirm; }
MockConnect connect;
base::RepeatingClosure connect_callback;
MockConfirm confirm;
base::RepeatingClosure confirm_callback;
NextProto next_proto = NextProto::kProtoUnknown;
std::optional<std::string> peer_application_settings;
SSLInfo ssl_info;
scoped_refptr<SSLCertRequestInfo> cert_request_info;
std::vector<uint8_t> ech_retry_configs;
std::vector<std::vector<uint8_t>> server_trust_anchor_ids_for_retry;
std::optional<NextProtoVector> next_protos_expected_in_ssl_config;
std::optional<SSLConfig::ApplicationSettings> expected_application_settings;
uint16_t expected_ssl_version_min;
uint16_t expected_ssl_version_max;
std::optional<bool> expected_early_data_enabled;
std::optional<bool> expected_send_client_cert;
scoped_refptr<X509Certificate> expected_client_cert;
std::optional<HostPortPair> expected_host_and_port;
std::optional<bool> expected_ignore_certificate_errors;
std::optional<NetworkAnonymizationKey> expected_network_anonymization_key;
std::optional<std::vector<uint8_t>> expected_ech_config_list;
std::optional<std::vector<uint8_t>> expected_trust_anchor_ids;
bool expect_no_trust_anchor_ids = false;
bool is_connect_data_consumed = false;
bool is_confirm_data_consumed = false;
bool write_called_before_confirm = false;
};
class SequencedSocketData : public SocketDataProvider {
public:
SequencedSocketData();
SequencedSocketData(base::span<const MockRead> reads,
base::span<const MockWrite> writes);
SequencedSocketData(const MockConnect& connect,
base::span<const MockRead> reads,
base::span<const MockWrite> writes);
SequencedSocketData(const SequencedSocketData&) = delete;
SequencedSocketData& operator=(const SequencedSocketData&) = delete;
~SequencedSocketData() override;
MockRead OnRead() override;
MockWriteResult OnWrite(const std::string& data) override;
bool AllReadDataConsumed() const override;
bool AllWriteDataConsumed() const override;
bool IsIdle() const override;
void CancelPendingRead() override;
void ExpectAllReadDataConsumed() const;
void ExpectAllWriteDataConsumed() const;
bool IsPaused() const;
void Resume();
void RunUntilPaused();
void set_busy_before_sync_reads(bool busy_before_sync_reads) {
busy_before_sync_reads_ = busy_before_sync_reads;
}
void set_printer(SocketDataPrinter* printer) { printer_ = printer; }
private:
enum class IoState {
kIdle,
kPending,
kCompleting,
kPaused,
};
void Reset() override;
void OnReadComplete();
void OnWriteComplete();
void MaybePostReadCompleteTask();
void MaybePostWriteCompleteTask();
StaticSocketDataHelper helper_;
raw_ptr<SocketDataPrinter> printer_ = nullptr;
int sequence_number_ = 0;
IoState read_state_ = IoState::kIdle;
IoState write_state_ = IoState::kIdle;
bool busy_before_sync_reads_ = false;
std::unique_ptr<base::RunLoop> run_until_paused_run_loop_;
base::WeakPtrFactory<SequencedSocketData> weak_factory_{this};
};
template <typename T>
class SocketDataProviderArray {
public:
SocketDataProviderArray() = default;
T* GetNext() {
DCHECK_LT(next_index_, data_providers_.size());
return data_providers_[next_index_++];
}
T* GetNextWithoutAsserting() {
if (next_index_ == data_providers_.size())
return nullptr;
return data_providers_[next_index_++];
}
void Add(T* data_provider) {
DCHECK(data_provider);
data_providers_.push_back(data_provider);
}
size_t next_index() { return next_index_; }
void ResetNextIndex() { next_index_ = 0; }
private:
size_t next_index_ = 0;
std::vector<T*> data_providers_;
};
class MockUDPClientSocket;
class MockTCPClientSocket;
class MockSSLClientSocket;
class MockClientSocketFactory : public ClientSocketFactory {
public:
MockClientSocketFactory();
MockClientSocketFactory(const MockClientSocketFactory&) = delete;
MockClientSocketFactory& operator=(const MockClientSocketFactory&) = delete;
~MockClientSocketFactory() override;
void AddSocketDataProvider(SocketDataProvider* socket);
void AddTcpSocketDataProvider(SocketDataProvider* socket);
void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
void ResetNextMockIndexes();
SocketDataProviderArray<SocketDataProvider>& mock_data() {
return mock_data_;
}
void set_enable_read_if_ready(bool enable_read_if_ready) {
enable_read_if_ready_ = enable_read_if_ready;
}
std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
NetLog* net_log,
const NetLogSource& source) override;
std::unique_ptr<TransportClientSocket> CreateTransportClientSocket(
const AddressList& addresses,
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetworkQualityEstimator* network_quality_estimator,
NetLog* net_log,
const NetLogSource& source) override;
std::unique_ptr<SSLClientSocket> CreateSSLClientSocket(
SSLClientContext* context,
std::unique_ptr<StreamSocket> stream_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config) override;
const std::vector<uint16_t>& udp_client_socket_ports() const {
return udp_client_socket_ports_;
}
private:
SocketDataProviderArray<SocketDataProvider> mock_data_;
SocketDataProviderArray<SocketDataProvider> mock_tcp_data_;
SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
std::vector<uint16_t> udp_client_socket_ports_;
bool enable_read_if_ready_ = false;
};
class MockClientSocket : public TransportClientSocket {
public:
explicit MockClientSocket(const NetLogWithSource& net_log);
MockClientSocket(const MockClientSocket&) = delete;
MockClientSocket& operator=(const MockClientSocket&) = delete;
int Read(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) override = 0;
int Write(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback,
const NetworkTrafficAnnotationTag& traffic_annotation) override = 0;
int SetReceiveBufferSize(int32_t size) override;
int SetSendBufferSize(int32_t size) override;
int Bind(const net::IPEndPoint& local_addr) override;
bool SetNoDelay(bool no_delay) override;
bool SetKeepAlive(bool enable, int delay) override;
int Connect(CompletionOnceCallback callback) override = 0;
void Disconnect() override;
bool IsConnected() const override;
bool IsConnectedAndIdle() const override;
int GetPeerAddress(IPEndPoint* address) const override;
int GetLocalAddress(IPEndPoint* address) const override;
const NetLogWithSource& NetLog() const override;
NextProto GetNegotiatedProtocol() const override;
int64_t GetTotalReceivedBytes() const override;
void ApplySocketTag(const SocketTag& tag) override {}
protected:
~MockClientSocket() override;
void RunCallbackAsync(CompletionOnceCallback callback, int result);
void RunCallback(CompletionOnceCallback callback, int result);
bool connected_ = false;
IPEndPoint local_addr_;
IPEndPoint peer_addr_;
NetLogWithSource net_log_;
private:
base::WeakPtrFactory<MockClientSocket> weak_factory_{this};
};
class MockTCPClientSocket : public MockClientSocket, public AsyncSocket {
public:
MockTCPClientSocket(const AddressList& addresses,
net::NetLog* net_log,
SocketDataProvider* socket);
MockTCPClientSocket(const MockTCPClientSocket&) = delete;
MockTCPClientSocket& operator=(const MockTCPClientSocket&) = delete;
~MockTCPClientSocket() override;
const AddressList& addresses() const { return addresses_; }
int Read(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) override;
int ReadIfReady(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) override;
int CancelReadIfReady() override;
int Write(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback,
const NetworkTrafficAnnotationTag& traffic_annotation) override;
int SetReceiveBufferSize(int32_t size) override;
int SetSendBufferSize(int32_t size) override;
bool SetNoDelay(bool no_delay) override;
bool SetKeepAlive(bool enable, int delay) override;
void SetBeforeConnectCallback(
const BeforeConnectCallback& before_connect_callback) override;
int Connect(CompletionOnceCallback callback) override;
void Disconnect() override;
bool IsConnected() const override;
bool IsConnectedAndIdle() const override;
int GetPeerAddress(IPEndPoint* address) const override;
bool WasEverUsed() const override;
bool GetSSLInfo(SSLInfo* ssl_info) override;
void OnReadComplete(const MockRead& data) override;
void OnWriteComplete(int rv) override;
void OnConnectComplete(const MockConnect& data) override;
void OnDataProviderDestroyed() override;
void set_enable_read_if_ready(bool enable_read_if_ready) {
enable_read_if_ready_ = enable_read_if_ready;
}
private:
void RetryRead(int rv);
int ReadIfReadyImpl(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback);
void RunReadIfReadyCallback(int result);
AddressList addresses_;
raw_ptr<SocketDataProvider> data_;
int read_offset_ = 0;
MockRead read_data_;
bool need_read_data_ = true;
bool peer_closed_connection_ = false;
scoped_refptr<IOBuffer> pending_read_buf_ = nullptr;
int pending_read_buf_len_ = 0;
CompletionOnceCallback pending_read_callback_;
CompletionOnceCallback pending_read_if_ready_callback_;
CompletionOnceCallback pending_connect_callback_;
CompletionOnceCallback pending_write_callback_;
bool was_used_to_convey_data_ = false;
bool enable_read_if_ready_ = false;
BeforeConnectCallback before_connect_callback_;
};
class MockSSLClientSocket : public AsyncSocket, public SSLClientSocket {
public:
MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
SSLSocketDataProvider* socket);
MockSSLClientSocket(const MockSSLClientSocket&) = delete;
MockSSLClientSocket& operator=(const MockSSLClientSocket&) = delete;
~MockSSLClientSocket() override;
int Read(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) override;
int ReadIfReady(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) override;
int Write(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback,
const NetworkTrafficAnnotationTag& traffic_annotation) override;
int CancelReadIfReady() override;
int Connect(CompletionOnceCallback callback) override;
void Disconnect() override;
int ConfirmHandshake(CompletionOnceCallback callback) override;
bool IsConnected() const override;
bool IsConnectedAndIdle() const override;
bool WasEverUsed() const override;
int GetPeerAddress(IPEndPoint* address) const override;
int GetLocalAddress(IPEndPoint* address) const override;
NextProto GetNegotiatedProtocol() const override;
std::optional<std::string_view> GetPeerApplicationSettings() const override;
bool GetSSLInfo(SSLInfo* ssl_info) override;
void GetSSLCertRequestInfo(
SSLCertRequestInfo* cert_request_info) const override;
void ApplySocketTag(const SocketTag& tag) override;
const NetLogWithSource& NetLog() const override;
int64_t GetTotalReceivedBytes() const override;
int SetReceiveBufferSize(int32_t size) override;
int SetSendBufferSize(int32_t size) override;
int ExportKeyingMaterial(std::string_view label,
std::optional<base::span<const uint8_t>> context,
base::span<uint8_t> out) override;
std::vector<uint8_t> GetECHRetryConfigs() override;
std::vector<std::vector<uint8_t>> GetServerTrustAnchorIDsForRetry() override;
void OnReadComplete(const MockRead& data) override;
void OnWriteComplete(int rv) override;
void OnConnectComplete(const MockConnect& data) override;
void OnDataProviderDestroyed() override {}
private:
static void ConnectCallback(MockSSLClientSocket* ssl_client_socket,
CompletionOnceCallback callback,
int rv);
void RunCallbackAsync(CompletionOnceCallback callback, int result);
void RunCallback(CompletionOnceCallback callback, int result);
void RunConfirmHandshakeCallback(CompletionOnceCallback callback, int result);
bool connected_ = false;
bool in_confirm_handshake_ = false;
NetLogWithSource net_log_;
std::unique_ptr<StreamSocket> stream_socket_;
raw_ptr<SSLSocketDataProvider, AcrossTasksDanglingUntriaged> data_;
IPEndPoint peer_addr_;
base::WeakPtrFactory<MockSSLClientSocket> weak_factory_{this};
};
class MockUDPClientSocket : public DatagramClientSocket, public AsyncSocket {
public:
explicit MockUDPClientSocket(SocketDataProvider* data = nullptr,
net::NetLog* net_log = nullptr);
MockUDPClientSocket(const MockUDPClientSocket&) = delete;
MockUDPClientSocket& operator=(const MockUDPClientSocket&) = delete;
~MockUDPClientSocket() override;
int Read(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) override;
int Write(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback,
const NetworkTrafficAnnotationTag& traffic_annotation) override;
int SetReceiveBufferSize(int32_t size) override;
int SetSendBufferSize(int32_t size) override;
int SetDoNotFragment() override;
int SetRecvTos() override;
int SetTos(DiffServCodePoint dscp, EcnCodePoint ecn) override;
void Close() override;
int GetPeerAddress(IPEndPoint* address) const override;
int GetLocalAddress(IPEndPoint* address) const override;
void UseNonBlockingIO() override;
int SetMulticastInterface(uint32_t interface_index) override;
const NetLogWithSource& NetLog() const override;
int Connect(const IPEndPoint& address) override;
int ConnectUsingNetwork(handles::NetworkHandle network,
const IPEndPoint& address) override;
int ConnectUsingDefaultNetwork(const IPEndPoint& address) override;
int ConnectAsync(const IPEndPoint& address,
CompletionOnceCallback callback) override;
int ConnectUsingNetworkAsync(handles::NetworkHandle network,
const IPEndPoint& address,
CompletionOnceCallback callback) override;
int ConnectUsingDefaultNetworkAsync(const IPEndPoint& address,
CompletionOnceCallback callback) override;
handles::NetworkHandle GetBoundNetwork() const override;
void ApplySocketTag(const SocketTag& tag) override;
void SetMsgConfirm(bool confirm) override {}
DscpAndEcn GetLastTos() const override;
void OnReadComplete(const MockRead& data) override;
void OnWriteComplete(int rv) override;
void OnConnectComplete(const MockConnect& data) override;
void OnDataProviderDestroyed() override;
void set_source_port(uint16_t port) { source_port_ = port; }
uint16_t source_port() const { return source_port_; }
void set_source_host(IPAddress addr) { source_host_ = addr; }
IPAddress source_host() const { return source_host_; }
SocketTag tag() const { return tag_; }
bool tagged_before_data_transferred() const {
return tagged_before_data_transferred_;
}
EcnCodePoint outgoing_ecn() const { return outgoing_ecn_; }
private:
int CompleteRead();
void RunCallbackAsync(CompletionOnceCallback callback, int result);
void RunCallback(CompletionOnceCallback callback, int result);
bool connected_ = false;
raw_ptr<SocketDataProvider> data_;
int read_offset_ = 0;
MockRead read_data_;
bool need_read_data_ = true;
IPAddress source_host_;
uint16_t source_port_ = 123;
IPEndPoint peer_addr_;
handles::NetworkHandle network_ = handles::kInvalidNetworkHandle;
scoped_refptr<IOBuffer> pending_read_buf_ = nullptr;
int pending_read_buf_len_ = 0;
CompletionOnceCallback pending_read_callback_;
CompletionOnceCallback pending_write_callback_;
NetLogWithSource net_log_;
SocketTag tag_;
bool data_transferred_ = false;
bool tagged_before_data_transferred_ = true;
uint8_t last_tos_ = 0;
EcnCodePoint outgoing_ecn_ = net::ECN_NOT_ECT;
base::WeakPtrFactory<MockUDPClientSocket> weak_factory_{this};
};
class TestSocketRequest : public TestCompletionCallbackBase {
public:
TestSocketRequest(std::vector<raw_ptr<TestSocketRequest, VectorExperimental>>*
request_order,
size_t* completion_count);
TestSocketRequest(const TestSocketRequest&) = delete;
TestSocketRequest& operator=(const TestSocketRequest&) = delete;
~TestSocketRequest() override;
ClientSocketHandle* handle() { return &handle_; }
CompletionOnceCallback callback() {
return base::BindOnce(&TestSocketRequest::OnComplete,
base::Unretained(this));
}
private:
void OnComplete(int result);
ClientSocketHandle handle_;
raw_ptr<std::vector<raw_ptr<TestSocketRequest, VectorExperimental>>>
request_order_;
raw_ptr<size_t> completion_count_;
};
class ClientSocketPoolTest {
public:
enum KeepAlive {
KEEP_ALIVE,
NO_KEEP_ALIVE,
};
static const int kIndexOutOfBounds;
static const int kRequestNotFound;
ClientSocketPoolTest();
ClientSocketPoolTest(const ClientSocketPoolTest&) = delete;
ClientSocketPoolTest& operator=(const ClientSocketPoolTest&) = delete;
~ClientSocketPoolTest();
template <typename PoolType>
int StartRequestUsingPool(
PoolType* socket_pool,
const ClientSocketPool::GroupId& group_id,
RequestPriority priority,
ClientSocketPool::RespectLimits respect_limits,
const scoped_refptr<typename PoolType::SocketParams>& socket_params) {
DCHECK(socket_pool);
TestSocketRequest* request(
new TestSocketRequest(&request_order_, &completion_count_));
requests_.push_back(base::WrapUnique(request));
int rv = request->handle()->Init(
group_id, socket_params, std::nullopt ,
priority, SocketTag(), respect_limits, request->callback(),
ClientSocketPool::ProxyAuthCallback(),
false, socket_pool,
NetLogWithSource());
if (rv != ERR_IO_PENDING)
request_order_.push_back(request);
return rv;
}
int GetOrderOfRequest(size_t index) const;
bool ReleaseOneConnection(KeepAlive keep_alive);
void ReleaseAllConnections(KeepAlive keep_alive);
TestSocketRequest* request(int i) { return requests_[i].get(); }
size_t requests_size() const { return requests_.size(); }
std::vector<std::unique_ptr<TestSocketRequest>>* requests() {
return &requests_;
}
size_t completion_count() const { return completion_count_; }
private:
std::vector<std::unique_ptr<TestSocketRequest>> requests_;
std::vector<raw_ptr<TestSocketRequest, VectorExperimental>> request_order_;
size_t completion_count_ = 0;
};
class MockTransportSocketParams
: public base::RefCounted<MockTransportSocketParams> {
public:
MockTransportSocketParams(const MockTransportSocketParams&) = delete;
MockTransportSocketParams& operator=(const MockTransportSocketParams&) =
delete;
private:
friend class base::RefCounted<MockTransportSocketParams>;
~MockTransportSocketParams() = default;
};
class MockTransportClientSocketPool : public TransportClientSocketPool {
public:
class MockConnectJob {
public:
MockConnectJob(std::unique_ptr<StreamSocket> socket,
ClientSocketHandle* handle,
const SocketTag& socket_tag,
CompletionOnceCallback callback,
RequestPriority priority);
MockConnectJob(const MockConnectJob&) = delete;
MockConnectJob& operator=(const MockConnectJob&) = delete;
~MockConnectJob();
int Connect();
bool CancelHandle(const ClientSocketHandle* handle);
ClientSocketHandle* handle() const { return handle_; }
RequestPriority priority() const { return priority_; }
void set_priority(RequestPriority priority) { priority_ = priority; }
private:
void OnConnect(int rv);
std::unique_ptr<StreamSocket> socket_;
raw_ptr<ClientSocketHandle> handle_;
const SocketTag socket_tag_;
CompletionOnceCallback user_callback_;
RequestPriority priority_;
};
MockTransportClientSocketPool(
int max_sockets,
int max_sockets_per_group,
const CommonConnectJobParams* common_connect_job_params);
MockTransportClientSocketPool(const MockTransportClientSocketPool&) = delete;
MockTransportClientSocketPool& operator=(
const MockTransportClientSocketPool&) = delete;
~MockTransportClientSocketPool() override;
RequestPriority last_request_priority() const {
return last_request_priority_;
}
const std::vector<std::unique_ptr<MockConnectJob>>& requests() const {
return job_list_;
}
int release_count() const { return release_count_; }
int cancel_count() const { return cancel_count_; }
int RequestSocket(
const GroupId& group_id,
scoped_refptr<ClientSocketPool::SocketParams> socket_params,
const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
RequestPriority priority,
const SocketTag& socket_tag,
RespectLimits respect_limits,
ClientSocketHandle* handle,
CompletionOnceCallback callback,
const ProxyAuthCallback& on_auth_callback,
bool fail_if_alias_requires_proxy_override,
const NetLogWithSource& net_log) override;
void SetPriority(const GroupId& group_id,
ClientSocketHandle* handle,
RequestPriority priority) override;
void CancelRequest(const GroupId& group_id,
ClientSocketHandle* handle,
bool cancel_connect_job) override;
void ReleaseSocket(const GroupId& group_id,
std::unique_ptr<StreamSocket> socket,
int64_t generation) override;
private:
raw_ptr<ClientSocketFactory> client_socket_factory_;
std::vector<std::unique_ptr<MockConnectJob>> job_list_;
RequestPriority last_request_priority_ = DEFAULT_PRIORITY;
int release_count_ = 0;
int cancel_count_ = 0;
};
class WrappedStreamSocket : public TransportClientSocket {
public:
explicit WrappedStreamSocket(std::unique_ptr<StreamSocket> transport);
~WrappedStreamSocket() override;
int Bind(const net::IPEndPoint& local_addr) override;
int Connect(CompletionOnceCallback callback) override;
void Disconnect() override;
bool IsConnected() const override;
bool IsConnectedAndIdle() const override;
int GetPeerAddress(IPEndPoint* address) const override;
int GetLocalAddress(IPEndPoint* address) const override;
const NetLogWithSource& NetLog() const override;
bool WasEverUsed() const override;
NextProto GetNegotiatedProtocol() const override;
bool GetSSLInfo(SSLInfo* ssl_info) override;
int64_t GetTotalReceivedBytes() const override;
void ApplySocketTag(const SocketTag& tag) override;
int Read(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) override;
int ReadIfReady(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) override;
int Write(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback,
const NetworkTrafficAnnotationTag& traffic_annotation) override;
int SetReceiveBufferSize(int32_t size) override;
int SetSendBufferSize(int32_t size) override;
protected:
std::unique_ptr<StreamSocket> transport_;
};
class MockTaggingStreamSocket : public WrappedStreamSocket {
public:
explicit MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport)
: WrappedStreamSocket(std::move(transport)) {}
MockTaggingStreamSocket(const MockTaggingStreamSocket&) = delete;
MockTaggingStreamSocket& operator=(const MockTaggingStreamSocket&) = delete;
~MockTaggingStreamSocket() override = default;
int Connect(CompletionOnceCallback callback) override;
void ApplySocketTag(const SocketTag& tag) override;
bool tagged_before_connected() const { return tagged_before_connected_; }
SocketTag tag() const { return tag_; }
private:
bool connected_ = false;
bool tagged_before_connected_ = true;
SocketTag tag_;
};
class MockTaggingClientSocketFactory : public MockClientSocketFactory {
public:
MockTaggingClientSocketFactory() = default;
MockTaggingClientSocketFactory(const MockTaggingClientSocketFactory&) =
delete;
MockTaggingClientSocketFactory& operator=(
const MockTaggingClientSocketFactory&) = delete;
std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
NetLog* net_log,
const NetLogSource& source) override;
std::unique_ptr<TransportClientSocket> CreateTransportClientSocket(
const AddressList& addresses,
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetworkQualityEstimator* network_quality_estimator,
NetLog* net_log,
const NetLogSource& source) override;
MockTaggingStreamSocket* GetLastProducedTCPSocket() const {
return tcp_socket_;
}
MockUDPClientSocket* GetLastProducedUDPSocket() const { return udp_socket_; }
private:
raw_ptr<MockTaggingStreamSocket, AcrossTasksDanglingUntriaged> tcp_socket_ =
nullptr;
raw_ptr<MockUDPClientSocket, AcrossTasksDanglingUntriaged> udp_socket_ =
nullptr;
};
extern const char kSOCKS4TestHost[];
extern const int kSOCKS4TestPort;
extern const std::string_view kSOCKS4OkRequestLocalHostPort80;
extern const std::string_view kSOCKS4OkReply;
extern const char kSOCKS5TestHost[];
extern const int kSOCKS5TestPort;
extern const std::string_view kSOCKS5GreetRequest;
extern const std::string_view kSOCKS5GreetResponse;
extern const std::string_view kSOCKS5OkRequest;
extern const std::string_view kSOCKS5OkResponse;
int64_t CountReadBytes(base::span<const MockRead> reads);
int64_t CountWriteBytes(base::span<const MockWrite> writes);
#if BUILDFLAG(IS_ANDROID)
bool CanGetTaggedBytes();
uint64_t GetTaggedBytes(int32_t expected_tag);
#endif
}
#endif