use std::ffi::{CStr, CString};
use std::fs::File;
use std::io::BufReader;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use axum::routing::post;
use axum::Router;
use axum::http::StatusCode;
use openssl::asn1::Asn1Time;
use thiserror::Error;
use tokio::sync::watch;
use tower::limit::ConcurrencyLimitLayer;
use tower_http::timeout::TimeoutLayer;
use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine as _};
use common::http_types::{KdcHttpRequest, KdcHttpResponse};
use common::constants::{PSK_LEN, CommonError, MAX_INNER_REQUEST_SIZE, TRUSTED_ENV_VERIFY_REQ, TRUSTED_REGISTER_PSK_REQ};
use crate::config_manager::get_agent_config;
use crate::handler_registry::get_handler;
use crate::psk_manager::{get_psk, psk_decrypt_with_key, psk_encrypt_with_key};
use crate::types::{
KdcFreeResponsePtr, KdcRequestHandlerCbFunc, KdcRequestMsg,
KdcResponseMsg,
};
const CERT_CHECK_INTERVAL_SECS: u64 = 24 * 3600;
#[derive(Debug, Error)]
pub enum ServerError {
#[error("TLS error: {0}")]
TlsError(String),
#[error("Config error: {0}")]
ConfigError(String),
#[error("IO error: {0}")]
IoError(String),
#[error("Handler error: {0}")]
HandlerError(String),
}
static SHUTDOWN_TX: OnceLock<watch::Sender<bool>> = OnceLock::new();
pub fn build_router() -> Router {
Router::new()
.route("/rest/kdc_agent/v1/proxy", post(proxy_handler))
.layer(ConcurrencyLimitLayer::new(10))
.layer(TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, Duration::from_secs(30)))
.layer(axum::extract::DefaultBodyLimit::max(MAX_INNER_REQUEST_SIZE))
}
pub async fn start_http_server() -> Result<(), ServerError> {
let config = get_agent_config()
.ok_or_else(|| ServerError::ConfigError("ConfigCenter not initialized".to_string()))?;
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let server_config = build_rustls_server_config(config)?;
let tls_config = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(server_config));
let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
SHUTDOWN_TX
.set(shutdown_tx)
.map_err(|_| ServerError::ConfigError("server already started".to_string()))?;
let addr = format!("{}:{}", config.ip, config.port)
.parse()
.map_err(|e| ServerError::ConfigError(format!("invalid address: {}", e)))?;
let app = build_router();
let cert_path = config.cert_path.clone();
tokio::spawn(cert_expiry_monitor(cert_path));
let handle = axum_server::Handle::new();
let shutdown_handle = handle.clone();
tokio::spawn(async move {
let _ = shutdown_rx.changed().await;
shutdown_handle.graceful_shutdown(None);
});
axum_server::bind_rustls(addr, tls_config)
.handle(handle)
.serve(app.into_make_service())
.await
.map_err(|e: std::io::Error| ServerError::IoError(e.to_string()))?;
Ok(())
}
pub fn stop_http_server() -> Result<(), ServerError> {
match SHUTDOWN_TX.get() {
Some(tx) => tx
.send(true)
.map_err(|e| ServerError::ConfigError(format!("shutdown signal failed: {}", e))),
None => Err(ServerError::ConfigError("server not started".to_string())),
}
}
struct ResponseGuard {
res: *mut KdcResponseMsg,
free_fn: Option<KdcFreeResponsePtr>,
}
impl Drop for ResponseGuard {
fn drop(&mut self) {
if !self.res.is_null() {
if let Some(free_fn) = self.free_fn {
free_fn(self.res);
}
}
}
}
fn extract_token_field(request_data: &str) -> Result<String, KdcHttpResponse> {
let json: serde_json::Value = serde_json::from_str(request_data).map_err(|e| {
log::error!("invalid request JSON: {}", e);
KdcHttpResponse {
ret_code: CommonError::RequestError as u32,
ret_msg: format!("invalid request JSON: {}", e),
data: None,
}
})?;
json.get("token")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| {
log::error!("request rejected: missing token field");
KdcHttpResponse {
ret_code: CommonError::RequestError as u32,
ret_msg: "missing token field".to_string(),
data: None,
}
})
}
fn handle_psk_registration_request(request_data: &str) -> KdcHttpResponse {
let token = match extract_token_field(request_data) {
Ok(t) => t,
Err(resp) => return resp,
};
match crate::psk_manager::handle_psk_registration(&token) {
Ok(encrypted_psk) => KdcHttpResponse {
ret_code: 0,
ret_msg: "Success".to_string(),
data: Some(encrypted_psk),
},
Err(e) => {
log::error!("PSK registration failed: {}", e);
KdcHttpResponse {
ret_code: CommonError::RequestError as u32,
ret_msg: e,
data: None,
}
},
}
}
fn handle_verify_token_request(request_data: &str) -> KdcHttpResponse {
let token = match extract_token_field(request_data) {
Ok(t) => t,
Err(resp) => return resp,
};
match crate::token_verifier::verify_token(&token) {
Ok(_) => KdcHttpResponse {
ret_code: 0,
ret_msg: "Success".to_string(),
data: Some(serde_json::json!({"status": "pass"}).to_string()),
},
Err(e) => KdcHttpResponse {
ret_code: CommonError::RequestError as u32,
ret_msg: e.to_string(),
data: None,
},
}
}
fn decrypt_request_data(
data: &str,
psk: Option<[u8; PSK_LEN]>,
) -> Result<String, KdcHttpResponse> {
match psk {
Some(key) => {
let decoded_bytes = BASE64_STANDARD.decode(data).map_err(|_| {
log::error!("PSK mismatch: invalid base64 in encrypted data");
KdcHttpResponse {
ret_code: CommonError::PskMismatchError as u32,
ret_msg: "PSK mismatch: invalid encrypted data".to_string(),
data: None,
}
})?;
match psk_decrypt_with_key(Some(key), &decoded_bytes) {
Ok(plain) => Ok(String::from_utf8(plain).unwrap_or_else(|_| String::new())),
Err(_) => {
log::error!("PSK mismatch: decryption failed");
Err(KdcHttpResponse {
ret_code: CommonError::PskMismatchError as u32,
ret_msg: "PSK mismatch".to_string(),
data: None,
})
},
}
}
None => {
log::error!("request rejected: PSK not initialized");
Err(KdcHttpResponse {
ret_code: CommonError::PskMismatchError as u32,
ret_msg: "PSK not initialized".to_string(),
data: None,
})
},
}
}
fn call_plugin_handler(
handler_fn: KdcRequestHandlerCbFunc,
free_fn: KdcFreeResponsePtr,
data: &str,
) -> Result<Option<String>, KdcHttpResponse> {
let c_data = match CString::new(data) {
Ok(s) => s,
Err(_) => {
log::error!("request data contains embedded null byte");
return Err(KdcHttpResponse {
ret_code: CommonError::RequestError as u32,
ret_msg: "invalid request data".to_string(),
data: None,
});
}
};
let body_length = c_data.as_bytes().len() as u32;
let original_body = c_data.into_raw();
let kdc_req = KdcRequestMsg {
body: original_body,
body_length,
context: std::ptr::null_mut(),
};
let mut kdc_res: KdcResponseMsg = unsafe { std::mem::zeroed() };
handler_fn(&kdc_req, &mut kdc_res);
let _guard = ResponseGuard {
res: &mut kdc_res,
free_fn: Some(free_fn),
};
unsafe {
let _ = CString::from_raw(original_body);
}
let response_data = if !kdc_res.body.is_null() {
let c_str = unsafe { CStr::from_ptr(kdc_res.body) };
Some(c_str.to_string_lossy().into_owned())
} else {
log::warn!("plugin handler returned null response body");
None
};
Ok(response_data)
}
fn parse_handler_response(response_data: &Option<String>) -> (u32, String, Option<String>) {
let body_str = match response_data {
Some(s) => s,
None => return (0, "Success".to_string(), None),
};
match serde_json::from_str::<serde_json::Value>(body_str) {
Ok(json) => {
let ret_code = json
.get("retCode")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32;
let ret_msg = json
.get("retMsg")
.and_then(|v| v.as_str())
.unwrap_or("Success")
.to_string();
let parsed_data = match json.get("data") {
Some(data_val) if !data_val.is_null() => Some(data_val.to_string()),
Some(_) => None,
None => None,
};
(ret_code, ret_msg, parsed_data)
}
Err(_) => {
log::warn!("handler response is not valid JSON, passing through as raw data");
(0, "Success".to_string(), response_data.clone())
}
}
}
fn encrypt_response_data(
data: &Option<String>,
psk: Option<[u8; PSK_LEN]>,
) -> Option<String> {
match (data, psk) {
(Some(d), Some(key)) => match psk_encrypt_with_key(Some(key), d.as_bytes()) {
Ok(encrypted) => Some(BASE64_STANDARD.encode(&encrypted)),
Err(e) => {
log::error!("response encryption failed: {}", e);
None
}
},
_ => data.clone(),
}
}
async fn proxy_handler(
axum::Json(req): axum::Json<KdcHttpRequest>,
) -> axum::Json<KdcHttpResponse> {
if req.msg_type == TRUSTED_REGISTER_PSK_REQ {
return axum::Json(handle_psk_registration_request(
&req.data.unwrap_or_default(),
));
}
let data = req.data.unwrap_or_default();
let psk = get_psk();
let decrypted = match decrypt_request_data(&data, psk) {
Ok(d) => d,
Err(resp) => return axum::Json(resp),
};
if req.msg_type == TRUSTED_ENV_VERIFY_REQ {
let resp = handle_verify_token_request(&decrypted);
let final_data = encrypt_response_data(&resp.data, psk);
log::info!(
"proxy response sent: task_id={}, msg_type=0x{:04X}, ret_code={}",
req.task_id,
req.msg_type,
resp.ret_code
);
return axum::Json(KdcHttpResponse {
ret_code: resp.ret_code,
ret_msg: resp.ret_msg,
data: final_data,
});
}
let (handler_fn, free_fn) = match get_handler(req.msg_type) {
Some(pair) => pair,
None => {
log::error!("no handler registered for msg_type=0x{:04X}", req.msg_type);
return axum::Json(KdcHttpResponse {
ret_code: CommonError::InternalError as u32,
ret_msg: "no handler registered".to_string(),
data: None,
});
}
};
let response_data = match tokio::task::spawn_blocking(move || {
call_plugin_handler(handler_fn, free_fn, &decrypted)
}).await {
Ok(Ok(d)) => d,
Ok(Err(resp)) => return axum::Json(resp),
Err(_) => {
log::error!("handler task panicked for msg_type=0x{:04X}", req.msg_type);
return axum::Json(KdcHttpResponse {
ret_code: CommonError::InternalError as u32,
ret_msg: "handler task failed".to_string(),
data: None,
})
},
};
let (ret_code, ret_msg, parsed_data) = parse_handler_response(&response_data);
let final_data = encrypt_response_data(&parsed_data, psk);
log::info!(
"proxy response sent: task_id={}, msg_type=0x{:04X}, ret_code={}",
req.task_id,
req.msg_type,
ret_code
);
axum::Json(KdcHttpResponse {
ret_code,
ret_msg,
data: final_data,
})
}
fn load_certs(
path: &str,
) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>, ServerError> {
let file =
File::open(path).map_err(|e| ServerError::IoError(e.to_string()))?;
let mut reader = BufReader::new(file);
let certs: Result<Vec<_>, _> = rustls_pemfile::certs(&mut reader).collect();
certs.map_err(|e| {
ServerError::TlsError(format!("failed to load certs from {}: {}", path, e))
})
}
fn load_private_key(
path: &str,
) -> Result<rustls::pki_types::PrivateKeyDer<'static>, ServerError> {
let file =
File::open(path).map_err(|e| ServerError::IoError(e.to_string()))?;
let mut reader = BufReader::new(file);
rustls_pemfile::private_key(&mut reader)
.map_err(|e| {
ServerError::TlsError(format!(
"failed to load private key from {}: {}",
path, e
))
})?
.ok_or_else(|| {
ServerError::TlsError(format!("no private key found in {}", path))
})
}
fn build_rustls_server_config(
config: &crate::config_manager::ConfigCenter,
) -> Result<rustls::ServerConfig, ServerError> {
let server_certs = load_certs(&config.cert_path)?;
let private_key = load_private_key(&config.private_path)?;
let server_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(server_certs, private_key)
.map_err(|e| {
ServerError::TlsError(format!("failed to build server config: {}", e))
})?;
Ok(server_config)
}
async fn cert_expiry_monitor(cert_path: String) {
let mut interval = tokio::time::interval(Duration::from_secs(CERT_CHECK_INTERVAL_SECS));
interval.tick().await;
loop {
check_cert_expiry(&cert_path);
interval.tick().await;
}
}
#[derive(Debug, PartialEq)]
enum CertStatus {
Valid { remaining_days: u64, not_after: String },
Expired { not_after: String },
}
impl CertStatus {
#[allow(dead_code)]
fn not_after(&self) -> &str {
match self {
CertStatus::Valid { not_after, .. } | CertStatus::Expired { not_after } => not_after,
}
}
}
fn check_cert_expiry_result(cert_path: &str) -> Result<CertStatus, String> {
let pem_data = std::fs::read(cert_path)
.map_err(|e| format!("failed to read certificate '{}': {}", cert_path, e))?;
let cert = openssl::x509::X509::from_pem(&pem_data)
.map_err(|e| format!("failed to parse certificate '{}': {}", cert_path, e))?;
let not_after = cert.not_after();
let not_after_str = not_after.to_string();
let now_asn1 = Asn1Time::days_from_now(0)
.map_err(|e| format!("failed to get current time: {}", e))?;
let diff = now_asn1
.diff(not_after)
.map_err(|e| format!("failed to compute cert time difference: {}", e))?;
if is_expired(&diff) {
Ok(CertStatus::Expired {
not_after: not_after_str,
})
} else {
Ok(CertStatus::Valid {
remaining_days: diff.days as u64,
not_after: not_after_str,
})
}
}
fn is_expired(diff: &openssl::asn1::TimeDiff) -> bool {
diff.days < 0 || (diff.days == 0 && diff.secs <= 0)
}
fn check_cert_expiry(cert_path: &str) {
match check_cert_expiry_result(cert_path) {
Ok(CertStatus::Valid {
remaining_days,
not_after,
}) => {
if remaining_days > 7 {
log::info!(
"Certificate expires at {}, remaining: {} days",
not_after,
remaining_days
);
} else {
log::warn!(
"Certificate expiring soon, remaining: {} days",
remaining_days
);
}
}
Ok(CertStatus::Expired { not_after }) => {
log::error!("Certificate has expired! Expired at: {}", not_after);
}
Err(e) => {
log::error!("{}", e);
}
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
use super::*;
use std::ffi::CString;
use std::fs;
use common::psk;
static TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
extern "C" fn mock_handler(
_req: *const KdcRequestMsg,
res: *mut KdcResponseMsg,
) -> u32 {
let json = r#"{"retCode":0,"retMsg":"ok","data":"handled"}"#;
let c_str = CString::new(json).unwrap();
unsafe {
(*res).body = c_str.into_raw();
(*res).body_length = json.len() as u32;
}
0
}
extern "C" fn mock_handler_null_body(
_req: *const KdcRequestMsg,
res: *mut KdcResponseMsg,
) -> u32 {
unsafe {
(*res).body = std::ptr::null_mut();
(*res).body_length = 0;
}
0
}
#[allow(dead_code)]
extern "C" fn mock_handler_error(
_req: *const KdcRequestMsg,
res: *mut KdcResponseMsg,
) -> u32 {
let json = r#"{"retCode":1,"retMsg":"plugin error","data":null}"#;
let c_str = CString::new(json).unwrap();
unsafe {
(*res).body = c_str.into_raw();
(*res).body_length = json.len() as u32;
}
1
}
extern "C" fn mock_free(res: *mut KdcResponseMsg) {
if !res.is_null() {
unsafe {
if !(*res).body.is_null() {
let _ = CString::from_raw((*res).body);
}
}
}
}
#[test]
fn test_extract_token_field_valid() {
let result = extract_token_field(r#"{"token":"abc123"}"#);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "abc123");
}
#[test]
fn test_extract_token_field_invalid_json() {
let result = extract_token_field("not json");
assert!(result.is_err());
let resp = result.unwrap_err();
assert_eq!(resp.ret_code, CommonError::RequestError as u32);
assert!(resp.ret_msg.contains("invalid request JSON"));
}
#[test]
fn test_extract_token_field_missing_token() {
let result = extract_token_field(r#"{"data":"value"}"#);
assert!(result.is_err());
let resp = result.unwrap_err();
assert!(resp.ret_msg.contains("missing token field"));
}
#[test]
fn test_extract_token_field_token_not_string() {
let result = extract_token_field(r#"{"token":12345}"#);
assert!(result.is_err());
}
#[test]
fn test_decrypt_request_data_no_psk() {
let result = decrypt_request_data("data", None);
assert!(result.is_err());
let resp = result.unwrap_err();
assert_eq!(resp.ret_code, CommonError::PskMismatchError as u32);
assert!(resp.ret_msg.contains("PSK not initialized"));
}
#[test]
fn test_decrypt_request_data_invalid_hex() {
let key = [0u8; PSK_LEN];
let result = decrypt_request_data("zzzz", Some(key));
assert!(result.is_err());
let resp = result.unwrap_err();
assert!(resp.ret_msg.contains("PSK mismatch"));
}
#[test]
fn test_decrypt_request_data_valid_roundtrip() {
let _lock = TEST_LOCK.lock().unwrap();
let key = common::psk::generate_random_psk().unwrap();
let plaintext = b"hello world";
let encrypted = psk::psk_encrypt_with_key(plaintext, &key).unwrap();
let encoded_data = BASE64_STANDARD.encode(&encrypted);
let result = decrypt_request_data(&encoded_data, Some(key));
assert!(result.is_ok());
assert_eq!(result.unwrap(), "hello world");
}
#[test]
fn test_decrypt_request_data_wrong_key() {
let key_a = common::psk::generate_random_psk().unwrap();
let key_b = common::psk::generate_random_psk().unwrap();
let encrypted = psk::psk_encrypt_with_key(b"secret", &key_a).unwrap();
let encoded_data = BASE64_STANDARD.encode(&encrypted);
let result = decrypt_request_data(&encoded_data, Some(key_b));
assert!(result.is_err());
}
#[test]
fn test_parse_handler_response_none() {
let (code, msg, data) = parse_handler_response(&None);
assert_eq!(code, 0);
assert_eq!(msg, "Success");
assert!(data.is_none());
}
#[test]
fn test_parse_handler_response_valid_json() {
let resp = Some(r#"{"retCode":0,"retMsg":"ok","data":"result_data"}"#.to_string());
let (code, msg, data) = parse_handler_response(&resp);
assert_eq!(code, 0);
assert_eq!(msg, "ok");
assert_eq!(data, Some("\"result_data\"".to_string()));
}
#[test]
fn test_parse_handler_response_null_data() {
let resp = Some(r#"{"retCode":1,"retMsg":"error","data":null}"#.to_string());
let (code, msg, data) = parse_handler_response(&resp);
assert_eq!(code, 1);
assert_eq!(msg, "error");
assert!(data.is_none());
}
#[test]
fn test_parse_handler_response_missing_fields() {
let resp = Some(r#"{"retCode":5}"#.to_string());
let (code, msg, data) = parse_handler_response(&resp);
assert_eq!(code, 5);
assert_eq!(msg, "Success");
assert!(data.is_none());
}
#[test]
fn test_parse_handler_response_invalid_json() {
let resp = Some("not json".to_string());
let (code, msg, data) = parse_handler_response(&resp);
assert_eq!(code, 0);
assert_eq!(msg, "Success");
assert_eq!(data, Some("not json".to_string()));
}
#[test]
fn test_encrypt_response_data_with_psk() {
let key = common::psk::generate_random_psk().unwrap();
let data = Some("secret response".to_string());
let result = encrypt_response_data(&data, Some(key));
assert!(result.is_some());
let encrypted_hex = result.unwrap();
assert_ne!(encrypted_hex, "secret response");
}
#[test]
fn test_encrypt_response_data_none_data() {
let key = common::psk::generate_random_psk().unwrap();
let result = encrypt_response_data(&None, Some(key));
assert!(result.is_none());
}
#[test]
fn test_encrypt_response_data_no_psk() {
let data = Some("plain response".to_string());
let result = encrypt_response_data(&data, None);
assert_eq!(result, data);
}
#[test]
fn test_is_expired_negative_days() {
let diff = openssl::asn1::TimeDiff { days: -1, secs: 0 };
assert!(is_expired(&diff));
}
#[test]
fn test_is_expired_positive_days() {
let diff = openssl::asn1::TimeDiff { days: 30, secs: 0 };
assert!(!is_expired(&diff));
}
#[test]
fn test_call_plugin_handler_valid() {
let result = call_plugin_handler(mock_handler, mock_free, r#"{"key":"value"}"#);
assert!(result.is_ok());
let data = result.unwrap();
assert!(data.is_some());
assert!(data.unwrap().contains("handled"));
}
#[test]
fn test_call_plugin_handler_null_body() {
let result = call_plugin_handler(mock_handler_null_body, mock_free, "data");
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_call_plugin_handler_null_byte_in_data() {
let result = call_plugin_handler(mock_handler, mock_free, "data\0with\0nulls");
assert!(result.is_err());
let resp = result.unwrap_err();
assert_eq!(resp.ret_code, CommonError::RequestError as u32);
}
#[test]
fn test_handle_psk_registration_invalid_json() {
let resp = handle_psk_registration_request("not json");
assert_ne!(resp.ret_code, 0);
assert!(resp.ret_msg.contains("invalid request JSON"));
}
#[test]
fn test_handle_psk_registration_missing_token() {
let resp = handle_psk_registration_request(r#"{"data":"no_token"}"#);
assert_ne!(resp.ret_code, 0);
assert!(resp.ret_msg.contains("missing token field"));
}
#[test]
fn test_handle_verify_token_invalid_json() {
let resp = handle_verify_token_request("not json");
assert_ne!(resp.ret_code, 0);
assert!(resp.ret_msg.contains("invalid request JSON"));
}
#[test]
fn test_handle_verify_token_missing_token() {
let resp = handle_verify_token_request(r#"{"data":"no_token"}"#);
assert_ne!(resp.ret_code, 0);
assert!(resp.ret_msg.contains("missing token field"));
}
#[test]
fn test_check_cert_expiry_nonexistent() {
let result = check_cert_expiry_result("/nonexistent/path/cert.pem");
assert!(result.is_err());
}
#[test]
fn test_check_cert_expiry_invalid_pem() {
let dir = std::env::temp_dir().join("kdc_cert_test");
let _ = fs::create_dir_all(&dir);
let path = dir.join("invalid.pem");
fs::write(&path, "not a pem file").unwrap();
let result = check_cert_expiry_result(path.to_str().unwrap());
assert!(result.is_err());
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_check_cert_expiry_valid_cert() {
let dir = std::env::temp_dir().join("kdc_cert_test_valid");
let _ = fs::create_dir_all(&dir);
let path = dir.join("valid.pem");
let group = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let key = openssl::ec::EcKey::generate(&group).unwrap();
let pkey = openssl::pkey::PKey::from_ec_key(key).unwrap();
let mut builder = openssl::x509::X509Builder::new().unwrap();
builder.set_version(2).unwrap();
let serial = openssl::bn::BigNum::from_u32(1).unwrap();
builder.set_serial_number(&serial.to_asn1_integer().unwrap()).unwrap();
let name = openssl::x509::X509NameBuilder::new().unwrap().build();
builder.set_issuer_name(&name).unwrap();
builder.set_subject_name(&name).unwrap();
builder.set_not_before(openssl::asn1::Asn1Time::days_from_now(0).unwrap().as_ref()).unwrap();
builder.set_not_after(openssl::asn1::Asn1Time::days_from_now(365).unwrap().as_ref()).unwrap();
builder.set_pubkey(&pkey).unwrap();
builder.sign(&pkey, openssl::hash::MessageDigest::sha256()).unwrap();
let cert = builder.build();
let pem = cert.to_pem().unwrap();
fs::write(&path, pem).unwrap();
let result = check_cert_expiry_result(path.to_str().unwrap());
assert!(result.is_ok());
match result.unwrap() {
CertStatus::Valid { remaining_days, .. } => {
assert!(remaining_days > 0);
}
CertStatus::Expired { .. } => {
panic!("cert should be valid, not expired");
}
}
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_cert_status_not_after() {
let valid = CertStatus::Valid {
remaining_days: 30,
not_after: "Jan 1 00:00:00 2030 GMT".to_string(),
};
assert_eq!(valid.not_after(), "Jan 1 00:00:00 2030 GMT");
let expired = CertStatus::Expired {
not_after: "Jan 1 00:00:00 2020 GMT".to_string(),
};
assert_eq!(expired.not_after(), "Jan 1 00:00:00 2020 GMT");
}
#[test]
fn test_server_error_display() {
let e = ServerError::TlsError("tls fail".to_string());
assert!(e.to_string().contains("tls fail"));
let e = ServerError::ConfigError("config fail".to_string());
assert!(e.to_string().contains("config fail"));
let e = ServerError::IoError("io fail".to_string());
assert!(e.to_string().contains("io fail"));
let e = ServerError::HandlerError("handler fail".to_string());
assert!(e.to_string().contains("handler fail"));
}
#[test]
fn test_load_certs_nonexistent() {
let result = load_certs("/nonexistent/cert.pem");
assert!(result.is_err());
let err = result.unwrap_err();
match err {
ServerError::IoError(_) => {}
other => panic!("expected IoError, got {:?}", other),
}
}
#[test]
fn test_load_certs_invalid_pem() {
let dir = std::env::temp_dir().join("kdc_load_certs_test");
let _ = fs::create_dir_all(&dir);
let path = dir.join("invalid.pem");
fs::write(&path, "not a valid PEM").unwrap();
let result = load_certs(path.to_str().unwrap());
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_load_certs_valid() {
let dir = std::env::temp_dir().join("kdc_load_certs_valid");
let _ = fs::create_dir_all(&dir);
let path = dir.join("cert.pem");
let group = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let key = openssl::ec::EcKey::generate(&group).unwrap();
let pkey = openssl::pkey::PKey::from_ec_key(key).unwrap();
let mut builder = openssl::x509::X509Builder::new().unwrap();
builder.set_version(2).unwrap();
let serial = openssl::bn::BigNum::from_u32(1).unwrap();
builder.set_serial_number(&serial.to_asn1_integer().unwrap()).unwrap();
let name = openssl::x509::X509NameBuilder::new().unwrap().build();
builder.set_issuer_name(&name).unwrap();
builder.set_subject_name(&name).unwrap();
builder.set_not_before(openssl::asn1::Asn1Time::days_from_now(0).unwrap().as_ref()).unwrap();
builder.set_not_after(openssl::asn1::Asn1Time::days_from_now(365).unwrap().as_ref()).unwrap();
builder.set_pubkey(&pkey).unwrap();
builder.sign(&pkey, openssl::hash::MessageDigest::sha256()).unwrap();
let cert = builder.build();
let pem = cert.to_pem().unwrap();
fs::write(&path, pem).unwrap();
let result = load_certs(path.to_str().unwrap());
assert!(result.is_ok(), "load_certs should succeed with valid PEM");
let certs = result.unwrap();
assert!(!certs.is_empty(), "should have loaded at least one cert");
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_load_private_key_nonexistent() {
let result = load_private_key("/nonexistent/key.pem");
assert!(result.is_err());
match result.unwrap_err() {
ServerError::IoError(_) => {}
other => panic!("expected IoError, got {:?}", other),
}
}
#[test]
fn test_load_private_key_invalid_pem() {
let dir = std::env::temp_dir().join("kdc_load_key_invalid");
let _ = fs::create_dir_all(&dir);
let path = dir.join("key.pem");
fs::write(&path, "not a valid key PEM").unwrap();
let result = load_private_key(path.to_str().unwrap());
assert!(result.is_err());
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_load_private_key_valid() {
let dir = std::env::temp_dir().join("kdc_load_key_valid");
let _ = fs::create_dir_all(&dir);
let path = dir.join("key.pem");
let group = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let key = openssl::ec::EcKey::generate(&group).unwrap();
let pkey = openssl::pkey::PKey::from_ec_key(key).unwrap();
let pem = pkey.private_key_to_pem_pkcs8().unwrap();
fs::write(&path, pem).unwrap();
let result = load_private_key(path.to_str().unwrap());
assert!(result.is_ok(), "load_private_key should succeed with valid key PEM");
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_load_private_key_no_key_in_pem() {
let dir = std::env::temp_dir().join("kdc_load_key_empty");
let _ = fs::create_dir_all(&dir);
let path = dir.join("nokey.pem");
let group = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let key = openssl::ec::EcKey::generate(&group).unwrap();
let pkey = openssl::pkey::PKey::from_ec_key(key).unwrap();
let mut builder = openssl::x509::X509Builder::new().unwrap();
builder.set_version(2).unwrap();
let serial = openssl::bn::BigNum::from_u32(1).unwrap();
builder.set_serial_number(&serial.to_asn1_integer().unwrap()).unwrap();
builder.set_pubkey(&pkey).unwrap();
builder.sign(&pkey, openssl::hash::MessageDigest::sha256()).unwrap();
let cert_pem = builder.build().to_pem().unwrap();
fs::write(&path, cert_pem).unwrap();
let result = load_private_key(path.to_str().unwrap());
assert!(result.is_err(), "should fail when no private key found");
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_build_rustls_config_nonexistent_cert() {
let dir = std::env::temp_dir().join("kdc_rustls_ne");
let _ = fs::create_dir_all(&dir);
let key_path = dir.join("key.pem");
let group = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let key = openssl::ec::EcKey::generate(&group).unwrap();
let pkey = openssl::pkey::PKey::from_ec_key(key).unwrap();
fs::write(&key_path, pkey.private_key_to_pem_pkcs8().unwrap()).unwrap();
let config = crate::config_manager::ConfigCenter {
ip: "127.0.0.1".to_string(),
port: 8080,
ca_path: String::new(),
cert_path: "/nonexistent/cert.pem".to_string(),
private_path: key_path.to_str().unwrap().to_string(),
crl_path: String::new(),
ra_public_key_path: String::new(),
log_path: String::new(),
data_path: String::new(),
plugins: vec![],
log_max_size: 10,
log_backup_count: 5,
min_log_level: crate::types::KdcLogLevel::Info,
};
let result = build_rustls_server_config(&config);
assert!(result.is_err());
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_build_rustls_config_nonexistent_key() {
let dir = std::env::temp_dir().join("kdc_rustls_nk");
let _ = fs::create_dir_all(&dir);
let cert_path = dir.join("cert.pem");
let group = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let key = openssl::ec::EcKey::generate(&group).unwrap();
let pkey = openssl::pkey::PKey::from_ec_key(key).unwrap();
let mut builder = openssl::x509::X509Builder::new().unwrap();
builder.set_version(2).unwrap();
let serial = openssl::bn::BigNum::from_u32(1).unwrap();
builder.set_serial_number(&serial.to_asn1_integer().unwrap()).unwrap();
builder.set_pubkey(&pkey).unwrap();
builder.sign(&pkey, openssl::hash::MessageDigest::sha256()).unwrap();
fs::write(&cert_path, builder.build().to_pem().unwrap()).unwrap();
let config = crate::config_manager::ConfigCenter {
ip: "127.0.0.1".to_string(),
port: 8080,
ca_path: String::new(),
cert_path: cert_path.to_str().unwrap().to_string(),
private_path: "/nonexistent/key.pem".to_string(),
crl_path: String::new(),
ra_public_key_path: String::new(),
log_path: String::new(),
data_path: String::new(),
plugins: vec![],
log_max_size: 10,
log_backup_count: 5,
min_log_level: crate::types::KdcLogLevel::Info,
};
let result = build_rustls_server_config(&config);
assert!(result.is_err());
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_build_rustls_config_valid() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let dir = std::env::temp_dir().join("kdc_rustls_valid");
let _ = fs::create_dir_all(&dir);
let group = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let ec_key = openssl::ec::EcKey::generate(&group).unwrap();
let pkey = openssl::pkey::PKey::from_ec_key(ec_key).unwrap();
let mut builder = openssl::x509::X509Builder::new().unwrap();
builder.set_version(2).unwrap();
let serial = openssl::bn::BigNum::from_u32(1).unwrap();
builder.set_serial_number(&serial.to_asn1_integer().unwrap()).unwrap();
let name = openssl::x509::X509NameBuilder::new().unwrap().build();
builder.set_issuer_name(&name).unwrap();
builder.set_subject_name(&name).unwrap();
builder.set_not_before(openssl::asn1::Asn1Time::days_from_now(0).unwrap().as_ref()).unwrap();
builder.set_not_after(openssl::asn1::Asn1Time::days_from_now(365).unwrap().as_ref()).unwrap();
builder.set_pubkey(&pkey).unwrap();
builder.sign(&pkey, openssl::hash::MessageDigest::sha256()).unwrap();
let cert = builder.build();
let cert_path = dir.join("cert.pem");
let key_path = dir.join("key.pem");
fs::write(&cert_path, cert.to_pem().unwrap()).unwrap();
fs::write(&key_path, pkey.private_key_to_pem_pkcs8().unwrap()).unwrap();
let config = crate::config_manager::ConfigCenter {
ip: "127.0.0.1".to_string(),
port: 8080,
ca_path: String::new(),
cert_path: cert_path.to_str().unwrap().to_string(),
private_path: key_path.to_str().unwrap().to_string(),
crl_path: String::new(),
ra_public_key_path: String::new(),
log_path: String::new(),
data_path: String::new(),
plugins: vec![],
log_max_size: 10,
log_backup_count: 5,
min_log_level: crate::types::KdcLogLevel::Info,
};
let result = build_rustls_server_config(&config);
assert!(result.is_ok(), "should succeed with valid cert and key");
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_check_cert_expiry_expiring_soon() {
let dir = std::env::temp_dir().join("kdc_cert_expiring");
let _ = fs::create_dir_all(&dir);
let path = dir.join("expiring.pem");
let group = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let key = openssl::ec::EcKey::generate(&group).unwrap();
let pkey = openssl::pkey::PKey::from_ec_key(key).unwrap();
let mut builder = openssl::x509::X509Builder::new().unwrap();
builder.set_version(2).unwrap();
let serial = openssl::bn::BigNum::from_u32(1).unwrap();
builder.set_serial_number(&serial.to_asn1_integer().unwrap()).unwrap();
let name = openssl::x509::X509NameBuilder::new().unwrap().build();
builder.set_issuer_name(&name).unwrap();
builder.set_subject_name(&name).unwrap();
builder.set_not_before(openssl::asn1::Asn1Time::days_from_now(0).unwrap().as_ref()).unwrap();
builder.set_not_after(openssl::asn1::Asn1Time::days_from_now(3).unwrap().as_ref()).unwrap();
builder.set_pubkey(&pkey).unwrap();
builder.sign(&pkey, openssl::hash::MessageDigest::sha256()).unwrap();
let cert = builder.build();
fs::write(&path, cert.to_pem().unwrap()).unwrap();
check_cert_expiry(path.to_str().unwrap());
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_check_cert_expiry_expired_cert() {
let dir = std::env::temp_dir().join("kdc_cert_expired");
let _ = fs::create_dir_all(&dir);
let path = dir.join("expired.pem");
let group = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let key = openssl::ec::EcKey::generate(&group).unwrap();
let pkey = openssl::pkey::PKey::from_ec_key(key).unwrap();
let mut builder = openssl::x509::X509Builder::new().unwrap();
builder.set_version(2).unwrap();
let serial = openssl::bn::BigNum::from_u32(1).unwrap();
builder.set_serial_number(&serial.to_asn1_integer().unwrap()).unwrap();
let name = openssl::x509::X509NameBuilder::new().unwrap().build();
builder.set_issuer_name(&name).unwrap();
builder.set_subject_name(&name).unwrap();
builder.set_not_before(openssl::asn1::Asn1Time::days_from_now(0).unwrap().as_ref()).unwrap();
let not_after_str = openssl::asn1::Asn1Time::days_from_now(0).unwrap();
builder.set_not_after(not_after_str.as_ref()).unwrap();
builder.set_pubkey(&pkey).unwrap();
builder.sign(&pkey, openssl::hash::MessageDigest::sha256()).unwrap();
let cert = builder.build();
fs::write(&path, cert.to_pem().unwrap()).unwrap();
let result = check_cert_expiry_result(path.to_str().unwrap());
match result {
Ok(CertStatus::Expired { .. }) | Ok(CertStatus::Valid { .. }) => {}
Err(_) => {}
}
check_cert_expiry(path.to_str().unwrap());
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_check_cert_expiry_error_path() {
check_cert_expiry("/nonexistent/cert.pem");
}
#[tokio::test]
async fn test_proxy_handler_no_handler() {
let _lock = TEST_LOCK.lock().unwrap();
let key = common::psk::generate_random_psk().unwrap();
let encrypted = psk::psk_encrypt_with_key(b"test", &key).unwrap();
let encoded_data = BASE64_STANDARD.encode(&encrypted);
*crate::psk_manager::get_psk_manager().lock() = Some(key);
let req = KdcHttpRequest {
task_id: "test".to_string(),
version: "1.0".to_string(),
msg_type: 0xFFFE,
data: Some(encoded_data),
};
let resp = proxy_handler(axum::Json(req)).await;
assert_eq!(resp.0.ret_code, CommonError::InternalError as u32);
assert!(resp.0.ret_msg.contains("no handler registered"));
*crate::psk_manager::get_psk_manager().lock() = None;
}
#[tokio::test]
async fn test_proxy_handler_psk_registration() {
let req = KdcHttpRequest {
task_id: "test".to_string(),
version: "1.0".to_string(),
msg_type: TRUSTED_REGISTER_PSK_REQ,
data: Some("not json".to_string()),
};
let resp = proxy_handler(axum::Json(req)).await;
assert_ne!(resp.0.ret_code, 0);
}
#[tokio::test]
async fn test_proxy_handler_decrypt_no_psk() {
let _lock = TEST_LOCK.lock().unwrap();
*crate::psk_manager::get_psk_manager().lock() = None;
let req = KdcHttpRequest {
task_id: "test".to_string(),
version: "1.0".to_string(),
msg_type: 0x0020,
data: Some("somehexdata".to_string()),
};
let resp = proxy_handler(axum::Json(req)).await;
assert_eq!(resp.0.ret_code, CommonError::PskMismatchError as u32);
}
#[tokio::test]
async fn test_proxy_handler_with_registered_handler() {
let _lock = TEST_LOCK.lock().unwrap();
let msg_type = 0x0099u32;
extern "C" fn test_dispatch_handler(
_req: *const KdcRequestMsg,
res: *mut KdcResponseMsg,
) -> u32 {
let json = r#"{"retCode":0,"retMsg":"dispatched","data":"result"}"#;
let c_str = CString::new(json).unwrap();
unsafe {
(*res).body = c_str.into_raw();
(*res).body_length = json.len() as u32;
}
0
}
let _ = crate::handler_registry::unregister_handler(msg_type);
let _ = crate::handler_registry::register_handler(msg_type, test_dispatch_handler, mock_free);
let key = common::psk::generate_random_psk().unwrap();
let encrypted = psk::psk_encrypt_with_key(b"dispatch test", &key).unwrap();
let encoded_data = BASE64_STANDARD.encode(&encrypted);
*crate::psk_manager::get_psk_manager().lock() = Some(key);
let req = KdcHttpRequest {
task_id: "test_dispatch".to_string(),
version: "1.0".to_string(),
msg_type,
data: Some(encoded_data),
};
let resp = proxy_handler(axum::Json(req)).await;
assert_eq!(resp.0.ret_code, 0);
assert_eq!(resp.0.ret_msg, "dispatched");
let _ = crate::handler_registry::unregister_handler(msg_type);
*crate::psk_manager::get_psk_manager().lock() = None;
}
#[tokio::test]
async fn test_proxy_handler_verify_token_invalid() {
let _lock = TEST_LOCK.lock().unwrap();
let key = common::psk::generate_random_psk().unwrap();
let encrypted = psk::psk_encrypt_with_key(b"not json", &key).unwrap();
let encoded_data = BASE64_STANDARD.encode(&encrypted);
*crate::psk_manager::get_psk_manager().lock() = Some(key);
let req = KdcHttpRequest {
task_id: "test_verify".to_string(),
version: "1.0".to_string(),
msg_type: TRUSTED_ENV_VERIFY_REQ,
data: Some(encoded_data),
};
let resp = proxy_handler(axum::Json(req)).await;
assert_ne!(resp.0.ret_code, 0);
*crate::psk_manager::get_psk_manager().lock() = None;
}
#[tokio::test]
async fn test_start_http_server_no_config() {
if get_agent_config().is_some() {
eprintln!("skipping: config already initialized");
return;
}
let result = start_http_server().await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("ConfigCenter not initialized"));
}
#[tokio::test]
async fn test_proxy_handler_verify_with_psk_no_token() {
let _lock = TEST_LOCK.lock().unwrap();
let key = common::psk::generate_random_psk().unwrap();
let encrypted = psk::psk_encrypt_with_key(br#"{"data":"no_token"}"#, &key).unwrap();
let encoded_data = BASE64_STANDARD.encode(&encrypted);
*crate::psk_manager::get_psk_manager().lock() = Some(key);
let req = KdcHttpRequest {
task_id: "test".to_string(),
version: "1.0".to_string(),
msg_type: TRUSTED_ENV_VERIFY_REQ,
data: Some(encoded_data),
};
let resp = proxy_handler(axum::Json(req)).await;
assert_ne!(resp.0.ret_code, 0);
*crate::psk_manager::get_psk_manager().lock() = None;
}
#[tokio::test]
async fn test_proxy_handler_verify_encrypts_response() {
let _lock = TEST_LOCK.lock().unwrap();
let key = common::psk::generate_random_psk().unwrap();
let encrypted = psk::psk_encrypt_with_key(br#"{"token":"bad_token"}"#, &key).unwrap();
let encoded_data = BASE64_STANDARD.encode(&encrypted);
*crate::psk_manager::get_psk_manager().lock() = Some(key);
let req = KdcHttpRequest {
task_id: "test".to_string(),
version: "1.0".to_string(),
msg_type: TRUSTED_ENV_VERIFY_REQ,
data: Some(encoded_data),
};
let resp = proxy_handler(axum::Json(req)).await;
assert_ne!(resp.0.ret_code, 0);
*crate::psk_manager::get_psk_manager().lock() = None;
}
#[tokio::test]
async fn test_proxy_handler_plugin_handler_encrypted_response() {
let _lock = TEST_LOCK.lock().unwrap();
let msg_type = 0x0088u32;
extern "C" fn encrypted_resp_handler(
_req: *const KdcRequestMsg,
res: *mut KdcResponseMsg,
) -> u32 {
let json = r#"{"retCode":0,"retMsg":"ok","data":"encrypted_result"}"#;
let c_str = CString::new(json).unwrap();
unsafe {
(*res).body = c_str.into_raw();
(*res).body_length = json.len() as u32;
}
0
}
let _ = crate::handler_registry::unregister_handler(msg_type);
let _ = crate::handler_registry::register_handler(msg_type, encrypted_resp_handler, mock_free);
let key = common::psk::generate_random_psk().unwrap();
let encrypted = psk::psk_encrypt_with_key(b"test payload", &key).unwrap();
let encoded_data = BASE64_STANDARD.encode(&encrypted);
*crate::psk_manager::get_psk_manager().lock() = Some(key);
let req = KdcHttpRequest {
task_id: "test_enc".to_string(),
version: "1.0".to_string(),
msg_type,
data: Some(encoded_data),
};
let resp = proxy_handler(axum::Json(req)).await;
assert_eq!(resp.0.ret_code, 0);
assert!(resp.0.data.is_some());
let _ = crate::handler_registry::unregister_handler(msg_type);
*crate::psk_manager::get_psk_manager().lock() = None;
}
#[test]
fn test_response_guard_drop_no_free_fn() {
let guard = ResponseGuard {
res: std::ptr::NonNull::dangling().as_ptr(),
free_fn: None,
};
drop(guard);
}
#[test]
fn test_response_guard_drop_with_free_fn() {
let mut kdc_res: KdcResponseMsg = unsafe { std::mem::zeroed() };
let c_str = CString::new("test response").unwrap();
kdc_res.body = c_str.into_raw();
kdc_res.body_length = 12;
let guard = ResponseGuard {
res: &mut kdc_res,
free_fn: Some(mock_free),
};
drop(guard);
}
#[test]
fn test_stop_http_server_not_started() {
match SHUTDOWN_TX.get() {
None => {
let result = stop_http_server();
assert!(result.is_err());
match result.unwrap_err() {
ServerError::ConfigError(msg) => {
assert!(msg.contains("server not started"));
}
other => panic!("expected ConfigError, got {:?}", other),
}
}
Some(_) => {
}
}
}
#[test]
fn test_is_expired_zero_days_zero_secs() {
let diff = openssl::asn1::TimeDiff { days: 0, secs: 0 };
assert!(is_expired(&diff));
}
#[test]
fn test_is_expired_zero_days_one_sec() {
let diff = openssl::asn1::TimeDiff { days: 0, secs: 1 };
assert!(!is_expired(&diff));
}
#[test]
fn test_is_expired_zero_days_neg_one_sec() {
let diff = openssl::asn1::TimeDiff { days: 0, secs: -1 };
assert!(is_expired(&diff));
}
#[test]
fn test_build_router_no_panic() {
let _router = build_router();
}
#[test]
fn test_check_cert_expiry_valid_long_expiry() {
let dir = std::env::temp_dir().join("kdc_cert_test_long_valid");
let _ = fs::create_dir_all(&dir);
let path = dir.join("long_valid.pem");
let group = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let key = openssl::ec::EcKey::generate(&group).unwrap();
let pkey = openssl::pkey::PKey::from_ec_key(key).unwrap();
let mut builder = openssl::x509::X509Builder::new().unwrap();
builder.set_version(2).unwrap();
let serial = openssl::bn::BigNum::from_u32(1).unwrap();
builder.set_serial_number(&serial.to_asn1_integer().unwrap()).unwrap();
let name = openssl::x509::X509NameBuilder::new().unwrap().build();
builder.set_issuer_name(&name).unwrap();
builder.set_subject_name(&name).unwrap();
builder.set_not_before(openssl::asn1::Asn1Time::days_from_now(0).unwrap().as_ref()).unwrap();
builder.set_not_after(openssl::asn1::Asn1Time::days_from_now(365).unwrap().as_ref()).unwrap();
builder.set_pubkey(&pkey).unwrap();
builder.sign(&pkey, openssl::hash::MessageDigest::sha256()).unwrap();
let cert = builder.build();
fs::write(&path, cert.to_pem().unwrap()).unwrap();
check_cert_expiry(path.to_str().unwrap());
let result = check_cert_expiry_result(path.to_str().unwrap());
assert!(result.is_ok());
match result.unwrap() {
CertStatus::Valid { remaining_days, .. } => {
assert!(remaining_days > 7);
}
CertStatus::Expired { .. } => {
panic!("cert should be valid with >7 days remaining");
}
}
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_handle_psk_registration_request_success() {
fn mock_handle_psk_registration(_token: &str) -> Result<String, String> {
Ok("encrypted_psk_b64".to_string())
}
let _mocker = mockrs::mock!(
crate::psk_manager::handle_psk_registration,
mock_handle_psk_registration
);
let resp = handle_psk_registration_request(r#"{"token":"valid_token"}"#);
assert_eq!(resp.ret_code, 0);
assert_eq!(resp.data, Some("encrypted_psk_b64".to_string()));
}
#[test]
fn test_handle_verify_token_request_success() {
fn mock_verify_token(_token: &str) -> Result<serde_json::Value, crate::token_verifier::TokenVerifierError> {
Ok(serde_json::json!({"status": "pass"}))
}
let _mocker = mockrs::mock!(
crate::token_verifier::verify_token,
mock_verify_token
);
let resp = handle_verify_token_request(r#"{"token":"valid_token"}"#);
assert_eq!(resp.ret_code, 0);
assert!(resp.data.is_some());
assert!(resp.data.as_ref().unwrap().contains("pass"));
}
#[test]
fn test_handle_verify_token_request_failure() {
fn mock_verify_token(_token: &str) -> Result<serde_json::Value, crate::token_verifier::TokenVerifierError> {
Err(crate::token_verifier::TokenVerifierError::VerificationFailed("bad token".to_string()))
}
let _mocker = mockrs::mock!(
crate::token_verifier::verify_token,
mock_verify_token
);
let resp = handle_verify_token_request(r#"{"token":"bad_token"}"#);
assert_ne!(resp.ret_code, 0);
assert!(resp.ret_msg.contains("bad token"));
}
#[test]
fn test_handle_psk_registration_request_failure() {
fn mock_handle_psk_registration(_token: &str) -> Result<String, String> {
Err("registration failed".to_string())
}
let _mocker = mockrs::mock!(
crate::psk_manager::handle_psk_registration,
mock_handle_psk_registration
);
let resp = handle_psk_registration_request(r#"{"token":"some_token"}"#);
assert_ne!(resp.ret_code, 0);
assert!(resp.ret_msg.contains("registration failed"));
}
#[test]
fn test_load_certs_corrupted_pem() {
let dir = std::env::temp_dir().join("kdc_load_certs_corrupt");
let _ = fs::create_dir_all(&dir);
let path = dir.join("corrupted.pem");
fs::write(&path, "-----BEGIN CERTIFICATE-----\nnot valid base64!!!\n-----END CERTIFICATE-----\n").unwrap();
let result = load_certs(path.to_str().unwrap());
assert!(result.is_err() || result.unwrap().is_empty());
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_load_private_key_corrupted_pem() {
let dir = std::env::temp_dir().join("kdc_load_key_corrupt");
let _ = fs::create_dir_all(&dir);
let path = dir.join("corrupted_key.pem");
fs::write(&path, "-----BEGIN PRIVATE KEY-----\n!!!invalid!!!\n-----END PRIVATE KEY-----\n").unwrap();
let result = load_private_key(path.to_str().unwrap());
assert!(result.is_err());
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_build_rustls_config_mismatched_key() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let dir = std::env::temp_dir().join("kdc_rustls_mismatch");
let _ = fs::create_dir_all(&dir);
let group_a = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let ec_key_a = openssl::ec::EcKey::generate(&group_a).unwrap();
let pkey_a = openssl::pkey::PKey::from_ec_key(ec_key_a).unwrap();
let mut builder = openssl::x509::X509Builder::new().unwrap();
builder.set_version(2).unwrap();
let serial = openssl::bn::BigNum::from_u32(1).unwrap();
builder.set_serial_number(&serial.to_asn1_integer().unwrap()).unwrap();
let name = openssl::x509::X509NameBuilder::new().unwrap().build();
builder.set_issuer_name(&name).unwrap();
builder.set_subject_name(&name).unwrap();
builder.set_not_before(openssl::asn1::Asn1Time::days_from_now(0).unwrap().as_ref()).unwrap();
builder.set_not_after(openssl::asn1::Asn1Time::days_from_now(365).unwrap().as_ref()).unwrap();
builder.set_pubkey(&pkey_a).unwrap();
builder.sign(&pkey_a, openssl::hash::MessageDigest::sha256()).unwrap();
let cert = builder.build();
let group_b = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let ec_key_b = openssl::ec::EcKey::generate(&group_b).unwrap();
let pkey_b = openssl::pkey::PKey::from_ec_key(ec_key_b).unwrap();
let cert_path = dir.join("cert.pem");
let key_path = dir.join("key.pem");
fs::write(&cert_path, cert.to_pem().unwrap()).unwrap();
fs::write(&key_path, pkey_b.private_key_to_pem_pkcs8().unwrap()).unwrap();
let config = crate::config_manager::ConfigCenter {
ip: "127.0.0.1".to_string(),
port: 8080,
ca_path: String::new(),
cert_path: cert_path.to_str().unwrap().to_string(),
private_path: key_path.to_str().unwrap().to_string(),
crl_path: String::new(),
ra_public_key_path: String::new(),
log_path: String::new(),
data_path: String::new(),
plugins: vec![],
log_max_size: 10,
log_backup_count: 5,
min_log_level: crate::types::KdcLogLevel::Info,
};
let result = build_rustls_server_config(&config);
assert!(result.is_err(), "should fail with mismatched cert and key");
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_load_certs_multiple_certs() {
let dir = std::env::temp_dir().join("kdc_load_certs_multi");
let _ = fs::create_dir_all(&dir);
let path = dir.join("multi.pem");
let group = openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let mut all_pem = Vec::new();
for _ in 0..3 {
let key = openssl::ec::EcKey::generate(&group).unwrap();
let pkey = openssl::pkey::PKey::from_ec_key(key).unwrap();
let mut builder = openssl::x509::X509Builder::new().unwrap();
builder.set_version(2).unwrap();
let serial = openssl::bn::BigNum::from_u32(1).unwrap();
builder.set_serial_number(&serial.to_asn1_integer().unwrap()).unwrap();
let name = openssl::x509::X509NameBuilder::new().unwrap().build();
builder.set_issuer_name(&name).unwrap();
builder.set_subject_name(&name).unwrap();
builder.set_not_before(openssl::asn1::Asn1Time::days_from_now(0).unwrap().as_ref()).unwrap();
builder.set_not_after(openssl::asn1::Asn1Time::days_from_now(365).unwrap().as_ref()).unwrap();
builder.set_pubkey(&pkey).unwrap();
builder.sign(&pkey, openssl::hash::MessageDigest::sha256()).unwrap();
all_pem.extend_from_slice(&builder.build().to_pem().unwrap());
}
fs::write(&path, &all_pem).unwrap();
let result = load_certs(path.to_str().unwrap());
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 3, "should load all 3 certs");
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_decrypt_request_data_non_utf8() {
let _lock = TEST_LOCK.lock().unwrap();
let key = common::psk::generate_random_psk().unwrap();
let non_utf8: Vec<u8> = vec![0xFF, 0xFE, 0xFD, 0xFC, 0xFB];
let encrypted = psk::psk_encrypt_with_key(&non_utf8, &key).unwrap();
let encoded_data = BASE64_STANDARD.encode(&encrypted);
let result = decrypt_request_data(&encoded_data, Some(key));
assert!(result.is_ok());
assert_eq!(result.unwrap(), "");
}
#[test]
fn test_encrypt_response_data_encrypt_fails() {
fn mock_encrypt(_key: Option<[u8; PSK_LEN]>, _data: &[u8]) -> Result<Vec<u8>, String> {
Err("encryption failed".to_string())
}
let _mocker = mockrs::mock!(crate::psk_manager::psk_encrypt_with_key, mock_encrypt);
let key = common::psk::generate_random_psk().unwrap();
let data = Some("secret data".to_string());
let result = encrypt_response_data(&data, Some(key));
assert!(result.is_none());
}
#[test]
fn test_handle_psk_registration_request_success_format() {
fn mock_handle(_token: &str) -> Result<String, String> {
Ok("base64encrypteddata".to_string())
}
let _mocker = mockrs::mock!(
crate::psk_manager::handle_psk_registration,
mock_handle
);
let resp = handle_psk_registration_request(r#"{"token":"tok123"}"#);
assert_eq!(resp.ret_code, 0);
assert_eq!(resp.ret_msg, "Success");
assert_eq!(resp.data, Some("base64encrypteddata".to_string()));
}
#[test]
fn test_handle_verify_token_request_verify_error() {
fn mock_verify(_token: &str) -> Result<serde_json::Value, crate::token_verifier::TokenVerifierError> {
Err(crate::token_verifier::TokenVerifierError::SignatureInvalid)
}
let _mocker = mockrs::mock!(
crate::token_verifier::verify_token,
mock_verify
);
let resp = handle_verify_token_request(r#"{"token":"signed_token"}"#);
assert_ne!(resp.ret_code, 0);
assert!(resp.ret_msg.contains("signature"));
}
#[test]
fn test_parse_handler_response_string_data() {
let resp = Some(r#"{"retCode":0,"retMsg":"ok","data":"plain string"}"#.to_string());
let (code, msg, data) = parse_handler_response(&resp);
assert_eq!(code, 0);
assert_eq!(msg, "ok");
assert_eq!(data, Some("\"plain string\"".to_string()));
}
#[test]
fn test_load_certs_empty_file() {
let dir = std::env::temp_dir().join("kdc_load_certs_empty");
let _ = fs::create_dir_all(&dir);
let path = dir.join("empty.pem");
fs::write(&path, "").unwrap();
let result = load_certs(path.to_str().unwrap());
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_verify_token_wrong_key_jwt() {
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use openssl::ecdsa::EcdsaSig;
use openssl::hash::{Hasher, MessageDigest};
let group =
openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let wrong_key = openssl::ec::EcKey::generate(&group).unwrap();
let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"ES256","typ":"JWT"}"#);
let body = URL_SAFE_NO_PAD.encode(br#"{"status":"pass"}"#);
let signing_input = format!("{}.{}", header, body);
let mut hasher = Hasher::new(MessageDigest::sha256()).unwrap();
hasher.update(signing_input.as_bytes()).unwrap();
let hash = hasher.finish().unwrap();
let sig = EcdsaSig::sign(&hash, &wrong_key).unwrap();
let r_bytes = sig.r().to_vec();
let s_bytes = sig.s().to_vec();
let mut raw_sig = [0u8; 64];
let r_off = 32 - r_bytes.len();
raw_sig[r_off..32].copy_from_slice(&r_bytes);
let s_off = 32 - s_bytes.len();
raw_sig[32 + s_off..].copy_from_slice(&s_bytes);
let sig_b64 = URL_SAFE_NO_PAD.encode(&raw_sig);
let jwt = format!("{}.{}.{}", header, body, sig_b64);
let result = crate::token_verifier::verify_token(&jwt);
assert!(result.is_err());
}
#[test]
fn test_verify_token_signature_invalid_with_ra_key_set() {
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use openssl::ecdsa::EcdsaSig;
use openssl::hash::{Hasher, MessageDigest};
let group =
openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let ra_key = openssl::ec::EcKey::generate(&group).unwrap();
let ra_pkey = openssl::pkey::PKey::from_ec_key(ra_key).unwrap();
let ra_pub_pem = String::from_utf8(ra_pkey.public_key_to_pem().unwrap()).unwrap();
let ra_key_path = std::env::temp_dir().join("kdc_test_ra_pub_key.pem");
std::fs::write(&ra_key_path, &ra_pub_pem).unwrap();
fn mock_get_agent_config() -> Option<&'static crate::config_manager::ConfigCenter> {
use crate::types::KdcLogLevel;
let config = crate::config_manager::ConfigCenter {
ip: "127.0.0.1".to_string(),
port: 8081,
ca_path: String::new(),
cert_path: String::new(),
private_path: String::new(),
crl_path: String::new(),
ra_public_key_path: std::env::temp_dir()
.join("kdc_test_ra_pub_key.pem")
.to_str()
.unwrap()
.to_string(),
log_path: "/tmp/kdc_test_sig.log".to_string(),
data_path: "/tmp".to_string(),
plugins: vec![],
log_max_size: 10,
log_backup_count: 5,
min_log_level: KdcLogLevel::Info,
};
Some(Box::leak(Box::new(config)))
}
let _cfg_mock = mockrs::mock!(
crate::config_manager::get_agent_config,
mock_get_agent_config
);
let init_result = crate::token_verifier::init_token_verifier();
let wrong_key = openssl::ec::EcKey::generate(&group).unwrap();
let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"ES256","typ":"JWT"}"#);
let body = URL_SAFE_NO_PAD.encode(br#"{"status":"pass"}"#);
let signing_input = format!("{}.{}", header, body);
let mut hasher = Hasher::new(MessageDigest::sha256()).unwrap();
hasher.update(signing_input.as_bytes()).unwrap();
let hash = hasher.finish().unwrap();
let sig = EcdsaSig::sign(&hash, &wrong_key).unwrap();
let r_bytes = sig.r().to_vec();
let s_bytes = sig.s().to_vec();
let mut raw_sig = [0u8; 64];
let r_off = 32 - r_bytes.len();
raw_sig[r_off..32].copy_from_slice(&r_bytes);
let s_off = 32 - s_bytes.len();
raw_sig[32 + s_off..].copy_from_slice(&s_bytes);
let sig_b64 = URL_SAFE_NO_PAD.encode(&raw_sig);
let jwt = format!("{}.{}.{}", header, body, sig_b64);
let result = crate::token_verifier::verify_token(&jwt);
assert!(result.is_err());
let _ = std::fs::remove_file(&ra_key_path);
}
}