#include "remoting/host/setup/me2me_native_messaging_host.h"
#include <stddef.h>
#include <array>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include "base/compiler_specific.h"
#include "base/containers/span.h"
#include "base/functional/bind.h"
#include "base/json/json_reader.h"
#include "base/json/json_writer.h"
#include "base/memory/ptr_util.h"
#include "base/memory/raw_ptr.h"
#include "base/numerics/safe_conversions.h"
#include "base/run_loop.h"
#include "base/strings/stringize_macros.h"
#include "base/test/task_environment.h"
#include "base/values.h"
#include "google_apis/gaia/gaia_oauth_client.h"
#include "net/base/file_stream.h"
#include "net/base/network_interfaces.h"
#include "remoting/base/auto_thread_task_runner.h"
#include "remoting/base/mock_oauth_client.h"
#include "remoting/host/chromoting_host_context.h"
#include "remoting/host/native_messaging/log_message_handler.h"
#include "remoting/host/native_messaging/native_messaging_pipe.h"
#include "remoting/host/native_messaging/pipe_messaging_channel.h"
#include "remoting/host/pin_hash.h"
#include "remoting/host/setup/test_util.h"
#include "remoting/protocol/pairing_registry.h"
#include "remoting/protocol/protocol_mock_objects.h"
#include "services/network/test/test_shared_url_loader_factory.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace {
using remoting::protocol::MockPairingRegistryDelegate;
using remoting::protocol::PairingRegistry;
using remoting::protocol::SynchronousPairingRegistry;
using ::testing::Optional;
void VerifyHelloResponse(const base::Value::Dict& response) {
const std::string* value = response.FindString("type");
ASSERT_TRUE(value);
EXPECT_EQ("helloResponse", *value);
value = response.FindString("version");
ASSERT_TRUE(value);
#ifndef VERSION
#error VERSION must be defined
#endif
EXPECT_EQ(STRINGIZE(VERSION), *value);
}
void VerifyGetHostNameResponse(const base::Value::Dict& response) {
const std::string* value = response.FindString("type");
ASSERT_TRUE(value);
EXPECT_EQ("getHostNameResponse", *value);
value = response.FindString("hostname");
ASSERT_TRUE(value);
EXPECT_EQ(net::GetHostName(), *value);
}
void VerifyGetPinHashResponse(const base::Value::Dict& response) {
const std::string* value = response.FindString("type");
ASSERT_TRUE(value);
EXPECT_EQ("getPinHashResponse", *value);
value = response.FindString("hash");
ASSERT_TRUE(value);
EXPECT_EQ(remoting::MakeHostPinHash("my_host", "1234"), *value);
}
void VerifyGenerateKeyPairResponse(const base::Value::Dict& response) {
const std::string* value = response.FindString("type");
ASSERT_TRUE(value);
EXPECT_EQ("generateKeyPairResponse", *value);
EXPECT_TRUE(response.FindString("privateKey"));
EXPECT_TRUE(response.FindString("publicKey"));
}
void VerifyGetDaemonConfigResponse(const base::Value::Dict& response) {
const std::string* value = response.FindString("type");
ASSERT_TRUE(value);
EXPECT_EQ("getDaemonConfigResponse", *value);
const base::Value::Dict* config = response.FindDict("config");
ASSERT_TRUE(config);
EXPECT_EQ(base::Value::Dict(), *config);
}
void VerifyGetUsageStatsConsentResponse(const base::Value::Dict& response) {
const std::string* value = response.FindString("type");
ASSERT_TRUE(value);
EXPECT_EQ("getUsageStatsConsentResponse", *value);
EXPECT_THAT(response.FindBool("supported"), Optional(true));
EXPECT_THAT(response.FindBool("allowed"), Optional(true));
EXPECT_THAT(response.FindBool("setByPolicy"), Optional(true));
}
void VerifyStopDaemonResponse(const base::Value::Dict& response) {
const std::string* value = response.FindString("type");
ASSERT_TRUE(value);
EXPECT_EQ("stopDaemonResponse", *value);
value = response.FindString("result");
ASSERT_TRUE(value);
EXPECT_EQ("OK", *value);
}
void VerifyGetDaemonStateResponse(const base::Value::Dict& response) {
const std::string* value = response.FindString("type");
ASSERT_TRUE(value);
EXPECT_EQ("getDaemonStateResponse", *value);
value = response.FindString("state");
ASSERT_TRUE(value);
EXPECT_EQ("STARTED", *value);
}
void VerifyUpdateDaemonConfigResponse(const base::Value::Dict& response) {
const std::string* value = response.FindString("type");
ASSERT_TRUE(value);
EXPECT_EQ("updateDaemonConfigResponse", *value);
value = response.FindString("result");
ASSERT_TRUE(value);
EXPECT_EQ("OK", *value);
}
void VerifyStartDaemonResponse(const base::Value::Dict& response) {
const std::string* value = response.FindString("type");
ASSERT_TRUE(value);
EXPECT_EQ("startDaemonResponse", *value);
value = response.FindString("result");
ASSERT_TRUE(value);
EXPECT_EQ("OK", *value);
}
void VerifyGetCredentialsFromAuthCodeResponse(
const base::Value::Dict& response) {
const std::string* value = response.FindString("type");
ASSERT_TRUE(value);
EXPECT_EQ("getCredentialsFromAuthCodeResponse", *value);
value = response.FindString("userEmail");
ASSERT_TRUE(value);
EXPECT_EQ("fake_user_email", *value);
value = response.FindString("refreshToken");
ASSERT_TRUE(value);
EXPECT_EQ("fake_refresh_token", *value);
}
}
namespace remoting {
class MockDaemonControllerDelegate : public DaemonController::Delegate {
public:
MockDaemonControllerDelegate();
MockDaemonControllerDelegate(const MockDaemonControllerDelegate&) = delete;
MockDaemonControllerDelegate& operator=(const MockDaemonControllerDelegate&) =
delete;
~MockDaemonControllerDelegate() override;
DaemonController::State GetState() override;
std::optional<base::Value::Dict> GetConfig() override;
void CheckPermission(bool it2me,
DaemonController::BoolCallback callback) override;
void SetConfigAndStart(base::Value::Dict config,
bool consent,
DaemonController::CompletionCallback done) override;
void UpdateConfig(base::Value::Dict config,
DaemonController::CompletionCallback done) override;
void Stop(DaemonController::CompletionCallback done) override;
DaemonController::UsageStatsConsent GetUsageStatsConsent() override;
};
MockDaemonControllerDelegate::MockDaemonControllerDelegate() = default;
MockDaemonControllerDelegate::~MockDaemonControllerDelegate() = default;
DaemonController::State MockDaemonControllerDelegate::GetState() {
return DaemonController::STATE_STARTED;
}
std::optional<base::Value::Dict> MockDaemonControllerDelegate::GetConfig() {
return base::Value::Dict();
}
void MockDaemonControllerDelegate::CheckPermission(
bool it2me,
DaemonController::BoolCallback callback) {
std::move(callback).Run(true);
}
void MockDaemonControllerDelegate::SetConfigAndStart(
base::Value::Dict config,
bool consent,
DaemonController::CompletionCallback done) {
if (consent && config.Find("start")) {
std::move(done).Run(DaemonController::RESULT_OK);
} else {
std::move(done).Run(DaemonController::RESULT_FAILED);
}
}
void MockDaemonControllerDelegate::UpdateConfig(
base::Value::Dict config,
DaemonController::CompletionCallback done) {
if (config.Find("update")) {
std::move(done).Run(DaemonController::RESULT_OK);
} else {
std::move(done).Run(DaemonController::RESULT_FAILED);
}
}
void MockDaemonControllerDelegate::Stop(
DaemonController::CompletionCallback done) {
std::move(done).Run(DaemonController::RESULT_OK);
}
DaemonController::UsageStatsConsent
MockDaemonControllerDelegate::GetUsageStatsConsent() {
DaemonController::UsageStatsConsent consent;
consent.supported = true;
consent.allowed = true;
consent.set_by_policy = true;
return consent;
}
class Me2MeNativeMessagingHostTest : public testing::Test {
public:
Me2MeNativeMessagingHostTest();
Me2MeNativeMessagingHostTest(const Me2MeNativeMessagingHostTest&) = delete;
Me2MeNativeMessagingHostTest& operator=(const Me2MeNativeMessagingHostTest&) =
delete;
~Me2MeNativeMessagingHostTest() override;
void SetUp() override;
void TearDown() override;
std::optional<base::Value::Dict> ReadMessageFromOutputPipe();
void WriteMessageToInputPipe(const base::ValueView& message);
void TestBadRequest(const base::Value& message);
protected:
raw_ptr<MockDaemonControllerDelegate, AcrossTasksDanglingUntriaged>
daemon_controller_delegate_;
private:
void StartHost();
void StopHost();
void ExitTest();
base::File input_write_file_;
base::File output_read_file_;
std::unique_ptr<base::test::TaskEnvironment> task_environment_;
std::unique_ptr<base::RunLoop> test_run_loop_;
std::unique_ptr<base::Thread> host_thread_;
std::unique_ptr<base::RunLoop> host_run_loop_;
scoped_refptr<network::TestSharedURLLoaderFactory> test_url_loader_factory_;
scoped_refptr<AutoThreadTaskRunner> host_task_runner_;
std::unique_ptr<NativeMessagingPipe> native_messaging_pipe_;
};
Me2MeNativeMessagingHostTest::Me2MeNativeMessagingHostTest() = default;
Me2MeNativeMessagingHostTest::~Me2MeNativeMessagingHostTest() = default;
void Me2MeNativeMessagingHostTest::SetUp() {
base::File input_read_file;
base::File output_write_file;
ASSERT_TRUE(MakePipe(&input_read_file, &input_write_file_));
ASSERT_TRUE(MakePipe(&output_read_file_, &output_write_file));
task_environment_ = std::make_unique<base::test::TaskEnvironment>();
test_run_loop_ = std::make_unique<base::RunLoop>();
host_thread_ = std::make_unique<base::Thread>("host_thread");
host_thread_->Start();
host_task_runner_ = new AutoThreadTaskRunner(
host_thread_->task_runner(),
base::BindOnce(&Me2MeNativeMessagingHostTest::ExitTest,
base::Unretained(this)));
#if BUILDFLAG(IS_CHROMEOS)
test_url_loader_factory_ = new network::TestSharedURLLoaderFactory();
#endif
host_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&Me2MeNativeMessagingHostTest::StartHost,
base::Unretained(this)));
test_run_loop_->Run();
}
void Me2MeNativeMessagingHostTest::StartHost() {
DCHECK(host_task_runner_->RunsTasksInCurrentSequence());
base::File input_read_file;
base::File output_write_file;
ASSERT_TRUE(MakePipe(&input_read_file, &input_write_file_));
ASSERT_TRUE(MakePipe(&output_read_file_, &output_write_file));
daemon_controller_delegate_ = new MockDaemonControllerDelegate();
scoped_refptr<DaemonController> daemon_controller(new DaemonController(
base::WrapUnique(daemon_controller_delegate_.get())));
scoped_refptr<PairingRegistry> pairing_registry =
new SynchronousPairingRegistry(
base::WrapUnique(new MockPairingRegistryDelegate()));
native_messaging_pipe_ = std::make_unique<NativeMessagingPipe>();
std::unique_ptr<extensions::NativeMessagingChannel> channel(
new PipeMessagingChannel(std::move(input_read_file),
std::move(output_write_file)));
std::unique_ptr<OAuthClient> oauth_client(
new MockOAuthClient("fake_user_email", "fake_refresh_token"));
std::unique_ptr<ChromotingHostContext> context =
ChromotingHostContext::CreateForTesting(
new remoting::AutoThreadTaskRunner(
host_task_runner_,
base::BindOnce(&Me2MeNativeMessagingHostTest::StopHost,
base::Unretained(this))),
test_url_loader_factory_);
std::unique_ptr<remoting::Me2MeNativeMessagingHost> host(
new Me2MeNativeMessagingHost(false, 0, std::move(context),
daemon_controller, pairing_registry,
std::move(oauth_client)));
host->Start(native_messaging_pipe_.get());
native_messaging_pipe_->Start(std::move(host), std::move(channel));
test_run_loop_->Quit();
}
void Me2MeNativeMessagingHostTest::StopHost() {
DCHECK(host_task_runner_->RunsTasksInCurrentSequence());
native_messaging_pipe_.reset();
base::RunLoop().RunUntilIdle();
host_task_runner_ = nullptr;
}
void Me2MeNativeMessagingHostTest::ExitTest() {
if (!task_environment_->GetMainThreadTaskRunner()
->RunsTasksInCurrentSequence()) {
task_environment_->GetMainThreadTaskRunner()->PostTask(
FROM_HERE, base::BindOnce(&Me2MeNativeMessagingHostTest::ExitTest,
base::Unretained(this)));
return;
}
test_run_loop_->Quit();
}
void Me2MeNativeMessagingHostTest::TearDown() {
input_write_file_.Close();
test_run_loop_ = std::make_unique<base::RunLoop>();
test_run_loop_->Run();
std::optional<base::Value::Dict> response = ReadMessageFromOutputPipe();
EXPECT_FALSE(response);
output_read_file_.Close();
}
std::optional<base::Value::Dict>
Me2MeNativeMessagingHostTest::ReadMessageFromOutputPipe() {
while (true) {
uint32_t length;
int read_result = UNSAFE_TODO(output_read_file_.ReadAtCurrentPos(
reinterpret_cast<char*>(&length), sizeof(length)));
if (read_result != sizeof(length)) {
return std::nullopt;
}
std::string message_json(length, '\0');
read_result = UNSAFE_TODO(
output_read_file_.ReadAtCurrentPos(std::data(message_json), length));
if (read_result != static_cast<int>(length)) {
return std::nullopt;
}
std::optional<base::Value::Dict> message = base::JSONReader::ReadDict(
message_json, base::JSON_PARSE_CHROMIUM_EXTENSIONS);
if (!message) {
return std::nullopt;
}
const std::string* type = message->FindString("type");
if (!type || *type != LogMessageHandler::kDebugMessageTypeName) {
return std::move(*message);
}
}
}
void Me2MeNativeMessagingHostTest::WriteMessageToInputPipe(
const base::ValueView& message) {
std::string message_json = base::WriteJson(message).value_or("");
uint32_t length = base::checked_cast<uint32_t>(message_json.length());
input_write_file_.WriteAtCurrentPos(base::byte_span_from_ref(length));
input_write_file_.WriteAtCurrentPos(base::as_byte_span(message_json));
}
void Me2MeNativeMessagingHostTest::TestBadRequest(const base::Value& message) {
base::Value::Dict good_message;
good_message.Set("type", "hello");
WriteMessageToInputPipe(good_message);
WriteMessageToInputPipe(message);
WriteMessageToInputPipe(good_message);
std::optional<base::Value::Dict> response = ReadMessageFromOutputPipe();
ASSERT_TRUE(response);
VerifyHelloResponse(std::move(*response));
response = ReadMessageFromOutputPipe();
EXPECT_FALSE(response);
}
TEST_F(Me2MeNativeMessagingHostTest, All) {
int next_id = 0;
base::Value::Dict message;
message.Set("id", next_id++);
message.Set("type", "hello");
WriteMessageToInputPipe(message);
message.Set("id", next_id++);
message.Set("type", "getHostName");
WriteMessageToInputPipe(message);
message.Set("id", next_id++);
message.Set("type", "getPinHash");
message.Set("hostId", "my_host");
message.Set("pin", "1234");
WriteMessageToInputPipe(message);
message.clear();
message.Set("id", next_id++);
message.Set("type", "generateKeyPair");
WriteMessageToInputPipe(message);
message.Set("id", next_id++);
message.Set("type", "getDaemonConfig");
WriteMessageToInputPipe(message);
message.Set("id", next_id++);
message.Set("type", "getUsageStatsConsent");
WriteMessageToInputPipe(message);
message.Set("id", next_id++);
message.Set("type", "stopDaemon");
WriteMessageToInputPipe(message);
message.Set("id", next_id++);
message.Set("type", "getDaemonState");
WriteMessageToInputPipe(message);
base::Value::Dict config;
config.Set("update", true);
message.Set("config", config.Clone());
message.Set("id", next_id++);
message.Set("type", "updateDaemonConfig");
WriteMessageToInputPipe(message);
config.clear();
config.Set("start", true);
message.Set("config", config.Clone());
message.Set("consent", true);
message.Set("id", next_id++);
message.Set("type", "startDaemon");
WriteMessageToInputPipe(message);
message.Set("id", next_id++);
message.Set("type", "getCredentialsFromAuthCode");
message.Set("authorizationCode", "fake_auth_code");
WriteMessageToInputPipe(message);
auto verify_routines = std::to_array<void (*)(const base::Value::Dict&)>({
&VerifyHelloResponse,
&VerifyGetHostNameResponse,
&VerifyGetPinHashResponse,
&VerifyGenerateKeyPairResponse,
&VerifyGetDaemonConfigResponse,
&VerifyGetUsageStatsConsentResponse,
&VerifyStopDaemonResponse,
&VerifyGetDaemonStateResponse,
&VerifyUpdateDaemonConfigResponse,
&VerifyStartDaemonResponse,
&VerifyGetCredentialsFromAuthCodeResponse,
});
ASSERT_EQ(std::size(verify_routines), static_cast<size_t>(next_id));
for (int i = 0; i < next_id; ++i) {
std::optional<base::Value::Dict> response = ReadMessageFromOutputPipe();
ASSERT_TRUE(response);
std::optional<int> id = response->FindInt("id");
ASSERT_TRUE(id);
ASSERT_TRUE(0 <= *id && *id < next_id);
ASSERT_TRUE(verify_routines[*id]);
verify_routines[*id](std::move(*response));
verify_routines[*id] = nullptr;
}
}
TEST_F(Me2MeNativeMessagingHostTest, Id) {
base::Value::Dict message;
message.Set("type", "hello");
WriteMessageToInputPipe(message);
message.Set("id", "42");
WriteMessageToInputPipe(message);
std::optional<base::Value::Dict> response = ReadMessageFromOutputPipe();
EXPECT_TRUE(response);
std::string* value = response->FindString("id");
EXPECT_FALSE(value);
response = ReadMessageFromOutputPipe();
EXPECT_TRUE(response);
value = response->FindString("id");
EXPECT_TRUE(value);
EXPECT_EQ("42", *value);
}
TEST_F(Me2MeNativeMessagingHostTest, WrongFormat) {
TestBadRequest(base::Value(base::Value::Type::LIST));
}
TEST_F(Me2MeNativeMessagingHostTest, MissingType) {
TestBadRequest(base::Value(base::Value::Type::DICT));
}
TEST_F(Me2MeNativeMessagingHostTest, InvalidType) {
base::Value::Dict message;
message.Set("type", "xxx");
TestBadRequest(base::Value(std::move(message)));
}
TEST_F(Me2MeNativeMessagingHostTest, GetPinHashNoHostId) {
base::Value::Dict message;
message.Set("type", "getPinHash");
message.Set("pin", "1234");
TestBadRequest(base::Value(std::move(message)));
}
TEST_F(Me2MeNativeMessagingHostTest, GetPinHashNoPin) {
base::Value::Dict message;
message.Set("type", "getPinHash");
message.Set("hostId", "my_host");
TestBadRequest(base::Value(std::move(message)));
}
TEST_F(Me2MeNativeMessagingHostTest, UpdateDaemonConfigInvalidConfig) {
base::Value::Dict message;
message.Set("type", "updateDaemonConfig");
message.Set("config", "xxx");
TestBadRequest(base::Value(std::move(message)));
}
TEST_F(Me2MeNativeMessagingHostTest, StartDaemonInvalidConfig) {
base::Value::Dict message;
message.Set("type", "startDaemon");
message.Set("config", "xxx");
message.Set("consent", true);
TestBadRequest(base::Value(std::move(message)));
}
TEST_F(Me2MeNativeMessagingHostTest, StartDaemonNoConsent) {
base::Value::Dict message;
message.Set("type", "startDaemon");
message.Set("config", base::Value::Dict());
TestBadRequest(base::Value(std::move(message)));
}
TEST_F(Me2MeNativeMessagingHostTest, GetCredentialsFromAuthCodeNoAuthCode) {
base::Value::Dict message;
message.Set("type", "getCredentialsFromAuthCode");
TestBadRequest(base::Value(std::move(message)));
}
}