use super::message::{Message, Role};
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum TurnStatus {
Active,
Completed,
Summarized,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Turn {
pub start_idx: usize,
pub msg_count: usize,
pub status: TurnStatus,
pub summary: Option<String>,
}
impl Turn {
pub fn end_idx(&self) -> usize {
self.start_idx + self.msg_count
}
}
#[derive(Debug, Clone, Default)]
pub struct TurnTracker {
pub turns: Vec<Turn>,
}
impl TurnTracker {
pub fn new() -> Self {
Self { turns: Vec::new() }
}
pub fn rebuild(messages: &[Message]) -> Self {
let mut tracker = Self::new();
for (i, msg) in messages.iter().enumerate() {
if matches!(msg.role, Role::User) {
if let Some(prev) = tracker.turns.last_mut() {
if prev.status == TurnStatus::Active {
prev.msg_count = i - prev.start_idx;
prev.status = TurnStatus::Completed;
}
}
tracker.turns.push(Turn {
start_idx: i,
msg_count: 1,
status: TurnStatus::Active,
summary: None,
});
} else if let Some(current) = tracker.turns.last_mut() {
current.msg_count = i - current.start_idx + 1;
}
}
let len = tracker.turns.len();
if len > 1 {
for turn in &mut tracker.turns[..len - 1] {
turn.status = TurnStatus::Completed;
}
}
tracker
}
pub fn on_user_message(&mut self, msg_idx: usize) {
if let Some(prev) = self.turns.last_mut() {
if prev.status == TurnStatus::Active {
prev.msg_count = msg_idx.saturating_sub(prev.start_idx);
prev.status = TurnStatus::Completed;
}
}
self.turns.push(Turn {
start_idx: msg_idx,
msg_count: 1,
status: TurnStatus::Active,
summary: None,
});
}
pub fn on_message_added(&mut self, msg_idx: usize) {
if let Some(current) = self.turns.last_mut() {
if current.status == TurnStatus::Active {
current.msg_count = msg_idx - current.start_idx + 1;
}
}
}
pub fn complete_current(&mut self) {
if let Some(current) = self.turns.last_mut() {
if current.status == TurnStatus::Active {
current.status = TurnStatus::Completed;
}
}
}
pub fn active_turn(&self) -> Option<&Turn> {
self.turns.last().filter(|t| t.status == TurnStatus::Active)
}
pub fn completed_count(&self) -> usize {
self.turns
.iter()
.filter(|t| t.status == TurnStatus::Completed)
.count()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::conversation::message::{Message, Role};
#[test]
fn test_rebuild_empty() {
let tracker = TurnTracker::rebuild(&[]);
assert!(tracker.turns.is_empty());
}
#[test]
fn test_rebuild_single_turn() {
let messages = vec![
Message::new(Role::User, "hello"),
Message::new(Role::Assistant, "hi there"),
];
let tracker = TurnTracker::rebuild(&messages);
assert_eq!(tracker.turns.len(), 1);
assert_eq!(tracker.turns[0].start_idx, 0);
assert_eq!(tracker.turns[0].msg_count, 2);
assert_eq!(tracker.turns[0].status, TurnStatus::Active);
}
#[test]
fn test_rebuild_multi_turn() {
let messages = vec![
Message::new(Role::User, "task 1"),
Message::new(Role::Assistant, "done 1"),
Message::new(Role::User, "task 2"),
Message::new(Role::Assistant, "done 2"),
Message::new(Role::User, "task 3"),
];
let tracker = TurnTracker::rebuild(&messages);
assert_eq!(tracker.turns.len(), 3);
assert_eq!(tracker.turns[0].start_idx, 0);
assert_eq!(tracker.turns[0].msg_count, 2);
assert_eq!(tracker.turns[0].status, TurnStatus::Completed);
assert_eq!(tracker.turns[1].start_idx, 2);
assert_eq!(tracker.turns[1].msg_count, 2);
assert_eq!(tracker.turns[1].status, TurnStatus::Completed);
assert_eq!(tracker.turns[2].start_idx, 4);
assert_eq!(tracker.turns[2].msg_count, 1);
assert_eq!(tracker.turns[2].status, TurnStatus::Active);
}
#[test]
fn test_on_user_message_closes_previous() {
let mut tracker = TurnTracker::new();
tracker.on_user_message(0);
assert_eq!(tracker.turns.len(), 1);
assert_eq!(tracker.turns[0].status, TurnStatus::Active);
tracker.on_message_added(1);
tracker.on_message_added(2);
tracker.on_user_message(3);
assert_eq!(tracker.turns.len(), 2);
assert_eq!(tracker.turns[0].status, TurnStatus::Completed);
assert_eq!(tracker.turns[0].msg_count, 3);
assert_eq!(tracker.turns[1].status, TurnStatus::Active);
assert_eq!(tracker.turns[1].start_idx, 3);
}
#[test]
fn test_complete_current() {
let mut tracker = TurnTracker::new();
tracker.on_user_message(0);
tracker.on_message_added(1);
tracker.complete_current();
assert_eq!(tracker.turns[0].status, TurnStatus::Completed);
}
#[test]
fn test_completed_count() {
let mut tracker = TurnTracker::new();
tracker.on_user_message(0);
tracker.on_message_added(1);
assert_eq!(tracker.completed_count(), 0);
tracker.complete_current();
assert_eq!(tracker.completed_count(), 1);
tracker.on_user_message(2);
assert_eq!(tracker.completed_count(), 1);
}
#[test]
fn test_rebuild_matches_truncated_messages_length() {
use super::super::message::MessageContent;
use crate::tool::{ToolCall, ToolResult};
let mut msgs: Vec<Message> = Vec::new();
for t in 0..3 {
msgs.push(Message::new(Role::User, &format!("task {}", t)));
msgs.push(Message {
role: Role::Assistant,
content: MessageContent::AssistantWithToolCalls {
text: Some("working".into()),
tool_calls: vec![ToolCall {
id: format!("c{}", t),
name: "bash".into(),
arguments: "{}".into(),
}],
reasoning_content: None,
thinking_blocks: Vec::new(),
},
synthetic: false,
});
msgs.push(Message {
role: Role::Tool,
content: MessageContent::ToolResult(ToolResult {
call_id: format!("c{}", t),
output: "ok".into(),
success: true,
}),
synthetic: false,
});
msgs.push(Message::new(Role::Assistant, &format!("done {}", t)));
}
assert_eq!(msgs.len(), 12);
msgs.truncate(msgs.len() - 4);
let tracker = TurnTracker::rebuild(&msgs);
for (i, t) in tracker.turns.iter().enumerate() {
assert!(
t.end_idx() <= msgs.len(),
"turn {} end_idx {} exceeds messages.len() {}",
i,
t.end_idx(),
msgs.len(),
);
}
}
}