// Copyright (c) Huawei Technologies Co., Ltd. 2026-2026. All rights reserved.

//! PSK (Pre-Shared Key) encryption utilities using AES-256-GCM.

use openssl::rand::rand_bytes;
use openssl::symm::{decrypt_aead, encrypt_aead, Cipher};

use crate::constants::PSK_LEN;

pub const IV_LEN: usize = 12;
pub const TAG_LEN: usize = 16;

#[derive(Debug, thiserror::Error)]
pub enum PskError {
    #[error("encryption failed: {0}")]
    Encrypt(#[from] openssl::error::ErrorStack),
    #[error("data too short to decrypt")]
    DataTooShort,
}

pub fn psk_encrypt_with_key(plaintext: &[u8], key: &[u8; PSK_LEN]) -> Result<Vec<u8>, PskError> {
    let mut iv = [0u8; IV_LEN];
    rand_bytes(&mut iv)?;

    let cipher = Cipher::aes_256_gcm();
    let mut tag = [0u8; TAG_LEN];
    let ciphertext = encrypt_aead(cipher, key, Some(&iv), &[], plaintext, &mut tag)?;

    let mut output = Vec::with_capacity(IV_LEN + ciphertext.len() + TAG_LEN);
    output.extend_from_slice(&iv);
    output.extend_from_slice(&ciphertext);
    output.extend_from_slice(&tag);
    Ok(output)
}

pub fn psk_decrypt_with_key(data: &[u8], key: &[u8; PSK_LEN]) -> Result<Vec<u8>, PskError> {
    if data.len() < IV_LEN + TAG_LEN {
        return Err(PskError::DataTooShort);
    }

    let (iv, rest) = data.split_at(IV_LEN);
    let (ciphertext, tag) = rest.split_at(rest.len() - TAG_LEN);

    let cipher = Cipher::aes_256_gcm();
    let plaintext = decrypt_aead(cipher, key, Some(iv), &[], ciphertext, tag)?;
    Ok(plaintext)
}

pub fn generate_random_psk() -> Result<[u8; PSK_LEN], PskError> {
    let mut key = [0u8; PSK_LEN];
    rand_bytes(&mut key)?;
    Ok(key)
}

#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
    use super::*;

    fn make_key() -> [u8; PSK_LEN] {
        generate_random_psk().unwrap()
    }

    // Test: Encrypting then decrypting with the same key returns the original plaintext
    #[test]
    fn test_psk_encrypt_decrypt_roundtrip() {
        let key = make_key();
        let plaintext = b"hello kdc common";
        let encrypted = psk_encrypt_with_key(plaintext, &key).unwrap();
        let decrypted = psk_decrypt_with_key(&encrypted, &key).unwrap();
        assert_eq!(decrypted, plaintext);
    }

    // Test: Empty plaintext encrypts and decrypts successfully to empty bytes
    #[test]
    fn test_psk_empty_plaintext() {
        let key = make_key();
        let encrypted = psk_encrypt_with_key(b"", &key).unwrap();
        let decrypted = psk_decrypt_with_key(&encrypted, &key).unwrap();
        assert!(decrypted.is_empty());
    }

    // Test: Large (1 MiB) plaintext encrypts and decrypts without errors
    #[test]
    fn test_psk_large_plaintext() {
        let key = make_key();
        let plaintext = vec![0xAB_u8; 1024 * 1024];
        let encrypted = psk_encrypt_with_key(&plaintext, &key).unwrap();
        let decrypted = psk_decrypt_with_key(&encrypted, &key).unwrap();
        assert_eq!(decrypted, plaintext);
    }

    // Test: Decrypting with a different key than used for encryption fails
    #[test]
    fn test_psk_wrong_key() {
        let key_a = make_key();
        let key_b = make_key();
        let encrypted = psk_encrypt_with_key(b"secret data", &key_a).unwrap();
        let result = psk_decrypt_with_key(&encrypted, &key_b);
        assert!(result.is_err());
    }

    // Test: Tampering with the ciphertext causes decryption to fail (auth tag mismatch)
    #[test]
    fn test_psk_tampered_ciphertext() {
        let key = make_key();
        let mut encrypted = psk_encrypt_with_key(b"secret data", &key).unwrap();
        let tamper_idx = encrypted.len() - 1;
        encrypted[tamper_idx] ^= 0xFF;
        let result = psk_decrypt_with_key(&encrypted, &key);
        assert!(result.is_err());
    }

    // Test: Encrypting the same plaintext twice produces different ciphertext (random IV)
    #[test]
    fn test_psk_different_ivs() {
        let key = make_key();
        let plaintext = b"same plaintext";
        let encrypted_a = psk_encrypt_with_key(plaintext, &key).unwrap();
        let encrypted_b = psk_encrypt_with_key(plaintext, &key).unwrap();
        assert_ne!(encrypted_a, encrypted_b);
    }

    // Test: generate_random_psk returns a 32-byte key and successive calls produce different keys
    #[test]
    fn test_generate_random_psk() {
        let key_a = generate_random_psk().unwrap();
        let key_b = generate_random_psk().unwrap();
        assert_eq!(key_a.len(), PSK_LEN);
        assert_eq!(key_b.len(), PSK_LEN);
        assert_ne!(key_a, key_b);
    }

    // Test: Data shorter than IV_LEN + TAG_LEN returns DataTooShort error
    #[test]
    fn test_psk_short_data() {
        let key = make_key();
        let short_data = [0u8; IV_LEN];
        let result = psk_decrypt_with_key(&short_data, &key);
        assert!(matches!(result, Err(PskError::DataTooShort)));
    }

    // Test: Encrypted output length equals IV_LEN + ciphertext_len + TAG_LEN
    #[test]
    fn test_psk_encrypted_output_format() {
        let key = make_key();
        let plaintext = b"test";
        let encrypted = psk_encrypt_with_key(plaintext, &key).unwrap();
        // encrypted = IV (12) + ciphertext (4, same as plaintext for GCM) + tag (16)
        assert_eq!(encrypted.len(), IV_LEN + plaintext.len() + TAG_LEN);
    }
}