#include "media/mojo/clients/mojo_cdm.h"
#include <stdint.h>
#include <array>
#include <memory>
#include "base/functional/bind.h"
#include "base/memory/ref_counted.h"
#include "base/run_loop.h"
#include "base/test/test_message_loop.h"
#include "base/time/time.h"
#include "build/build_config.h"
#include "media/base/cdm_config.h"
#include "media/base/cdm_factory.h"
#include "media/base/content_decryption_module.h"
#include "media/base/mock_filters.h"
#include "media/cdm/clear_key_cdm_common.h"
#include "media/cdm/default_cdm_factory.h"
#include "media/mojo/mojom/content_decryption_module.mojom.h"
#include "media/mojo/services/mojo_cdm_service.h"
#include "media/mojo/services/mojo_cdm_service_context.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "url/gurl.h"
using ::testing::_;
using ::testing::DoAll;
using ::testing::Invoke;
using ::testing::Return;
using ::testing::ReturnNull;
using ::testing::ReturnPointee;
using ::testing::StrictMock;
using ::testing::WithArg;
using ::testing::WithArgs;
using ::testing::WithoutArgs;
MATCHER(NotEmpty, "") {
return !arg.empty();
}
ACTION_P2(CdmCreated, cdm, status) {
arg0.Run(cdm, status);
}
namespace media {
namespace {
const auto kKeyId = std::to_array<uint8_t>({
0x01,
0x02,
0x03,
0x04,
0x05,
0x06,
0x07,
0x08,
0x09,
0x0a,
0x0b,
0x0c,
0x0d,
0x0e,
0x0f,
0x10,
});
}
class MojoCdmTest : public ::testing::Test {
public:
enum ExpectedResult {
SUCCESS,
FAILURE,
PENDING,
CONNECTION_ERROR_BEFORE,
CONNECTION_ERROR_DURING,
};
MojoCdmTest() = default;
MojoCdmTest(const MojoCdmTest&) = delete;
MojoCdmTest& operator=(const MojoCdmTest&) = delete;
~MojoCdmTest() override = default;
void Initialize(ExpectedResult expected_result) {
EXPECT_CALL(*remote_cdm_, GetCdmContext())
.WillRepeatedly(Return(&cdm_context_));
EXPECT_CALL(cdm_context_, GetDecryptor()).WillRepeatedly(ReturnNull());
#if BUILDFLAG(IS_WIN)
EXPECT_CALL(cdm_context_, RequiresMediaFoundationRenderer())
.WillRepeatedly(ReturnPointee(&requires_media_foundation_renderer_));
#endif
if (expected_result == CONNECTION_ERROR_DURING) {
cdm_factory_.SetBeforeCreationCB(base::BindRepeating(
&MojoCdmTest::ForceConnectionError, base::Unretained(this)));
}
mojo_cdm_service_ =
std::make_unique<MojoCdmService>(&mojo_cdm_service_context_);
mojo_cdm_service_->Initialize(
&cdm_factory_, {kClearKeyKeySystem, false, false, false},
base::BindOnce(&MojoCdmTest::OnCdmServiceInitialized,
base::Unretained(this), expected_result));
}
void OnCdmServiceInitialized(ExpectedResult expected_result,
mojom::CdmContextPtr cdm_context,
CreateCdmStatus status) {
cdm_receiver_ =
std::make_unique<mojo::Receiver<mojom::ContentDecryptionModule>>(
mojo_cdm_service_.get());
if (expected_result == CONNECTION_ERROR_BEFORE) {
ForceConnectionError();
}
mojo::Remote<mojom::ContentDecryptionModule> cdm_remote(
cdm_receiver_->BindNewPipeAndPassRemote());
CdmConfig cdm_config = {"com.foo.bar", false, false, false};
mojo_cdm_ = base::MakeRefCounted<MojoCdm>(
std::move(cdm_remote), std::move(cdm_context), cdm_config,
base::BindRepeating(&MockCdmClient::OnSessionMessage,
base::Unretained(&cdm_client_)),
base::BindRepeating(&MockCdmClient::OnSessionClosed,
base::Unretained(&cdm_client_)),
base::BindRepeating(&MockCdmClient::OnSessionKeysChange,
base::Unretained(&cdm_client_)),
base::BindRepeating(&MockCdmClient::OnSessionExpirationUpdate,
base::Unretained(&cdm_client_)));
EXPECT_EQ(kClearKeyKeySystem, remote_cdm_->GetKeySystem());
base::RunLoop().RunUntilIdle();
}
void ForceConnectionError() {
cdm_receiver_->ResetWithReason(2, "Test closed connection.");
base::RunLoop().RunUntilIdle();
}
void SetServerCertificateAndExpect(const std::vector<uint8_t>& certificate,
ExpectedResult expected_result) {
if (expected_result == CONNECTION_ERROR_BEFORE) {
ForceConnectionError();
} else {
EXPECT_CALL(*remote_cdm_, SetServerCertificate(certificate, _))
.WillOnce([&](const auto& certificate, auto promise) {
HandlePromise(std::move(promise), expected_result);
});
}
mojo_cdm_->SetServerCertificate(
certificate,
std::make_unique<MockCdmPromise>(expected_result == SUCCESS));
base::RunLoop().RunUntilIdle();
}
void CreateSessionAndExpect(const std::string& session_id,
ExpectedResult expected_result) {
const CdmSessionType session_type = CdmSessionType::kTemporary;
const EmeInitDataType data_type = EmeInitDataType::WEBM;
const std::vector<uint8_t> key_id(
kKeyId.data(),
base::span<const uint8_t>(kKeyId).subspan(std::size(kKeyId)).data());
std::string created_session_id;
if (expected_result == CONNECTION_ERROR_BEFORE) {
ForceConnectionError();
} else {
EXPECT_CALL(*remote_cdm_, CreateSessionAndGenerateRequest(
session_type, data_type, key_id, _))
.WillOnce([&](auto session_type, auto init_data_type,
const auto& init_data, auto promise) {
HandleSessionPromise(std::move(promise), session_id,
expected_result);
});
}
mojo_cdm_->CreateSessionAndGenerateRequest(
session_type, data_type, key_id,
std::make_unique<MockCdmSessionPromise>(expected_result == SUCCESS,
&created_session_id));
base::RunLoop().RunUntilIdle();
if (expected_result == SUCCESS) {
EXPECT_EQ(session_id, created_session_id);
base::RunLoop().RunUntilIdle();
}
}
void LoadSessionAndExpect(const std::string& session_id,
ExpectedResult expected_result) {
const CdmSessionType session_type = CdmSessionType::kPersistentLicense;
std::string loaded_session_id;
if (expected_result == CONNECTION_ERROR_BEFORE) {
ForceConnectionError();
} else {
EXPECT_CALL(*remote_cdm_, LoadSession(session_type, session_id, _))
.WillOnce([&](auto session_type, auto session_id, auto promise) {
HandleSessionPromise(std::move(promise), session_id,
expected_result);
});
}
mojo_cdm_->LoadSession(session_type, session_id,
std::make_unique<MockCdmSessionPromise>(
expected_result == SUCCESS, &loaded_session_id));
base::RunLoop().RunUntilIdle();
if (expected_result == SUCCESS) {
EXPECT_EQ(session_id, loaded_session_id);
EXPECT_CALL(
cdm_client_,
OnSessionClosed(session_id, CdmSessionClosedReason::kInternalError));
remote_cdm_->CallSessionClosedCB(session_id,
CdmSessionClosedReason::kInternalError);
base::RunLoop().RunUntilIdle();
}
}
void UpdateSessionAndExpect(const std::string& session_id,
ExpectedResult expected_result) {
const std::vector<uint8_t> response = {1, 2, 3, 4, 5, 6};
if (expected_result == CONNECTION_ERROR_BEFORE) {
ForceConnectionError();
} else {
EXPECT_CALL(*remote_cdm_, UpdateSession(session_id, response, _))
.WillOnce([&](auto session_id, auto response, auto promise) {
HandlePromise(std::move(promise), expected_result);
});
}
mojo_cdm_->UpdateSession(
session_id, response,
std::make_unique<MockCdmPromise>(expected_result == SUCCESS));
base::RunLoop().RunUntilIdle();
}
void CloseSessionAndExpect(const std::string& session_id,
ExpectedResult expected_result) {
if (expected_result == CONNECTION_ERROR_BEFORE) {
ForceConnectionError();
} else {
EXPECT_CALL(*remote_cdm_, CloseSession(session_id, _))
.WillOnce([&](auto session_id, auto promise) {
HandlePromise(std::move(promise), expected_result);
});
}
mojo_cdm_->CloseSession(session_id, std::make_unique<MockCdmPromise>(
expected_result == SUCCESS));
base::RunLoop().RunUntilIdle();
}
void RemoveSessionAndExpect(const std::string& session_id,
ExpectedResult expected_result) {
if (expected_result == CONNECTION_ERROR_BEFORE) {
ForceConnectionError();
} else {
EXPECT_CALL(*remote_cdm_, RemoveSession(session_id, _))
.WillOnce([&](auto session_id, auto promise) {
HandlePromise(std::move(promise), expected_result);
});
}
mojo_cdm_->RemoveSession(session_id, std::make_unique<MockCdmPromise>(
expected_result == SUCCESS));
base::RunLoop().RunUntilIdle();
}
void HandlePromise(std::unique_ptr<SimpleCdmPromise> promise,
ExpectedResult expected_result) {
switch (expected_result) {
case SUCCESS:
promise->resolve();
break;
case FAILURE:
promise->reject(CdmPromise::Exception::NOT_SUPPORTED_ERROR, 0,
"Promise rejected");
break;
case PENDING:
pending_simple_cdm_promise_ = std::move(promise);
break;
case CONNECTION_ERROR_BEFORE:
NOTREACHED();
case CONNECTION_ERROR_DURING:
ForceConnectionError();
promise->resolve();
break;
}
}
void HandleSessionPromise(std::unique_ptr<NewSessionCdmPromise> promise,
const std::string& session_id,
ExpectedResult expected_result) {
switch (expected_result) {
case SUCCESS:
promise->resolve(session_id);
break;
case FAILURE:
promise->reject(CdmPromise::Exception::NOT_SUPPORTED_ERROR, 0,
"Promise rejected");
break;
case PENDING:
pending_new_session_cdm_promise_ = std::move(promise);
break;
case CONNECTION_ERROR_BEFORE:
NOTREACHED();
case CONNECTION_ERROR_DURING:
ForceConnectionError();
promise->resolve(session_id);
break;
}
}
base::TestMessageLoop message_loop_;
scoped_refptr<MockCdm> remote_cdm_{base::MakeRefCounted<MockCdm>()};
MockCdmFactory cdm_factory_{remote_cdm_};
MockCdmContext cdm_context_;
MojoCdmServiceContext mojo_cdm_service_context_;
StrictMock<MockCdmClient> cdm_client_;
std::unique_ptr<SimpleCdmPromise> pending_simple_cdm_promise_;
std::unique_ptr<NewSessionCdmPromise> pending_new_session_cdm_promise_;
std::unique_ptr<MojoCdmService> mojo_cdm_service_;
std::unique_ptr<mojo::Receiver<mojom::ContentDecryptionModule>> cdm_receiver_;
scoped_refptr<ContentDecryptionModule> mojo_cdm_;
#if BUILDFLAG(IS_WIN)
bool requires_media_foundation_renderer_ = false;
#endif
};
TEST_F(MojoCdmTest, Create_Success) {
Initialize(SUCCESS);
}
TEST_F(MojoCdmTest, Create_Failure) {
Initialize(FAILURE);
}
TEST_F(MojoCdmTest, Create_ConnectionErrorBefore) {
Initialize(CONNECTION_ERROR_BEFORE);
}
TEST_F(MojoCdmTest, SetServerCertificate_Success) {
const std::vector<uint8_t> certificate = {0, 1, 2};
Initialize(SUCCESS);
SetServerCertificateAndExpect(certificate, SUCCESS);
}
TEST_F(MojoCdmTest, SetServerCertificate_Failure) {
const std::vector<uint8_t> certificate = {1, 2, 3, 4, 5};
Initialize(SUCCESS);
SetServerCertificateAndExpect(certificate, FAILURE);
}
TEST_F(MojoCdmTest, SetServerCertificate_Pending) {
const std::vector<uint8_t> certificate = {1, 2, 3, 4, 5};
Initialize(SUCCESS);
SetServerCertificateAndExpect(certificate, PENDING);
}
TEST_F(MojoCdmTest, SetServerCertificate_ConnectionErrorBefore) {
const std::vector<uint8_t> certificate = {3, 4};
Initialize(SUCCESS);
SetServerCertificateAndExpect(certificate, CONNECTION_ERROR_BEFORE);
}
TEST_F(MojoCdmTest, SetServerCertificate_ConnectionErrorDuring) {
const std::vector<uint8_t> certificate = {10, 11, 12};
Initialize(SUCCESS);
SetServerCertificateAndExpect(certificate, CONNECTION_ERROR_DURING);
}
TEST_F(MojoCdmTest, CreateSession_Success) {
const std::string session_id = "create1";
Initialize(SUCCESS);
CreateSessionAndExpect(session_id, SUCCESS);
EXPECT_CALL(
cdm_client_,
OnSessionClosed(session_id, CdmSessionClosedReason::kInternalError));
}
TEST_F(MojoCdmTest, CreateSession_Failure) {
const std::string session_id = "create2";
Initialize(SUCCESS);
CreateSessionAndExpect(session_id, FAILURE);
}
TEST_F(MojoCdmTest, CreateSession_Pending) {
const std::string session_id = "create2";
Initialize(SUCCESS);
CreateSessionAndExpect(session_id, PENDING);
}
TEST_F(MojoCdmTest, CreateSession_ConnectionErrorBefore) {
const std::string session_id = "create3";
Initialize(SUCCESS);
CreateSessionAndExpect(session_id, CONNECTION_ERROR_BEFORE);
}
TEST_F(MojoCdmTest, CreateSession_ConnectionErrorDuring) {
const std::string session_id = "create4";
Initialize(SUCCESS);
CreateSessionAndExpect(session_id, CONNECTION_ERROR_DURING);
}
TEST_F(MojoCdmTest, LoadSession_Success) {
const std::string session_id = "load1";
Initialize(SUCCESS);
LoadSessionAndExpect(session_id, SUCCESS);
}
TEST_F(MojoCdmTest, LoadSession_Failure) {
const std::string session_id = "load2";
Initialize(SUCCESS);
LoadSessionAndExpect(session_id, FAILURE);
}
TEST_F(MojoCdmTest, LoadSession_Pending) {
const std::string session_id = "load2";
Initialize(SUCCESS);
LoadSessionAndExpect(session_id, PENDING);
}
TEST_F(MojoCdmTest, LoadSession_ConnectionErrorBefore) {
const std::string session_id = "load3";
Initialize(SUCCESS);
LoadSessionAndExpect(session_id, CONNECTION_ERROR_BEFORE);
}
TEST_F(MojoCdmTest, LoadSession_ConnectionErrorDuring) {
const std::string session_id = "load4";
Initialize(SUCCESS);
LoadSessionAndExpect(session_id, CONNECTION_ERROR_DURING);
}
TEST_F(MojoCdmTest, UpdateSession_Success) {
const std::string session_id = "update1";
Initialize(SUCCESS);
UpdateSessionAndExpect(session_id, SUCCESS);
}
TEST_F(MojoCdmTest, UpdateSession_Failure) {
const std::string session_id = "update2";
Initialize(SUCCESS);
UpdateSessionAndExpect(session_id, FAILURE);
}
TEST_F(MojoCdmTest, UpdateSession_Pending) {
const std::string session_id = "update2";
Initialize(SUCCESS);
UpdateSessionAndExpect(session_id, PENDING);
}
TEST_F(MojoCdmTest, UpdateSession_ConnectionErrorBefore) {
const std::string session_id = "update3";
Initialize(SUCCESS);
UpdateSessionAndExpect(session_id, CONNECTION_ERROR_BEFORE);
}
TEST_F(MojoCdmTest, UpdateSession_ConnectionErrorDuring) {
const std::string session_id = "update4";
Initialize(SUCCESS);
UpdateSessionAndExpect(session_id, CONNECTION_ERROR_DURING);
}
TEST_F(MojoCdmTest, CloseSession_Success) {
const std::string session_id = "close1";
Initialize(SUCCESS);
CloseSessionAndExpect(session_id, SUCCESS);
}
TEST_F(MojoCdmTest, CloseSession_Failure) {
const std::string session_id = "close2";
Initialize(SUCCESS);
CloseSessionAndExpect(session_id, FAILURE);
}
TEST_F(MojoCdmTest, CloseSession_Pending) {
const std::string session_id = "close2";
Initialize(SUCCESS);
CloseSessionAndExpect(session_id, PENDING);
}
TEST_F(MojoCdmTest, CloseSession_ConnectionErrorBefore) {
const std::string session_id = "close3";
Initialize(SUCCESS);
CloseSessionAndExpect(session_id, CONNECTION_ERROR_BEFORE);
}
TEST_F(MojoCdmTest, CloseSession_ConnectionErrorDuring) {
const std::string session_id = "close4";
Initialize(SUCCESS);
CloseSessionAndExpect(session_id, CONNECTION_ERROR_DURING);
}
TEST_F(MojoCdmTest, RemoveSession_Success) {
const std::string session_id = "remove1";
Initialize(SUCCESS);
RemoveSessionAndExpect(session_id, SUCCESS);
}
TEST_F(MojoCdmTest, RemoveSession_Failure) {
const std::string session_id = "remove2";
Initialize(SUCCESS);
RemoveSessionAndExpect(session_id, FAILURE);
}
TEST_F(MojoCdmTest, RemoveSession_Pending) {
const std::string session_id = "remove2";
Initialize(SUCCESS);
RemoveSessionAndExpect(session_id, PENDING);
}
TEST_F(MojoCdmTest, RemoveSession_ConnectionErrorBefore) {
const std::string session_id = "remove3";
Initialize(SUCCESS);
RemoveSessionAndExpect(session_id, CONNECTION_ERROR_BEFORE);
}
TEST_F(MojoCdmTest, RemoveSession_ConnectionErrorDuring) {
const std::string session_id = "remove4";
Initialize(SUCCESS);
RemoveSessionAndExpect(session_id, CONNECTION_ERROR_DURING);
}
TEST_F(MojoCdmTest, SessionMessageCB_Success) {
const std::string session_id = "message";
const CdmMessageType message_type = CdmMessageType::LICENSE_REQUEST;
const std::vector<uint8_t> message = {0, 1, 2};
Initialize(SUCCESS);
EXPECT_CALL(cdm_client_, OnSessionMessage(session_id, message_type, message));
remote_cdm_->CallSessionMessageCB(session_id, message_type, message);
base::RunLoop().RunUntilIdle();
}
TEST_F(MojoCdmTest, SessionExpirationChangeCB_Success) {
const std::string session_id = "expiration";
const base::Time time = base::Time::Now();
Initialize(SUCCESS);
EXPECT_CALL(cdm_client_, OnSessionExpirationUpdate(session_id, time));
remote_cdm_->CallSessionExpirationUpdateCB(session_id, time);
base::RunLoop().RunUntilIdle();
}
TEST_F(MojoCdmTest, SessionKeysChangeCB_Success) {
const std::string session_id = "change";
bool has_additional_usable_key = true;
CdmKeysInfo keys_info;
Initialize(SUCCESS);
EXPECT_CALL(cdm_client_,
OnSessionKeysChangeCalled(session_id, has_additional_usable_key));
remote_cdm_->CallSessionKeysChangeCB(session_id, has_additional_usable_key,
std::move(keys_info));
base::RunLoop().RunUntilIdle();
}
TEST_F(MojoCdmTest, NoDecryptor) {
Initialize(SUCCESS);
auto* cdm_context = mojo_cdm_->GetCdmContext();
EXPECT_TRUE(cdm_context) << "All CDMs should support CdmContext";
auto* decryptor = cdm_context->GetDecryptor();
EXPECT_FALSE(decryptor);
}
#if BUILDFLAG(IS_WIN)
TEST_F(MojoCdmTest, RequiresMediaFoundationRenderer) {
requires_media_foundation_renderer_ = true;
Initialize(SUCCESS);
auto* cdm_context = mojo_cdm_->GetCdmContext();
EXPECT_TRUE(cdm_context) << "All CDMs should support CdmContext";
EXPECT_TRUE(cdm_context->RequiresMediaFoundationRenderer());
}
TEST_F(MojoCdmTest, NotRequireMediaFoundationRenderer) {
requires_media_foundation_renderer_ = false;
Initialize(SUCCESS);
auto* cdm_context = mojo_cdm_->GetCdmContext();
EXPECT_TRUE(cdm_context) << "All CDMs should support CdmContext";
EXPECT_FALSE(cdm_context->RequiresMediaFoundationRenderer());
}
#endif
}