"""Concurrency tests for the per-context TokenBucket attribution.
These tests guard the property that motivated the bucket redesign:
multiple threads can run :class:`providers.token_tracker.TokenTracker`
recordings simultaneously and each thread's bucket sees only its own
calls, with no cross-thread pollution.
"""
from __future__ import annotations
import threading
from concurrent.futures import ThreadPoolExecutor
import pytest
from providers.token_tracker import (
TokenBucket,
TokenTracker,
current_buckets,
pop_bucket,
push_bucket,
)
@pytest.mark.unit
def test_push_pop_isolated_per_thread() -> None:
"""Buckets pushed in one thread MUST NOT appear in another thread."""
main_bucket, main_token = push_bucket()
seen_in_other_thread: list[int] = []
def _worker() -> None:
seen_in_other_thread.append(len(current_buckets()))
t = threading.Thread(target=_worker)
t.start()
t.join()
assert seen_in_other_thread == [0], (
"ContextVar bucket stack must NOT leak from main thread to spawned thread"
)
assert len(current_buckets()) == 1
pop_bucket(main_token)
assert len(current_buckets()) == 0
@pytest.mark.unit
def test_concurrent_threads_no_token_pollution() -> None:
"""Each thread's bucket reflects ONLY its own recordings.
Spawn N threads, each with its own bucket and its own TokenTracker
(simulating distinct providers). Have each thread record a unique
token count, then verify:
* its bucket equals exactly what it recorded
* no bucket contains any other thread's tokens
"""
n_threads = 16
barrier = threading.Barrier(n_threads)
results: dict[int, dict[str, int]] = {}
lock = threading.Lock()
def _worker(idx: int) -> None:
my_input = (idx + 1) * 100
my_output = (idx + 1) * 50
my_embed = (idx + 1) * 10
tracker_llm = TokenTracker(model=f"llm-{idx}")
tracker_embed = TokenTracker(model=f"embed-{idx}")
bucket, token = push_bucket()
try:
barrier.wait()
tracker_llm.record_llm(my_input, my_output)
tracker_embed.record_embed(my_embed)
finally:
data = bucket.to_attribution_dict()
pop_bucket(token)
with lock:
results[idx] = {
"llm_input": data["llm"]["input_tokens"],
"llm_output": data["llm"]["output_tokens"],
"embed": data["embed"]["embed_tokens"],
"expected_input": my_input,
"expected_output": my_output,
"expected_embed": my_embed,
"llm_model": data.get("llm_model"),
"embed_model": data.get("embed_model"),
}
with ThreadPoolExecutor(max_workers=n_threads) as ex:
list(ex.map(_worker, range(n_threads)))
assert len(results) == n_threads
for idx, r in results.items():
assert r["llm_input"] == r["expected_input"], (
f"thread {idx}: LLM input contaminated — got {r['llm_input']}, "
f"expected {r['expected_input']}"
)
assert r["llm_output"] == r["expected_output"], (
f"thread {idx}: LLM output contaminated"
)
assert r["embed"] == r["expected_embed"], (
f"thread {idx}: embed contaminated"
)
assert r["llm_model"] == f"llm-{idx}"
assert r["embed_model"] == f"embed-{idx}"
@pytest.mark.unit
def test_nested_buckets_both_receive_recording() -> None:
"""When buckets nest, every active bucket receives each recording."""
tracker = TokenTracker(model="m1")
outer, outer_token = push_bucket()
tracker.record_llm(100, 50)
inner, inner_token = push_bucket()
tracker.record_llm(200, 80)
tracker.record_embed(15)
pop_bucket(inner_token)
tracker.record_llm(50, 20)
pop_bucket(outer_token)
inner_data = inner.to_attribution_dict()
outer_data = outer.to_attribution_dict()
assert inner_data["llm"]["input_tokens"] == 200
assert inner_data["llm"]["output_tokens"] == 80
assert inner_data["embed"]["embed_tokens"] == 15
assert outer_data["llm"]["input_tokens"] == 350
assert outer_data["llm"]["output_tokens"] == 150
assert outer_data["embed"]["embed_tokens"] == 15
@pytest.mark.unit
def test_global_tracker_unaffected_by_buckets() -> None:
"""Global cumulative counters keep working alongside buckets."""
tracker = TokenTracker(model="m1")
bucket, token = push_bucket()
tracker.record_llm(100, 50)
pop_bucket(token)
tracker.record_llm(7, 3)
snap = tracker.snapshot()
assert snap.input_tokens == 107
assert snap.output_tokens == 53
@pytest.mark.unit
def test_recording_outside_any_bucket_is_safe() -> None:
"""No active bucket = no-op for the bucket layer; no exceptions."""
tracker = TokenTracker(model="m1")
assert current_buckets() == ()
tracker.record_llm(10, 5)
tracker.record_embed(20)
tracker.record_local_cache_hit(tokens_saved=3)
tracker.record_local_cache_miss()
snap = tracker.snapshot()
assert snap.input_tokens == 10
assert snap.embed_tokens == 20
assert snap.local_cache_hits == 1
@pytest.mark.unit
def test_cache_hit_propagates_to_bucket() -> None:
tracker = TokenTracker(model="m1")
bucket, token = push_bucket()
tracker.record_local_cache_hit(tokens_saved=42)
tracker.record_local_cache_miss()
pop_bucket(token)
data = bucket.to_attribution_dict()
assert data["llm"]["local_cache_hits"] == 1
assert data["llm"]["local_cache_misses"] == 1
assert data["llm"]["local_cache_saved_tokens"] == 42