#include "net/test/embedded_test_server/http_connect_proxy_handler.h"
#include <stdint.h>
#include <memory>
#include <optional>
#include <set>
#include "base/check_op.h"
#include "base/containers/flat_set.h"
#include "base/containers/span.h"
#include "base/containers/unique_ptr_adapters.h"
#include "base/functional/bind.h"
#include "base/location.h"
#include "base/memory/raw_ptr.h"
#include "base/strings/strcat.h"
#include "base/task/sequenced_task_runner.h"
#include "net/base/address_list.h"
#include "net/base/host_port_pair.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_address.h"
#include "net/base/net_errors.h"
#include "net/http/http_status_code.h"
#include "net/log/net_log_source.h"
#include "net/socket/stream_socket.h"
#include "net/socket/tcp_client_socket.h"
#include "net/test/embedded_test_server/http_connection.h"
#include "net/test/embedded_test_server/http_request.h"
#include "net/test/embedded_test_server/http_response.h"
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "url/origin.h"
namespace net::test_server {
class HttpConnectProxyHandler::ConnectTunnel {
public:
static constexpr size_t kCapacity = 32 * 1024;
using DeleteCallback = base::OnceCallback<void(ConnectTunnel*)>;
ConnectTunnel(HttpConnectProxyHandler* http_proxy_handler,
std::unique_ptr<StreamSocket> socket)
: http_proxy_handler_(http_proxy_handler), socket_(std::move(socket)) {}
~ConnectTunnel() = default;
void Start(uint16_t dest_port) {
dest_socket_ = std::make_unique<TCPClientSocket>(
AddressList::CreateFromIPAddress(IPAddress::IPv4Localhost(), dest_port),
nullptr,
nullptr, nullptr,
NetLogSource());
int result = dest_socket_->Connect(base::BindOnce(
&ConnectTunnel::OnConnectComplete, base::Unretained(this)));
if (result != ERR_IO_PENDING) {
OnConnectComplete(result);
}
}
private:
void OnConnectComplete(int result) {
if (result != OK) {
VLOG(1) << "Failed to establish tunnel connection.";
BasicHttpResponse response;
response.set_code(HttpStatusCode::HTTP_BAD_GATEWAY);
response.set_reason("Bad Gateway");
std::string response_string = response.ToResponseString();
scoped_refptr<GrowableIOBuffer> buffer =
base::MakeRefCounted<GrowableIOBuffer>();
buffer->SetCapacity(response_string.size());
buffer->span().copy_prefix_from(base::as_byte_span(response_string));
DoWrite(nullptr, socket_.get(), std::move(buffer),
response_string.size());
return;
}
BasicHttpResponse response;
response.set_reason("Connection established");
StartTunneling(dest_socket_.get(), socket_.get(),
response.ToResponseString());
StartTunneling(socket_.get(), dest_socket_.get());
}
void StartTunneling(StreamSocket* src,
StreamSocket* dest,
std::string initial_data = {}) {
scoped_refptr<GrowableIOBuffer> buffer =
base::MakeRefCounted<GrowableIOBuffer>();
buffer->SetCapacity(std::max(kCapacity, initial_data.size()));
if (!initial_data.empty()) {
buffer->span().copy_prefix_from(base::as_byte_span(initial_data));
DoWrite(src, dest, std::move(buffer), initial_data.size());
return;
}
DoRead(src, dest, std::move(buffer));
}
void DoRead(StreamSocket* src,
StreamSocket* dest,
scoped_refptr<GrowableIOBuffer> buffer) {
int result =
src->Read(buffer.get(), buffer->size(),
base::BindOnce(&ConnectTunnel::OnReadComplete,
base::Unretained(this), src, dest, buffer));
if (result == ERR_IO_PENDING) {
return;
}
OnReadComplete(src, dest, std::move(buffer), result);
}
void OnReadComplete(StreamSocket* src,
StreamSocket* dest,
scoped_refptr<GrowableIOBuffer> buffer,
int result) {
CHECK_NE(result, ERR_IO_PENDING);
if (result <= 0) {
http_proxy_handler_->DeleteTunnel(this);
return;
}
DoWrite(src, dest, std::move(buffer), result);
}
void DoWrite(StreamSocket* src,
StreamSocket* dest,
scoped_refptr<GrowableIOBuffer> buffer,
int remaining_bytes) {
CHECK_GE(remaining_bytes, 0);
int result = dest->Write(
buffer.get(), remaining_bytes,
base::BindOnce(&ConnectTunnel::OnWriteComplete, base::Unretained(this),
src, dest, buffer, remaining_bytes),
TRAFFIC_ANNOTATION_FOR_TESTS);
if (result == ERR_IO_PENDING) {
return;
}
OnWriteComplete(src, dest, std::move(buffer), remaining_bytes, result);
}
void OnWriteComplete(StreamSocket* src,
StreamSocket* dest,
scoped_refptr<GrowableIOBuffer> buffer,
int remaining_bytes,
int result) {
CHECK_NE(result, ERR_IO_PENDING);
if (result < 0) {
http_proxy_handler_->DeleteTunnel(this);
return;
}
CHECK_LE(result, remaining_bytes);
buffer->DidConsume(result);
remaining_bytes -= result;
if (remaining_bytes > 0) {
DoWrite(src, dest, std::move(buffer), remaining_bytes);
return;
}
if (!src) {
http_proxy_handler_->DeleteTunnel(this);
return;
}
buffer->set_offset(0);
DoRead(src, dest, std::move(buffer));
}
raw_ptr<HttpConnectProxyHandler> http_proxy_handler_;
std::unique_ptr<StreamSocket> socket_;
std::unique_ptr<TCPClientSocket> dest_socket_;
};
HttpConnectProxyHandler::HttpConnectProxyHandler(
base::span<const HostPortPair> proxied_destinations)
: proxied_destinations_(proxied_destinations.begin(),
proxied_destinations.end()) {}
HttpConnectProxyHandler::~HttpConnectProxyHandler() = default;
bool HttpConnectProxyHandler::HandleProxyRequest(HttpConnection& connection,
const HttpRequest& request) {
CHECK_EQ(connection.protocol(), HttpConnection::Protocol::kHttp1);
CHECK_EQ(request.method, METHOD_CONNECT);
HostPortPair dest = HostPortPair::FromString(request.relative_url);
std::unique_ptr<BasicHttpResponse> error_response;
if (dest.IsEmpty()) {
ADD_FAILURE() << "Invalid CONNECT destination: " << request.relative_url;
return false;
}
if (!proxied_destinations_.contains(dest)) {
return false;
}
auto tunnel = std::make_unique<ConnectTunnel>(this, connection.TakeSocket());
auto tunnel_it = connect_tunnels_.insert(std::move(tunnel)).first;
(*tunnel_it)->Start(dest.port());
return true;
}
void HttpConnectProxyHandler::DeleteTunnel(ConnectTunnel* tunnel) {
auto tunnel_it = connect_tunnels_.find(tunnel);
CHECK(tunnel_it != connect_tunnels_.end());
connect_tunnels_.erase(tunnel_it);
}
}