#include "net/test/embedded_test_server/connection_tracker.h"
#include "base/containers/contains.h"
#include "base/run_loop.h"
#include "base/task/single_thread_task_runner.h"
#include "net/test/embedded_test_server/embedded_test_server.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace {
bool GetPort(const net::StreamSocket& connection, uint16_t* port) {
net::IPEndPoint address;
int result = connection.GetPeerAddress(&address);
if (result != net::OK)
return false;
*port = address.port();
return true;
}
}
namespace net::test_server {
ConnectionTracker::ConnectionTracker(EmbeddedTestServer* test_server)
: connection_listener_(this) {
test_server->SetConnectionListener(&connection_listener_);
}
ConnectionTracker::~ConnectionTracker() = default;
void ConnectionTracker::AcceptedSocketWithPort(uint16_t port) {
num_connected_sockets_++;
sockets_[port] = SocketStatus::kAccepted;
CheckAccepted();
}
void ConnectionTracker::ReadFromSocketWithPort(uint16_t port) {
EXPECT_TRUE(base::Contains(sockets_, port));
if (sockets_[port] == SocketStatus::kAccepted)
num_read_sockets_++;
sockets_[port] = SocketStatus::kReadFrom;
if (read_loop_) {
read_loop_->Quit();
read_loop_ = nullptr;
}
}
size_t ConnectionTracker::GetAcceptedSocketCount() const {
return num_connected_sockets_;
}
size_t ConnectionTracker::GetReadSocketCount() const {
return num_read_sockets_;
}
void ConnectionTracker::WaitUntilConnectionRead() {
base::RunLoop run_loop;
read_loop_ = &run_loop;
read_loop_->Run();
}
void ConnectionTracker::WaitForAcceptedConnections(size_t num_connections) {
DCHECK(!num_accepted_connections_loop_);
DCHECK_GT(num_connections, 0u);
base::RunLoop run_loop;
EXPECT_GE(num_connections, num_connected_sockets_);
num_accepted_connections_loop_ = &run_loop;
num_accepted_connections_needed_ = num_connections;
CheckAccepted();
run_loop.Run();
EXPECT_EQ(num_connections, num_connected_sockets_);
}
void ConnectionTracker::CheckAccepted() {
DCHECK(num_accepted_connections_loop_ ||
num_accepted_connections_needed_ == 0);
if (!num_accepted_connections_loop_ ||
num_accepted_connections_needed_ != num_connected_sockets_) {
return;
}
num_accepted_connections_loop_->Quit();
num_accepted_connections_needed_ = 0;
num_accepted_connections_loop_ = nullptr;
}
void ConnectionTracker::ResetCounts() {
sockets_.clear();
num_connected_sockets_ = 0;
num_read_sockets_ = 0;
}
ConnectionTracker::ConnectionListener::ConnectionListener(
ConnectionTracker* tracker)
: task_runner_(base::SingleThreadTaskRunner::GetCurrentDefault()),
tracker_(tracker) {}
ConnectionTracker::ConnectionListener::~ConnectionListener() = default;
std::unique_ptr<net::StreamSocket>
ConnectionTracker::ConnectionListener::AcceptedSocket(
std::unique_ptr<net::StreamSocket> connection) {
uint16_t port;
if (GetPort(*connection, &port)) {
task_runner_->PostTask(
FROM_HERE, base::BindOnce(&ConnectionTracker::AcceptedSocketWithPort,
base::Unretained(tracker_), port));
}
return connection;
}
void ConnectionTracker::ConnectionListener::ReadFromSocket(
const net::StreamSocket& connection,
int rv) {
if (rv <= 0)
return;
uint16_t port;
if (GetPort(connection, &port)) {
task_runner_->PostTask(
FROM_HERE, base::BindOnce(&ConnectionTracker::ReadFromSocketWithPort,
base::Unretained(tracker_), port));
}
}
}