#include "crypto/encryptor.h"
#include <stddef.h>
#include <stdint.h>
#include "base/logging.h"
#include "base/strings/string_util.h"
#include "base/sys_byteorder.h"
#include "crypto/openssl_util.h"
#include "crypto/symmetric_key.h"
#include "third_party/boringssl/src/include/openssl/aes.h"
#include "third_party/boringssl/src/include/openssl/evp.h"
#if defined(OHOS_ENCRYPT)
#define GCM_TAG_SIZE 16
#define GCM_IV_SIZE 12
#endif
namespace crypto {
namespace {
const EVP_CIPHER* GetCipherForKey(const SymmetricKey* key) {
switch (key->key().length()) {
case 16: return EVP_aes_128_cbc();
case 32: return EVP_aes_256_cbc();
default:
return nullptr;
}
}
#if defined(OHOS_ENCRYPT)
const EVP_CIPHER* GetCipherForKeyGCM(const SymmetricKey* key) {
switch (key->key().length()) {
case 16:
return EVP_aes_128_gcm();
case 32:
return EVP_aes_256_gcm();
default:
return nullptr;
}
}
#endif
}
Encryptor::Encryptor() : key_(nullptr), mode_(CBC) {}
Encryptor::~Encryptor() = default;
bool Encryptor::Init(const SymmetricKey* key, Mode mode, base::StringPiece iv) {
return Init(key, mode, base::as_bytes(base::make_span(iv)));
}
bool Encryptor::Init(const SymmetricKey* key,
Mode mode,
base::span<const uint8_t> iv) {
DCHECK(key);
#if defined(OHOS_ENCRYPT)
DCHECK(mode == CBC || mode == CTR || mode == GCM);
#else
DCHECK(mode == CBC || mode == CTR);
#endif
EnsureOpenSSLInit();
if (mode == CBC && iv.size() != AES_BLOCK_SIZE)
return false;
if (mode == CTR && !iv.empty())
return false;
if (GetCipherForKey(key) == nullptr)
return false;
#if defined(OHOS_ENCRYPT)
if (mode == GCM && iv.size() != GCM_IV_SIZE) {
return false;
}
#endif
key_ = key;
mode_ = mode;
iv_.assign(iv.begin(), iv.end());
return true;
}
bool Encryptor::Encrypt(base::StringPiece plaintext, std::string* ciphertext) {
return CryptString(true, plaintext, ciphertext);
}
bool Encryptor::Encrypt(base::span<const uint8_t> plaintext,
std::vector<uint8_t>* ciphertext) {
return CryptBytes(true, plaintext, ciphertext);
}
bool Encryptor::Decrypt(base::StringPiece ciphertext, std::string* plaintext) {
return CryptString(false, ciphertext, plaintext);
}
bool Encryptor::Decrypt(base::span<const uint8_t> ciphertext,
std::vector<uint8_t>* plaintext) {
return CryptBytes(false, ciphertext, plaintext);
}
bool Encryptor::SetCounter(base::StringPiece counter) {
return SetCounter(base::as_bytes(base::make_span(counter)));
}
bool Encryptor::SetCounter(base::span<const uint8_t> counter) {
if (mode_ != CTR)
return false;
if (counter.size() != 16u)
return false;
iv_.assign(counter.begin(), counter.end());
return true;
}
bool Encryptor::CryptString(bool do_encrypt,
base::StringPiece input,
std::string* output) {
std::string result(MaxOutput(do_encrypt, input.size()), '\0');
absl::optional<size_t> len;
#if defined(OHOS_ENCRYPT)
std::string tag(GCM_TAG_SIZE, 0);
if (mode_ == CTR) {
len = CryptCTR(do_encrypt, base::as_bytes(base::make_span(input)),
base::as_writable_bytes(base::make_span(result)));
} else if (mode_ == GCM) {
if (do_encrypt) {
len = EncryptGCM(base::as_bytes(base::make_span(input)),
base::as_writable_bytes(base::make_span(result)), &tag);
} else {
if (input.length() <= GCM_TAG_SIZE) {
LOG(WARNING) << "input size less than gcm tag size";
return false;
}
tag = std::string(
input.substr(input.length() - GCM_TAG_SIZE, GCM_TAG_SIZE));
std::string ciphertext =
std::string(input.substr(0, input.length() - GCM_TAG_SIZE));
const size_t output_size = ciphertext.length();
if (output_size + 1 <= ciphertext.length()) {
LOG(WARNING) << "output size occur overflow";
return false;
}
len = DecryptGCM(ciphertext,
base::as_writable_bytes(base::make_span(result)), &tag);
}
} else {
len = Crypt(do_encrypt, base::as_bytes(base::make_span(input)),
base::as_writable_bytes(base::make_span(result)));
}
#else
len = (mode_ == CTR)
? CryptCTR(do_encrypt, base::as_bytes(base::make_span(input)),
base::as_writable_bytes(base::make_span(result)))
: Crypt(do_encrypt, base::as_bytes(base::make_span(input)),
base::as_writable_bytes(base::make_span(result)));
#endif
if (!len)
return false;
result.resize(*len);
#if defined(OHOS_ENCRYPT)
if (mode_ == GCM && do_encrypt) {
result += tag;
}
#endif
*output = std::move(result);
return true;
}
bool Encryptor::CryptBytes(bool do_encrypt,
base::span<const uint8_t> input,
std::vector<uint8_t>* output) {
std::vector<uint8_t> result(MaxOutput(do_encrypt, input.size()));
absl::optional<size_t> len = (mode_ == CTR)
? CryptCTR(do_encrypt, input, result)
: Crypt(do_encrypt, input, result);
if (!len)
return false;
result.resize(*len);
*output = std::move(result);
return true;
}
size_t Encryptor::MaxOutput(bool do_encrypt, size_t length) {
#if defined(OHOS_ENCRYPT)
size_t result = length + ((do_encrypt && mode_ == CBC) ? 16
: (do_encrypt && mode_ == GCM) ? 12
: 0);
#else
size_t result = length + ((do_encrypt && mode_ == CBC) ? 16 : 0);
#endif
CHECK_GE(result, length);
return result;
}
absl::optional<size_t> Encryptor::Crypt(bool do_encrypt,
base::span<const uint8_t> input,
base::span<uint8_t> output) {
DCHECK(key_);
const EVP_CIPHER* cipher = GetCipherForKey(key_);
DCHECK(cipher);
const std::string& key = key_->key();
DCHECK_EQ(EVP_CIPHER_iv_length(cipher), iv_.size());
DCHECK_EQ(EVP_CIPHER_key_length(cipher), key.size());
OpenSSLErrStackTracer err_tracer(FROM_HERE);
bssl::ScopedEVP_CIPHER_CTX ctx;
if (!EVP_CipherInit_ex(ctx.get(), cipher, nullptr,
reinterpret_cast<const uint8_t*>(key.data()),
iv_.data(), do_encrypt)) {
return absl::nullopt;
}
CHECK_GE(output.size(), input.size() + (do_encrypt ? iv_.size() : 0));
int out_len;
if (!EVP_CipherUpdate(ctx.get(), output.data(), &out_len, input.data(),
input.size()))
return absl::nullopt;
int tail_len;
if (!EVP_CipherFinal_ex(ctx.get(), output.data() + out_len, &tail_len))
return absl::nullopt;
out_len += tail_len;
DCHECK_LE(out_len, static_cast<int>(output.size()));
return out_len;
}
absl::optional<size_t> Encryptor::CryptCTR(bool do_encrypt,
base::span<const uint8_t> input,
base::span<uint8_t> output) {
if (iv_.size() != AES_BLOCK_SIZE) {
LOG(ERROR) << "Counter value not set in CTR mode.";
return absl::nullopt;
}
AES_KEY aes_key;
if (AES_set_encrypt_key(reinterpret_cast<const uint8_t*>(key_->key().data()),
key_->key().size() * 8, &aes_key) != 0) {
return absl::nullopt;
}
uint8_t ecount_buf[AES_BLOCK_SIZE] = { 0 };
unsigned int block_offset = 0;
CHECK_GE(output.size(), input.size());
AES_ctr128_encrypt(input.data(), output.data(), input.size(), &aes_key,
iv_.data(), ecount_buf, &block_offset);
return input.size();
}
#if defined(OHOS_ENCRYPT)
absl::optional<size_t> Encryptor::EncryptGCM(base::span<const uint8_t> input,
base::span<uint8_t> output,
std::string* tag) {
DCHECK(key_);
DCHECK(output.data());
const EVP_CIPHER* cipher = GetCipherForKeyGCM(key_);
DCHECK(cipher);
const std::string& key = key_->key();
DCHECK_EQ(EVP_CIPHER_iv_length(cipher), iv_.size());
DCHECK_EQ(EVP_CIPHER_key_length(cipher), key.size());
OpenSSLErrStackTracer err_tracer(FROM_HERE);
bssl::ScopedEVP_CIPHER_CTX ctx;
if (!EVP_EncryptInit_ex(ctx.get(), cipher, nullptr, nullptr, nullptr)) {
return absl::nullopt;
}
if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_IVLEN, iv_.size(),
nullptr)) {
return absl::nullopt;
}
if (!EVP_EncryptInit_ex(ctx.get(), nullptr, nullptr,
reinterpret_cast<const uint8_t*>(key.data()),
iv_.data())) {
return absl::nullopt;
}
const size_t output_size = input.size() + (iv_.size());
CHECK_GT(output_size, 0u);
CHECK_GT(output_size + 1, input.size());
int out_len;
* EVP_EncryptUpdate may call multiples times if necessary
*/
if (!EVP_EncryptUpdate(ctx.get(), output.data(), &out_len,
reinterpret_cast<const uint8_t*>(input.data()),
input.size())) {
return absl::nullopt;
}
* written in this stage but this does not occure in GCM mode
*/
int len;
if (!EVP_EncryptFinal_ex(ctx.get(), output.data() + out_len, &len)) {
return absl::nullopt;
}
out_len += len;
if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_GET_TAG, GCM_TAG_SIZE,
(void*)tag->data())) {
return absl::nullopt;
}
DCHECK_LE(out_len, static_cast<int>(output_size));
return out_len;
}
absl::optional<size_t> Encryptor::DecryptGCM(const std::string& input,
base::span<uint8_t> output,
std::string* tag) {
DCHECK(key_);
DCHECK(output.data());
const EVP_CIPHER* cipher = GetCipherForKeyGCM(key_);
DCHECK(cipher);
const std::string& key = key_->key();
DCHECK_EQ(EVP_CIPHER_iv_length(cipher), iv_.size());
DCHECK_EQ(EVP_CIPHER_key_length(cipher), key.size());
OpenSSLErrStackTracer err_tracer(FROM_HERE);
bssl::ScopedEVP_CIPHER_CTX ctx;
if (!EVP_DecryptInit_ex(ctx.get(), cipher, nullptr, nullptr, nullptr)) {
return absl::nullopt;
}
if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_IVLEN, iv_.size(),
nullptr)) {
return absl::nullopt;
}
if (!EVP_DecryptInit_ex(ctx.get(), nullptr, nullptr,
reinterpret_cast<const uint8_t*>(key.data()),
iv_.data())) {
return absl::nullopt;
}
int out_len;
if (!EVP_DecryptUpdate(ctx.get(), output.data(), &out_len,
reinterpret_cast<const uint8_t*>(input.data()),
input.length())) {
return absl::nullopt;
}
if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_TAG, GCM_TAG_SIZE,
(void*)tag->data())) {
return absl::nullopt;
}
* it return as failure.
*/
int len;
int ret = EVP_DecryptFinal_ex(ctx.get(), output.data() + out_len, &len);
if (ret > 0) {
out_len += len;
} else {
return absl::nullopt;
}
DCHECK_LE(out_len, static_cast<int>(input.length()));
return out_len;
}
#endif
}