use std::collections::HashSet;
use std::sync::{Arc, RwLock};
use uuid::Uuid;
use axum::{
extract::State,
http::{header::AUTHORIZATION, StatusCode},
middleware::Next,
response::Response,
};
#[derive(Clone, Default)]
pub struct WebuiTokenStore {
inner: Arc<RwLock<HashSet<String>>>,
}
impl WebuiTokenStore {
pub fn new() -> Self {
Self::default()
}
pub fn mint(&self) -> String {
let token = Uuid::new_v4().simple().to_string();
self.inner.write().unwrap().insert(token.clone());
token
}
pub fn is_valid(&self, token: &str) -> bool {
if token.is_empty() {
return false;
}
self.inner.read().unwrap().contains(token)
}
}
pub fn token_from_header(value: Option<&str>) -> Option<String> {
let v = value?;
let rest = v.strip_prefix("Bearer ").or_else(|| v.strip_prefix("bearer "))?;
let rest = rest.trim();
if rest.is_empty() { None } else { Some(rest.to_string()) }
}
pub async fn require_webui_token(
State(state): State<crate::AppState>,
req: axum::extract::Request,
next: Next,
) -> Result<Response, StatusCode> {
if !state.enforce_token {
return Ok(next.run(req).await);
}
let header = req
.headers()
.get(AUTHORIZATION)
.and_then(|h| h.to_str().ok());
match token_from_header(header) {
Some(tok) if state.webui_tokens.is_valid(&tok) => Ok(next.run(req).await),
_ => Err(StatusCode::UNAUTHORIZED),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extracts_bearer_token() {
assert_eq!(token_from_header(Some("Bearer abc123")), Some("abc123".to_string()));
assert_eq!(token_from_header(Some("bearer abc123")), Some("abc123".to_string()));
assert_eq!(token_from_header(Some("abc123")), None);
assert_eq!(token_from_header(None), None);
}
#[test]
fn mints_and_validates_token() {
let store = WebuiTokenStore::new();
let tok = store.mint();
assert!(store.is_valid(&tok), "freshly minted token must validate");
}
#[test]
fn rejects_unknown_token() {
let store = WebuiTokenStore::new();
store.mint();
assert!(!store.is_valid("not-a-real-token"));
assert!(!store.is_valid(""));
}
#[test]
fn mint_is_unique() {
let store = WebuiTokenStore::new();
assert_ne!(store.mint(), store.mint());
}
}