#include "remoting/host/token_validator_base.h"
#include <memory>
#include <vector>
#include "base/atomic_sequence_num.h"
#include "base/memory/raw_ptr.h"
#include "crypto/rsa_private_key.h"
#include "net/cert/x509_util.h"
#include "net/ssl/client_cert_identity_test_util.h"
#include "net/ssl/ssl_private_key.h"
#include "net/ssl/test_ssl_private_key.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace {
const char kTokenUrl[] = "https://example.com/token";
const char kTokenValidationUrl[] = "https://example.com/validate";
const char kTokenValidationCertIssuer[] = "*";
base::AtomicSequenceNumber g_serial_number;
std::unique_ptr<net::FakeClientCertIdentity> CreateFakeCert(
base::Time valid_start,
base::Time valid_expiry) {
std::unique_ptr<crypto::RSAPrivateKey> rsa_private_key;
std::string cert_der;
net::x509_util::CreateKeyAndSelfSignedCert(
"CN=subject", g_serial_number.GetNext(), valid_start, valid_expiry,
&rsa_private_key, &cert_der);
scoped_refptr<net::X509Certificate> cert =
net::X509Certificate::CreateFromBytes(
base::as_bytes(base::make_span(cert_der)));
if (!cert) {
return nullptr;
}
scoped_refptr<net::SSLPrivateKey> ssl_private_key =
net::WrapRSAPrivateKey(rsa_private_key.get());
if (!ssl_private_key) {
return nullptr;
}
return std::make_unique<net::FakeClientCertIdentity>(cert, ssl_private_key);
}
}
namespace remoting {
class TestTokenValidator : TokenValidatorBase {
public:
explicit TestTokenValidator(const ThirdPartyAuthConfig& config);
~TestTokenValidator() override;
void SelectCertificates(net::ClientCertIdentityList selected_certs);
void ExpectContinueWithCertificate(
const net::FakeClientCertIdentity* identity);
protected:
void ContinueWithCertificate(
scoped_refptr<net::X509Certificate> client_cert,
scoped_refptr<net::SSLPrivateKey> client_private_key) override;
private:
void StartValidateRequest(const std::string& token) override {}
raw_ptr<net::X509Certificate> expected_client_cert_ = nullptr;
raw_ptr<net::SSLPrivateKey> expected_private_key_ = nullptr;
};
TestTokenValidator::TestTokenValidator(const ThirdPartyAuthConfig& config)
: TokenValidatorBase(config, "", nullptr) {}
TestTokenValidator::~TestTokenValidator() = default;
void TestTokenValidator::SelectCertificates(
net::ClientCertIdentityList selected_certs) {
OnCertificatesSelected(nullptr, std::move(selected_certs));
}
void TestTokenValidator::ExpectContinueWithCertificate(
const net::FakeClientCertIdentity* identity) {
if (identity) {
expected_client_cert_ = identity->certificate();
expected_private_key_ = identity->ssl_private_key();
} else {
expected_client_cert_ = nullptr;
expected_private_key_ = nullptr;
}
}
void TestTokenValidator::ContinueWithCertificate(
scoped_refptr<net::X509Certificate> client_cert,
scoped_refptr<net::SSLPrivateKey> client_private_key) {
EXPECT_EQ(expected_client_cert_, client_cert.get());
EXPECT_EQ(expected_private_key_, client_private_key.get());
}
class TokenValidatorBaseTest : public testing::Test {
public:
void SetUp() override;
protected:
std::unique_ptr<TestTokenValidator> token_validator_;
};
void TokenValidatorBaseTest::SetUp() {
ThirdPartyAuthConfig config;
config.token_url = GURL(kTokenUrl);
config.token_validation_url = GURL(kTokenValidationUrl);
config.token_validation_cert_issuer = kTokenValidationCertIssuer;
token_validator_ = std::make_unique<TestTokenValidator>(config);
}
TEST_F(TokenValidatorBaseTest, TestSelectCertificate) {
base::Time now = base::Time::Now();
std::unique_ptr<net::FakeClientCertIdentity> cert_expired_5_minutes_ago =
CreateFakeCert(now - base::Minutes(10), now - base::Minutes(5));
ASSERT_TRUE(cert_expired_5_minutes_ago);
std::unique_ptr<net::FakeClientCertIdentity> cert_start_5min_expire_5min =
CreateFakeCert(now - base::Minutes(5), now + base::Minutes(5));
ASSERT_TRUE(cert_start_5min_expire_5min);
std::unique_ptr<net::FakeClientCertIdentity> cert_start_10min_expire_5min =
CreateFakeCert(now - base::Minutes(10), now + base::Minutes(5));
ASSERT_TRUE(cert_start_10min_expire_5min);
std::unique_ptr<net::FakeClientCertIdentity> cert_start_5min_expire_10min =
CreateFakeCert(now - base::Minutes(5), now + base::Minutes(10));
ASSERT_TRUE(cert_start_5min_expire_10min);
token_validator_->ExpectContinueWithCertificate(nullptr);
token_validator_->SelectCertificates(net::ClientCertIdentityList());
{
net::ClientCertIdentityList client_certs;
client_certs.push_back(cert_expired_5_minutes_ago->Copy());
token_validator_->ExpectContinueWithCertificate(nullptr);
token_validator_->SelectCertificates(std::move(client_certs));
}
{
net::ClientCertIdentityList client_certs;
client_certs.push_back(cert_start_5min_expire_5min->Copy());
token_validator_->ExpectContinueWithCertificate(
cert_start_5min_expire_5min.get());
token_validator_->SelectCertificates(std::move(client_certs));
}
{
net::ClientCertIdentityList client_certs;
client_certs.push_back(cert_expired_5_minutes_ago->Copy());
client_certs.push_back(cert_start_5min_expire_5min->Copy());
token_validator_->ExpectContinueWithCertificate(
cert_start_5min_expire_5min.get());
token_validator_->SelectCertificates(std::move(client_certs));
}
{
net::ClientCertIdentityList client_certs;
client_certs.push_back(cert_start_10min_expire_5min->Copy());
client_certs.push_back(cert_start_5min_expire_5min->Copy());
token_validator_->ExpectContinueWithCertificate(
cert_start_5min_expire_5min.get());
token_validator_->SelectCertificates(std::move(client_certs));
}
{
net::ClientCertIdentityList client_certs;
client_certs.push_back(cert_start_5min_expire_5min->Copy());
client_certs.push_back(cert_start_5min_expire_10min->Copy());
token_validator_->ExpectContinueWithCertificate(
cert_start_5min_expire_10min.get());
token_validator_->SelectCertificates(std::move(client_certs));
}
{
net::ClientCertIdentityList client_certs;
client_certs.push_back(cert_expired_5_minutes_ago->Copy());
client_certs.push_back(cert_start_5min_expire_5min->Copy());
client_certs.push_back(cert_start_5min_expire_10min->Copy());
client_certs.push_back(cert_start_10min_expire_5min->Copy());
token_validator_->ExpectContinueWithCertificate(
cert_start_5min_expire_10min.get());
token_validator_->SelectCertificates(std::move(client_certs));
}
}
}