pub mod config;
pub mod executor;
pub mod json_config;
pub mod async_batcher;
pub mod built_in;
pub mod config_loader;
pub mod engine;
pub mod script_runner;
pub mod webhook;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookEvent {
PreToolUse,
PostToolUse,
SessionStart,
SessionEnd,
Notification,
UserPromptSubmit,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HookConfig {
pub event: HookEvent,
pub matcher: Option<String>,
pub command: String,
#[serde(default = "default_timeout_ms")]
pub timeout_ms: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub plugin_root: Option<std::path::PathBuf>,
}
fn default_timeout_ms() -> u64 {
10_000
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserPromptSubmitPayload {
pub session_id: String,
pub hook_event_name: String,
pub prompt: String,
pub cwd: String,
}
#[derive(Debug, Clone, PartialEq)]
pub enum UserPromptHookResult {
Continue,
Inject(String),
Block(String),
Warning(String),
}
#[derive(Debug, Deserialize, Default)]
#[serde(default)]
pub(crate) struct UserPromptSubmitOutput {
pub decision: Option<String>,
pub reason: Option<String>,
#[serde(rename = "hookSpecificOutput")]
pub hook_specific_output: Option<UserPromptHookSpecific>,
}
#[derive(Debug, Deserialize, Default)]
#[serde(default)]
pub(crate) struct UserPromptHookSpecific {
#[serde(rename = "additionalContext")]
pub additional_context: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "action", rename_all = "snake_case")]
pub enum PreHookResult {
Allow,
Block { reason: String },
Modify { args: Value },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HookContext {
pub event: String,
pub tool_name: Option<String>,
pub tool_args: Option<Value>,
pub tool_result: Option<String>,
pub tool_success: Option<bool>,
pub session_id: String,
pub working_dir: String,
}
#[derive(Debug, Clone)]
pub enum HookResult {
Ok,
Warning(String),
Denied(String),
Modified(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserMessageContext {
pub content: String,
pub session_id: Option<String>,
pub attached_files: Vec<String>,
pub timestamp: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TurnStartContext {
pub turn_number: u32,
pub session_id: Option<String>,
pub working_dir: String,
pub phase: String,
pub has_file_context: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallStartContext {
pub tool_name: String,
pub tool_args: String,
pub call_id: String,
pub turn_number: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TurnCompleteContext {
pub turn_number: u32,
pub result_type: String,
pub tokens_used: usize,
pub tool_calls: usize,
pub duration_ms: u64,
pub truncated: bool,
pub edited_files: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionContext {
pub session_id: String,
pub working_dir: String,
pub model_name: String,
pub provider_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorContext {
pub error_type: String,
pub error_message: String,
pub phase: String,
pub turn_number: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HookCtx {
pub tool_name: String,
pub tool_args: String,
pub working_dir: String,
pub session_id: Option<String>,
pub turn_number: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
}
impl HookCtx {
pub fn new(tool_name: String, tool_args: String, working_dir: String) -> Self {
Self {
tool_name,
tool_args,
working_dir,
session_id: None,
turn_number: 0,
metadata: None,
}
}
pub fn with_session(mut self, session_id: String) -> Self {
self.session_id = Some(session_id);
self
}
pub fn with_turn(mut self, turn_number: u32) -> Self {
self.turn_number = turn_number;
self
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
let metadata = self.metadata.get_or_insert_with(HashMap::new);
metadata.insert(key, value);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResultContext {
pub tool_name: String,
pub tool_args: String,
pub result: String,
pub success: bool,
pub duration_ms: u64,
}
#[async_trait]
pub trait Hook: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str {
""
}
fn is_enabled(&self) -> bool {
true
}
fn priority(&self) -> i32 {
0
}
}
#[async_trait]
pub trait PreToolExecutionHook: Hook {
async fn on_pre_execute(&self, ctx: &HookCtx) -> HookResult;
}
#[async_trait]
pub trait PostToolExecutionHook: Hook {
async fn on_post_execute(&self, ctx: &HookCtx, result_ctx: &ToolResultContext) -> HookResult;
}
#[async_trait]
pub trait PostTurnHook: Hook {
async fn on_post_turn(&self, ctx: &HookCtx, turn_result: &str) -> HookResult;
}
#[async_trait]
pub trait SystemPromptHook: Hook {
async fn extend_system_prompt(&self) -> Option<String>;
}
#[async_trait]
pub trait OnMessageReceivedHook: Hook {
async fn on_message_received(&self, ctx: &UserMessageContext) -> HookResult;
}
#[async_trait]
pub trait OnTurnStartHook: Hook {
async fn on_turn_start(&self, ctx: &TurnStartContext) -> HookResult;
}
#[async_trait]
pub trait OnToolCallStartHook: Hook {
async fn on_tool_call_start(&self, ctx: &ToolCallStartContext) -> HookResult;
}
#[async_trait]
pub trait OnTurnCompleteHook: Hook {
async fn on_turn_complete(&self, ctx: &TurnCompleteContext) -> HookResult;
}
#[async_trait]
pub trait OnSessionStartHook: Hook {
async fn on_session_start(&self, ctx: &SessionContext) -> HookResult;
}
#[async_trait]
pub trait OnSessionEndHook: Hook {
async fn on_session_end(&self, ctx: &SessionContext) -> HookResult;
}
#[async_trait]
pub trait OnErrorHook: Hook {
async fn on_error(&self, ctx: &ErrorContext) -> HookResult;
}
#[async_trait]
pub trait OnModelResponseHook: Hook {
async fn on_model_response(&self, response: &str, turn_ctx: &TurnStartContext) -> HookResult;
}
#[async_trait]
pub trait OnUserPromptSubmitHook: Hook {
async fn on_user_prompt_submit(
&self,
payload: &UserPromptSubmitPayload,
) -> UserPromptSubmitResult;
}
pub enum UserPromptSubmitResult {
Continue,
Inject(String),
Block(String),
Warning(String),
}
impl From<UserPromptSubmitResult> for HookResult {
fn from(r: UserPromptSubmitResult) -> Self {
match r {
UserPromptSubmitResult::Continue => HookResult::Ok,
UserPromptSubmitResult::Inject(s) => HookResult::Modified(s),
UserPromptSubmitResult::Block(s) => HookResult::Denied(s),
UserPromptSubmitResult::Warning(_) => HookResult::Warning(String::new()),
}
}
}
#[derive(Debug)]
pub struct HookStats {
pub pre_tool_hooks: usize,
pub post_tool_hooks: usize,
pub post_turn_hooks: usize,
pub system_prompt_hooks: usize,
pub on_turn_start_hooks: usize,
pub on_tool_call_start_hooks: usize,
pub on_turn_complete_hooks: usize,
pub on_session_start_hooks: usize,
pub on_session_end_hooks: usize,
pub on_error_hooks: usize,
pub on_model_response_hooks: usize,
pub on_user_prompt_submit_hooks: usize,
}
pub use engine::HookEngine;
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn hook_event_serializes_to_snake_case() {
assert_eq!(
serde_json::to_string(&HookEvent::PreToolUse).unwrap(),
r#""pre_tool_use""#
);
assert_eq!(
serde_json::to_string(&HookEvent::PostToolUse).unwrap(),
r#""post_tool_use""#
);
assert_eq!(
serde_json::to_string(&HookEvent::SessionStart).unwrap(),
r#""session_start""#
);
assert_eq!(
serde_json::to_string(&HookEvent::SessionEnd).unwrap(),
r#""session_end""#
);
assert_eq!(
serde_json::to_string(&HookEvent::Notification).unwrap(),
r#""notification""#
);
}
#[test]
fn hook_event_deserializes_from_snake_case() {
let event: HookEvent = serde_json::from_str(r#""pre_tool_use""#).unwrap();
assert_eq!(event, HookEvent::PreToolUse);
let event: HookEvent = serde_json::from_str(r#""post_tool_use""#).unwrap();
assert_eq!(event, HookEvent::PostToolUse);
}
#[test]
fn hook_config_roundtrip_json() {
let cfg = HookConfig {
event: HookEvent::PreToolUse,
matcher: Some("bash".into()),
command: "echo ok".into(),
timeout_ms: 5000,
plugin_root: None,
};
let json = serde_json::to_string(&cfg).unwrap();
let back: HookConfig = serde_json::from_str(&json).unwrap();
assert_eq!(back.event, HookEvent::PreToolUse);
assert_eq!(back.matcher.as_deref(), Some("bash"));
assert_eq!(back.command, "echo ok");
assert_eq!(back.timeout_ms, 5000);
}
#[test]
fn hook_config_timeout_defaults_to_10000() {
let json = r#"{
"event": "session_start",
"command": "notify-send hello"
}"#;
let cfg: HookConfig = serde_json::from_str(json).unwrap();
assert_eq!(cfg.timeout_ms, 10_000);
assert!(cfg.matcher.is_none());
}
#[test]
fn hook_config_roundtrip_toml() {
let toml_str = r#"
event = "pre_tool_use"
matcher = "write"
command = "check-write.sh"
timeout_ms = 3000
"#;
let cfg: HookConfig = toml::from_str(toml_str).unwrap();
assert_eq!(cfg.event, HookEvent::PreToolUse);
assert_eq!(cfg.matcher.as_deref(), Some("write"));
assert_eq!(cfg.timeout_ms, 3000);
}
#[test]
fn pre_hook_result_allow_roundtrip() {
let r = PreHookResult::Allow;
let json = serde_json::to_value(&r).unwrap();
assert_eq!(json, json!({"action": "allow"}));
let back: PreHookResult = serde_json::from_value(json).unwrap();
assert_eq!(back, PreHookResult::Allow);
}
#[test]
fn pre_hook_result_block_roundtrip() {
let r = PreHookResult::Block {
reason: "unsafe".into(),
};
let json = serde_json::to_value(&r).unwrap();
assert_eq!(json, json!({"action": "block", "reason": "unsafe"}));
let back: PreHookResult = serde_json::from_value(json).unwrap();
assert_eq!(back, r);
}
#[test]
fn pre_hook_result_modify_roundtrip() {
let new_args = json!({"path": "/safe/dir", "content": "ok"});
let r = PreHookResult::Modify {
args: new_args.clone(),
};
let json = serde_json::to_value(&r).unwrap();
assert_eq!(
json,
json!({"action": "modify", "args": {"path": "/safe/dir", "content": "ok"}})
);
let back: PreHookResult = serde_json::from_value(json).unwrap();
assert_eq!(back, r);
}
#[test]
fn hook_context_full_roundtrip() {
let ctx = HookContext {
event: "pre_tool_use".into(),
tool_name: Some("bash".into()),
tool_args: Some(json!({"command": "ls"})),
tool_result: None,
tool_success: None,
session_id: "abc-123".into(),
working_dir: "/home/user/project".into(),
};
let json = serde_json::to_string(&ctx).unwrap();
let back: HookContext = serde_json::from_str(&json).unwrap();
assert_eq!(back.event, "pre_tool_use");
assert_eq!(back.tool_name.as_deref(), Some("bash"));
assert!(back.tool_result.is_none());
assert!(back.tool_success.is_none());
assert_eq!(back.session_id, "abc-123");
}
#[test]
fn hook_context_post_tool_use() {
let ctx = HookContext {
event: "post_tool_use".into(),
tool_name: Some("write".into()),
tool_args: None,
tool_result: Some("file written".into()),
tool_success: Some(true),
session_id: "xyz-789".into(),
working_dir: "/tmp".into(),
};
let v = serde_json::to_value(&ctx).unwrap();
assert_eq!(v["tool_success"], json!(true));
assert_eq!(v["tool_result"], json!("file written"));
}
#[test]
fn hook_context_minimal_session_event() {
let json_str = r#"{
"event": "session_start",
"tool_name": null,
"tool_args": null,
"tool_result": null,
"tool_success": null,
"session_id": "s1",
"working_dir": "/home"
}"#;
let ctx: HookContext = serde_json::from_str(json_str).unwrap();
assert_eq!(ctx.event, "session_start");
assert!(ctx.tool_name.is_none());
}
}