from mindiesd.utils.logs.logging import logger
from .cache import CacheConfig, CacheBase
class AttentionCache(CacheBase):
def __init__(self, config: CacheConfig):
super().__init__(config)
self._cache = [None] * self._config.blocks_count
def apply_imp(self, func: callable, *args, **kwargs):
if self._config.step_start < self._cur_step <= self._config.step_end and \
((self._cur_step - self._config.step_start) % self._config.step_interval != 0):
attn = self._cache[self._cur_block]
logger.debug(f"[AttentionCache] step: {self._cur_step} block: {self._cur_block} reuse cache.")
else:
attn = func(*args, **kwargs)
if self._config.step_start <= self._cur_step < self._config.step_end:
self._cache[self._cur_block] = attn
logger.debug(f"[AttentionCache] step: {self._cur_step} block: {self._cur_block} update cache.")
return attn
def _release(self):
self._cache = [None] * self._config.blocks_count