use std::collections::HashMap;
use std::io::{self, BufRead, Write};
use std::net::TcpListener;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{mpsc, Arc};
use std::thread;
use std::time::Duration;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use atomcode_telemetry::{Event, Telemetry};
use crate::config::Config;
const DEFAULT_PLATFORM_SERVER: &str = "https://acs.atomgit.com";
fn sanitize_base_url(raw: &str) -> String {
let trimmed = raw.trim();
let with_scheme = if trimmed.contains("://") {
trimmed.to_string()
} else {
format!("http://{}", trimmed)
};
with_scheme.trim_end_matches('/').to_string()
}
fn platform_base_url() -> &'static str {
use std::sync::OnceLock;
static BASE: OnceLock<String> = OnceLock::new();
BASE.get_or_init(|| {
let raw = std::env::var("ATOMCODE_PLATFORM_SERVER")
.unwrap_or_else(|_| DEFAULT_PLATFORM_SERVER.to_string());
sanitize_base_url(&raw)
})
}
pub fn platform_broker_url() -> String { platform_base_url().to_string() }
pub fn platform_login_url() -> String { format!("{}/auth/login", platform_base_url()) }
pub fn platform_check_url() -> String { format!("{}/auth/check", platform_base_url()) }
pub fn platform_token_url() -> String { format!("{}/auth/token", platform_base_url()) }
pub fn platform_exchange_url() -> String { format!("{}/oauth/exchange", platform_base_url()) }
pub fn platform_refresh_url() -> String { format!("{}/oauth/refresh", platform_base_url()) }
#[allow(dead_code)]
pub fn authorize_url() -> String { format!("{}/oauth/authorize", platform_base_url()) }
#[allow(dead_code)]
pub fn token_url() -> String { format!("{}/oauth/token", platform_base_url()) }
#[allow(dead_code)]
pub fn user_url() -> String { format!("{}/api/v5/user", platform_base_url()) }
fn blocking_client() -> Result<reqwest::blocking::Client> {
reqwest::blocking::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(10))
.user_agent(crate::ATOMCODE_USER_AGENT)
.build()
.context("failed to build OAuth HTTP client")
}
fn pending_invite_for_login() -> (Option<String>, Option<uuid::Uuid>) {
match atomcode_telemetry::pending_invite::load(&Config::config_dir()) {
Some(invite) => (Some(invite.invite_code), Some(invite.install_uuid)),
None => (None, None),
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthInfo {
pub access_token: String,
pub refresh_token: Option<String>,
pub token_type: String,
pub expires_in: Option<i64>,
#[serde(default)]
pub created_at: i64,
pub user: UserInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserInfo {
pub id: String,
pub username: String,
pub name: Option<String>,
pub email: Option<String>,
pub avatar_url: Option<String>,
}
#[derive(Debug, Deserialize)]
struct PlatformLoginResponse {
login_url: String,
state: String,
}
#[derive(Debug, Deserialize)]
struct PlatformCheckResponse {
valid: bool,
}
#[derive(Debug, Deserialize)]
struct PlatformUserInfo {
id: String,
username: String,
name: Option<String>,
email: Option<String>,
avatar_url: Option<String>,
}
#[derive(Debug, Deserialize)]
struct PlatformTokenResponse {
access_token: String,
token_type: String,
expires_in: Option<i64>,
refresh_token: Option<String>,
user: PlatformUserInfo,
}
#[cfg_attr(target_os = "windows", allow(dead_code))]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum EscOutcome {
Cancelled,
Timeout,
OtherInput,
}
#[cfg_attr(target_os = "windows", allow(dead_code))]
fn classify_input(bytes: &[u8]) -> EscOutcome {
match bytes {
[] => EscOutcome::Timeout,
[0x1B] => EscOutcome::Cancelled,
_ => EscOutcome::OtherInput,
}
}
#[cfg(not(target_os = "windows"))]
struct CbreakGuard {
fd: std::os::unix::io::RawFd,
orig: libc::termios,
}
#[cfg(target_os = "windows")]
struct CbreakGuard;
impl CbreakGuard {
#[cfg(not(target_os = "windows"))]
fn new() -> Option<Self> {
use std::os::unix::io::AsRawFd;
let fd = io::stdin().as_raw_fd();
let mut orig: libc::termios = unsafe { std::mem::zeroed() };
if unsafe { libc::tcgetattr(fd, &mut orig) } != 0 {
return None;
}
let mut new = orig;
new.c_lflag &= !(libc::ICANON | libc::ECHO);
new.c_cc[libc::VMIN] = 0;
new.c_cc[libc::VTIME] = 0;
if unsafe { libc::tcsetattr(fd, libc::TCSANOW, &new) } != 0 {
return None;
}
Some(Self { fd, orig })
}
#[cfg(target_os = "windows")]
fn new() -> Option<Self> {
None
}
}
#[cfg(not(target_os = "windows"))]
impl Drop for CbreakGuard {
fn drop(&mut self) {
unsafe {
libc::tcsetattr(self.fd, libc::TCSANOW, &self.orig);
}
}
}
#[cfg(not(target_os = "windows"))]
fn wait_for_esc_or_timeout(guard: &Option<CbreakGuard>, timeout: Duration) -> EscOutcome {
let Some(g) = guard.as_ref() else {
thread::sleep(timeout);
return EscOutcome::Timeout;
};
let mut pfd = libc::pollfd {
fd: g.fd,
events: libc::POLLIN,
revents: 0,
};
let timeout_ms = timeout.as_millis().min(i32::MAX as u128) as i32;
let rc = unsafe { libc::poll(&mut pfd, 1, timeout_ms) };
if rc <= 0 {
return EscOutcome::Timeout;
}
let mut buf = [0u8; 32];
let n = unsafe { libc::read(g.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) };
if n <= 0 {
return EscOutcome::Timeout;
}
classify_input(&buf[..n as usize])
}
#[cfg(target_os = "windows")]
fn wait_for_esc_or_timeout(_guard: &Option<CbreakGuard>, timeout: Duration) -> EscOutcome {
thread::sleep(timeout);
EscOutcome::Timeout
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PollOutcome {
Pending,
Authorized,
}
pub struct LoginSession {
state: String,
login_url: String,
client: reqwest::blocking::Client,
}
impl LoginSession {
pub fn url(&self) -> &str {
&self.login_url
}
pub fn open_browser_best_effort(&self) {
let _ = open_browser(&self.login_url);
}
pub fn poll_once(&self) -> Result<PollOutcome> {
let resp = self
.client
.get(platform_check_url())
.query(&[("state", &self.state)])
.send()
.context("Failed to call /auth/check")?;
if resp.status().is_success() {
if let Ok(check) = resp.json::<PlatformCheckResponse>() {
if check.valid {
return Ok(PollOutcome::Authorized);
}
}
}
Ok(PollOutcome::Pending)
}
pub fn finish(self, tel: Option<&Arc<Telemetry>>) -> Result<AuthInfo> {
let token_resp: PlatformTokenResponse = self
.client
.get(platform_token_url())
.query(&[("state", &self.state)])
.send()
.context("Failed to call /auth/token")?
.json()
.context("Failed to parse /auth/token response")?;
let created_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
let auth_info = AuthInfo {
access_token: token_resp.access_token,
refresh_token: token_resp.refresh_token,
token_type: token_resp.token_type,
expires_in: token_resp.expires_in,
created_at,
user: UserInfo {
id: token_resp.user.id,
username: token_resp.user.username,
name: token_resp.user.name,
email: token_resp.user.email,
avatar_url: token_resp.user.avatar_url,
},
};
if let Some(t) = tel {
t.set_account_id(Some(auth_info.user.id.to_string()));
let (invite_code, install_uuid) = pending_invite_for_login();
let event = Event::LoginSuccess {
invite_code,
install_uuid,
};
if let Err(e) = t.track_durable_sync(event.clone()) {
tracing::warn!(
?e,
"login_success durable enqueue failed; falling back to async telemetry"
);
t.track(event);
}
}
Ok(auth_info)
}
}
pub fn start_login() -> Result<LoginSession> {
let client = reqwest::blocking::Client::builder()
.user_agent(crate::ATOMCODE_USER_AGENT)
.build()
.context("failed to build OAuth login HTTP client")?;
let resp: PlatformLoginResponse = client
.get(platform_login_url())
.query(&[("provider", "atomgit")])
.send()
.context("Failed to call /auth/login")?
.json()
.context("Failed to parse /auth/login response")?;
Ok(LoginSession {
state: resp.state,
login_url: strip_force_login(&resp.login_url),
client,
})
}
fn strip_force_login(url: &str) -> String {
url.replace("&force_login=true", "")
.replace("?force_login=true&", "?")
.replace("?force_login=true", "")
}
pub fn login(tel: Option<&Arc<Telemetry>>) -> Result<AuthInfo> {
let session = start_login()?;
println!(" Browser didn't open? Open the URL below in any browser to sign in:");
println!(" {}", session.url());
let cbreak = CbreakGuard::new();
if cbreak.is_some() {
println!();
println!(" Press ESC to cancel");
}
session.open_browser_best_effort();
loop {
match session.poll_once()? {
PollOutcome::Authorized => break,
PollOutcome::Pending => {}
}
match wait_for_esc_or_timeout(&cbreak, Duration::from_secs(2)) {
EscOutcome::Cancelled => anyhow::bail!("login cancelled by user"),
EscOutcome::Timeout | EscOutcome::OtherInput => {}
}
}
session.finish(tel)
}
#[allow(dead_code)]
fn pasted_state(url: &str) -> Option<String> {
url.split('?')
.nth(1)?
.split('&')
.filter_map(|pair| {
let mut parts = pair.splitn(2, '=');
if parts.next()? == "state" {
Some(urlencoding_decode(parts.next()?))
} else {
None
}
})
.next()
}
#[allow(dead_code)]
fn generate_state() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
format!("atomcode_{}", timestamp)
}
#[cfg(target_os = "macos")]
pub fn open_browser(url: &str) -> Result<()> {
std::process::Command::new("open")
.arg(url)
.spawn()
.context("Failed to open browser")?;
Ok(())
}
#[cfg(target_os = "linux")]
pub fn open_browser(url: &str) -> Result<()> {
std::process::Command::new("xdg-open")
.arg(url)
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.spawn()
.context("Failed to open browser")?;
Ok(())
}
#[cfg(target_os = "windows")]
pub fn open_browser(url: &str) -> Result<()> {
use std::os::windows::process::CommandExt;
std::process::Command::new("cmd")
.raw_arg(format!("/C start \"\" \"{}\"", url))
.spawn()
.context("Failed to open browser")?;
Ok(())
}
#[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
pub fn open_browser(_url: &str) -> Result<()> {
anyhow::bail!("Unsupported platform for browser auto-open");
}
#[allow(dead_code)]
fn await_callback(port: u16) -> Result<(String, String)> {
let listener = match TcpListener::bind(("127.0.0.1", port)) {
Ok(l) => Some(l),
Err(e) => {
println!(" Could not bind port {} ({}). Paste path only.", port, e);
None
}
};
println!(
" Waiting for callback on http://127.0.0.1:{}/callback",
port
);
println!(" Or paste the full callback URL here and press Enter:");
println!(" (Ctrl+C to cancel)\n");
let (tx, rx) = mpsc::channel::<Result<(String, String)>>();
let stop = Arc::new(AtomicBool::new(false));
#[cfg_attr(not(target_os = "windows"), allow(unused_variables))]
let has_listener = listener.is_some();
if let Some(listener) = listener {
let tx_l = tx.clone();
let stop_l = Arc::clone(&stop);
thread::spawn(move || {
let r = accept_callback_until_stopped(listener, &stop_l);
let _ = tx_l.send(r);
});
}
#[cfg(not(target_os = "windows"))]
{
let tx_stdin = tx.clone();
let stop_stdin = Arc::clone(&stop);
thread::spawn(move || {
let r = read_callback_from_stdin_until_stopped(&stop_stdin);
let _ = tx_stdin.send(r);
});
}
#[cfg(target_os = "windows")]
{
if !has_listener {
let tx_stdin = tx.clone();
thread::spawn(move || {
let stdin = io::stdin();
let mut line = String::new();
let r = match stdin.lock().read_line(&mut line) {
Ok(0) => Err(anyhow::anyhow!("stdin closed")),
Ok(_) => parse_pasted_callback(&line),
Err(e) => Err(anyhow::Error::new(e).context("Failed to read from stdin")),
};
let _ = tx_stdin.send(r);
});
}
}
drop(tx);
let result = rx.recv().context("login cancelled")?;
stop.store(true, Ordering::Relaxed);
result
}
#[cfg(not(target_os = "windows"))]
#[allow(dead_code)]
fn read_callback_from_stdin_until_stopped(stop: &AtomicBool) -> Result<(String, String)> {
use std::os::unix::io::AsRawFd;
let stdin = io::stdin();
let fd = stdin.as_raw_fd();
let orig_flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
if orig_flags >= 0 {
unsafe {
libc::fcntl(fd, libc::F_SETFL, orig_flags | libc::O_NONBLOCK);
}
}
struct FlagGuard {
fd: std::os::unix::io::RawFd,
orig_flags: i32,
}
impl Drop for FlagGuard {
fn drop(&mut self) {
if self.orig_flags >= 0 {
unsafe {
libc::fcntl(self.fd, libc::F_SETFL, self.orig_flags);
}
}
}
}
let _guard = FlagGuard { fd, orig_flags };
let mut line = String::new();
let mut buf = [0u8; 256];
loop {
if stop.load(Ordering::Relaxed) {
anyhow::bail!("stdin cancelled");
}
let mut pfd = libc::pollfd {
fd,
events: libc::POLLIN,
revents: 0,
};
let poll_rc = unsafe { libc::poll(&mut pfd, 1, 100) };
if poll_rc < 0 {
let err = io::Error::last_os_error();
if err.kind() == io::ErrorKind::Interrupted {
continue;
}
return Err(anyhow::Error::new(err).context("poll(stdin)"));
}
if poll_rc == 0 {
continue;
}
let n = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) };
if n < 0 {
let err = io::Error::last_os_error();
if err.kind() == io::ErrorKind::WouldBlock || err.kind() == io::ErrorKind::Interrupted {
continue;
}
return Err(anyhow::Error::new(err).context("read(stdin)"));
}
if n == 0 {
anyhow::bail!("stdin closed");
}
line.push_str(&String::from_utf8_lossy(&buf[..n as usize]));
if line.contains('\n') {
return parse_pasted_callback(&line);
}
}
}
#[allow(dead_code)]
fn accept_callback_until_stopped(
listener: TcpListener,
stop: &AtomicBool,
) -> Result<(String, String)> {
listener
.set_nonblocking(true)
.context("Failed to set non-blocking mode")?;
let mut stream = loop {
if stop.load(Ordering::Relaxed) {
anyhow::bail!("listener cancelled");
}
match listener.accept() {
Ok((stream, _)) => break stream,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
thread::sleep(Duration::from_millis(200));
continue;
}
Err(e) => return Err(e).context("Failed to accept connection"),
}
};
stream.set_nonblocking(false)?;
let mut reader = io::BufReader::new(&mut stream);
let mut request_line = String::new();
reader.read_line(&mut request_line)?;
let url: String = request_line
.split_whitespace()
.nth(1)
.context("Invalid HTTP request")?
.to_string();
let query_start = url.find('?').context("No query parameters in callback")?;
let query = &url[query_start + 1..];
let params: HashMap<String, String> = query
.split('&')
.filter_map(|pair| {
let mut parts = pair.splitn(2, '=');
let key = parts.next()?;
let value = parts
.next()
.map(|v| urlencoding_decode(v))
.unwrap_or_default();
Some((key.to_string(), value))
})
.collect();
if let Some(error) = params.get("error") {
let error_desc = params
.get("error_description")
.map(|s| s.as_str())
.unwrap_or(error);
let response = "HTTP/1.1 302 Found\r\nLocation: https://atomgit.com\r\n\r\n";
let _ = stream.write_all(response.as_bytes());
let _ = stream.flush();
anyhow::bail!("OAuth error: {}", error_desc);
}
let code = params.get("code").context("No code in callback")?.clone();
let state = params.get("state").cloned().unwrap_or_default();
let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\n\r\n\
<html><head><title>AtomCode Login</title>\
<style>body{font-family:system-ui;display:flex;justify-content:center;align-items:center;height:100vh;margin:0;background:#1a1a2e;color:#eee}\
.container{text-align:center;padding:2rem}h1{color:#7c3aed;margin:0}p{color:#888}\
.success{color:#22c55e;font-size:4rem}</style></head>\
<body><div class=\"container\">\
<div class=\"success\">✓</div>\
<h1>Authorization Successful</h1>\
<p>You can close this window and return to AtomCode.</p>\
</div></body></html>";
stream.write_all(response.as_bytes())?;
stream.flush()?;
Ok((code, state))
}
fn urlencoding_decode(s: &str) -> String {
let mut result = String::new();
let mut chars = s.chars().peekable();
while let Some(c) = chars.next() {
if c == '%' {
let hex: String = chars.by_ref().take(2).collect();
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
result.push(byte as char);
}
} else if c == '+' {
result.push(' ');
} else {
result.push(c);
}
}
result
}
pub fn refresh_access_token(auth: &AuthInfo) -> Result<AuthInfo> {
let refresh_token = auth
.refresh_token
.as_deref()
.context("No refresh_token available — please /login again")?;
let client = blocking_client()?;
let response = client
.post(platform_refresh_url())
.json(&serde_json::json!({ "refresh_token": refresh_token }))
.send()
.context("Failed to send refresh token request to broker")?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().unwrap_or_default();
anyhow::bail!(
"Token refresh failed ({}): {} — please /login again",
status,
body
);
}
#[derive(Deserialize)]
struct BrokerResponse {
access_token: String,
token_type: Option<String>,
expires_in: Option<i64>,
refresh_token: Option<String>,
user: Option<PlatformUserInfo>,
}
let broker_resp: BrokerResponse = response.json().context("Failed to parse broker response")?;
let created_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
let new_auth = AuthInfo {
access_token: broker_resp.access_token,
refresh_token: broker_resp
.refresh_token
.or_else(|| auth.refresh_token.clone()),
token_type: broker_resp
.token_type
.unwrap_or_else(|| auth.token_type.clone()),
expires_in: broker_resp.expires_in.or(auth.expires_in),
created_at,
user: broker_resp
.user
.map(|u| UserInfo {
id: u.id,
username: u.username,
name: u.name,
email: u.email,
avatar_url: u.avatar_url,
})
.unwrap_or_else(|| auth.user.clone()),
};
save_auth(&new_auth)?;
Ok(new_auth)
}
pub fn get_valid_token() -> Result<String> {
let auth = get_stored_auth().context("Not logged in — please use /login first")?;
if let Some(expires_in) = auth.expires_in {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(i64::MAX);
let expires_at = auth.created_at + expires_in;
if now >= expires_at - 300 {
match refresh_access_token(&auth) {
Ok(new_auth) => return Ok(new_auth.access_token),
Err(e) => anyhow::bail!("Token expired and refresh failed: {}", e),
}
}
} else if auth.created_at == 0 {
if auth.refresh_token.is_some() {
if let Ok(new_auth) = refresh_access_token(&auth) {
return Ok(new_auth.access_token);
}
}
}
Ok(auth.access_token)
}
pub fn logout() -> Result<()> {
let auth_path = auth_file_path();
if auth_path.exists() {
std::fs::remove_file(&auth_path).context("Failed to remove auth file")?;
}
Ok(())
}
pub fn get_stored_auth() -> Option<AuthInfo> {
let auth_path = auth_file_path();
if !auth_path.exists() {
return None;
}
let content = std::fs::read_to_string(&auth_path).ok()?;
toml::from_str(&content).ok()
}
pub fn save_auth(auth: &AuthInfo) -> Result<()> {
let auth_path = auth_file_path();
if let Some(parent) = auth_path.parent() {
std::fs::create_dir_all(parent).context("Failed to create auth directory")?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o700));
}
}
let content = toml::to_string_pretty(auth).context("Failed to serialize auth info")?;
super::write_auth_file_secure(&auth_path, &content).context("Failed to write auth file")?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&auth_path, std::fs::Permissions::from_mode(0o600))
.context("Failed to set auth file permissions")?;
}
Ok(())
}
pub fn auth_file_path() -> std::path::PathBuf {
crate::config::Config::config_dir().join("auth.toml")
}
pub fn is_logged_in() -> bool {
get_stored_auth().is_some()
}
pub fn current_user() -> Option<UserInfo> {
get_stored_auth().map(|auth| auth.user)
}
#[allow(dead_code)]
fn parse_pasted_callback(input: &str) -> Result<(String, String)> {
let cleaned = input
.trim()
.trim_start_matches("\x1b[200~")
.trim_end_matches("\x1b[201~")
.trim();
let query_start = cleaned.find('?').context(
"Could not parse callback URL — paste the full http://127.0.0.1:8765/callback?... URL",
)?;
let query = &cleaned[query_start + 1..];
let params: HashMap<String, String> = query
.split('&')
.filter_map(|pair| {
let mut parts = pair.splitn(2, '=');
let key = parts.next()?;
let value = parts
.next()
.map(|v| urlencoding_decode(v))
.unwrap_or_default();
Some((key.to_string(), value))
})
.collect();
if let Some(error) = params.get("error") {
let desc = params
.get("error_description")
.map(|s| s.as_str())
.unwrap_or(error);
anyhow::bail!("OAuth error: {}", desc);
}
let code = params
.get("code")
.context("Callback URL missing 'code' parameter")?
.clone();
let state = params
.get("state")
.context("Callback URL missing 'state' parameter (paste the full URL, not just the code)")?
.clone();
Ok((code, state))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn strip_force_login_removes_trailing_param() {
let url = "https://atomgit.com/oauth/authorize?client_id=abc&state=xyz&force_login=true";
assert_eq!(
strip_force_login(url),
"https://atomgit.com/oauth/authorize?client_id=abc&state=xyz"
);
}
#[test]
fn strip_force_login_removes_middle_param() {
let url = "https://atomgit.com/oauth/authorize?client_id=abc&force_login=true&state=xyz";
assert_eq!(
strip_force_login(url),
"https://atomgit.com/oauth/authorize?client_id=abc&state=xyz"
);
}
#[test]
fn strip_force_login_removes_only_param() {
let url = "https://atomgit.com/oauth/authorize?force_login=true";
assert_eq!(
strip_force_login(url),
"https://atomgit.com/oauth/authorize"
);
}
#[test]
fn strip_force_login_removes_first_of_many() {
let url = "https://atomgit.com/oauth/authorize?force_login=true&state=xyz";
assert_eq!(
strip_force_login(url),
"https://atomgit.com/oauth/authorize?state=xyz"
);
}
#[test]
fn strip_force_login_passthrough_when_absent() {
let url = "https://atomgit.com/oauth/authorize?client_id=abc&state=xyz";
assert_eq!(strip_force_login(url), url);
}
#[test]
fn parse_happy_path_loopback_url() {
let (code, state) =
parse_pasted_callback("http://127.0.0.1:8765/callback?code=abc&state=xyz").unwrap();
assert_eq!(code, "abc");
assert_eq!(state, "xyz");
}
#[test]
fn parse_any_host_with_extra_params() {
let (code, state) =
parse_pasted_callback("https://example.com/x?foo=1&code=abc&state=xyz&bar=2").unwrap();
assert_eq!(code, "abc");
assert_eq!(state, "xyz");
}
#[test]
fn parse_missing_state_errors_with_full_url_hint() {
let err = parse_pasted_callback("http://127.0.0.1:8765/callback?code=abc")
.unwrap_err()
.to_string();
assert!(err.contains("state"), "got: {err}");
assert!(err.contains("full URL"), "got: {err}");
}
#[test]
fn parse_missing_code_errors() {
let err = parse_pasted_callback("http://127.0.0.1:8765/callback?state=xyz")
.unwrap_err()
.to_string();
assert!(err.contains("code"), "got: {err}");
}
#[test]
fn parse_error_response_includes_description() {
let err = parse_pasted_callback(
"http://127.0.0.1:8765/callback?error=access_denied&error_description=User+denied",
)
.unwrap_err()
.to_string();
assert!(err.contains("User denied"), "got: {err}");
}
#[test]
fn parse_not_a_url_errors() {
let err = parse_pasted_callback("this is not a url")
.unwrap_err()
.to_string();
assert!(err.contains("full"), "got: {err}");
}
#[test]
fn parse_url_encoded_state_is_decoded() {
let (_, state) =
parse_pasted_callback("http://127.0.0.1:8765/callback?code=c&state=atomcode_%3Atest")
.unwrap();
assert_eq!(state, "atomcode_:test");
}
#[test]
fn parse_strips_bracketed_paste_markers() {
let input = "\x1b[200~http://127.0.0.1:8765/callback?code=abc&state=xyz\x1b[201~";
let (code, state) = parse_pasted_callback(input).unwrap();
assert_eq!(code, "abc");
assert_eq!(state, "xyz");
}
#[test]
fn parse_trims_surrounding_whitespace() {
let (code, state) =
parse_pasted_callback(" http://127.0.0.1:8765/callback?code=abc&state=xyz\n")
.unwrap();
assert_eq!(code, "abc");
assert_eq!(state, "xyz");
}
#[test]
fn classify_input_bare_esc_cancels() {
assert_eq!(classify_input(&[0x1B]), EscOutcome::Cancelled);
}
#[test]
fn classify_input_arrow_key_ignored() {
assert_eq!(classify_input(b"\x1B[A"), EscOutcome::OtherInput);
}
#[test]
fn classify_input_alt_letter_ignored() {
assert_eq!(classify_input(b"\x1Ba"), EscOutcome::OtherInput);
}
#[test]
fn classify_input_normal_byte_ignored() {
assert_eq!(classify_input(b"q"), EscOutcome::OtherInput);
}
#[test]
fn classify_input_empty_is_timeout() {
assert_eq!(classify_input(&[]), EscOutcome::Timeout);
}
#[test]
fn classify_input_pasted_text_ignored() {
assert_eq!(classify_input(b"hello\n"), EscOutcome::OtherInput);
}
#[test]
fn classify_input_csi_color_code_ignored() {
assert_eq!(classify_input(b"\x1B[31m"), EscOutcome::OtherInput);
}
#[test]
fn sanitize_adds_http_if_no_scheme() {
assert_eq!(sanitize_base_url("127.0.0.1:8765"), "http://127.0.0.1:8765");
}
#[test]
fn sanitize_preserves_http_scheme() {
assert_eq!(sanitize_base_url("http://127.0.0.1:8765"), "http://127.0.0.1:8765");
}
#[test]
fn sanitize_preserves_https_scheme() {
assert_eq!(sanitize_base_url("https://acs.example.com"), "https://acs.example.com");
}
#[test]
fn sanitize_strips_trailing_slash() {
assert_eq!(sanitize_base_url("http://127.0.0.1:8765/"), "http://127.0.0.1:8765");
assert_eq!(sanitize_base_url("http://127.0.0.1:8765///"), "http://127.0.0.1:8765");
}
#[test]
fn sanitize_trims_whitespace() {
assert_eq!(sanitize_base_url(" http://127.0.0.1:8765 "), "http://127.0.0.1:8765");
}
#[test]
fn sanitize_no_scheme_with_trailing_slash() {
assert_eq!(sanitize_base_url("127.0.0.1:8765/"), "http://127.0.0.1:8765");
}
}