910e62b5创建于 1月15日历史提交
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "remoting/base/instance_identity_token_getter_impl.h"

#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>

#include "base/base64.h"
#include "base/environment.h"
#include "base/functional/bind.h"
#include "base/json/json_writer.h"
#include "base/memory/ref_counted.h"
#include "base/run_loop.h"
#include "base/strings/stringprintf.h"
#include "base/test/task_environment.h"
#include "net/http/http_status_code.h"
#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
#include "services/network/test/test_shared_url_loader_factory.h"
#include "services/network/test/test_url_loader_factory.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace remoting {

namespace {
constexpr char kTestAudience[] = "audience_for_testing";
// Matches the URL generated for requests from ComputeEngineServiceClient.
constexpr char kHttpMetadataRequestUrl[] =
    "http://metadata.google.internal/computeMetadata/v1/instance/"
    "service-accounts/default/identity?audience=audience_for_testing&"
    "format=full";
// The environment variable used to override the default metadata server host
// when code is run on a Compute Engine instance.
constexpr char kGceMetadataHostVarName[] = "GCE_METADATA_HOST";

// Constants used for testing Compute Engine Metadata server overrides.
constexpr char kTestMetadataServerHost[] = "override.google.internal";
constexpr char kOverriddenMetadataRequestUrl[] =
    "http://override.google.internal/computeMetadata/v1/instance/"
    "service-accounts/default/identity?audience=audience_for_testing&"
    "format=full";

// This JWT is valid but decoding will fail if kStrict is used because the
// length of the header and payload are not divisible by 4 as the padding has
// been stripped which is how tokens received from the metadata server are
// formatted.
constexpr std::string_view kJwtWithoutPadding =
    // Header
    "eyJhbGciOiJSUzI1NiIsImtpZCI6ImtpZC1zaWduYXR1cmUiLCJ0eXAiOiJKV1QifQ."
    // Payload
    "eyJhdWQiOiJhdWRpZW5jZV9mb3JfdGVzdGluZyIsImF6cCI6IjEyMzQ1IiwiZXhwIjo1NzQw"
    "MTcsImdvb2dsZSI6eyJjb21wdXRlX2VuZ2luZSI6eyJpbnN0YW5jZV9pZCI6IjEyMzQ1Njc4"
    "OSIsImluc3RhbmNlX25hbWUiOiJteS1pbnN0YW5jZSIsInByb2plY3RfaWQiOiJteS1wcm9q"
    "ZWN0IiwicHJvamVjdF9udW1iZXIiOjU0MzIxLCJ6b25lIjoidXMtd2lsZC13ZXN0MS16In19"
    "LCJpYXQiOjU3MDQxNywiaXNzIjoiRzAwZ2xlIiwic3ViIjoiMTIzNDUifQ."
    // Signature
    "SIGNATURE!!";

base::Value::Dict CreateJwtHeaderDict() {
  return base::Value::Dict()
      .Set("typ", "JWT")
      .Set("kid", "kid-signature")
      .Set("alg", "RS256");
}

base::Value::Dict CreateJwtPayloadDict(base::Time now = base::Time::Now()) {
  auto compute_engine_dict = base::Value::Dict()
                                 .Set("instance_id", "123456789")
                                 .Set("instance_name", "my-instance")
                                 .Set("project_id", "my-project")
                                 .Set("project_number", 54321)
                                 .Set("zone", "us-wild-west1-z");

  return base::Value::Dict()
      .Set("iss", "G00gle")
      .Set("iat", static_cast<int>(now.InSecondsFSinceUnixEpoch()))
      .Set("exp", static_cast<int>(
                      (now + base::Minutes(60)).InSecondsFSinceUnixEpoch()))
      .Set("aud", kTestAudience)
      .Set("sub", "12345")
      .Set("azp", "12345")
      .Set("google", base::Value::Dict().Set("compute_engine",
                                             std::move(compute_engine_dict)));
}

std::string Base64EncodeDict(base::Value::Dict dict) {
  return base::Base64Encode(*base::WriteJson(std::move(dict)));
}

std::string GetBase64EncodedHeader() {
  return Base64EncodeDict(CreateJwtHeaderDict());
}

std::string GetBase64EncodedPayload(base::Time now = base::Time::Now()) {
  return Base64EncodeDict(CreateJwtPayloadDict(now));
}

std::string GenerateValidJwt(std::string header, std::string payload) {
  return header + "." + payload + "." + "signature";
}

std::string GenerateValidJwt(base::Time now = base::Time::Now()) {
  return GenerateValidJwt(GetBase64EncodedHeader(), GetBase64EncodedPayload());
}

struct TokenParams {
  TokenParams(std::string header, std::string payload, std::string test_name)
      : header(std::move(header)),
        payload(std::move(payload)),
        test_name(std::move(test_name)) {}

