use agent_contracts::runtime::RuntimeView;
use agent_contracts::trace::{TraceOutcome, TraceSpanHandle, TraceSpanKind};
use agent_types::hook::{HookInvokeInput, HookInvokeMetadata, HookInvokeOutput, HookPointId};
use agent_types::tool::{
    ErrorHookResult, ErrorToolHookInput, PostHookResult, PostToolHookInput, PreHookResult,
    PreToolHookInput, RawToolOutcome, ToolExecutionError,
};
use hook::{resolve_hook_point_category, HookPointCategory};
use serde_json::json;
use std::borrow::Cow;

use super::state::ToolExecutionState;
use super::ToolCallImpl;

impl ToolCallImpl {
    pub(super) async fn run_pre_hook_sequence(
        &self,
        state: &mut ToolExecutionState,
        runtime: &dyn RuntimeView,
    ) -> Result<Vec<PreHookResult>, ToolExecutionError> {
        let hook_point = self.build_tool_hook_point(runtime, &state.final_call.tool_name, "pre");
        let category = resolve_hook_point_category(&hook_point).map_err(|error| {
            ToolExecutionError::ExecutionFailed {
                message: format!(
                    "failed to resolve pre-hook category for tool call '{}' (hook_point='{}'): {}",
                    state.final_call.call_id, hook_point.0, error
                ),
            }
        })?;

        if category != HookPointCategory::ToolPre {
            return Err(ToolExecutionError::ExecutionFailed {
                message: format!(
                    "pre-hook sequence expected ToolPre category but got {:?} for hook point {}",
                    category, hook_point.0
                ),
            });
        }

        let mut hookers = runtime.hookers().list_for_hook_point(&hook_point);
        hookers.retain(|hooker| runtime.hookers().is_enabled(hooker.id()));
        hookers.sort_by(|left, right| left.id().0.cmp(&right.id().0));

        if hookers.is_empty() {
            return Ok(Vec::new());
        }

        let mut results = Vec::new();

        for hooker in hookers {
            let hook_span = runtime
                .trace_recorder()
                .begin_span(
                    TraceSpanKind::Hook,
                    Cow::Borrowed("tool_pre_hook"),
                    json!({
                        "hook_kind": "tool_pre",
                        "hooker_id": hooker.id().to_string(),
                        "hook_point": hook_point.0,
                        "tool_name": state.final_call.tool_name,
                        "call_id": state.final_call.call_id,
                    }),
                )
                .await;

            let input = HookInvokeInput::Pre {
                input: PreToolHookInput {
                    call: state.final_call.clone(),
                },
                metadata: hook_invoke_metadata(&hook_span),
            };

            let output = match hooker.invoke(input, runtime).await {
                Ok(output) => output,
                Err(error) => {
                    tracing::warn!(
                        hooker_id = %hooker.id(),
                        call_id = %state.final_call.call_id,
                        tool = %state.final_call.tool_name,
                        error = %error,
                        "pre-hook invoke failed"
                    );
                    runtime
                        .trace_recorder()
                        .end_span(
                            hook_span,
                            TraceOutcome::Error,
                            json!({"error": error.to_string()}),
                        )
                        .await;
                    continue;
                }
            };

            let pre_result = match output {
                HookInvokeOutput::Pre(pre_result) => pre_result,
                other => {
                    tracing::warn!(
                        hooker_id = %hooker.id(),
                        call_id = %state.final_call.call_id,
                        output = ?other,
                        "pre-hooker returned non-pre output"
                    );
                    runtime
                        .trace_recorder()
                        .end_span(
                            hook_span,
                            TraceOutcome::Error,
                            json!({"error": "unexpected output variant"}),
                        )
                        .await;
                    continue;
                }
            };

            match pre_result {
                PreHookResult::Allow => {
                    results.push(PreHookResult::Allow);
                    runtime
                        .trace_recorder()
                        .end_span(hook_span, TraceOutcome::Ok, json!({"result": "allow"}))
                        .await;
                }
                PreHookResult::Deny { reason } => {
                    let fields = json!({"result": "deny", "reason": &reason});
                    results.push(PreHookResult::Deny { reason });
                    runtime
                        .trace_recorder()
                        .end_span(hook_span, TraceOutcome::Denied, fields)
                        .await;
                    return Ok(results);
                }
                PreHookResult::Transform { modified_input } => {
                    state.final_call.input = modified_input;
                    results.push(PreHookResult::Transform {
                        modified_input: state.final_call.input.clone(),
                    });
                    runtime
                        .trace_recorder()
                        .end_span(hook_span, TraceOutcome::Ok, json!({"result": "transform"}))
                        .await;
                }
            }
        }

        Ok(results)
    }

