import asyncio
import functools
import logging
from typing import Optional, Union
from aiolimiter import AsyncLimiter
from openjiuwen_deepsearch.common.exception import CustomRuntimeException
from openjiuwen_deepsearch.common.status_code import StatusCode
logger = logging.getLogger(__name__)
class QPSRateLimiter:
"""
QPS 限流器(基于 aiolimiter 实现)
使用令牌桶算法实现异步 QPS 限流,支持动态配置 QPS 值。
当 QPS 设置为 0 或负数时,不启用限流功能。
支持浮点数 QPS,如 0.5 表示每 2 秒 1 个请求。
"""
def __init__(self):
"""初始化限流器,从配置中读取 QPS 限制"""
self._limiter: Optional[AsyncLimiter] = None
self._max_qps: Optional[float] = None
def get_max_qps(self) -> Optional[float]:
"""获取当前 QPS 限制值"""
return self._max_qps
def set_max_qps(self, max_qps: Optional[Union[int, float]]) -> None:
"""设置 QPS 限制值"""
self._max_qps = float(max_qps) if max_qps is not None else None
logger.info(f"[QPSRateLimiter] Set max_qps to {self._max_qps}")
async def acquire(self) -> None:
"""
获取限流许可,支持超时和重试机制
超时后会自动重试一次,如果仍然超时则抛出异常。
Raises:
CustomRuntimeException: 限流超时异常
"""
limiter = self._get_limiter()
if limiter is None:
return
timeout = self._calculate_timeout()
max_attempts = 2
for attempt in range(max_attempts):
try:
await self._acquire_with_timeout(timeout)
return
except asyncio.TimeoutError as e:
if attempt < max_attempts - 1:
logger.warning(
f"[QPSRateLimiter] Rate limit timeout, retrying... "
f"(attempt {attempt + 1}/{max_attempts}), max_qps={self._max_qps}, timeout={timeout:.1f}s"
)
else:
logger.error(
f"[QPSRateLimiter] Rate limit timeout after {max_attempts} attempts, "
f"max_qps={self._max_qps}, timeout={timeout:.1f}s"
)
raise CustomRuntimeException(
StatusCode.RATE_LIMIT_TIMEOUT_ERROR.code,
StatusCode.RATE_LIMIT_TIMEOUT_ERROR.errmsg.format(
timeout=timeout, max_qps=self._max_qps
)
) from e
def _get_limiter(self) -> Optional[AsyncLimiter]:
"""获取或创建限流器实例,当 max_qps 变化时重建"""
max_qps = self._max_qps
if max_qps is None or max_qps <= 0:
return None
if self._limiter is None:
if max_qps >= 1:
self._limiter = AsyncLimiter(max_rate=max_qps, time_period=1.0)
else:
self._limiter = AsyncLimiter(max_rate=1, time_period=1.0 / max_qps)
logger.info(f"[QPSRateLimiter] Created new limiter with max_qps={max_qps}")
return self._limiter
def _calculate_timeout(self) -> float:
"""
计算超时时间,与 QPS 挂钩
超时时间计算策略:
- QPS >= 1: timeout = 3.0 (最少 3 秒等待)
- QPS < 1: timeout = 3.0 / max_qps (低 QPS 需要更长等待时间)
- 最小超时时间: 3 秒
- 最大超时时间: 60 秒
Returns:
float: 超时时间(秒)
"""
max_qps = self._max_qps
if max_qps is None or max_qps <= 0:
return 3.0
timeout = 3.0 / max_qps
timeout = max(3.0, min(timeout, 60.0))
return timeout
async def _acquire_with_timeout(self, timeout: float) -> bool:
"""
带超时的获取限流许可
Args:
timeout: 超时时间(秒)
Returns:
bool: 是否成功获取许可
Raises:
asyncio.TimeoutError: 获取许可超时
"""
limiter = self._get_limiter()
if limiter is None:
return True
if not limiter.has_capacity():
logger.info(f"[QPSRateLimiter] Rate limit exceeded, waiting for permit, max_qps={self._max_qps}")
await asyncio.wait_for(limiter.acquire(1), timeout=timeout)
logger.info(f"[QPSRateLimiter] Request permitted, max_qps={self._max_qps}")
return True
qps_rate_limiter = QPSRateLimiter()
def qps_rate_limit_async(func):
"""
异步 QPS 限流装饰器
用于装饰异步函数,在函数执行前自动获取限流许可。
限流参数从配置中读取 web_search_max_qps。
支持超时机制和重试逻辑。
用法:
@qps_rate_limit_async
async def my_async_function():
...
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
await qps_rate_limiter.acquire()
return await func(*args, **kwargs)
return wrapper