#include "remoting/host/native_messaging/native_messaging_reader.h"
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include "base/compiler_specific.h"
#include "base/files/file.h"
#include "base/functional/bind.h"
#include "base/json/json_reader.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/message_loop/message_pump_type.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/single_thread_task_runner.h"
#include "base/threading/thread.h"
#include "base/values.h"
#include "build/build_config.h"
#if BUILDFLAG(IS_WIN)
#include <windows.h>
#include "base/threading/platform_thread.h"
#include "base/win/scoped_handle.h"
#endif
namespace {
typedef uint32_t MessageLengthType;
const int kMessageHeaderSize = sizeof(MessageLengthType);
const MessageLengthType kMaximumMessageSize = 1024 * 1024;
}
namespace remoting {
class NativeMessagingReader::Core {
public:
Core(base::File file,
scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner,
scoped_refptr<base::SequencedTaskRunner> read_task_runner,
base::WeakPtr<NativeMessagingReader> reader_);
Core(const Core&) = delete;
Core& operator=(const Core&) = delete;
~Core();
void ReadMessage();
private:
void NotifyEof();
base::File read_stream_;
base::WeakPtr<NativeMessagingReader> reader_;
scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner_;
scoped_refptr<base::SequencedTaskRunner> read_task_runner_;
};
NativeMessagingReader::Core::Core(
base::File file,
scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner,
scoped_refptr<base::SequencedTaskRunner> read_task_runner,
base::WeakPtr<NativeMessagingReader> reader)
: read_stream_(std::move(file)),
reader_(reader),
caller_task_runner_(caller_task_runner),
read_task_runner_(read_task_runner) {}
NativeMessagingReader::Core::~Core() = default;
void NativeMessagingReader::Core::ReadMessage() {
DCHECK(read_task_runner_->RunsTasksInCurrentSequence());
while (true) {
MessageLengthType message_length;
int read_result = UNSAFE_TODO(read_stream_.ReadAtCurrentPos(
reinterpret_cast<char*>(&message_length), kMessageHeaderSize));
if (read_result != kMessageHeaderSize) {
if (read_result != 0) {
LOG(ERROR) << "Failed to read message header, read returned "
<< read_result;
}
NotifyEof();
return;
}
if (message_length > kMaximumMessageSize) {
LOG(ERROR) << "Message size too large: " << message_length;
NotifyEof();
return;
}
std::string message_json(message_length, '\0');
read_result = UNSAFE_TODO(
read_stream_.ReadAtCurrentPos(std::data(message_json), message_length));
if (read_result != static_cast<int>(message_length)) {
LOG(ERROR) << "Failed to read message body, read returned "
<< read_result;
NotifyEof();
return;
}
std::optional<base::Value> message = base::JSONReader::Read(
message_json, base::JSON_PARSE_CHROMIUM_EXTENSIONS);
if (!message) {
LOG(ERROR) << "Failed to parse JSON message: " << message_json;
NotifyEof();
return;
}
caller_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&NativeMessagingReader::InvokeMessageCallback,
reader_, std::move(*message)));
}
}
void NativeMessagingReader::Core::NotifyEof() {
DCHECK(read_task_runner_->RunsTasksInCurrentSequence());
caller_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&NativeMessagingReader::InvokeEofCallback, reader_));
}
NativeMessagingReader::NativeMessagingReader(base::File file)
: reader_thread_("Reader") {
reader_thread_.StartWithOptions(
base::Thread::Options(base::MessagePumpType::IO, 0));
read_task_runner_ = reader_thread_.task_runner();
core_ = std::make_unique<Core>(
std::move(file), base::SingleThreadTaskRunner::GetCurrentDefault(),
read_task_runner_, weak_factory_.GetWeakPtr());
}
NativeMessagingReader::~NativeMessagingReader() {
read_task_runner_->DeleteSoon(FROM_HERE, core_.release());
#if BUILDFLAG(IS_WIN)
base::PlatformThreadId thread_id = reader_thread_.GetThreadId();
base::win::ScopedHandle thread_handle(
OpenThread(THREAD_TERMINATE, false, thread_id.raw()));
if (!CancelSynchronousIo(thread_handle.Get())) {
if (GetLastError() != ERROR_NOT_FOUND) {
PLOG(ERROR) << "CancelSynchronousIo() failed";
}
}
#endif
}
void NativeMessagingReader::Start(const MessageCallback& message_callback,
base::OnceClosure eof_callback) {
message_callback_ = message_callback;
eof_callback_ = std::move(eof_callback);
read_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&NativeMessagingReader::Core::ReadMessage,
base::Unretained(core_.get())));
}
void NativeMessagingReader::InvokeMessageCallback(base::Value message) {
message_callback_.Run(std::move(message));
}
void NativeMessagingReader::InvokeEofCallback() {
std::move(eof_callback_).Run();
}
}