    pub(super) fn build_tool_hook_point(
        &self,
        runtime: &dyn RuntimeView,
        tool_name: &str,
        stage: &str,
    ) -> HookPointId {
        let agent_id = &runtime.agent_context().metadata().agent_id;
        HookPointId(format!("{}.Tool.{}.{}", agent_id, tool_name, stage))
    }

    pub(super) async fn run_post_hook_sequence(
        &self,
        state: &mut ToolExecutionState,
        runtime: &dyn RuntimeView,
    ) -> Result<Vec<PostHookResult>, ToolExecutionError> {
        if state.raw_outcome.is_none() {
            return Ok(Vec::new());
        }

        let initial_outcome = match state.raw_outcome.as_ref() {
            Some(raw_outcome) => raw_outcome.clone(),
            None => unreachable!(),
        };

        let hook_point = self.build_tool_hook_point(runtime, &state.final_call.tool_name, "post");
        let category = resolve_hook_point_category(&hook_point).map_err(|error| {
            ToolExecutionError::ExecutionFailed {
                message: format!(
                    "failed to resolve post-hook category for tool call '{}' (hook_point='{}'): {}",
                    state.final_call.call_id, hook_point.0, error
                ),
            }
        })?;

        if category != HookPointCategory::ToolPost {
            return Err(ToolExecutionError::ExecutionFailed {
                message: format!(
                    "post-hook sequence expected ToolPost category but got {:?} for hook point {}",
                    category, hook_point.0
                ),
            });
        }

        let mut hookers = runtime.hookers().list_for_hook_point(&hook_point);
        hookers.retain(|hooker| runtime.hookers().is_enabled(hooker.id()));
        hookers.sort_by(|left, right| left.id().0.cmp(&right.id().0));

        if hookers.is_empty() {
            return Ok(Vec::new());
        }

        let mut results = Vec::new();

        for hooker in hookers {
            let current_outcome = state
                .raw_outcome
                .as_ref()
                .cloned()
                .unwrap_or_else(|| initial_outcome.clone());

            let hook_span = runtime
                .trace_recorder()
                .begin_span(
                    TraceSpanKind::Hook,
                    Cow::Borrowed("tool_post_hook"),
                    json!({
                        "hook_kind": "tool_post",
                        "hooker_id": hooker.id().to_string(),
                        "hook_point": hook_point.0,
                        "tool_name": state.final_call.tool_name,
                        "call_id": state.final_call.call_id,
                    }),
                )
                .await;

            let input = HookInvokeInput::Post {
                input: PostToolHookInput {
                    call: state.final_call.clone(),
                    outcome: current_outcome,
                },
                metadata: hook_invoke_metadata(&hook_span),
            };

            let output = match hooker.invoke(input, runtime).await {
                Ok(output) => output,
                Err(error) => {
                    tracing::warn!(
                        hooker_id = %hooker.id(),
                        call_id = %state.final_call.call_id,
                        tool = %state.final_call.tool_name,
                        error = %error,
                        "post-hook invoke failed"
                    );
                    runtime
                        .trace_recorder()
                        .end_span(
                            hook_span,
                            TraceOutcome::Error,
                            json!({"error": error.to_string()}),
                        )
                        .await;
                    continue;
                }
            };

            let post_result = match output {
                HookInvokeOutput::Post(post_result) => post_result,
                other => {
                    let error = format!(
                        "post-hooker '{}' returned non-post output {:?} for tool call '{}'",
                        hooker.id(),
                        other,
                        state.final_call.call_id
                    );
                    tracing::warn!(
                        hooker_id = %hooker.id(),
                        call_id = %state.final_call.call_id,
                        output = ?other,
                        "post-hooker returned non-post output"
                    );
                    runtime
                        .trace_recorder()
                        .end_span(hook_span, TraceOutcome::Error, json!({"error": error}))
                        .await;
                    continue;
                }
            };

            match post_result {
                PostHookResult::Accept => {
                    results.push(PostHookResult::Accept);
                    runtime
                        .trace_recorder()
                        .end_span(hook_span, TraceOutcome::Ok, json!({"result": "accept"}))
                        .await;
                }
                PostHookResult::Transform { modified_output } => {
                    state.raw_outcome = Some(RawToolOutcome::Success {
                        output: modified_output.clone(),
                    });
                    results.push(PostHookResult::Transform { modified_output });
                    runtime
                        .trace_recorder()
                        .end_span(hook_span, TraceOutcome::Ok, json!({"result": "transform"}))
                        .await;
                }
            }
        }

        Ok(results)
    }

