use async_trait::async_trait;
use chrono::Local;
use crate::hook::{
Hook, HookResult,
OnToolCallStartHook, OnTurnStartHook, OnTurnCompleteHook,
OnSessionStartHook, OnSessionEndHook, OnErrorHook, OnModelResponseHook,
ToolCallStartContext, TurnStartContext, TurnCompleteContext,
SessionContext, ErrorContext,
};
pub struct ToolAuditLogHook {
pub enabled: bool,
pub log_file: Option<std::path::PathBuf>,
}
impl Hook for ToolAuditLogHook {
fn name(&self) -> &str {
"tool-audit-log"
}
fn description(&self) -> &str {
"记录所有工具调用到审计日志"
}
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[async_trait]
impl OnToolCallStartHook for ToolAuditLogHook {
async fn on_tool_call_start(&self, ctx: &ToolCallStartContext) -> HookResult {
let timestamp = Local::now().format("%Y-%m-%d %H:%M:%S");
let log_entry = format!(
"[{}] TURN #{} | Tool: {} | CallID: {}\n",
timestamp, ctx.turn_number, ctx.tool_name, ctx.call_id
);
if let Some(ref log_file) = self.log_file {
use std::io::Write;
if let Ok(mut file) = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(log_file)
{
let _ = file.write_all(log_entry.as_bytes());
}
} else {
tracing::info!("{}", log_entry.trim_end());
}
HookResult::Ok
}
}
pub struct TurnStatsHook {
pub enabled: bool,
}
impl Hook for TurnStatsHook {
fn name(&self) -> &str {
"turn-stats"
}
fn description(&self) -> &str {
"收集 Turn 级别的统计信息"
}
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[async_trait]
impl OnTurnStartHook for TurnStatsHook {
async fn on_turn_start(&self, _ctx: &TurnStartContext) -> HookResult {
HookResult::Ok
}
}
#[async_trait]
impl OnTurnCompleteHook for TurnStatsHook {
async fn on_turn_complete(&self, ctx: &TurnCompleteContext) -> HookResult {
tracing::info!(
"[Turn #{}] Result: {} | Tokens: {} | Tools: {} | Duration: {}ms | Files: {:?}",
ctx.turn_number,
ctx.result_type,
ctx.tokens_used,
ctx.tool_calls,
ctx.duration_ms,
ctx.edited_files
);
HookResult::Ok
}
}
pub struct AutoCommitHook {
pub enabled: bool,
pub interval: u32,
}
impl Hook for AutoCommitHook {
fn name(&self) -> &str {
"auto-commit"
}
fn description(&self) -> &str {
"定期自动提交代码变更"
}
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[async_trait]
impl OnTurnCompleteHook for AutoCommitHook {
async fn on_turn_complete(&self, ctx: &TurnCompleteContext) -> HookResult {
if ctx.turn_number % self.interval != 0 {
return HookResult::Ok;
}
if ctx.edited_files.is_empty() {
return HookResult::Ok;
}
self.run_git_commit(ctx).await
}
}
impl AutoCommitHook {
async fn run_git_commit(&self, ctx: &TurnCompleteContext) -> HookResult {
let output = match std::process::Command::new("git")
.args(&["status", "--porcelain"])
.output()
{
Ok(o) => o,
Err(e) => return HookResult::Warning(format!("Failed to run git status: {}", e)),
};
if output.stdout.is_empty() {
return HookResult::Ok;
}
let _ = std::process::Command::new("git")
.args(&["add", "-A"])
.output();
let file_count = match std::process::Command::new("git")
.args(&["diff", "--cached", "--name-only"])
.output()
{
Ok(o) => {
let stdout = String::from_utf8_lossy(&o.stdout);
let count = stdout.lines().filter(|l| !l.is_empty()).count();
if count > 0 { count } else { ctx.edited_files.len() }
}
Err(_) => ctx.edited_files.len(),
};
let commit_msg = format!(
"Auto-commit at turn #{} ({} files changed)",
ctx.turn_number,
file_count
);
match std::process::Command::new("git")
.args(&["commit", "-m", &commit_msg])
.output()
{
Ok(output) if output.status.success() => {
tracing::info!("[AutoCommit] Committed at turn #{}", ctx.turn_number);
HookResult::Ok
}
Ok(output) => HookResult::Warning(format!(
"Git commit failed: {}",
String::from_utf8_lossy(&output.stderr)
)),
Err(e) => HookResult::Warning(format!("Failed to run git commit: {}", e)),
}
}
}
pub struct SessionSummaryHook {
pub enabled: bool,
}
impl SessionSummaryHook {
pub fn new() -> Self {
Self { enabled: true }
}
}
impl Hook for SessionSummaryHook {
fn name(&self) -> &str {
"session-summary"
}
fn description(&self) -> &str {
"在会话结束时生成总结报告"
}
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[async_trait]
impl OnSessionStartHook for SessionSummaryHook {
async fn on_session_start(&self, _ctx: &SessionContext) -> HookResult {
HookResult::Ok
}
}
#[async_trait]
impl OnSessionEndHook for SessionSummaryHook {
async fn on_session_end(&self, ctx: &SessionContext) -> HookResult {
let sep = "=".repeat(60);
tracing::info!(
"\n{sep}\n[Session Summary]\nSession ID: {}\nWorking Dir: {}\nModel: {} ({})\n{sep}",
ctx.session_id,
ctx.working_dir,
ctx.model_name,
ctx.provider_name,
);
HookResult::Ok
}
}
pub struct ErrorReportHook {
pub enabled: bool,
}
impl Hook for ErrorReportHook {
fn name(&self) -> &str {
"error-report"
}
fn description(&self) -> &str {
"记录错误详细信息到日志"
}
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[async_trait]
impl OnErrorHook for ErrorReportHook {
async fn on_error(&self, ctx: &ErrorContext) -> HookResult {
let timestamp = Local::now().format("%Y-%m-%d %H:%M:%S");
tracing::error!(
"[ERROR REPORT] {}\nType: {}\nPhase: {}\nTurn: {:?}\nMessage: {}",
timestamp, ctx.error_type, ctx.phase, ctx.turn_number, ctx.error_message
);
HookResult::Ok
}
}
pub struct ResponseValidationHook {
pub enabled: bool,
sensitive_patterns: Vec<String>,
}
impl ResponseValidationHook {
pub fn new(patterns: Vec<String>) -> Self {
Self {
enabled: true,
sensitive_patterns: patterns,
}
}
}
impl Hook for ResponseValidationHook {
fn name(&self) -> &str {
"response-validation"
}
fn description(&self) -> &str {
"验证模型响应,检测敏感信息泄露"
}
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[async_trait]
impl OnModelResponseHook for ResponseValidationHook {
async fn on_model_response(&self, response: &str, _turn_ctx: &TurnStartContext) -> HookResult {
let response_lower = response.to_lowercase();
for pattern in &self.sensitive_patterns {
if response_lower.contains(&pattern.to_lowercase()) {
return HookResult::Warning(format!(
"Response may contain sensitive information: '{}'",
pattern
));
}
}
HookResult::Ok
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_audit_log_hook_trait() {
let hook = ToolAuditLogHook { enabled: true, log_file: None };
assert_eq!(hook.name(), "tool-audit-log");
assert!(hook.is_enabled());
let hook_disabled = ToolAuditLogHook { enabled: false, log_file: None };
assert!(!hook_disabled.is_enabled());
}
#[tokio::test]
async fn test_tool_audit_log_on_tool_call_start() {
let hook = ToolAuditLogHook { enabled: true, log_file: None };
let ctx = ToolCallStartContext {
tool_name: "bash".into(),
tool_args: "echo hello".into(),
call_id: "call_123".into(),
turn_number: 1,
};
let result = hook.on_tool_call_start(&ctx).await;
assert!(matches!(result, HookResult::Ok));
}
#[test]
fn test_turn_stats_hook_trait() {
let hook = TurnStatsHook { enabled: true };
assert_eq!(hook.name(), "turn-stats");
assert!(hook.is_enabled());
}
#[tokio::test]
async fn test_turn_stats_on_turn_start_and_complete() {
let hook = TurnStatsHook { enabled: true };
let ctx = TurnStartContext {
turn_number: 1,
session_id: Some("test-session".into()),
working_dir: "/tmp".into(),
phase: "execution".into(),
has_file_context: false,
};
let result = hook.on_turn_start(&ctx).await;
assert!(matches!(result, HookResult::Ok));
let complete_ctx = TurnCompleteContext {
turn_number: 1,
result_type: "Responded".into(),
tokens_used: 500,
tool_calls: 3,
duration_ms: 1500,
truncated: false,
edited_files: vec!["a.rs".into(), "b.rs".into()],
};
let result = hook.on_turn_complete(&complete_ctx).await;
assert!(matches!(result, HookResult::Ok));
}
#[test]
fn test_auto_commit_hook_trait() {
let hook = AutoCommitHook { enabled: true, interval: 1 };
assert!(hook.is_enabled());
let hook_disabled = AutoCommitHook { enabled: false, interval: 1 };
assert!(!hook_disabled.is_enabled());
}
#[tokio::test]
async fn test_auto_commit_no_changes() {
let hook = AutoCommitHook { enabled: true, interval: 1 };
let ctx = TurnCompleteContext {
turn_number: 1,
result_type: "Responded".into(),
tokens_used: 100,
tool_calls: 1,
duration_ms: 100,
truncated: false,
edited_files: vec![],
};
let result = hook.run_git_commit(&ctx).await;
match result {
HookResult::Ok | HookResult::Warning(_) => {}
other => panic!("Unexpected result: {:?}", other),
}
}
#[test]
fn test_session_summary_hook_trait() {
let hook = SessionSummaryHook::new();
assert_eq!(hook.name(), "session-summary");
assert!(hook.is_enabled());
}
#[tokio::test]
async fn test_session_summary_start_end() {
let hook = SessionSummaryHook::new();
let ctx = SessionContext {
session_id: "session-1".into(),
working_dir: "/tmp".into(),
model_name: "deepseek-v4".into(),
provider_name: "atomgit".into(),
};
let result = hook.on_session_start(&ctx).await;
assert!(matches!(result, HookResult::Ok));
let result = hook.on_session_end(&ctx).await;
assert!(matches!(result, HookResult::Ok));
}
#[test]
fn test_error_report_hook_trait() {
let hook = ErrorReportHook { enabled: true };
assert_eq!(hook.name(), "error-report");
assert!(hook.is_enabled());
}
#[tokio::test]
async fn test_error_report_on_error() {
let hook = ErrorReportHook { enabled: true };
let ctx = ErrorContext {
error_type: "timeout".into(),
error_message: "Tool execution timed out after 30s".into(),
phase: "tool_execution".into(),
turn_number: Some(3),
};
let result = hook.on_error(&ctx).await;
assert!(matches!(result, HookResult::Ok));
}
#[test]
fn test_response_validation_hook_trait() {
let hook = ResponseValidationHook::new(vec!["password".into(), "secret".into()]);
assert_eq!(hook.name(), "response-validation");
assert!(hook.is_enabled());
}
#[tokio::test]
async fn test_response_validation_no_sensitive_content() {
let hook = ResponseValidationHook::new(vec!["password".into(), "secret".into()]);
let turn_ctx = TurnStartContext {
turn_number: 1,
session_id: Some("test".into()),
working_dir: "/tmp".into(),
phase: "execution".into(),
has_file_context: false,
};
let result = hook.on_model_response("Hello, how can I help you?", &turn_ctx).await;
assert!(matches!(result, HookResult::Ok));
}
#[tokio::test]
async fn test_response_validation_detects_sensitive_content() {
let hook = ResponseValidationHook::new(vec!["password".into(), "secret".into()]);
let turn_ctx = TurnStartContext {
turn_number: 1,
session_id: Some("test".into()),
working_dir: "/tmp".into(),
phase: "execution".into(),
has_file_context: false,
};
let result = hook.on_model_response(
"The admin password is super_secret_123",
&turn_ctx,
).await;
match result {
HookResult::Warning(msg) => {
assert!(msg.contains("password"), "Warning should mention 'password'");
}
other => panic!("Expected Warning, got {:?}", other),
}
}
#[tokio::test]
async fn test_response_validation_case_insensitive() {
let hook = ResponseValidationHook::new(vec!["SECRET".into()]);
let turn_ctx = TurnStartContext {
turn_number: 1,
session_id: Some("test".into()),
working_dir: "/tmp".into(),
phase: "execution".into(),
has_file_context: false,
};
let result = hook.on_model_response("This is a Secret value", &turn_ctx).await;
assert!(matches!(result, HookResult::Warning(_)));
}
}