"""Task time tracker - per-task-type time estimation."""
from __future__ import annotations
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List
if TYPE_CHECKING:
from .schemas import ExperimentTask
MAX_HISTORY_PER_TASK = 10
def _task_key(task: "ExperimentTask") -> str:
"""Generate task type key for grouping.
Key is sim_type: throughput_optimizer, text_generate, or video_generate
"""
return task.sim_type
class TaskTimeTracker:
"""Track execution time per task type."""
def __init__(self):
self._history: Dict[str, List[float]] = defaultdict(list)
self._last_seen: Dict[str, float] = {}
def record(self, task: "ExperimentTask", duration_s: float):
"""Record execution time for a task."""
key = _task_key(task)
self._history[key].append(duration_s)
if len(self._history[key]) > MAX_HISTORY_PER_TASK:
self._history[key].pop(0)
self._last_seen[key] = time.time()
def get_estimate(self, task: "ExperimentTask") -> float:
"""
Get estimated time for a task based on similar tasks.
Returns:
Estimated seconds, or 60.0 if no history
"""
key = _task_key(task)
history = self._history.get(key)
if not history:
model_id = task.params.get("model_id", "")
for other_key, other_history in self._history.items():
if other_key.startswith(model_id):
return sum(other_history) / len(other_history)
return 60.0
return sum(history) / len(history)
def get_stats(self) -> Dict[str, Dict]:
"""Get statistics for all tracked task types."""
stats = {}
for key, history in self._history.items():
if history:
stats[key] = {
"count": len(history),
"avg": sum(history) / len(history),
"min": min(history),
"max": max(history),
"recent": history[-1],
}
return stats
_tracker = TaskTimeTracker()
def get_tracker() -> TaskTimeTracker:
return _tracker