  std::string header;
  std::string payload;
  std::string test_name;
};

}  // namespace

class InstanceIdentityTokenGetterImplTest
    : public testing::Test,
      public testing::WithParamInterface<TokenParams> {
 public:
  InstanceIdentityTokenGetterImplTest();
  ~InstanceIdentityTokenGetterImplTest() override;

  void SetUp() override;

  void OnTokenRetrieved(std::string_view token);

 protected:
  void RunUntilQuit();
  void SetTokenResponse(
      std::string_view response_body,
      std::string_view metadata_server_url = kHttpMetadataRequestUrl);
  void SetErrorResponse(net::HttpStatusCode status);
  void ResetQuitClosure();
  void ClearTokenResponse();
  void FastForwardBy(base::TimeDelta duration);
  void SetMetadataServerEnvVar(std::string_view metadata_server_host);

  InstanceIdentityTokenGetterImpl& instance_identity_token_getter() {
    return *instance_identity_token_getter_;
  }

  void set_pending_callback_count(int count) {
    pending_callback_count_ = count;
  }

  const std::optional<std::string>& token() { return token_; }
  void clear_token() { token_.reset(); }

  size_t url_loader_request_count() {
    return test_url_loader_factory_.total_requests();
  }

  std::string_view valid_jwt() { return valid_jwt_; }

 private:
  int pending_callback_count_ = 1;
  std::optional<std::string> token_;
  // Generate once and store so this token can be used for comparisons.
  std::string valid_jwt_;

  base::RepeatingClosure quit_closure_;

  base::test::TaskEnvironment task_environment_{
      base::test::TaskEnvironment::MainThreadType::IO,
      base::test::TaskEnvironment::TimeSource::MOCK_TIME};

  std::unique_ptr<base::Environment> environment_ = base::Environment::Create();
  network::TestURLLoaderFactory test_url_loader_factory_;
  scoped_refptr<network::SharedURLLoaderFactory> shared_url_loader_factory_;

  std::unique_ptr<InstanceIdentityTokenGetterImpl>
      instance_identity_token_getter_;
};

InstanceIdentityTokenGetterImplTest::InstanceIdentityTokenGetterImplTest() =
    default;
InstanceIdentityTokenGetterImplTest::~InstanceIdentityTokenGetterImplTest() =
    default;

void InstanceIdentityTokenGetterImplTest::SetUp() {
  valid_jwt_ = GenerateValidJwt();
  shared_url_loader_factory_ = test_url_loader_factory_.GetSafeWeakWrapper();
  // Unset any pre-existing environment vars before constructing the
  // InstanceIdentityTokenGetterImpl so the externally set values are not used.
  environment_->UnSetVar(kGceMetadataHostVarName);
  instance_identity_token_getter_ =
      std::make_unique<InstanceIdentityTokenGetterImpl>(
          kTestAudience, shared_url_loader_factory_);

  quit_closure_ = task_environment_.QuitClosure();
}

void InstanceIdentityTokenGetterImplTest::RunUntilQuit() {
  task_environment_.RunUntilQuit();
}

void InstanceIdentityTokenGetterImplTest::SetTokenResponse(
    std::string_view response_body,
    std::string_view metadata_server_url) {
  ClearTokenResponse();
  test_url_loader_factory_.AddResponse(metadata_server_url, response_body);
}

void InstanceIdentityTokenGetterImplTest::ClearTokenResponse() {
  test_url_loader_factory_.ClearResponses();
}

void InstanceIdentityTokenGetterImplTest::SetErrorResponse(
    net::HttpStatusCode status) {
  ClearTokenResponse();
  test_url_loader_factory_.AddResponse(kHttpMetadataRequestUrl,
                                       /*content=*/std::string(), status);
}

void InstanceIdentityTokenGetterImplTest::ResetQuitClosure() {
  ASSERT_TRUE(quit_closure_.is_null());
  quit_closure_ = task_environment_.QuitClosure();
}

void InstanceIdentityTokenGetterImplTest::OnTokenRetrieved(
    std::string_view token) {
  // If the callback has been run previously, make sure each callback receives
  // the same value.
  if (token_.has_value()) {
    EXPECT_EQ(*token_, token);
  } else {
    token_ = token;
  }

  pending_callback_count_--;
  if (pending_callback_count_ == 0) {
    std::move(quit_closure_).Run();
  }
}

