import asyncio
import copy
from functools import wraps
from itertools import chain, filterfalse
from pathlib import Path
from time import time
from typing import Any, Callable, Iterable, List, Optional, ParamSpec, Tuple, Type, TypeVar, Awaitable
import numpy as np
import tomli
from pydantic import BaseModel, ConfigDict, Field
import bentoml
from vrag.constants import DEFAULT_RETRY_EXCEPTIONS
from vrag.logger import logger
_S = TypeVar("_S", bound="ArgsBase")
_T = TypeVar("_T")
_R = TypeVar("_R")
_P = ParamSpec("_P")
U8_MAX = 255
def into_u8_frames(n_frames: np.ndarray) -> np.ndarray:
if np.issubdtype(n_frames.dtype, np.floating):
return (n_frames * U8_MAX).astype(np.uint8)
elif n_frames.dtype != np.uint8:
return n_frames.astype(np.uint8)
return n_frames
def once(func: Callable[_P, _R]) -> Callable[_P, _R]:
"""
Decorator that ensures a function is executed only once and caches the result.
Support both synchronous functions.
Args:
func: The function to wrap.
Returns:
A wrapped function that returns the cached result on subsequent calls.
"""
__ret_value: Optional[_R] = None
_executed: bool = False
@wraps(func)
def _wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
nonlocal __ret_value, _executed
if _executed:
return __ret_value
__ret_value = func(*args, **kwargs)
_executed = True
return __ret_value
return _wrapper
def execute_time(func: Callable[_P, _R]) -> Callable[_P, _R]:
"""
Decorators that logs the execution time of a function.
Args:
func: The synchronous function to be timing.
Returns:
The wrapped function with logging capabilities.
"""
@wraps(func)
def _wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
start = time()
msg = f"Executing {func.__name__}"
logger.debug(msg)
res = func(*args, **kwargs)
msg = f"Executing {func.__name__} in {time() - start:.4f}s"
logger.debug(msg)
return res
return _wrapper
def first_available(*val: Optional[_T]) -> _T:
"""
Return the first non-None value from the provided arguments.
Args:
val: Variable number of optional values to check.
Returns:
The first non-None value among the provided arguments.
Raises:
RuntimeError: If all provided values are None.
"""
def ok(obj: Optional[_T], msg: str = "Value is None") -> _T:
if obj is None:
raise RuntimeError(msg)
return obj
return ok(next(v for v in val if v is not None), "No non-None values available")
def downsample(data: List[int], target_length: int) -> List[int]:
"""
Downsample a list of integers to a specified length (only supports target_length <= len(data)).
Args:
data: Optional list of integers.
target_length: Target length (must be <= original length and >= 1).
Returns:
Downsample list of integers to target_length.
"""
n = len(data)
if target_length > n:
raise ValueError("Target length should not be larger than original length (downsample)")
if target_length <= 0:
raise ValueError("Target length should be positive (downsample)")
if target_length == n:
return data.copy()
if target_length == 1:
return [data[n // 2]]
indices: List[int] = np.linspace(0, n - 1, target_length, dtype=int).tolist()
return [data[idx] for idx in indices]
class ConfigBase(BaseModel):
model_config = ConfigDict(use_attribute_docstrings=True)
class ArgsBase(BaseModel):
model_config = ConfigDict(use_attribute_docstrings=True)
config_file_path: Optional[Path] = None
"""Path to the TOML configuration file for overriding default arguments."""
service_timeout: float = 3600.0
"""The request session timeout for all services."""
service_max_concurrency: int = 20
"""The maximum number of concurrent service requests allowed."""
service_connection_request_multiplier: float = Field(2.0, gt=0)
"""The multiplier for the number of concurrent service requests allowed per runner connection."""
@classmethod
@once
def default(cls: Type[_S]) -> "_S":
return cls()
def override_by(self: _S, path: Optional[Path]) -> _S:
if path is None or not path.exists():
return self
override_data = tomli.loads(path.read_text(encoding="utf-8"))
current_data = self.model_dump(mode="json")
merged = self._merge(current_data, override_data)
return self.model_validate(merged)
def override(self: _S) -> _S:
return self.override_by(self.config_file_path)
@classmethod
def _merge(cls, base: dict, override: dict) -> dict:
_merged = copy.deepcopy(base)
for k, v in override.items():
if k in base and isinstance(base[k], dict) and isinstance(v, dict):
_merged[k] = cls._merge(base[k], v)
else:
_merged[k] = v
return _merged
def flatten(list_of_list: Iterable[Iterable[_T]]) -> Iterable[_T]:
return chain.from_iterable(list_of_list)
def vrag_service(args: ArgsBase):
return bentoml.service(
traffic={"timeout": args.service_timeout, "max_concurrency": args.service_max_concurrency},
runner_connection={
"max_age": args.service_timeout,
"max_request": args.service_max_concurrency * args.service_connection_request_multiplier,
},
max_runner_connections=args.service_max_concurrency,
)
def unique_everseen(iterable: Iterable[_T], key: Optional[Callable[[_T], _R]] = None) -> Iterable[_T]:
seen = set()
seen_add = seen.add
if key is None:
for element in filterfalse(seen.__contains__, iterable):
seen_add(element)
yield element
else:
for element in iterable:
k = key(element)
if k not in seen:
seen_add(k)
yield element
async def retry_async_request(
func: Callable[[], Awaitable[Any]],
label: str,
*,
retries: int = 3,
delay: float = 0.2,
max_delay: float = 2.0,
backoff: float = 2.0,
retry_exceptions: Tuple[Type[Exception], ...] = DEFAULT_RETRY_EXCEPTIONS,
) -> Any:
attempt = 0
current_delay = delay
while True:
try:
return await func()
except retry_exceptions as e:
if attempt >= retries:
msg = f"[retry] function: {label} failed after {retries} tries"
logger.error(msg)
raise
sleep_time = min(current_delay, max_delay)
msg = (
f"[retry] function: {label} attempt={attempt + 1}/{retries} "
f"error={type(e).__name__} sleep={sleep_time:.2f}s"
)
logger.warning(msg)
await asyncio.sleep(sleep_time)
attempt += 1
current_delay *= backoff