#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
pub mod config_manager;
pub mod crypto;
pub mod http_client;
pub use common::psk;
pub mod types;
use std::ffi::CString;
use std::os::raw::c_char;
use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine as _};
use parking_lot::Mutex;
use common::constants::MAX_REQUEST_SIZE;
use crate::config_manager::{get_config_center, get_log_callback_holder, SecurePsk};
use crate::crypto::{ecies_decrypt, generate_ec_keypair};
use crate::http_client::{request_token_from_ra_agent, send_to_kdc_agent};
use crate::types::{SetLogCallbackFunc, KDC_LOG_ERROR, KDC_LOG_INFO, KDC_LOG_WARN};
use common::constants::{CommonError, PSK_LEN, TRUSTED_ENV_TOKEN_REQ, TRUSTED_REGISTER_PSK_REQ};
static PSK_REGISTER_LOCK: Mutex<()> = Mutex::new(());
fn has_control_chars(s: &str) -> bool {
s.bytes().any(|b| b < 0x20 || b == 0x7F)
}
fn is_valid_path(path: &str) -> bool {
if path.is_empty() || !std::path::Path::new(path).is_absolute() {
return false;
}
std::fs::canonicalize(path).is_ok()
}
fn cstr_to_string(ptr: *const c_char) -> Result<String, CommonError> {
common::cstr::cstr_to_string(ptr).map_err(|_| CommonError::RequestError)
}
fn do_register_psk() -> Result<(), CommonError> {
proxy_log!(KDC_LOG_INFO, "PSK registration: generating EC keypair...");
let (ec_key, _pub_pem) = generate_ec_keypair().map_err(|e| {
proxy_log!(KDC_LOG_ERROR, &format!("failed to generate EC keypair: {}", e));
CommonError::InternalError
})?;
proxy_log!(KDC_LOG_INFO, "PSK registration: requesting attestation token from RA agent...");
let token = request_token_from_ra_agent(Some(&ec_key))?;
proxy_log!(KDC_LOG_INFO, &format!("PSK registration: sending token to kdc_agent ({} bytes)...", token.len()));
let register_json = serde_json::json!({"token": token}).to_string();
let resp = send_to_kdc_agent(TRUSTED_REGISTER_PSK_REQ, Some(®ister_json))?;
let encrypted_psk_b64 = resp.data.ok_or_else(|| {
proxy_log!(KDC_LOG_ERROR, "PSK registration: kdc_agent response has no data field");
CommonError::InternalError
})?;
proxy_log!(KDC_LOG_INFO, "PSK registration: decrypting PSK from kdc_agent response...");
let encrypted_psk = BASE64_STANDARD.decode(&encrypted_psk_b64).map_err(|e| {
proxy_log!(KDC_LOG_ERROR, &format!("PSK registration: base64 decode failed: {}", e));
CommonError::InternalError
})?;
let psk_bytes = ecies_decrypt(&encrypted_psk, &ec_key).map_err(|e| {
proxy_log!(KDC_LOG_ERROR, &format!("PSK registration: ECIES decrypt failed: {}", e));
CommonError::InternalError
})?;
if psk_bytes.len() != PSK_LEN {
proxy_log!(KDC_LOG_ERROR, &format!("PSK registration: unexpected PSK length {} (expected {})", psk_bytes.len(), PSK_LEN));
return Err(CommonError::InternalError);
}
let mut psk = [0u8; PSK_LEN];
psk.copy_from_slice(&psk_bytes);
let config = get_config_center();
let mut guard = config.lock();
guard.psk = Some(SecurePsk(psk));
drop(guard);
proxy_log!(KDC_LOG_INFO, "PSK registered successfully");
Ok(())
}
#[no_mangle]
pub extern "C" fn SetRAAgentTokenUrl(url: *const c_char) -> u32 {
match cstr_to_string(url) {
Ok(s) => {
if s.is_empty() || has_control_chars(&s) {
proxy_log!(KDC_LOG_ERROR, "SetRAAgentTokenUrl: URL is empty or contains control characters");
return CommonError::RequestError as u32;
}
let config = get_config_center();
let mut guard = config.lock();
guard.ra_agent_url = s;
0
}
Err(e) => {
proxy_log!(KDC_LOG_ERROR, &format!("SetRAAgentTokenUrl: failed to parse URL parameter, error code {}", e));
e as u32
}
}
}
#[no_mangle]
pub extern "C" fn SetKdcAgentUrl(url: *const c_char) -> u32 {
match cstr_to_string(url) {
Ok(s) => {
if s.is_empty() || has_control_chars(&s) {
proxy_log!(KDC_LOG_ERROR, "SetKdcAgentUrl: URL is empty or contains control characters");
return CommonError::RequestError as u32;
}
let config = get_config_center();
let mut guard = config.lock();
guard.kdc_agent_url = s;
0
}
Err(e) => {
proxy_log!(KDC_LOG_ERROR, &format!("SetKdcAgentUrl: failed to parse URL parameter, error code {}", e));
e as u32
}
}
}
#[no_mangle]
pub extern "C" fn SetCaFile(ca: *const c_char) -> u32 {
match cstr_to_string(ca) {
Ok(s) => {
if !is_valid_path(&s) {
proxy_log!(KDC_LOG_ERROR, "SetCaFile: invalid CA file path");
return CommonError::RequestError as u32;
}
let config = get_config_center();
let mut guard = config.lock();
guard.capath = s;
0
}
Err(e) => {
proxy_log!(KDC_LOG_ERROR, &format!("SetCaFile: failed to parse CA parameter, error code {}", e));
e as u32
}
}
}
#[no_mangle]
pub extern "C" fn SetTlsCertAndKeyFile(
tls_cert: *const c_char,
tls_key: *const c_char,
key_pwd: *const c_char,
) -> u32 {
let cert = match cstr_to_string(tls_cert) {
Ok(s) => s,
Err(e) => {
proxy_log!(KDC_LOG_ERROR, &format!("SetTlsCertAndKeyFile: failed to parse certificate parameter, error code {}", e));
return e as u32;
}
};
if !is_valid_path(&cert) {
proxy_log!(KDC_LOG_ERROR, "SetTlsCertAndKeyFile: invalid certificate file path");
return CommonError::RequestError as u32;
}
let key = match cstr_to_string(tls_key) {
Ok(s) => s,
Err(e) => {
proxy_log!(KDC_LOG_ERROR, &format!("SetTlsCertAndKeyFile: failed to parse key parameter, error code {}", e));
return e as u32;
}
};
if !is_valid_path(&key) {
proxy_log!(KDC_LOG_ERROR, "SetTlsCertAndKeyFile: invalid key file path");
return CommonError::RequestError as u32;
}
let pwd = if key_pwd.is_null() {
String::new()
} else {
match cstr_to_string(key_pwd) {
Ok(s) => s,
Err(e) => {
proxy_log!(KDC_LOG_ERROR, &format!("SetTlsCertAndKeyFile: failed to parse key password parameter, error code {}", e));
return e as u32;
}
}
};
let config = get_config_center();
let mut guard = config.lock();
guard.certpath = cert;
guard.privatepath = key;
guard.key_pwd = zeroize::Zeroizing::new(pwd);
0
}
#[no_mangle]
pub extern "C" fn SetCrlFile(crl: *const c_char) -> u32 {
match cstr_to_string(crl) {
Ok(s) => {
if !is_valid_path(&s) {
proxy_log!(KDC_LOG_ERROR, "SetCrlFile: invalid CRL file path");
return CommonError::RequestError as u32;
}
let config = get_config_center();
let mut guard = config.lock();
guard.crlpath = s;
0
}
Err(e) => {
proxy_log!(KDC_LOG_ERROR, &format!("SetCrlFile: failed to parse CRL parameter, error code {}", e));
e as u32
}
}
}
#[no_mangle]
pub extern "C" fn SetLogCallback(cb: Option<SetLogCallbackFunc>) -> u32 {
match cb {
Some(cb) => {
*get_log_callback_holder().lock() = Some(cb);
0
}
None => {
proxy_log!(KDC_LOG_ERROR, "SetLogCallback: callback function pointer is null");
CommonError::RequestError as u32
}
}
}
fn execute_trusted_request(
request_type: u32,
input_str: &str,
) -> Result<(u32, String), CommonError> {
if request_type == TRUSTED_ENV_TOKEN_REQ {
proxy_log!(KDC_LOG_INFO, "TRUSTED_ENV_TOKEN_REQ: requesting token from RA agent...");
let token = request_token_from_ra_agent(None)?;
let json_token = serde_json::json!({"token": token}).to_string();
return Ok((0, json_token));
}
let config = get_config_center().lock();
if config.psk.is_none() {
drop(config);
proxy_log!(KDC_LOG_ERROR, "PSK not registered, cannot process request");
return Err(CommonError::PskMismatchError);
}
drop(config);
let resp = match send_to_kdc_agent(request_type, Some(input_str)) {
Ok(resp) => resp,
Err(CommonError::PskMismatchError) => {
proxy_log!(KDC_LOG_WARN, "PSK mismatch on request, re-registering");
do_register_psk()?;
send_to_kdc_agent(request_type, Some(input_str))?
}
Err(e) => {
proxy_log!(KDC_LOG_ERROR, &format!("request type=0x{:04X} failed: {}", request_type, e));
return Err(e);
}
};
Ok((resp.ret_code, resp.data.unwrap_or_else(|| "{}".to_string())))
}
fn write_cstr_output(
output: *mut *mut c_char,
output_len: *mut u32,
result: Result<(u32, String), CommonError>,
) -> u32 {
match result {
Ok((ret_code, data)) => {
let len = data.len() as u32;
match CString::new(data) {
Ok(c_str) => {
unsafe {
*output = c_str.into_raw();
*output_len = len;
}
ret_code
}
Err(_) => {
proxy_log!(KDC_LOG_ERROR, "write_cstr_output: response data contains embedded NUL byte, returning empty JSON");
static EMPTY_JSON: &[u8] = b"{}\0";
unsafe {
*output = CString::from_vec_unchecked(EMPTY_JSON.to_vec()).into_raw();
*output_len = 2;
}
CommonError::InternalError as u32
}
}
}
Err(e) => {
static EMPTY_JSON: &[u8] = b"{}\0";
unsafe {
*output = CString::from_vec_unchecked(EMPTY_JSON.to_vec()).into_raw();
*output_len = 2;
}
e as u32
}
}
}
#[no_mangle]
pub extern "C" fn KdcProxyIdentityRegister() -> u32 {
let _lock = PSK_REGISTER_LOCK.lock();
let config = get_config_center().lock();
if config.psk.is_some() {
return 0;
}
drop(config);
match do_register_psk() {
Ok(()) => 0,
Err(e) => e as u32,
}
}
#[no_mangle]
pub extern "C" fn KdcTrustedRequest(
request_type: u32,
input: *const c_char,
input_len: u32,
output: *mut *mut c_char,
output_len: *mut u32,
) -> u32 {
if output.is_null() || output_len.is_null() {
proxy_log!(KDC_LOG_ERROR, "KdcTrustedRequest: output or output_len pointer is null");
return CommonError::RequestError as u32;
}
let input_str = match cstr_to_string(input) {
Ok(s) => s,
Err(e) => {
proxy_log!(KDC_LOG_ERROR, &format!("KdcTrustedRequest: failed to parse input parameter, error code {}", e));
return e as u32;
}
};
if input_str.len() != input_len as usize {
proxy_log!(KDC_LOG_ERROR, &format!("input length mismatch: actual {} != declared {}", input_str.len(), input_len));
return CommonError::RequestError as u32;
}
if input_str.len() > MAX_REQUEST_SIZE {
proxy_log!(KDC_LOG_ERROR, &format!("request rejected: message too large ({} bytes). please check your input.", input_str.len()));
return CommonError::RequestError as u32;
}
let result = execute_trusted_request(request_type, &input_str);
write_cstr_output(output, output_len, result)
}
#[no_mangle]
pub extern "C" fn KdcFreePtr(ptr: *mut c_char) {
if !ptr.is_null() {
unsafe {
let _ = CString::from_raw(ptr);
}
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
use super::*;
use crate::config_manager::SecurePsk;
use std::ffi::CStr;
use zeroize::Zeroizing;
extern "C" fn test_log_cb(
_severity: i32,
_msg: *const c_char,
_file: *const c_char,
_line: i32,
) {
}
fn reset_config() {
let config = get_config_center();
let mut guard = config.lock();
guard.ra_agent_url = String::new();
guard.kdc_agent_url = String::new();
guard.capath = String::new();
guard.certpath = String::new();
guard.privatepath = String::new();
guard.crlpath = String::new();
guard.key_pwd = Zeroizing::new(String::new());
guard.psk = None;
}
#[test]
fn test_has_control_chars_normal() {
assert!(!has_control_chars("hello world"));
assert!(!has_control_chars("https://example.com/path"));
}
#[test]
fn test_has_control_chars_with_ctrl() {
assert!(has_control_chars("hello\x01world"));
assert!(has_control_chars("hello\x00world"));
assert!(has_control_chars("\x1F"));
assert!(has_control_chars("hello\x7F"));
}
#[test]
fn test_has_control_chars_empty() {
assert!(!has_control_chars(""));
}
#[test]
fn test_is_valid_path_empty() {
assert!(!is_valid_path(""));
}
#[test]
fn test_is_valid_path_relative() {
assert!(!is_valid_path("relative/path/file.pem"));
}
#[test]
fn test_is_valid_path_nonexistent_absolute() {
assert!(!is_valid_path("/nonexistent/path/file.pem"));
}
#[test]
fn test_is_valid_path_existing_file() {
let temp_path = std::env::temp_dir().join("kdc_proxy_test_valid_path.pem");
std::fs::write(&temp_path, b"test").unwrap();
assert!(is_valid_path(temp_path.to_str().unwrap()));
let _ = std::fs::remove_file(&temp_path);
}
#[test]
fn test_cstr_to_string_null() {
let result = cstr_to_string(std::ptr::null());
assert!(result.is_err());
assert_eq!(result.unwrap_err(), CommonError::RequestError);
}
#[test]
fn test_cstr_to_string_valid() {
let s = CString::new("hello world").unwrap();
let result = cstr_to_string(s.as_ptr());
assert_eq!(result.unwrap(), "hello world");
}
#[test]
fn test_cstr_to_string_invalid_utf8() {
let bytes: &[u8] = &[0xFF, 0xFE, 0x00];
let ptr = bytes.as_ptr() as *const c_char;
let result = cstr_to_string(ptr);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), CommonError::RequestError);
}
#[test]
fn test_write_cstr_output_success() {
let mut output_ptr: *mut c_char = std::ptr::null_mut();
let mut output_len: u32 = 0;
let result = write_cstr_output(
&mut output_ptr,
&mut output_len,
Ok((0u32, "hello".to_string())),
);
assert_eq!(result, 0);
assert_eq!(output_len, 5);
let s = unsafe { CStr::from_ptr(output_ptr) }.to_str().unwrap();
assert_eq!(s, "hello");
KdcFreePtr(output_ptr);
}
#[test]
fn test_write_cstr_output_error() {
let mut output_ptr: *mut c_char = std::ptr::null_mut();
let mut output_len: u32 = 0;
let result = write_cstr_output(
&mut output_ptr,
&mut output_len,
Err(CommonError::InternalError),
);
assert_eq!(result, CommonError::InternalError as u32);
assert_eq!(output_len, 2);
let s = unsafe { CStr::from_ptr(output_ptr) }.to_str().unwrap();
assert_eq!(s, "{}");
KdcFreePtr(output_ptr);
}
#[test]
fn test_write_cstr_output_null_byte_in_data() {
let mut output_ptr: *mut c_char = std::ptr::null_mut();
let mut output_len: u32 = 0;
let result = write_cstr_output(
&mut output_ptr,
&mut output_len,
Ok((0u32, "hello\0world".to_string())),
);
assert_eq!(result, CommonError::InternalError as u32);
assert_eq!(output_len, 2);
let s = unsafe { CStr::from_ptr(output_ptr) }.to_str().unwrap();
assert_eq!(s, "{}");
KdcFreePtr(output_ptr);
}
#[test]
fn test_kdc_trusted_request_null_output() {
let result = KdcTrustedRequest(
0x0010,
std::ptr::null(),
0,
std::ptr::null_mut(),
std::ptr::null_mut(),
);
assert_eq!(result, CommonError::RequestError as u32);
}
#[test]
fn test_kdc_trusted_request_output_nonnull_len_null() {
let input = CString::new("test").unwrap();
let mut output_ptr: *mut c_char = std::ptr::null_mut();
let result = KdcTrustedRequest(
0x0010,
input.as_ptr(),
4,
&mut output_ptr,
std::ptr::null_mut(),
);
assert_eq!(result, CommonError::RequestError as u32);
}
#[test]
fn test_kdc_trusted_request_output_null_len_nonnull() {
let input = CString::new("test").unwrap();
let mut output_len: u32 = 0;
let result = KdcTrustedRequest(
0x0010,
input.as_ptr(),
4,
std::ptr::null_mut(),
&mut output_len,
);
assert_eq!(result, CommonError::RequestError as u32);
}
#[test]
fn test_kdc_trusted_request_null_input() {
let mut output_ptr: *mut c_char = std::ptr::null_mut();
let mut output_len: u32 = 0;
let result = KdcTrustedRequest(
0x0010,
std::ptr::null(),
0,
&mut output_ptr,
&mut output_len,
);
assert_eq!(result, CommonError::RequestError as u32);
}
#[test]
fn test_kdc_trusted_request_input_length_mismatch() {
reset_config();
let input = CString::new("test data").unwrap();
let mut output_ptr: *mut c_char = std::ptr::null_mut();
let mut output_len: u32 = 0;
let result = KdcTrustedRequest(
0x0020,
input.as_ptr(),
0,
&mut output_ptr,
&mut output_len,
);
assert_eq!(result, CommonError::RequestError as u32);
reset_config();
}
#[test]
fn test_kdc_trusted_request_input_too_large() {
reset_config();
let large_input = CString::new("x".repeat(20481)).unwrap();
let mut output_ptr: *mut c_char = std::ptr::null_mut();
let mut output_len: u32 = 0;
let result = KdcTrustedRequest(
0x0020,
large_input.as_ptr(),
20481,
&mut output_ptr,
&mut output_len,
);
assert_eq!(result, CommonError::RequestError as u32);
reset_config();
}
#[test]
fn test_kdc_trusted_request_token_req_no_url() {
reset_config();
let input = CString::new("").unwrap();
let mut output_ptr: *mut c_char = std::ptr::null_mut();
let mut output_len: u32 = 0;
let result = KdcTrustedRequest(
TRUSTED_ENV_TOKEN_REQ,
input.as_ptr(),
0,
&mut output_ptr,
&mut output_len,
);
assert_ne!(result, 0);
assert_eq!(output_len, 2);
let s = unsafe { CStr::from_ptr(output_ptr) }.to_str().unwrap();
assert_eq!(s, "{}");
KdcFreePtr(output_ptr);
reset_config();
}
#[test]
fn test_kdc_trusted_request_do_register_psk_fails() {
reset_config();
let input = CString::new("test data").unwrap();
let mut output_ptr: *mut c_char = std::ptr::null_mut();
let mut output_len: u32 = 0;
let result = KdcTrustedRequest(
0x0020,
input.as_ptr(),
9,
&mut output_ptr,
&mut output_len,
);
assert_ne!(result, 0);
assert_eq!(output_len, 2);
let s = unsafe { CStr::from_ptr(output_ptr) }.to_str().unwrap();
assert_eq!(s, "{}");
KdcFreePtr(output_ptr);
reset_config();
}
#[test]
fn test_kdc_trusted_request_with_psk_no_url() {
reset_config();
{
let config = get_config_center();
let mut guard = config.lock();
guard.psk = Some(SecurePsk([0u8; PSK_LEN]));
}
let input = CString::new("test data").unwrap();
let mut output_ptr: *mut c_char = std::ptr::null_mut();
let mut output_len: u32 = 0;
let result = KdcTrustedRequest(
0x0020,
input.as_ptr(),
9,
&mut output_ptr,
&mut output_len,
);
assert_eq!(result, CommonError::ConfigError as u32);
assert_eq!(output_len, 2);
let s = unsafe { CStr::from_ptr(output_ptr) }.to_str().unwrap();
assert_eq!(s, "{}");
KdcFreePtr(output_ptr);
reset_config();
}
#[test]
fn test_kdc_free_ptr_null() {
KdcFreePtr(std::ptr::null_mut());
}
#[test]
fn test_kdc_free_ptr_valid() {
let c_str = CString::new("test data").unwrap();
let ptr = c_str.into_raw();
KdcFreePtr(ptr);
}
#[test]
fn test_set_ra_agent_url() {
reset_config();
let url = CString::new("https://ra.example.com/token").unwrap();
let result = SetRAAgentTokenUrl(url.as_ptr());
assert_eq!(result, 0);
let config = get_config_center();
let guard = config.lock();
assert_eq!(guard.ra_agent_url, "https://ra.example.com/token");
}
#[test]
fn test_set_ra_agent_url_null() {
assert_eq!(
SetRAAgentTokenUrl(std::ptr::null()),
CommonError::RequestError as u32
);
}
#[test]
fn test_set_ra_agent_url_empty() {
reset_config();
let url = CString::new("").unwrap();
let result = SetRAAgentTokenUrl(url.as_ptr());
assert_ne!(result, 0);
}
#[test]
fn test_set_ra_agent_url_with_control_chars() {
reset_config();
let url = CString::new("https://ra.example.com\x01/token").unwrap();
let result = SetRAAgentTokenUrl(url.as_ptr());
assert_eq!(result, CommonError::RequestError as u32);
}
#[test]
fn test_set_kdc_agent_url() {
reset_config();
let url = CString::new("https://kdc.example.com/api").unwrap();
let result = SetKdcAgentUrl(url.as_ptr());
assert_eq!(result, 0);
let config = get_config_center();
let guard = config.lock();
assert_eq!(guard.kdc_agent_url, "https://kdc.example.com/api");
}
#[test]
fn test_set_kdc_agent_url_null() {
assert_eq!(
SetKdcAgentUrl(std::ptr::null()),
CommonError::RequestError as u32
);
}
#[test]
fn test_set_kdc_agent_url_with_control_chars() {
reset_config();
let url = CString::new("https://kdc.example.com\x01/api").unwrap();
let result = SetKdcAgentUrl(url.as_ptr());
assert_eq!(result, CommonError::RequestError as u32);
}
#[test]
fn test_set_ca_file() {
reset_config();
let temp_path = std::env::temp_dir().join("kdc_proxy_test_ca.pem");
std::fs::write(&temp_path, b"test ca data").unwrap();
let ca = CString::new(temp_path.to_str().unwrap()).unwrap();
let result = SetCaFile(ca.as_ptr());
assert_eq!(result, 0);
{
let config = get_config_center();
let guard = config.lock();
assert_eq!(guard.capath, temp_path.to_str().unwrap());
}
let _ = std::fs::remove_file(&temp_path);
reset_config();
}
#[test]
fn test_set_ca_file_null() {
assert_eq!(
SetCaFile(std::ptr::null()),
CommonError::RequestError as u32
);
}
#[test]
fn test_set_ca_file_relative_path() {
reset_config();
let ca = CString::new("relative/path/ca.pem").unwrap();
let result = SetCaFile(ca.as_ptr());
assert_eq!(result, CommonError::RequestError as u32);
}
#[test]
fn test_set_ca_file_nonexistent() {
reset_config();
let ca = CString::new("/nonexistent/path/ca.pem").unwrap();
let result = SetCaFile(ca.as_ptr());
assert_eq!(result, CommonError::RequestError as u32);
}
#[test]
fn test_set_ca_file_empty() {
reset_config();
let ca = CString::new("").unwrap();
let result = SetCaFile(ca.as_ptr());
assert_eq!(result, CommonError::RequestError as u32);
}
#[test]
fn test_set_crl_file() {
reset_config();
let temp_path = std::env::temp_dir().join("kdc_proxy_test_crl.pem");
std::fs::write(&temp_path, b"test crl data").unwrap();
let crl = CString::new(temp_path.to_str().unwrap()).unwrap();
let result = SetCrlFile(crl.as_ptr());
assert_eq!(result, 0);
{
let config = get_config_center();
let guard = config.lock();
assert_eq!(guard.crlpath, temp_path.to_str().unwrap());
}
let _ = std::fs::remove_file(&temp_path);
reset_config();
}
#[test]
fn test_set_crl_file_null() {
assert_eq!(
SetCrlFile(std::ptr::null()),
CommonError::RequestError as u32
);
}
#[test]
fn test_set_crl_file_relative_path() {
reset_config();
let crl = CString::new("relative/crl.pem").unwrap();
let result = SetCrlFile(crl.as_ptr());
assert_eq!(result, CommonError::RequestError as u32);
}
#[test]
fn test_set_crl_file_empty_rejected() {
reset_config();
let crl = CString::new("").unwrap();
let result = SetCrlFile(crl.as_ptr());
assert_eq!(result, CommonError::RequestError as u32);
}
#[test]
fn test_set_tls_cert_and_key_file() {
reset_config();
let cert_path = std::env::temp_dir().join("kdc_proxy_test_cert.pem");
let key_path = std::env::temp_dir().join("kdc_proxy_test_key.pem");
std::fs::write(&cert_path, b"cert data").unwrap();
std::fs::write(&key_path, b"key data").unwrap();
let cert = CString::new(cert_path.to_str().unwrap()).unwrap();
let key = CString::new(key_path.to_str().unwrap()).unwrap();
let pwd = CString::new("password123").unwrap();
let result = SetTlsCertAndKeyFile(cert.as_ptr(), key.as_ptr(), pwd.as_ptr());
assert_eq!(result, 0);
{
let config = get_config_center();
let guard = config.lock();
assert_eq!(guard.certpath, cert_path.to_str().unwrap());
assert_eq!(guard.privatepath, key_path.to_str().unwrap());
assert_eq!(&*guard.key_pwd, "password123");
}
let _ = std::fs::remove_file(&cert_path);
let _ = std::fs::remove_file(&key_path);
reset_config();
}
#[test]
fn test_set_tls_cert_null_cert() {
let key = CString::new("/etc/ssl/key.pem").unwrap();
let pwd = CString::new("password").unwrap();
let result = SetTlsCertAndKeyFile(std::ptr::null(), key.as_ptr(), pwd.as_ptr());
assert_eq!(result, CommonError::RequestError as u32);
}
#[test]
fn test_set_tls_cert_null_key() {
let cert = CString::new("/etc/ssl/cert.pem").unwrap();
let pwd = CString::new("password").unwrap();
let result = SetTlsCertAndKeyFile(cert.as_ptr(), std::ptr::null(), pwd.as_ptr());
assert_eq!(result, CommonError::RequestError as u32);
}
#[test]
fn test_set_tls_cert_null_pwd() {
let cert = CString::new("/etc/ssl/cert.pem").unwrap();
let key = CString::new("/etc/ssl/key.pem").unwrap();
let result = SetTlsCertAndKeyFile(cert.as_ptr(), key.as_ptr(), std::ptr::null());
assert_eq!(result, CommonError::RequestError as u32);
}
#[test]
fn test_set_log_callback() {
let result = SetLogCallback(Some(test_log_cb));
assert_eq!(result, 0);
let holder = get_log_callback_holder();
let guard = holder.lock();
assert!(guard.is_some());
}
#[test]
fn test_set_log_callback_null() {
*get_log_callback_holder().lock() = Some(test_log_cb);
let result = SetLogCallback(None);
assert_eq!(result, CommonError::RequestError as u32);
{
let holder = get_log_callback_holder();
let guard = holder.lock();
assert!(guard.is_some());
}
*get_log_callback_holder().lock() = None;
}
#[test]
fn test_kdc_proxy_identity_register_already_has_psk() {
reset_config();
{
let config = get_config_center();
let mut guard = config.lock();
guard.psk = Some(SecurePsk([42u8; PSK_LEN]));
}
let result = KdcProxyIdentityRegister();
assert_eq!(result, 0, "should return 0 when PSK already exists");
reset_config();
}
#[test]
fn test_kdc_proxy_identity_register_no_url_fails() {
reset_config();
let result = KdcProxyIdentityRegister();
assert_ne!(result, 0, "should fail without RA agent URL configured");
reset_config();
}
#[test]
fn test_set_null_pointer_returns_error() {
assert_eq!(
SetRAAgentTokenUrl(std::ptr::null()),
CommonError::RequestError as u32
);
assert_eq!(
SetKdcAgentUrl(std::ptr::null()),
CommonError::RequestError as u32
);
assert_eq!(
SetCaFile(std::ptr::null()),
CommonError::RequestError as u32
);
assert_eq!(
SetCrlFile(std::ptr::null()),
CommonError::RequestError as u32
);
}
#[test]
fn test_error_codes_are_distinct() {
let codes = [
CommonError::Ok as u32,
CommonError::RequestError as u32,
CommonError::PskMismatchError as u32,
CommonError::NetworkError as u32,
CommonError::ConfigError as u32,
CommonError::InternalError as u32,
];
for i in 0..codes.len() {
for j in (i + 1)..codes.len() {
assert_ne!(codes[i], codes[j], "error codes must be distinct");
}
}
}
#[test]
fn test_error_codes_are_nonzero_except_ok() {
assert_eq!(CommonError::Ok as u32, 0);
assert_ne!(CommonError::RequestError as u32, 0);
assert_ne!(CommonError::NetworkError as u32, 0);
assert_ne!(CommonError::ConfigError as u32, 0);
assert_ne!(CommonError::InternalError as u32, 0);
}
#[test]
fn test_set_kdc_agent_url_empty() {
reset_config();
let url = CString::new("").unwrap();
let result = SetKdcAgentUrl(url.as_ptr());
assert_eq!(result, CommonError::RequestError as u32);
}
#[test]
fn test_set_tls_cert_and_key_file_invalid_key_path() {
reset_config();
let cert_path = std::env::temp_dir().join("kdc_proxy_test_cert2.pem");
std::fs::write(&cert_path, b"cert").unwrap();
let cert = CString::new(cert_path.to_str().unwrap()).unwrap();
let key = CString::new("relative_key.pem").unwrap();
let pwd = CString::new("pw").unwrap();
let result = SetTlsCertAndKeyFile(cert.as_ptr(), key.as_ptr(), pwd.as_ptr());
assert_eq!(result, CommonError::RequestError as u32);
let _ = std::fs::remove_file(&cert_path);
}
#[test]
fn test_set_tls_cert_null_pwd_valid_paths() {
reset_config();
let cert_path = std::env::temp_dir().join("kdc_proxy_test_cert_npwd.pem");
let key_path = std::env::temp_dir().join("kdc_proxy_test_key_npwd.pem");
std::fs::write(&cert_path, b"cert data").unwrap();
std::fs::write(&key_path, b"key data").unwrap();
let cert = CString::new(cert_path.to_str().unwrap()).unwrap();
let key = CString::new(key_path.to_str().unwrap()).unwrap();
let result = SetTlsCertAndKeyFile(cert.as_ptr(), key.as_ptr(), std::ptr::null());
assert_eq!(result, 0);
{
let config = get_config_center();
let guard = config.lock();
assert_eq!(&*guard.key_pwd, "");
}
let _ = std::fs::remove_file(&cert_path);
let _ = std::fs::remove_file(&key_path);
reset_config();
}
#[test]
fn test_write_cstr_output_empty_string() {
let mut output_ptr: *mut c_char = std::ptr::null_mut();
let mut output_len: u32 = 99;
let result = write_cstr_output(&mut output_ptr, &mut output_len, Ok((0u32, String::new())));
assert_eq!(result, 0);
assert_eq!(output_len, 0);
let s = unsafe { CStr::from_ptr(output_ptr) }.to_str().unwrap();
assert_eq!(s, "");
KdcFreePtr(output_ptr);
}
#[test]
fn test_write_cstr_output_nonzero_ret_code() {
let mut output_ptr: *mut c_char = std::ptr::null_mut();
let mut output_len: u32 = 0;
let result = write_cstr_output(
&mut output_ptr,
&mut output_len,
Ok((42u32, "data".to_string())),
);
assert_eq!(result, 42);
assert_eq!(output_len, 4);
let s = unsafe { CStr::from_ptr(output_ptr) }.to_str().unwrap();
assert_eq!(s, "data");
KdcFreePtr(output_ptr);
}
#[test]
fn test_set_kdc_agent_url_empty_string() {
reset_config();
let url = CString::new("").unwrap();
let result = SetKdcAgentUrl(url.as_ptr());
assert_eq!(result, CommonError::RequestError as u32);
reset_config();
}
#[test]
fn test_set_tls_cert_and_key_file_nonexistent_cert() {
reset_config();
let cert = CString::new("/nonexistent/cert.pem").unwrap();
let key = CString::new("/nonexistent/key.pem").unwrap();
let pwd = CString::new("pw").unwrap();
let result = SetTlsCertAndKeyFile(cert.as_ptr(), key.as_ptr(), pwd.as_ptr());
assert_eq!(result, CommonError::RequestError as u32);
reset_config();
}
#[test]
fn test_set_tls_cert_valid_key_nonexistent() {
reset_config();
let cert_path = std::env::temp_dir().join("kdc_proxy_test_cert_vkn.pem");
std::fs::write(&cert_path, b"cert").unwrap();
let cert = CString::new(cert_path.to_str().unwrap()).unwrap();
let key = CString::new("/nonexistent/key.pem").unwrap();
let pwd = CString::new("pw").unwrap();
let result = SetTlsCertAndKeyFile(cert.as_ptr(), key.as_ptr(), pwd.as_ptr());
assert_eq!(result, CommonError::RequestError as u32);
let _ = std::fs::remove_file(&cert_path);
reset_config();
}
#[test]
fn test_kdc_trusted_request_token_req_output_valid() {
reset_config();
let input = CString::new("").unwrap();
let mut output_ptr: *mut c_char = std::ptr::null_mut();
let mut output_len: u32 = 0;
let result = KdcTrustedRequest(
TRUSTED_ENV_TOKEN_REQ,
input.as_ptr(),
0,
&mut output_ptr,
&mut output_len,
);
assert_ne!(result, 0);
assert_eq!(output_len, 2);
KdcFreePtr(output_ptr);
reset_config();
}
#[test]
fn test_psk_mismatch_error_code() {
assert_ne!(CommonError::PskMismatchError as u32, CommonError::Ok as u32);
assert_ne!(
CommonError::PskMismatchError as u32,
CommonError::RequestError as u32
);
assert_eq!(CommonError::PskMismatchError as u32, 2);
}
#[test]
fn test_do_register_psk_keypair_failure() {
reset_config();
fn mock_generate_ec_keypair(
) -> Result<(openssl::ec::EcKey<openssl::pkey::Private>, String), String> {
Err("mock keypair failure".to_string())
}
let _m = mockrs::mock!(crate::crypto::generate_ec_keypair, mock_generate_ec_keypair);
let result = do_register_psk();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), CommonError::InternalError);
reset_config();
}
#[test]
fn test_do_register_psk_base64_failure() {
reset_config();
let (_ec_key, _pub_pem) = crate::crypto::generate_ec_keypair().unwrap();
fn mock_generate_ec_keypair(
) -> Result<(openssl::ec::EcKey<openssl::pkey::Private>, String), String> {
let group = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1)
.map_err(|e| e.to_string())?;
let ec_key = openssl::ec::EcKey::generate(&group).map_err(|e| e.to_string())?;
let pub_pkey =
openssl::pkey::PKey::from_ec_key(ec_key.clone()).map_err(|e| e.to_string())?;
let pub_pem = pub_pkey.public_key_to_pem().map_err(|e| e.to_string())?;
let pub_pem_str = String::from_utf8(pub_pem).map_err(|e| e.to_string())?;
Ok((ec_key, pub_pem_str))
}
fn mock_request_token_from_ra_agent(
_ec_key: Option<&openssl::ec::EcKey<openssl::pkey::Private>>,
) -> Result<String, CommonError> {
Ok("mock_token".to_string())
}
fn mock_send_to_kdc_agent(
_msg_type: u32,
_data: Option<&str>,
) -> Result<common::http_types::KdcHttpResponse, CommonError> {
Ok(common::http_types::KdcHttpResponse {
ret_code: 0,
ret_msg: String::new(),
data: Some("!!!not-valid-base64!!!".to_string()),
})
}
let _m1 = mockrs::mock!(crate::crypto::generate_ec_keypair, mock_generate_ec_keypair);
let _m2 = mockrs::mock!(
crate::http_client::request_token_from_ra_agent,
mock_request_token_from_ra_agent
);
let _m3 = mockrs::mock!(
crate::http_client::send_to_kdc_agent,
mock_send_to_kdc_agent
);
let result = do_register_psk();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), CommonError::InternalError);
reset_config();
}
#[test]
fn test_do_register_psk_ecies_decrypt_failure() {
reset_config();
fn mock_generate_ec_keypair(
) -> Result<(openssl::ec::EcKey<openssl::pkey::Private>, String), String> {
let group = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1)
.map_err(|e| e.to_string())?;
let ec_key = openssl::ec::EcKey::generate(&group).map_err(|e| e.to_string())?;
let pub_pkey =
openssl::pkey::PKey::from_ec_key(ec_key.clone()).map_err(|e| e.to_string())?;
let pub_pem = pub_pkey.public_key_to_pem().map_err(|e| e.to_string())?;
let pub_pem_str = String::from_utf8(pub_pem).map_err(|e| e.to_string())?;
Ok((ec_key, pub_pem_str))
}
fn mock_request_token_from_ra_agent(
_ec_key: Option<&openssl::ec::EcKey<openssl::pkey::Private>>,
) -> Result<String, CommonError> {
Ok("mock_token".to_string())
}
fn mock_send_to_kdc_agent(
_msg_type: u32,
_data: Option<&str>,
) -> Result<common::http_types::KdcHttpResponse, CommonError> {
let garbage = vec![0xAAu8; 100];
Ok(common::http_types::KdcHttpResponse {
ret_code: 0,
ret_msg: String::new(),
data: Some(BASE64_STANDARD.encode(&garbage)),
})
}
let _m1 = mockrs::mock!(crate::crypto::generate_ec_keypair, mock_generate_ec_keypair);
let _m2 = mockrs::mock!(
crate::http_client::request_token_from_ra_agent,
mock_request_token_from_ra_agent
);
let _m3 = mockrs::mock!(
crate::http_client::send_to_kdc_agent,
mock_send_to_kdc_agent
);
let result = do_register_psk();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), CommonError::InternalError);
reset_config();
}
#[test]
fn test_execute_trusted_request_send_error_closure() {
reset_config();
{
let config = get_config_center();
let mut guard = config.lock();
guard.psk = Some(SecurePsk([0u8; PSK_LEN]));
}
fn mock_send_to_kdc_agent(
_msg_type: u32,
_data: Option<&str>,
) -> Result<common::http_types::KdcHttpResponse, CommonError> {
Err(CommonError::NetworkError)
}
let _m = mockrs::mock!(
crate::http_client::send_to_kdc_agent,
mock_send_to_kdc_agent
);
let result = execute_trusted_request(0x0020, "test data");
assert!(result.is_err());
assert_eq!(result.unwrap_err(), CommonError::NetworkError);
reset_config();
}
#[test]
fn test_do_register_psk_no_data_field() {
reset_config();
fn mock_generate_ec_keypair(
) -> Result<(openssl::ec::EcKey<openssl::pkey::Private>, String), String> {
let group = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1)
.map_err(|e| e.to_string())?;
let ec_key = openssl::ec::EcKey::generate(&group).map_err(|e| e.to_string())?;
let pub_pkey =
openssl::pkey::PKey::from_ec_key(ec_key.clone()).map_err(|e| e.to_string())?;
let pub_pem = pub_pkey.public_key_to_pem().map_err(|e| e.to_string())?;
let pub_pem_str = String::from_utf8(pub_pem).map_err(|e| e.to_string())?;
Ok((ec_key, pub_pem_str))
}
fn mock_request_token_from_ra_agent(
_ec_key: Option<&openssl::ec::EcKey<openssl::pkey::Private>>,
) -> Result<String, CommonError> {
Ok("mock_token".to_string())
}
fn mock_send_to_kdc_agent(
_msg_type: u32,
_data: Option<&str>,
) -> Result<common::http_types::KdcHttpResponse, CommonError> {
Ok(common::http_types::KdcHttpResponse {
ret_code: 0,
ret_msg: String::new(),
data: None,
})
}
let _m1 = mockrs::mock!(crate::crypto::generate_ec_keypair, mock_generate_ec_keypair);
let _m2 = mockrs::mock!(
crate::http_client::request_token_from_ra_agent,
mock_request_token_from_ra_agent
);
let _m3 = mockrs::mock!(
crate::http_client::send_to_kdc_agent,
mock_send_to_kdc_agent
);
let result = do_register_psk();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), CommonError::InternalError);
reset_config();
}
#[test]
fn test_execute_trusted_request_success_no_data() {
reset_config();
{
let config = get_config_center();
let mut guard = config.lock();
guard.psk = Some(SecurePsk([0u8; PSK_LEN]));
}
fn mock_send_ok_no_data(
_msg_type: u32,
_data: Option<&str>,
) -> Result<common::http_types::KdcHttpResponse, CommonError> {
Ok(common::http_types::KdcHttpResponse {
ret_code: 0,
ret_msg: String::new(),
data: None,
})
}
let _m = mockrs::mock!(crate::http_client::send_to_kdc_agent, mock_send_ok_no_data);
let result = execute_trusted_request(0x0020, "test data");
assert!(result.is_ok());
let (ret_code, data) = result.unwrap();
assert_eq!(ret_code, 0);
assert_eq!(data, "{}");
reset_config();
}
#[test]
fn test_execute_trusted_request_psk_mismatch_reregister() {
reset_config();
{
let config = get_config_center();
let mut guard = config.lock();
guard.psk = Some(SecurePsk([0u8; PSK_LEN]));
}
use std::sync::atomic::{AtomicUsize, Ordering};
static SEND_CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
fn mock_send_psk_mismatch_then_ok(
_msg_type: u32,
_data: Option<&str>,
) -> Result<common::http_types::KdcHttpResponse, CommonError> {
let count = SEND_CALL_COUNT.fetch_add(1, Ordering::SeqCst);
if count == 0 {
Err(CommonError::PskMismatchError)
} else {
Ok(common::http_types::KdcHttpResponse {
ret_code: 0,
ret_msg: String::new(),
data: Some("{}".to_string()),
})
}
}
fn mock_do_register_psk_ok() -> Result<(), CommonError> {
Ok(())
}
SEND_CALL_COUNT.store(0, Ordering::SeqCst);
let _m1 = mockrs::mock!(
crate::http_client::send_to_kdc_agent,
mock_send_psk_mismatch_then_ok
);
let _m2 = mockrs::mock!(do_register_psk, mock_do_register_psk_ok);
let result = execute_trusted_request(0x0020, "test data");
assert!(result.is_ok());
let (ret_code, data) = result.unwrap();
assert_eq!(ret_code, 0);
assert_eq!(data, "{}");
assert_eq!(SEND_CALL_COUNT.load(Ordering::SeqCst), 2);
let config = get_config_center();
let guard = config.lock();
assert!(guard.psk.is_some(), "PSK should remain after re-registration");
drop(guard);
reset_config();
}
}