use openssl::rand::rand_bytes;
use openssl::symm::{decrypt_aead, encrypt_aead, Cipher};
use crate::constants::PSK_LEN;
pub const IV_LEN: usize = 12;
pub const TAG_LEN: usize = 16;
#[derive(Debug, thiserror::Error)]
pub enum PskError {
#[error("encryption failed: {0}")]
Encrypt(#[from] openssl::error::ErrorStack),
#[error("data too short to decrypt")]
DataTooShort,
}
pub fn psk_encrypt_with_key(plaintext: &[u8], key: &[u8; PSK_LEN]) -> Result<Vec<u8>, PskError> {
let mut iv = [0u8; IV_LEN];
rand_bytes(&mut iv)?;
let cipher = Cipher::aes_256_gcm();
let mut tag = [0u8; TAG_LEN];
let ciphertext = encrypt_aead(cipher, key, Some(&iv), &[], plaintext, &mut tag)?;
let mut output = Vec::with_capacity(IV_LEN + ciphertext.len() + TAG_LEN);
output.extend_from_slice(&iv);
output.extend_from_slice(&ciphertext);
output.extend_from_slice(&tag);
Ok(output)
}
pub fn psk_decrypt_with_key(data: &[u8], key: &[u8; PSK_LEN]) -> Result<Vec<u8>, PskError> {
if data.len() < IV_LEN + TAG_LEN {
return Err(PskError::DataTooShort);
}
let (iv, rest) = data.split_at(IV_LEN);
let (ciphertext, tag) = rest.split_at(rest.len() - TAG_LEN);
let cipher = Cipher::aes_256_gcm();
let plaintext = decrypt_aead(cipher, key, Some(iv), &[], ciphertext, tag)?;
Ok(plaintext)
}
pub fn generate_random_psk() -> Result<[u8; PSK_LEN], PskError> {
let mut key = [0u8; PSK_LEN];
rand_bytes(&mut key)?;
Ok(key)
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
use super::*;
fn make_key() -> [u8; PSK_LEN] {
generate_random_psk().unwrap()
}
#[test]
fn test_psk_encrypt_decrypt_roundtrip() {
let key = make_key();
let plaintext = b"hello kdc common";
let encrypted = psk_encrypt_with_key(plaintext, &key).unwrap();
let decrypted = psk_decrypt_with_key(&encrypted, &key).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_psk_empty_plaintext() {
let key = make_key();
let encrypted = psk_encrypt_with_key(b"", &key).unwrap();
let decrypted = psk_decrypt_with_key(&encrypted, &key).unwrap();
assert!(decrypted.is_empty());
}
#[test]
fn test_psk_large_plaintext() {
let key = make_key();
let plaintext = vec![0xAB_u8; 1024 * 1024];
let encrypted = psk_encrypt_with_key(&plaintext, &key).unwrap();
let decrypted = psk_decrypt_with_key(&encrypted, &key).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_psk_wrong_key() {
let key_a = make_key();
let key_b = make_key();
let encrypted = psk_encrypt_with_key(b"secret data", &key_a).unwrap();
let result = psk_decrypt_with_key(&encrypted, &key_b);
assert!(result.is_err());
}
#[test]
fn test_psk_tampered_ciphertext() {
let key = make_key();
let mut encrypted = psk_encrypt_with_key(b"secret data", &key).unwrap();
let tamper_idx = encrypted.len() - 1;
encrypted[tamper_idx] ^= 0xFF;
let result = psk_decrypt_with_key(&encrypted, &key);
assert!(result.is_err());
}
#[test]
fn test_psk_different_ivs() {
let key = make_key();
let plaintext = b"same plaintext";
let encrypted_a = psk_encrypt_with_key(plaintext, &key).unwrap();
let encrypted_b = psk_encrypt_with_key(plaintext, &key).unwrap();
assert_ne!(encrypted_a, encrypted_b);
}
#[test]
fn test_generate_random_psk() {
let key_a = generate_random_psk().unwrap();
let key_b = generate_random_psk().unwrap();
assert_eq!(key_a.len(), PSK_LEN);
assert_eq!(key_b.len(), PSK_LEN);
assert_ne!(key_a, key_b);
}
#[test]
fn test_psk_short_data() {
let key = make_key();
let short_data = [0u8; IV_LEN];
let result = psk_decrypt_with_key(&short_data, &key);
assert!(matches!(result, Err(PskError::DataTooShort)));
}
#[test]
fn test_psk_encrypted_output_format() {
let key = make_key();
let plaintext = b"test";
let encrypted = psk_encrypt_with_key(plaintext, &key).unwrap();
assert_eq!(encrypted.len(), IV_LEN + plaintext.len() + TAG_LEN);
}
}