use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::process::Stdio;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::time::{timeout, Duration};
use crate::hook::{
Hook, HookCtx, HookResult, PreToolExecutionHook, PostToolExecutionHook,
PostTurnHook, SystemPromptHook, ToolResultContext,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScriptHookConfig {
pub name: String,
pub trigger: String,
pub script: PathBuf,
#[serde(default = "default_script_type")]
pub script_type: String,
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_timeout")]
pub timeout_secs: u64,
#[serde(default)]
pub description: String,
}
fn default_script_type() -> String {
"shell".to_string()
}
fn default_true() -> bool {
true
}
fn default_timeout() -> u64 {
2
}
pub struct ScriptHook {
config: ScriptHookConfig,
}
impl ScriptHook {
pub fn new(config: ScriptHookConfig) -> Self {
Self { config }
}
async fn run_script(&self, input_json: &str) -> Result<String, String> {
let script_path = &self.config.script;
if !script_path.exists() {
return Err(format!("Script not found: {}", script_path.display()));
}
let (cmd, args) = match self.config.script_type.as_str() {
"python" => ("python", vec![script_path.to_string_lossy().to_string()]),
"shell" | "bash" => {
if cfg!(windows) {
("cmd", vec!["/C".to_string(), script_path.to_string_lossy().to_string()])
} else {
("sh", vec![script_path.to_string_lossy().to_string()])
}
}
_ => return Err(format!("Unsupported script type: {}", self.config.script_type)),
};
let mut child = tokio::process::Command::new(cmd)
.args(&args)
.kill_on_drop(true)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| format!("Failed to spawn script: {}", e))?;
if let Some(mut stdin) = child.stdin.take() {
stdin
.write_all(input_json.as_bytes())
.await
.map_err(|e| format!("Failed to write to script: {}", e))?;
}
let result = timeout(
Duration::from_secs(self.config.timeout_secs),
Self::wait_for_output(&mut child),
)
.await
.map_err(|_| "Script execution timed out".to_string())?;
result
}
async fn wait_for_output(child: &mut tokio::process::Child) -> Result<String, String> {
let mut stdout = String::new();
let mut stderr = String::new();
if let Some(ref mut out) = child.stdout {
out.read_to_string(&mut stdout)
.await
.map_err(|e| format!("Failed to read stdout: {}", e))?;
}
if let Some(ref mut err) = child.stderr {
err.read_to_string(&mut stderr)
.await
.map_err(|e| format!("Failed to read stderr: {}", e))?;
}
let status = child
.wait()
.await
.map_err(|e| format!("Script failed: {}", e))?;
if !status.success() {
return Err(format!("Script exited with status {}: {}", status, stderr));
}
Ok(stdout.trim().to_string())
}
fn parse_output(&self, output: &str) -> HookResult {
let output = output.trim();
if output.is_empty() {
return HookResult::Ok;
}
if let Ok(json) = serde_json::from_str::<serde_json::Value>(output) {
if let Some(result) = json.get("result").and_then(|v| v.as_str()) {
let message = json
.get("message")
.or_else(|| json.get("reason"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
return match result {
"ok" => HookResult::Ok,
"warning" => HookResult::Warning(message),
"deny" => HookResult::Denied(message),
"modify" => HookResult::Modified(message),
_ => HookResult::Warning(format!("Unknown result: {}", result)),
};
}
}
if output.starts_with("warning:") || output.starts_with("WARN:") {
return HookResult::Warning(output.splitn(2, ':').nth(1).unwrap_or("").trim().to_string());
}
if output.starts_with("deny:") || output.starts_with("DENY:") {
return HookResult::Denied(output.splitn(2, ':').nth(1).unwrap_or("").trim().to_string());
}
if output.starts_with("modify:") || output.starts_with("MODIFY:") {
return HookResult::Modified(output.splitn(2, ':').nth(1).unwrap_or("").trim().to_string());
}
HookResult::Ok
}
}
impl Hook for ScriptHook {
fn name(&self) -> &str {
&self.config.name
}
fn description(&self) -> &str {
&self.config.description
}
fn is_enabled(&self) -> bool {
self.config.enabled
}
}
#[async_trait]
impl PreToolExecutionHook for ScriptHook {
async fn on_pre_execute(&self, ctx: &HookCtx) -> HookResult {
let input = serde_json::to_string(ctx).unwrap_or_default();
match self.run_script(&input).await {
Ok(output) => self.parse_output(&output),
Err(e) => HookResult::Warning(format!("Script error: {}", e)),
}
}
}
#[async_trait]
impl PostToolExecutionHook for ScriptHook {
async fn on_post_execute(&self, ctx: &HookCtx, result_ctx: &ToolResultContext) -> HookResult {
let mut combined = serde_json::Map::new();
combined.insert("hook_context".to_string(), serde_json::to_value(ctx).unwrap_or_default());
combined.insert("result_context".to_string(), serde_json::to_value(result_ctx).unwrap_or_default());
let input = serde_json::to_string(&combined).unwrap_or_default();
match self.run_script(&input).await {
Ok(output) => self.parse_output(&output),
Err(e) => HookResult::Warning(format!("Script error: {}", e)),
}
}
}
#[async_trait]
impl PostTurnHook for ScriptHook {
async fn on_post_turn(&self, ctx: &HookCtx, turn_result: &str) -> HookResult {
let mut combined = serde_json::Map::new();
combined.insert("hook_context".to_string(), serde_json::to_value(ctx).unwrap_or_default());
combined.insert("turn_result".to_string(), serde_json::Value::String(turn_result.to_string()));
let input = serde_json::to_string(&combined).unwrap_or_default();
match self.run_script(&input).await {
Ok(output) => self.parse_output(&output),
Err(e) => HookResult::Warning(format!("Script error: {}", e)),
}
}
}
#[async_trait]
impl SystemPromptHook for ScriptHook {
async fn extend_system_prompt(&self) -> Option<String> {
let empty_ctx = HookCtx::new("".to_string(), "".to_string(), "".to_string());
let input = serde_json::to_string(&empty_ctx).unwrap_or_default();
match self.run_script(&input).await {
Ok(output) if !output.is_empty() => Some(output),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hook::Hook;
#[test]
fn config_default_script_type() {
let json = r#"{"name":"test","trigger":"pre_tool","script":"test.sh"}"#;
let config: ScriptHookConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.script_type, "shell");
}
#[test]
fn config_default_enabled() {
let json = r#"{"name":"test","trigger":"pre_tool","script":"test.sh"}"#;
let config: ScriptHookConfig = serde_json::from_str(json).unwrap();
assert!(config.enabled);
}
#[test]
fn config_default_timeout() {
let json = r#"{"name":"test","trigger":"pre_tool","script":"test.sh"}"#;
let config: ScriptHookConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.timeout_secs, 2);
}
#[test]
fn script_hook_new_and_trait_methods() {
let config = ScriptHookConfig {
name: "my-hook".into(),
trigger: "post_tool".into(),
script: PathBuf::from("/tmp/dummy.sh"),
script_type: "shell".into(),
enabled: true,
timeout_secs: 5,
description: "My test hook".into(),
};
let hook = ScriptHook::new(config);
assert_eq!(hook.name(), "my-hook");
assert_eq!(hook.description(), "My test hook");
assert!(hook.is_enabled());
}
#[test]
fn script_hook_disabled() {
let config = ScriptHookConfig {
name: "disabled-hook".into(),
trigger: "pre_tool".into(),
script: PathBuf::from("ignored.sh"),
script_type: "shell".into(),
enabled: false,
timeout_secs: 2,
description: String::new(),
};
let hook = ScriptHook::new(config);
assert!(!hook.is_enabled());
}
#[test]
fn parse_output_empty() {
let config = ScriptHookConfig {
name: "t".into(),
trigger: "pre_tool".into(),
script: PathBuf::from("t.sh"),
script_type: "shell".into(),
enabled: true,
timeout_secs: 2,
description: String::new(),
};
let hook = ScriptHook::new(config);
assert!(matches!(hook.parse_output(""), HookResult::Ok));
}
#[test]
fn parse_output_ok() {
let hook = ScriptHook::new(ScriptHookConfig {
name: "t".into(),
trigger: "pre_tool".into(),
script: PathBuf::from("t.sh"),
script_type: "shell".into(),
enabled: true,
timeout_secs: 2,
description: String::new(),
});
assert!(matches!(hook.parse_output("ok"), HookResult::Ok));
}
#[test]
fn parse_output_warning() {
let hook = ScriptHook::new(ScriptHookConfig {
name: "t".into(),
trigger: "pre_tool".into(),
script: PathBuf::from("t.sh"),
script_type: "shell".into(),
enabled: true,
timeout_secs: 2,
description: String::new(),
});
let result = hook.parse_output("warning: something");
assert!(matches!(result, HookResult::Warning(msg) if msg == "something"));
}
#[test]
fn parse_output_deny() {
let hook = ScriptHook::new(ScriptHookConfig {
name: "t".into(),
trigger: "pre_tool".into(),
script: PathBuf::from("t.sh"),
script_type: "shell".into(),
enabled: true,
timeout_secs: 2,
description: String::new(),
});
let result = hook.parse_output("deny: access denied");
assert!(matches!(result, HookResult::Denied(msg) if msg == "access denied"));
}
#[test]
fn parse_output_modify() {
let hook = ScriptHook::new(ScriptHookConfig {
name: "t".into(),
trigger: "pre_tool".into(),
script: PathBuf::from("t.sh"),
script_type: "shell".into(),
enabled: true,
timeout_secs: 2,
description: String::new(),
});
let result = hook.parse_output("modify: new_args");
assert!(matches!(result, HookResult::Modified(msg) if msg == "new_args"));
}
#[test]
fn parse_output_json_ok() {
let hook = ScriptHook::new(ScriptHookConfig {
name: "t".into(),
trigger: "pre_tool".into(),
script: PathBuf::from("t.sh"),
script_type: "shell".into(),
enabled: true,
timeout_secs: 2,
description: String::new(),
});
let result = hook.parse_output(r#"{"result":"ok"}"#);
assert!(matches!(result, HookResult::Ok));
}
#[test]
fn parse_output_json_warning() {
let hook = ScriptHook::new(ScriptHookConfig {
name: "t".into(),
trigger: "pre_tool".into(),
script: PathBuf::from("t.sh"),
script_type: "shell".into(),
enabled: true,
timeout_secs: 2,
description: String::new(),
});
let result = hook.parse_output(r#"{"result":"warning","message":"be careful"}"#);
assert!(matches!(result, HookResult::Warning(msg) if msg == "be careful"));
}
#[test]
fn parse_output_unrecognized_fallback_to_ok() {
let hook = ScriptHook::new(ScriptHookConfig {
name: "t".into(),
trigger: "pre_tool".into(),
script: PathBuf::from("t.sh"),
script_type: "shell".into(),
enabled: true,
timeout_secs: 2,
description: String::new(),
});
let result = hook.parse_output("some random output");
assert!(matches!(result, HookResult::Ok));
}
#[test]
fn script_hook_impl_hook_trait() {
fn require_hook<T: Hook>() {}
require_hook::<ScriptHook>();
}
}