from __future__ import annotations
import threading
from collections import deque
from copy import deepcopy
from datetime import datetime, timezone
from typing import Any
from server.control_plane_store import ControlPlaneStore
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _parse_iso(value: str | None):
if not value:
return None
try:
return datetime.fromisoformat(str(value).replace("Z", "+00:00"))
except (TypeError, ValueError):
return None
def _empty_summary() -> dict:
return {
"llm_tool_rounds": 0,
"tool_calls": 0,
"input_tokens": 0,
"output_tokens": 0,
"cache_read": 0,
"cache_write": 0,
"total_tokens": 0,
}
def _empty_tool_counter() -> dict:
return {
"call_count": 0,
"success_count": 0,
"fail_count": 0,
"total_duration_ms": 0.0,
"round_ids": set(),
"total_tokens": 0,
"attributed_tokens": 0,
"allocated_tokens": 0,
}
def _tool_response(tool_name: str, raw: dict) -> dict:
call_count = int(raw.get("call_count", 0) or 0)
success_count = int(raw.get("success_count", 0) or 0)
total_duration = float(raw.get("total_duration_ms", 0.0) or 0.0)
round_ids = raw.get("round_ids", set()) or set()
if not isinstance(round_ids, set):
round_ids = set(round_ids)
return {
"tool_name": tool_name,
"call_count": call_count,
"success_count": success_count,
"fail_count": int(raw.get("fail_count", 0) or 0),
"success_rate": success_count / call_count if call_count else 0.0,
"avg_duration_ms": total_duration / call_count if call_count else 0.0,
"round_count": len(round_ids),
"attributed_tokens": int(raw.get("attributed_tokens", raw.get("total_tokens", 0)) or 0),
"allocated_tokens": int(raw.get("allocated_tokens", 0) or 0),
"total_tokens": int(raw.get("total_tokens", raw.get("attributed_tokens", 0)) or 0),
}
class InternalToolUsageTracker:
"""In-memory usage tracker for ogmem internal tool calls."""
def __init__(self, max_rounds_per_session: int = 1000) -> None:
self._lock = threading.Lock()
self._max_rounds = max(1, int(max_rounds_per_session or 1000))
self._sessions: dict[str, dict] = {}
def record_round(
self,
*,
account_id: str,
session_id: str,
pipeline: str,
round_id: str,
tool_names: list[str],
user_id: str = "",
model: str = "",
input_tokens: int = 0,
output_tokens: int = 0,
cache_read: int = 0,
cache_write: int = 0,
started_at: str = "",
ended_at: str = "",
) -> None:
if not session_id or not round_id or not tool_names:
return
with self._lock:
session = self._get_session(account_id, session_id, user_id=user_id)
is_new_round = round_id not in session["round_ids"]
if is_new_round:
session["round_ids"].add(round_id)
session["summary"]["llm_tool_rounds"] += 1
session["summary"]["input_tokens"] += input_tokens or 0
session["summary"]["output_tokens"] += output_tokens or 0
session["summary"]["cache_read"] += cache_read or 0
session["summary"]["cache_write"] += cache_write or 0
session["summary"]["total_tokens"] += (input_tokens or 0) + (output_tokens or 0)
round_record = {
"round_id": round_id,
"session_id": session_id,
"user_id": user_id or session.get("user_id", ""),
"pipeline": pipeline,
"model": model,
"input_tokens": input_tokens or 0,
"output_tokens": output_tokens or 0,
"cache_read": cache_read or 0,
"cache_write": cache_write or 0,
"total_tokens": (input_tokens or 0) + (output_tokens or 0),
"tool_names": list(tool_names),
"tool_call_count": len(tool_names),
"started_at": started_at,
"ended_at": ended_at or _now_iso(),
}
session["rounds"].append(round_record)
session["rounds_truncated"] = session["rounds_truncated"] or (
len(session["rounds"]) == self._max_rounds
and session["summary"]["llm_tool_rounds"] > self._max_rounds
)
def record_tool_call(
self,
*,
account_id: str,
session_id: str,
pipeline: str,
round_id: str,
tool_name: str,
status: str,
duration_ms: float | int | None = 0,
error_type: str = "",
) -> None:
if not session_id or not tool_name:
return
success = status in {"success", "completed", "recorded"}
with self._lock:
session = self._get_session(account_id, session_id)
session["summary"]["tool_calls"] += 1
pipeline_tools = session["tools_by_pipeline"].setdefault(pipeline or "", {})
counter = pipeline_tools.setdefault(tool_name, _empty_tool_counter())
counter["call_count"] += 1
if success:
counter["success_count"] += 1
else:
counter["fail_count"] += 1
counter["total_duration_ms"] += duration_ms or 0
if round_id:
counter["round_ids"].add(round_id)
session["tool_events"].append({
"tool_name": tool_name,
"pipeline": pipeline,
"round_id": round_id,
"status": status,
"duration_ms": duration_ms or 0,
"error_type": error_type,
"created_at": _now_iso(),
})
def get_stats(
self,
*,
session_id: str | None = None,
user_id: str | None = None,
start_time: str | None = None,
end_time: str | None = None,
pipeline: str | None = None,
include_rounds: bool = False,
) -> dict:
with self._lock:
sessions = {
sid: deepcopy(data)
for sid, data in self._sessions.items()
if not session_id or sid == session_id
}
return self._format_stats(
sessions,
session_id=session_id,
user_id=user_id,
start_time=start_time,
end_time=end_time,
pipeline=pipeline,
include_rounds=include_rounds,
)
def session_snapshot(self, session_id: str) -> dict | None:
with self._lock:
session = self._sessions.get(session_id)
if not session:
return None
sessions = {session_id: deepcopy(session)}
stats = self._format_stats(sessions, session_id=session_id, include_rounds=True)
if not stats["summary"]["llm_tool_rounds"] and not stats["summary"]["tool_calls"]:
return None
stats["account_id"] = session.get("account_id", "")
return stats
def reset(self) -> None:
with self._lock:
self._sessions = {}
def clear_session(self, session_id: str) -> None:
if not session_id:
return
with self._lock:
self._sessions.pop(session_id, None)
def count_snapshot(self) -> dict:
stats = self.get_stats()
summary = stats["summary"]
return {
"rounds": summary["llm_tool_rounds"],
"tool_calls": summary["tool_calls"],
"input_tokens": summary.get("input_tokens", 0),
"output_tokens": summary.get("output_tokens", 0),
"cache_read": summary.get("cache_read", 0),
"cache_write": summary.get("cache_write", 0),
"total_tokens": summary.get("total_tokens", 0),
}
def _get_session(self, account_id: str, session_id: str, *, user_id: str = "") -> dict:
if session_id not in self._sessions:
self._sessions[session_id] = {
"account_id": account_id,
"user_id": user_id,
"session_id": session_id,
"created_at": _now_iso(),
"updated_at": _now_iso(),
"summary": _empty_summary(),
"tools_by_pipeline": {},
"rounds": deque(maxlen=self._max_rounds),
"round_ids": set(),
"rounds_truncated": False,
"tool_events": deque(maxlen=self._max_rounds),
}
session = self._sessions[session_id]
session["updated_at"] = _now_iso()
if account_id and not session.get("account_id"):
session["account_id"] = account_id
if user_id and not session.get("user_id"):
session["user_id"] = user_id
return session
def _format_stats(
self,
sessions: dict[str, dict],
*,
session_id: str | None = None,
user_id: str | None = None,
start_time: str | None = None,
end_time: str | None = None,
pipeline: str | None = None,
include_rounds: bool = False,
) -> dict:
summary = _empty_summary()
tools: dict[str, dict] = {}
session_rows = []
rounds = []
rounds_truncated = False
rounds_limit = self._max_rounds
start_dt = _parse_iso(start_time)
end_dt = _parse_iso(end_time)
for sid, session in sessions.items():
if user_id and session.get("user_id") != user_id:
continue
session_summary = _empty_summary()
session_rounds = [
dict(item) for item in session.get("rounds", [])
if (
(not pipeline or item.get("pipeline") == pipeline)
and self._round_in_time_range(item, start_dt, end_dt)
)
]
allowed_round_ids = {item["round_id"] for item in session_rounds}
apply_round_filter = bool(pipeline or start_dt or end_dt)
session_tools = self._aggregate_session_tools(
session,
pipeline,
allowed_round_ids=allowed_round_ids if apply_round_filter else None,
)
if pipeline:
session_summary["llm_tool_rounds"] = len({r["round_id"] for r in session_rounds})
session_summary["tool_calls"] = sum(
int(counter.get("call_count", 0) or 0)
for counter in session_tools.values()
)
self._add_round_tokens(session_summary, session_rounds)
else:
if start_dt or end_dt:
session_summary["llm_tool_rounds"] = len(allowed_round_ids)
session_summary["tool_calls"] = sum(
int(counter.get("call_count", 0) or 0)
for counter in session_tools.values()
)
self._add_round_tokens(session_summary, session_rounds)
else:
session_summary.update(session.get("summary", _empty_summary()))
summary["llm_tool_rounds"] += session_summary["llm_tool_rounds"]
summary["tool_calls"] += session_summary["tool_calls"]
summary["input_tokens"] += session_summary.get("input_tokens", 0) or 0
summary["output_tokens"] += session_summary.get("output_tokens", 0) or 0
summary["cache_read"] += session_summary.get("cache_read", 0) or 0
summary["cache_write"] += session_summary.get("cache_write", 0) or 0
summary["total_tokens"] += session_summary.get("total_tokens", 0) or 0
for name, counter in session_tools.items():
target = tools.setdefault(name, _empty_tool_counter())
target["call_count"] += counter.get("call_count", 0) or 0
target["success_count"] += counter.get("success_count", 0) or 0
target["fail_count"] += counter.get("fail_count", 0) or 0
target["total_duration_ms"] += counter.get("total_duration_ms", 0.0) or 0.0
target["round_ids"].update(counter.get("round_ids", set()) or set())
target["total_tokens"] += counter.get("total_tokens", 0) or 0
target["attributed_tokens"] += counter.get("attributed_tokens", counter.get("total_tokens", 0)) or 0
target["allocated_tokens"] += counter.get("allocated_tokens", 0) or 0
session_rows.append({
"session_id": sid,
"user_id": session.get("user_id", ""),
"llm_tool_rounds": session_summary["llm_tool_rounds"],
"tool_calls": session_summary["tool_calls"],
})
if include_rounds:
rounds.extend(session_rounds)
rounds_truncated = rounds_truncated or bool(session.get("rounds_truncated", False))
result = {
"ok": True,
"source": "internal",
"filters": {
"session_id": session_id,
"user_id": user_id,
"start_time": start_time,
"end_time": end_time,
"pipeline": pipeline,
},
"summary": summary,
"tools": [_tool_response(name, tools[name]) for name in sorted(tools)],
"sessions": sorted(session_rows, key=lambda item: item["session_id"]),
}
if include_rounds:
result["rounds"] = rounds[-rounds_limit:]
result["rounds_limit"] = rounds_limit
result["rounds_truncated"] = rounds_truncated or len(rounds) > rounds_limit
return result
@staticmethod
def _round_in_time_range(round_record: dict, start_dt, end_dt) -> bool:
ended_at = _parse_iso(round_record.get("ended_at"))
if ended_at is None:
return not (start_dt or end_dt)
if start_dt and ended_at < start_dt:
return False
if end_dt and ended_at > end_dt:
return False
return True
@staticmethod
def _add_round_tokens(summary: dict, rounds: list[dict]) -> None:
for round_record in rounds:
summary["input_tokens"] += int(round_record.get("input_tokens", 0) or 0)
summary["output_tokens"] += int(round_record.get("output_tokens", 0) or 0)
summary["cache_read"] += int(round_record.get("cache_read", 0) or 0)
summary["cache_write"] += int(round_record.get("cache_write", 0) or 0)
summary["total_tokens"] += int(round_record.get("total_tokens", 0) or 0)
@staticmethod
def _round_token_map(session: dict) -> dict[str, int]:
return {
str(round_record.get("round_id", "")): int(round_record.get("total_tokens", 0) or 0)
for round_record in session.get("rounds", [])
if round_record.get("round_id")
}
@classmethod
def _aggregate_session_tools(
cls,
session: dict,
pipeline: str | None,
*,
allowed_round_ids: set[str] | None = None,
) -> dict[str, dict]:
round_tokens = cls._round_token_map(session)
if allowed_round_ids is not None:
combined: dict[str, dict] = {}
for event in session.get("tool_events", []):
if pipeline and event.get("pipeline") != pipeline:
continue
round_id = event.get("round_id", "")
if round_id not in allowed_round_ids:
continue
target = combined.setdefault(event.get("tool_name", ""), _empty_tool_counter())
target["call_count"] += 1
if event.get("status") in {"success", "completed", "recorded"}:
target["success_count"] += 1
else:
target["fail_count"] += 1
target["total_duration_ms"] += event.get("duration_ms", 0.0) or 0.0
if round_id:
target["round_ids"].add(round_id)
round_tool_counts = cls._round_tool_call_counts(
session, pipeline, allowed_round_ids=allowed_round_ids,
)
cls._attach_tool_tokens(combined, round_tokens, round_tool_counts)
return combined
if pipeline:
combined = deepcopy(session.get("tools_by_pipeline", {}).get(pipeline, {}))
round_tool_counts = cls._round_tool_call_counts(session, pipeline)
cls._attach_tool_tokens(combined, round_tokens, round_tool_counts)
return combined
combined: dict[str, dict] = {}
for pipeline_tools in session.get("tools_by_pipeline", {}).values():
for name, counter in pipeline_tools.items():
target = combined.setdefault(name, _empty_tool_counter())
target["call_count"] += counter.get("call_count", 0) or 0
target["success_count"] += counter.get("success_count", 0) or 0
target["fail_count"] += counter.get("fail_count", 0) or 0
target["total_duration_ms"] += counter.get("total_duration_ms", 0.0) or 0.0
target["round_ids"].update(counter.get("round_ids", set()) or set())
round_tool_counts = cls._round_tool_call_counts(session, pipeline=None)
cls._attach_tool_tokens(combined, round_tokens, round_tool_counts)
return combined
@staticmethod
def _round_tool_call_counts(
session: dict,
pipeline: str | None,
*,
allowed_round_ids: set[str] | None = None,
) -> dict[str, dict[str, int]]:
counts: dict[str, dict[str, int]] = {}
for event in session.get("tool_events", []):
if pipeline and event.get("pipeline") != pipeline:
continue
round_id = str(event.get("round_id", ""))
if not round_id:
continue
if allowed_round_ids is not None and round_id not in allowed_round_ids:
continue
tool_name = str(event.get("tool_name", ""))
if not tool_name:
continue
per_round = counts.setdefault(round_id, {})
per_round[tool_name] = per_round.get(tool_name, 0) + 1
return counts
@staticmethod
def _attach_tool_tokens(
tools: dict[str, dict],
round_tokens: dict[str, int],
round_tool_counts: dict[str, dict[str, int]],
) -> None:
for tool_name, counter in tools.items():
attributed = sum(
round_tokens.get(round_id, 0)
for round_id in counter.get("round_ids", set())
)
allocated = 0.0
for round_id in counter.get("round_ids", set()):
per_round = round_tool_counts.get(round_id, {})
total_calls = sum(per_round.values())
if not total_calls:
continue
allocated += (
round_tokens.get(round_id, 0)
* per_round.get(tool_name, 0)
/ total_calls
)
counter["attributed_tokens"] = int(attributed)
counter["allocated_tokens"] = int(round(allocated))
counter["total_tokens"] = int(attributed)
class InternalToolUsageStore:
"""JSON store for per-session internal tool usage snapshots."""
def __init__(self, store: ControlPlaneStore):
self._store = store
def write_session(self, snapshot: dict) -> None:
account_id = snapshot["account_id"] if "account_id" in snapshot else self._account_from_snapshot(snapshot)
session_id = snapshot["filters"]["session_id"]
self._store.write_json(
self._tool_calls_path(account_id, session_id),
{
"version": 1,
"account_id": account_id,
"session_id": session_id,
"user_id": snapshot.get("sessions", [{}])[0].get("user_id", ""),
"created_at": _now_iso(),
"updated_at": _now_iso(),
"summary": snapshot.get("summary", _empty_summary()),
"tools": snapshot.get("tools", []),
},
)
self._store.write_json(
self._rounds_path(account_id, session_id),
{
"version": 1,
"account_id": account_id,
"session_id": session_id,
"user_id": snapshot.get("sessions", [{}])[0].get("user_id", ""),
"rounds_limit": snapshot.get("rounds_limit", 1000),
"rounds_truncated": snapshot.get("rounds_truncated", False),
"rounds": snapshot.get("rounds", []),
},
)
def read_session(self, account_id: str, session_id: str, *, include_rounds: bool = False) -> dict:
tool_calls = self._store.read_json(
self._tool_calls_path(account_id, session_id),
{
"summary": _empty_summary(),
"tools": [],
},
)
result = {
"ok": True,
"source": "internal",
"filters": {"session_id": session_id, "user_id": None, "start_time": None, "end_time": None, "pipeline": None},
"summary": tool_calls.get("summary", _empty_summary()),
"tools": tool_calls.get("tools", []),
"sessions": [{
"session_id": session_id,
"user_id": tool_calls.get("user_id", ""),
"llm_tool_rounds": tool_calls.get("summary", {}).get("llm_tool_rounds", 0),
"tool_calls": tool_calls.get("summary", {}).get("tool_calls", 0),
}],
}
if include_rounds:
rounds = self._store.read_json(
self._rounds_path(account_id, session_id),
{"rounds": [], "rounds_limit": 1000, "rounds_truncated": False},
)
result["rounds"] = rounds.get("rounds", [])
result["rounds_limit"] = rounds.get("rounds_limit", 1000)
result["rounds_truncated"] = rounds.get("rounds_truncated", False)
return result
def _tool_calls_path(self, account_id: str, session_id: str) -> str:
return self._store._join("accounts", account_id, "sessions", session_id, "tool_usage", "tool_calls.json")
def _rounds_path(self, account_id: str, session_id: str) -> str:
return self._store._join("accounts", account_id, "sessions", session_id, "tool_usage", "rounds.json")
@staticmethod
def _account_from_snapshot(snapshot: dict) -> str:
rounds = snapshot.get("rounds") or []
if rounds:
return str(rounds[0].get("account_id") or "")
return ""