910e62b5创建于 1月15日历史提交
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "remoting/base/cloud_session_authz_service_client_factory.h"

#include <limits.h>

#include <memory>
#include <string>
#include <string_view>
#include <utility>

#include "base/memory/weak_ptr.h"
#include "base/time/time.h"
#include "remoting/base/cloud_service_client.h"
#include "remoting/base/instance_identity_token_getter.h"
#include "remoting/base/oauth_token_getter_impl.h"
#include "remoting/base/session_policies.h"
#include "remoting/proto/google/internal/remoting/cloud/v1alpha/duration.pb.h"
#include "remoting/proto/google/internal/remoting/cloud/v1alpha/session_authz_service.pb.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"

namespace remoting {

namespace {

using Duration = google::internal::remoting::cloud::v1alpha::Duration;
using GenerateHostTokenResponse =
    google::internal::remoting::cloud::v1alpha::GenerateHostTokenResponse;
using ReauthorizeHostResponse =
    google::internal::remoting::cloud::v1alpha::ReauthorizeHostResponse;
using VerifySessionTokenResponse =
    google::internal::remoting::cloud::v1alpha::VerifySessionTokenResponse;

base::TimeDelta fromProtoDuration(const Duration& duration) {
  return base::Seconds(duration.seconds()) +
         base::Nanoseconds(duration.nanos());
}

class CloudSessionAuthzServiceClient : public SessionAuthzServiceClient {
 public:
  CloudSessionAuthzServiceClient(
      OAuthTokenGetter* oauth_token_getter,
      InstanceIdentityTokenGetter* instance_identity_token_getter,
      scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory);

  CloudSessionAuthzServiceClient(const CloudSessionAuthzServiceClient&) =
      delete;
  CloudSessionAuthzServiceClient& operator=(
      const CloudSessionAuthzServiceClient&) = delete;

  ~CloudSessionAuthzServiceClient() override;

  // SessionAuthzServiceClient implementation.
  void GenerateHostToken(GenerateHostTokenCallback callback) override;
  void VerifySessionToken(std::string_view session_token,
                          VerifySessionTokenCallback callback) override;
  void ReauthorizeHost(std::string_view session_reauth_token,
                       std::string_view session_id,
                       base::TimeTicks token_expire_time,
                       ReauthorizeHostCallback callback) override;

 private:
  // Overloads used to create callbacks for |instance_identity_token_getter_|.
  void GenerateHostTokenWithIdToken(GenerateHostTokenCallback callback,
                                    std::string_view instance_identity_token);
  void VerifySessionTokenWithIdToken(std::string session_token,
                                     VerifySessionTokenCallback callback,
                                     std::string_view instance_identity_token);
  void ReauthorizeHostWithIdToken(std::string session_reauth_token,
                                  std::string session_id,
                                  base::TimeTicks token_expire_time,
                                  ReauthorizeHostCallback callback,
                                  std::string_view instance_identity_token);

  void OnGenerateHostTokenResponse(
      GenerateHostTokenCallback callback,
      const HttpStatus& status,
      std::unique_ptr<GenerateHostTokenResponse> response);
  void OnVerifySessionTokenResponse(
      VerifySessionTokenCallback callback,
      const HttpStatus& status,
      std::unique_ptr<VerifySessionTokenResponse> response);
  void OnReauthorizeHostResponse(
      ReauthorizeHostCallback callback,
      const HttpStatus& status,
      std::unique_ptr<ReauthorizeHostResponse> response);

  std::unique_ptr<CloudServiceClient> client_;
  const raw_ptr<InstanceIdentityTokenGetter> instance_identity_token_getter_;