    pub(super) async fn run_error_hook_sequence(
        &self,
        state: &ToolExecutionState,
        execution_error: &ToolExecutionError,
        runtime: &dyn RuntimeView,
    ) -> Result<Vec<ErrorHookResult>, ToolExecutionError> {
        let hook_point = self.build_tool_hook_point(runtime, &state.final_call.tool_name, "error");
        let category = resolve_hook_point_category(&hook_point).map_err(|error| {
            ToolExecutionError::ExecutionFailed {
                message: format!(
                    "failed to resolve error-hook category for tool call '{}' (hook_point='{}'): {}",
                    state.final_call.call_id, hook_point.0, error
                ),
            }
        })?;

        if category != HookPointCategory::ToolError {
            return Err(ToolExecutionError::ExecutionFailed {
                message: format!(
                    "error-hook sequence expected ToolError category but got {:?} for hook point {}",
                    category, hook_point.0
                ),
            });
        }

        let mut hookers = runtime.hookers().list_for_hook_point(&hook_point);
        hookers.retain(|hooker| runtime.hookers().is_enabled(hooker.id()));
        hookers.sort_by(|left, right| left.id().0.cmp(&right.id().0));

        if hookers.is_empty() {
            return Ok(Vec::new());
        }

        let mut results = Vec::new();

        for hooker in hookers {
            let hook_span = runtime
                .trace_recorder()
                .begin_span(
                    TraceSpanKind::Hook,
                    Cow::Borrowed("tool_error_hook"),
                    json!({
                        "hook_kind": "tool_error",
                        "hooker_id": hooker.id().to_string(),
                        "hook_point": hook_point.0,
                        "tool_name": state.final_call.tool_name,
                        "call_id": state.final_call.call_id,
                        "execution_error": execution_error.to_string(),
                    }),
                )
                .await;

            let input = HookInvokeInput::Error {
                input: ErrorToolHookInput {
                    call: state.final_call.clone(),
                    error: execution_error.clone(),
                },
                metadata: hook_invoke_metadata(&hook_span),
            };

            let output = match hooker.invoke(input, runtime).await {
                Ok(output) => output,
                Err(error) => {
                    tracing::warn!(
                        hooker_id = %hooker.id(),
                        call_id = %state.final_call.call_id,
                        tool = %state.final_call.tool_name,
                        error = %error,
                        "error-hook invoke failed"
                    );
                    runtime
                        .trace_recorder()
                        .end_span(
                            hook_span,
                            TraceOutcome::Error,
                            json!({"error": error.to_string()}),
                        )
                        .await;
                    continue;
                }
            };

            let error_result = match output {
                HookInvokeOutput::Error(error_result) => error_result,
                other => {
                    let error = format!(
                        "error-hooker '{}' returned non-error output {:?} for tool call '{}'",
                        hooker.id(),
                        other,
                        state.final_call.call_id
                    );
                    tracing::warn!(
                        hooker_id = %hooker.id(),
                        call_id = %state.final_call.call_id,
                        output = ?other,
                        "error-hooker returned non-error output"
                    );
                    runtime
                        .trace_recorder()
                        .end_span(hook_span, TraceOutcome::Error, json!({"error": error}))
                        .await;
                    continue;
                }
            };

            let (trace_outcome, result_label) = match &error_result {
                ErrorHookResult::Propagate => (TraceOutcome::Ok, "propagate"),
                ErrorHookResult::Recover { .. } => (TraceOutcome::Ok, "recover"),
            };
            runtime
                .trace_recorder()
                .end_span(hook_span, trace_outcome, json!({"result": result_label}))
                .await;

            results.push(error_result);
        }

        Ok(results)
    }

    pub(super) async fn collect_error_hook_results_after_begin(
        &self,
        state: &mut ToolExecutionState,
        execution_error: &ToolExecutionError,
        runtime: &dyn RuntimeView,
    ) {
        match self
            .run_error_hook_sequence(state, execution_error, runtime)
            .await
        {
            Ok(error_hook_results) => state.error_hook_results = error_hook_results,
            Err(error_hook_failure) => {
                tracing::warn!(
                    call_id = %state.final_call.call_id,
                    tool = %state.final_call.tool_name,
                    error = %error_hook_failure,
                    "error-hook phase failed after begin"
                );
            }
        }
    }
}

fn hook_invoke_metadata(hook_span: &TraceSpanHandle) -> HookInvokeMetadata {
    HookInvokeMetadata {
        trace_id: Some(hook_span.trace_id().to_string()),
        span_id: Some(hook_span.span_id().to_string()),
        parent_span_id: hook_span.parent_span_id().map(ToString::to_string),
    }
}