#include "chromecast/net/fake_stream_socket.h"
#include <algorithm>
#include <cstring>
#include <vector>
#include "base/check_op.h"
#include "base/compiler_specific.h"
#include "base/functional/bind.h"
#include "base/functional/callback_helpers.h"
#include "base/location.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/task/sequenced_task_runner.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/socket/next_proto.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
namespace chromecast {
class SocketBuffer {
public:
SocketBuffer() : pending_read_data_(nullptr), pending_read_len_(0) {}
SocketBuffer(const SocketBuffer&) = delete;
SocketBuffer& operator=(const SocketBuffer&) = delete;
~SocketBuffer() {}
int Read(char* data, size_t len, net::CompletionOnceCallback callback) {
DCHECK(data);
DCHECK_GT(len, 0u);
DCHECK(callback);
if (data_.empty()) {
if (eos_) {
return 0;
}
pending_read_data_ = data;
pending_read_len_ = len;
pending_read_callback_ = std::move(callback);
return net::ERR_IO_PENDING;
}
return ReadInternal(data, len);
}
void Write(const char* data, size_t len) {
DCHECK(data);
DCHECK_GT(len, 0u);
data_.insert(data_.end(), data, UNSAFE_TODO(data + len));
if (!pending_read_callback_.is_null()) {
int result = ReadInternal(pending_read_data_, pending_read_len_);
pending_read_data_ = nullptr;
pending_read_len_ = 0;
PostReadCallback(std::move(pending_read_callback_), result);
}
}
void ReceiveEOS() {
eos_ = true;
if (pending_read_callback_ && data_.empty()) {
PostReadCallback(std::move(pending_read_callback_), 0);
}
}
private:
int ReadInternal(char* data, size_t len) {
DCHECK(data);
DCHECK_GT(len, 0u);
len = std::min(len, data_.size());
UNSAFE_TODO(std::memcpy(data, data_.data(), len));
data_.erase(data_.begin(), data_.begin() + len);
return len;
}
void PostReadCallback(net::CompletionOnceCallback callback, int result) {
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&SocketBuffer::CallReadCallback,
weak_factory_.GetWeakPtr(),
std::move(callback), result));
}
void CallReadCallback(net::CompletionOnceCallback callback, int result) {
std::move(callback).Run(result);
}
std::vector<char> data_;
raw_ptr<char> pending_read_data_;
size_t pending_read_len_;
net::CompletionOnceCallback pending_read_callback_;
bool eos_ = false;
base::WeakPtrFactory<SocketBuffer> weak_factory_{this};
};
FakeStreamSocket::FakeStreamSocket() : FakeStreamSocket(net::IPEndPoint()) {}
FakeStreamSocket::FakeStreamSocket(const net::IPEndPoint& local_address)
: local_address_(local_address),
buffer_(std::make_unique<SocketBuffer>()),
peer_(nullptr) {}
FakeStreamSocket::~FakeStreamSocket() {
if (peer_) {
peer_->RemoteDisconnected();
}
}
void FakeStreamSocket::SetPeer(FakeStreamSocket* peer) {
DCHECK(peer);
peer_ = peer;
}
void FakeStreamSocket::RemoteDisconnected() {
peer_ = nullptr;
buffer_->ReceiveEOS();
}
void FakeStreamSocket::SetBadSenderMode(bool bad_sender) {
bad_sender_mode_ = bad_sender;
}
int FakeStreamSocket::Read(net::IOBuffer* buf,
int buf_len,
net::CompletionOnceCallback callback) {
DCHECK(buf);
return buffer_->Read(buf->data(), buf_len, std::move(callback));
}
int FakeStreamSocket::Write(
net::IOBuffer* buf,
int buf_len,
net::CompletionOnceCallback ,
const net::NetworkTrafficAnnotationTag& ) {
DCHECK(buf);
if (!peer_) {
return net::ERR_SOCKET_NOT_CONNECTED;
}
int amount_to_send = buf_len;
if (bad_sender_mode_) {
amount_to_send = std::min(buf_len, buf_len / 2 + 1);
}
peer_->buffer_->Write(buf->data(), amount_to_send);
return amount_to_send;
}
int FakeStreamSocket::SetReceiveBufferSize(int32_t ) {
return net::OK;
}
int FakeStreamSocket::SetSendBufferSize(int32_t ) {
return net::OK;
}
int FakeStreamSocket::Connect(net::CompletionOnceCallback ) {
return net::OK;
}
void FakeStreamSocket::Disconnect() {}
bool FakeStreamSocket::IsConnected() const {
return true;
}
bool FakeStreamSocket::IsConnectedAndIdle() const {
return false;
}
int FakeStreamSocket::GetPeerAddress(net::IPEndPoint* address) const {
DCHECK(address);
if (!peer_) {
return net::ERR_SOCKET_NOT_CONNECTED;
}
*address = peer_->local_address_;
return net::OK;
}
int FakeStreamSocket::GetLocalAddress(net::IPEndPoint* address) const {
DCHECK(address);
*address = local_address_;
return net::OK;
}
const net::NetLogWithSource& FakeStreamSocket::NetLog() const {
return net_log_;
}
bool FakeStreamSocket::WasEverUsed() const {
return false;
}
net::NextProto FakeStreamSocket::GetNegotiatedProtocol() const {
return net::NextProto::kProtoUnknown;
}
bool FakeStreamSocket::GetSSLInfo(net::SSLInfo* ) {
return false;
}
int64_t FakeStreamSocket::GetTotalReceivedBytes() const {
return 0;
}
void FakeStreamSocket::ApplySocketTag(const net::SocketTag& tag) {}
}