  base::WeakPtrFactory<CloudSessionAuthzServiceClient> weak_factory_{this};
};

CloudSessionAuthzServiceClient::CloudSessionAuthzServiceClient(
    OAuthTokenGetter* oauth_token_getter,
    InstanceIdentityTokenGetter* instance_identity_token_getter,
    scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
    : client_(CloudServiceClient::CreateForChromotingRobotAccount(
          oauth_token_getter,
          url_loader_factory)),
      instance_identity_token_getter_(instance_identity_token_getter) {}

CloudSessionAuthzServiceClient::~CloudSessionAuthzServiceClient() = default;

void CloudSessionAuthzServiceClient::GenerateHostToken(
    GenerateHostTokenCallback callback) {
  instance_identity_token_getter_->RetrieveToken(base::BindOnce(
      &CloudSessionAuthzServiceClient::GenerateHostTokenWithIdToken,
      weak_factory_.GetWeakPtr(), std::move(callback)));
}

void CloudSessionAuthzServiceClient::GenerateHostTokenWithIdToken(
    GenerateHostTokenCallback callback,
    std::string_view instance_identity_token) {
  client_->GenerateHostToken(
      instance_identity_token,
      base::BindOnce(
          &CloudSessionAuthzServiceClient::OnGenerateHostTokenResponse,
          weak_factory_.GetWeakPtr(), std::move(callback)));
}

void CloudSessionAuthzServiceClient::VerifySessionToken(
    std::string_view session_token,
    VerifySessionTokenCallback callback) {
  instance_identity_token_getter_->RetrieveToken(base::BindOnce(
      &CloudSessionAuthzServiceClient::VerifySessionTokenWithIdToken,
      weak_factory_.GetWeakPtr(), std::string(session_token),
      std::move(callback)));
}

void CloudSessionAuthzServiceClient::VerifySessionTokenWithIdToken(
    std::string session_token,
    VerifySessionTokenCallback callback,
    std::string_view instance_identity_token) {
  client_->VerifySessionToken(
      session_token, instance_identity_token,
      base::BindOnce(
          &CloudSessionAuthzServiceClient::OnVerifySessionTokenResponse,
          weak_factory_.GetWeakPtr(), std::move(callback)));
}

void CloudSessionAuthzServiceClient::ReauthorizeHost(
    std::string_view session_reauth_token,
    std::string_view session_id,
    base::TimeTicks token_expire_time,
    ReauthorizeHostCallback callback) {
  instance_identity_token_getter_->RetrieveToken(base::BindOnce(
      &CloudSessionAuthzServiceClient::ReauthorizeHostWithIdToken,
      weak_factory_.GetWeakPtr(), std::string(session_reauth_token),
      std::string(session_id), token_expire_time, std::move(callback)));
}

void CloudSessionAuthzServiceClient::ReauthorizeHostWithIdToken(
    std::string session_reauth_token,
    std::string session_id,
    base::TimeTicks token_expire_time,
    ReauthorizeHostCallback callback,
    std::string_view instance_identity_token) {
  client_->ReauthorizeHost(
      session_reauth_token, session_id, instance_identity_token,
      GetReauthRetryPolicy(token_expire_time),
      base::BindOnce(&CloudSessionAuthzServiceClient::OnReauthorizeHostResponse,
                     weak_factory_.GetWeakPtr(), std::move(callback)));
}

void CloudSessionAuthzServiceClient::OnGenerateHostTokenResponse(
    GenerateHostTokenCallback callback,
    const HttpStatus& status,
    std::unique_ptr<GenerateHostTokenResponse> response) {
  auto response_struct =
      std::make_unique<internal::GenerateHostTokenResponseStruct>();
  response_struct->host_token = response->host_token();
  response_struct->session_id = response->session_id();
  std::move(callback).Run(status, std::move(response_struct));
}

void CloudSessionAuthzServiceClient::OnVerifySessionTokenResponse(
    VerifySessionTokenCallback callback,
    const HttpStatus& status,
    std::unique_ptr<VerifySessionTokenResponse> response) {
  auto response_struct =
      std::make_unique<internal::VerifySessionTokenResponseStruct>();
  response_struct->session_id = response->session_id();
  response_struct->shared_secret = response->shared_secret();
  response_struct->session_reauth_token = response->session_reauth_token();
  auto token_lifetime = response->session_reauth_token_lifetime();
  response_struct->session_reauth_token_lifetime =
      fromProtoDuration(response->session_reauth_token_lifetime());
  if (response->has_session_policies()) {
    SessionPolicies session_policies;
    if (response->session_policies().has_allow_file_transfer()) {
      session_policies.allow_file_transfer =
          response->session_policies().allow_file_transfer();
    }
    if (response->session_policies().has_allow_relayed_connections()) {
      session_policies.allow_relayed_connections =
          response->session_policies().allow_relayed_connections();
    }
    if (response->session_policies().has_allow_stun_connections()) {
      session_policies.allow_stun_connections =
          response->session_policies().allow_stun_connections();
    }
    if (response->session_policies().has_allow_uri_forwarding()) {
      session_policies.allow_uri_forwarding =
          response->session_policies().allow_uri_forwarding();
    }
    if (response->session_policies().has_allow_web_authn_forwarding()) {
      session_policies.allow_webauthn_forwarding =
          response->session_policies().allow_web_authn_forwarding();
    }
    if (response->session_policies().has_clipboard_size_bytes()) {
      session_policies.clipboard_size_bytes =
          response->session_policies().clipboard_size_bytes();
    }
    if (response->session_policies().has_curtain_required()) {
      session_policies.curtain_required =
          response->session_policies().curtain_required();
    }
    if (response->session_policies().has_host_udp_port_range()) {
      const auto& proto_port_range =
          response->session_policies().host_udp_port_range();
      int min_port =
          proto_port_range.has_start() ? proto_port_range.start() : 1;
      int max_port =
          proto_port_range.has_end() ? proto_port_range.end() : USHRT_MAX;
      if (min_port < 1 || min_port > max_port || max_port > USHRT_MAX) {
        LOG(ERROR) << "Invalid port range: [" << min_port << ", " << max_port
                   << "]";
      } else {
        session_policies.host_udp_port_range.min_port = min_port;
        session_policies.host_udp_port_range.max_port = max_port;
      }
    }
    if (response->session_policies().has_maximum_session_duration()) {
      auto maximum_session_duration = base::Seconds(
          response->session_policies().maximum_session_duration().seconds());
      session_policies.maximum_session_duration =
          std::max(maximum_session_duration,
                   SessionPolicies::kMinMaximumSessionDuration);
    }
    response_struct->session_policies.emplace(std::move(session_policies));
  }

  std::move(callback).Run(status, std::move(response_struct));
}

void CloudSessionAuthzServiceClient::OnReauthorizeHostResponse(
    ReauthorizeHostCallback callback,
    const HttpStatus& status,
    std::unique_ptr<ReauthorizeHostResponse> response) {
  auto response_struct =
      std::make_unique<internal::ReauthorizeHostResponseStruct>();
  response_struct->session_reauth_token = response->session_reauth_token();
  response_struct->session_reauth_token_lifetime =
      fromProtoDuration(response->session_reauth_token_lifetime());
  std::move(callback).Run(status, std::move(response_struct));
}

}  // namespace

CloudSessionAuthzServiceClientFactory::CloudSessionAuthzServiceClientFactory(
    OAuthTokenGetter* oauth_token_getter,
    InstanceIdentityTokenGetter* instance_identity_token_getter,
    scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
    : oauth_token_getter_(oauth_token_getter),
      instance_identity_token_getter_(instance_identity_token_getter),
      url_loader_factory_(url_loader_factory) {}

CloudSessionAuthzServiceClientFactory::
    ~CloudSessionAuthzServiceClientFactory() = default;

std::unique_ptr<SessionAuthzServiceClient>
CloudSessionAuthzServiceClientFactory::Create() {
  return std::make_unique<CloudSessionAuthzServiceClient>(
      oauth_token_getter_, instance_identity_token_getter_,
      url_loader_factory_);
}

AuthenticationMethod CloudSessionAuthzServiceClientFactory::method() {
  return AuthenticationMethod::CLOUD_SESSION_AUTHZ_SPAKE2_CURVE25519;
}

}  // namespace remoting