from __future__ import annotations
from types import SimpleNamespace
import pytest
from langchain_core.messages import HumanMessage
from msagent.utils import rate_limiter as limiter_module
class _FakeClock:
def __init__(self, current: float = 100.0) -> None:
self.current = current
def time(self) -> float:
return self.current
def sleep(self, seconds: float) -> None:
self.current += seconds
class _FakeLLM:
def __init__(self) -> None:
self.sync_calls: list[tuple[list[HumanMessage], object | None, dict[str, object]]] = []
self.async_calls: list[tuple[list[HumanMessage], object | None, dict[str, object]]] = []
def _call(self, messages, config=None, **kwargs):
self.sync_calls.append((messages, config, kwargs))
return "sync-result"
async def _acall(self, messages, config=None, **kwargs):
self.async_calls.append((messages, config, kwargs))
return "async-result"
def test_token_bucket_limiter_updates_and_consumes_tokens(monkeypatch) -> None:
fake_clock = _FakeClock()
monkeypatch.setattr(limiter_module.time, "time", fake_clock.time)
limiter = limiter_module.TokenBucketLimiter(2.0, 50.0, 40.0, max_bucket_size=10)
limiter.request_bucket = 0.0
limiter.input_token_bucket = 0.0
limiter.output_token_bucket = 0.0
limiter.last_update_time = 100.0
fake_clock.current = 101.0
limiter._update_buckets()
assert limiter.request_bucket == 2
assert limiter.input_token_bucket == 50
assert limiter.output_token_bucket == 40
logged: list[str] = []
monkeypatch.setattr(limiter_module.logger, "info", logged.append)
limiter.request_bucket = 1.0
limiter.input_token_bucket = 1000.0
limiter.output_token_bucket = 1000.0
limiter.last_update_time = fake_clock.current
assert limiter._consume(input_tokens=600, output_tokens=700) is True
assert limiter.request_bucket == 0.0
assert limiter.input_token_bucket == 400.0
assert limiter.output_token_bucket == 300.0
assert list(limiter.recent_input_tokens) == [600]
assert list(limiter.recent_output_tokens) == [700]
assert logged and "Rate usage:" in logged[0]
def test_token_bucket_limiter_acquire_variants(monkeypatch) -> None:
limiter = limiter_module.TokenBucketLimiter(1.0, 10.0, 10.0)
states = iter([False, True])
monkeypatch.setattr(limiter, "_consume", lambda *args, **kwargs: next(states))
sleeps: list[float] = []
monkeypatch.setattr(limiter_module.time, "sleep", sleeps.append)
assert limiter.acquire(blocking=True) is True
assert sleeps == [limiter.check_every_n_seconds]
monkeypatch.setattr(limiter, "_consume", lambda *args, **kwargs: False)
assert limiter.acquire(blocking=False) is False
@pytest.mark.asyncio
async def test_token_bucket_limiter_async_acquire_and_async_call(monkeypatch) -> None:
limiter = limiter_module.TokenBucketLimiter(1.0, 10.0, 10.0)
states = iter([False, True])
monkeypatch.setattr(limiter, "_consume", lambda *args, **kwargs: next(states))
sleeps: list[float] = []
async def fake_sleep(seconds: float) -> None:
sleeps.append(seconds)
import asyncio
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
assert await limiter.aacquire(blocking=True) is True
assert sleeps == [limiter.check_every_n_seconds]
monkeypatch.setattr(limiter, "_consume", lambda *args, **kwargs: True)
assert await limiter.aacquire(blocking=False) is True
llm = _FakeLLM()
consumed: list[int] = []
monkeypatch.setattr(
limiter, "_consume", lambda input_tokens=0, output_tokens=0: consumed.append(input_tokens) or True
)
result = await limiter._acall(
llm,
[HumanMessage(content="abcdefgh")],
config=SimpleNamespace(name="cfg"),
stream=True,
)
assert result == "async-result"
assert consumed == [2]
assert llm.async_calls[0][2] == {"stream": True}
def test_token_bucket_limiter_call_estimates_tokens_and_invokes_llm(monkeypatch) -> None:
limiter = limiter_module.TokenBucketLimiter(1.0, 10.0, 10.0)
llm = _FakeLLM()
consumed: list[int] = []
monkeypatch.setattr(
limiter, "_consume", lambda input_tokens=0, output_tokens=0: consumed.append(input_tokens) or True
)
monkeypatch.setattr(limiter_module.time, "sleep", lambda _seconds: None)
result = limiter(
llm,
[HumanMessage(content="abcd"), HumanMessage(content="efgh")],
config=SimpleNamespace(name="cfg"),
temperature=0.1,
)
assert result == "sync-result"
assert consumed == [2]
assert llm.sync_calls[0][2] == {"temperature": 0.1}