// Copyright (c) Huawei Technologies Co., Ltd. 2026-2026. All rights reserved.

use openssl::bn::BigNumContext;
use openssl::derive::Deriver;
use openssl::ec::{EcGroup, EcKey, EcPoint};
use openssl::kdf::{hkdf, HkdfMode};
use openssl::md::Md;
use openssl::nid::Nid;
use openssl::pkey::PKey;
use openssl::symm::{decrypt_aead, Cipher};

const HKDF_LABEL: &[u8] = b"KDC-PSK";

pub fn generate_ec_keypair() -> Result<(EcKey<openssl::pkey::Private>, String), String> {
    let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1).map_err(|e| e.to_string())?;
    let ec_key = EcKey::generate(&group).map_err(|e| e.to_string())?;
    let pub_pkey = PKey::from_ec_key(ec_key.clone()).map_err(|e| e.to_string())?;
    let pub_pem = pub_pkey.public_key_to_pem().map_err(|e| e.to_string())?;
    let pub_pem_str = String::from_utf8(pub_pem).map_err(|e| e.to_string())?;
    Ok((ec_key, pub_pem_str))
}

pub fn ecies_decrypt(
    encrypted: &[u8],
    ec_key: &EcKey<openssl::pkey::Private>,
) -> Result<Vec<u8>, String> {
    if encrypted.len() < 65 + 12 + 16 {
        return Err("encrypted data too short".to_string());
    }
    let ephemeral_bytes = &encrypted[..65];
    let iv = &encrypted[65..77];
    let ciphertext = &encrypted[77..encrypted.len() - 16];
    let tag = &encrypted[encrypted.len() - 16..];

    let mut ctx = BigNumContext::new().map_err(|e| e.to_string())?;
    let group = ec_key.group();
    let point = EcPoint::from_bytes(group, ephemeral_bytes, &mut ctx)
        .map_err(|e| format!("failed to parse ephemeral public key: {}", e))?;
    let ephemeral_pub = EcKey::from_public_key(group, &point).map_err(|e| e.to_string())?;
    ephemeral_pub
        .check_key()
        .map_err(|e| format!("invalid ephemeral public key: {}", e))?;

    let peer_pkey = PKey::from_ec_key(ephemeral_pub).map_err(|e| e.to_string())?;
    let private_pkey = PKey::from_ec_key(ec_key.clone()).map_err(|e| e.to_string())?;
    let mut deriver = Deriver::new(&private_pkey).map_err(|e| e.to_string())?;
    deriver.set_peer(&peer_pkey).map_err(|e| e.to_string())?;
    let shared_secret = deriver.derive_to_vec().map_err(|e| e.to_string())?;

    let mut aes_key = [0u8; 32];
    hkdf(
        Md::sha256(),
        &shared_secret,
        Some(&[]),
        Some(HKDF_LABEL),
        HkdfMode::ExtractAndExpand,
        None,
        &mut aes_key,
    )
    .map_err(|e| format!("HKDF failed: {}", e))?;

    let plaintext = decrypt_aead(
        Cipher::aes_256_gcm(),
        &aes_key,
        Some(iv),
        &[],
        ciphertext,
        tag,
    )
    .map_err(|e| format!("AES-GCM decrypt failed: {}", e))?;
    Ok(plaintext)
}

