#include "tools/android/forwarder2/forwarder.h"
#include <utility>
#include "base/check.h"
#include "base/posix/eintr_wrapper.h"
#include "tools/android/forwarder2/socket.h"
namespace forwarder2 {
namespace {
const int kBufferSize = 32 * 1024;
}
class Forwarder::BufferedCopier {
public:
enum State {
STATE_READING = 0,
STATE_WRITING = 1,
STATE_CLOSING = 2,
STATE_CLOSED = 3,
};
BufferedCopier(Socket* socket_from, Socket* socket_to)
: socket_from_(socket_from),
socket_to_(socket_to),
bytes_read_(0),
write_offset_(0),
peer_(NULL),
state_(STATE_READING) {}
BufferedCopier(const BufferedCopier&) = delete;
BufferedCopier& operator=(const BufferedCopier&) = delete;
void SetPeer(BufferedCopier* peer) {
DCHECK(!peer_);
peer_ = peer;
}
bool is_closed() const { return state_ == STATE_CLOSED; }
void Close() {
switch (state_) {
case STATE_READING:
state_ = STATE_CLOSED;
break;
case STATE_WRITING:
state_ = STATE_CLOSING;
break;
case STATE_CLOSING:
break;
case STATE_CLOSED:
break;
}
}
void PrepareSelect(fd_set* read_fds, fd_set* write_fds, int* max_fd) {
int fd;
switch (state_) {
case STATE_READING:
DCHECK(bytes_read_ == 0);
DCHECK(write_offset_ == 0);
fd = socket_from_->fd();
if (fd < 0) {
ForceClose();
return;
}
FD_SET(fd, read_fds);
break;
case STATE_WRITING:
case STATE_CLOSING:
DCHECK(bytes_read_ > 0);
DCHECK(write_offset_ < bytes_read_);
fd = socket_to_->fd();
if (fd < 0) {
ForceClose();
return;
}
FD_SET(fd, write_fds);
break;
case STATE_CLOSED:
return;
}
*max_fd = std::max(*max_fd, fd);
}
void ProcessSelect(const fd_set& read_fds, const fd_set& write_fds) {
int fd;
int ret;
fd_set read_fds_copy = read_fds;
fd_set write_fds_copy = write_fds;
switch (state_) {
case STATE_READING:
fd = socket_from_->fd();
if (fd < 0) {
state_ = STATE_CLOSED;
return;
}
if (!FD_ISSET(fd, &read_fds_copy))
return;
ret = socket_from_->NonBlockingRead(buffer_, kBufferSize);
if (ret <= 0) {
ForceClose();
return;
}
bytes_read_ = ret;
write_offset_ = 0;
state_ = STATE_WRITING;
break;
case STATE_WRITING:
case STATE_CLOSING:
fd = socket_to_->fd();
if (fd < 0) {
ForceClose();
return;
}
if (!FD_ISSET(fd, &write_fds_copy))
return;
ret = socket_to_->NonBlockingWrite(buffer_ + write_offset_,
bytes_read_ - write_offset_);
if (ret <= 0) {
ForceClose();
return;
}
write_offset_ += ret;
if (write_offset_ < bytes_read_)
return;
write_offset_ = 0;
bytes_read_ = 0;
if (state_ == STATE_CLOSING) {
ForceClose();
return;
}
state_ = STATE_READING;
break;
case STATE_CLOSED:
break;
}
}
private:
void ForceClose() {
if (peer_) {
peer_->Close();
peer_ = NULL;
}
state_ = STATE_CLOSED;
}
Socket* socket_from_;
Socket* socket_to_;
int bytes_read_;
int write_offset_;
BufferedCopier* peer_;
State state_;
char buffer_[kBufferSize];
};
Forwarder::Forwarder(std::unique_ptr<Socket> socket1,
std::unique_ptr<Socket> socket2)
: socket1_(std::move(socket1)),
socket2_(std::move(socket2)),
buffer1_(new BufferedCopier(socket1_.get(), socket2_.get())),
buffer2_(new BufferedCopier(socket2_.get(), socket1_.get())) {
buffer1_->SetPeer(buffer2_.get());
buffer2_->SetPeer(buffer1_.get());
}
Forwarder::~Forwarder() {
DCHECK(thread_checker_.CalledOnValidThread());
}
void Forwarder::RegisterFDs(fd_set* read_fds, fd_set* write_fds, int* max_fd) {
DCHECK(thread_checker_.CalledOnValidThread());
buffer1_->PrepareSelect(read_fds, write_fds, max_fd);
buffer2_->PrepareSelect(read_fds, write_fds, max_fd);
}
void Forwarder::ProcessEvents(const fd_set& read_fds, const fd_set& write_fds) {
DCHECK(thread_checker_.CalledOnValidThread());
buffer1_->ProcessSelect(read_fds, write_fds);
buffer2_->ProcessSelect(read_fds, write_fds);
}
bool Forwarder::IsClosed() const {
DCHECK(thread_checker_.CalledOnValidThread());
return buffer1_->is_closed() && buffer2_->is_closed();
}
void Forwarder::Shutdown() {
DCHECK(thread_checker_.CalledOnValidThread());
buffer1_->Close();
buffer2_->Close();
}
}