use anyhow::Result;
use async_trait::async_trait;
use serde::Deserialize;
use serde_json::json;
use super::{ApprovalRequirement, Tool, ToolContext, ToolDef, ToolResult};
pub struct WriteFileTool;
#[derive(Deserialize)]
struct WriteFileArgs {
file_path: String,
content: String,
}
#[async_trait]
impl Tool for WriteFileTool {
fn definition(&self) -> ToolDef {
ToolDef {
name: "write_file",
description:
"Write content to a file. Creates new files or overwrites existing ones.\n\
Use this for: creating new files, or rewriting an entire file from scratch.\n\
For small edits to existing files, prefer edit_file instead.\n\
Parent directories are auto-created if they don't exist."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"file_path": { "type": "string", "description": "Absolute path to the file" },
"content": { "type": "string", "description": "The full content to write" }
},
"required": ["file_path", "content"]
}),
}
}
fn validate_args(&self, args: &str) -> std::result::Result<(), String> {
super::diagnose_args(
"write_file",
args,
&[&["file_path", "content"]],
"write_file({\"file_path\": \"<absolute path>\", \"content\": \"<file body>\"})",
)?;
serde_json::from_str::<WriteFileArgs>(args)
.map(|_| ())
.map_err(|e| {
format!(
"write_file: {e}. Re-issue with file_path as a string and content as a string."
)
})
}
fn approval(&self, args: &str) -> ApprovalRequirement {
let parsed = match serde_json::from_str::<WriteFileArgs>(args) {
Ok(p) => p,
Err(_) => {
return ApprovalRequirement::RequireApproval(
"Could not parse create_file arguments for safety check.".to_string(),
);
}
};
if super::is_sensitive_input_path(&parsed.file_path) {
return ApprovalRequirement::RequireApproval(
format!("Writing to sensitive system path: {}", parsed.file_path),
);
}
ApprovalRequirement::AutoApprove
}
fn approval_with_context(&self, args: &str, ctx: &ToolContext) -> ApprovalRequirement {
let base = self.approval(args);
let parsed = match serde_json::from_str::<WriteFileArgs>(args) {
Ok(parsed) => parsed,
Err(_) => return base,
};
let working_dir = match ctx.working_dir.try_read() {
Ok(wd) => wd.clone(),
Err(_) => return base,
};
match super::approval_for_path(
&parsed.file_path,
&working_dir,
super::ExternalPathAction::Write,
) {
Ok(ApprovalRequirement::RequireApprovalAlways(reason)) => {
ApprovalRequirement::RequireApprovalAlways(reason)
}
Ok(ApprovalRequirement::RequireApproval(reason)) => {
ApprovalRequirement::RequireApproval(reason)
}
Ok(ApprovalRequirement::AutoApprove) => match base {
ApprovalRequirement::RequireApproval(reason) => {
ApprovalRequirement::RequireApprovalAlways(reason)
}
other => other,
},
Err(_) => base,
}
}
async fn execute(&self, args: &str, ctx: &ToolContext) -> Result<ToolResult> {
if let Err(msg) = super::diagnose_args(
"write_file",
args,
&[&["file_path", "content"]],
"write_file({\"file_path\": \"<absolute path>\", \"content\": \"<file body>\"})",
) {
return Ok(ToolResult {
call_id: String::new(),
output: msg,
success: false,
});
}
let parsed: WriteFileArgs = match serde_json::from_str(args) {
Ok(p) => p,
Err(e) => {
return Ok(ToolResult {
call_id: String::new(),
output: format!(
"write_file: {e}. Re-issue with file_path as a string and content as a string."
),
success: false,
});
}
};
let working_dir = ctx.working_dir.read().await.clone();
let path = match super::inspect_path_access(&parsed.file_path, &working_dir) {
Ok(access) => access.path,
Err(err) => {
return Ok(ToolResult {
call_id: String::new(),
output: err.to_string(),
success: false,
});
}
};
ctx.file_history
.lock()
.await
.backup_before_write(&path.to_string_lossy())
.await;
let overwrite_info = if path.exists() {
let old_lines = tokio::fs::read_to_string(&path)
.await
.map(|c| c.lines().count())
.unwrap_or(0);
Some(old_lines)
} else {
None
};
if let Some(parent) = path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let new_lines = parsed.content.lines().count();
let bytes = parsed.content.len();
tokio::fs::write(&path, &parsed.content).await?;
ctx.file_store.write().await.invalidate(&path);
ctx.read_cache
.write()
.await
.retain(|(p, _, _), _| p != &path);
ctx.notify_lsp_file_changed(&path, &parsed.content).await;
let output = if let Some(old_lines) = overwrite_info {
let diff = new_lines as i64 - old_lines as i64;
let sign = if diff >= 0 { "+" } else { "" };
let mut msg = format!(
"Overwrote {} (was {} lines, now {} lines, {}{})",
path.display(),
old_lines,
new_lines,
sign,
diff
);
if old_lines > 20 && new_lines < old_lines / 2 {
msg.push_str(&format!(
"\n⚠ WARNING: File shrank by {}%. Verify no important code was lost. Use /undo to revert if needed.",
100 - (new_lines * 100 / old_lines)
));
}
msg
} else {
format!(
"Created new file {} ({} bytes, {} lines)",
path.display(),
bytes,
new_lines
)
};
Ok(ToolResult {
call_id: String::new(),
output,
success: true,
})
}
}