pub mod auto_fix;
pub mod bash;
pub mod blast_radius;
pub mod cd;
pub mod diagnostics;
pub mod edit;
pub mod file_deps;
pub mod file_history;
pub mod find_references;
pub mod glob;
pub mod grep;
pub mod list_dir;
pub mod list_symbols;
pub mod open_file;
pub mod parallel_edit;
pub mod read;
pub mod read_symbol;
pub mod result_store;
pub mod search_replace;
pub mod todo;
pub mod trace_callees;
pub mod trace_callers;
pub mod trace_chain;
pub mod use_skill;
pub mod web_fetch;
pub mod web_search;
pub mod write;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::ffi::{OsStr, OsString};
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;
pub const SKIP_DIRS: &[&str] = &[
"node_modules",
".git",
"target",
"__pycache__",
".next",
"dist",
"build",
".cache",
"vendor",
".venv",
"venv",
".idea",
".vscode",
".DS_Store",
".env",
"datalog",
"logs",
"log",
".atomcode",
".claude",
"runs",
];
pub const SKIP_DIR_PREFIXES: &[&str] = &[".venv-"];
pub fn should_skip_dir(name: &str) -> bool {
SKIP_DIRS.contains(&name) || SKIP_DIR_PREFIXES.iter().any(|p| name.starts_with(p))
}
pub fn diagnose_args(
tool: &str,
args: &str,
required_modes: &[&[&str]],
example: &str,
) -> std::result::Result<serde_json::Value, String> {
let trimmed = args.trim();
if trimmed.is_empty() || trimmed == "{}" {
return Err(format!(
"{tool} called with empty arguments — likely max_tokens cutoff. \
Re-issue: {example}"
));
}
let value: serde_json::Value = serde_json::from_str(args).map_err(|_| {
format!(
"{tool} arguments are not valid JSON. Re-issue: {example}"
)
})?;
let obj = match value.as_object() {
Some(o) => o,
None => {
let kind = match &value {
serde_json::Value::Null => "null",
serde_json::Value::Bool(_) => "boolean",
serde_json::Value::Number(_) => "number",
serde_json::Value::String(_) => "string",
serde_json::Value::Array(_) => "array",
serde_json::Value::Object(_) => unreachable!(),
};
return Err(format!(
"{tool} expected a JSON object, got {kind}. Re-issue: {example}"
));
}
};
if required_modes
.iter()
.any(|m| m.iter().all(|k| obj.contains_key(*k)))
{
return Ok(value);
}
let provided: Vec<&str> = obj.keys().map(String::as_str).collect();
let (closest, missing) = required_modes
.iter()
.map(|m| {
let miss: Vec<&str> = m
.iter()
.filter(|k| !obj.contains_key(**k))
.copied()
.collect();
(*m, miss)
})
.min_by_key(|(_, miss)| miss.len())
.expect("required_modes must be non-empty");
Err(format!(
"{tool}: provided keys [{}], missing required [{}] for mode [{}]. \
Re-issue: {}",
provided.join(", "),
missing.join(", "),
closest.join("+"),
example,
))
}
pub(crate) fn is_sensitive_input_path(path: &str) -> bool {
let base_dir = std::env::current_dir().ok();
let home_dir = dirs::home_dir();
is_sensitive_input_path_with_context(path, base_dir.as_deref(), home_dir.as_deref())
}
fn is_sensitive_input_path_with_context(
path: &str,
base_dir: Option<&Path>,
home_dir: Option<&Path>,
) -> bool {
if is_windows_sensitive_path(path) {
return true;
}
let mut expanded = expand_home_path(path, home_dir);
if !expanded.is_absolute() {
if let Some(base_dir) = base_dir {
expanded = base_dir.join(expanded);
}
}
let normalized = lexical_normalize(&expanded);
if is_windows_sensitive_path(&normalized.to_string_lossy()) {
return true;
}
is_sensitive_path(&normalized)
}
fn expand_home_path(path: &str, home_dir: Option<&Path>) -> PathBuf {
if let Some(stripped) = path.strip_prefix("~/") {
if let Some(home_dir) = home_dir {
return home_dir.join(stripped);
}
}
if path == "~" {
if let Some(home_dir) = home_dir {
return home_dir.to_path_buf();
}
}
PathBuf::from(path)
}
fn lexical_normalize(path: &Path) -> PathBuf {
let mut prefix: Option<OsString> = None;
let mut has_root = false;
let mut parts: Vec<OsString> = Vec::new();
for component in path.components() {
match component {
Component::Prefix(prefix_component) => {
prefix = Some(prefix_component.as_os_str().to_os_string());
parts.clear();
}
Component::RootDir => {
has_root = true;
parts.clear();
}
Component::CurDir => {}
Component::ParentDir => {
if parts.last().is_some_and(|part| part != OsStr::new("..")) {
parts.pop();
} else if !has_root {
parts.push(OsString::from(".."));
}
}
Component::Normal(part) => parts.push(part.to_os_string()),
}
}
let mut normalized = PathBuf::new();
if let Some(prefix) = prefix {
normalized.push(prefix);
}
if has_root {
normalized.push(std::path::MAIN_SEPARATOR.to_string());
}
for part in parts {
normalized.push(part);
}
normalized
}
fn is_windows_sensitive_path(path: &str) -> bool {
let normalized = path.replace('/', "\\");
let normalized = normalized.strip_prefix(r"\\?\").unwrap_or(&normalized);
let lowercase = normalized.to_ascii_lowercase();
let sensitive_roots = [
r"\windows",
r"\program files",
r"\program files (x86)",
r"\programdata",
];
let Some(path_root) = windows_rooted_path(&lowercase) else {
return false;
};
sensitive_roots
.iter()
.any(|root| windows_path_starts_with(path_root, root))
}
fn windows_path_starts_with(path: &str, root: &str) -> bool {
path == root
|| path
.strip_prefix(root)
.is_some_and(|rest| rest.starts_with('\\'))
}
fn windows_rooted_path(path: &str) -> Option<&str> {
if let Some(path_without_drive) = strip_windows_drive_prefix(path) {
return Some(path_without_drive);
}
if path.starts_with('\\') && !path.starts_with(r"\\") {
return Some(path);
}
None
}
fn strip_windows_drive_prefix(path: &str) -> Option<&str> {
let bytes = path.as_bytes();
if bytes.len() < 3
|| !bytes[0].is_ascii_alphabetic()
|| bytes[1] != b':'
|| bytes[2] != b'\\'
{
return None;
}
Some(&path[2..])
}
/// Count of leading characters shared between two paths. Used by read_file
/// and glob 404 recovery to rank candidate suggestions.
pub fn shared_prefix_len(a: &str, b: &str) -> usize {
a.chars().zip(b.chars()).take_while(|(x, y)| x == y).count()
}
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use tokio::sync::{Mutex, RwLock};
/// Get the real user's home directory, accounting for sudo scenarios.
pub fn real_home_dir() -> Option<PathBuf> {
if let Ok(sudo_user) = std::env::var("SUDO_USER") {
if let Some(home) = get_user_home(&sudo_user) {
return Some(home);
}
}
dirs::home_dir()
}
#[cfg(unix)]
fn get_user_home(username: &str) -> Option<PathBuf> {
use std::ffi::CString;
use std::ptr;
let username_c = CString::new(username).ok()?;
unsafe {
let mut pwd: libc::passwd = std::mem::zeroed();
let mut buf = vec![0u8; 4096];
let mut result: *mut libc::passwd = ptr::null_mut();
let ret = libc::getpwnam_r(
username_c.as_ptr(),
&mut pwd,
buf.as_mut_ptr() as *mut libc::c_char,
buf.len(),
&mut result,
);
if ret == 0 && !result.is_null() {
let home = std::ffi::CStr::from_ptr(pwd.pw_dir)
.to_string_lossy()
.into_owned();
return Some(PathBuf::from(home));
}
}
None
}
#[cfg(not(unix))]
fn get_user_home(_username: &str) -> Option<PathBuf> {
None
}
fn expand_user_path(path: &str) -> PathBuf {
if path == "~" {
return real_home_dir().unwrap_or_else(|| PathBuf::from(path));
}
if let Some(rest) = path.strip_prefix("~/") {
return real_home_dir()
.map(|home| home.join(rest))
.unwrap_or_else(|| PathBuf::from(path));
}
PathBuf::from(path)
}
fn normalize_path(path: &Path) -> PathBuf {
let mut normalized = PathBuf::new();
for component in path.components() {
match component {
Component::CurDir => {}
Component::ParentDir => {
let can_pop = normalized
.components()
.next_back()
.is_some_and(|last| matches!(last, Component::Normal(_)));
if can_pop {
normalized.pop();
} else if normalized.as_os_str().is_empty() {
normalized.push(component.as_os_str());
}
}
Component::RootDir | Component::Prefix(_) | Component::Normal(_) => {
normalized.push(component.as_os_str());
}
}
}
normalized
}
fn canonicalize_candidate_path(path: &Path) -> Result<PathBuf> {
if path.exists() {
return std::fs::canonicalize(path)
.with_context(|| format!("Failed to resolve path {}", path.display()));
}
let mut missing_parts = Vec::new();
let mut current = path;
loop {
if current.exists() {
let mut resolved = std::fs::canonicalize(current)
.with_context(|| format!("Failed to resolve parent path {}", current.display()))?;
for part in missing_parts.iter().rev() {
resolved.push(part);
}
return Ok(resolved);
}
let name = current.file_name().ok_or_else(|| {
anyhow::anyhow!("Path {} has no existing parent directory", path.display())
})?;
missing_parts.push(name.to_os_string());
current = current.parent().ok_or_else(|| {
anyhow::anyhow!("Path {} has no existing parent directory", path.display())
})?;
}
}
pub struct ResolvedPath {
pub path: PathBuf,
pub workspace_root: PathBuf,
pub within_workspace: bool,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ExternalPathAction {
Enumerate,
Read,
Write,
}
pub fn inspect_path_access(raw_path: &str, working_dir: &Path) -> Result<ResolvedPath> {
let workspace_root = std::fs::canonicalize(working_dir).with_context(|| {
format!(
"Failed to resolve working directory {}",
working_dir.display()
)
})?;
let expanded = expand_user_path(raw_path);
let candidate = if expanded.is_absolute() {
expanded
} else {
working_dir.join(expanded)
};
let candidate = normalize_path(&candidate);
let resolved = canonicalize_candidate_path(&candidate)?;
Ok(ResolvedPath {
within_workspace: resolved.starts_with(&workspace_root),
path: resolved,
workspace_root,
})
}
pub fn resolve_workspace_path(raw_path: &str, working_dir: &Path) -> Result<PathBuf> {
let resolved = inspect_path_access(raw_path, working_dir)?;
if resolved.within_workspace {
Ok(resolved.path)
} else {
bail!(
"Access denied: {} resolves outside working directory {}",
raw_path,
resolved.workspace_root.display()
);
}
}
fn is_sensitive_path(path: &Path) -> bool {
#[cfg(not(target_os = "windows"))]
const SYSTEM_PROTECTED_PREFIXES: &[&str] = &[
"/System",
"/bin",
"/sbin",
"/usr",
"/var",
"/private/etc",
"/private/var",
"/etc",
"/root",
"/var/root",
"/private/var/root",
];
#[cfg(target_os = "windows")]
const SYSTEM_PROTECTED_PREFIXES: &[&str] = &[
r"C:\Windows",
r"C:\Program Files",
r"C:\Program Files (x86)",
r"C:\ProgramData",
r"C:\PerfLogs",
];
#[cfg(not(target_os = "windows"))]
const SYSTEM_PROTECTED_EXCEPTIONS: &[&str] = &[
"/usr/local",
"/private/usr/local",
"/Applications",
"/Library",
"/var/folders",
"/private/var/folders",
"/var/tmp",
"/private/var/tmp",
];
#[cfg(target_os = "windows")]
const SYSTEM_PROTECTED_EXCEPTIONS: &[&str] = &[];
const SECRET_HOME_DIRS: &[&str] = &[".ssh", ".aws", ".gnupg", ".config"];
const SECRET_FILE_NAMES: &[&str] = &[
".bashrc",
".bash_profile",
".zshrc",
".zprofile",
".zshenv",
".npmrc",
".pypirc",
".env",
".env.local",
"credentials",
"config",
"id_rsa",
"id_dsa",
"id_ecdsa",
"id_ed25519",
];
const SECRET_EXTS: &[&str] = &["pem", "key", "p12", "pfx", "der", "crt", "cer"];
let has_protected_prefix = SYSTEM_PROTECTED_PREFIXES
.iter()
.any(|prefix| path == Path::new(prefix) || path.starts_with(prefix));
let has_exception_prefix = SYSTEM_PROTECTED_EXCEPTIONS
.iter()
.any(|prefix| path == Path::new(prefix) || path.starts_with(prefix));
if has_protected_prefix && !has_exception_prefix {
return true;
}
if let Some(home) = real_home_dir() {
for dir in SECRET_HOME_DIRS {
if path.starts_with(home.join(dir)) {
return true;
}
}
for file in SECRET_FILE_NAMES {
if path == home.join(file) {
return true;
}
}
}
if path
.file_name()
.and_then(|n| n.to_str())
.is_some_and(|name| SECRET_FILE_NAMES.contains(&name))
{
return true;
}
path.extension()
.and_then(|ext| ext.to_str())
.is_some_and(|ext| {
SECRET_EXTS
.iter()
.any(|candidate| ext.eq_ignore_ascii_case(candidate))
})
}
fn is_atomcode_owned_path(path: &Path) -> bool {
let Some(home) = real_home_dir() else { return false };
let trusted_roots: &[PathBuf] = &[
home.join(".atomcode").join("plugins"),
home.join(".atomcode").join("skills"),
];
trusted_roots
.iter()
.any(|root| path == root.as_path() || path.starts_with(root))
}
pub fn approval_for_path(
raw_path: &str,
working_dir: &Path,
action: ExternalPathAction,
) -> Result<ApprovalRequirement> {
let access = inspect_path_access(raw_path, working_dir)?;
if access.within_workspace {
return Ok(ApprovalRequirement::AutoApprove);
}
if action != ExternalPathAction::Write && is_atomcode_owned_path(&access.path) {
return Ok(ApprovalRequirement::AutoApprove);
}
let sensitive = is_sensitive_path(&access.path);
let action_label = match action {
ExternalPathAction::Enumerate => "Accessing",
ExternalPathAction::Read => "Reading",
ExternalPathAction::Write => "Writing",
};
let base_reason = format!(
"{} path outside working directory: {} (working dir: {})",
action_label,
raw_path,
access.workspace_root.display()
);
Ok(match action {
ExternalPathAction::Enumerate => {
if sensitive {
ApprovalRequirement::RequireApprovalAlways(format!(
"{}. This path looks sensitive and always requires confirmation.",
base_reason
))
} else {
ApprovalRequirement::AutoApprove
}
}
ExternalPathAction::Read => {
if sensitive {
ApprovalRequirement::RequireApprovalAlways(format!(
"{}. This path looks sensitive and always requires confirmation.",
base_reason
))
} else {
ApprovalRequirement::RequireApproval(format!("{base_reason}."))
}
}
ExternalPathAction::Write => ApprovalRequirement::RequireApprovalAlways(format!(
"{}. Writing outside the workspace always requires confirmation.",
base_reason
)),
})
}
#[derive(Debug, Clone)]
pub struct ToolDef {
pub name: &'static str,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ToolResult {
pub call_id: String,
pub output: String,
pub success: bool,
}
#[derive(Debug, Clone)]
pub struct ToolCallBuffer {
pub id: String,
pub name: String,
pub arguments: String,
pub hint_sent: bool,
}
pub enum ApprovalRequirement {
AutoApprove,
RequireApproval(String),
RequireApprovalAlways(String),
}
#[derive(Debug, Clone, PartialEq)]
pub enum PermissionLevel {
AlwaysAllow,
Ask,
SessionAllow,
AlwaysDeny,
}
#[derive(Debug, Clone)]
pub enum PermissionDecision {
Allow,
Ask(String),
Deny,
}
pub struct PermissionStore {
overrides: HashMap<String, PermissionLevel>,
session_grants: HashSet<String>,
}
impl PermissionStore {
pub fn new() -> Self {
Self {
overrides: HashMap::new(),
session_grants: HashSet::new(),
}
}
pub fn check(&self, tool_name: &str, approval: &ApprovalRequirement) -> PermissionDecision {
if let ApprovalRequirement::RequireApprovalAlways(reason) = approval {
return PermissionDecision::Ask(reason.clone());
}
if self.session_grants.contains(tool_name) {
return PermissionDecision::Allow;
}
if let ApprovalRequirement::RequireApproval(reason) = approval {
return PermissionDecision::Ask(reason.clone());
}
if let Some(level) = self.overrides.get(tool_name) {
match level {
PermissionLevel::AlwaysAllow | PermissionLevel::SessionAllow => {
return PermissionDecision::Allow;
}
PermissionLevel::AlwaysDeny => return PermissionDecision::Deny,
PermissionLevel::Ask => {}
}
}
PermissionDecision::Allow
}
pub fn grant_session(&mut self, tool_name: &str) {
self.session_grants.insert(tool_name.to_string());
}
pub fn set_override(&mut self, tool_name: &str, level: PermissionLevel) {
self.overrides.insert(tool_name.to_string(), level);
}
}
pub type ReadCacheKey = (PathBuf, Option<usize>, Option<usize>);
pub type ReadCacheEntry = (std::time::SystemTime, String, usize);
#[derive(Clone)]
pub struct ToolContext {
pub working_dir: Arc<RwLock<PathBuf>>,
pub semantic: Arc<Mutex<crate::semantic::SemanticSearcher>>,
pub file_history: Arc<Mutex<file_history::FileHistory>>,
pub graph: Arc<RwLock<crate::graph::CodeGraph>>,
pub ctx_budget_hint: Arc<std::sync::atomic::AtomicUsize>,
pub read_budget_tokens: Arc<std::sync::atomic::AtomicUsize>,
pub read_cache: Arc<RwLock<std::collections::HashMap<ReadCacheKey, ReadCacheEntry>>>,
pub first_error_signatures: Arc<RwLock<Vec<String>>>,
pub telemetry: std::sync::Arc<atomcode_telemetry::Telemetry>,
pub lsp: Option<std::sync::Arc<crate::lsp::manager::LspManager>>,
pub event_tx: Option<Arc<tokio::sync::mpsc::UnboundedSender<crate::turn::event::TurnEvent>>>,
pub current_call_id: Option<String>,
pub tool_registry: Option<Arc<ToolRegistry>>,
pub file_store: Arc<RwLock<crate::ctx::file_store::FileStore>>,
}
impl ToolContext {
pub fn new(working_dir: PathBuf) -> Self {
let telemetry = disabled_telemetry();
Self::with_telemetry(working_dir, "default", telemetry)
}
pub fn with_session(working_dir: PathBuf, session_id: &str) -> Self {
let telemetry = disabled_telemetry();
Self::with_telemetry(working_dir, session_id, telemetry)
}
pub fn with_telemetry(
working_dir: PathBuf,
session_id: &str,
telemetry: std::sync::Arc<atomcode_telemetry::Telemetry>,
) -> Self {
Self {
working_dir: Arc::new(RwLock::new(working_dir)),
semantic: Arc::new(Mutex::new(crate::semantic::SemanticSearcher::new())),
file_history: Arc::new(Mutex::new(file_history::FileHistory::new(session_id))),
ctx_budget_hint: Arc::new(std::sync::atomic::AtomicUsize::new(usize::MAX)),
read_budget_tokens: Arc::new(std::sync::atomic::AtomicUsize::new(usize::MAX)),
graph: Arc::new(RwLock::new(crate::graph::CodeGraph::new())),
read_cache: Arc::new(RwLock::new(std::collections::HashMap::new())),
first_error_signatures: Arc::new(RwLock::new(Vec::new())),
telemetry,
lsp: None,
event_tx: None,
current_call_id: None,
tool_registry: None,
file_store: Arc::new(RwLock::new(crate::ctx::file_store::FileStore::new())),
}
}
pub async fn isolate(&self) -> Self {
let wd = self.working_dir.read().await.clone();
let mut ctx = Self::new(wd);
ctx.graph = self.graph.clone();
ctx.telemetry = self.telemetry.clone();
ctx.lsp = self.lsp.clone();
ctx.file_store = self.file_store.clone();
ctx
}
pub async fn notify_lsp_file_changed(&self, path: &Path, content: &str) {
if let Some(ref lsp) = self.lsp {
if let Err(e) = lsp.notify_file_changed(path, content).await {
tracing::warn!(
"[lsp] Failed to refresh diagnostics for {}: {}",
path.display(),
e
);
}
}
}
}
fn disabled_telemetry() -> std::sync::Arc<atomcode_telemetry::Telemetry> {
let cfg = atomcode_telemetry::ResolvedConfig {
state: atomcode_telemetry::TelemetryState::Disabled("default"),
endpoint: "http://localhost/v1/events".into(),
atomcode_dir: std::path::PathBuf::from("/tmp"),
};
atomcode_telemetry::Telemetry::init(cfg, env!("CARGO_PKG_VERSION").into())
}
pub fn extract_error_signatures(output: &str) -> Vec<String> {
let mut lines: Vec<String> = Vec::new();
for line in output.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
if trimmed.starts_with('[') {
continue;
}
if trimmed == "STDERR:" {
continue;
}
if trimmed.len() < 15 {
continue;
}
let s: String = trimmed.chars().take(120).collect();
if !lines.contains(&s) {
lines.push(s);
}
}
lines.sort_by_key(|s| std::cmp::Reverse(s.len()));
lines.into_iter().take(5).collect()
}
#[async_trait]
pub trait Tool: Send + Sync {
fn definition(&self) -> ToolDef;
fn approval(&self, args: &str) -> ApprovalRequirement;
fn approval_with_context(&self, args: &str, _ctx: &ToolContext) -> ApprovalRequirement {
self.approval(args)
}
async fn execute(&self, args: &str, ctx: &ToolContext) -> Result<ToolResult>;
fn validate_args(&self, _args: &str) -> std::result::Result<(), String> {
Ok(())
}
}
pub struct ToolRegistry {
tools: tokio::sync::RwLock<BTreeMap<String, Arc<dyn Tool>>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: tokio::sync::RwLock::new(BTreeMap::new()),
}
}
pub async fn register(&self, tool: Box<dyn Tool>) {
let name = tool.definition().name.to_string();
let mut tools = self.tools.write().await;
tools.insert(name, Arc::from(tool));
}
pub fn register_sync(&mut self, tool: Box<dyn Tool>) {
let name = tool.definition().name.to_string();
self.tools.get_mut().insert(name, Arc::from(tool));
}
pub async fn get_definitions(&self) -> Vec<ToolDef> {
let tools = self.tools.read().await;
tools.values().map(|t| t.definition()).collect()
}
pub async fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
let tools = self.tools.read().await;
tools.get(name).cloned()
}
pub async fn iter(&self) -> impl Iterator<Item = (String, Arc<dyn Tool>)> {
let tools = self.tools.read().await;
tools.iter().map(|(k, v)| (k.clone(), v.clone())).collect::<Vec<_>>().into_iter()
}
pub async fn register_arc(&self, name: String, tool: Arc<dyn Tool>) {
let mut tools = self.tools.write().await;
tools.insert(name, tool);
}
pub async fn expected_top_keys(&self, name: &str) -> Vec<String> {
let tools = self.tools.read().await;
let Some(tool) = tools.get(name) else { return Vec::new() };
let def = tool.definition();
def.parameters
.get("properties")
.and_then(|p| p.as_object())
.map(|o| o.keys().cloned().collect())
.unwrap_or_default()
}
pub async fn unregister_prefix(&self, prefix: &str) -> usize {
let mut tools = self.tools.write().await;
let to_remove: Vec<String> = tools
.keys()
.filter(|k| k.starts_with(prefix))
.cloned()
.collect();
let n = to_remove.len();
for k in to_remove {
tools.remove(&k);
}
n
}
}
const ARGS_WRAPPER_KEYS: &[&str] = &["arguments", "input", "content"];
pub fn recover_tool_args(raw: &str, expected_top_keys: &[String]) -> Option<String> {
let mut value: serde_json::Value = serde_json::from_str(raw).ok()?;
if !value.is_object() {
return None;
}
if !expected_top_keys.is_empty() && all_keys_in_expected(&value, expected_top_keys) {
return None;
}
let mut progressed = false;
for _ in 0..5 {
match try_unwrap_once(value, expected_top_keys) {
UnwrapStep::Stable(v) => {
value = v;
break;
}
UnwrapStep::Progressed(v) => {
value = v;
progressed = true;
}
}
}
if !progressed {
return None;
}
if !expected_top_keys.is_empty() && !has_expected_key(&value, expected_top_keys) {
return None;
}
if has_wrapper_shape(&value) {
return None;
}
serde_json::to_string(&value).ok()
}
fn has_expected_key(v: &serde_json::Value, expected: &[String]) -> bool {
let Some(map) = v.as_object() else { return false };
expected.iter().any(|k| map.contains_key(k.as_str()))
}
fn all_keys_in_expected(v: &serde_json::Value, expected: &[String]) -> bool {
let Some(map) = v.as_object() else { return false };
if map.is_empty() {
return false;
}
map.keys().all(|k| expected.iter().any(|e| e == k))
}
fn has_wrapper_shape(v: &serde_json::Value) -> bool {
let Some(map) = v.as_object() else { return false };
ARGS_WRAPPER_KEYS.iter().any(|k| {
map.get(*k).is_some_and(|inner| {
if inner.is_object() {
return true;
}
if let Some(s) = inner.as_str() {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s) {
return parsed.is_object();
}
}
false
})
})
}
enum UnwrapStep {
Progressed(serde_json::Value),
Stable(serde_json::Value),
}
fn try_unwrap_once(value: serde_json::Value, expected: &[String]) -> UnwrapStep {
let Some(map) = value.as_object() else {
return UnwrapStep::Stable(value);
};
let mut wrapper_key: Option<&str> = None;
let mut inner_obj: Option<serde_json::Value> = None;
for &k in ARGS_WRAPPER_KEYS {
let Some(v) = map.get(k) else { continue };
if let Some(obj) = v.as_object() {
wrapper_key = Some(k);
inner_obj = Some(serde_json::Value::Object(obj.clone()));
break;
}
if let Some(s) = v.as_str() {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s) {
if parsed.is_object() {
wrapper_key = Some(k);
inner_obj = Some(parsed);
break;
}
}
}
}
let (Some(wk), Some(mut inner)) = (wrapper_key, inner_obj) else {
return UnwrapStep::Stable(value);
};
if let Some(inner_map) = inner.as_object_mut() {
for (k, v) in map.iter() {
if k == wk {
continue;
}
if expected.iter().any(|e| e == k) && !inner_map.contains_key(k) {
inner_map.insert(k.clone(), v.clone());
}
}
}
UnwrapStep::Progressed(inner)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
struct DummyTool;
#[async_trait::async_trait]
impl Tool for DummyTool {
fn definition(&self) -> ToolDef {
ToolDef {
name: "dummy",
description: "A dummy tool".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {},
}),
}
}
fn approval(&self, _args: &str) -> ApprovalRequirement {
ApprovalRequirement::AutoApprove
}
async fn execute(&self, _args: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
Ok(ToolResult {
call_id: "test".to_string(),
output: "ok".to_string(),
success: true,
})
}
}
#[tokio::test]
async fn test_registry_register_and_get() {
let reg = ToolRegistry::new();
reg.register(Box::new(DummyTool)).await;
assert!(reg.get("dummy").await.is_some());
assert!(reg.get("nonexistent").await.is_none());
}
#[tokio::test]
async fn test_registry_definitions() {
let reg = ToolRegistry::new();
reg.register(Box::new(DummyTool)).await;
let defs = reg.get_definitions().await;
assert_eq!(defs.len(), 1);
assert_eq!(defs[0].name, "dummy");
}
#[test]
fn sensitive_path_detects_relative_traversal_to_unix_root() {
assert!(is_sensitive_input_path_with_context(
"../../../etc/passwd",
Some(Path::new("/home/alice/project")),
Some(Path::new("/home/alice")),
));
}
#[test]
fn sensitive_path_detects_windows_system_roots() {
assert!(is_sensitive_input_path_with_context(
r"C:\Windows\System32\drivers\etc\hosts",
None,
None,
));
assert!(is_sensitive_input_path_with_context(
r"D:\Windows\System32\drivers\etc\hosts",
None,
None,
));
assert!(is_sensitive_input_path_with_context(
r"\Windows\System32\drivers\etc\hosts",
None,
None,
));
assert!(is_sensitive_input_path_with_context(
r"C:\Program Files\AtomCode\config.toml",
None,
None,
));
assert!(is_sensitive_input_path_with_context(
r"C:\ProgramData\AtomCode\config.toml",
None,
None,
));
}
#[test]
fn sensitive_path_uses_path_boundaries() {
assert!(!is_sensitive_input_path_with_context(
"/etc-old/passwd",
None,
None,
));
assert!(!is_sensitive_input_path_with_context(
r"C:\Windows.old\system.ini",
None,
None,
));
assert!(!is_sensitive_input_path_with_context(
r"D:\Windows.old\system.ini",
None,
None,
));
assert!(!is_sensitive_input_path_with_context(
r"\Windows.old\system.ini",
None,
None,
));
assert!(!is_sensitive_input_path_with_context(
r"\\server\share\Windows\system.ini",
None,
None,
));
}
#[tokio::test]
async fn test_tool_execute() {
let tool = DummyTool;
let ctx = ToolContext::new(std::env::current_dir().unwrap());
let result = tool.execute("{}", &ctx).await.unwrap();
assert!(result.success);
assert_eq!(result.output, "ok");
}
#[test]
fn resolve_workspace_path_rejects_parent_escape() {
let workspace = TempDir::new().unwrap();
let outside = TempDir::new().unwrap();
let path = format!("{}/secret.txt", outside.path().display());
std::fs::write(outside.path().join("secret.txt"), "top-secret").unwrap();
let err = resolve_workspace_path(&path, workspace.path()).unwrap_err();
assert!(err.to_string().contains("outside working directory"));
}
#[cfg(unix)]
#[test]
fn resolve_workspace_path_rejects_symlink_escape() {
let workspace = TempDir::new().unwrap();
let outside = TempDir::new().unwrap();
let target = outside.path().join("secret.txt");
std::fs::write(&target, "top-secret").unwrap();
let link = workspace.path().join("secret-link");
std::os::unix::fs::symlink(&target, &link).unwrap();
let err =
resolve_workspace_path(link.to_string_lossy().as_ref(), workspace.path()).unwrap_err();
assert!(err.to_string().contains("outside working directory"));
}
#[test]
fn inspect_path_access_marks_workspace_escape() {
let workspace = TempDir::new().unwrap();
let outside = TempDir::new().unwrap();
let target = outside.path().join("secret.txt");
std::fs::write(&target, "top-secret").unwrap();
let access = inspect_path_access(&target.to_string_lossy(), workspace.path()).unwrap();
assert!(!access.within_workspace);
assert_eq!(access.path, target.canonicalize().unwrap());
}
#[test]
fn approval_for_non_sensitive_enumeration_outside_workspace_is_auto() {
let workspace = TempDir::new().unwrap();
let outside = TempDir::new().unwrap();
let approval = approval_for_path(
&outside.path().to_string_lossy(),
workspace.path(),
ExternalPathAction::Enumerate,
)
.unwrap();
assert!(matches!(approval, ApprovalRequirement::AutoApprove));
}
#[test]
fn approval_for_non_sensitive_read_outside_workspace_requires_confirmation() {
let workspace = TempDir::new().unwrap();
let outside = TempDir::new().unwrap();
let target = outside.path().join("notes.txt");
std::fs::write(&target, "hello").unwrap();
let approval = approval_for_path(
&target.to_string_lossy(),
workspace.path(),
ExternalPathAction::Read,
)
.unwrap();
assert!(matches!(approval, ApprovalRequirement::RequireApproval(_)));
}
#[test]
fn approval_for_sensitive_read_outside_workspace_requires_always() {
let workspace = TempDir::new().unwrap();
let outside = TempDir::new().unwrap();
let target = outside.path().join("id_rsa");
std::fs::write(&target, "private-key").unwrap();
let approval = approval_for_path(
&target.to_string_lossy(),
workspace.path(),
ExternalPathAction::Read,
)
.unwrap();
assert!(matches!(
approval,
ApprovalRequirement::RequireApprovalAlways(_)
));
}
#[test]
fn approval_for_system_protected_prefix_requires_always() {
assert!(is_sensitive_path(Path::new(
"/System/Library/CoreServices/boot.efi"
)));
}
#[test]
fn approval_for_usr_local_exception_is_not_sensitive() {
assert!(!is_sensitive_path(Path::new("/usr/local/bin/tool")));
}
#[test]
fn approval_for_private_var_prefix_requires_always() {
assert!(is_sensitive_path(Path::new("/private/var/db/config")));
}
#[test]
fn approval_for_private_var_folders_exception_is_not_sensitive() {
assert!(!is_sensitive_path(Path::new(
"/private/var/folders/xx/yy/T/file.txt"
)));
}
#[test]
fn approval_for_write_outside_workspace_requires_always() {
let workspace = TempDir::new().unwrap();
let outside = TempDir::new().unwrap();
let target = outside.path().join("notes.txt");
let approval = approval_for_path(
&target.to_string_lossy(),
workspace.path(),
ExternalPathAction::Write,
)
.unwrap();
assert!(matches!(
approval,
ApprovalRequirement::RequireApprovalAlways(_)
));
}
#[tokio::test]
async fn read_file_requests_approval_for_workspace_escape() {
let workspace = TempDir::new().unwrap();
let outside = TempDir::new().unwrap();
let target = outside.path().join("secret.txt");
std::fs::write(&target, "top-secret").unwrap();
let tool = crate::tool::read::ReadFileTool;
let ctx = ToolContext::new(workspace.path().to_path_buf());
let args = format!(r#"{{"file_path":"{}"}}"#, target.display());
assert!(matches!(
tool.approval_with_context(&args, &ctx),
ApprovalRequirement::RequireApproval(_)
));
}
#[tokio::test]
async fn edit_file_requests_approval_for_workspace_escape() {
let workspace = TempDir::new().unwrap();
let outside = TempDir::new().unwrap();
let target = outside.path().join("secret.txt");
std::fs::write(&target, "top-secret").unwrap();
let tool = crate::tool::edit::EditFileTool;
let ctx = ToolContext::new(workspace.path().to_path_buf());
let args = format!(
r#"{{"file_path":"{}","old_string":"top-secret","new_string":"changed"}}"#,
target.display()
);
assert!(matches!(
tool.approval_with_context(&args, &ctx),
ApprovalRequirement::RequireApprovalAlways(_)
));
}
#[test]
fn test_permission_store_auto_approve() {
let store = PermissionStore::new();
let decision = store.check("bash", &ApprovalRequirement::AutoApprove);
assert!(matches!(decision, PermissionDecision::Allow));
}
#[test]
fn test_permission_store_require_approval() {
let store = PermissionStore::new();
let decision = store.check(
"bash",
&ApprovalRequirement::RequireApproval("Destructive".into()),
);
assert!(matches!(decision, PermissionDecision::Ask(_)));
}
#[test]
fn test_permission_store_session_grant_bypasses_require_approval() {
let mut store = PermissionStore::new();
store.grant_session("bash");
let decision = store.check(
"bash",
&ApprovalRequirement::RequireApproval("non-destructive".into()),
);
assert!(matches!(decision, PermissionDecision::Allow));
}
#[test]
fn test_permission_store_session_grant_does_not_bypass_require_approval_always() {
let mut store = PermissionStore::new();
store.grant_session("bash");
let decision = store.check(
"bash",
&ApprovalRequirement::RequireApprovalAlways("Sensitive".into()),
);
assert!(matches!(decision, PermissionDecision::Ask(_)));
}
#[test]
fn test_permission_store_session_grant_allows_auto_approve() {
let mut store = PermissionStore::new();
store.grant_session("bash");
let decision = store.check("bash", &ApprovalRequirement::AutoApprove);
assert!(matches!(decision, PermissionDecision::Allow));
}
#[test]
fn test_permission_store_always_deny_override() {
let mut store = PermissionStore::new();
store.set_override("bash", PermissionLevel::AlwaysDeny);
let decision = store.check("bash", &ApprovalRequirement::AutoApprove);
assert!(matches!(decision, PermissionDecision::Deny));
}
#[test]
fn test_permission_store_always_allow_cannot_bypass_destructive() {
let mut store = PermissionStore::new();
store.set_override("bash", PermissionLevel::AlwaysAllow);
let decision = store.check(
"bash",
&ApprovalRequirement::RequireApproval("Destructive".into()),
);
assert!(matches!(decision, PermissionDecision::Ask(_)));
}
#[tokio::test]
async fn test_tool_context_isolate() {
let ctx = ToolContext::new(PathBuf::from("/original"));
let isolated = ctx.isolate().await;
*isolated.working_dir.write().await = PathBuf::from("/changed");
let original_wd = ctx.working_dir.read().await.clone();
assert_eq!(original_wd, PathBuf::from("/original"));
}
#[tokio::test]
async fn test_registry_iter() {
let reg = ToolRegistry::new();
reg.register(Box::new(DummyTool)).await;
let items: Vec<_> = reg.iter().await.collect();
assert_eq!(items.len(), 1);
assert_eq!(items[0].0, "dummy");
}
#[tokio::test]
async fn test_registry_register_arc() {
let reg1 = ToolRegistry::new();
reg1.register(Box::new(DummyTool)).await;
let reg2 = ToolRegistry::new();
for (name, arc) in reg1.iter().await {
reg2.register_arc(name, arc).await;
}
assert!(reg2.get("dummy").await.is_some());
}
#[test]
fn test_permission_store_session_grant_only_affects_named_tool() {
let mut store = PermissionStore::new();
store.grant_session("bash");
let decision = store.check(
"create_file",
&ApprovalRequirement::RequireApproval("write".into()),
);
assert!(matches!(decision, PermissionDecision::Ask(_)));
}
fn cmd_keys() -> Vec<String> {
vec!["command".into(), "timeout".into()]
}
fn read_keys() -> Vec<String> {
vec!["file_path".into(), "offset".into(), "limit".into()]
}
fn grep_keys() -> Vec<String> {
vec!["pattern".into(), "path".into(), "max_results".into(), "context".into()]
}
fn write_keys() -> Vec<String> {
vec!["file_path".into(), "content".into()]
}
fn todo_keys() -> Vec<String> {
vec!["action".into(), "content".into(), "id".into()]
}
fn parse(s: &str) -> serde_json::Value {
serde_json::from_str(s).unwrap()
}
#[test]
fn recover_flat_passes_through() {
let raw = r#"{"command":"ls -la"}"#;
assert!(recover_tool_args(raw, &cmd_keys()).is_none());
}
#[test]
fn recover_variant_a1_string_inner() {
let raw = r#"{"arguments":"{\"command\":\"ls\"}"}"#;
let recovered = recover_tool_args(raw, &cmd_keys()).unwrap();
assert_eq!(parse(&recovered)["command"], "ls");
}
#[test]
fn recover_variant_a2_object_inner() {
let raw = r#"{"arguments":{"command":"ls","timeout":30}}"#;
let recovered = recover_tool_args(raw, &cmd_keys()).unwrap();
let v = parse(&recovered);
assert_eq!(v["command"], "ls");
assert_eq!(v["timeout"], 30);
}
#[test]
fn recover_variant_b_double_string() {
let raw = r#"{"arguments":"{\"arguments\":\"{\\\"command\\\":\\\"ls\\\"}\"}"}"#;
let recovered = recover_tool_args(raw, &cmd_keys()).unwrap();
assert_eq!(parse(&recovered)["command"], "ls");
}
#[test]
fn recover_variant_b_triple_object() {
let raw = r#"{"arguments":{"arguments":{"command":"ls"}}}"#;
let recovered = recover_tool_args(raw, &cmd_keys()).unwrap();
assert_eq!(parse(&recovered)["command"], "ls");
}
#[test]
fn recover_variant_c_multi_key_merges_siblings() {
let raw = r#"{"arguments":"{\"command\":\"ls\"}","timeout":120}"#;
let recovered = recover_tool_args(raw, &cmd_keys()).unwrap();
let v = parse(&recovered);
assert_eq!(v["command"], "ls");
assert_eq!(v["timeout"], 120);
}
#[test]
fn recover_variant_d_content_wrapper() {
let raw = r#"{"content":"{\"pattern\":\"foo\",\"path\":\"/x\"}"}"#;
let recovered = recover_tool_args(raw, &grep_keys()).unwrap();
let v = parse(&recovered);
assert_eq!(v["pattern"], "foo");
assert_eq!(v["path"], "/x");
}
#[test]
fn recover_variant_d_input_wrapper() {
let raw = r#"{"input":{"file_path":"/tmp/a.rs"}}"#;
let recovered = recover_tool_args(raw, &read_keys()).unwrap();
assert_eq!(parse(&recovered)["file_path"], "/tmp/a.rs");
}
#[test]
fn recover_unrecoverable_returns_none() {
let raw = r#"{"arguments":{"random":"junk"}}"#;
assert!(recover_tool_args(raw, &cmd_keys()).is_none());
}
#[test]
fn recover_iteration_bound_pathological_input() {
let mut deep = String::from(r#"{"command":"ls"}"#);
for _ in 0..100 {
deep = format!(r#"{{"arguments":{}}}"#, deep);
}
let result = recover_tool_args(&deep, &cmd_keys());
assert!(result.is_none() || result.is_some());
}
#[test]
fn recover_no_expected_keys_falls_back_permissive() {
let wrapped = r#"{"arguments":{"x":1}}"#;
let recovered = recover_tool_args(wrapped, &[]).unwrap();
assert_eq!(parse(&recovered)["x"], 1);
let flat = r#"{"x":1}"#;
assert!(recover_tool_args(flat, &[]).is_none());
}
#[test]
fn recover_real_datalog_payload() {
let raw = r#"{"arguments": "{\"command\": \"cd /Users/lichao/project/gitcode/ai/atomcode && cargo check 2>&1 | grep -iE 'warning.*(dead_code|unused)' | head -20\"}"}"#;
let recovered = recover_tool_args(raw, &cmd_keys()).unwrap();
let v = parse(&recovered);
assert!(v["command"].as_str().unwrap().contains("cargo check"));
}
#[test]
fn recover_real_bruno_object_payload() {
let raw = r#"{"arguments": {"command": "grep -rn '#\\[allow(dead_code)\\]' /Users/lichao/project/gitcode/ai/atomcode/crates/ --include='*.rs' | head -50", "timeout": 10}}"#;
let recovered = recover_tool_args(raw, &cmd_keys()).unwrap();
let v = parse(&recovered);
assert_eq!(v["timeout"], 10);
assert!(v["command"].as_str().unwrap().contains("dead_code"));
}
#[test]
fn recover_malformed_json_returns_none() {
assert!(recover_tool_args("not json", &cmd_keys()).is_none());
assert!(recover_tool_args("", &cmd_keys()).is_none());
assert!(recover_tool_args("[]", &cmd_keys()).is_none());
}
#[test]
fn recover_write_with_json_object_content_passthrough() {
let raw = r#"{"file_path":"/tmp/x.json","content":"{\"foo\":1}"}"#;
assert!(recover_tool_args(raw, &write_keys()).is_none());
}
#[test]
fn recover_write_with_nested_json_content_passthrough() {
let raw = r#"{"file_path":"/tmp/cfg.json","content":"{\"a\":{\"b\":{\"c\":1}}}"}"#;
assert!(recover_tool_args(raw, &write_keys()).is_none());
}
#[test]
fn recover_todo_with_json_content_passthrough() {
let raw = r#"{"action":"add","content":"{\"task\":\"refactor\"}"}"#;
assert!(recover_tool_args(raw, &todo_keys()).is_none());
}
#[test]
fn recover_write_genuine_wrap_still_recovered() {
let raw = r#"{"arguments":{"file_path":"/tmp/x","content":"hello"}}"#;
let recovered = recover_tool_args(raw, &write_keys()).unwrap();
let v = parse(&recovered);
assert_eq!(v["file_path"], "/tmp/x");
assert_eq!(v["content"], "hello");
}
#[test]
fn recover_partial_keys_still_recoverable_via_wrapper() {
let raw = r#"{"arguments":"{\"command\":\"ls\"}","foo":1}"#;
let recovered = recover_tool_args(raw, &cmd_keys()).unwrap();
assert_eq!(parse(&recovered)["command"], "ls");
}
#[test]
fn test_real_home_dir_returns_something() {
let home = real_home_dir();
assert!(home.is_some(), "real_home_dir should return Some in normal conditions");
let path = home.unwrap();
assert!(path.is_absolute(), "Home directory should be an absolute path");
}
#[test]
fn test_real_home_dir_with_simulated_sudo() {
let original_sudo_user = std::env::var("SUDO_USER").ok();
let original_home = std::env::var("HOME").ok();
#[cfg(unix)]
{
let normal_home = dirs::home_dir();
if let Some(ref home) = normal_home {
assert!(home.is_absolute());
}
}
if let Some(orig) = original_sudo_user {
std::env::set_var("SUDO_USER", orig);
} else {
std::env::remove_var("SUDO_USER");
}
if let Some(orig) = original_home {
std::env::set_var("HOME", orig);
}
}
#[test]
fn test_expand_user_path_with_tilde() {
let home = real_home_dir().unwrap();
let expanded = expand_user_path("~/test");
assert_eq!(expanded, home.join("test"));
let expanded = expand_user_path("~");
assert_eq!(expanded, home);
let expanded = expand_user_path("/absolute/path");
assert_eq!(expanded, PathBuf::from("/absolute/path"));
}
#[test]
#[cfg(target_os = "windows")]
fn approval_for_windows_system_protected_prefix_requires_always() {
assert!(is_sensitive_path(Path::new(r"C:\Windows\System32\config.sys")));
assert!(is_sensitive_path(Path::new(r"C:\Program Files\SomeApp\app.exe")));
assert!(is_sensitive_path(Path::new(r"C:\ProgramData\secrets.txt")));
}
#[test]
#[cfg(not(target_os = "windows"))]
fn approval_for_unix_system_protected_prefix_requires_always() {
assert!(is_sensitive_path(Path::new("/System/Library/CoreServices/boot.efi")));
assert!(is_sensitive_path(Path::new("/etc/passwd")));
assert!(is_sensitive_path(Path::new("/var/log/syslog")));
}
}