use std::convert::Infallible;
use std::sync::Mutex;
use agent_contracts::LoopEventSink;
use agent_types::common::ids::AgentId;
use agent_types::events::{LoopEndSummary, ToolResultEvent};
use agent_types::interaction::InteractionRequest;
use axum::response::sse;
use futures_util::StreamExt;
use serde::Serialize;
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum SseStreamEvent {
TurnStart {
agent_id: String,
turn: u32,
},
TextDelta {
delta: String,
snapshot: String,
},
ThinkingDelta {
delta: String,
snapshot: String,
},
ToolResult {
call_id: String,
tool_name: String,
output_preview: String,
is_error: bool,
},
InteractionRequested {
request: InteractionRequest,
},
Done {
reply: String,
raw_reply: String,
conversation_id: String,
session_id: String,
turn_count: u32,
total_tokens: usize,
prompt_tokens: u64,
completion_tokens: u64,
estimated_input_tokens: u64,
messages: Vec<llm_client::ChatMessage>,
stop_reason: String,
},
Error {
error: String,
},
Cancelled {
session_id: String,
},
}
impl SseStreamEvent {
fn event_name(&self) -> &'static str {
match self {
SseStreamEvent::TurnStart { .. } => "turn_start",
SseStreamEvent::TextDelta { .. } => "text_delta",
SseStreamEvent::ThinkingDelta { .. } => "thinking_delta",
SseStreamEvent::ToolResult { .. } => "tool_result",
SseStreamEvent::InteractionRequested { .. } => "interaction_requested",
SseStreamEvent::Done { .. } => "done",
SseStreamEvent::Error { .. } => "error",
SseStreamEvent::Cancelled { .. } => "cancelled",
}
}
}
pub struct SseLoopEventSink {
tx: mpsc::UnboundedSender<SseStreamEvent>,
last_snapshot_len: Mutex<usize>,
last_thinking_snapshot_len: Mutex<usize>,
loop_summary: Mutex<Option<LoopEndSummary>>,
}
impl SseLoopEventSink {
pub fn new(tx: mpsc::UnboundedSender<SseStreamEvent>) -> Self {
Self {
tx,
last_snapshot_len: Mutex::new(0),
last_thinking_snapshot_len: Mutex::new(0),
loop_summary: Mutex::new(None),
}
}
pub fn take_loop_summary(&self) -> Option<LoopEndSummary> {
self.loop_summary
.lock()
.expect("sse sink loop_summary mutex should not be poisoned")
.take()
}
}
impl LoopEventSink for SseLoopEventSink {
fn on_turn_start(&self, agent_id: &AgentId, turn: u32) {
if let Ok(mut len) = self.last_snapshot_len.lock() {
*len = 0;
}
if let Ok(mut len) = self.last_thinking_snapshot_len.lock() {
*len = 0;
}
let _ = self.tx.send(SseStreamEvent::TurnStart {
agent_id: agent_id.0.clone(),
turn,
});
}
fn on_assistant_message(&self, _agent_id: &AgentId, text: &str) {
let delta = {
let mut last_len = self
.last_snapshot_len
.lock()
.expect("sse sink last_snapshot_len mutex should not be poisoned");
let prev = *last_len;
*last_len = text.len();
if prev < text.len() {
text[prev..].to_string()
} else {
return;
}
};
let _ = self.tx.send(SseStreamEvent::TextDelta {
delta,
snapshot: text.to_string(),
});
}
fn on_assistant_reasoning(&self, _agent_id: &AgentId, text: &str) {
let delta = {
let mut last_len = self
.last_thinking_snapshot_len
.lock()
.expect("sse sink last_thinking_snapshot_len mutex should not be poisoned");
let prev = *last_len;
*last_len = text.len();
if prev < text.len() {
text[prev..].to_string()
} else {
return;
}
};
let _ = self.tx.send(SseStreamEvent::ThinkingDelta {
delta,
snapshot: text.to_string(),
});
}
fn on_tool_result(&self, _agent_id: &AgentId, event: &ToolResultEvent) {
let _ = self.tx.send(SseStreamEvent::ToolResult {
call_id: event.call_id.clone(),
tool_name: event.tool_name.clone(),
output_preview: event.output_preview.clone(),
is_error: event.is_error,
});
}
fn on_loop_end(&self, _agent_id: &AgentId, summary: &LoopEndSummary) {
if let Ok(mut stored) = self.loop_summary.lock() {
*stored = Some(summary.clone());
}
}
}
pub fn sse_stream_from_receiver(
rx: mpsc::UnboundedReceiver<SseStreamEvent>,
) -> impl futures_util::Stream<Item = Result<sse::Event, Infallible>> {
UnboundedReceiverStream::new(rx).map(|event| {
let name = event.event_name();
let data =
serde_json::to_string(&event).unwrap_or_else(|e| format!("{{\"error\":\"{e}\"}}"));
Ok(sse::Event::default().event(name).data(data))
})
}