#include "remoting/test/it2me_cli_host.h"
#include "base/command_line.h"
#include "base/functional/bind.h"
#include "base/json/json_reader.h"
#include "base/json/json_writer.h"
#include "base/run_loop.h"
#include "base/task/single_thread_task_runner.h"
#include "base/time/time.h"
#include "net/base/network_change_notifier.h"
#include "remoting/base/auto_thread_task_runner.h"
#include "remoting/base/logging.h"
#include "remoting/host/chromoting_host_context.h"
#include "remoting/host/it2me/it2me_constants.h"
#include "remoting/host/it2me/it2me_native_messaging_host.h"
#include "remoting/host/policy_watcher.h"
#include "remoting/test/test_oauth_token_getter.h"
#include "remoting/test/test_token_storage.h"
namespace remoting {
namespace {
constexpr char kCRDDebugLog[] = "_debug_log";
constexpr char kSwitchNameHelp[] = "help";
constexpr char kSwitchNameUsername[] = "username";
constexpr char kSwitchNameStoragePath[] = "storage-path";
std::unique_ptr<It2MeNativeMessagingHost> CreateNativeMessagingHost(
scoped_refptr<AutoThreadTaskRunner> ui_task_runner) {
auto context = ChromotingHostContext::Create(ui_task_runner);
std::unique_ptr<PolicyWatcher> policy_watcher =
PolicyWatcher::CreateWithTaskRunner(context->file_task_runner(),
context->management_service());
auto factory = std::make_unique<It2MeHostFactory>();
return std::make_unique<It2MeNativeMessagingHost>(
false, std::move(policy_watcher),
std::move(context), std::move(factory));
}
}
bool It2MeCliHost::ShouldPrintHelp() {
return base::CommandLine::ForCurrentProcess()->HasSwitch(kSwitchNameHelp);
}
void It2MeCliHost::PrintHelp() {
fprintf(stderr,
"Usage: %s [--storage-path=<storage-path>] "
"[--username=<example@gmail.com>]\n ",
base::CommandLine::ForCurrentProcess()
->GetProgram()
.AsUTF8Unsafe()
.c_str());
}
It2MeCliHost::It2MeCliHost() {}
It2MeCliHost::~It2MeCliHost() = default;
void It2MeCliHost::Start() {
base::CommandLine* cmd_line = base::CommandLine::ForCurrentProcess();
std::string username = cmd_line->GetSwitchValueASCII(kSwitchNameUsername);
base::FilePath storage_path =
cmd_line->GetSwitchValuePath(kSwitchNameStoragePath);
storage_ = test::TestTokenStorage::OnDisk(username, storage_path);
token_getter_ = std::make_unique<test::TestOAuthTokenGetter>(storage_.get());
base::RunLoop initialize_token_getter_loop;
token_getter_->Initialize(initialize_token_getter_loop.QuitClosure());
initialize_token_getter_loop.Run();
base::RunLoop ui_loop;
ui_task_runner_ = new AutoThreadTaskRunner(
base::SingleThreadTaskRunner::GetCurrentDefault(), ui_loop.QuitClosure());
std::unique_ptr<net::NetworkChangeNotifier> network_change_notifier(
net::NetworkChangeNotifier::CreateIfNeeded());
ui_loop.Run();
}
void It2MeCliHost::PostMessageFromNativeHost(const std::string& message) {
auto message_dict =
base::JSONReader::ReadDict(message, base::JSON_PARSE_CHROMIUM_EXTENSIONS);
if (!message_dict) {
OnProtocolBroken("Message is not a dictionary");
return;
}
std::string* type = message_dict->FindString(kMessageType);
if (!type) {
OnProtocolBroken("Message without type");
return;
}
if (*type == kHelloResponse) {
OnHelloResponse();
} else if (*type == kConnectResponse) {
} else if (*type == kDisconnectResponse) {
OnDisconnectResponse();
} else if (*type == kHostStateChangedMessage) {
std::string* state = message_dict->FindString(kState);
if (!state) {
OnProtocolBroken("No state in message");
return;
}
if (*state == kHostStateReceivedAccessCode) {
OnStateReceivedAccessCode(*message_dict);
} else if (*state == kHostStateConnected) {
OnStateRemoteConnected(*message_dict);
} else if (*state == kHostStateDisconnected) {
OnStateRemoteDisconnected();
} else if (*state == kHostStateError || *state == kHostStateDomainError) {
OnStateError(*state, *message_dict);
} else if (*state == kHostStateStarting ||
*state == kHostStateRequestedAccessCode) {
} else {
LOG(WARNING) << "Unhandled state: " << *state;
}
} else if (*type == kCRDDebugLog) {
} else {
LOG(WARNING) << "Unknown message type: " << *type;
}
}
void It2MeCliHost::CloseChannel(const std::string& error_message) {
LOG(ERROR) << "CRD Host closed channel: " << error_message;
command_awaiting_crd_access_code_ = false;
ShutdownHost();
}
void It2MeCliHost::SendMessageToHost(const std::string& type,
base::Value::Dict params) {
params.Set(kMessageType, type);
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(&It2MeCliHost::DoSendMessage, weak_factory_.GetWeakPtr(),
base::WriteJson(params).value_or("")));
}
void It2MeCliHost::DoSendMessage(const std::string& json) {
if (!host_) {
return;
}
host_->OnMessage(json);
}
void It2MeCliHost::OnProtocolBroken(const std::string& message) {
LOG(ERROR) << "Error communicating with CRD Host : " << message;
command_awaiting_crd_access_code_ = false;
ShutdownHost();
}
void It2MeCliHost::StartCRDHostAndGetCode(OAuthTokenGetter::Status status,
const std::string& user_email,
const std::string& access_token,
const std::string& scopes) {
DCHECK(!host_);
connect_params_ = base::Value::Dict()
.Set(kUserName, user_email)
.Set(kAccessToken, access_token);
remote_connected_ = false;
command_awaiting_crd_access_code_ = true;
host_ = CreateNativeMessagingHost(ui_task_runner_);
host_->Start(this);
SendMessageToHost(kHelloMessage, base::Value::Dict());
}
void It2MeCliHost::ShutdownHost() {
if (!host_) {
return;
}
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&It2MeCliHost::DoShutdownHost,
weak_factory_.GetWeakPtr()));
}
void It2MeCliHost::DoShutdownHost() {
host_.reset();
ui_task_runner_.reset();
}
void It2MeCliHost::OnHelloResponse() {
SendMessageToHost(kConnectMessage, std::move(connect_params_));
}
void It2MeCliHost::OnDisconnectResponse() {
DCHECK(!command_awaiting_crd_access_code_);
DCHECK(!remote_connected_);
ShutdownHost();
}
void It2MeCliHost::OnStateError(const std::string& error_state,
const base::Value::Dict& message) {
std::string error_message;
if (error_state == kHostStateDomainError) {
error_message = "CRD Error : Invalid domain";
} else {
const std::string* error_code = message.FindString(kErrorMessageCode);
if (error_code) {
error_message = *error_code;
} else {
error_message = "Unknown CRD Error";
}
}
if (command_awaiting_crd_access_code_) {
command_awaiting_crd_access_code_ = false;
LOG(ERROR) << "CRD Error state " + error_state;
}
ShutdownHost();
}
void It2MeCliHost::OnStateRemoteConnected(const base::Value::Dict& message) {
remote_connected_ = true;
const std::string* client = message.FindString(kClient);
if (client) {
HOST_LOG << "Remote connection by " << *client;
}
}
void It2MeCliHost::OnStateRemoteDisconnected() {
if (!remote_connected_) {
return;
}
remote_connected_ = false;
SendMessageToHost(kDisconnectMessage, base::Value::Dict());
}
void It2MeCliHost::OnStateReceivedAccessCode(const base::Value::Dict& message) {
if (!command_awaiting_crd_access_code_) {
if (!remote_connected_) {
SendMessageToHost(kDisconnectMessage, base::Value::Dict());
}
return;
}
const std::string* code = message.FindString(kAccessCode);
const std::optional<int> code_lifetime = message.FindInt(kAccessCodeLifetime);
if (!code || !code_lifetime) {
OnProtocolBroken("Can not obtain access code");
return;
}
command_awaiting_crd_access_code_ = false;
base::TimeDelta expires_in = base::Seconds(*code_lifetime);
HOST_LOG << "It2Me access code is generated: " << *code;
HOST_LOG << "Expires at: " << (base::Time::Now() + expires_in);
}
}