import asyncio
import time
import pytest
from openjiuwen_deepsearch.common.exception import CustomRuntimeException
from openjiuwen_deepsearch.common.status_code import StatusCode
from openjiuwen_deepsearch.utils.rate_limiter_utils.qps_limiter import QPSRateLimiter
class TestQPSRateLimiter:
"""QPSRateLimiter 单元测试"""
@pytest.mark.asyncio
async def test_acquire_no_limit(self):
"""测试 QPS<=0 时不限流"""
limiter = QPSRateLimiter()
for qps in [None, 0, -1]:
limiter.set_max_qps(qps)
start_time = time.time()
await limiter.acquire()
elapsed = time.time() - start_time
assert elapsed < 0.1
@pytest.mark.asyncio
async def test_acquire_with_limit(self):
"""测试正常 QPS 限流"""
limiter = QPSRateLimiter()
max_qps = 5
limiter.set_max_qps(max_qps)
num_requests = 8
start_time = time.time()
tasks = [limiter.acquire() for _ in range(num_requests)]
await asyncio.gather(*tasks)
elapsed = time.time() - start_time
expected_min_time = (num_requests - max_qps) / max_qps
assert elapsed >= expected_min_time * 0.5
@pytest.mark.asyncio
async def test_calculate_timeout(self):
"""测试超时时间计算"""
limiter = QPSRateLimiter()
limiter.set_max_qps(10)
timeout = limiter._calculate_timeout()
assert timeout == 3.0
limiter.set_max_qps(1)
timeout = limiter._calculate_timeout()
assert timeout == 3.0
limiter.set_max_qps(0.5)
timeout = limiter._calculate_timeout()
assert timeout == 6.0
limiter.set_max_qps(0.1)
timeout = limiter._calculate_timeout()
assert timeout == 30.0
limiter.set_max_qps(0.01)
timeout = limiter._calculate_timeout()
assert timeout == 60.0
@pytest.mark.asyncio
async def test_acquire_timeout_with_retry(self):
"""测试超时后重试机制"""
limiter = QPSRateLimiter()
limiter.set_max_qps(10)
call_count = 0
async def mock_acquire_always_timeout(timeout):
nonlocal call_count
call_count += 1
raise asyncio.TimeoutError()
original_acquire = limiter._acquire_with_timeout
limiter._acquire_with_timeout = mock_acquire_always_timeout
with pytest.raises(CustomRuntimeException) as exc_info:
await limiter.acquire()
assert exc_info.value.error_code == StatusCode.RATE_LIMIT_TIMEOUT_ERROR.code
assert call_count == 2
limiter._acquire_with_timeout = original_acquire
@pytest.mark.asyncio
async def test_acquire_success_after_retry(self):
"""测试重试后成功获取许可"""
limiter = QPSRateLimiter()
limiter.set_max_qps(10)
call_count = 0
original_acquire = limiter._acquire_with_timeout
async def mock_acquire_once_fail(timeout):
nonlocal call_count
call_count += 1
if call_count == 1:
raise asyncio.TimeoutError()
return await original_acquire(timeout)
limiter._acquire_with_timeout = mock_acquire_once_fail
await limiter.acquire()
assert call_count == 2
limiter._acquire_with_timeout = original_acquire