#ifdef UNSAFE_BUFFERS_BUILD
#pragma allow_unsafe_buffers
#endif
#include "mojo/core/channel.h"
#include <windows.h>
#include <stdint.h>
#include <algorithm>
#include <limits>
#include <memory>
#include <tuple>
#include "base/containers/queue.h"
#include "base/functional/bind.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/memory/ref_counted.h"
#include "base/message_loop/message_pump_for_io.h"
#include "base/process/process_handle.h"
#include "base/synchronization/lock.h"
#include "base/task/current_thread.h"
#include "base/task/single_thread_task_runner.h"
#include "base/task/task_runner.h"
#include "base/win/scoped_handle.h"
#include "base/win/win_util.h"
namespace mojo {
namespace core {
namespace {
class ChannelWinMessageQueue {
public:
ChannelWinMessageQueue() = default;
~ChannelWinMessageQueue() = default;
void Append(Channel::MessagePtr message) {
queue_.emplace_back(std::move(message));
}
Channel::Message* GetFirst() const { return queue_.front().get(); }
Channel::MessagePtr TakeFirst() {
Channel::MessagePtr message = std::move(queue_.front());
queue_.pop_front();
return message;
}
bool IsEmpty() const { return queue_.empty(); }
private:
base::circular_deque<Channel::MessagePtr> queue_;
};
class ChannelWin : public Channel,
public base::CurrentThread::DestructionObserver,
public base::MessagePumpForIO::IOHandler {
public:
ChannelWin(Delegate* delegate,
ConnectionParams connection_params,
HandlePolicy handle_policy,
scoped_refptr<base::SingleThreadTaskRunner> io_task_runner)
: Channel(delegate, handle_policy),
base::MessagePumpForIO::IOHandler(FROM_HERE),
is_untrusted_process_(connection_params.is_untrusted_process()),
self_(this),
io_task_runner_(io_task_runner) {
handle_ =
connection_params.TakeEndpoint().TakePlatformHandle().TakeHandle();
CHECK(handle_.is_valid());
}
ChannelWin(const ChannelWin&) = delete;
ChannelWin& operator=(const ChannelWin&) = delete;
void Start() override {
io_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&ChannelWin::StartOnIOThread, this));
}
void ShutDownImpl() override {
io_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&ChannelWin::ShutDownOnIOThread, this));
}
void Write(MessagePtr message) override {
RecordSentMessageMetrics(message->data_num_bytes());
if (remote_process().IsValid()) {
std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
for (auto& handle : handles) {
if (handle.handle().is_valid()) {
handle.TransferToProcess(
remote_process().Duplicate(),
is_untrusted_process_ ? PlatformHandleInTransit::kUntrustedTarget
: PlatformHandleInTransit::kTrustedTarget);
}
}
message->SetHandles(std::move(handles));
}
bool write_error = false;
{
base::AutoLock lock(write_lock_);
if (reject_writes_)
return;
bool write_now = !delay_writes_ && outgoing_messages_.IsEmpty();
outgoing_messages_.Append(std::move(message));
if (write_now && !WriteNoLock(outgoing_messages_.GetFirst()))
reject_writes_ = write_error = true;
}
if (write_error) {
io_task_runner_->PostTask(FROM_HERE,
base::BindOnce(&ChannelWin::OnWriteError, this,
Error::kDisconnected));
}
}
void LeakHandle() override {
DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
leak_handle_ = true;
}
bool GetReadPlatformHandles(const void* payload,
size_t payload_size,
size_t num_handles,
const void* extra_header,
size_t extra_header_size,
std::vector<PlatformHandle>* handles) override {
DCHECK(extra_header);
if (num_handles > std::numeric_limits<uint16_t>::max())
return false;
using HandleEntry = Channel::Message::HandleEntry;
size_t handles_size = sizeof(HandleEntry) * num_handles;
if (handles_size > extra_header_size)
return false;
handles->reserve(num_handles);
const HandleEntry* extra_header_handles =
reinterpret_cast<const HandleEntry*>(extra_header);
for (size_t i = 0; i < num_handles; i++) {
HANDLE handle_value =
base::win::Uint32ToHandle(extra_header_handles[i].handle);
if (PlatformHandleInTransit::IsPseudoHandle(handle_value))
return false;
if (remote_process().IsValid() && handle_value != INVALID_HANDLE_VALUE) {
handle_value = PlatformHandleInTransit::TakeIncomingRemoteHandle(
handle_value, remote_process().Handle())
.ReleaseHandle();
}
handles->emplace_back(base::win::ScopedHandle(std::move(handle_value)));
}
return true;
}
bool GetReadPlatformHandlesForIpcz(
size_t num_handles,
std::vector<PlatformHandle>& handles) override {
return false;
}
private:
~ChannelWin() override = default;
void StartOnIOThread() {
base::CurrentThread::Get()->AddDestructionObserver(this);
if (!base::CurrentIOThread::Get()->RegisterIOHandler(handle_.get(), this)) {
OnError(Error::kConnectionFailed);
return;
}
{
base::AutoLock lock(write_lock_);
if (delay_writes_) {
delay_writes_ = false;
WriteNextNoLock();
}
}
scoped_refptr<ChannelWin> keep_alive(this);
ReadMore(0);
}
void ShutDownOnIOThread() {
base::CurrentThread::Get()->RemoveDestructionObserver(this);
{
base::AutoLock lock(write_lock_);
reject_writes_ = true;
}
CHECK(handle_.is_valid());
CancelIo(handle_.get());
if (leak_handle_) {
std::ignore = handle_.Take();
} else {
handle_.Close();
}
self_ = nullptr;
}
void WillDestroyCurrentMessageLoop() override {
DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
if (self_)
ShutDownOnIOThread();
}
void OnIOCompleted(base::MessagePumpForIO::IOContext* context,
DWORD bytes_transfered,
DWORD error) override {
if (error != ERROR_SUCCESS) {
if (context == &write_context_) {
{
base::AutoLock lock(write_lock_);
reject_writes_ = true;
}
OnWriteError(Error::kDisconnected);
} else {
OnError(Error::kDisconnected);
}
} else if (context == &read_context_) {
OnReadDone(static_cast<size_t>(bytes_transfered));
} else {
CHECK(context == &write_context_);
OnWriteDone(static_cast<size_t>(bytes_transfered));
}
Release();
}
void OnReadDone(size_t bytes_read) {
DCHECK(is_read_pending_);
is_read_pending_ = false;
if (bytes_read > 0) {
size_t next_read_size = 0;
if (OnReadComplete(bytes_read, &next_read_size)) {
ReadMore(next_read_size);
} else {
OnError(Error::kReceivedMalformedData);
}
} else if (bytes_read == 0) {
OnError(Error::kDisconnected);
}
}
void OnWriteDone(size_t bytes_written) {
if (bytes_written == 0)
return;
bool write_error = false;
{
base::AutoLock lock(write_lock_);
DCHECK(is_write_pending_);
is_write_pending_ = false;
DCHECK(!outgoing_messages_.IsEmpty());
Channel::MessagePtr message = outgoing_messages_.TakeFirst();
if (message->data_num_bytes() != bytes_written)
reject_writes_ = write_error = true;
else if (!WriteNextNoLock())
reject_writes_ = write_error = true;
}
if (write_error)
OnWriteError(Error::kDisconnected);
}
void ReadMore(size_t next_read_size_hint) {
DCHECK(!is_read_pending_);
size_t buffer_capacity = next_read_size_hint;
char* buffer = GetReadBuffer(&buffer_capacity);
DCHECK_GT(buffer_capacity, 0u);
BOOL ok =
::ReadFile(handle_.get(), buffer, static_cast<DWORD>(buffer_capacity),
NULL, read_context_.GetOverlapped());
if (ok || GetLastError() == ERROR_IO_PENDING) {
is_read_pending_ = true;
AddRef();
} else {
OnError(Error::kDisconnected);
}
}
bool WriteNoLock(Channel::Message* message) {
std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
for (auto& handle : handles)
handle.CompleteTransit();
DCHECK(handle_.is_valid());
BOOL ok = WriteFile(handle_.get(), message->data(),
static_cast<DWORD>(message->data_num_bytes()), NULL,
write_context_.GetOverlapped());
if (ok || GetLastError() == ERROR_IO_PENDING) {
is_write_pending_ = true;
AddRef();
return true;
}
return false;
}
bool WriteNextNoLock() {
if (outgoing_messages_.IsEmpty()) {
return true;
}
if (reject_writes_) {
return false;
}
return WriteNoLock(outgoing_messages_.GetFirst());
}
void OnWriteError(Error error) {
DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
DCHECK(reject_writes_);
if (error == Error::kDisconnected) {
if (is_read_pending_) {
return;
}
}
OnError(error);
}
const bool is_untrusted_process_;
scoped_refptr<Channel> self_;
base::win::ScopedHandle handle_;
const scoped_refptr<base::SingleThreadTaskRunner> io_task_runner_;
base::MessagePumpForIO::IOContext read_context_;
bool is_read_pending_ = false;
base::Lock write_lock_;
base::MessagePumpForIO::IOContext write_context_;
ChannelWinMessageQueue outgoing_messages_;
bool delay_writes_ = true;
bool reject_writes_ = false;
bool is_write_pending_ = false;
bool leak_handle_ = false;
};
}
scoped_refptr<Channel> Channel::Create(
Delegate* delegate,
ConnectionParams connection_params,
HandlePolicy handle_policy,
scoped_refptr<base::SingleThreadTaskRunner> io_task_runner) {
return new ChannelWin(delegate, std::move(connection_params), handle_policy,
io_task_runner);
}
}
}