"""Trajectory schema for tracking task execution history."""
from typing import Optional, Dict, Any, List
from enum import Enum
from pydantic import BaseModel, Field, ConfigDict
class FeedbackType(str, Enum):
"""Feedback type for trajectory outcomes."""
HELPFUL = "helpful"
HARMFUL = "harmful"
NEUTRAL = "neutral"
class Trajectory(BaseModel):
"""Trajectory representing a task execution with feedback.
A trajectory captures:
- The query/task that was executed
- The response/output generated
- Feedback on whether the outcome was helpful or harmful
"""
query: str = Field(description="The query or task that was executed")
response: str = Field(description="The response or output generated")
feedback: FeedbackType = Field(
default=FeedbackType.NEUTRAL, description="Feedback on the outcome"
)
context: Optional[Dict[str, Any]] = Field(
default_factory=dict, description="Additional context about the execution"
)
model_config = ConfigDict(
use_enum_values=True
)
def is_success(self) -> bool:
"""Check if trajectory was successful.
Returns:
True if feedback is helpful
"""
return self.feedback == FeedbackType.HELPFUL
def is_failure(self) -> bool:
"""Check if trajectory was a failure.
Returns:
True if feedback is harmful
"""
return self.feedback == FeedbackType.HARMFUL
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary.
Returns:
Dictionary representation
"""
return self.model_dump()
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Trajectory":
"""Create from dictionary.
Args:
data: Dictionary data
Returns:
Trajectory instance
"""
return cls(**data)
def __repr__(self) -> str:
"""String representation."""
query_preview = self.query[:50] + "..." if len(self.query) > 50 else self.query
return f"Trajectory(query='{query_preview}', feedback={self.feedback})"
class TrajectoryBatch(BaseModel):
"""Batch of trajectories for processing."""
trajectories: List[Trajectory] = Field(description="List of trajectories")
user_id: str = Field(description="User identifier")
metadata: Dict[str, Any] = Field(
default_factory=dict, description="Batch metadata"
)
def get_success_trajectories(self) -> List[Trajectory]:
"""Get successful trajectories.
Returns:
List of trajectories with helpful feedback
"""
return [t for t in self.trajectories if t.is_success()]
def get_failure_trajectories(self) -> List[Trajectory]:
"""Get failed trajectories.
Returns:
List of trajectories with harmful feedback
"""
return [t for t in self.trajectories if t.is_failure()]
def count_by_feedback(self) -> Dict[str, int]:
"""Count trajectories by feedback type.
Returns:
Dictionary with counts per feedback type
"""
counts = {
FeedbackType.HELPFUL: 0,
FeedbackType.HARMFUL: 0,
FeedbackType.NEUTRAL: 0,
}
for traj in self.trajectories:
counts[traj.feedback] += 1
return counts
def __repr__(self) -> str:
"""String representation."""
counts = self.count_by_feedback()
return (
f"TrajectoryBatch(user={self.user_id}, "
f"total={len(self.trajectories)}, "
f"helpful={counts[FeedbackType.HELPFUL]}, "
f"harmful={counts[FeedbackType.HARMFUL]})"
)