void InstanceIdentityTokenGetterImplTest::FastForwardBy(
    base::TimeDelta duration) {
  task_environment_.FastForwardBy(duration);
}

void InstanceIdentityTokenGetterImplTest::SetMetadataServerEnvVar(
    std::string_view metadata_server_host) {
  environment_->SetVar(kGceMetadataHostVarName,
                       std::string(metadata_server_host));
  // Recreate the InstanceIdentityTokenGetterImpl so the new value is used.
  instance_identity_token_getter_ =
      std::make_unique<InstanceIdentityTokenGetterImpl>(
          kTestAudience, shared_url_loader_factory_);
}

TEST_F(InstanceIdentityTokenGetterImplTest, SingleRequest) {
  SetTokenResponse(valid_jwt());

  instance_identity_token_getter().RetrieveToken(
      base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
                     base::Unretained(this)));

  RunUntilQuit();

  ASSERT_TRUE(token().has_value());
  ASSERT_EQ(*token(), valid_jwt());
  ASSERT_EQ(url_loader_request_count(), 1U);
}

TEST_F(InstanceIdentityTokenGetterImplTest,
       SingleRequestWithCustomMetadataServer) {
  SetMetadataServerEnvVar(kTestMetadataServerHost);
  SetTokenResponse(valid_jwt(), kOverriddenMetadataRequestUrl);

  instance_identity_token_getter().RetrieveToken(
      base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
                     base::Unretained(this)));

  RunUntilQuit();

  ASSERT_TRUE(token().has_value());
  ASSERT_EQ(*token(), valid_jwt());
  ASSERT_EQ(url_loader_request_count(), 1U);
}

TEST_F(InstanceIdentityTokenGetterImplTest, JwtWithoutPadding) {
  // Base64 decode will fail if kStrict is used.
  SetTokenResponse(kJwtWithoutPadding);

  instance_identity_token_getter().RetrieveToken(
      base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
                     base::Unretained(this)));

  RunUntilQuit();

  ASSERT_TRUE(token().has_value());
  ASSERT_EQ(*token(), kJwtWithoutPadding);
  ASSERT_EQ(url_loader_request_count(), 1U);
}

TEST_F(InstanceIdentityTokenGetterImplTest, MultipleRequests) {
  const int kQueuedCallbackCount = 10;

  set_pending_callback_count(kQueuedCallbackCount);
  for (int i = 0; i < kQueuedCallbackCount; i++) {
    instance_identity_token_getter().RetrieveToken(
        base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
                       base::Unretained(this)));
  }

  SetTokenResponse(valid_jwt());

  RunUntilQuit();

  ASSERT_TRUE(token().has_value());
  ASSERT_EQ(*token(), valid_jwt());
  ASSERT_EQ(url_loader_request_count(), 1U);
}

TEST_F(InstanceIdentityTokenGetterImplTest, CachedTokenReturned) {
  instance_identity_token_getter().RetrieveToken(
      base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
                     base::Unretained(this)));

  SetTokenResponse(valid_jwt());

  RunUntilQuit();

  ASSERT_TRUE(token().has_value());
  ASSERT_EQ(*token(), valid_jwt());

  // Call a second time and verify a token is provided w/o calling the service.
  ClearTokenResponse();
  ResetQuitClosure();
  set_pending_callback_count(1);

  instance_identity_token_getter().RetrieveToken(
      base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
                     base::Unretained(this)));

  RunUntilQuit();

  ASSERT_TRUE(token().has_value());
  ASSERT_EQ(*token(), valid_jwt());
  ASSERT_EQ(url_loader_request_count(), 1U);
}

TEST_F(InstanceIdentityTokenGetterImplTest, CachedTokenIgnored) {
  auto first_jwt_response = valid_jwt();
  instance_identity_token_getter().RetrieveToken(
      base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
                     base::Unretained(this)));

  SetTokenResponse(first_jwt_response);

  RunUntilQuit();

  ASSERT_TRUE(token().has_value());
  ASSERT_EQ(*token(), first_jwt_response);

  // Call again and verify a token is provided after calling the service.
  FastForwardBy(base::Hours(1));
  ResetQuitClosure();
  set_pending_callback_count(1);
  clear_token();
  // Generate a new JWT with updated timestamp.
  auto second_jwt_response = GenerateValidJwt();
  SetTokenResponse(second_jwt_response);

  instance_identity_token_getter().RetrieveToken(
      base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
                     base::Unretained(this)));

  RunUntilQuit();

  ASSERT_TRUE(token().has_value());
  ASSERT_EQ(*token(), second_jwt_response);
  ASSERT_EQ(url_loader_request_count(), 2U);
}

