#include "media/cdm/json_web_key.h"
#include <stddef.h>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include "base/base64url.h"
#include "base/containers/span.h"
#include "base/json/json_reader.h"
#include "base/json/json_writer.h"
#include "base/json/string_escape.h"
#include "base/logging.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_util.h"
#include "base/values.h"
#include "media/base/content_decryption_module.h"
namespace media {
const char kKeysTag[] = "keys";
const char kKeyTypeTag[] = "kty";
const char kKeyTypeOct[] = "oct";
const char kKeyTag[] = "k";
const char kKeyIdTag[] = "kid";
const char kKeyIdsTag[] = "kids";
const char kTypeTag[] = "type";
const char kTemporarySession[] = "temporary";
const char kPersistentLicenseSession[] = "persistent-license";
static std::string ShortenTo64Characters(const std::string& input) {
std::string escaped_str =
base::EscapeBytesAsInvalidJSONString(input.substr(0, 65), false);
if (escaped_str.length() <= 64u)
return escaped_str;
return escaped_str.substr(0, 61).append("...");
}
static base::Value::Dict CreateJSONDictionary(
base::span<const uint8_t> key,
base::span<const uint8_t> key_id) {
std::string key_string, key_id_string;
base::Base64UrlEncode(key, base::Base64UrlEncodePolicy::OMIT_PADDING,
&key_string);
base::Base64UrlEncode(key_id, base::Base64UrlEncodePolicy::OMIT_PADDING,
&key_id_string);
base::Value::Dict jwk;
jwk.Set(kKeyTypeTag, kKeyTypeOct);
jwk.Set(kKeyTag, key_string);
jwk.Set(kKeyIdTag, key_id_string);
return jwk;
}
std::string GenerateJWKSet(base::span<const uint8_t> key,
base::span<const uint8_t> key_id) {
base::Value::List list;
list.Append(CreateJSONDictionary(key, key_id));
base::Value::Dict jwk_set;
jwk_set.Set(kKeysTag, std::move(list));
return base::WriteJson(jwk_set).value_or(std::string());
}
std::string GenerateJWKSet(const KeyIdAndKeyPairs& keys,
CdmSessionType session_type) {
base::Value::List list;
for (const auto& key_pair : keys) {
list.Append(
base::Value(CreateJSONDictionary(base::as_byte_span(key_pair.second),
base::as_byte_span(key_pair.first))));
}
base::Value::Dict jwk_set;
jwk_set.Set(kKeysTag, std::move(list));
switch (session_type) {
case CdmSessionType::kTemporary:
jwk_set.Set(kTypeTag, kTemporarySession);
break;
case CdmSessionType::kPersistentLicense:
jwk_set.Set(kTypeTag, kPersistentLicenseSession);
break;
}
return base::WriteJson(jwk_set).value_or(std::string());
}
static bool ConvertJwkToKeyPair(const base::Value::Dict& jwk,
KeyIdAndKeyPair* jwk_key) {
const base::Value* type = jwk.Find(kKeyTypeTag);
if (!type || *type != kKeyTypeOct) {
DVLOG(1) << "Missing or invalid '" << kKeyTypeTag
<< "': " << (type ? type->DebugString() : "");
return false;
}
const base::Value* encoded_key_id = jwk.Find(kKeyIdTag);
const base::Value* encoded_key = jwk.Find(kKeyTag);
if (!encoded_key_id || !encoded_key_id->is_string()) {
DVLOG(1) << "Missing '" << kKeyIdTag << "' parameter";
return false;
}
if (!encoded_key || !encoded_key->is_string()) {
DVLOG(1) << "Missing '" << kKeyTag << "' parameter";
return false;
}
std::string raw_key_id;
if (!base::Base64UrlDecode(encoded_key_id->GetString(),
base::Base64UrlDecodePolicy::DISALLOW_PADDING,
&raw_key_id) ||
raw_key_id.empty()) {
DVLOG(1) << "Invalid '" << kKeyIdTag << "' value: " << encoded_key_id;
return false;
}
std::string raw_key;
if (!base::Base64UrlDecode(encoded_key->GetString(),
base::Base64UrlDecodePolicy::DISALLOW_PADDING,
&raw_key) ||
raw_key.empty()) {
DVLOG(1) << "Invalid '" << kKeyTag << "' value: " << encoded_key;
return false;
}
*jwk_key = std::make_pair(raw_key_id, raw_key);
return true;
}
bool ExtractKeysFromJWKSet(const std::string& jwk_set,
KeyIdAndKeyPairs* keys,
CdmSessionType* session_type) {
if (!base::IsStringASCII(jwk_set)) {
DVLOG(1) << "Non ASCII JWK Set: " << jwk_set;
return false;
}
std::optional<base::Value> root =
base::JSONReader::Read(jwk_set, base::JSON_PARSE_CHROMIUM_EXTENSIONS);
if (!root.has_value() || !root->is_dict()) {
DVLOG(1) << "Not valid JSON: " << jwk_set;
return false;
}
base::Value::Dict* dictionary = root.value().GetIfDict();
base::Value::List* list_val = dictionary->FindList(kKeysTag);
if (!list_val) {
DVLOG(1) << "Missing '" << kKeysTag
<< "' parameter or not a list in JWK Set";
return false;
}
KeyIdAndKeyPairs local_keys;
for (size_t i = 0; i < list_val->size(); ++i) {
base::Value& jwk = (*list_val)[i];
if (!jwk.is_dict()) {
DVLOG(1) << "Unable to access '" << kKeysTag << "'[" << i
<< "] in JWK Set";
return false;
}
KeyIdAndKeyPair key_pair;
if (!ConvertJwkToKeyPair(jwk.GetDict(), &key_pair)) {
DVLOG(1) << "Error from '" << kKeysTag << "'[" << i << "]";
return false;
}
local_keys.push_back(key_pair);
}
base::Value* value = dictionary->Find(kTypeTag);
if (!value) {
*session_type = CdmSessionType::kTemporary;
} else {
if (!value->is_string()) {
DVLOG(1) << "Invalid '" << kTypeTag << "' value";
return false;
}
const std::string session_type_id = value->GetString();
if (session_type_id == kTemporarySession) {
*session_type = CdmSessionType::kTemporary;
} else if (session_type_id == kPersistentLicenseSession) {
*session_type = CdmSessionType::kPersistentLicense;
} else {
DVLOG(1) << "Invalid '" << kTypeTag << "' value: " << session_type_id;
return false;
}
}
keys->swap(local_keys);
return true;
}
bool ExtractKeyIdsFromKeyIdsInitData(const std::string& input,
KeyIdList* key_ids,
std::string* error_message) {
if (!base::IsStringASCII(input)) {
error_message->assign("Non ASCII: ");
error_message->append(ShortenTo64Characters(input));
return false;
}
std::optional<base::Value> root =
base::JSONReader::Read(input, base::JSON_PARSE_CHROMIUM_EXTENSIONS);
if (!root.has_value() || !root->is_dict()) {
error_message->assign("Not valid JSON: ");
error_message->append(ShortenTo64Characters(input));
return false;
}
const base::Value::List* list_val = root->GetDict().FindList(kKeyIdsTag);
if (!list_val) {
error_message->assign("Missing '");
error_message->append(kKeyIdsTag);
error_message->append("' parameter or not a list");
return false;
}
KeyIdList local_key_ids;
for (size_t i = 0; i < list_val->size(); ++i) {
const std::string* encoded_key_id = (*list_val)[i].GetIfString();
if (!encoded_key_id) {
error_message->assign("'");
error_message->append(kKeyIdsTag);
error_message->append("'[");
error_message->append(base::NumberToString(i));
error_message->append("] is not string.");
return false;
}
std::string raw_key_id;
if (!base::Base64UrlDecode(*encoded_key_id,
base::Base64UrlDecodePolicy::DISALLOW_PADDING,
&raw_key_id) ||
raw_key_id.empty()) {
error_message->assign("'");
error_message->append(kKeyIdsTag);
error_message->append("'[");
error_message->append(base::NumberToString(i));
error_message->append("] is not valid base64url encoded. Value: ");
error_message->append(ShortenTo64Characters(*encoded_key_id));
return false;
}
local_key_ids.emplace_back(raw_key_id.begin(), raw_key_id.end());
}
key_ids->swap(local_key_ids);
error_message->clear();
return true;
}
void CreateLicenseRequest(const KeyIdList& key_ids,
CdmSessionType session_type,
std::vector<uint8_t>* license) {
base::Value::Dict request;
base::Value::List list;
for (const auto& key_id : key_ids) {
std::string key_id_string;
base::Base64UrlEncode(
std::string_view(reinterpret_cast<const char*>(key_id.data()),
key_id.size()),
base::Base64UrlEncodePolicy::OMIT_PADDING, &key_id_string);
list.Append(key_id_string);
}
request.Set(kKeyIdsTag, std::move(list));
switch (session_type) {
case CdmSessionType::kTemporary:
request.Set(kTypeTag, kTemporarySession);
break;
case CdmSessionType::kPersistentLicense:
request.Set(kTypeTag, kPersistentLicenseSession);
break;
}
std::optional<std::string> json =
base::WriteJson(request).value_or(std::string());
std::vector<uint8_t> result(json->begin(), json->end());
license->swap(result);
}
base::Value::Dict MakeKeyIdsDictionary(const KeyIdList& key_ids) {
base::Value::Dict dictionary;
base::Value::List list;
for (const auto& key_id : key_ids) {
std::string key_id_string;
base::Base64UrlEncode(
std::string_view(reinterpret_cast<const char*>(key_id.data()),
key_id.size()),
base::Base64UrlEncodePolicy::OMIT_PADDING, &key_id_string);
list.Append(key_id_string);
}
dictionary.Set(kKeyIdsTag, std::move(list));
return dictionary;
}
std::vector<uint8_t> SerializeDictionaryToVector(
const base::Value::Dict& dictionary) {
std::optional<std::string> json =
base::WriteJson(dictionary).value_or(std::string());
return std::vector<uint8_t>(json->begin(), json->end());
}
void CreateKeyIdsInitData(const KeyIdList& key_ids,
std::vector<uint8_t>* init_data) {
auto dictionary = MakeKeyIdsDictionary(key_ids);
auto data = SerializeDictionaryToVector(dictionary);
init_data->swap(data);
}
std::vector<uint8_t> CreateLicenseReleaseMessage(const KeyIdList& key_ids) {
auto dictionary = MakeKeyIdsDictionary(key_ids);
return SerializeDictionaryToVector(dictionary);
}
bool ExtractFirstKeyIdFromLicenseRequest(const std::vector<uint8_t>& license,
std::vector<uint8_t>* first_key) {
const std::string license_as_str(
reinterpret_cast<const char*>(!license.empty() ? &license[0] : nullptr),
license.size());
if (!base::IsStringASCII(license_as_str)) {
DVLOG(1) << "Non ASCII license: " << license_as_str;
return false;
}
std::optional<base::Value> root = base::JSONReader::Read(
license_as_str, base::JSON_PARSE_CHROMIUM_EXTENSIONS);
if (!root.has_value() || !root->is_dict()) {
DVLOG(1) << "Not valid JSON: " << license_as_str;
return false;
}
const base::Value::List* list_val = root->GetDict().FindList(kKeyIdsTag);
if (!list_val) {
DVLOG(1) << "Missing '" << kKeyIdsTag << "' parameter or not a list";
return false;
}
if (list_val->size() < 1) {
DVLOG(1) << "Empty '" << kKeyIdsTag << "' list";
return false;
}
const std::string* encoded_key = (*list_val)[0].GetIfString();
if (!encoded_key) {
DVLOG(1) << "First entry in '" << kKeyIdsTag << "' not a string";
return false;
}
std::string decoded_string;
if (!base::Base64UrlDecode(*encoded_key,
base::Base64UrlDecodePolicy::DISALLOW_PADDING,
&decoded_string) ||
decoded_string.empty()) {
DVLOG(1) << "Invalid '" << kKeyIdsTag << "' value: " << *encoded_key;
return false;
}
std::vector<uint8_t> result(decoded_string.begin(), decoded_string.end());
first_key->swap(result);
return true;
}
}