"""
-------------------------------------------------------------------------
This file is part of the RAGSDK project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
RAGSDK is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
MXRAGCache 核心功能类
该类主要是给RAG框架提供数据缓存的能力,包括以下功能
1、缓存实例的构造(get_cache, new_cache)
2、缓存的查询(search),更新(update),以及刷新(flush)
3、缓存的级联功能(join)
"""
import os
import re
from typing import Any
from gptcache.core import Cache
from loguru import logger
from mx_rag.cache import CacheConfig, SimilarityCacheConfig
from mx_rag.utils.file_check import SecFileCheck
from mx_rag.utils.common import validate_params, GB, TEXT_MAX_LEN, MAX_QUERY_LENGTH
def _default_dump(data: Any) -> str:
return data
def _default_load(data: str) -> Any:
return data
class MxRAGCache:
cache_limit: int = TEXT_MAX_LEN
cache_join_size_limit: int = 6
current_join_size: int = 0
verbose: bool = False
@validate_params(
cache_name=dict(
validator=lambda x: isinstance(x, str) and 0 < len(x) < 64 and bool(re.fullmatch(r'[0-9a-zA-Z_]+', x)),
message="param must meets: Type is str, length range (0, 64), match '[0-9a-zA-Z_]+'",
),
config=dict(
validator=lambda x: isinstance(x, (CacheConfig, SimilarityCacheConfig)),
message="param must be instance of CacheConfig or SimilarityCacheConfig",
),
)
def __init__(self, cache_name: str, config: CacheConfig):
self.__cache_obj = Cache()
self.config = config
self.cache_name = cache_name
self.next_cache = None
try:
from mx_rag.cache.cache_api.cache_init import _init_mxrag_cache
self.data_save_path = _init_mxrag_cache(self.__cache_obj, cache_name, config)
except KeyError as e:
logger.error(f"init rag cache failed because key error: {e}")
raise
except Exception as e:
logger.error(f"init rag cache failed: {e}")
raise
@staticmethod
def _update(llm_data, update_cache_func, *args, **kwargs) -> None:
"""When updating cached data, do nothing, because currently only cached queries are processed"""
from gptcache.adapter.api import _update_cache_callback
_update_cache_callback(llm_data, update_cache_func, *args, **kwargs)
return llm_data
@classmethod
@validate_params(verbose=dict(validator=lambda x: isinstance(x, bool), message="param value must be bool"))
def set_verbose(cls, verbose: bool):
cls.verbose = verbose
@classmethod
@validate_params(
cache_limit=dict(
validator=lambda x: isinstance(x, int) and 0 < x <= TEXT_MAX_LEN,
message="param value range (0, 1000 * 1000]",
)
)
def set_cache_limit(cls, cache_limit: int):
cls.cache_limit = cache_limit
def clear(self):
next_cache = self.next_cache
while next_cache is not None:
next_cache.clear()
next_cache = next_cache.next_cache
if not self.__cache_obj.has_init:
return
data_manager = self.__cache_obj.data_manager
if data_manager is None:
return
if hasattr(data_manager, 's') and hasattr(data_manager, 'v'):
scalar_store = data_manager.s
vector_store = data_manager.v
all_ids = scalar_store.get_ids(deleted=False)
if all_ids:
scalar_store.mark_deleted(all_ids)
scalar_store.clear_deleted_data()
vector_store.delete(all_ids)
vector_store.flush()
if hasattr(data_manager, 'data') and hasattr(data_manager, 'data_path'):
data_manager.data.clear()
data_manager.flush()
@validate_params(
query=dict(validator=lambda x: isinstance(x, str) and 0 < len(x) <= MAX_QUERY_LENGTH),
)
def search(self, query: str):
"""
MXRAGCache 查询缓存
Args:
query: 需要被查询的缓存问题
Return:
answer: 如果命中则为缓存问题,未命中则返回None
"""
if not self.__cache_obj.has_init:
raise KeyError("cache not init pls init first")
from gptcache.adapter.api import adapt, _cache_data_converter
def llm_handle_none(*llm_args, **llm_kwargs) -> None:
"""Do nothing on a cache miss"""
return None
answer = adapt(llm_handle_none, _cache_data_converter, self._update, prompt=query, cache_obj=self.__cache_obj)
if answer is not None:
self._verbose_log("Hit!")
else:
self._verbose_log("Miss!")
return answer
@validate_params(
query=dict(
validator=lambda x: isinstance(x, str) and 0 < len(x) <= MAX_QUERY_LENGTH,
message="param must str and char range is (0, 128 * 1024 * 1024]",
),
answer=dict(
validator=lambda x: isinstance(x, str) and 0 < len(x) <= TEXT_MAX_LEN,
message="param must str and char range is (0, 1000 * 1000]",
),
)
def update(self, query: str, answer: str):
"""
MXRAGCache 更新缓存
Args:
query: 需要被缓存的用户问题
answer: 需要被缓存的用户答案
Return:
None
"""
if not self.__cache_obj.has_init:
raise KeyError("cache not init pls init first")
if not self._check_limit(answer):
self._verbose_log("context length is large no caching")
return
from gptcache.adapter.api import adapt, _cache_data_converter
def llm_handle(*llm_args, **llm_kwargs):
return answer
adapt(
llm_handle, _cache_data_converter, self._update, cache_skip=True, prompt=query, cache_obj=self.__cache_obj
)
self._verbose_log("Update!")
def flush(self):
"""
MXRAGCache 强制将缓存数据从内存刷新到磁盘
Return:
None
"""
if not self.__cache_obj.has_init:
raise KeyError("cache not init pls init first")
if hasattr(self.__cache_obj.data_manager, "data_path") and os.path.exists(
self.__cache_obj.data_manager.data_path
):
SecFileCheck(self.__cache_obj.data_manager.data_path, 100 * GB).check()
self.__cache_obj.flush()
self._verbose_log("Flush!")
def get_obj(self):
"""
MXRAGCache 获得gpt缓存示例,用于兼容langchain等 RAG开源框架
Return:
gptcache
"""
if not self.__cache_obj.has_init:
raise KeyError("cache not init pls init first")
return self.__cache_obj
@validate_params(
next_cache=dict(validator=lambda x: isinstance(x, MxRAGCache)),
)
def join(self, next_cache):
"""
MXRAGCache 缓存级联
Args:
next_cache: 下级缓存
Return:
None
"""
if not self.__cache_obj.has_init:
raise KeyError("cache not init pls init first")
self._join_check(next_cache)
self.__cache_obj.next_cache = next_cache.get_obj()
MxRAGCache.current_join_size = MxRAGCache.current_join_size + 1
self.next_cache = next_cache
def _join_check(self, next_cache):
logger.debug(f"cache deepth:{MxRAGCache.current_join_size} cache_limit:{MxRAGCache.cache_join_size_limit}")
if MxRAGCache.current_join_size >= MxRAGCache.cache_join_size_limit:
raise OverflowError(
f"the number of cache join deepth cannot be greater than {MxRAGCache.cache_join_size_limit}"
)
loop_cnt: int = 0
next_cache_obj = next_cache.get_obj()
while next_cache_obj is not None:
if loop_cnt >= MxRAGCache.cache_join_size_limit:
raise OverflowError(
f"the number of cache join deepth cannot be greater than {MxRAGCache.cache_join_size_limit}"
)
if next_cache_obj == self.__cache_obj:
raise ValueError("forbidden loop join")
next_cache_obj = next_cache_obj.next_cache
loop_cnt = loop_cnt + 1
def _check_limit(self, input_text: str):
return len(input_text) < self.cache_limit
def _verbose_log(self, log_str: str):
"""
MXRAGCache 根据verbose标志 用于表示是否记录日志。
Args:
log_str: 日志信息
Return:
None
"""
if self.verbose:
logger.info(log_str)