TEST_F(InstanceIdentityTokenGetterImplTest, ServiceFailureReturnsEmptyString) {
  instance_identity_token_getter().RetrieveToken(
      base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
                     base::Unretained(this)));

  SetErrorResponse(net::HTTP_BAD_REQUEST);

  RunUntilQuit();

  ASSERT_TRUE(token().has_value());
  ASSERT_TRUE(token()->empty());
}

TEST_F(InstanceIdentityTokenGetterImplTest, UnsignedJwtReturnsEmptyToken) {
  // Set response to a token which is missing the signature.
  SetTokenResponse(GetBase64EncodedHeader() + "." + GetBase64EncodedPayload());

  instance_identity_token_getter().RetrieveToken(
      base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
                     base::Unretained(this)));

  RunUntilQuit();

  ASSERT_TRUE(token().has_value());
  ASSERT_TRUE(token()->empty());
}

TEST_P(InstanceIdentityTokenGetterImplTest, InvalidJwtReturnsEmptyToken) {
  SetTokenResponse(GetParam().header + "." + GetParam().payload + ".signature");

  instance_identity_token_getter().RetrieveToken(
      base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
                     base::Unretained(this)));

  RunUntilQuit();

  ASSERT_TRUE(token().has_value());
  ASSERT_TRUE(token()->empty());
}

INSTANTIATE_TEST_SUITE_P(
    InstanceIdentityTokenGetterImplTest,
    InstanceIdentityTokenGetterImplTest,
    ::testing::Values(
        TokenParams("", "", "EmptyHeaderAndPayload"),
        TokenParams("a." + GetBase64EncodedHeader(),
                    GetBase64EncodedPayload(),
                    "ExtraTokenSegment"),
        TokenParams(GetBase64EncodedHeader() + "====",
                    GetBase64EncodedPayload() + "====",
                    "ExtraPadding"),
        TokenParams("", GetBase64EncodedPayload(), "EmptyHeader"),
        TokenParams(GetBase64EncodedHeader(), "", "EmptyPayload"),
        TokenParams(Base64EncodeDict(base::Value::Dict()),
                    GetBase64EncodedPayload(),
                    "HeaderIsAnEmptyDict"),
        TokenParams(GetBase64EncodedHeader(),
                    Base64EncodeDict(base::Value::Dict()),
                    "PayloadIsAnEmptyDict"),
        TokenParams(*base::WriteJson(CreateJwtHeaderDict()),
                    GetBase64EncodedPayload(),
                    "HeaderNotBase64Encoded"),
        TokenParams(GetBase64EncodedHeader(),
                    *base::WriteJson(CreateJwtPayloadDict()),
                    "PayloadNotBase64Encoded"),
        TokenParams(base::Base64Encode("I'm JSON!"),
                    GetBase64EncodedPayload(),
                    "HeaderIsNotValidJson"),
        TokenParams(GetBase64EncodedHeader(),
                    base::Base64Encode("I'm JSON!"),
                    "PayloadIsNotValidJson"),
        TokenParams(Base64EncodeDict(base::Value::Dict()
                                         .Set("alg", "RS256")
                                         .Set("typ", "JWT")),
                    GetBase64EncodedPayload(),
                    "HeaderIsMissingMembers"),
        TokenParams(GetBase64EncodedHeader(),
                    Base64EncodeDict(base::Value::Dict().Set("iss", "blergh")),
                    "PayloadIsMissingMembers"),
        TokenParams(GetBase64EncodedHeader(),
                    Base64EncodeDict(CreateJwtPayloadDict().SetByDottedPath(
                        "google",
                        base::Value::Dict())),
                    "NoComputeEngineDict"),
        TokenParams(GetBase64EncodedHeader(),
                    Base64EncodeDict(CreateJwtPayloadDict().SetByDottedPath(
                        "google.compute_engine",
                        base::Value::Dict())),
                    "EmptyComputeEngineDict"),
        TokenParams(GetBase64EncodedHeader(),
                    Base64EncodeDict(CreateJwtPayloadDict().SetByDottedPath(
                        "google.compute_engine",
                        base::Value::Dict().Set("instance_id",
                                                "test-instance-id"))),
                    "ComputeEngineDictMissingValues")),
    [](const testing::TestParamInfo<
        InstanceIdentityTokenGetterImplTest::ParamType>& info) {
      return info.param.test_name;
    });

}  // namespace remoting