#include "google_apis/gcm/base/socket_stream.h"
#include <stdint.h>
#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "base/compiler_specific.h"
#include "base/containers/span.h"
#include "base/functional/bind.h"
#include "base/run_loop.h"
#include "base/test/bind.h"
#include "base/test/task_environment.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "net/base/ip_address.h"
#include "net/base/network_isolation_key.h"
#include "net/log/net_log_source.h"
#include "net/socket/socket_test_util.h"
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
#include "net/url_request/url_request_context.h"
#include "net/url_request/url_request_context_builder.h"
#include "net/url_request/url_request_test_util.h"
#include "services/network/network_context.h"
#include "services/network/network_service.h"
#include "services/network/public/mojom/proxy_resolving_socket.mojom.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "url/origin.h"
namespace gcm {
namespace {
typedef std::vector<net::MockRead> ReadList;
typedef std::vector<net::MockWrite> WriteList;
constexpr std::string_view kReadData = "read_data";
constexpr std::string_view kReadData2 = "read_alternate_data";
constexpr std::string_view kWriteData = "write_data";
class GCMSocketStreamTest : public testing::Test {
public:
GCMSocketStreamTest();
~GCMSocketStreamTest() override;
void BuildSocket(const ReadList& read_list, const WriteList& write_list);
void PumpLoop();
std::string_view DoInputStreamRead(int bytes);
size_t DoOutputStreamWrite(std::string_view write_src);
size_t DoOutputStreamWriteWithoutFlush(std::string_view write_src);
void WaitForData(int msg_size);
SocketInputStream* input_stream() { return socket_input_stream_.get(); }
SocketOutputStream* output_stream() { return socket_output_stream_.get(); }
mojo::Remote<network::mojom::ProxyResolvingSocket> mojo_socket_remote_;
void set_socket_output_stream(std::unique_ptr<SocketOutputStream> stream) {
socket_output_stream_ = std::move(stream);
}
private:
void OpenConnection();
void ResetInputStream();
void ResetOutputStream();
base::test::TaskEnvironment task_environment_;
ReadList mock_reads_;
WriteList mock_writes_;
std::unique_ptr<net::StaticSocketDataProvider> data_provider_;
std::unique_ptr<net::SSLSocketDataProvider> ssl_data_provider_;
std::unique_ptr<SocketInputStream> socket_input_stream_;
std::unique_ptr<SocketOutputStream> socket_output_stream_;
net::AddressList address_list_;
std::unique_ptr<net::NetworkChangeNotifier> network_change_notifier_;
std::unique_ptr<network::NetworkService> network_service_;
mojo::Remote<network::mojom::NetworkContext> network_context_remote_;
net::MockClientSocketFactory socket_factory_;
std::unique_ptr<net::URLRequestContext> url_request_context_;
std::unique_ptr<network::NetworkContext> network_context_;
mojo::Remote<network::mojom::ProxyResolvingSocketFactory>
mojo_socket_factory_remote_;
mojo::ScopedDataPipeConsumerHandle receive_pipe_handle_;
mojo::ScopedDataPipeProducerHandle send_pipe_handle_;
};
GCMSocketStreamTest::GCMSocketStreamTest()
: task_environment_(base::test::TaskEnvironment::MainThreadType::IO),
network_change_notifier_(
net::NetworkChangeNotifier::CreateMockIfNeeded()),
network_service_(network::NetworkService::CreateForTesting()) {
address_list_ = net::AddressList::CreateFromIPAddress(
net::IPAddress::IPv4Localhost(), 5228);
socket_factory_.set_enable_read_if_ready(true);
auto context_builder = net::CreateTestURLRequestContextBuilder();
context_builder->set_client_socket_factory_for_testing(&socket_factory_);
url_request_context_ = context_builder->Build();
network_context_ = std::make_unique<network::NetworkContext>(
network_service_.get(),
network_context_remote_.BindNewPipeAndPassReceiver(),
url_request_context_.get(),
std::vector<std::string>());
}
GCMSocketStreamTest::~GCMSocketStreamTest() {}
void GCMSocketStreamTest::BuildSocket(const ReadList& read_list,
const WriteList& write_list) {
mock_reads_ = read_list;
mock_writes_ = write_list;
data_provider_ = std::make_unique<net::StaticSocketDataProvider>(
mock_reads_, mock_writes_);
ssl_data_provider_ =
std::make_unique<net::SSLSocketDataProvider>(net::SYNCHRONOUS, net::OK);
socket_factory_.AddSocketDataProvider(data_provider_.get());
socket_factory_.AddSSLSocketDataProvider(ssl_data_provider_.get());
OpenConnection();
ResetInputStream();
ResetOutputStream();
}
void GCMSocketStreamTest::PumpLoop() {
base::RunLoop run_loop;
run_loop.RunUntilIdle();
}
std::string_view GCMSocketStreamTest::DoInputStreamRead(int bytes) {
int total_bytes_read = 0;
const void* initial_buffer = nullptr;
const void* buffer = nullptr;
int size = 0;
do {
DCHECK(socket_input_stream_->GetState() == SocketInputStream::EMPTY ||
socket_input_stream_->GetState() == SocketInputStream::READY);
if (!socket_input_stream_->Next(&buffer, &size))
break;
total_bytes_read += size;
if (initial_buffer) {
UNSAFE_TODO(EXPECT_EQ(
static_cast<const uint8_t*>(initial_buffer) + total_bytes_read,
static_cast<const uint8_t*>(buffer) + size));
} else {
initial_buffer = buffer;
}
} while (total_bytes_read < bytes);
if (total_bytes_read > bytes) {
socket_input_stream_->BackUp(total_bytes_read - bytes);
total_bytes_read = bytes;
}
return std::string_view(static_cast<const char*>(initial_buffer),
total_bytes_read);
}
size_t GCMSocketStreamTest::DoOutputStreamWrite(std::string_view write_src) {
size_t total_bytes_written = DoOutputStreamWriteWithoutFlush(write_src);
base::RunLoop run_loop;
if (socket_output_stream_->Flush(run_loop.QuitClosure()) ==
net::ERR_IO_PENDING) {
run_loop.Run();
}
return total_bytes_written;
}
size_t GCMSocketStreamTest::DoOutputStreamWriteWithoutFlush(
std::string_view write_src) {
DCHECK_EQ(socket_output_stream_->GetState(), SocketOutputStream::EMPTY);
int total_bytes_written = 0;
void* buffer = nullptr;
int size = 0;
const int bytes = write_src.size();
do {
if (!socket_output_stream_->Next(&buffer, &size))
break;
int bytes_to_write = (size < bytes ? size : bytes);
UNSAFE_TODO(
memcpy(buffer, write_src.data() + total_bytes_written, bytes_to_write));
if (bytes_to_write < size)
socket_output_stream_->BackUp(size - bytes_to_write);
total_bytes_written += bytes_to_write;
} while (total_bytes_written < bytes);
return base::checked_cast<size_t>(total_bytes_written);
}
void GCMSocketStreamTest::WaitForData(int msg_size) {
while (input_stream()->UnreadByteCount() < msg_size) {
base::RunLoop run_loop;
if (input_stream()->Refresh(run_loop.QuitClosure(),
msg_size - input_stream()->UnreadByteCount()) ==
net::ERR_IO_PENDING) {
run_loop.Run();
}
if (input_stream()->GetState() == SocketInputStream::CLOSED)
return;
}
}
void GCMSocketStreamTest::OpenConnection() {
network_context_->CreateProxyResolvingSocketFactory(
mojo_socket_factory_remote_.BindNewPipeAndPassReceiver());
base::RunLoop run_loop;
int net_error = net::ERR_FAILED;
const GURL kDestination("https://example.com");
network::mojom::ProxyResolvingSocketOptionsPtr options =
network::mojom::ProxyResolvingSocketOptions::New();
options->use_tls = true;
const url::Origin kOrigin = url::Origin::Create(kDestination);
mojo_socket_factory_remote_->CreateProxyResolvingSocket(
kDestination,
net::NetworkAnonymizationKey::CreateSameSite(net::SchemefulSite(kOrigin)),
std::move(options),
net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
mojo_socket_remote_.BindNewPipeAndPassReceiver(),
mojo::NullRemote() ,
base::BindLambdaForTesting(
[&](int result, const std::optional<net::IPEndPoint>& local_addr,
const std::optional<net::IPEndPoint>& peer_addr,
mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
mojo::ScopedDataPipeProducerHandle send_pipe_handle) {
net_error = result;
receive_pipe_handle_ = std::move(receive_pipe_handle);
send_pipe_handle_ = std::move(send_pipe_handle);
run_loop.Quit();
}));
run_loop.Run();
PumpLoop();
}
void GCMSocketStreamTest::ResetInputStream() {
DCHECK(mojo_socket_remote_);
socket_input_stream_ =
std::make_unique<SocketInputStream>(std::move(receive_pipe_handle_));
}
void GCMSocketStreamTest::ResetOutputStream() {
DCHECK(mojo_socket_remote_);
socket_output_stream_ =
std::make_unique<SocketOutputStream>(std::move(send_pipe_handle_));
}
TEST_F(GCMSocketStreamTest, ReadDataSync) {
ReadList read_list;
read_list.push_back(net::MockRead(net::SYNCHRONOUS, kReadData));
read_list.push_back(net::MockRead(net::ASYNC, net::OK) );
BuildSocket(read_list, WriteList());
WaitForData(kReadData.size());
ASSERT_EQ(kReadData, DoInputStreamRead(kReadData.size()));
}
TEST_F(GCMSocketStreamTest, ReadPartialDataSync) {
int first_read_len = kReadData.size() / 2;
ReadList read_list;
read_list.push_back(
net::MockRead(net::SYNCHRONOUS, kReadData.substr(0, first_read_len)));
read_list.push_back(
net::MockRead(net::SYNCHRONOUS, kReadData.substr(first_read_len)));
read_list.push_back(net::MockRead(net::SYNCHRONOUS, net::OK));
BuildSocket(read_list, WriteList());
WaitForData(kReadData.size());
ASSERT_EQ(kReadData, DoInputStreamRead(kReadData.size()));
}
TEST_F(GCMSocketStreamTest, ReadAsync) {
int first_read_len = kReadData.size() / 2;
ReadList read_list;
read_list.push_back(
net::MockRead(net::ASYNC, kReadData.substr(0, first_read_len)));
read_list.push_back(
net::MockRead(net::ASYNC, kReadData.substr(first_read_len)));
read_list.push_back(net::MockRead(net::ASYNC, net::OK) );
BuildSocket(read_list, WriteList());
WaitForData(kReadData.size());
ASSERT_EQ(kReadData, DoInputStreamRead(kReadData.size()));
}
TEST_F(GCMSocketStreamTest, TwoReadsAtOnce) {
std::string long_data = std::string(kReadData) + std::string(kReadData2);
ReadList read_list;
read_list.push_back(net::MockRead(net::SYNCHRONOUS, long_data));
read_list.push_back(net::MockRead(net::SYNCHRONOUS, net::OK));
BuildSocket(read_list, WriteList());
WaitForData(kReadData.size());
ASSERT_EQ(kReadData, DoInputStreamRead(kReadData.size()));
WaitForData(kReadData2.size());
ASSERT_EQ(kReadData2, DoInputStreamRead(kReadData2.size()));
}
TEST_F(GCMSocketStreamTest, TwoReadsAtOnceWithRebuild) {
std::string long_data = std::string(kReadData) + std::string(kReadData2);
ReadList read_list;
read_list.push_back(net::MockRead(net::SYNCHRONOUS, long_data));
read_list.push_back(net::MockRead(net::SYNCHRONOUS, net::OK));
BuildSocket(read_list, WriteList());
WaitForData(kReadData.size());
ASSERT_EQ(kReadData, DoInputStreamRead(kReadData.size()));
input_stream()->RebuildBuffer();
WaitForData(kReadData2.size());
ASSERT_EQ(kReadData2, DoInputStreamRead(kReadData2.size()));
}
TEST_F(GCMSocketStreamTest, ReadError) {
int result = net::ERR_ABORTED;
BuildSocket(ReadList(1, net::MockRead(net::SYNCHRONOUS, result)),
WriteList());
WaitForData(kReadData.size());
ASSERT_EQ(SocketInputStream::CLOSED, input_stream()->GetState());
ASSERT_EQ(net::ERR_FAILED, input_stream()->last_error());
}
TEST_F(GCMSocketStreamTest, ReadDisconnected) {
BuildSocket(ReadList(1, net::MockRead(net::SYNCHRONOUS, net::ERR_IO_PENDING)),
WriteList());
mojo_socket_remote_.reset();
WaitForData(kReadData.size());
ASSERT_EQ(SocketInputStream::CLOSED, input_stream()->GetState());
ASSERT_EQ(net::ERR_FAILED, input_stream()->last_error());
}
TEST_F(GCMSocketStreamTest, WriteFull) {
BuildSocket(ReadList(1, net::MockRead(net::SYNCHRONOUS, net::ERR_IO_PENDING)),
WriteList(1, net::MockWrite(net::SYNCHRONOUS, kWriteData)));
ASSERT_EQ(kWriteData.size(), DoOutputStreamWrite(kWriteData));
}
TEST_F(GCMSocketStreamTest, WritePartial) {
WriteList write_list;
write_list.push_back(net::MockWrite(
net::SYNCHRONOUS, kWriteData.substr(0, kWriteData.size() / 2)));
write_list.push_back(net::MockWrite(
net::SYNCHRONOUS, kWriteData.substr(kWriteData.size() / 2)));
BuildSocket(ReadList(1, net::MockRead(net::SYNCHRONOUS, net::ERR_IO_PENDING)),
write_list);
ASSERT_EQ(kWriteData.size(), DoOutputStreamWrite(kWriteData));
}
TEST_F(GCMSocketStreamTest, WritePartialWithLengthChecking) {
std::string prefix_data("xxxxx");
const size_t kPrefixDataSize = 5;
mojo::ScopedDataPipeProducerHandle producer_handle;
mojo::ScopedDataPipeConsumerHandle consumer_handle;
ASSERT_EQ(mojo::CreateDataPipe(
kWriteData.size() + prefix_data.size() - 1 ,
producer_handle, consumer_handle),
MOJO_RESULT_OK);
size_t bytes_written = 0;
MojoResult r =
producer_handle->WriteData(base::as_byte_span(prefix_data),
MOJO_WRITE_DATA_FLAG_NONE, bytes_written);
ASSERT_EQ(MOJO_RESULT_OK, r);
ASSERT_EQ(prefix_data.size(), bytes_written);
auto socket_output_stream =
std::make_unique<SocketOutputStream>(std::move(producer_handle));
set_socket_output_stream(std::move(socket_output_stream));
EXPECT_EQ(kWriteData.size(), DoOutputStreamWriteWithoutFlush(kWriteData));
base::RunLoop run_loop;
output_stream()->Flush(run_loop.QuitClosure());
base::RunLoop().RunUntilIdle();
std::string contents;
std::string buffer(kPrefixDataSize, '\0');
size_t bytes_read = 0;
ASSERT_EQ(MOJO_RESULT_OK,
consumer_handle->ReadData(MOJO_READ_DATA_FLAG_NONE,
base::as_writable_byte_span(buffer),
bytes_read));
ASSERT_EQ(kPrefixDataSize, bytes_read);
contents += buffer.substr(0, bytes_read);
base::RunLoop().RunUntilIdle();
run_loop.Run();
set_socket_output_stream(nullptr);
while (true) {
r = consumer_handle->ReadData(MOJO_READ_DATA_FLAG_NONE,
base::as_writable_byte_span(buffer),
bytes_read);
if (r == MOJO_RESULT_SHOULD_WAIT)
continue;
if (r != MOJO_RESULT_OK)
break;
ASSERT_EQ(MOJO_RESULT_OK, r);
contents += buffer.substr(0, bytes_read);
}
std::string expected(prefix_data);
expected.append(kWriteData);
EXPECT_EQ(expected, contents);
}
TEST_F(GCMSocketStreamTest, WriteNone) {
WriteList write_list;
write_list.push_back(net::MockWrite(
net::SYNCHRONOUS, kWriteData.substr(0, kWriteData.size() / 2)));
write_list.push_back(net::MockWrite(
net::SYNCHRONOUS, kWriteData.substr(kWriteData.size() / 2)));
BuildSocket(ReadList(1, net::MockRead(net::SYNCHRONOUS, net::ERR_IO_PENDING)),
write_list);
ASSERT_EQ(kWriteData.size(), DoOutputStreamWrite(kWriteData));
}
TEST_F(GCMSocketStreamTest, WriteThenRead) {
ReadList read_list;
read_list.push_back(net::MockRead(net::SYNCHRONOUS, kReadData));
read_list.push_back(net::MockRead(net::SYNCHRONOUS, net::OK));
BuildSocket(read_list,
WriteList(1, net::MockWrite(net::SYNCHRONOUS, kWriteData)));
ASSERT_EQ(kWriteData.size(), DoOutputStreamWrite(kWriteData));
WaitForData(kReadData.size());
ASSERT_EQ(kReadData, DoInputStreamRead(kReadData.size()));
}
TEST_F(GCMSocketStreamTest, ReadThenWrite) {
ReadList read_list;
read_list.push_back(net::MockRead(net::SYNCHRONOUS, kReadData));
read_list.push_back(net::MockRead(net::SYNCHRONOUS, net::OK));
BuildSocket(read_list,
WriteList(1, net::MockWrite(net::SYNCHRONOUS, kWriteData)));
WaitForData(kReadData.size());
ASSERT_EQ(kReadData, DoInputStreamRead(kReadData.size()));
ASSERT_EQ(kWriteData.size(), DoOutputStreamWrite(kWriteData));
}
TEST_F(GCMSocketStreamTest, WriteError) {
int result = net::ERR_ABORTED;
BuildSocket(ReadList(1, net::MockRead(net::SYNCHRONOUS, net::ERR_IO_PENDING)),
WriteList(1, net::MockWrite(net::SYNCHRONOUS, result)));
while (output_stream()->GetState() != SocketOutputStream::CLOSED) {
DoOutputStreamWrite(kWriteData);
}
ASSERT_EQ(SocketOutputStream::CLOSED, output_stream()->GetState());
ASSERT_EQ(net::ERR_FAILED, output_stream()->last_error());
}
TEST_F(GCMSocketStreamTest, WriteDisconnected) {
BuildSocket(ReadList(1, net::MockRead(net::SYNCHRONOUS, net::ERR_IO_PENDING)),
WriteList());
mojo_socket_remote_.reset();
DoOutputStreamWrite(kWriteData);
ASSERT_EQ(SocketOutputStream::CLOSED, output_stream()->GetState());
ASSERT_EQ(net::ERR_FAILED, output_stream()->last_error());
}
}
}