#include "device/fido/cable/noise.h"
#include <string.h>
#include "base/check_op.h"
#include "base/numerics/byte_conversions.h"
#include "crypto/aead.h"
#include "crypto/hash.h"
#include "crypto/kdf.h"
#include "device/fido/fido_constants.h"
#include "third_party/boringssl/src/include/openssl/ec.h"
#include "third_party/boringssl/src/include/openssl/nid.h"
namespace {
std::tuple<std::array<uint8_t, 32>, std::array<uint8_t, 32>> HKDF2(
base::span<const uint8_t, 32> ck,
base::span<const uint8_t> ikm) {
auto output = crypto::kdf::Hkdf<32 * 2>(crypto::hash::kSha256, ikm, ck,
{});
std::array<uint8_t, 32> a, b;
auto [first, second] = base::span(output).split_at<32>();
base::span(a).copy_from(first);
base::span(b).copy_from(second);
return std::make_tuple(a, b);
}
std::string_view ProtocolNameForHandshakeType(
device::Noise::HandshakeType type) {
static const std::string_view kKNProtocolName =
"Noise_KNpsk0_P256_AESGCM_SHA256";
static const std::string_view kNKProtocolName =
"Noise_NKpsk0_P256_AESGCM_SHA256";
static const std::string_view kNKNoPskProtocolName =
"Noise_NK_P256_AESGCM_SHA256";
switch (type) {
case device::Noise::HandshakeType::kNKpsk0:
return kNKProtocolName;
case device::Noise::HandshakeType::kKNpsk0:
return kKNProtocolName;
case device::Noise::HandshakeType::kNK:
return kNKNoPskProtocolName;
}
}
}
namespace device {
Noise::Noise() = default;
Noise::~Noise() = default;
void Noise::Init(Noise::HandshakeType type) {
std::string_view name = ProtocolNameForHandshakeType(type);
chaining_key_.fill(0);
base::span(chaining_key_).copy_prefix_from(base::as_byte_span(name));
h_ = chaining_key_;
}
void Noise::MixHash(base::span<const uint8_t> in) {
crypto::hash::Hasher hash(crypto::hash::kSha256);
hash.Update(h_);
hash.Update(in);
hash.Finish(h_);
}
void Noise::MixKey(base::span<const uint8_t> ikm) {
std::array<uint8_t, 32> temp_k;
std::tie(chaining_key_, temp_k) = HKDF2(chaining_key_, ikm);
InitializeKey(temp_k);
}
void Noise::MixKeyAndHash(base::span<const uint8_t> ikm) {
auto output = crypto::kdf::Hkdf<32 * 3>(crypto::hash::kSha256, ikm,
chaining_key_, {});
base::span(chaining_key_).copy_from(base::span(output).first<32>());
const auto [hash, key] = base::span(output).subspan<32>().split_at<32>();
MixHash(hash);
InitializeKey(key);
}
std::vector<uint8_t> Noise::EncryptAndHash(
base::span<const uint8_t> plaintext) {
uint8_t nonce[12] = {};
base::span(nonce).first<4u>().copy_from(
base::U32ToBigEndian(symmetric_nonce_));
symmetric_nonce_++;
crypto::Aead aead(crypto::Aead::AES_256_GCM);
aead.Init(symmetric_key_);
std::vector<uint8_t> ciphertext = aead.Seal(plaintext, nonce, h_);
MixHash(ciphertext);
return ciphertext;
}
std::optional<std::vector<uint8_t>> Noise::DecryptAndHash(
base::span<const uint8_t> ciphertext) {
uint8_t nonce[12] = {};
base::span(nonce).first<4u>().copy_from(
base::U32ToBigEndian(symmetric_nonce_));
symmetric_nonce_++;
crypto::Aead aead(crypto::Aead::AES_256_GCM);
aead.Init(symmetric_key_);
auto plaintext = aead.Open(ciphertext, nonce, h_);
if (plaintext) {
MixHash(ciphertext);
}
return plaintext;
}
std::array<uint8_t, 32> Noise::handshake_hash() const {
return h_;
}
void Noise::MixHashPoint(const EC_POINT* point) {
uint8_t x962[kP256X962Length];
bssl::UniquePtr<EC_GROUP> p256(
EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
CHECK_EQ(sizeof(x962),
EC_POINT_point2oct(p256.get(), point, POINT_CONVERSION_UNCOMPRESSED,
x962, sizeof(x962), nullptr));
MixHash(x962);
}
std::tuple<std::array<uint8_t, 32>, std::array<uint8_t, 32>>
Noise::traffic_keys() const {
return HKDF2(chaining_key_, {});
}
void Noise::InitializeKey(base::span<const uint8_t, 32> key) {
base::span(symmetric_key_).copy_from(key);
symmetric_nonce_ = 0;
}
}