use std::collections::HashSet;
use agent_types::CompressionMeta;
use agent_types::{ChatMessage, ContentBlock, MessageRole};
use serde::{Deserialize, Serialize};
use crate::{CompactError, CompactResult};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum CompactMode {
HistorySnip,
ContextCollapse,
AutoCompact,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum PartialDirection {
BeforePivot,
AfterPivot,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct CompactionBoundary {
pub pivot_message_id: String,
pub direction: PartialDirection,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct CompactRequest {
pub messages: Vec<ChatMessage>,
pub mode: CompactMode,
pub boundary: Option<CompactionBoundary>,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct CompactionDecision {
pub mode: CompactMode,
pub boundary: Option<CompactionBoundary>,
pub reason: String,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct CompactionResult {
pub messages: Vec<ChatMessage>,
pub summary: Option<String>,
pub boundary: Option<CompactionBoundary>,
pub updated_meta: CompressionMeta,
pub estimated_tokens: usize,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
struct CompactionWindow {
start: usize,
end: usize,
}
pub fn discover_boundary(
messages: &[ChatMessage],
preserve_tail_messages: usize,
) -> Option<CompactionBoundary> {
let boundary_search_end = messages.len().saturating_sub(preserve_tail_messages);
messages
.iter()
.take(boundary_search_end)
.enumerate()
.rev()
.find_map(|(index, message)| {
let pivot_message_id = message.message_id.clone()?;
let has_removable_history = messages
.iter()
.take(index)
.any(|message| !matches!(message.role, MessageRole::System));
has_removable_history.then_some(CompactionBoundary {
pivot_message_id,
direction: PartialDirection::BeforePivot,
})
})
}
pub fn candidate_indexes(
messages: &[ChatMessage],
preserve_tail_messages: usize,
boundary: Option<&CompactionBoundary>,
) -> CompactResult<Vec<usize>> {
let window = resolve_compaction_window(messages, preserve_tail_messages, boundary)?;
let adjusted_end = adjust_keep_start_for_invariants(messages, window.end, window.start);
let first_user_index = messages
.iter()
.position(|m| matches!(m.role, MessageRole::User));
Ok(messages
.iter()
.enumerate()
.filter_map(|(index, message)| {
if matches!(message.role, MessageRole::System) {
return None;
}
if Some(index) == first_user_index {
return None;
}
if index < window.start || index >= adjusted_end {
return None;
}
Some(index)
})
.collect())
}
pub fn apply_candidate_replacement(
messages: &[ChatMessage],
candidate_indexes: &[usize],
replacement: Option<ChatMessage>,
) -> Vec<ChatMessage> {
let removable = candidate_indexes.iter().copied().collect::<HashSet<_>>();
let first_candidate = candidate_indexes.iter().copied().min();
let mut result = Vec::new();
for (index, message) in messages.iter().enumerate() {
if Some(index) == first_candidate {
if let Some(replacement) = &replacement {
result.push(replacement.clone());
}
}
if !removable.contains(&index) {
result.push(message.clone());
}
}
result
}
fn resolve_boundary_index(
messages: &[ChatMessage],
boundary: &CompactionBoundary,
) -> CompactResult<usize> {
messages
.iter()
.position(|message| {
message.message_id.as_deref() == Some(boundary.pivot_message_id.as_str())
})
.ok_or_else(|| CompactError::BoundaryNotFound {
pivot_message_id: boundary.pivot_message_id.clone(),
})
}
fn resolve_compaction_window(
messages: &[ChatMessage],
preserve_tail_messages: usize,
boundary: Option<&CompactionBoundary>,
) -> CompactResult<CompactionWindow> {
let protected_tail_start = messages.len().saturating_sub(preserve_tail_messages);
let window = match boundary {
Some(boundary) => {
let boundary_index = resolve_boundary_index(messages, boundary)?;
match boundary.direction {
PartialDirection::BeforePivot => CompactionWindow {
start: 0,
end: boundary_index,
},
PartialDirection::AfterPivot => CompactionWindow {
start: boundary_index.saturating_add(1),
end: protected_tail_start,
},
}
}
None => CompactionWindow {
start: 0,
end: protected_tail_start,
},
};
Ok(CompactionWindow {
start: window.start.min(messages.len()),
end: window.end.min(messages.len()),
})
}
pub fn adjust_keep_start_for_invariants(
messages: &[ChatMessage],
keep_start: usize,
floor_start: usize,
) -> usize {
if floor_start >= keep_start || keep_start >= messages.len() {
return keep_start.min(messages.len());
}
let mut adjusted_keep_start = keep_start;
loop {
let required_keep_start =
find_required_keep_start(messages, floor_start, adjusted_keep_start);
if required_keep_start == adjusted_keep_start {
return adjusted_keep_start;
}
adjusted_keep_start = required_keep_start;
}
}
fn find_required_keep_start(
messages: &[ChatMessage],
window_start: usize,
keep_start: usize,
) -> usize {
let mut required_keep_start = keep_start;
let kept_messages = &messages[keep_start..];
let required_tool_use_ids = kept_messages
.iter()
.flat_map(message_tool_result_ids)
.collect::<HashSet<_>>();
if !required_tool_use_ids.is_empty() {
for index in (window_start..keep_start).rev() {
if message_has_tool_use_with_ids(&messages[index], &required_tool_use_ids) {
required_keep_start = required_keep_start.min(index);
}
}
}
for kept_message in kept_messages {
if !matches!(kept_message.role, MessageRole::Assistant) {
continue;
}
let Some(message_id) = kept_message.message_id.as_deref() else {
continue;
};
for index in (window_start..keep_start).rev() {
let candidate = &messages[index];
if matches!(candidate.role, MessageRole::Assistant)
&& candidate.message_id.as_deref() == Some(message_id)
{
required_keep_start = required_keep_start.min(index);
}
}
}
required_keep_start
}
fn message_tool_result_ids(message: &ChatMessage) -> Vec<&str> {
message
.blocks
.iter()
.filter_map(|block| match block {
ContentBlock::ToolResult { call_id, .. } => Some(call_id.as_str()),
_ => None,
})
.collect()
}
fn message_has_tool_use_with_ids(message: &ChatMessage, tool_use_ids: &HashSet<&str>) -> bool {
message.blocks.iter().any(|block| {
matches!(
block,
ContentBlock::ToolUse { call_id, .. } if tool_use_ids.contains(call_id.as_str())
)
})
}