#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
    use openssl::bn::BigNumContext;
    use openssl::derive::Deriver;
    use openssl::ec::{EcGroup, EcKey, PointConversionForm};
    use openssl::kdf::{hkdf, HkdfMode};
    use openssl::md::Md;
    use openssl::nid::Nid;
    use openssl::pkey::PKey;
    use openssl::rand::rand_bytes;
    use openssl::symm::{encrypt_aead, Cipher};

    use super::*;

    fn ecies_encrypt_test(plaintext: &[u8], pub_pem: &str) -> Result<Vec<u8>, String> {
        let peer_pkey =
            PKey::public_key_from_pem(pub_pem.as_bytes()).map_err(|e| e.to_string())?;
        let group =
            EcGroup::from_curve_name(Nid::X9_62_PRIME256V1).map_err(|e| e.to_string())?;
        let ephemeral = EcKey::generate(&group).map_err(|e| e.to_string())?;
        let mut ctx = BigNumContext::new().map_err(|e| e.to_string())?;
        let ephemeral_pub_bytes = ephemeral
            .public_key()
            .to_bytes(&group, PointConversionForm::UNCOMPRESSED, &mut ctx)
            .map_err(|e| e.to_string())?;

        let ephemeral_pkey = PKey::from_ec_key(ephemeral).map_err(|e| e.to_string())?;
        let mut deriver = Deriver::new(&ephemeral_pkey).map_err(|e| e.to_string())?;
        deriver.set_peer(&peer_pkey).map_err(|e| e.to_string())?;
        let shared_secret = deriver.derive_to_vec().map_err(|e| e.to_string())?;

        let mut aes_key = [0u8; 32];
        hkdf(
            Md::sha256(),
            &shared_secret,
            Some(&[]),
            Some(HKDF_LABEL),
            HkdfMode::ExtractAndExpand,
            None,
            &mut aes_key,
        )
        .map_err(|e| e.to_string())?;

        let mut iv = [0u8; 12];
        rand_bytes(&mut iv).map_err(|e| e.to_string())?;
        let mut tag = [0u8; 16];
        let ciphertext = encrypt_aead(
            Cipher::aes_256_gcm(),
            &aes_key,
            Some(&iv),
            &[],
            plaintext,
            &mut tag,
        )
        .map_err(|e| e.to_string())?;

        let mut result = Vec::with_capacity(65 + 12 + ciphertext.len() + 16);
        result.extend_from_slice(&ephemeral_pub_bytes);
        result.extend_from_slice(&iv);
        result.extend_from_slice(&ciphertext);
        result.extend_from_slice(&tag);
        Ok(result)
    }

    // Test: verify generate_ec_keypair() returns valid P-256 keypair with PEM
    #[test]
    fn test_generate_ec_keypair() {
        let (ec_key, pub_pem) = generate_ec_keypair().unwrap();
        assert!(pub_pem.starts_with("-----BEGIN PUBLIC KEY-----"));
        assert!(pub_pem.contains("-----END PUBLIC KEY-----"));
        let group = ec_key.group();
        assert_eq!(group.curve_name().unwrap(), Nid::X9_62_PRIME256V1);
    }

    // Test: verify ECIES encrypt/decrypt roundtrip produces original plaintext
    #[test]
    fn test_ecies_encrypt_decrypt_roundtrip() {
        let (ec_key, pub_pem) = generate_ec_keypair().unwrap();
        let plaintext = b"hello ECIES-KEM/DEM world!";
        let encrypted = ecies_encrypt_test(plaintext, &pub_pem).unwrap();
        let decrypted = ecies_decrypt(&encrypted, &ec_key).unwrap();
        assert_eq!(decrypted, plaintext);
    }

    // Test: verify decrypting with wrong key fails
    #[test]
    fn test_ecies_wrong_key_fails() {
        let (ec_key_a, _pub_pem_a) = generate_ec_keypair().unwrap();
        let (_ec_key_b, pub_pem_b) = generate_ec_keypair().unwrap();
        let plaintext = b"secret message";
        let encrypted = ecies_encrypt_test(plaintext, &pub_pem_b).unwrap();
        let result = ecies_decrypt(&encrypted, &ec_key_a);
        assert!(result.is_err());
    }

    // Test: verify ecies_decrypt rejects empty data
    #[test]
    fn test_ecies_decrypt_empty_data() {
        let (ec_key, _) = generate_ec_keypair().unwrap();
        let result = ecies_decrypt(&[], &ec_key);
        assert!(result.is_err());
    }

    // Test: verify ecies_decrypt rejects truncated data
    #[test]
    fn test_ecies_decrypt_truncated_data() {
        let (ec_key, _) = generate_ec_keypair().unwrap();
        let result = ecies_decrypt(&[0u8; 10], &ec_key);
        assert!(result.is_err());
    }

    // Test: verify ECIES roundtrip with 32-byte PSK payload and correct wire format
    #[test]
    fn test_ecies_32byte_psk_roundtrip() {
        let (ec_key, pub_pem) = generate_ec_keypair().unwrap();
        let psk = [0xABu8; 32];
        let encrypted = ecies_encrypt_test(&psk, &pub_pem).unwrap();
        assert_eq!(encrypted.len(), 65 + 12 + 32 + 16);
        assert_eq!(encrypted[0], 0x04);
        let decrypted = ecies_decrypt(&encrypted, &ec_key).unwrap();
        assert_eq!(decrypted, &psk[..]);
    }

    // Test: verify ecies_decrypt rejects invalid ephemeral key bytes
    #[test]
    fn test_ecies_decrypt_invalid_ephemeral_key() {
        let (ec_key, _) = generate_ec_keypair().unwrap();
        let mut data = vec![0u8; 65 + 12 + 16];
        data[0] = 0x04;
        data[1..65].fill(0xFF);
        let result = ecies_decrypt(&data, &ec_key);
        assert!(result.is_err());
    }

    #[test]
    fn test_ecies_decrypt_minimum_length_empty_plaintext() {
        let (ec_key, pub_pem) = generate_ec_keypair().unwrap();
        let encrypted = ecies_encrypt_test(b"", &pub_pem).unwrap();
        assert_eq!(encrypted.len(), 65 + 12 + 0 + 16);
        let decrypted = ecies_decrypt(&encrypted, &ec_key).unwrap();
        assert!(decrypted.is_empty());
    }

    #[test]
    fn test_ecies_decrypt_corrupted_ciphertext() {
        let (ec_key, pub_pem) = generate_ec_keypair().unwrap();
        let plaintext = b"some secret data";
        let mut encrypted = ecies_encrypt_test(plaintext, &pub_pem).unwrap();
        // Flip a byte in the ciphertext portion (bytes 77..len-16)
        let ciphertext_end = encrypted.len() - 16;
        if ciphertext_end > 77 {
            encrypted[77] ^= 0xFF;
        }
        let result = ecies_decrypt(&encrypted, &ec_key);
        assert!(result.is_err());
        assert!(result.unwrap_err().contains("AES-GCM decrypt failed"));
    }

    #[test]
    fn test_ecies_decrypt_all_zeros_93_bytes() {
        let (ec_key, _) = generate_ec_keypair().unwrap();
        let data = [0u8; 93];
        let result = ecies_decrypt(&data, &ec_key);
        assert!(result.is_err());
    }

    #[test]
    fn test_generate_ec_keypair_different_keys() {
        let (_, pub_pem_a) = generate_ec_keypair().unwrap();
        let (_, pub_pem_b) = generate_ec_keypair().unwrap();
        assert_ne!(pub_pem_a, pub_pem_b);
    }

    #[test]
    fn test_ecies_decrypt_92_bytes_too_short() {
        let (ec_key, _) = generate_ec_keypair().unwrap();
        let data = [0u8; 92];
        let result = ecies_decrypt(&data, &ec_key);
        assert!(result.is_err());
        assert!(result.unwrap_err().contains("encrypted data too short"));
    }

    #[test]
    fn test_ecies_decrypt_valid_point_modified_iv() {
        let (ec_key, pub_pem) = generate_ec_keypair().unwrap();
        let plaintext = b"test data for iv corruption";
        let mut encrypted = ecies_encrypt_test(plaintext, &pub_pem).unwrap();
        // Modify the IV (bytes 65..77) instead of the ephemeral key
        encrypted[70] ^= 0xFF;
        let result = ecies_decrypt(&encrypted, &ec_key);
        assert!(result.is_err());
        assert!(result.unwrap_err().contains("AES-GCM decrypt failed"));
    }

    #[test]
    fn test_ecies_decrypt_valid_point_modified_tag() {
        let (ec_key, pub_pem) = generate_ec_keypair().unwrap();
        let plaintext = b"test data for tag corruption";
        let mut encrypted = ecies_encrypt_test(plaintext, &pub_pem).unwrap();
        // Modify the last byte of the tag (last 16 bytes)
        let tag_start = encrypted.len() - 1;
        encrypted[tag_start] ^= 0xFF;
        let result = ecies_decrypt(&encrypted, &ec_key);
        assert!(result.is_err());
        assert!(result.unwrap_err().contains("AES-GCM decrypt failed"));
    }

    #[test]
    fn test_ecies_decrypt_valid_point_wrong_key_full_path() {
        let (ec_key_a, _pub_pem_a) = generate_ec_keypair().unwrap();
        let (_ec_key_b, pub_pem_b) = generate_ec_keypair().unwrap();
        let plaintext = b"secret message for full path";
        // Encrypt for B, try to decrypt with A
        let encrypted = ecies_encrypt_test(plaintext, &pub_pem_b).unwrap();
        let result = ecies_decrypt(&encrypted, &ec_key_a);
        assert!(result.is_err());
        assert!(result.unwrap_err().contains("AES-GCM decrypt failed"));
    }

    #[test]
    fn test_ecies_decrypt_invalid_point_format() {
        let (ec_key, _) = generate_ec_keypair().unwrap();
        let mut data = vec![0u8; 93];
        // Use 0x02 (compressed point format) instead of 0x04 (uncompressed)
        data[0] = 0x02;
        let result = ecies_decrypt(&data, &ec_key);
        assert!(result.is_err());
    }

    #[test]
    fn test_ecies_decrypt_corrupted_middle_ciphertext() {
        let (ec_key, pub_pem) = generate_ec_keypair().unwrap();
        let plaintext = b"0123456789ABCDEF0123456789ABCDEF";
        let mut encrypted = ecies_encrypt_test(plaintext, &pub_pem).unwrap();
        // Corrupt a byte in the middle of the ciphertext area
        let mid = (77 + encrypted.len() - 16) / 2;
        encrypted[mid] ^= 0x01;
        let result = ecies_decrypt(&encrypted, &ec_key);
        assert!(result.is_err());
        assert!(result.unwrap_err().contains("AES-GCM decrypt failed"));
    }
}