910e62b5创建于 1月15日历史提交
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "remoting/test/session_authz_playground.h"

#include <cstdio>
#include <cstdlib>
#include <memory>
#include <string_view>
#include <vector>

#include "base/command_line.h"
#include "base/compiler_specific.h"
#include "base/files/file_path.h"
#include "base/functional/bind.h"
#include "base/memory/scoped_refptr.h"
#include "base/run_loop.h"
#include "base/task/bind_post_task.h"
#include "base/task/single_thread_task_runner.h"
#include "base/time/time.h"
#include "remoting/base/certificate_helpers.h"
#include "remoting/base/oauth_token_getter.h"
#include "remoting/base/oauth_token_getter_impl.h"
#include "remoting/base/url_request_context_getter.h"
#include "remoting/host/host_config.h"
#include "remoting/proto/session_authz_service.h"
#include "remoting/test/cli_util.h"

namespace remoting {

namespace {

constexpr char kOAuthScope[] =
    "https://www.googleapis.com/auth/chromoting.me2me.host";

void PrintErrorAndExit(const std::string& api_name, const HttpStatus& status) {
  UNSAFE_TODO(fprintf(stderr, "Failed to call API: %s, error code: %d\n",
                      api_name.data(), static_cast<int>(status.error_code())));
  UNSAFE_TODO(fprintf(stderr, "%s\n", status.error_message().data()));
  UNSAFE_TODO(fprintf(stderr, "%s\n", status.response_body().data()));
  std::exit(1);
}

}  // namespace

SessionAuthzPlayground::SessionAuthzPlayground() {
  auto url_request_context_getter =
      base::MakeRefCounted<URLRequestContextGetter>(
          base::SingleThreadTaskRunner::GetCurrentDefault());
  url_loader_factory_owner_ =
      std::make_unique<network::TransitionalURLLoaderFactoryOwner>(
          url_request_context_getter, /* is_trusted= */ true);
}

SessionAuthzPlayground::~SessionAuthzPlayground() = default;

void SessionAuthzPlayground::Start() {
  auto* cmd_line = base::CommandLine::ForCurrentProcess();
  base::CommandLine::StringVector args = cmd_line->GetArgs();
  base::FilePath host_config_file_path;
  if (args.size() == 1) {
    host_config_file_path = base::FilePath(args[0]);
  }
  if (host_config_file_path.empty()) {
    printf("Usage: %s <path-to-host-config-json>\n",
           cmd_line->GetProgram().MaybeAsASCII().c_str());
    std::exit(1);
  }

  service_client_ = std::make_unique<CorpSessionAuthzServiceClient>(
      url_loader_factory_owner_->GetURLLoaderFactory(),
      CreateClientCertStoreInstance(),
      CreateOAuthTokenGetter(host_config_file_path),
      /* support_id= */ std::string_view());

  run_loop_ = std::make_unique<base::RunLoop>();
  GenerateHostToken();
  run_loop_->Run();
}

void SessionAuthzPlayground::GenerateHostToken() {
  printf("Fetching host token...\n");
  service_client_->GenerateHostToken(
      base::BindOnce(
          [](const HttpStatus& status,
             std::unique_ptr<internal::GenerateHostTokenResponseStruct>
                 response) {
            if (!status.ok()) {
              PrintErrorAndExit("GenerateHostToken", status);
            }
            if (!response) {
              fprintf(stderr, "No response was received.\n");
            }
            UNSAFE_TODO(
                printf("Host token: %s\n", response->host_token.data()));
            UNSAFE_TODO(
                printf("Session ID: %s\n", response->session_id.data()));
            return response->session_id;
          })
          .Then(base::BindPostTaskToCurrentDefault(
              base::BindOnce(&SessionAuthzPlayground::VerifySessionToken,
                             base::Unretained(this)))));
}

void SessionAuthzPlayground::VerifySessionToken(const std::string& session_id) {
  printf("Enter session token generated by the client: ");
  std::string session_token = test::ReadString();
  service_client_->VerifySessionToken(
      session_token,
      base::BindOnce(
          [](const std::string& session_id, const HttpStatus& status,
             std::unique_ptr<internal::VerifySessionTokenResponseStruct>
                 response) {
            if (!status.ok()) {
              PrintErrorAndExit("VerifySessionToken", status);
            }
            if (!response) {
              fprintf(stderr, "No response was received.\n");
            }
            UNSAFE_TODO(printf("Reauth token: %s\n",
                               response->session_reauth_token.data()));
            printf("Reauth token expires in: %fs\n",
                   response->session_reauth_token_lifetime.InSecondsF());
            UNSAFE_TODO(
                printf("Session ID: %s\n", response->session_id.data()));
            UNSAFE_TODO(
                printf("Shared secret: %s\n", response->shared_secret.data()));
            if (session_id != response->session_id) {
              UNSAFE_TODO(fprintf(stderr,
                                  "Unexpected session ID. Expected: %s\n",
                                  session_id.data()));
              std::exit(1);
            }
            return response->session_reauth_token;
          },
          session_id)
          .Then(base::BindPostTaskToCurrentDefault(
              base::BindPostTaskToCurrentDefault(
                  base::BindOnce(&SessionAuthzPlayground::ReauthorizeHost,
                                 base::Unretained(this), session_id)))));
}

void SessionAuthzPlayground::ReauthorizeHost(const std::string& session_id,
                                             const std::string& reauth_token) {
  printf("Reauthorize host? [y/N]: ");
  bool shouldReauthorize = test::ReadYNBool();
  if (!shouldReauthorize) {
    run_loop_->Quit();
  }
  service_client_->ReauthorizeHost(
      reauth_token, session_id, base::TimeTicks::Max(),
      base::BindOnce([](const HttpStatus& status,
                        std::unique_ptr<internal::ReauthorizeHostResponseStruct>
                            response) {
        if (!status.ok()) {
          PrintErrorAndExit("ReauthorizeHost", status);
        }
        if (!response) {
          fprintf(stderr, "No response was received.\n");
        }
        UNSAFE_TODO(printf("Reauth token: %s\n",
                           response->session_reauth_token.data()));
        printf("Reauth token expires in: %fs\n",
               response->session_reauth_token_lifetime.InSecondsF());
        return response->session_reauth_token;
      })
          .Then(base::BindPostTaskToCurrentDefault(
              base::BindOnce(&SessionAuthzPlayground::ReauthorizeHost,
                             base::Unretained(this), session_id))));
}

std::unique_ptr<OAuthTokenGetter>
SessionAuthzPlayground::CreateOAuthTokenGetter(
    const base::FilePath& host_config_file_path) {
  auto host_config_dict = HostConfigFromJsonFile(host_config_file_path);
  if (!host_config_dict.has_value()) {
    UNSAFE_TODO(fprintf(stderr, "Cannot read host config file: %s\n",
                        host_config_file_path.MaybeAsASCII().data()));
    std::exit(1);
  }
  const std::string* refresh_token =
      host_config_dict->FindString(kOAuthRefreshTokenConfigPath);
  if (!refresh_token) {
    fprintf(stderr, "Cannot find OAuth refresh token from host config file.\n");
    std::exit(1);
  }
  const std::string* service_account_email =
      host_config_dict->FindString(kServiceAccountConfigPath);
  if (!service_account_email) {
    fprintf(stderr, "Cannot find service account email.\n");
    std::exit(1);
  }
  auto oauth_credentials =
      std::make_unique<OAuthTokenGetter::OAuthAuthorizationCredentials>(
          *service_account_email, *refresh_token,
          /* is_service_account= */ true,
          std::vector<std::string>{kOAuthScope});
  return std::make_unique<OAuthTokenGetterImpl>(
      std::move(oauth_credentials),
      url_loader_factory_owner_->GetURLLoaderFactory(), false);
}

}  // namespace remoting