import hashlib
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from dataclasses import dataclass, field
from functools import wraps
from typing import Any, Callable, Concatenate, List, Optional, ParamSpec, TypeAlias, TypeVar, Union, overload, Coroutine
import numpy as np
from vrag.logger import logger
P = ParamSpec("P")
T = TypeVar("T")
S = TypeVar("S")
ArrayInput: TypeAlias = Union[np.ndarray, List[np.ndarray]]
class CacherBase(metaclass=ABCMeta):
@staticmethod
def _gather_missed(
all_keys: List[str], is_dense_arr: bool, miss_indices: List[int], work_list: List[np.ndarray]
) -> tuple[ArrayInput, List[str]]:
msg = f"Cache miss count (sync): {len(miss_indices)} / {len(work_list)}"
logger.debug(msg)
missed_arrays = [work_list[idx] for idx in miss_indices]
missed_keys = [all_keys[idx] for idx in miss_indices]
batch_input = np.stack(missed_arrays, axis=0) if is_dense_arr else missed_arrays
return batch_input, missed_keys
@staticmethod
def _validate_results(result_map: dict[int, T], expected_length: int) -> List[T]:
final_result: List[T] = []
for i in range(expected_length):
if i not in result_map:
raise RuntimeError(f"Logic error: missing result for index {i}")
final_result.append(result_map[i])
return final_result
@staticmethod
def _normalized_arr(arr_input: ArrayInput) -> tuple[bool, List[np.ndarray]]:
if isinstance(arr_input, np.ndarray):
work_list: List[np.ndarray] = list(arr_input)
is_dense_arr = True
else:
work_list = arr_input
is_dense_arr = False
return is_dense_arr, work_list
@staticmethod
def _get_array_hash(arr: np.ndarray) -> str:
"""Generate hash including shape, dtype, and content to prevent collisions."""
arr_contiguous = np.ascontiguousarray(arr)
shape_bytes = str(arr.shape).encode("utf-8")
dtype_bytes = str(arr.dtype).encode("utf-8")
data_bytes = arr_contiguous.tobytes()
combined = shape_bytes + b"|" + dtype_bytes + b"|" + data_bytes
return hashlib.sha256(combined).hexdigest()
@abstractmethod
def get(self, key: str) -> Any:
raise NotImplementedError(f"{self.__class__} not implements get")
@abstractmethod
def put(self, key: str, data: Any) -> None:
raise NotImplementedError(f"{self.__class__} not implements put")
def cached_sync(
self, func: Callable[Concatenate[ArrayInput, P], List[T]]
) -> Callable[Concatenate[ArrayInput, P], List[T]]:
"""Sync version of cache."""
@wraps(func)
def _wrapper(arr_input: ArrayInput, *args: P.args, **kwargs: P.kwargs) -> List[T]:
return self._cache_logic_sync(arr_input, args, kwargs, func, self._get_array_hash)
return _wrapper
def cached_sync_with(
self, make_key_suffix: Callable[P, str]
) -> Callable[[Callable[Concatenate[ArrayInput, P], List[T]]], Callable[Concatenate[ArrayInput, P], List[T]]]:
def decorator(
func: Callable[Concatenate[ArrayInput, P], List[T]],
) -> Callable[Concatenate[ArrayInput, P], List[T]]:
@wraps(func)
def _wrapper(arr_input: ArrayInput, *args: P.args, **kwargs: P.kwargs) -> List[T]:
suffix = make_key_suffix(*args, **kwargs)
return self._cache_logic_sync(
arr_input, args, kwargs, func, lambda arr: self._get_combined_key(arr, suffix)
)
return _wrapper
return decorator
def cached(
self, func: Callable[Concatenate[ArrayInput, P], Coroutine[Any, Any, List[T]]]
) -> Callable[Concatenate[ArrayInput, P], Coroutine[Any, Any, List[T]]]:
@wraps(func)
async def _wrapper(arr_input: ArrayInput, *args: P.args, **kwargs: P.kwargs) -> List[T]:
return await self._cache_logic(arr_input, args, kwargs, func, self._get_array_hash)
return _wrapper
def cached_with(
self, make_key_suffix: Callable[P, str]
) -> Callable[
[Callable[Concatenate[ArrayInput, P], Coroutine[Any, Any, List[T]]]],
Callable[Concatenate[ArrayInput, P], Coroutine[Any, Any, List[T]]],
]:
def decorator(
func: Callable[Concatenate[ArrayInput, P], Coroutine[Any, Any, List[T]]],
) -> Callable[Concatenate[ArrayInput, P], Coroutine[Any, Any, List[T]]]:
@wraps(func)
async def _wrapper(arr_input: ArrayInput, *args: P.args, **kwargs: P.kwargs) -> List[T]:
suffix = make_key_suffix(*args, **kwargs)
return await self._cache_logic(
arr_input, args, kwargs, func, lambda arr: self._get_combined_key(arr, suffix)
)
return _wrapper
return decorator
def cached_method(
self, func: Callable[Concatenate[S, ArrayInput, P], Coroutine[Any, Any, List[T]]]
) -> Callable[Concatenate[S, ArrayInput, P], Coroutine[Any, Any, List[T]]]:
@wraps(func)
async def _wrapper(slf: S, arr_input: ArrayInput, *args: P.args, **kwargs: P.kwargs) -> List[T]:
return await self._cache_logic(arr_input, args, kwargs, func, self._get_array_hash, slf=slf)
return _wrapper
def cached_method_with(
self, make_key_suffix: Callable[P, str]
) -> Callable[
[Callable[Concatenate[S, ArrayInput, P], Coroutine[Any, Any, List[T]]]],
Callable[Concatenate[S, ArrayInput, P], Coroutine[Any, Any, List[T]]],
]:
def decorator(
func: Callable[Concatenate[S, ArrayInput, P], Coroutine[Any, Any, List[T]]],
) -> Callable[Concatenate[S, ArrayInput, P], Coroutine[Any, Any, List[T]]]:
@wraps(func)
async def _wrapper(slf: S, arr_input: ArrayInput, *args: P.args, **kwargs: P.kwargs) -> List[T]:
suffix = make_key_suffix(*args, **kwargs)
return await self._cache_logic(
arr_input, args, kwargs, func, lambda arr: self._get_combined_key(arr, suffix), slf=slf
)
return _wrapper
return decorator
def _merge_result(
self, batch_result: List[T], miss_indices: List[int], miss_keys: List[str], result_map: dict[int, T]
):
if len(batch_result) != len(miss_indices):
raise ValueError(f"Merge failed with {len(batch_result)} items, expected {len(miss_indices)}")
for i, idx in enumerate(miss_indices):
res_item = batch_result[i]
result_map[idx] = res_item
self.put(miss_keys[i], res_item)
def _gather_hits_misses(
self, make_key: Callable[[np.ndarray], str], work_list: List[np.ndarray]
) -> tuple[List[str], List[int], dict[int, T]]:
result_map: dict[int, T] = {}
miss_indices: List[int] = []
all_keys = [make_key(arr) for arr in work_list]
for idx, key in enumerate(all_keys):
cached_val = self.get(key)
if cached_val is not None:
result_map[idx] = cached_val
else:
miss_indices.append(idx)
return all_keys, miss_indices, result_map
def _get_combined_key(self, arr: np.ndarray, extra_key: str) -> str:
return f"{self._get_array_hash(arr)}_{extra_key}"
@overload
def _cache_logic_sync(
self,
arr_input: ArrayInput,
args: P.args,
kwargs: P.kwargs,
func: Callable[Concatenate[ArrayInput, P], List[T]],
make_key: Callable[[np.ndarray], str],
) -> List[T]:
"""Cache for function."""
@overload
def _cache_logic_sync(
self,
arr_input: ArrayInput,
args: P.args,
kwargs: P.kwargs,
func: Callable[Concatenate[S, ArrayInput, P], List[T]],
make_key: Callable[[np.ndarray], str],
/,
slf: S = None,
) -> List[T]:
"""Cache for method."""
def _cache_logic_sync(
self,
arr_input: ArrayInput,
args: P.args,
kwargs: P.kwargs,
func: Callable[Concatenate[ArrayInput, P], List[T]] | Callable[Concatenate[S, ArrayInput, P], List[T]],
make_key: Callable[[np.ndarray], str],
/,
slf: S = None,
) -> List[T]:
"""Sync version of cache."""
is_dense_arr, work_list = self._normalized_arr(arr_input)
all_keys, miss_indices, result_map = self._gather_hits_misses(make_key, work_list)
if miss_indices:
batch_input, miss_keys = self._gather_missed(all_keys, is_dense_arr, miss_indices, work_list)
batch_results = (func(slf, batch_input, *args, **kwargs)) if slf else func(batch_input, *args, **kwargs)
self._merge_result(batch_results, miss_indices, miss_keys, result_map)
else:
logger.debug("Cache all hit (sync)!")
return self._validate_results(result_map, len(work_list))
@overload
async def _cache_logic(
self,
arr_input: ArrayInput,
args: P.args,
kwargs: P.kwargs,
func: Callable[Concatenate[ArrayInput, P], Coroutine[Any, Any, List[T]]],
make_key: Callable[[np.ndarray], str],
) -> List[T]:
"""Cache for function"""
@overload
async def _cache_logic(
self,
arr_input: ArrayInput,
args: P.args,
kwargs: P.kwargs,
func: Callable[Concatenate[S, ArrayInput, P], Coroutine[Any, Any, List[T]]],
make_key: Callable[[np.ndarray], str],
/,
slf: S = None,
) -> List[T]:
"""Cache for method"""
async def _cache_logic(
self,
arr_input: ArrayInput,
args: P.args,
kwargs: P.kwargs,
func: Callable[Concatenate[ArrayInput, P], Coroutine[Any, Any, List[T]]]
| Callable[Concatenate[S, ArrayInput, P], Coroutine[Any, Any, List[T]]],
make_key: Callable[[np.ndarray], str],
/,
slf: S = None,
) -> List[T]:
"""Async version of cache."""
is_dense_arr, work_list = self._normalized_arr(arr_input)
all_keys, miss_indices, result_map = self._gather_hits_misses(make_key, work_list)
if miss_indices:
batch_input, miss_keys = self._gather_missed(all_keys, is_dense_arr, miss_indices, work_list)
if slf:
batch_results = await func(slf, batch_input, *args, **kwargs)
else:
batch_results = await func(batch_input, *args, **kwargs)
self._merge_result(batch_results, miss_indices, miss_keys, result_map)
else:
logger.debug("Cache all hit (async)!")
return self._validate_results(result_map, len(work_list))
def get_cacher(cap: int) -> CacherBase:
return _NumpyLRUCache(capacity=cap)
@dataclass
class _NumpyLRUCache(CacherBase):
capacity: int
cache: OrderedDict = field(init=False, default_factory=OrderedDict)
def get(self, key: str) -> Optional[Any]:
if key not in self.cache:
return None
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key: str, data: Any) -> None:
if key in self.cache:
self.cache.move_to_end(key)
elif len(self.cache) >= self.capacity:
self.cache.popitem(last=False)
self.cache[key] = data